Remove unused async database and CRUD modules

- Deleted `database_async.py`, `base_async.py`, and `organization_async.py` modules due to deprecation and unused references across the project.
- Improved overall codebase clarity and minimized redundant functionality by removing unused async database logic, CRUD utilities, and organization-related operations.
This commit is contained in:
Felipe Cardoso
2025-11-01 05:47:43 +01:00
parent ee938ce6a6
commit efcf10f9aa
20 changed files with 972 additions and 2283 deletions

View File

@@ -7,7 +7,7 @@ from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.core.auth import get_token_data, TokenExpiredError, TokenInvalidError
from app.core.database_async import get_async_db
from app.core.database import get_db
from app.models.user import User
# OAuth2 configuration
@@ -15,7 +15,7 @@ oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/api/v1/auth/login")
async def get_current_user(
db: AsyncSession = Depends(get_async_db),
db: AsyncSession = Depends(get_db),
token: str = Depends(oauth2_scheme)
) -> User:
"""
@@ -139,7 +139,7 @@ async def get_optional_token(authorization: str = Header(None)) -> Optional[str]
async def get_optional_current_user(
db: AsyncSession = Depends(get_async_db),
db: AsyncSession = Depends(get_db),
token: Optional[str] = Depends(get_optional_token)
) -> Optional[User]:
"""

View File

@@ -14,8 +14,8 @@ from fastapi import Depends, HTTPException, status
from sqlalchemy.ext.asyncio import AsyncSession
from app.api.dependencies.auth import get_current_user
from app.core.database_async import get_async_db
from app.crud.organization_async import organization_async as organization_crud
from app.core.database import get_db
from app.crud.organization import organization as organization_crud
from app.models.user import User
from app.models.user_organization import OrganizationRole
@@ -78,7 +78,7 @@ class OrganizationPermission:
self,
organization_id: UUID,
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_async_db)
db: AsyncSession = Depends(get_db)
) -> User:
"""
Check if user has required role in the organization.
@@ -133,7 +133,7 @@ require_org_member = OrganizationPermission([
async def get_current_org_role(
organization_id: UUID,
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_async_db)
db: AsyncSession = Depends(get_db)
) -> Optional[OrganizationRole]:
"""
Get the current user's role in an organization.
@@ -164,7 +164,7 @@ async def get_current_org_role(
async def require_org_membership(
organization_id: UUID,
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_async_db)
db: AsyncSession = Depends(get_db)
) -> User:
"""
Ensure user is a member of the organization (any role).

View File

@@ -15,10 +15,10 @@ from pydantic import BaseModel, Field
from sqlalchemy.ext.asyncio import AsyncSession
from app.api.dependencies.permissions import require_superuser
from app.core.database_async import get_async_db
from app.core.database import get_db
from app.core.exceptions import NotFoundError, DuplicateError, AuthorizationError, ErrorCode
from app.crud.organization_async import organization_async as organization_crud
from app.crud.user_async import user_async as user_crud
from app.crud.organization import organization as organization_crud
from app.crud.user import user as user_crud
from app.models.user import User
from app.models.user_organization import OrganizationRole
from app.schemas.common import (
@@ -80,7 +80,7 @@ async def admin_list_users(
is_superuser: Optional[bool] = Query(None, description="Filter by superuser status"),
search: Optional[str] = Query(None, description="Search by email, name"),
admin: User = Depends(require_superuser),
db: AsyncSession = Depends(get_async_db)
db: AsyncSession = Depends(get_db)
) -> Any:
"""
List all users with comprehensive filtering and search.
@@ -131,7 +131,7 @@ async def admin_list_users(
async def admin_create_user(
user_in: UserCreate,
admin: User = Depends(require_superuser),
db: AsyncSession = Depends(get_async_db)
db: AsyncSession = Depends(get_db)
) -> Any:
"""
Create a new user with admin privileges.
@@ -163,7 +163,7 @@ async def admin_create_user(
async def admin_get_user(
user_id: UUID,
admin: User = Depends(require_superuser),
db: AsyncSession = Depends(get_async_db)
db: AsyncSession = Depends(get_db)
) -> Any:
"""Get detailed information about a specific user."""
user = await user_crud.get(db, id=user_id)
@@ -186,7 +186,7 @@ async def admin_update_user(
user_id: UUID,
user_in: UserUpdate,
admin: User = Depends(require_superuser),
db: AsyncSession = Depends(get_async_db)
db: AsyncSession = Depends(get_db)
) -> Any:
"""Update user information with admin privileges."""
try:
@@ -218,7 +218,7 @@ async def admin_update_user(
async def admin_delete_user(
user_id: UUID,
admin: User = Depends(require_superuser),
db: AsyncSession = Depends(get_async_db)
db: AsyncSession = Depends(get_db)
) -> Any:
"""Soft delete a user (sets deleted_at timestamp)."""
try:
@@ -262,7 +262,7 @@ async def admin_delete_user(
async def admin_activate_user(
user_id: UUID,
admin: User = Depends(require_superuser),
db: AsyncSession = Depends(get_async_db)
db: AsyncSession = Depends(get_db)
) -> Any:
"""Activate a user account."""
try:
@@ -298,7 +298,7 @@ async def admin_activate_user(
async def admin_deactivate_user(
user_id: UUID,
admin: User = Depends(require_superuser),
db: AsyncSession = Depends(get_async_db)
db: AsyncSession = Depends(get_db)
) -> Any:
"""Deactivate a user account."""
try:
@@ -342,7 +342,7 @@ async def admin_deactivate_user(
async def admin_bulk_user_action(
bulk_action: BulkUserAction,
admin: User = Depends(require_superuser),
db: AsyncSession = Depends(get_async_db)
db: AsyncSession = Depends(get_db)
) -> Any:
"""
Perform bulk actions on multiple users using optimized bulk operations.
@@ -410,7 +410,7 @@ async def admin_list_organizations(
is_active: Optional[bool] = Query(None, description="Filter by active status"),
search: Optional[str] = Query(None, description="Search by name, slug, description"),
admin: User = Depends(require_superuser),
db: AsyncSession = Depends(get_async_db)
db: AsyncSession = Depends(get_db)
) -> Any:
"""List all organizations with filtering and search."""
try:
@@ -467,7 +467,7 @@ async def admin_list_organizations(
async def admin_create_organization(
org_in: OrganizationCreate,
admin: User = Depends(require_superuser),
db: AsyncSession = Depends(get_async_db)
db: AsyncSession = Depends(get_db)
) -> Any:
"""Create a new organization."""
try:
@@ -509,7 +509,7 @@ async def admin_create_organization(
async def admin_get_organization(
org_id: UUID,
admin: User = Depends(require_superuser),
db: AsyncSession = Depends(get_async_db)
db: AsyncSession = Depends(get_db)
) -> Any:
"""Get detailed information about a specific organization."""
org = await organization_crud.get(db, id=org_id)
@@ -544,7 +544,7 @@ async def admin_update_organization(
org_id: UUID,
org_in: OrganizationUpdate,
admin: User = Depends(require_superuser),
db: AsyncSession = Depends(get_async_db)
db: AsyncSession = Depends(get_db)
) -> Any:
"""Update organization information."""
try:
@@ -588,7 +588,7 @@ async def admin_update_organization(
async def admin_delete_organization(
org_id: UUID,
admin: User = Depends(require_superuser),
db: AsyncSession = Depends(get_async_db)
db: AsyncSession = Depends(get_db)
) -> Any:
"""Delete an organization and all its relationships."""
try:
@@ -626,7 +626,7 @@ async def admin_list_organization_members(
pagination: PaginationParams = Depends(),
is_active: Optional[bool] = Query(True, description="Filter by active status"),
admin: User = Depends(require_superuser),
db: AsyncSession = Depends(get_async_db)
db: AsyncSession = Depends(get_db)
) -> Any:
"""List all members of an organization."""
try:
@@ -681,7 +681,7 @@ async def admin_add_organization_member(
org_id: UUID,
request: AddMemberRequest,
admin: User = Depends(require_superuser),
db: AsyncSession = Depends(get_async_db)
db: AsyncSession = Depends(get_db)
) -> Any:
"""Add a user to an organization."""
try:
@@ -742,7 +742,7 @@ async def admin_remove_organization_member(
org_id: UUID,
user_id: UUID,
admin: User = Depends(require_superuser),
db: AsyncSession = Depends(get_async_db)
db: AsyncSession = Depends(get_db)
) -> Any:
"""Remove a user from an organization."""
try:

View File

@@ -13,14 +13,14 @@ from sqlalchemy.ext.asyncio import AsyncSession
from app.api.dependencies.auth import get_current_user
from app.core.auth import TokenExpiredError, TokenInvalidError, decode_token
from app.core.auth import get_password_hash
from app.core.database_async import get_async_db
from app.core.database import get_db
from app.core.exceptions import (
AuthenticationError as AuthError,
DatabaseError,
ErrorCode
)
from app.crud.session_async import session_async as session_crud
from app.crud.user_async import user_async as user_crud
from app.crud.session import session as session_crud
from app.crud.user import user as user_crud
from app.models.user import User
from app.schemas.common import MessageResponse
from app.schemas.sessions import SessionCreate, LogoutRequest
@@ -54,7 +54,7 @@ RATE_MULTIPLIER = 100 if IS_TEST else 1
async def register_user(
request: Request,
user_data: UserCreate,
db: AsyncSession = Depends(get_async_db)
db: AsyncSession = Depends(get_db)
) -> Any:
"""
Register a new user.
@@ -85,7 +85,7 @@ async def register_user(
async def login(
request: Request,
login_data: LoginRequest,
db: AsyncSession = Depends(get_async_db)
db: AsyncSession = Depends(get_db)
) -> Any:
"""
Login with username and password.
@@ -167,7 +167,7 @@ async def login(
async def login_oauth(
request: Request,
form_data: OAuth2PasswordRequestForm = Depends(),
db: AsyncSession = Depends(get_async_db)
db: AsyncSession = Depends(get_db)
) -> Any:
"""
OAuth2-compatible login endpoint, used by the OpenAPI UI.
@@ -244,7 +244,7 @@ async def login_oauth(
async def refresh_token(
request: Request,
refresh_data: RefreshTokenRequest,
db: AsyncSession = Depends(get_async_db)
db: AsyncSession = Depends(get_db)
) -> Any:
"""
Refresh access token using a refresh token.
@@ -333,7 +333,7 @@ async def refresh_token(
async def request_password_reset(
request: Request,
reset_request: PasswordResetRequest,
db: AsyncSession = Depends(get_async_db)
db: AsyncSession = Depends(get_db)
) -> Any:
"""
Request a password reset.
@@ -391,7 +391,7 @@ async def request_password_reset(
async def confirm_password_reset(
request: Request,
reset_confirm: PasswordResetConfirm,
db: AsyncSession = Depends(get_async_db)
db: AsyncSession = Depends(get_db)
) -> Any:
"""
Confirm password reset with token.
@@ -430,7 +430,7 @@ async def confirm_password_reset(
# SECURITY: Invalidate all existing sessions after password reset
# This prevents stolen sessions from being used after password change
from app.crud.session_async import session_async as session_crud
from app.crud.session import session as session_crud
try:
deactivated_count = await session_crud.deactivate_all_user_sessions(
db,
@@ -478,7 +478,7 @@ async def logout(
request: Request,
logout_request: LogoutRequest,
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_async_db)
db: AsyncSession = Depends(get_db)
) -> Any:
"""
Logout from current device by deactivating the session.
@@ -566,7 +566,7 @@ async def logout(
async def logout_all(
request: Request,
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_async_db)
db: AsyncSession = Depends(get_db)
) -> Any:
"""
Logout from all devices by deactivating all user sessions.

View File

@@ -13,9 +13,9 @@ from sqlalchemy.ext.asyncio import AsyncSession
from app.api.dependencies.auth import get_current_user
from app.api.dependencies.permissions import require_org_admin, require_org_membership
from app.core.database_async import get_async_db
from app.core.database import get_db
from app.core.exceptions import NotFoundError, ErrorCode
from app.crud.organization_async import organization_async as organization_crud
from app.crud.organization import organization as organization_crud
from app.models.user import User
from app.schemas.common import (
PaginationParams,
@@ -43,7 +43,7 @@ router = APIRouter()
async def get_my_organizations(
is_active: bool = Query(True, description="Filter by active membership"),
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_async_db)
db: AsyncSession = Depends(get_db)
) -> Any:
"""
Get all organizations the current user belongs to.
@@ -93,7 +93,7 @@ async def get_my_organizations(
async def get_organization(
organization_id: UUID,
current_user: User = Depends(require_org_membership),
db: AsyncSession = Depends(get_async_db)
db: AsyncSession = Depends(get_db)
) -> Any:
"""
Get details of a specific organization.
@@ -140,7 +140,7 @@ async def get_organization_members(
pagination: PaginationParams = Depends(),
is_active: bool = Query(True, description="Filter by active status"),
current_user: User = Depends(require_org_membership),
db: AsyncSession = Depends(get_async_db)
db: AsyncSession = Depends(get_db)
) -> Any:
"""
Get all members of an organization.
@@ -183,7 +183,7 @@ async def update_organization(
organization_id: UUID,
org_in: OrganizationUpdate,
current_user: User = Depends(require_org_admin),
db: AsyncSession = Depends(get_async_db)
db: AsyncSession = Depends(get_db)
) -> Any:
"""
Update organization details.

View File

@@ -14,9 +14,9 @@ from sqlalchemy.ext.asyncio import AsyncSession
from app.api.dependencies.auth import get_current_user
from app.core.auth import decode_token
from app.core.database_async import get_async_db
from app.core.database import get_db
from app.core.exceptions import NotFoundError, AuthorizationError, ErrorCode
from app.crud.session_async import session_async as session_crud
from app.crud.session import session as session_crud
from app.models.user import User
from app.schemas.common import MessageResponse
from app.schemas.sessions import SessionResponse, SessionListResponse
@@ -45,7 +45,7 @@ limiter = Limiter(key_func=get_remote_address)
async def list_my_sessions(
request: Request,
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_async_db)
db: AsyncSession = Depends(get_db)
) -> Any:
"""
List all active sessions for the current user.
@@ -129,7 +129,7 @@ async def revoke_session(
request: Request,
session_id: UUID,
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_async_db)
db: AsyncSession = Depends(get_db)
) -> Any:
"""
Revoke a specific session by ID.
@@ -204,7 +204,7 @@ async def revoke_session(
async def cleanup_expired_sessions(
request: Request,
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_async_db)
db: AsyncSession = Depends(get_db)
) -> Any:
"""
Cleanup expired sessions for the current user.

View File

@@ -11,13 +11,13 @@ from slowapi.util import get_remote_address
from sqlalchemy.ext.asyncio import AsyncSession
from app.api.dependencies.auth import get_current_user, get_current_superuser
from app.core.database_async import get_async_db
from app.core.database import get_db
from app.core.exceptions import (
NotFoundError,
AuthorizationError,
ErrorCode
)
from app.crud.user_async import user_async as user_crud
from app.crud.user import user as user_crud
from app.models.user import User
from app.schemas.common import (
PaginationParams,
@@ -58,7 +58,7 @@ async def list_users(
is_active: Optional[bool] = Query(None, description="Filter by active status"),
is_superuser: Optional[bool] = Query(None, description="Filter by superuser status"),
current_user: User = Depends(get_current_superuser),
db: AsyncSession = Depends(get_async_db)
db: AsyncSession = Depends(get_db)
) -> Any:
"""
List all users with pagination, filtering, and sorting.
@@ -138,7 +138,7 @@ def get_current_user_profile(
async def update_current_user(
user_update: UserUpdate,
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_async_db)
db: AsyncSession = Depends(get_db)
) -> Any:
"""
Update current user's profile.
@@ -188,7 +188,7 @@ async def update_current_user(
async def get_user_by_id(
user_id: UUID,
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_async_db)
db: AsyncSession = Depends(get_db)
) -> Any:
"""
Get user by ID.
@@ -236,7 +236,7 @@ async def update_user(
user_id: UUID,
user_update: UserUpdate,
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_async_db)
db: AsyncSession = Depends(get_db)
) -> Any:
"""
Update user by ID.
@@ -304,7 +304,7 @@ async def change_current_user_password(
request: Request,
password_change: PasswordChange,
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_async_db)
db: AsyncSession = Depends(get_db)
) -> Any:
"""
Change current user's password.
@@ -356,7 +356,7 @@ async def change_current_user_password(
async def delete_user(
user_id: UUID,
current_user: User = Depends(get_current_superuser),
db: AsyncSession = Depends(get_async_db)
db: AsyncSession = Depends(get_db)
) -> Any:
"""
Delete user by ID (superuser only).

207
backend/app/core/database.py Normal file → Executable file
View File

@@ -1,113 +1,186 @@
# app/core/database.py
import logging
from contextlib import contextmanager
from typing import Generator
"""
Database configuration using SQLAlchemy 2.0 and asyncpg.
from sqlalchemy import create_engine, text
This module provides async database connectivity with proper connection pooling
and session management for FastAPI endpoints.
"""
import logging
from contextlib import asynccontextmanager
from typing import AsyncGenerator
from sqlalchemy import text
from sqlalchemy.dialects.postgresql import JSONB, UUID
from sqlalchemy.ext.asyncio import (
AsyncSession,
AsyncEngine,
create_async_engine,
async_sessionmaker,
)
from sqlalchemy.ext.compiler import compiles
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import sessionmaker, Session
from sqlalchemy.orm import DeclarativeBase
from app.core.config import settings
# Configure logging
logger = logging.getLogger(__name__)
# SQLite compatibility for testing
@compiles(JSONB, 'sqlite')
def compile_jsonb_sqlite(type_, compiler, **kw):
return "TEXT"
@compiles(UUID, 'sqlite')
def compile_uuid_sqlite(type_, compiler, **kw):
return "TEXT"
# Declarative base for models
Base = declarative_base()
# Create engine with optimized settings for PostgreSQL
def create_production_engine():
return create_engine(
settings.database_url,
# Connection pool settings
pool_size=settings.db_pool_size,
max_overflow=settings.db_max_overflow,
pool_timeout=settings.db_pool_timeout,
pool_recycle=settings.db_pool_recycle,
pool_pre_ping=True,
# Query execution settings
connect_args={
"application_name": "eventspace",
"keepalives": 1,
"keepalives_idle": 60,
"keepalives_interval": 10,
"keepalives_count": 5,
"options": "-c timezone=UTC",
},
isolation_level="READ COMMITTED",
echo=settings.sql_echo,
echo_pool=settings.sql_echo_pool,
)
# Declarative base for models (SQLAlchemy 2.0 style)
class Base(DeclarativeBase):
"""Base class for all database models."""
pass
# Default production engine and session factory
engine = create_production_engine()
SessionLocal = sessionmaker(
def get_async_database_url(url: str) -> str:
"""
Convert sync database URL to async URL.
postgresql:// -> postgresql+asyncpg://
sqlite:// -> sqlite+aiosqlite://
"""
if url.startswith("postgresql://"):
return url.replace("postgresql://", "postgresql+asyncpg://")
elif url.startswith("sqlite://"):
return url.replace("sqlite://", "sqlite+aiosqlite://")
return url
# Create async engine with optimized settings
def create_async_production_engine() -> AsyncEngine:
"""Create an async database engine with production settings."""
async_url = get_async_database_url(settings.database_url)
# Base engine config
engine_config = {
"pool_size": settings.db_pool_size,
"max_overflow": settings.db_max_overflow,
"pool_timeout": settings.db_pool_timeout,
"pool_recycle": settings.db_pool_recycle,
"pool_pre_ping": True,
"echo": settings.sql_echo,
"echo_pool": settings.sql_echo_pool,
}
# Add PostgreSQL-specific connect_args
if "postgresql" in async_url:
engine_config["connect_args"] = {
"server_settings": {
"application_name": "eventspace",
"timezone": "UTC",
},
# asyncpg-specific settings
"command_timeout": 60,
"timeout": 10,
}
return create_async_engine(async_url, **engine_config)
# Create async engine and session factory
engine = create_async_production_engine()
SessionLocal = async_sessionmaker(
engine,
class_=AsyncSession,
autocommit=False,
autoflush=False,
bind=engine,
expire_on_commit=False # Prevent unnecessary queries after commit
expire_on_commit=False, # Prevent unnecessary queries after commit
)
# FastAPI dependency
def get_db() -> Generator[Session, None, None]:
# FastAPI dependency for async database sessions
async def get_db() -> AsyncGenerator[AsyncSession, None]:
"""
FastAPI dependency that provides a database session.
FastAPI dependency that provides an async database session.
Automatically closes the session after the request completes.
Usage:
@router.get("/users")
async def get_users(db: AsyncSession = Depends(get_db)):
result = await db.execute(select(User))
return result.scalars().all()
"""
db = SessionLocal()
try:
yield db
finally:
db.close()
async with SessionLocal() as session:
try:
yield session
finally:
await session.close()
@contextmanager
def transaction_scope() -> Generator[Session, None, None]:
@asynccontextmanager
async def async_transaction_scope() -> AsyncGenerator[AsyncSession, None]:
"""
Provide a transactional scope for database operations.
Provide an async transactional scope for database operations.
Automatically commits on success or rolls back on exception.
Useful for grouping multiple operations in a single transaction.
Usage:
with transaction_scope() as db:
user = user_crud.create(db, obj_in=user_create)
profile = profile_crud.create(db, obj_in=profile_create)
async with async_transaction_scope() as db:
user = await user_crud.create(db, obj_in=user_create)
profile = await profile_crud.create(db, obj_in=profile_create)
# Both operations committed together
"""
db = SessionLocal()
try:
yield db
db.commit()
logger.debug("Transaction committed successfully")
except Exception as e:
db.rollback()
logger.error(f"Transaction failed, rolling back: {str(e)}")
raise
finally:
db.close()
async with SessionLocal() as session:
try:
yield session
await session.commit()
logger.debug("Async transaction committed successfully")
except Exception as e:
await session.rollback()
logger.error(f"Async transaction failed, rolling back: {str(e)}")
raise
finally:
await session.close()
def check_database_health() -> bool:
async def check_async_database_health() -> bool:
"""
Check if database connection is healthy.
Check if async database connection is healthy.
Returns True if connection is successful, False otherwise.
"""
try:
with transaction_scope() as db:
db.execute(text("SELECT 1"))
async with async_transaction_scope() as db:
await db.execute(text("SELECT 1"))
return True
except Exception as e:
logger.error(f"Database health check failed: {str(e)}")
return False
logger.error(f"Async database health check failed: {str(e)}")
return False
# Alias for consistency with main.py
check_database_health = check_async_database_health
async def init_async_db() -> None:
"""
Initialize async database tables.
This creates all tables defined in the models.
Should only be used in development or testing.
In production, use Alembic migrations.
"""
async with engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
logger.info("Async database tables created")
async def close_async_db() -> None:
"""
Close all async database connections.
Should be called during application shutdown.
"""
await engine.dispose()
logger.info("Async database connections closed")

View File

@@ -1,186 +0,0 @@
# app/core/database_async.py
"""
Async database configuration using SQLAlchemy 2.0 and asyncpg.
This module provides async database connectivity with proper connection pooling
and session management for FastAPI endpoints.
"""
import logging
from contextlib import asynccontextmanager
from typing import AsyncGenerator
from sqlalchemy import text
from sqlalchemy.dialects.postgresql import JSONB, UUID
from sqlalchemy.ext.asyncio import (
AsyncSession,
AsyncEngine,
create_async_engine,
async_sessionmaker,
)
from sqlalchemy.ext.compiler import compiles
from sqlalchemy.orm import DeclarativeBase
from app.core.config import settings
# Configure logging
logger = logging.getLogger(__name__)
# SQLite compatibility for testing
@compiles(JSONB, 'sqlite')
def compile_jsonb_sqlite(type_, compiler, **kw):
return "TEXT"
@compiles(UUID, 'sqlite')
def compile_uuid_sqlite(type_, compiler, **kw):
return "TEXT"
# Declarative base for models (SQLAlchemy 2.0 style)
class Base(DeclarativeBase):
"""Base class for all database models."""
pass
def get_async_database_url(url: str) -> str:
"""
Convert sync database URL to async URL.
postgresql:// -> postgresql+asyncpg://
sqlite:// -> sqlite+aiosqlite://
"""
if url.startswith("postgresql://"):
return url.replace("postgresql://", "postgresql+asyncpg://")
elif url.startswith("sqlite://"):
return url.replace("sqlite://", "sqlite+aiosqlite://")
return url
# Create async engine with optimized settings
def create_async_production_engine() -> AsyncEngine:
"""Create an async database engine with production settings."""
async_url = get_async_database_url(settings.database_url)
# Base engine config
engine_config = {
"pool_size": settings.db_pool_size,
"max_overflow": settings.db_max_overflow,
"pool_timeout": settings.db_pool_timeout,
"pool_recycle": settings.db_pool_recycle,
"pool_pre_ping": True,
"echo": settings.sql_echo,
"echo_pool": settings.sql_echo_pool,
}
# Add PostgreSQL-specific connect_args
if "postgresql" in async_url:
engine_config["connect_args"] = {
"server_settings": {
"application_name": "eventspace",
"timezone": "UTC",
},
# asyncpg-specific settings
"command_timeout": 60,
"timeout": 10,
}
return create_async_engine(async_url, **engine_config)
# Create async engine and session factory
async_engine = create_async_production_engine()
AsyncSessionLocal = async_sessionmaker(
async_engine,
class_=AsyncSession,
autocommit=False,
autoflush=False,
expire_on_commit=False, # Prevent unnecessary queries after commit
)
# FastAPI dependency for async database sessions
async def get_async_db() -> AsyncGenerator[AsyncSession, None]:
"""
FastAPI dependency that provides an async database session.
Automatically closes the session after the request completes.
Usage:
@router.get("/users")
async def get_users(db: AsyncSession = Depends(get_async_db)):
result = await db.execute(select(User))
return result.scalars().all()
"""
async with AsyncSessionLocal() as session:
try:
yield session
finally:
await session.close()
@asynccontextmanager
async def async_transaction_scope() -> AsyncGenerator[AsyncSession, None]:
"""
Provide an async transactional scope for database operations.
Automatically commits on success or rolls back on exception.
Useful for grouping multiple operations in a single transaction.
Usage:
async with async_transaction_scope() as db:
user = await user_crud.create(db, obj_in=user_create)
profile = await profile_crud.create(db, obj_in=profile_create)
# Both operations committed together
"""
async with AsyncSessionLocal() as session:
try:
yield session
await session.commit()
logger.debug("Async transaction committed successfully")
except Exception as e:
await session.rollback()
logger.error(f"Async transaction failed, rolling back: {str(e)}")
raise
finally:
await session.close()
async def check_async_database_health() -> bool:
"""
Check if async database connection is healthy.
Returns True if connection is successful, False otherwise.
"""
try:
async with async_transaction_scope() as db:
await db.execute(text("SELECT 1"))
return True
except Exception as e:
logger.error(f"Async database health check failed: {str(e)}")
return False
# Alias for consistency with main.py
check_database_health = check_async_database_health
async def init_async_db() -> None:
"""
Initialize async database tables.
This creates all tables defined in the models.
Should only be used in development or testing.
In production, use Alembic migrations.
"""
async with async_engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
logger.info("Async database tables created")
async def close_async_db() -> None:
"""
Close all async database connections.
Should be called during application shutdown.
"""
await async_engine.dispose()
logger.info("Async database connections closed")

207
backend/app/crud/base.py Normal file → Executable file
View File

@@ -1,13 +1,19 @@
# app/crud/base_async.py
"""
Async CRUD operations base class using SQLAlchemy 2.0 async patterns.
Provides reusable create, read, update, and delete operations for all models.
"""
import logging
import uuid
from datetime import datetime, timezone
from typing import Any, Dict, Generic, List, Optional, Type, TypeVar, Union, Tuple
from fastapi.encoders import jsonable_encoder
from pydantic import BaseModel
from sqlalchemy import asc, desc
from sqlalchemy import func, select
from sqlalchemy.exc import IntegrityError, OperationalError, DataError
from sqlalchemy.orm import Session
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import Load
from app.core.database import Base
@@ -19,17 +25,40 @@ UpdateSchemaType = TypeVar("UpdateSchemaType", bound=BaseModel)
class CRUDBase(Generic[ModelType, CreateSchemaType, UpdateSchemaType]):
"""Async CRUD operations for a model."""
def __init__(self, model: Type[ModelType]):
"""
CRUD object with default methods to Create, Read, Update, Delete (CRUD).
CRUD object with default async methods to Create, Read, Update, Delete.
Parameters:
model: A SQLAlchemy model class
"""
self.model = model
def get(self, db: Session, id: str) -> Optional[ModelType]:
"""Get a single record by ID with UUID validation."""
async def get(
self,
db: AsyncSession,
id: str,
options: Optional[List[Load]] = None
) -> Optional[ModelType]:
"""
Get a single record by ID with UUID validation and optional eager loading.
Args:
db: Database session
id: Record UUID
options: Optional list of SQLAlchemy load options (e.g., joinedload, selectinload)
for eager loading relationships to prevent N+1 queries
Returns:
Model instance or None if not found
Example:
# Eager load user relationship
from sqlalchemy.orm import joinedload
session = await session_crud.get(db, id=session_id, options=[joinedload(UserSession.user)])
"""
# Validate UUID format and convert to UUID object if string
try:
if isinstance(id, uuid.UUID):
@@ -41,15 +70,39 @@ class CRUDBase(Generic[ModelType, CreateSchemaType, UpdateSchemaType]):
return None
try:
return db.query(self.model).filter(self.model.id == uuid_obj).first()
query = select(self.model).where(self.model.id == uuid_obj)
# Apply eager loading options if provided
if options:
for option in options:
query = query.options(option)
result = await db.execute(query)
return result.scalar_one_or_none()
except Exception as e:
logger.error(f"Error retrieving {self.model.__name__} with id {id}: {str(e)}")
raise
def get_multi(
self, db: Session, *, skip: int = 0, limit: int = 100
async def get_multi(
self,
db: AsyncSession,
*,
skip: int = 0,
limit: int = 100,
options: Optional[List[Load]] = None
) -> List[ModelType]:
"""Get multiple records with pagination validation."""
"""
Get multiple records with pagination validation and optional eager loading.
Args:
db: Database session
skip: Number of records to skip
limit: Maximum number of records to return
options: Optional list of SQLAlchemy load options for eager loading
Returns:
List of model instances
"""
# Validate pagination parameters
if skip < 0:
raise ValueError("skip must be non-negative")
@@ -59,22 +112,30 @@ class CRUDBase(Generic[ModelType, CreateSchemaType, UpdateSchemaType]):
raise ValueError("Maximum limit is 1000")
try:
return db.query(self.model).offset(skip).limit(limit).all()
query = select(self.model).offset(skip).limit(limit)
# Apply eager loading options if provided
if options:
for option in options:
query = query.options(option)
result = await db.execute(query)
return list(result.scalars().all())
except Exception as e:
logger.error(f"Error retrieving multiple {self.model.__name__} records: {str(e)}")
raise
def create(self, db: Session, *, obj_in: CreateSchemaType) -> ModelType:
async def create(self, db: AsyncSession, *, obj_in: CreateSchemaType) -> ModelType:
"""Create a new record with error handling."""
try:
obj_in_data = jsonable_encoder(obj_in)
db_obj = self.model(**obj_in_data)
db.add(db_obj)
db.commit()
db.refresh(db_obj)
await db.commit()
await db.refresh(db_obj)
return db_obj
except IntegrityError as e:
db.rollback()
await db.rollback()
error_msg = str(e.orig) if hasattr(e, 'orig') else str(e)
if "unique" in error_msg.lower() or "duplicate" in error_msg.lower():
logger.warning(f"Duplicate entry attempted for {self.model.__name__}: {error_msg}")
@@ -82,20 +143,20 @@ class CRUDBase(Generic[ModelType, CreateSchemaType, UpdateSchemaType]):
logger.error(f"Integrity error creating {self.model.__name__}: {error_msg}")
raise ValueError(f"Database integrity error: {error_msg}")
except (OperationalError, DataError) as e:
db.rollback()
await db.rollback()
logger.error(f"Database error creating {self.model.__name__}: {str(e)}")
raise ValueError(f"Database operation failed: {str(e)}")
except Exception as e:
db.rollback()
await db.rollback()
logger.error(f"Unexpected error creating {self.model.__name__}: {str(e)}", exc_info=True)
raise
def update(
self,
db: Session,
*,
db_obj: ModelType,
obj_in: Union[UpdateSchemaType, Dict[str, Any]]
async def update(
self,
db: AsyncSession,
*,
db_obj: ModelType,
obj_in: Union[UpdateSchemaType, Dict[str, Any]]
) -> ModelType:
"""Update a record with error handling."""
try:
@@ -104,15 +165,17 @@ class CRUDBase(Generic[ModelType, CreateSchemaType, UpdateSchemaType]):
update_data = obj_in
else:
update_data = obj_in.model_dump(exclude_unset=True)
for field in obj_data:
if field in update_data:
setattr(db_obj, field, update_data[field])
db.add(db_obj)
db.commit()
db.refresh(db_obj)
await db.commit()
await db.refresh(db_obj)
return db_obj
except IntegrityError as e:
db.rollback()
await db.rollback()
error_msg = str(e.orig) if hasattr(e, 'orig') else str(e)
if "unique" in error_msg.lower() or "duplicate" in error_msg.lower():
logger.warning(f"Duplicate entry attempted for {self.model.__name__}: {error_msg}")
@@ -120,15 +183,15 @@ class CRUDBase(Generic[ModelType, CreateSchemaType, UpdateSchemaType]):
logger.error(f"Integrity error updating {self.model.__name__}: {error_msg}")
raise ValueError(f"Database integrity error: {error_msg}")
except (OperationalError, DataError) as e:
db.rollback()
await db.rollback()
logger.error(f"Database error updating {self.model.__name__}: {str(e)}")
raise ValueError(f"Database operation failed: {str(e)}")
except Exception as e:
db.rollback()
await db.rollback()
logger.error(f"Unexpected error updating {self.model.__name__}: {str(e)}", exc_info=True)
raise
def remove(self, db: Session, *, id: str) -> Optional[ModelType]:
async def remove(self, db: AsyncSession, *, id: str) -> Optional[ModelType]:
"""Delete a record with error handling and null check."""
# Validate UUID format and convert to UUID object if string
try:
@@ -141,27 +204,31 @@ class CRUDBase(Generic[ModelType, CreateSchemaType, UpdateSchemaType]):
return None
try:
obj = db.query(self.model).filter(self.model.id == uuid_obj).first()
result = await db.execute(
select(self.model).where(self.model.id == uuid_obj)
)
obj = result.scalar_one_or_none()
if obj is None:
logger.warning(f"{self.model.__name__} with id {id} not found for deletion")
return None
db.delete(obj)
db.commit()
await db.delete(obj)
await db.commit()
return obj
except IntegrityError as e:
db.rollback()
await db.rollback()
error_msg = str(e.orig) if hasattr(e, 'orig') else str(e)
logger.error(f"Integrity error deleting {self.model.__name__}: {error_msg}")
raise ValueError(f"Cannot delete {self.model.__name__}: referenced by other records")
except Exception as e:
db.rollback()
await db.rollback()
logger.error(f"Error deleting {self.model.__name__} with id {id}: {str(e)}", exc_info=True)
raise
def get_multi_with_total(
async def get_multi_with_total(
self,
db: Session,
db: AsyncSession,
*,
skip: int = 0,
limit: int = 100,
@@ -193,43 +260,63 @@ class CRUDBase(Generic[ModelType, CreateSchemaType, UpdateSchemaType]):
try:
# Build base query
query = db.query(self.model)
query = select(self.model)
# Exclude soft-deleted records by default
if hasattr(self.model, 'deleted_at'):
query = query.filter(self.model.deleted_at.is_(None))
query = query.where(self.model.deleted_at.is_(None))
# Apply filters
if filters:
for field, value in filters.items():
if hasattr(self.model, field) and value is not None:
query = query.filter(getattr(self.model, field) == value)
query = query.where(getattr(self.model, field) == value)
# Get total count (before pagination)
total = query.count()
count_query = select(func.count()).select_from(query.alias())
count_result = await db.execute(count_query)
total = count_result.scalar_one()
# Apply sorting
if sort_by and hasattr(self.model, sort_by):
sort_column = getattr(self.model, sort_by)
if sort_order.lower() == "desc":
query = query.order_by(desc(sort_column))
query = query.order_by(sort_column.desc())
else:
query = query.order_by(asc(sort_column))
query = query.order_by(sort_column.asc())
# Apply pagination
items = query.offset(skip).limit(limit).all()
query = query.offset(skip).limit(limit)
items_result = await db.execute(query)
items = list(items_result.scalars().all())
return items, total
except Exception as e:
logger.error(f"Error retrieving paginated {self.model.__name__} records: {str(e)}")
raise
def soft_delete(self, db: Session, *, id: str) -> Optional[ModelType]:
async def count(self, db: AsyncSession) -> int:
"""Get total count of records."""
try:
result = await db.execute(select(func.count(self.model.id)))
return result.scalar_one()
except Exception as e:
logger.error(f"Error counting {self.model.__name__} records: {str(e)}")
raise
async def exists(self, db: AsyncSession, id: str) -> bool:
"""Check if a record exists by ID."""
obj = await self.get(db, id=id)
return obj is not None
async def soft_delete(self, db: AsyncSession, *, id: str) -> Optional[ModelType]:
"""
Soft delete a record by setting deleted_at timestamp.
Only works if the model has a 'deleted_at' column.
"""
from datetime import datetime, timezone
# Validate UUID format and convert to UUID object if string
try:
if isinstance(id, uuid.UUID):
@@ -241,7 +328,10 @@ class CRUDBase(Generic[ModelType, CreateSchemaType, UpdateSchemaType]):
return None
try:
obj = db.query(self.model).filter(self.model.id == uuid_obj).first()
result = await db.execute(
select(self.model).where(self.model.id == uuid_obj)
)
obj = result.scalar_one_or_none()
if obj is None:
logger.warning(f"{self.model.__name__} with id {id} not found for soft deletion")
@@ -255,15 +345,15 @@ class CRUDBase(Generic[ModelType, CreateSchemaType, UpdateSchemaType]):
# Set deleted_at timestamp
obj.deleted_at = datetime.now(timezone.utc)
db.add(obj)
db.commit()
db.refresh(obj)
await db.commit()
await db.refresh(obj)
return obj
except Exception as e:
db.rollback()
await db.rollback()
logger.error(f"Error soft deleting {self.model.__name__} with id {id}: {str(e)}", exc_info=True)
raise
def restore(self, db: Session, *, id: str) -> Optional[ModelType]:
async def restore(self, db: AsyncSession, *, id: str) -> Optional[ModelType]:
"""
Restore a soft-deleted record by clearing the deleted_at timestamp.
@@ -282,10 +372,13 @@ class CRUDBase(Generic[ModelType, CreateSchemaType, UpdateSchemaType]):
try:
# Find the soft-deleted record
if hasattr(self.model, 'deleted_at'):
obj = db.query(self.model).filter(
self.model.id == uuid_obj,
self.model.deleted_at.isnot(None)
).first()
result = await db.execute(
select(self.model).where(
self.model.id == uuid_obj,
self.model.deleted_at.isnot(None)
)
)
obj = result.scalar_one_or_none()
else:
logger.error(f"{self.model.__name__} does not support soft deletes")
raise ValueError(f"{self.model.__name__} does not have a deleted_at column")
@@ -297,10 +390,10 @@ class CRUDBase(Generic[ModelType, CreateSchemaType, UpdateSchemaType]):
# Clear deleted_at timestamp
obj.deleted_at = None
db.add(obj)
db.commit()
db.refresh(obj)
await db.commit()
await db.refresh(obj)
return obj
except Exception as e:
db.rollback()
await db.rollback()
logger.error(f"Error restoring {self.model.__name__} with id {id}: {str(e)}", exc_info=True)
raise
raise

View File

@@ -1,399 +0,0 @@
# app/crud/base_async.py
"""
Async CRUD operations base class using SQLAlchemy 2.0 async patterns.
Provides reusable create, read, update, and delete operations for all models.
"""
import logging
import uuid
from typing import Any, Dict, Generic, List, Optional, Type, TypeVar, Union, Tuple
from fastapi.encoders import jsonable_encoder
from pydantic import BaseModel
from sqlalchemy import func, select
from sqlalchemy.exc import IntegrityError, OperationalError, DataError
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import Load
from app.core.database_async import Base
logger = logging.getLogger(__name__)
ModelType = TypeVar("ModelType", bound=Base)
CreateSchemaType = TypeVar("CreateSchemaType", bound=BaseModel)
UpdateSchemaType = TypeVar("UpdateSchemaType", bound=BaseModel)
class CRUDBaseAsync(Generic[ModelType, CreateSchemaType, UpdateSchemaType]):
"""Async CRUD operations for a model."""
def __init__(self, model: Type[ModelType]):
"""
CRUD object with default async methods to Create, Read, Update, Delete.
Parameters:
model: A SQLAlchemy model class
"""
self.model = model
async def get(
self,
db: AsyncSession,
id: str,
options: Optional[List[Load]] = None
) -> Optional[ModelType]:
"""
Get a single record by ID with UUID validation and optional eager loading.
Args:
db: Database session
id: Record UUID
options: Optional list of SQLAlchemy load options (e.g., joinedload, selectinload)
for eager loading relationships to prevent N+1 queries
Returns:
Model instance or None if not found
Example:
# Eager load user relationship
from sqlalchemy.orm import joinedload
session = await session_crud.get(db, id=session_id, options=[joinedload(UserSession.user)])
"""
# Validate UUID format and convert to UUID object if string
try:
if isinstance(id, uuid.UUID):
uuid_obj = id
else:
uuid_obj = uuid.UUID(str(id))
except (ValueError, AttributeError, TypeError) as e:
logger.warning(f"Invalid UUID format: {id} - {str(e)}")
return None
try:
query = select(self.model).where(self.model.id == uuid_obj)
# Apply eager loading options if provided
if options:
for option in options:
query = query.options(option)
result = await db.execute(query)
return result.scalar_one_or_none()
except Exception as e:
logger.error(f"Error retrieving {self.model.__name__} with id {id}: {str(e)}")
raise
async def get_multi(
self,
db: AsyncSession,
*,
skip: int = 0,
limit: int = 100,
options: Optional[List[Load]] = None
) -> List[ModelType]:
"""
Get multiple records with pagination validation and optional eager loading.
Args:
db: Database session
skip: Number of records to skip
limit: Maximum number of records to return
options: Optional list of SQLAlchemy load options for eager loading
Returns:
List of model instances
"""
# Validate pagination parameters
if skip < 0:
raise ValueError("skip must be non-negative")
if limit < 0:
raise ValueError("limit must be non-negative")
if limit > 1000:
raise ValueError("Maximum limit is 1000")
try:
query = select(self.model).offset(skip).limit(limit)
# Apply eager loading options if provided
if options:
for option in options:
query = query.options(option)
result = await db.execute(query)
return list(result.scalars().all())
except Exception as e:
logger.error(f"Error retrieving multiple {self.model.__name__} records: {str(e)}")
raise
async def create(self, db: AsyncSession, *, obj_in: CreateSchemaType) -> ModelType:
"""Create a new record with error handling."""
try:
obj_in_data = jsonable_encoder(obj_in)
db_obj = self.model(**obj_in_data)
db.add(db_obj)
await db.commit()
await db.refresh(db_obj)
return db_obj
except IntegrityError as e:
await db.rollback()
error_msg = str(e.orig) if hasattr(e, 'orig') else str(e)
if "unique" in error_msg.lower() or "duplicate" in error_msg.lower():
logger.warning(f"Duplicate entry attempted for {self.model.__name__}: {error_msg}")
raise ValueError(f"A {self.model.__name__} with this data already exists")
logger.error(f"Integrity error creating {self.model.__name__}: {error_msg}")
raise ValueError(f"Database integrity error: {error_msg}")
except (OperationalError, DataError) as e:
await db.rollback()
logger.error(f"Database error creating {self.model.__name__}: {str(e)}")
raise ValueError(f"Database operation failed: {str(e)}")
except Exception as e:
await db.rollback()
logger.error(f"Unexpected error creating {self.model.__name__}: {str(e)}", exc_info=True)
raise
async def update(
self,
db: AsyncSession,
*,
db_obj: ModelType,
obj_in: Union[UpdateSchemaType, Dict[str, Any]]
) -> ModelType:
"""Update a record with error handling."""
try:
obj_data = jsonable_encoder(db_obj)
if isinstance(obj_in, dict):
update_data = obj_in
else:
update_data = obj_in.model_dump(exclude_unset=True)
for field in obj_data:
if field in update_data:
setattr(db_obj, field, update_data[field])
db.add(db_obj)
await db.commit()
await db.refresh(db_obj)
return db_obj
except IntegrityError as e:
await db.rollback()
error_msg = str(e.orig) if hasattr(e, 'orig') else str(e)
if "unique" in error_msg.lower() or "duplicate" in error_msg.lower():
logger.warning(f"Duplicate entry attempted for {self.model.__name__}: {error_msg}")
raise ValueError(f"A {self.model.__name__} with this data already exists")
logger.error(f"Integrity error updating {self.model.__name__}: {error_msg}")
raise ValueError(f"Database integrity error: {error_msg}")
except (OperationalError, DataError) as e:
await db.rollback()
logger.error(f"Database error updating {self.model.__name__}: {str(e)}")
raise ValueError(f"Database operation failed: {str(e)}")
except Exception as e:
await db.rollback()
logger.error(f"Unexpected error updating {self.model.__name__}: {str(e)}", exc_info=True)
raise
async def remove(self, db: AsyncSession, *, id: str) -> Optional[ModelType]:
"""Delete a record with error handling and null check."""
# Validate UUID format and convert to UUID object if string
try:
if isinstance(id, uuid.UUID):
uuid_obj = id
else:
uuid_obj = uuid.UUID(str(id))
except (ValueError, AttributeError, TypeError) as e:
logger.warning(f"Invalid UUID format for deletion: {id} - {str(e)}")
return None
try:
result = await db.execute(
select(self.model).where(self.model.id == uuid_obj)
)
obj = result.scalar_one_or_none()
if obj is None:
logger.warning(f"{self.model.__name__} with id {id} not found for deletion")
return None
await db.delete(obj)
await db.commit()
return obj
except IntegrityError as e:
await db.rollback()
error_msg = str(e.orig) if hasattr(e, 'orig') else str(e)
logger.error(f"Integrity error deleting {self.model.__name__}: {error_msg}")
raise ValueError(f"Cannot delete {self.model.__name__}: referenced by other records")
except Exception as e:
await db.rollback()
logger.error(f"Error deleting {self.model.__name__} with id {id}: {str(e)}", exc_info=True)
raise
async def get_multi_with_total(
self,
db: AsyncSession,
*,
skip: int = 0,
limit: int = 100,
sort_by: Optional[str] = None,
sort_order: str = "asc",
filters: Optional[Dict[str, Any]] = None
) -> Tuple[List[ModelType], int]:
"""
Get multiple records with total count, filtering, and sorting.
Args:
db: Database session
skip: Number of records to skip
limit: Maximum number of records to return
sort_by: Field name to sort by (must be a valid model attribute)
sort_order: Sort order ("asc" or "desc")
filters: Dictionary of filters (field_name: value)
Returns:
Tuple of (items, total_count)
"""
# Validate pagination parameters
if skip < 0:
raise ValueError("skip must be non-negative")
if limit < 0:
raise ValueError("limit must be non-negative")
if limit > 1000:
raise ValueError("Maximum limit is 1000")
try:
# Build base query
query = select(self.model)
# Exclude soft-deleted records by default
if hasattr(self.model, 'deleted_at'):
query = query.where(self.model.deleted_at.is_(None))
# Apply filters
if filters:
for field, value in filters.items():
if hasattr(self.model, field) and value is not None:
query = query.where(getattr(self.model, field) == value)
# Get total count (before pagination)
count_query = select(func.count()).select_from(query.alias())
count_result = await db.execute(count_query)
total = count_result.scalar_one()
# Apply sorting
if sort_by and hasattr(self.model, sort_by):
sort_column = getattr(self.model, sort_by)
if sort_order.lower() == "desc":
query = query.order_by(sort_column.desc())
else:
query = query.order_by(sort_column.asc())
# Apply pagination
query = query.offset(skip).limit(limit)
items_result = await db.execute(query)
items = list(items_result.scalars().all())
return items, total
except Exception as e:
logger.error(f"Error retrieving paginated {self.model.__name__} records: {str(e)}")
raise
async def count(self, db: AsyncSession) -> int:
"""Get total count of records."""
try:
result = await db.execute(select(func.count(self.model.id)))
return result.scalar_one()
except Exception as e:
logger.error(f"Error counting {self.model.__name__} records: {str(e)}")
raise
async def exists(self, db: AsyncSession, id: str) -> bool:
"""Check if a record exists by ID."""
obj = await self.get(db, id=id)
return obj is not None
async def soft_delete(self, db: AsyncSession, *, id: str) -> Optional[ModelType]:
"""
Soft delete a record by setting deleted_at timestamp.
Only works if the model has a 'deleted_at' column.
"""
from datetime import datetime, timezone
# Validate UUID format and convert to UUID object if string
try:
if isinstance(id, uuid.UUID):
uuid_obj = id
else:
uuid_obj = uuid.UUID(str(id))
except (ValueError, AttributeError, TypeError) as e:
logger.warning(f"Invalid UUID format for soft deletion: {id} - {str(e)}")
return None
try:
result = await db.execute(
select(self.model).where(self.model.id == uuid_obj)
)
obj = result.scalar_one_or_none()
if obj is None:
logger.warning(f"{self.model.__name__} with id {id} not found for soft deletion")
return None
# Check if model supports soft deletes
if not hasattr(self.model, 'deleted_at'):
logger.error(f"{self.model.__name__} does not support soft deletes")
raise ValueError(f"{self.model.__name__} does not have a deleted_at column")
# Set deleted_at timestamp
obj.deleted_at = datetime.now(timezone.utc)
db.add(obj)
await db.commit()
await db.refresh(obj)
return obj
except Exception as e:
await db.rollback()
logger.error(f"Error soft deleting {self.model.__name__} with id {id}: {str(e)}", exc_info=True)
raise
async def restore(self, db: AsyncSession, *, id: str) -> Optional[ModelType]:
"""
Restore a soft-deleted record by clearing the deleted_at timestamp.
Only works if the model has a 'deleted_at' column.
"""
# Validate UUID format
try:
if isinstance(id, uuid.UUID):
uuid_obj = id
else:
uuid_obj = uuid.UUID(str(id))
except (ValueError, AttributeError, TypeError) as e:
logger.warning(f"Invalid UUID format for restoration: {id} - {str(e)}")
return None
try:
# Find the soft-deleted record
if hasattr(self.model, 'deleted_at'):
result = await db.execute(
select(self.model).where(
self.model.id == uuid_obj,
self.model.deleted_at.isnot(None)
)
)
obj = result.scalar_one_or_none()
else:
logger.error(f"{self.model.__name__} does not support soft deletes")
raise ValueError(f"{self.model.__name__} does not have a deleted_at column")
if obj is None:
logger.warning(f"Soft-deleted {self.model.__name__} with id {id} not found for restoration")
return None
# Clear deleted_at timestamp
obj.deleted_at = None
db.add(obj)
await db.commit()
await db.refresh(obj)
return obj
except Exception as e:
await db.rollback()
logger.error(f"Error restoring {self.model.__name__} with id {id}: {str(e)}", exc_info=True)
raise

434
backend/app/crud/organization.py Normal file → Executable file
View File

@@ -1,11 +1,12 @@
# app/crud/organization.py
# app/crud/organization_async.py
"""Async CRUD operations for Organization model using SQLAlchemy 2.0 patterns."""
import logging
from typing import Optional, List, Dict, Any
from uuid import UUID
from sqlalchemy import func, or_, and_
from sqlalchemy import func, or_, and_, select
from sqlalchemy.exc import IntegrityError
from sqlalchemy.orm import Session
from sqlalchemy.ext.asyncio import AsyncSession
from app.crud.base import CRUDBase
from app.models.organization import Organization
@@ -13,20 +14,27 @@ from app.models.user import User
from app.models.user_organization import UserOrganization, OrganizationRole
from app.schemas.organizations import (
OrganizationCreate,
OrganizationUpdate
OrganizationUpdate,
)
logger = logging.getLogger(__name__)
class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUpdate]):
"""CRUD operations for Organization model."""
"""Async CRUD operations for Organization model."""
def get_by_slug(self, db: Session, *, slug: str) -> Optional[Organization]:
async def get_by_slug(self, db: AsyncSession, *, slug: str) -> Optional[Organization]:
"""Get organization by slug."""
return db.query(Organization).filter(Organization.slug == slug).first()
try:
result = await db.execute(
select(Organization).where(Organization.slug == slug)
)
return result.scalar_one_or_none()
except Exception as e:
logger.error(f"Error getting organization by slug {slug}: {str(e)}")
raise
def create(self, db: Session, *, obj_in: OrganizationCreate) -> Organization:
async def create(self, db: AsyncSession, *, obj_in: OrganizationCreate) -> Organization:
"""Create a new organization with error handling."""
try:
db_obj = Organization(
@@ -37,11 +45,11 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp
settings=obj_in.settings or {}
)
db.add(db_obj)
db.commit()
db.refresh(db_obj)
await db.commit()
await db.refresh(db_obj)
return db_obj
except IntegrityError as e:
db.rollback()
await db.rollback()
error_msg = str(e.orig) if hasattr(e, 'orig') else str(e)
if "slug" in error_msg.lower():
logger.warning(f"Duplicate slug attempted: {obj_in.slug}")
@@ -49,13 +57,13 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp
logger.error(f"Integrity error creating organization: {error_msg}")
raise ValueError(f"Database integrity error: {error_msg}")
except Exception as e:
db.rollback()
await db.rollback()
logger.error(f"Unexpected error creating organization: {str(e)}", exc_info=True)
raise
def get_multi_with_filters(
async def get_multi_with_filters(
self,
db: Session,
db: AsyncSession,
*,
skip: int = 0,
limit: int = 100,
@@ -70,47 +78,139 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp
Returns:
Tuple of (organizations list, total count)
"""
query = db.query(Organization)
try:
query = select(Organization)
# Apply filters
if is_active is not None:
query = query.filter(Organization.is_active == is_active)
# Apply filters
if is_active is not None:
query = query.where(Organization.is_active == is_active)
if search:
search_filter = or_(
Organization.name.ilike(f"%{search}%"),
Organization.slug.ilike(f"%{search}%"),
Organization.description.ilike(f"%{search}%")
)
query = query.filter(search_filter)
if search:
search_filter = or_(
Organization.name.ilike(f"%{search}%"),
Organization.slug.ilike(f"%{search}%"),
Organization.description.ilike(f"%{search}%")
)
query = query.where(search_filter)
# Get total count before pagination
total = query.count()
# Get total count before pagination
count_query = select(func.count()).select_from(query.alias())
count_result = await db.execute(count_query)
total = count_result.scalar_one()
# Apply sorting
sort_column = getattr(Organization, sort_by, Organization.created_at)
if sort_order == "desc":
query = query.order_by(sort_column.desc())
else:
query = query.order_by(sort_column.asc())
# Apply sorting
sort_column = getattr(Organization, sort_by, Organization.created_at)
if sort_order == "desc":
query = query.order_by(sort_column.desc())
else:
query = query.order_by(sort_column.asc())
# Apply pagination
organizations = query.offset(skip).limit(limit).all()
# Apply pagination
query = query.offset(skip).limit(limit)
result = await db.execute(query)
organizations = list(result.scalars().all())
return organizations, total
return organizations, total
except Exception as e:
logger.error(f"Error getting organizations with filters: {str(e)}")
raise
def get_member_count(self, db: Session, *, organization_id: UUID) -> int:
async def get_member_count(self, db: AsyncSession, *, organization_id: UUID) -> int:
"""Get the count of active members in an organization."""
return db.query(func.count(UserOrganization.user_id)).filter(
and_(
UserOrganization.organization_id == organization_id,
UserOrganization.is_active == True
try:
result = await db.execute(
select(func.count(UserOrganization.user_id)).where(
and_(
UserOrganization.organization_id == organization_id,
UserOrganization.is_active == True
)
)
)
).scalar() or 0
return result.scalar_one() or 0
except Exception as e:
logger.error(f"Error getting member count for organization {organization_id}: {str(e)}")
raise
def add_user(
async def get_multi_with_member_counts(
self,
db: Session,
db: AsyncSession,
*,
skip: int = 0,
limit: int = 100,
is_active: Optional[bool] = None,
search: Optional[str] = None
) -> tuple[List[Dict[str, Any]], int]:
"""
Get organizations with member counts in a SINGLE QUERY using JOIN and GROUP BY.
This eliminates the N+1 query problem.
Returns:
Tuple of (list of dicts with org and member_count, total count)
"""
try:
# Build base query with LEFT JOIN and GROUP BY
query = (
select(
Organization,
func.count(
func.distinct(
and_(
UserOrganization.is_active == True,
UserOrganization.user_id
).self_group()
)
).label('member_count')
)
.outerjoin(UserOrganization, Organization.id == UserOrganization.organization_id)
.group_by(Organization.id)
)
# Apply filters
if is_active is not None:
query = query.where(Organization.is_active == is_active)
if search:
search_filter = or_(
Organization.name.ilike(f"%{search}%"),
Organization.slug.ilike(f"%{search}%"),
Organization.description.ilike(f"%{search}%")
)
query = query.where(search_filter)
# Get total count
count_query = select(func.count(Organization.id))
if is_active is not None:
count_query = count_query.where(Organization.is_active == is_active)
if search:
count_query = count_query.where(search_filter)
count_result = await db.execute(count_query)
total = count_result.scalar_one()
# Apply pagination and ordering
query = query.order_by(Organization.created_at.desc()).offset(skip).limit(limit)
result = await db.execute(query)
rows = result.all()
# Convert to list of dicts
orgs_with_counts = [
{
'organization': org,
'member_count': member_count
}
for org, member_count in rows
]
return orgs_with_counts, total
except Exception as e:
logger.error(f"Error getting organizations with member counts: {str(e)}", exc_info=True)
raise
async def add_user(
self,
db: AsyncSession,
*,
organization_id: UUID,
user_id: UUID,
@@ -120,12 +220,15 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp
"""Add a user to an organization with a specific role."""
try:
# Check if relationship already exists
existing = db.query(UserOrganization).filter(
and_(
UserOrganization.user_id == user_id,
UserOrganization.organization_id == organization_id
result = await db.execute(
select(UserOrganization).where(
and_(
UserOrganization.user_id == user_id,
UserOrganization.organization_id == organization_id
)
)
).first()
)
existing = result.scalar_one_or_none()
if existing:
# Reactivate if inactive, or raise error if already active
@@ -133,8 +236,8 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp
existing.is_active = True
existing.role = role
existing.custom_permissions = custom_permissions
db.commit()
db.refresh(existing)
await db.commit()
await db.refresh(existing)
return existing
else:
raise ValueError("User is already a member of this organization")
@@ -148,48 +251,51 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp
custom_permissions=custom_permissions
)
db.add(user_org)
db.commit()
db.refresh(user_org)
await db.commit()
await db.refresh(user_org)
return user_org
except IntegrityError as e:
db.rollback()
await db.rollback()
logger.error(f"Integrity error adding user to organization: {str(e)}")
raise ValueError("Failed to add user to organization")
except Exception as e:
db.rollback()
await db.rollback()
logger.error(f"Error adding user to organization: {str(e)}", exc_info=True)
raise
def remove_user(
async def remove_user(
self,
db: Session,
db: AsyncSession,
*,
organization_id: UUID,
user_id: UUID
) -> bool:
"""Remove a user from an organization (soft delete)."""
try:
user_org = db.query(UserOrganization).filter(
and_(
UserOrganization.user_id == user_id,
UserOrganization.organization_id == organization_id
result = await db.execute(
select(UserOrganization).where(
and_(
UserOrganization.user_id == user_id,
UserOrganization.organization_id == organization_id
)
)
).first()
)
user_org = result.scalar_one_or_none()
if not user_org:
return False
user_org.is_active = False
db.commit()
await db.commit()
return True
except Exception as e:
db.rollback()
await db.rollback()
logger.error(f"Error removing user from organization: {str(e)}", exc_info=True)
raise
def update_user_role(
async def update_user_role(
self,
db: Session,
db: AsyncSession,
*,
organization_id: UUID,
user_id: UUID,
@@ -198,12 +304,15 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp
) -> Optional[UserOrganization]:
"""Update a user's role in an organization."""
try:
user_org = db.query(UserOrganization).filter(
and_(
UserOrganization.user_id == user_id,
UserOrganization.organization_id == organization_id
result = await db.execute(
select(UserOrganization).where(
and_(
UserOrganization.user_id == user_id,
UserOrganization.organization_id == organization_id
)
)
).first()
)
user_org = result.scalar_one_or_none()
if not user_org:
return None
@@ -211,17 +320,17 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp
user_org.role = role
if custom_permissions is not None:
user_org.custom_permissions = custom_permissions
db.commit()
db.refresh(user_org)
await db.commit()
await db.refresh(user_org)
return user_org
except Exception as e:
db.rollback()
await db.rollback()
logger.error(f"Error updating user role: {str(e)}", exc_info=True)
raise
def get_organization_members(
async def get_organization_members(
self,
db: Session,
db: AsyncSession,
*,
organization_id: UUID,
skip: int = 0,
@@ -234,86 +343,175 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp
Returns:
Tuple of (members list with user details, total count)
"""
query = db.query(UserOrganization, User).join(
User, UserOrganization.user_id == User.id
).filter(UserOrganization.organization_id == organization_id)
try:
# Build query with join
query = (
select(UserOrganization, User)
.join(User, UserOrganization.user_id == User.id)
.where(UserOrganization.organization_id == organization_id)
)
if is_active is not None:
query = query.filter(UserOrganization.is_active == is_active)
if is_active is not None:
query = query.where(UserOrganization.is_active == is_active)
total = query.count()
# Get total count
count_query = select(func.count()).select_from(
select(UserOrganization)
.where(UserOrganization.organization_id == organization_id)
.where(UserOrganization.is_active == is_active if is_active is not None else True)
.alias()
)
count_result = await db.execute(count_query)
total = count_result.scalar_one()
results = query.order_by(UserOrganization.created_at.desc()).offset(skip).limit(limit).all()
# Apply ordering and pagination
query = query.order_by(UserOrganization.created_at.desc()).offset(skip).limit(limit)
result = await db.execute(query)
results = result.all()
members = []
for user_org, user in results:
members.append({
"user_id": user.id,
"email": user.email,
"first_name": user.first_name,
"last_name": user.last_name,
"role": user_org.role,
"is_active": user_org.is_active,
"joined_at": user_org.created_at
})
members = []
for user_org, user in results:
members.append({
"user_id": user.id,
"email": user.email,
"first_name": user.first_name,
"last_name": user.last_name,
"role": user_org.role,
"is_active": user_org.is_active,
"joined_at": user_org.created_at
})
return members, total
return members, total
except Exception as e:
logger.error(f"Error getting organization members: {str(e)}")
raise
def get_user_organizations(
async def get_user_organizations(
self,
db: Session,
db: AsyncSession,
*,
user_id: UUID,
is_active: bool = True
) -> List[Organization]:
"""Get all organizations a user belongs to."""
query = db.query(Organization).join(
UserOrganization, Organization.id == UserOrganization.organization_id
).filter(UserOrganization.user_id == user_id)
try:
query = (
select(Organization)
.join(UserOrganization, Organization.id == UserOrganization.organization_id)
.where(UserOrganization.user_id == user_id)
)
if is_active is not None:
query = query.filter(UserOrganization.is_active == is_active)
if is_active is not None:
query = query.where(UserOrganization.is_active == is_active)
return query.all()
result = await db.execute(query)
return list(result.scalars().all())
except Exception as e:
logger.error(f"Error getting user organizations: {str(e)}")
raise
def get_user_role_in_org(
async def get_user_organizations_with_details(
self,
db: Session,
db: AsyncSession,
*,
user_id: UUID,
is_active: bool = True
) -> List[Dict[str, Any]]:
"""
Get user's organizations with role and member count in SINGLE QUERY.
Eliminates N+1 problem by using subquery for member counts.
Returns:
List of dicts with organization, role, and member_count
"""
try:
# Subquery to get member counts for each organization
member_count_subq = (
select(
UserOrganization.organization_id,
func.count(UserOrganization.user_id).label('member_count')
)
.where(UserOrganization.is_active == True)
.group_by(UserOrganization.organization_id)
.subquery()
)
# Main query with JOIN to get org, role, and member count
query = (
select(
Organization,
UserOrganization.role,
func.coalesce(member_count_subq.c.member_count, 0).label('member_count')
)
.join(UserOrganization, Organization.id == UserOrganization.organization_id)
.outerjoin(member_count_subq, Organization.id == member_count_subq.c.organization_id)
.where(UserOrganization.user_id == user_id)
)
if is_active is not None:
query = query.where(UserOrganization.is_active == is_active)
result = await db.execute(query)
rows = result.all()
return [
{
'organization': org,
'role': role,
'member_count': member_count
}
for org, role, member_count in rows
]
except Exception as e:
logger.error(f"Error getting user organizations with details: {str(e)}", exc_info=True)
raise
async def get_user_role_in_org(
self,
db: AsyncSession,
*,
user_id: UUID,
organization_id: UUID
) -> Optional[OrganizationRole]:
"""Get a user's role in a specific organization."""
user_org = db.query(UserOrganization).filter(
and_(
UserOrganization.user_id == user_id,
UserOrganization.organization_id == organization_id,
UserOrganization.is_active == True
try:
result = await db.execute(
select(UserOrganization).where(
and_(
UserOrganization.user_id == user_id,
UserOrganization.organization_id == organization_id,
UserOrganization.is_active == True
)
)
)
).first()
user_org = result.scalar_one_or_none()
return user_org.role if user_org else None
return user_org.role if user_org else None
except Exception as e:
logger.error(f"Error getting user role in org: {str(e)}")
raise
def is_user_org_owner(
async def is_user_org_owner(
self,
db: Session,
db: AsyncSession,
*,
user_id: UUID,
organization_id: UUID
) -> bool:
"""Check if a user is an owner of an organization."""
role = self.get_user_role_in_org(db, user_id=user_id, organization_id=organization_id)
role = await self.get_user_role_in_org(db, user_id=user_id, organization_id=organization_id)
return role == OrganizationRole.OWNER
def is_user_org_admin(
async def is_user_org_admin(
self,
db: Session,
db: AsyncSession,
*,
user_id: UUID,
organization_id: UUID
) -> bool:
"""Check if a user is an owner or admin of an organization."""
role = self.get_user_role_in_org(db, user_id=user_id, organization_id=organization_id)
role = await self.get_user_role_in_org(db, user_id=user_id, organization_id=organization_id)
return role in [OrganizationRole.OWNER, OrganizationRole.ADMIN]

View File

@@ -1,519 +0,0 @@
# app/crud/organization_async.py
"""Async CRUD operations for Organization model using SQLAlchemy 2.0 patterns."""
import logging
from typing import Optional, List, Dict, Any
from uuid import UUID
from sqlalchemy import func, or_, and_, select
from sqlalchemy.exc import IntegrityError
from sqlalchemy.ext.asyncio import AsyncSession
from app.crud.base_async import CRUDBaseAsync
from app.models.organization import Organization
from app.models.user import User
from app.models.user_organization import UserOrganization, OrganizationRole
from app.schemas.organizations import (
OrganizationCreate,
OrganizationUpdate,
)
logger = logging.getLogger(__name__)
class CRUDOrganizationAsync(CRUDBaseAsync[Organization, OrganizationCreate, OrganizationUpdate]):
"""Async CRUD operations for Organization model."""
async def get_by_slug(self, db: AsyncSession, *, slug: str) -> Optional[Organization]:
"""Get organization by slug."""
try:
result = await db.execute(
select(Organization).where(Organization.slug == slug)
)
return result.scalar_one_or_none()
except Exception as e:
logger.error(f"Error getting organization by slug {slug}: {str(e)}")
raise
async def create(self, db: AsyncSession, *, obj_in: OrganizationCreate) -> Organization:
"""Create a new organization with error handling."""
try:
db_obj = Organization(
name=obj_in.name,
slug=obj_in.slug,
description=obj_in.description,
is_active=obj_in.is_active,
settings=obj_in.settings or {}
)
db.add(db_obj)
await db.commit()
await db.refresh(db_obj)
return db_obj
except IntegrityError as e:
await db.rollback()
error_msg = str(e.orig) if hasattr(e, 'orig') else str(e)
if "slug" in error_msg.lower():
logger.warning(f"Duplicate slug attempted: {obj_in.slug}")
raise ValueError(f"Organization with slug '{obj_in.slug}' already exists")
logger.error(f"Integrity error creating organization: {error_msg}")
raise ValueError(f"Database integrity error: {error_msg}")
except Exception as e:
await db.rollback()
logger.error(f"Unexpected error creating organization: {str(e)}", exc_info=True)
raise
async def get_multi_with_filters(
self,
db: AsyncSession,
*,
skip: int = 0,
limit: int = 100,
is_active: Optional[bool] = None,
search: Optional[str] = None,
sort_by: str = "created_at",
sort_order: str = "desc"
) -> tuple[List[Organization], int]:
"""
Get multiple organizations with filtering, searching, and sorting.
Returns:
Tuple of (organizations list, total count)
"""
try:
query = select(Organization)
# Apply filters
if is_active is not None:
query = query.where(Organization.is_active == is_active)
if search:
search_filter = or_(
Organization.name.ilike(f"%{search}%"),
Organization.slug.ilike(f"%{search}%"),
Organization.description.ilike(f"%{search}%")
)
query = query.where(search_filter)
# Get total count before pagination
count_query = select(func.count()).select_from(query.alias())
count_result = await db.execute(count_query)
total = count_result.scalar_one()
# Apply sorting
sort_column = getattr(Organization, sort_by, Organization.created_at)
if sort_order == "desc":
query = query.order_by(sort_column.desc())
else:
query = query.order_by(sort_column.asc())
# Apply pagination
query = query.offset(skip).limit(limit)
result = await db.execute(query)
organizations = list(result.scalars().all())
return organizations, total
except Exception as e:
logger.error(f"Error getting organizations with filters: {str(e)}")
raise
async def get_member_count(self, db: AsyncSession, *, organization_id: UUID) -> int:
"""Get the count of active members in an organization."""
try:
result = await db.execute(
select(func.count(UserOrganization.user_id)).where(
and_(
UserOrganization.organization_id == organization_id,
UserOrganization.is_active == True
)
)
)
return result.scalar_one() or 0
except Exception as e:
logger.error(f"Error getting member count for organization {organization_id}: {str(e)}")
raise
async def get_multi_with_member_counts(
self,
db: AsyncSession,
*,
skip: int = 0,
limit: int = 100,
is_active: Optional[bool] = None,
search: Optional[str] = None
) -> tuple[List[Dict[str, Any]], int]:
"""
Get organizations with member counts in a SINGLE QUERY using JOIN and GROUP BY.
This eliminates the N+1 query problem.
Returns:
Tuple of (list of dicts with org and member_count, total count)
"""
try:
# Build base query with LEFT JOIN and GROUP BY
query = (
select(
Organization,
func.count(
func.distinct(
and_(
UserOrganization.is_active == True,
UserOrganization.user_id
).self_group()
)
).label('member_count')
)
.outerjoin(UserOrganization, Organization.id == UserOrganization.organization_id)
.group_by(Organization.id)
)
# Apply filters
if is_active is not None:
query = query.where(Organization.is_active == is_active)
if search:
search_filter = or_(
Organization.name.ilike(f"%{search}%"),
Organization.slug.ilike(f"%{search}%"),
Organization.description.ilike(f"%{search}%")
)
query = query.where(search_filter)
# Get total count
count_query = select(func.count(Organization.id))
if is_active is not None:
count_query = count_query.where(Organization.is_active == is_active)
if search:
count_query = count_query.where(search_filter)
count_result = await db.execute(count_query)
total = count_result.scalar_one()
# Apply pagination and ordering
query = query.order_by(Organization.created_at.desc()).offset(skip).limit(limit)
result = await db.execute(query)
rows = result.all()
# Convert to list of dicts
orgs_with_counts = [
{
'organization': org,
'member_count': member_count
}
for org, member_count in rows
]
return orgs_with_counts, total
except Exception as e:
logger.error(f"Error getting organizations with member counts: {str(e)}", exc_info=True)
raise
async def add_user(
self,
db: AsyncSession,
*,
organization_id: UUID,
user_id: UUID,
role: OrganizationRole = OrganizationRole.MEMBER,
custom_permissions: Optional[str] = None
) -> UserOrganization:
"""Add a user to an organization with a specific role."""
try:
# Check if relationship already exists
result = await db.execute(
select(UserOrganization).where(
and_(
UserOrganization.user_id == user_id,
UserOrganization.organization_id == organization_id
)
)
)
existing = result.scalar_one_or_none()
if existing:
# Reactivate if inactive, or raise error if already active
if not existing.is_active:
existing.is_active = True
existing.role = role
existing.custom_permissions = custom_permissions
await db.commit()
await db.refresh(existing)
return existing
else:
raise ValueError("User is already a member of this organization")
# Create new relationship
user_org = UserOrganization(
user_id=user_id,
organization_id=organization_id,
role=role,
is_active=True,
custom_permissions=custom_permissions
)
db.add(user_org)
await db.commit()
await db.refresh(user_org)
return user_org
except IntegrityError as e:
await db.rollback()
logger.error(f"Integrity error adding user to organization: {str(e)}")
raise ValueError("Failed to add user to organization")
except Exception as e:
await db.rollback()
logger.error(f"Error adding user to organization: {str(e)}", exc_info=True)
raise
async def remove_user(
self,
db: AsyncSession,
*,
organization_id: UUID,
user_id: UUID
) -> bool:
"""Remove a user from an organization (soft delete)."""
try:
result = await db.execute(
select(UserOrganization).where(
and_(
UserOrganization.user_id == user_id,
UserOrganization.organization_id == organization_id
)
)
)
user_org = result.scalar_one_or_none()
if not user_org:
return False
user_org.is_active = False
await db.commit()
return True
except Exception as e:
await db.rollback()
logger.error(f"Error removing user from organization: {str(e)}", exc_info=True)
raise
async def update_user_role(
self,
db: AsyncSession,
*,
organization_id: UUID,
user_id: UUID,
role: OrganizationRole,
custom_permissions: Optional[str] = None
) -> Optional[UserOrganization]:
"""Update a user's role in an organization."""
try:
result = await db.execute(
select(UserOrganization).where(
and_(
UserOrganization.user_id == user_id,
UserOrganization.organization_id == organization_id
)
)
)
user_org = result.scalar_one_or_none()
if not user_org:
return None
user_org.role = role
if custom_permissions is not None:
user_org.custom_permissions = custom_permissions
await db.commit()
await db.refresh(user_org)
return user_org
except Exception as e:
await db.rollback()
logger.error(f"Error updating user role: {str(e)}", exc_info=True)
raise
async def get_organization_members(
self,
db: AsyncSession,
*,
organization_id: UUID,
skip: int = 0,
limit: int = 100,
is_active: bool = True
) -> tuple[List[Dict[str, Any]], int]:
"""
Get members of an organization with user details.
Returns:
Tuple of (members list with user details, total count)
"""
try:
# Build query with join
query = (
select(UserOrganization, User)
.join(User, UserOrganization.user_id == User.id)
.where(UserOrganization.organization_id == organization_id)
)
if is_active is not None:
query = query.where(UserOrganization.is_active == is_active)
# Get total count
count_query = select(func.count()).select_from(
select(UserOrganization)
.where(UserOrganization.organization_id == organization_id)
.where(UserOrganization.is_active == is_active if is_active is not None else True)
.alias()
)
count_result = await db.execute(count_query)
total = count_result.scalar_one()
# Apply ordering and pagination
query = query.order_by(UserOrganization.created_at.desc()).offset(skip).limit(limit)
result = await db.execute(query)
results = result.all()
members = []
for user_org, user in results:
members.append({
"user_id": user.id,
"email": user.email,
"first_name": user.first_name,
"last_name": user.last_name,
"role": user_org.role,
"is_active": user_org.is_active,
"joined_at": user_org.created_at
})
return members, total
except Exception as e:
logger.error(f"Error getting organization members: {str(e)}")
raise
async def get_user_organizations(
self,
db: AsyncSession,
*,
user_id: UUID,
is_active: bool = True
) -> List[Organization]:
"""Get all organizations a user belongs to."""
try:
query = (
select(Organization)
.join(UserOrganization, Organization.id == UserOrganization.organization_id)
.where(UserOrganization.user_id == user_id)
)
if is_active is not None:
query = query.where(UserOrganization.is_active == is_active)
result = await db.execute(query)
return list(result.scalars().all())
except Exception as e:
logger.error(f"Error getting user organizations: {str(e)}")
raise
async def get_user_organizations_with_details(
self,
db: AsyncSession,
*,
user_id: UUID,
is_active: bool = True
) -> List[Dict[str, Any]]:
"""
Get user's organizations with role and member count in SINGLE QUERY.
Eliminates N+1 problem by using subquery for member counts.
Returns:
List of dicts with organization, role, and member_count
"""
try:
# Subquery to get member counts for each organization
member_count_subq = (
select(
UserOrganization.organization_id,
func.count(UserOrganization.user_id).label('member_count')
)
.where(UserOrganization.is_active == True)
.group_by(UserOrganization.organization_id)
.subquery()
)
# Main query with JOIN to get org, role, and member count
query = (
select(
Organization,
UserOrganization.role,
func.coalesce(member_count_subq.c.member_count, 0).label('member_count')
)
.join(UserOrganization, Organization.id == UserOrganization.organization_id)
.outerjoin(member_count_subq, Organization.id == member_count_subq.c.organization_id)
.where(UserOrganization.user_id == user_id)
)
if is_active is not None:
query = query.where(UserOrganization.is_active == is_active)
result = await db.execute(query)
rows = result.all()
return [
{
'organization': org,
'role': role,
'member_count': member_count
}
for org, role, member_count in rows
]
except Exception as e:
logger.error(f"Error getting user organizations with details: {str(e)}", exc_info=True)
raise
async def get_user_role_in_org(
self,
db: AsyncSession,
*,
user_id: UUID,
organization_id: UUID
) -> Optional[OrganizationRole]:
"""Get a user's role in a specific organization."""
try:
result = await db.execute(
select(UserOrganization).where(
and_(
UserOrganization.user_id == user_id,
UserOrganization.organization_id == organization_id,
UserOrganization.is_active == True
)
)
)
user_org = result.scalar_one_or_none()
return user_org.role if user_org else None
except Exception as e:
logger.error(f"Error getting user role in org: {str(e)}")
raise
async def is_user_org_owner(
self,
db: AsyncSession,
*,
user_id: UUID,
organization_id: UUID
) -> bool:
"""Check if a user is an owner of an organization."""
role = await self.get_user_role_in_org(db, user_id=user_id, organization_id=organization_id)
return role == OrganizationRole.OWNER
async def is_user_org_admin(
self,
db: AsyncSession,
*,
user_id: UUID,
organization_id: UUID
) -> bool:
"""Check if a user is an owner or admin of an organization."""
role = await self.get_user_role_in_org(db, user_id=user_id, organization_id=organization_id)
return role in [OrganizationRole.OWNER, OrganizationRole.ADMIN]
# Create a singleton instance for use across the application
organization_async = CRUDOrganizationAsync(Organization)

220
backend/app/crud/session.py Normal file → Executable file
View File

@@ -1,13 +1,14 @@
"""
CRUD operations for user sessions.
Async CRUD operations for user sessions using SQLAlchemy 2.0 patterns.
"""
import logging
from datetime import datetime, timezone, timedelta
from typing import List, Optional
from uuid import UUID
from sqlalchemy import and_
from sqlalchemy.orm import Session
from sqlalchemy import and_, select, update, delete, func
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import joinedload
from app.crud.base import CRUDBase
from app.models.user_session import UserSession
@@ -17,9 +18,9 @@ logger = logging.getLogger(__name__)
class CRUDSession(CRUDBase[UserSession, SessionCreate, SessionUpdate]):
"""CRUD operations for user sessions."""
"""Async CRUD operations for user sessions."""
def get_by_jti(self, db: Session, *, jti: str) -> Optional[UserSession]:
async def get_by_jti(self, db: AsyncSession, *, jti: str) -> Optional[UserSession]:
"""
Get session by refresh token JTI.
@@ -31,14 +32,15 @@ class CRUDSession(CRUDBase[UserSession, SessionCreate, SessionUpdate]):
UserSession if found, None otherwise
"""
try:
return db.query(UserSession).filter(
UserSession.refresh_token_jti == jti
).first()
result = await db.execute(
select(UserSession).where(UserSession.refresh_token_jti == jti)
)
return result.scalar_one_or_none()
except Exception as e:
logger.error(f"Error getting session by JTI {jti}: {str(e)}")
raise
def get_active_by_jti(self, db: Session, *, jti: str) -> Optional[UserSession]:
async def get_active_by_jti(self, db: AsyncSession, *, jti: str) -> Optional[UserSession]:
"""
Get active session by refresh token JTI.
@@ -50,30 +52,35 @@ class CRUDSession(CRUDBase[UserSession, SessionCreate, SessionUpdate]):
Active UserSession if found, None otherwise
"""
try:
return db.query(UserSession).filter(
and_(
UserSession.refresh_token_jti == jti,
UserSession.is_active == True
result = await db.execute(
select(UserSession).where(
and_(
UserSession.refresh_token_jti == jti,
UserSession.is_active == True
)
)
).first()
)
return result.scalar_one_or_none()
except Exception as e:
logger.error(f"Error getting active session by JTI {jti}: {str(e)}")
raise
def get_user_sessions(
async def get_user_sessions(
self,
db: Session,
db: AsyncSession,
*,
user_id: str,
active_only: bool = True
active_only: bool = True,
with_user: bool = False
) -> List[UserSession]:
"""
Get all sessions for a user.
Get all sessions for a user with optional eager loading.
Args:
db: Database session
user_id: User ID
active_only: If True, return only active sessions
with_user: If True, eager load user relationship to prevent N+1
Returns:
List of UserSession objects
@@ -82,19 +89,25 @@ class CRUDSession(CRUDBase[UserSession, SessionCreate, SessionUpdate]):
# Convert user_id string to UUID if needed
user_uuid = UUID(user_id) if isinstance(user_id, str) else user_id
query = db.query(UserSession).filter(UserSession.user_id == user_uuid)
query = select(UserSession).where(UserSession.user_id == user_uuid)
# Add eager loading if requested to prevent N+1 queries
if with_user:
query = query.options(joinedload(UserSession.user))
if active_only:
query = query.filter(UserSession.is_active == True)
query = query.where(UserSession.is_active == True)
return query.order_by(UserSession.last_used_at.desc()).all()
query = query.order_by(UserSession.last_used_at.desc())
result = await db.execute(query)
return list(result.scalars().all())
except Exception as e:
logger.error(f"Error getting sessions for user {user_id}: {str(e)}")
raise
def create_session(
async def create_session(
self,
db: Session,
db: AsyncSession,
*,
obj_in: SessionCreate
) -> UserSession:
@@ -126,8 +139,8 @@ class CRUDSession(CRUDBase[UserSession, SessionCreate, SessionUpdate]):
location_country=obj_in.location_country,
)
db.add(db_obj)
db.commit()
db.refresh(db_obj)
await db.commit()
await db.refresh(db_obj)
logger.info(
f"Session created for user {obj_in.user_id} from {obj_in.device_name} "
@@ -136,11 +149,11 @@ class CRUDSession(CRUDBase[UserSession, SessionCreate, SessionUpdate]):
return db_obj
except Exception as e:
db.rollback()
await db.rollback()
logger.error(f"Error creating session: {str(e)}", exc_info=True)
raise ValueError(f"Failed to create session: {str(e)}")
def deactivate(self, db: Session, *, session_id: str) -> Optional[UserSession]:
async def deactivate(self, db: AsyncSession, *, session_id: str) -> Optional[UserSession]:
"""
Deactivate a session (logout from device).
@@ -152,15 +165,15 @@ class CRUDSession(CRUDBase[UserSession, SessionCreate, SessionUpdate]):
Deactivated UserSession if found, None otherwise
"""
try:
session = self.get(db, id=session_id)
session = await self.get(db, id=session_id)
if not session:
logger.warning(f"Session {session_id} not found for deactivation")
return None
session.is_active = False
db.add(session)
db.commit()
db.refresh(session)
await db.commit()
await db.refresh(session)
logger.info(
f"Session {session_id} deactivated for user {session.user_id} "
@@ -169,13 +182,13 @@ class CRUDSession(CRUDBase[UserSession, SessionCreate, SessionUpdate]):
return session
except Exception as e:
db.rollback()
await db.rollback()
logger.error(f"Error deactivating session {session_id}: {str(e)}")
raise
def deactivate_all_user_sessions(
async def deactivate_all_user_sessions(
self,
db: Session,
db: AsyncSession,
*,
user_id: str
) -> int:
@@ -193,26 +206,33 @@ class CRUDSession(CRUDBase[UserSession, SessionCreate, SessionUpdate]):
# Convert user_id string to UUID if needed
user_uuid = UUID(user_id) if isinstance(user_id, str) else user_id
count = db.query(UserSession).filter(
and_(
UserSession.user_id == user_uuid,
UserSession.is_active == True
stmt = (
update(UserSession)
.where(
and_(
UserSession.user_id == user_uuid,
UserSession.is_active == True
)
)
).update({"is_active": False})
.values(is_active=False)
)
db.commit()
result = await db.execute(stmt)
await db.commit()
count = result.rowcount
logger.info(f"Deactivated {count} sessions for user {user_id}")
return count
except Exception as e:
db.rollback()
await db.rollback()
logger.error(f"Error deactivating all sessions for user {user_id}: {str(e)}")
raise
def update_last_used(
async def update_last_used(
self,
db: Session,
db: AsyncSession,
*,
session: UserSession
) -> UserSession:
@@ -229,17 +249,17 @@ class CRUDSession(CRUDBase[UserSession, SessionCreate, SessionUpdate]):
try:
session.last_used_at = datetime.now(timezone.utc)
db.add(session)
db.commit()
db.refresh(session)
await db.commit()
await db.refresh(session)
return session
except Exception as e:
db.rollback()
await db.rollback()
logger.error(f"Error updating last_used for session {session.id}: {str(e)}")
raise
def update_refresh_token(
async def update_refresh_token(
self,
db: Session,
db: AsyncSession,
*,
session: UserSession,
new_jti: str,
@@ -264,22 +284,24 @@ class CRUDSession(CRUDBase[UserSession, SessionCreate, SessionUpdate]):
session.expires_at = new_expires_at
session.last_used_at = datetime.now(timezone.utc)
db.add(session)
db.commit()
db.refresh(session)
await db.commit()
await db.refresh(session)
return session
except Exception as e:
db.rollback()
await db.rollback()
logger.error(f"Error updating refresh token for session {session.id}: {str(e)}")
raise
def cleanup_expired(self, db: Session, *, keep_days: int = 30) -> int:
async def cleanup_expired(self, db: AsyncSession, *, keep_days: int = 30) -> int:
"""
Clean up expired sessions.
Clean up expired sessions using optimized bulk DELETE.
Deletes sessions that are:
- Expired AND inactive
- Older than keep_days
Uses single DELETE query instead of N individual deletes for efficiency.
Args:
db: Database session
keep_days: Keep inactive sessions for this many days (for audit)
@@ -289,31 +311,87 @@ class CRUDSession(CRUDBase[UserSession, SessionCreate, SessionUpdate]):
"""
try:
cutoff_date = datetime.now(timezone.utc) - timedelta(days=keep_days)
now = datetime.now(timezone.utc)
# Delete sessions that are:
# 1. Expired (expires_at < now) AND inactive
# AND
# 2. Older than keep_days
count = db.query(UserSession).filter(
# Use bulk DELETE with WHERE clause - single query
stmt = delete(UserSession).where(
and_(
UserSession.is_active == False,
UserSession.expires_at < datetime.now(timezone.utc),
UserSession.expires_at < now,
UserSession.created_at < cutoff_date
)
).delete()
)
db.commit()
result = await db.execute(stmt)
await db.commit()
count = result.rowcount
if count > 0:
logger.info(f"Cleaned up {count} expired sessions")
logger.info(f"Cleaned up {count} expired sessions using bulk DELETE")
return count
except Exception as e:
db.rollback()
await db.rollback()
logger.error(f"Error cleaning up expired sessions: {str(e)}")
raise
def get_user_session_count(self, db: Session, *, user_id: str) -> int:
async def cleanup_expired_for_user(
self,
db: AsyncSession,
*,
user_id: str
) -> int:
"""
Clean up expired and inactive sessions for a specific user.
Uses single bulk DELETE query for efficiency instead of N individual deletes.
Args:
db: Database session
user_id: User ID to cleanup sessions for
Returns:
Number of sessions deleted
"""
try:
# Validate UUID
try:
uuid_obj = uuid.UUID(user_id)
except (ValueError, AttributeError):
logger.error(f"Invalid UUID format: {user_id}")
raise ValueError(f"Invalid user ID format: {user_id}")
now = datetime.now(timezone.utc)
# Use bulk DELETE with WHERE clause - single query
stmt = delete(UserSession).where(
and_(
UserSession.user_id == uuid_obj,
UserSession.is_active == False,
UserSession.expires_at < now
)
)
result = await db.execute(stmt)
await db.commit()
count = result.rowcount
if count > 0:
logger.info(
f"Cleaned up {count} expired sessions for user {user_id} using bulk DELETE"
)
return count
except Exception as e:
await db.rollback()
logger.error(
f"Error cleaning up expired sessions for user {user_id}: {str(e)}"
)
raise
async def get_user_session_count(self, db: AsyncSession, *, user_id: str) -> int:
"""
Get count of active sessions for a user.
@@ -325,12 +403,18 @@ class CRUDSession(CRUDBase[UserSession, SessionCreate, SessionUpdate]):
Number of active sessions
"""
try:
return db.query(UserSession).filter(
and_(
UserSession.user_id == user_id,
UserSession.is_active == True
# Convert user_id string to UUID if needed
user_uuid = UUID(user_id) if isinstance(user_id, str) else user_id
result = await db.execute(
select(func.count(UserSession.id)).where(
and_(
UserSession.user_id == user_uuid,
UserSession.is_active == True
)
)
).count()
)
return result.scalar_one()
except Exception as e:
logger.error(f"Error counting sessions for user {user_id}: {str(e)}")
raise

View File

@@ -1,424 +0,0 @@
"""
Async CRUD operations for user sessions using SQLAlchemy 2.0 patterns.
"""
import logging
from datetime import datetime, timezone, timedelta
from typing import List, Optional
from uuid import UUID
from sqlalchemy import and_, select, update, delete, func
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import joinedload
from app.crud.base_async import CRUDBaseAsync
from app.models.user_session import UserSession
from app.schemas.sessions import SessionCreate, SessionUpdate
logger = logging.getLogger(__name__)
class CRUDSessionAsync(CRUDBaseAsync[UserSession, SessionCreate, SessionUpdate]):
"""Async CRUD operations for user sessions."""
async def get_by_jti(self, db: AsyncSession, *, jti: str) -> Optional[UserSession]:
"""
Get session by refresh token JTI.
Args:
db: Database session
jti: Refresh token JWT ID
Returns:
UserSession if found, None otherwise
"""
try:
result = await db.execute(
select(UserSession).where(UserSession.refresh_token_jti == jti)
)
return result.scalar_one_or_none()
except Exception as e:
logger.error(f"Error getting session by JTI {jti}: {str(e)}")
raise
async def get_active_by_jti(self, db: AsyncSession, *, jti: str) -> Optional[UserSession]:
"""
Get active session by refresh token JTI.
Args:
db: Database session
jti: Refresh token JWT ID
Returns:
Active UserSession if found, None otherwise
"""
try:
result = await db.execute(
select(UserSession).where(
and_(
UserSession.refresh_token_jti == jti,
UserSession.is_active == True
)
)
)
return result.scalar_one_or_none()
except Exception as e:
logger.error(f"Error getting active session by JTI {jti}: {str(e)}")
raise
async def get_user_sessions(
self,
db: AsyncSession,
*,
user_id: str,
active_only: bool = True,
with_user: bool = False
) -> List[UserSession]:
"""
Get all sessions for a user with optional eager loading.
Args:
db: Database session
user_id: User ID
active_only: If True, return only active sessions
with_user: If True, eager load user relationship to prevent N+1
Returns:
List of UserSession objects
"""
try:
# Convert user_id string to UUID if needed
user_uuid = UUID(user_id) if isinstance(user_id, str) else user_id
query = select(UserSession).where(UserSession.user_id == user_uuid)
# Add eager loading if requested to prevent N+1 queries
if with_user:
query = query.options(joinedload(UserSession.user))
if active_only:
query = query.where(UserSession.is_active == True)
query = query.order_by(UserSession.last_used_at.desc())
result = await db.execute(query)
return list(result.scalars().all())
except Exception as e:
logger.error(f"Error getting sessions for user {user_id}: {str(e)}")
raise
async def create_session(
self,
db: AsyncSession,
*,
obj_in: SessionCreate
) -> UserSession:
"""
Create a new user session.
Args:
db: Database session
obj_in: SessionCreate schema with session data
Returns:
Created UserSession
Raises:
ValueError: If session creation fails
"""
try:
db_obj = UserSession(
user_id=obj_in.user_id,
refresh_token_jti=obj_in.refresh_token_jti,
device_name=obj_in.device_name,
device_id=obj_in.device_id,
ip_address=obj_in.ip_address,
user_agent=obj_in.user_agent,
last_used_at=obj_in.last_used_at,
expires_at=obj_in.expires_at,
is_active=True,
location_city=obj_in.location_city,
location_country=obj_in.location_country,
)
db.add(db_obj)
await db.commit()
await db.refresh(db_obj)
logger.info(
f"Session created for user {obj_in.user_id} from {obj_in.device_name} "
f"(IP: {obj_in.ip_address})"
)
return db_obj
except Exception as e:
await db.rollback()
logger.error(f"Error creating session: {str(e)}", exc_info=True)
raise ValueError(f"Failed to create session: {str(e)}")
async def deactivate(self, db: AsyncSession, *, session_id: str) -> Optional[UserSession]:
"""
Deactivate a session (logout from device).
Args:
db: Database session
session_id: Session UUID
Returns:
Deactivated UserSession if found, None otherwise
"""
try:
session = await self.get(db, id=session_id)
if not session:
logger.warning(f"Session {session_id} not found for deactivation")
return None
session.is_active = False
db.add(session)
await db.commit()
await db.refresh(session)
logger.info(
f"Session {session_id} deactivated for user {session.user_id} "
f"({session.device_name})"
)
return session
except Exception as e:
await db.rollback()
logger.error(f"Error deactivating session {session_id}: {str(e)}")
raise
async def deactivate_all_user_sessions(
self,
db: AsyncSession,
*,
user_id: str
) -> int:
"""
Deactivate all active sessions for a user (logout from all devices).
Args:
db: Database session
user_id: User ID
Returns:
Number of sessions deactivated
"""
try:
# Convert user_id string to UUID if needed
user_uuid = UUID(user_id) if isinstance(user_id, str) else user_id
stmt = (
update(UserSession)
.where(
and_(
UserSession.user_id == user_uuid,
UserSession.is_active == True
)
)
.values(is_active=False)
)
result = await db.execute(stmt)
await db.commit()
count = result.rowcount
logger.info(f"Deactivated {count} sessions for user {user_id}")
return count
except Exception as e:
await db.rollback()
logger.error(f"Error deactivating all sessions for user {user_id}: {str(e)}")
raise
async def update_last_used(
self,
db: AsyncSession,
*,
session: UserSession
) -> UserSession:
"""
Update the last_used_at timestamp for a session.
Args:
db: Database session
session: UserSession object
Returns:
Updated UserSession
"""
try:
session.last_used_at = datetime.now(timezone.utc)
db.add(session)
await db.commit()
await db.refresh(session)
return session
except Exception as e:
await db.rollback()
logger.error(f"Error updating last_used for session {session.id}: {str(e)}")
raise
async def update_refresh_token(
self,
db: AsyncSession,
*,
session: UserSession,
new_jti: str,
new_expires_at: datetime
) -> UserSession:
"""
Update session with new refresh token JTI and expiration.
Called during token refresh.
Args:
db: Database session
session: UserSession object
new_jti: New refresh token JTI
new_expires_at: New expiration datetime
Returns:
Updated UserSession
"""
try:
session.refresh_token_jti = new_jti
session.expires_at = new_expires_at
session.last_used_at = datetime.now(timezone.utc)
db.add(session)
await db.commit()
await db.refresh(session)
return session
except Exception as e:
await db.rollback()
logger.error(f"Error updating refresh token for session {session.id}: {str(e)}")
raise
async def cleanup_expired(self, db: AsyncSession, *, keep_days: int = 30) -> int:
"""
Clean up expired sessions using optimized bulk DELETE.
Deletes sessions that are:
- Expired AND inactive
- Older than keep_days
Uses single DELETE query instead of N individual deletes for efficiency.
Args:
db: Database session
keep_days: Keep inactive sessions for this many days (for audit)
Returns:
Number of sessions deleted
"""
try:
cutoff_date = datetime.now(timezone.utc) - timedelta(days=keep_days)
now = datetime.now(timezone.utc)
# Use bulk DELETE with WHERE clause - single query
stmt = delete(UserSession).where(
and_(
UserSession.is_active == False,
UserSession.expires_at < now,
UserSession.created_at < cutoff_date
)
)
result = await db.execute(stmt)
await db.commit()
count = result.rowcount
if count > 0:
logger.info(f"Cleaned up {count} expired sessions using bulk DELETE")
return count
except Exception as e:
await db.rollback()
logger.error(f"Error cleaning up expired sessions: {str(e)}")
raise
async def cleanup_expired_for_user(
self,
db: AsyncSession,
*,
user_id: str
) -> int:
"""
Clean up expired and inactive sessions for a specific user.
Uses single bulk DELETE query for efficiency instead of N individual deletes.
Args:
db: Database session
user_id: User ID to cleanup sessions for
Returns:
Number of sessions deleted
"""
try:
# Validate UUID
try:
uuid_obj = uuid.UUID(user_id)
except (ValueError, AttributeError):
logger.error(f"Invalid UUID format: {user_id}")
raise ValueError(f"Invalid user ID format: {user_id}")
now = datetime.now(timezone.utc)
# Use bulk DELETE with WHERE clause - single query
stmt = delete(UserSession).where(
and_(
UserSession.user_id == uuid_obj,
UserSession.is_active == False,
UserSession.expires_at < now
)
)
result = await db.execute(stmt)
await db.commit()
count = result.rowcount
if count > 0:
logger.info(
f"Cleaned up {count} expired sessions for user {user_id} using bulk DELETE"
)
return count
except Exception as e:
await db.rollback()
logger.error(
f"Error cleaning up expired sessions for user {user_id}: {str(e)}"
)
raise
async def get_user_session_count(self, db: AsyncSession, *, user_id: str) -> int:
"""
Get count of active sessions for a user.
Args:
db: Database session
user_id: User ID
Returns:
Number of active sessions
"""
try:
# Convert user_id string to UUID if needed
user_uuid = UUID(user_id) if isinstance(user_id, str) else user_id
result = await db.execute(
select(func.count(UserSession.id)).where(
and_(
UserSession.user_id == user_uuid,
UserSession.is_active == True
)
)
)
return result.scalar_one()
except Exception as e:
logger.error(f"Error counting sessions for user {user_id}: {str(e)}")
raise
# Create singleton instance
session_async = CRUDSessionAsync(UserSession)

183
backend/app/crud/user.py Normal file → Executable file
View File

@@ -1,12 +1,15 @@
# app/crud/user.py
# app/crud/user_async.py
"""Async CRUD operations for User model using SQLAlchemy 2.0 patterns."""
import logging
from datetime import datetime, timezone
from typing import Optional, Union, Dict, Any, List, Tuple
from uuid import UUID
from sqlalchemy import or_, asc, desc
from sqlalchemy import or_, select, update
from sqlalchemy.exc import IntegrityError
from sqlalchemy.orm import Session
from sqlalchemy.ext.asyncio import AsyncSession
from app.core.auth import get_password_hash
from app.core.auth import get_password_hash_async
from app.crud.base import CRUDBase
from app.models.user import User
from app.schemas.users import UserCreate, UserUpdate
@@ -15,15 +18,28 @@ logger = logging.getLogger(__name__)
class CRUDUser(CRUDBase[User, UserCreate, UserUpdate]):
def get_by_email(self, db: Session, *, email: str) -> Optional[User]:
return db.query(User).filter(User.email == email).first()
"""Async CRUD operations for User model."""
def create(self, db: Session, *, obj_in: UserCreate) -> User:
"""Create a new user with password hashing and error handling."""
async def get_by_email(self, db: AsyncSession, *, email: str) -> Optional[User]:
"""Get user by email address."""
try:
result = await db.execute(
select(User).where(User.email == email)
)
return result.scalar_one_or_none()
except Exception as e:
logger.error(f"Error getting user by email {email}: {str(e)}")
raise
async def create(self, db: AsyncSession, *, obj_in: UserCreate) -> User:
"""Create a new user with async password hashing and error handling."""
try:
# Hash password asynchronously to avoid blocking event loop
password_hash = await get_password_hash_async(obj_in.password)
db_obj = User(
email=obj_in.email,
password_hash=get_password_hash(obj_in.password),
password_hash=password_hash,
first_name=obj_in.first_name,
last_name=obj_in.last_name,
phone_number=obj_in.phone_number if hasattr(obj_in, 'phone_number') else None,
@@ -31,11 +47,11 @@ class CRUDUser(CRUDBase[User, UserCreate, UserUpdate]):
preferences={}
)
db.add(db_obj)
db.commit()
db.refresh(db_obj)
await db.commit()
await db.refresh(db_obj)
return db_obj
except IntegrityError as e:
db.rollback()
await db.rollback()
error_msg = str(e.orig) if hasattr(e, 'orig') else str(e)
if "email" in error_msg.lower():
logger.warning(f"Duplicate email attempted: {obj_in.email}")
@@ -43,32 +59,34 @@ class CRUDUser(CRUDBase[User, UserCreate, UserUpdate]):
logger.error(f"Integrity error creating user: {error_msg}")
raise ValueError(f"Database integrity error: {error_msg}")
except Exception as e:
db.rollback()
await db.rollback()
logger.error(f"Unexpected error creating user: {str(e)}", exc_info=True)
raise
def update(
self,
db: Session,
*,
db_obj: User,
obj_in: Union[UserUpdate, Dict[str, Any]]
async def update(
self,
db: AsyncSession,
*,
db_obj: User,
obj_in: Union[UserUpdate, Dict[str, Any]]
) -> User:
"""Update user with async password hashing if password is updated."""
if isinstance(obj_in, dict):
update_data = obj_in
else:
update_data = obj_in.model_dump(exclude_unset=True)
# Handle password separately if it exists in update data
# Hash password asynchronously to avoid blocking event loop
if "password" in update_data:
update_data["password_hash"] = get_password_hash(update_data["password"])
update_data["password_hash"] = await get_password_hash_async(update_data["password"])
del update_data["password"]
return super().update(db, db_obj=db_obj, obj_in=update_data)
return await super().update(db, db_obj=db_obj, obj_in=update_data)
def get_multi_with_total(
async def get_multi_with_total(
self,
db: Session,
db: AsyncSession,
*,
skip: int = 0,
limit: int = 100,
@@ -102,16 +120,16 @@ class CRUDUser(CRUDBase[User, UserCreate, UserUpdate]):
try:
# Build base query
query = db.query(User)
query = select(User)
# Exclude soft-deleted users
query = query.filter(User.deleted_at.is_(None))
query = query.where(User.deleted_at.is_(None))
# Apply filters
if filters:
for field, value in filters.items():
if hasattr(User, field) and value is not None:
query = query.filter(getattr(User, field) == value)
query = query.where(getattr(User, field) == value)
# Apply search
if search:
@@ -120,21 +138,26 @@ class CRUDUser(CRUDBase[User, UserCreate, UserUpdate]):
User.first_name.ilike(f"%{search}%"),
User.last_name.ilike(f"%{search}%")
)
query = query.filter(search_filter)
query = query.where(search_filter)
# Get total count
total = query.count()
from sqlalchemy import func
count_query = select(func.count()).select_from(query.alias())
count_result = await db.execute(count_query)
total = count_result.scalar_one()
# Apply sorting
if sort_by and hasattr(User, sort_by):
sort_column = getattr(User, sort_by)
if sort_order.lower() == "desc":
query = query.order_by(desc(sort_column))
query = query.order_by(sort_column.desc())
else:
query = query.order_by(asc(sort_column))
query = query.order_by(sort_column.asc())
# Apply pagination
users = query.offset(skip).limit(limit).all()
query = query.offset(skip).limit(limit)
result = await db.execute(query)
users = list(result.scalars().all())
return users, total
@@ -142,12 +165,108 @@ class CRUDUser(CRUDBase[User, UserCreate, UserUpdate]):
logger.error(f"Error retrieving paginated users: {str(e)}")
raise
async def bulk_update_status(
self,
db: AsyncSession,
*,
user_ids: List[UUID],
is_active: bool
) -> int:
"""
Bulk update is_active status for multiple users.
Args:
db: Database session
user_ids: List of user IDs to update
is_active: New active status
Returns:
Number of users updated
"""
try:
if not user_ids:
return 0
# Use UPDATE with WHERE IN for efficiency
stmt = (
update(User)
.where(User.id.in_(user_ids))
.where(User.deleted_at.is_(None)) # Don't update deleted users
.values(is_active=is_active, updated_at=datetime.now(timezone.utc))
)
result = await db.execute(stmt)
await db.commit()
updated_count = result.rowcount
logger.info(f"Bulk updated {updated_count} users to is_active={is_active}")
return updated_count
except Exception as e:
await db.rollback()
logger.error(f"Error bulk updating user status: {str(e)}", exc_info=True)
raise
async def bulk_soft_delete(
self,
db: AsyncSession,
*,
user_ids: List[UUID],
exclude_user_id: Optional[UUID] = None
) -> int:
"""
Bulk soft delete multiple users.
Args:
db: Database session
user_ids: List of user IDs to delete
exclude_user_id: Optional user ID to exclude (e.g., the admin performing the action)
Returns:
Number of users deleted
"""
try:
if not user_ids:
return 0
# Remove excluded user from list
filtered_ids = [uid for uid in user_ids if uid != exclude_user_id]
if not filtered_ids:
return 0
# Use UPDATE with WHERE IN for efficiency
stmt = (
update(User)
.where(User.id.in_(filtered_ids))
.where(User.deleted_at.is_(None)) # Don't re-delete already deleted users
.values(
deleted_at=datetime.now(timezone.utc),
is_active=False,
updated_at=datetime.now(timezone.utc)
)
)
result = await db.execute(stmt)
await db.commit()
deleted_count = result.rowcount
logger.info(f"Bulk soft deleted {deleted_count} users")
return deleted_count
except Exception as e:
await db.rollback()
logger.error(f"Error bulk deleting users: {str(e)}", exc_info=True)
raise
def is_active(self, user: User) -> bool:
"""Check if user is active."""
return user.is_active
def is_superuser(self, user: User) -> bool:
"""Check if user is a superuser."""
return user.is_superuser
# Create a singleton instance for use across the application
user = CRUDUser(User)
user = CRUDUser(User)

View File

@@ -1,272 +0,0 @@
# app/crud/user_async.py
"""Async CRUD operations for User model using SQLAlchemy 2.0 patterns."""
import logging
from datetime import datetime, timezone
from typing import Optional, Union, Dict, Any, List, Tuple
from uuid import UUID
from sqlalchemy import or_, select, update
from sqlalchemy.exc import IntegrityError
from sqlalchemy.ext.asyncio import AsyncSession
from app.core.auth import get_password_hash_async
from app.crud.base_async import CRUDBaseAsync
from app.models.user import User
from app.schemas.users import UserCreate, UserUpdate
logger = logging.getLogger(__name__)
class CRUDUserAsync(CRUDBaseAsync[User, UserCreate, UserUpdate]):
"""Async CRUD operations for User model."""
async def get_by_email(self, db: AsyncSession, *, email: str) -> Optional[User]:
"""Get user by email address."""
try:
result = await db.execute(
select(User).where(User.email == email)
)
return result.scalar_one_or_none()
except Exception as e:
logger.error(f"Error getting user by email {email}: {str(e)}")
raise
async def create(self, db: AsyncSession, *, obj_in: UserCreate) -> User:
"""Create a new user with async password hashing and error handling."""
try:
# Hash password asynchronously to avoid blocking event loop
password_hash = await get_password_hash_async(obj_in.password)
db_obj = User(
email=obj_in.email,
password_hash=password_hash,
first_name=obj_in.first_name,
last_name=obj_in.last_name,
phone_number=obj_in.phone_number if hasattr(obj_in, 'phone_number') else None,
is_superuser=obj_in.is_superuser if hasattr(obj_in, 'is_superuser') else False,
preferences={}
)
db.add(db_obj)
await db.commit()
await db.refresh(db_obj)
return db_obj
except IntegrityError as e:
await db.rollback()
error_msg = str(e.orig) if hasattr(e, 'orig') else str(e)
if "email" in error_msg.lower():
logger.warning(f"Duplicate email attempted: {obj_in.email}")
raise ValueError(f"User with email {obj_in.email} already exists")
logger.error(f"Integrity error creating user: {error_msg}")
raise ValueError(f"Database integrity error: {error_msg}")
except Exception as e:
await db.rollback()
logger.error(f"Unexpected error creating user: {str(e)}", exc_info=True)
raise
async def update(
self,
db: AsyncSession,
*,
db_obj: User,
obj_in: Union[UserUpdate, Dict[str, Any]]
) -> User:
"""Update user with async password hashing if password is updated."""
if isinstance(obj_in, dict):
update_data = obj_in
else:
update_data = obj_in.model_dump(exclude_unset=True)
# Handle password separately if it exists in update data
# Hash password asynchronously to avoid blocking event loop
if "password" in update_data:
update_data["password_hash"] = await get_password_hash_async(update_data["password"])
del update_data["password"]
return await super().update(db, db_obj=db_obj, obj_in=update_data)
async def get_multi_with_total(
self,
db: AsyncSession,
*,
skip: int = 0,
limit: int = 100,
sort_by: Optional[str] = None,
sort_order: str = "asc",
filters: Optional[Dict[str, Any]] = None,
search: Optional[str] = None
) -> Tuple[List[User], int]:
"""
Get multiple users with total count, filtering, sorting, and search.
Args:
db: Database session
skip: Number of records to skip
limit: Maximum number of records to return
sort_by: Field name to sort by
sort_order: Sort order ("asc" or "desc")
filters: Dictionary of filters (field_name: value)
search: Search term to match against email, first_name, last_name
Returns:
Tuple of (users list, total count)
"""
# Validate pagination
if skip < 0:
raise ValueError("skip must be non-negative")
if limit < 0:
raise ValueError("limit must be non-negative")
if limit > 1000:
raise ValueError("Maximum limit is 1000")
try:
# Build base query
query = select(User)
# Exclude soft-deleted users
query = query.where(User.deleted_at.is_(None))
# Apply filters
if filters:
for field, value in filters.items():
if hasattr(User, field) and value is not None:
query = query.where(getattr(User, field) == value)
# Apply search
if search:
search_filter = or_(
User.email.ilike(f"%{search}%"),
User.first_name.ilike(f"%{search}%"),
User.last_name.ilike(f"%{search}%")
)
query = query.where(search_filter)
# Get total count
from sqlalchemy import func
count_query = select(func.count()).select_from(query.alias())
count_result = await db.execute(count_query)
total = count_result.scalar_one()
# Apply sorting
if sort_by and hasattr(User, sort_by):
sort_column = getattr(User, sort_by)
if sort_order.lower() == "desc":
query = query.order_by(sort_column.desc())
else:
query = query.order_by(sort_column.asc())
# Apply pagination
query = query.offset(skip).limit(limit)
result = await db.execute(query)
users = list(result.scalars().all())
return users, total
except Exception as e:
logger.error(f"Error retrieving paginated users: {str(e)}")
raise
async def bulk_update_status(
self,
db: AsyncSession,
*,
user_ids: List[UUID],
is_active: bool
) -> int:
"""
Bulk update is_active status for multiple users.
Args:
db: Database session
user_ids: List of user IDs to update
is_active: New active status
Returns:
Number of users updated
"""
try:
if not user_ids:
return 0
# Use UPDATE with WHERE IN for efficiency
stmt = (
update(User)
.where(User.id.in_(user_ids))
.where(User.deleted_at.is_(None)) # Don't update deleted users
.values(is_active=is_active, updated_at=datetime.now(timezone.utc))
)
result = await db.execute(stmt)
await db.commit()
updated_count = result.rowcount
logger.info(f"Bulk updated {updated_count} users to is_active={is_active}")
return updated_count
except Exception as e:
await db.rollback()
logger.error(f"Error bulk updating user status: {str(e)}", exc_info=True)
raise
async def bulk_soft_delete(
self,
db: AsyncSession,
*,
user_ids: List[UUID],
exclude_user_id: Optional[UUID] = None
) -> int:
"""
Bulk soft delete multiple users.
Args:
db: Database session
user_ids: List of user IDs to delete
exclude_user_id: Optional user ID to exclude (e.g., the admin performing the action)
Returns:
Number of users deleted
"""
try:
if not user_ids:
return 0
# Remove excluded user from list
filtered_ids = [uid for uid in user_ids if uid != exclude_user_id]
if not filtered_ids:
return 0
# Use UPDATE with WHERE IN for efficiency
stmt = (
update(User)
.where(User.id.in_(filtered_ids))
.where(User.deleted_at.is_(None)) # Don't re-delete already deleted users
.values(
deleted_at=datetime.now(timezone.utc),
is_active=False,
updated_at=datetime.now(timezone.utc)
)
)
result = await db.execute(stmt)
await db.commit()
deleted_count = result.rowcount
logger.info(f"Bulk soft deleted {deleted_count} users")
return deleted_count
except Exception as e:
await db.rollback()
logger.error(f"Error bulk deleting users: {str(e)}", exc_info=True)
raise
def is_active(self, user: User) -> bool:
"""Check if user is active."""
return user.is_active
def is_superuser(self, user: User) -> bool:
"""Check if user is a superuser."""
return user.is_superuser
# Create a singleton instance for use across the application
user_async = CRUDUserAsync(User)

View File

@@ -1,78 +0,0 @@
# app/init_db.py
import logging
from typing import Optional
from sqlalchemy.orm import Session
from app.core.config import settings
from app.core.database import engine
from app.crud.user import user as user_crud
from app.schemas.users import UserCreate
logger = logging.getLogger(__name__)
def init_db(db: Session) -> Optional[UserCreate]:
"""
Initialize database with first superuser if settings are configured and user doesn't exist.
Returns:
The created or existing superuser, or None if creation fails
"""
# Use default values if not set in environment variables
superuser_email = settings.FIRST_SUPERUSER_EMAIL or "admin@example.com"
superuser_password = settings.FIRST_SUPERUSER_PASSWORD or "Admin123!Change"
if not settings.FIRST_SUPERUSER_EMAIL or not settings.FIRST_SUPERUSER_PASSWORD:
logger.warning(
"First superuser credentials not configured in settings. "
f"Using defaults: {superuser_email}"
)
try:
# Check if superuser already exists
existing_user = user_crud.get_by_email(db, email=superuser_email)
if existing_user:
logger.info(f"Superuser already exists: {existing_user.email}")
return existing_user
# Create superuser if doesn't exist
user_in = UserCreate(
email=superuser_email,
password=superuser_password,
first_name="Admin",
last_name="User",
is_superuser=True
)
user = user_crud.create(db, obj_in=user_in)
logger.info(f"Created first superuser: {user.email}")
return user
except Exception as e:
logger.error(f"Error initializing database: {e}")
raise
if __name__ == "__main__":
# Configure logging to show info logs
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
)
with Session(engine) as session:
try:
user = init_db(session)
if user:
print(f"✓ Database initialized successfully")
print(f"✓ Superuser: {user.email}")
else:
print("✗ Failed to initialize database")
except Exception as e:
print(f"✗ Error initializing database: {e}")
raise
finally:
session.close()

View File

@@ -13,7 +13,7 @@ from slowapi.util import get_remote_address
from app.api.main import api_router
from app.core.config import settings
from app.core.database_async import check_database_health
from app.core.database import check_database_health
from app.core.exceptions import (
APIException,
api_exception_handler,

View File

@@ -6,8 +6,8 @@ This service runs periodically to remove old session records from the database.
import logging
from datetime import datetime, timezone
from app.core.database_async import AsyncSessionLocal
from app.crud.session_async import session_async as session_crud
from app.core.database import SessionLocal
from app.crud.session import session as session_crud
logger = logging.getLogger(__name__)
@@ -29,7 +29,7 @@ async def cleanup_expired_sessions(keep_days: int = 30) -> int:
"""
logger.info("Starting session cleanup job...")
async with AsyncSessionLocal() as db:
async with SessionLocal() as db:
try:
# Use CRUD method to cleanup
count = await session_crud.cleanup_expired(db, keep_days=keep_days)
@@ -50,7 +50,7 @@ async def get_session_statistics() -> dict:
Returns:
Dictionary with session stats
"""
async with AsyncSessionLocal() as db:
async with SessionLocal() as db:
try:
from app.models.user_session import UserSession
from sqlalchemy import select, func