forked from cardosofelipe/fast-next-template
Remove unused async database and CRUD modules
- Deleted `database_async.py`, `base_async.py`, and `organization_async.py` modules due to deprecation and unused references across the project. - Improved overall codebase clarity and minimized redundant functionality by removing unused async database logic, CRUD utilities, and organization-related operations.
This commit is contained in:
@@ -7,7 +7,7 @@ from sqlalchemy import select
|
|||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
from app.core.auth import get_token_data, TokenExpiredError, TokenInvalidError
|
from app.core.auth import get_token_data, TokenExpiredError, TokenInvalidError
|
||||||
from app.core.database_async import get_async_db
|
from app.core.database import get_db
|
||||||
from app.models.user import User
|
from app.models.user import User
|
||||||
|
|
||||||
# OAuth2 configuration
|
# OAuth2 configuration
|
||||||
@@ -15,7 +15,7 @@ oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/api/v1/auth/login")
|
|||||||
|
|
||||||
|
|
||||||
async def get_current_user(
|
async def get_current_user(
|
||||||
db: AsyncSession = Depends(get_async_db),
|
db: AsyncSession = Depends(get_db),
|
||||||
token: str = Depends(oauth2_scheme)
|
token: str = Depends(oauth2_scheme)
|
||||||
) -> User:
|
) -> User:
|
||||||
"""
|
"""
|
||||||
@@ -139,7 +139,7 @@ async def get_optional_token(authorization: str = Header(None)) -> Optional[str]
|
|||||||
|
|
||||||
|
|
||||||
async def get_optional_current_user(
|
async def get_optional_current_user(
|
||||||
db: AsyncSession = Depends(get_async_db),
|
db: AsyncSession = Depends(get_db),
|
||||||
token: Optional[str] = Depends(get_optional_token)
|
token: Optional[str] = Depends(get_optional_token)
|
||||||
) -> Optional[User]:
|
) -> Optional[User]:
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -14,8 +14,8 @@ from fastapi import Depends, HTTPException, status
|
|||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
from app.api.dependencies.auth import get_current_user
|
from app.api.dependencies.auth import get_current_user
|
||||||
from app.core.database_async import get_async_db
|
from app.core.database import get_db
|
||||||
from app.crud.organization_async import organization_async as organization_crud
|
from app.crud.organization import organization as organization_crud
|
||||||
from app.models.user import User
|
from app.models.user import User
|
||||||
from app.models.user_organization import OrganizationRole
|
from app.models.user_organization import OrganizationRole
|
||||||
|
|
||||||
@@ -78,7 +78,7 @@ class OrganizationPermission:
|
|||||||
self,
|
self,
|
||||||
organization_id: UUID,
|
organization_id: UUID,
|
||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
db: AsyncSession = Depends(get_async_db)
|
db: AsyncSession = Depends(get_db)
|
||||||
) -> User:
|
) -> User:
|
||||||
"""
|
"""
|
||||||
Check if user has required role in the organization.
|
Check if user has required role in the organization.
|
||||||
@@ -133,7 +133,7 @@ require_org_member = OrganizationPermission([
|
|||||||
async def get_current_org_role(
|
async def get_current_org_role(
|
||||||
organization_id: UUID,
|
organization_id: UUID,
|
||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
db: AsyncSession = Depends(get_async_db)
|
db: AsyncSession = Depends(get_db)
|
||||||
) -> Optional[OrganizationRole]:
|
) -> Optional[OrganizationRole]:
|
||||||
"""
|
"""
|
||||||
Get the current user's role in an organization.
|
Get the current user's role in an organization.
|
||||||
@@ -164,7 +164,7 @@ async def get_current_org_role(
|
|||||||
async def require_org_membership(
|
async def require_org_membership(
|
||||||
organization_id: UUID,
|
organization_id: UUID,
|
||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
db: AsyncSession = Depends(get_async_db)
|
db: AsyncSession = Depends(get_db)
|
||||||
) -> User:
|
) -> User:
|
||||||
"""
|
"""
|
||||||
Ensure user is a member of the organization (any role).
|
Ensure user is a member of the organization (any role).
|
||||||
|
|||||||
@@ -15,10 +15,10 @@ from pydantic import BaseModel, Field
|
|||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
from app.api.dependencies.permissions import require_superuser
|
from app.api.dependencies.permissions import require_superuser
|
||||||
from app.core.database_async import get_async_db
|
from app.core.database import get_db
|
||||||
from app.core.exceptions import NotFoundError, DuplicateError, AuthorizationError, ErrorCode
|
from app.core.exceptions import NotFoundError, DuplicateError, AuthorizationError, ErrorCode
|
||||||
from app.crud.organization_async import organization_async as organization_crud
|
from app.crud.organization import organization as organization_crud
|
||||||
from app.crud.user_async import user_async as user_crud
|
from app.crud.user import user as user_crud
|
||||||
from app.models.user import User
|
from app.models.user import User
|
||||||
from app.models.user_organization import OrganizationRole
|
from app.models.user_organization import OrganizationRole
|
||||||
from app.schemas.common import (
|
from app.schemas.common import (
|
||||||
@@ -80,7 +80,7 @@ async def admin_list_users(
|
|||||||
is_superuser: Optional[bool] = Query(None, description="Filter by superuser status"),
|
is_superuser: Optional[bool] = Query(None, description="Filter by superuser status"),
|
||||||
search: Optional[str] = Query(None, description="Search by email, name"),
|
search: Optional[str] = Query(None, description="Search by email, name"),
|
||||||
admin: User = Depends(require_superuser),
|
admin: User = Depends(require_superuser),
|
||||||
db: AsyncSession = Depends(get_async_db)
|
db: AsyncSession = Depends(get_db)
|
||||||
) -> Any:
|
) -> Any:
|
||||||
"""
|
"""
|
||||||
List all users with comprehensive filtering and search.
|
List all users with comprehensive filtering and search.
|
||||||
@@ -131,7 +131,7 @@ async def admin_list_users(
|
|||||||
async def admin_create_user(
|
async def admin_create_user(
|
||||||
user_in: UserCreate,
|
user_in: UserCreate,
|
||||||
admin: User = Depends(require_superuser),
|
admin: User = Depends(require_superuser),
|
||||||
db: AsyncSession = Depends(get_async_db)
|
db: AsyncSession = Depends(get_db)
|
||||||
) -> Any:
|
) -> Any:
|
||||||
"""
|
"""
|
||||||
Create a new user with admin privileges.
|
Create a new user with admin privileges.
|
||||||
@@ -163,7 +163,7 @@ async def admin_create_user(
|
|||||||
async def admin_get_user(
|
async def admin_get_user(
|
||||||
user_id: UUID,
|
user_id: UUID,
|
||||||
admin: User = Depends(require_superuser),
|
admin: User = Depends(require_superuser),
|
||||||
db: AsyncSession = Depends(get_async_db)
|
db: AsyncSession = Depends(get_db)
|
||||||
) -> Any:
|
) -> Any:
|
||||||
"""Get detailed information about a specific user."""
|
"""Get detailed information about a specific user."""
|
||||||
user = await user_crud.get(db, id=user_id)
|
user = await user_crud.get(db, id=user_id)
|
||||||
@@ -186,7 +186,7 @@ async def admin_update_user(
|
|||||||
user_id: UUID,
|
user_id: UUID,
|
||||||
user_in: UserUpdate,
|
user_in: UserUpdate,
|
||||||
admin: User = Depends(require_superuser),
|
admin: User = Depends(require_superuser),
|
||||||
db: AsyncSession = Depends(get_async_db)
|
db: AsyncSession = Depends(get_db)
|
||||||
) -> Any:
|
) -> Any:
|
||||||
"""Update user information with admin privileges."""
|
"""Update user information with admin privileges."""
|
||||||
try:
|
try:
|
||||||
@@ -218,7 +218,7 @@ async def admin_update_user(
|
|||||||
async def admin_delete_user(
|
async def admin_delete_user(
|
||||||
user_id: UUID,
|
user_id: UUID,
|
||||||
admin: User = Depends(require_superuser),
|
admin: User = Depends(require_superuser),
|
||||||
db: AsyncSession = Depends(get_async_db)
|
db: AsyncSession = Depends(get_db)
|
||||||
) -> Any:
|
) -> Any:
|
||||||
"""Soft delete a user (sets deleted_at timestamp)."""
|
"""Soft delete a user (sets deleted_at timestamp)."""
|
||||||
try:
|
try:
|
||||||
@@ -262,7 +262,7 @@ async def admin_delete_user(
|
|||||||
async def admin_activate_user(
|
async def admin_activate_user(
|
||||||
user_id: UUID,
|
user_id: UUID,
|
||||||
admin: User = Depends(require_superuser),
|
admin: User = Depends(require_superuser),
|
||||||
db: AsyncSession = Depends(get_async_db)
|
db: AsyncSession = Depends(get_db)
|
||||||
) -> Any:
|
) -> Any:
|
||||||
"""Activate a user account."""
|
"""Activate a user account."""
|
||||||
try:
|
try:
|
||||||
@@ -298,7 +298,7 @@ async def admin_activate_user(
|
|||||||
async def admin_deactivate_user(
|
async def admin_deactivate_user(
|
||||||
user_id: UUID,
|
user_id: UUID,
|
||||||
admin: User = Depends(require_superuser),
|
admin: User = Depends(require_superuser),
|
||||||
db: AsyncSession = Depends(get_async_db)
|
db: AsyncSession = Depends(get_db)
|
||||||
) -> Any:
|
) -> Any:
|
||||||
"""Deactivate a user account."""
|
"""Deactivate a user account."""
|
||||||
try:
|
try:
|
||||||
@@ -342,7 +342,7 @@ async def admin_deactivate_user(
|
|||||||
async def admin_bulk_user_action(
|
async def admin_bulk_user_action(
|
||||||
bulk_action: BulkUserAction,
|
bulk_action: BulkUserAction,
|
||||||
admin: User = Depends(require_superuser),
|
admin: User = Depends(require_superuser),
|
||||||
db: AsyncSession = Depends(get_async_db)
|
db: AsyncSession = Depends(get_db)
|
||||||
) -> Any:
|
) -> Any:
|
||||||
"""
|
"""
|
||||||
Perform bulk actions on multiple users using optimized bulk operations.
|
Perform bulk actions on multiple users using optimized bulk operations.
|
||||||
@@ -410,7 +410,7 @@ async def admin_list_organizations(
|
|||||||
is_active: Optional[bool] = Query(None, description="Filter by active status"),
|
is_active: Optional[bool] = Query(None, description="Filter by active status"),
|
||||||
search: Optional[str] = Query(None, description="Search by name, slug, description"),
|
search: Optional[str] = Query(None, description="Search by name, slug, description"),
|
||||||
admin: User = Depends(require_superuser),
|
admin: User = Depends(require_superuser),
|
||||||
db: AsyncSession = Depends(get_async_db)
|
db: AsyncSession = Depends(get_db)
|
||||||
) -> Any:
|
) -> Any:
|
||||||
"""List all organizations with filtering and search."""
|
"""List all organizations with filtering and search."""
|
||||||
try:
|
try:
|
||||||
@@ -467,7 +467,7 @@ async def admin_list_organizations(
|
|||||||
async def admin_create_organization(
|
async def admin_create_organization(
|
||||||
org_in: OrganizationCreate,
|
org_in: OrganizationCreate,
|
||||||
admin: User = Depends(require_superuser),
|
admin: User = Depends(require_superuser),
|
||||||
db: AsyncSession = Depends(get_async_db)
|
db: AsyncSession = Depends(get_db)
|
||||||
) -> Any:
|
) -> Any:
|
||||||
"""Create a new organization."""
|
"""Create a new organization."""
|
||||||
try:
|
try:
|
||||||
@@ -509,7 +509,7 @@ async def admin_create_organization(
|
|||||||
async def admin_get_organization(
|
async def admin_get_organization(
|
||||||
org_id: UUID,
|
org_id: UUID,
|
||||||
admin: User = Depends(require_superuser),
|
admin: User = Depends(require_superuser),
|
||||||
db: AsyncSession = Depends(get_async_db)
|
db: AsyncSession = Depends(get_db)
|
||||||
) -> Any:
|
) -> Any:
|
||||||
"""Get detailed information about a specific organization."""
|
"""Get detailed information about a specific organization."""
|
||||||
org = await organization_crud.get(db, id=org_id)
|
org = await organization_crud.get(db, id=org_id)
|
||||||
@@ -544,7 +544,7 @@ async def admin_update_organization(
|
|||||||
org_id: UUID,
|
org_id: UUID,
|
||||||
org_in: OrganizationUpdate,
|
org_in: OrganizationUpdate,
|
||||||
admin: User = Depends(require_superuser),
|
admin: User = Depends(require_superuser),
|
||||||
db: AsyncSession = Depends(get_async_db)
|
db: AsyncSession = Depends(get_db)
|
||||||
) -> Any:
|
) -> Any:
|
||||||
"""Update organization information."""
|
"""Update organization information."""
|
||||||
try:
|
try:
|
||||||
@@ -588,7 +588,7 @@ async def admin_update_organization(
|
|||||||
async def admin_delete_organization(
|
async def admin_delete_organization(
|
||||||
org_id: UUID,
|
org_id: UUID,
|
||||||
admin: User = Depends(require_superuser),
|
admin: User = Depends(require_superuser),
|
||||||
db: AsyncSession = Depends(get_async_db)
|
db: AsyncSession = Depends(get_db)
|
||||||
) -> Any:
|
) -> Any:
|
||||||
"""Delete an organization and all its relationships."""
|
"""Delete an organization and all its relationships."""
|
||||||
try:
|
try:
|
||||||
@@ -626,7 +626,7 @@ async def admin_list_organization_members(
|
|||||||
pagination: PaginationParams = Depends(),
|
pagination: PaginationParams = Depends(),
|
||||||
is_active: Optional[bool] = Query(True, description="Filter by active status"),
|
is_active: Optional[bool] = Query(True, description="Filter by active status"),
|
||||||
admin: User = Depends(require_superuser),
|
admin: User = Depends(require_superuser),
|
||||||
db: AsyncSession = Depends(get_async_db)
|
db: AsyncSession = Depends(get_db)
|
||||||
) -> Any:
|
) -> Any:
|
||||||
"""List all members of an organization."""
|
"""List all members of an organization."""
|
||||||
try:
|
try:
|
||||||
@@ -681,7 +681,7 @@ async def admin_add_organization_member(
|
|||||||
org_id: UUID,
|
org_id: UUID,
|
||||||
request: AddMemberRequest,
|
request: AddMemberRequest,
|
||||||
admin: User = Depends(require_superuser),
|
admin: User = Depends(require_superuser),
|
||||||
db: AsyncSession = Depends(get_async_db)
|
db: AsyncSession = Depends(get_db)
|
||||||
) -> Any:
|
) -> Any:
|
||||||
"""Add a user to an organization."""
|
"""Add a user to an organization."""
|
||||||
try:
|
try:
|
||||||
@@ -742,7 +742,7 @@ async def admin_remove_organization_member(
|
|||||||
org_id: UUID,
|
org_id: UUID,
|
||||||
user_id: UUID,
|
user_id: UUID,
|
||||||
admin: User = Depends(require_superuser),
|
admin: User = Depends(require_superuser),
|
||||||
db: AsyncSession = Depends(get_async_db)
|
db: AsyncSession = Depends(get_db)
|
||||||
) -> Any:
|
) -> Any:
|
||||||
"""Remove a user from an organization."""
|
"""Remove a user from an organization."""
|
||||||
try:
|
try:
|
||||||
|
|||||||
@@ -13,14 +13,14 @@ from sqlalchemy.ext.asyncio import AsyncSession
|
|||||||
from app.api.dependencies.auth import get_current_user
|
from app.api.dependencies.auth import get_current_user
|
||||||
from app.core.auth import TokenExpiredError, TokenInvalidError, decode_token
|
from app.core.auth import TokenExpiredError, TokenInvalidError, decode_token
|
||||||
from app.core.auth import get_password_hash
|
from app.core.auth import get_password_hash
|
||||||
from app.core.database_async import get_async_db
|
from app.core.database import get_db
|
||||||
from app.core.exceptions import (
|
from app.core.exceptions import (
|
||||||
AuthenticationError as AuthError,
|
AuthenticationError as AuthError,
|
||||||
DatabaseError,
|
DatabaseError,
|
||||||
ErrorCode
|
ErrorCode
|
||||||
)
|
)
|
||||||
from app.crud.session_async import session_async as session_crud
|
from app.crud.session import session as session_crud
|
||||||
from app.crud.user_async import user_async as user_crud
|
from app.crud.user import user as user_crud
|
||||||
from app.models.user import User
|
from app.models.user import User
|
||||||
from app.schemas.common import MessageResponse
|
from app.schemas.common import MessageResponse
|
||||||
from app.schemas.sessions import SessionCreate, LogoutRequest
|
from app.schemas.sessions import SessionCreate, LogoutRequest
|
||||||
@@ -54,7 +54,7 @@ RATE_MULTIPLIER = 100 if IS_TEST else 1
|
|||||||
async def register_user(
|
async def register_user(
|
||||||
request: Request,
|
request: Request,
|
||||||
user_data: UserCreate,
|
user_data: UserCreate,
|
||||||
db: AsyncSession = Depends(get_async_db)
|
db: AsyncSession = Depends(get_db)
|
||||||
) -> Any:
|
) -> Any:
|
||||||
"""
|
"""
|
||||||
Register a new user.
|
Register a new user.
|
||||||
@@ -85,7 +85,7 @@ async def register_user(
|
|||||||
async def login(
|
async def login(
|
||||||
request: Request,
|
request: Request,
|
||||||
login_data: LoginRequest,
|
login_data: LoginRequest,
|
||||||
db: AsyncSession = Depends(get_async_db)
|
db: AsyncSession = Depends(get_db)
|
||||||
) -> Any:
|
) -> Any:
|
||||||
"""
|
"""
|
||||||
Login with username and password.
|
Login with username and password.
|
||||||
@@ -167,7 +167,7 @@ async def login(
|
|||||||
async def login_oauth(
|
async def login_oauth(
|
||||||
request: Request,
|
request: Request,
|
||||||
form_data: OAuth2PasswordRequestForm = Depends(),
|
form_data: OAuth2PasswordRequestForm = Depends(),
|
||||||
db: AsyncSession = Depends(get_async_db)
|
db: AsyncSession = Depends(get_db)
|
||||||
) -> Any:
|
) -> Any:
|
||||||
"""
|
"""
|
||||||
OAuth2-compatible login endpoint, used by the OpenAPI UI.
|
OAuth2-compatible login endpoint, used by the OpenAPI UI.
|
||||||
@@ -244,7 +244,7 @@ async def login_oauth(
|
|||||||
async def refresh_token(
|
async def refresh_token(
|
||||||
request: Request,
|
request: Request,
|
||||||
refresh_data: RefreshTokenRequest,
|
refresh_data: RefreshTokenRequest,
|
||||||
db: AsyncSession = Depends(get_async_db)
|
db: AsyncSession = Depends(get_db)
|
||||||
) -> Any:
|
) -> Any:
|
||||||
"""
|
"""
|
||||||
Refresh access token using a refresh token.
|
Refresh access token using a refresh token.
|
||||||
@@ -333,7 +333,7 @@ async def refresh_token(
|
|||||||
async def request_password_reset(
|
async def request_password_reset(
|
||||||
request: Request,
|
request: Request,
|
||||||
reset_request: PasswordResetRequest,
|
reset_request: PasswordResetRequest,
|
||||||
db: AsyncSession = Depends(get_async_db)
|
db: AsyncSession = Depends(get_db)
|
||||||
) -> Any:
|
) -> Any:
|
||||||
"""
|
"""
|
||||||
Request a password reset.
|
Request a password reset.
|
||||||
@@ -391,7 +391,7 @@ async def request_password_reset(
|
|||||||
async def confirm_password_reset(
|
async def confirm_password_reset(
|
||||||
request: Request,
|
request: Request,
|
||||||
reset_confirm: PasswordResetConfirm,
|
reset_confirm: PasswordResetConfirm,
|
||||||
db: AsyncSession = Depends(get_async_db)
|
db: AsyncSession = Depends(get_db)
|
||||||
) -> Any:
|
) -> Any:
|
||||||
"""
|
"""
|
||||||
Confirm password reset with token.
|
Confirm password reset with token.
|
||||||
@@ -430,7 +430,7 @@ async def confirm_password_reset(
|
|||||||
|
|
||||||
# SECURITY: Invalidate all existing sessions after password reset
|
# SECURITY: Invalidate all existing sessions after password reset
|
||||||
# This prevents stolen sessions from being used after password change
|
# This prevents stolen sessions from being used after password change
|
||||||
from app.crud.session_async import session_async as session_crud
|
from app.crud.session import session as session_crud
|
||||||
try:
|
try:
|
||||||
deactivated_count = await session_crud.deactivate_all_user_sessions(
|
deactivated_count = await session_crud.deactivate_all_user_sessions(
|
||||||
db,
|
db,
|
||||||
@@ -478,7 +478,7 @@ async def logout(
|
|||||||
request: Request,
|
request: Request,
|
||||||
logout_request: LogoutRequest,
|
logout_request: LogoutRequest,
|
||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
db: AsyncSession = Depends(get_async_db)
|
db: AsyncSession = Depends(get_db)
|
||||||
) -> Any:
|
) -> Any:
|
||||||
"""
|
"""
|
||||||
Logout from current device by deactivating the session.
|
Logout from current device by deactivating the session.
|
||||||
@@ -566,7 +566,7 @@ async def logout(
|
|||||||
async def logout_all(
|
async def logout_all(
|
||||||
request: Request,
|
request: Request,
|
||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
db: AsyncSession = Depends(get_async_db)
|
db: AsyncSession = Depends(get_db)
|
||||||
) -> Any:
|
) -> Any:
|
||||||
"""
|
"""
|
||||||
Logout from all devices by deactivating all user sessions.
|
Logout from all devices by deactivating all user sessions.
|
||||||
|
|||||||
@@ -13,9 +13,9 @@ from sqlalchemy.ext.asyncio import AsyncSession
|
|||||||
|
|
||||||
from app.api.dependencies.auth import get_current_user
|
from app.api.dependencies.auth import get_current_user
|
||||||
from app.api.dependencies.permissions import require_org_admin, require_org_membership
|
from app.api.dependencies.permissions import require_org_admin, require_org_membership
|
||||||
from app.core.database_async import get_async_db
|
from app.core.database import get_db
|
||||||
from app.core.exceptions import NotFoundError, ErrorCode
|
from app.core.exceptions import NotFoundError, ErrorCode
|
||||||
from app.crud.organization_async import organization_async as organization_crud
|
from app.crud.organization import organization as organization_crud
|
||||||
from app.models.user import User
|
from app.models.user import User
|
||||||
from app.schemas.common import (
|
from app.schemas.common import (
|
||||||
PaginationParams,
|
PaginationParams,
|
||||||
@@ -43,7 +43,7 @@ router = APIRouter()
|
|||||||
async def get_my_organizations(
|
async def get_my_organizations(
|
||||||
is_active: bool = Query(True, description="Filter by active membership"),
|
is_active: bool = Query(True, description="Filter by active membership"),
|
||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
db: AsyncSession = Depends(get_async_db)
|
db: AsyncSession = Depends(get_db)
|
||||||
) -> Any:
|
) -> Any:
|
||||||
"""
|
"""
|
||||||
Get all organizations the current user belongs to.
|
Get all organizations the current user belongs to.
|
||||||
@@ -93,7 +93,7 @@ async def get_my_organizations(
|
|||||||
async def get_organization(
|
async def get_organization(
|
||||||
organization_id: UUID,
|
organization_id: UUID,
|
||||||
current_user: User = Depends(require_org_membership),
|
current_user: User = Depends(require_org_membership),
|
||||||
db: AsyncSession = Depends(get_async_db)
|
db: AsyncSession = Depends(get_db)
|
||||||
) -> Any:
|
) -> Any:
|
||||||
"""
|
"""
|
||||||
Get details of a specific organization.
|
Get details of a specific organization.
|
||||||
@@ -140,7 +140,7 @@ async def get_organization_members(
|
|||||||
pagination: PaginationParams = Depends(),
|
pagination: PaginationParams = Depends(),
|
||||||
is_active: bool = Query(True, description="Filter by active status"),
|
is_active: bool = Query(True, description="Filter by active status"),
|
||||||
current_user: User = Depends(require_org_membership),
|
current_user: User = Depends(require_org_membership),
|
||||||
db: AsyncSession = Depends(get_async_db)
|
db: AsyncSession = Depends(get_db)
|
||||||
) -> Any:
|
) -> Any:
|
||||||
"""
|
"""
|
||||||
Get all members of an organization.
|
Get all members of an organization.
|
||||||
@@ -183,7 +183,7 @@ async def update_organization(
|
|||||||
organization_id: UUID,
|
organization_id: UUID,
|
||||||
org_in: OrganizationUpdate,
|
org_in: OrganizationUpdate,
|
||||||
current_user: User = Depends(require_org_admin),
|
current_user: User = Depends(require_org_admin),
|
||||||
db: AsyncSession = Depends(get_async_db)
|
db: AsyncSession = Depends(get_db)
|
||||||
) -> Any:
|
) -> Any:
|
||||||
"""
|
"""
|
||||||
Update organization details.
|
Update organization details.
|
||||||
|
|||||||
@@ -14,9 +14,9 @@ from sqlalchemy.ext.asyncio import AsyncSession
|
|||||||
|
|
||||||
from app.api.dependencies.auth import get_current_user
|
from app.api.dependencies.auth import get_current_user
|
||||||
from app.core.auth import decode_token
|
from app.core.auth import decode_token
|
||||||
from app.core.database_async import get_async_db
|
from app.core.database import get_db
|
||||||
from app.core.exceptions import NotFoundError, AuthorizationError, ErrorCode
|
from app.core.exceptions import NotFoundError, AuthorizationError, ErrorCode
|
||||||
from app.crud.session_async import session_async as session_crud
|
from app.crud.session import session as session_crud
|
||||||
from app.models.user import User
|
from app.models.user import User
|
||||||
from app.schemas.common import MessageResponse
|
from app.schemas.common import MessageResponse
|
||||||
from app.schemas.sessions import SessionResponse, SessionListResponse
|
from app.schemas.sessions import SessionResponse, SessionListResponse
|
||||||
@@ -45,7 +45,7 @@ limiter = Limiter(key_func=get_remote_address)
|
|||||||
async def list_my_sessions(
|
async def list_my_sessions(
|
||||||
request: Request,
|
request: Request,
|
||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
db: AsyncSession = Depends(get_async_db)
|
db: AsyncSession = Depends(get_db)
|
||||||
) -> Any:
|
) -> Any:
|
||||||
"""
|
"""
|
||||||
List all active sessions for the current user.
|
List all active sessions for the current user.
|
||||||
@@ -129,7 +129,7 @@ async def revoke_session(
|
|||||||
request: Request,
|
request: Request,
|
||||||
session_id: UUID,
|
session_id: UUID,
|
||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
db: AsyncSession = Depends(get_async_db)
|
db: AsyncSession = Depends(get_db)
|
||||||
) -> Any:
|
) -> Any:
|
||||||
"""
|
"""
|
||||||
Revoke a specific session by ID.
|
Revoke a specific session by ID.
|
||||||
@@ -204,7 +204,7 @@ async def revoke_session(
|
|||||||
async def cleanup_expired_sessions(
|
async def cleanup_expired_sessions(
|
||||||
request: Request,
|
request: Request,
|
||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
db: AsyncSession = Depends(get_async_db)
|
db: AsyncSession = Depends(get_db)
|
||||||
) -> Any:
|
) -> Any:
|
||||||
"""
|
"""
|
||||||
Cleanup expired sessions for the current user.
|
Cleanup expired sessions for the current user.
|
||||||
|
|||||||
@@ -11,13 +11,13 @@ from slowapi.util import get_remote_address
|
|||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
from app.api.dependencies.auth import get_current_user, get_current_superuser
|
from app.api.dependencies.auth import get_current_user, get_current_superuser
|
||||||
from app.core.database_async import get_async_db
|
from app.core.database import get_db
|
||||||
from app.core.exceptions import (
|
from app.core.exceptions import (
|
||||||
NotFoundError,
|
NotFoundError,
|
||||||
AuthorizationError,
|
AuthorizationError,
|
||||||
ErrorCode
|
ErrorCode
|
||||||
)
|
)
|
||||||
from app.crud.user_async import user_async as user_crud
|
from app.crud.user import user as user_crud
|
||||||
from app.models.user import User
|
from app.models.user import User
|
||||||
from app.schemas.common import (
|
from app.schemas.common import (
|
||||||
PaginationParams,
|
PaginationParams,
|
||||||
@@ -58,7 +58,7 @@ async def list_users(
|
|||||||
is_active: Optional[bool] = Query(None, description="Filter by active status"),
|
is_active: Optional[bool] = Query(None, description="Filter by active status"),
|
||||||
is_superuser: Optional[bool] = Query(None, description="Filter by superuser status"),
|
is_superuser: Optional[bool] = Query(None, description="Filter by superuser status"),
|
||||||
current_user: User = Depends(get_current_superuser),
|
current_user: User = Depends(get_current_superuser),
|
||||||
db: AsyncSession = Depends(get_async_db)
|
db: AsyncSession = Depends(get_db)
|
||||||
) -> Any:
|
) -> Any:
|
||||||
"""
|
"""
|
||||||
List all users with pagination, filtering, and sorting.
|
List all users with pagination, filtering, and sorting.
|
||||||
@@ -138,7 +138,7 @@ def get_current_user_profile(
|
|||||||
async def update_current_user(
|
async def update_current_user(
|
||||||
user_update: UserUpdate,
|
user_update: UserUpdate,
|
||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
db: AsyncSession = Depends(get_async_db)
|
db: AsyncSession = Depends(get_db)
|
||||||
) -> Any:
|
) -> Any:
|
||||||
"""
|
"""
|
||||||
Update current user's profile.
|
Update current user's profile.
|
||||||
@@ -188,7 +188,7 @@ async def update_current_user(
|
|||||||
async def get_user_by_id(
|
async def get_user_by_id(
|
||||||
user_id: UUID,
|
user_id: UUID,
|
||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
db: AsyncSession = Depends(get_async_db)
|
db: AsyncSession = Depends(get_db)
|
||||||
) -> Any:
|
) -> Any:
|
||||||
"""
|
"""
|
||||||
Get user by ID.
|
Get user by ID.
|
||||||
@@ -236,7 +236,7 @@ async def update_user(
|
|||||||
user_id: UUID,
|
user_id: UUID,
|
||||||
user_update: UserUpdate,
|
user_update: UserUpdate,
|
||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
db: AsyncSession = Depends(get_async_db)
|
db: AsyncSession = Depends(get_db)
|
||||||
) -> Any:
|
) -> Any:
|
||||||
"""
|
"""
|
||||||
Update user by ID.
|
Update user by ID.
|
||||||
@@ -304,7 +304,7 @@ async def change_current_user_password(
|
|||||||
request: Request,
|
request: Request,
|
||||||
password_change: PasswordChange,
|
password_change: PasswordChange,
|
||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
db: AsyncSession = Depends(get_async_db)
|
db: AsyncSession = Depends(get_db)
|
||||||
) -> Any:
|
) -> Any:
|
||||||
"""
|
"""
|
||||||
Change current user's password.
|
Change current user's password.
|
||||||
@@ -356,7 +356,7 @@ async def change_current_user_password(
|
|||||||
async def delete_user(
|
async def delete_user(
|
||||||
user_id: UUID,
|
user_id: UUID,
|
||||||
current_user: User = Depends(get_current_superuser),
|
current_user: User = Depends(get_current_superuser),
|
||||||
db: AsyncSession = Depends(get_async_db)
|
db: AsyncSession = Depends(get_db)
|
||||||
) -> Any:
|
) -> Any:
|
||||||
"""
|
"""
|
||||||
Delete user by ID (superuser only).
|
Delete user by ID (superuser only).
|
||||||
|
|||||||
207
backend/app/core/database.py
Normal file → Executable file
207
backend/app/core/database.py
Normal file → Executable file
@@ -1,113 +1,186 @@
|
|||||||
# app/core/database.py
|
# app/core/database.py
|
||||||
import logging
|
"""
|
||||||
from contextlib import contextmanager
|
Database configuration using SQLAlchemy 2.0 and asyncpg.
|
||||||
from typing import Generator
|
|
||||||
|
|
||||||
from sqlalchemy import create_engine, text
|
This module provides async database connectivity with proper connection pooling
|
||||||
|
and session management for FastAPI endpoints.
|
||||||
|
"""
|
||||||
|
import logging
|
||||||
|
from contextlib import asynccontextmanager
|
||||||
|
from typing import AsyncGenerator
|
||||||
|
|
||||||
|
from sqlalchemy import text
|
||||||
from sqlalchemy.dialects.postgresql import JSONB, UUID
|
from sqlalchemy.dialects.postgresql import JSONB, UUID
|
||||||
|
from sqlalchemy.ext.asyncio import (
|
||||||
|
AsyncSession,
|
||||||
|
AsyncEngine,
|
||||||
|
create_async_engine,
|
||||||
|
async_sessionmaker,
|
||||||
|
)
|
||||||
from sqlalchemy.ext.compiler import compiles
|
from sqlalchemy.ext.compiler import compiles
|
||||||
from sqlalchemy.ext.declarative import declarative_base
|
from sqlalchemy.orm import DeclarativeBase
|
||||||
from sqlalchemy.orm import sessionmaker, Session
|
|
||||||
|
|
||||||
from app.core.config import settings
|
from app.core.config import settings
|
||||||
|
|
||||||
# Configure logging
|
# Configure logging
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
# SQLite compatibility for testing
|
# SQLite compatibility for testing
|
||||||
@compiles(JSONB, 'sqlite')
|
@compiles(JSONB, 'sqlite')
|
||||||
def compile_jsonb_sqlite(type_, compiler, **kw):
|
def compile_jsonb_sqlite(type_, compiler, **kw):
|
||||||
return "TEXT"
|
return "TEXT"
|
||||||
|
|
||||||
|
|
||||||
@compiles(UUID, 'sqlite')
|
@compiles(UUID, 'sqlite')
|
||||||
def compile_uuid_sqlite(type_, compiler, **kw):
|
def compile_uuid_sqlite(type_, compiler, **kw):
|
||||||
return "TEXT"
|
return "TEXT"
|
||||||
|
|
||||||
# Declarative base for models
|
|
||||||
Base = declarative_base()
|
|
||||||
|
|
||||||
# Create engine with optimized settings for PostgreSQL
|
# Declarative base for models (SQLAlchemy 2.0 style)
|
||||||
def create_production_engine():
|
class Base(DeclarativeBase):
|
||||||
return create_engine(
|
"""Base class for all database models."""
|
||||||
settings.database_url,
|
pass
|
||||||
# Connection pool settings
|
|
||||||
pool_size=settings.db_pool_size,
|
|
||||||
max_overflow=settings.db_max_overflow,
|
|
||||||
pool_timeout=settings.db_pool_timeout,
|
|
||||||
pool_recycle=settings.db_pool_recycle,
|
|
||||||
pool_pre_ping=True,
|
|
||||||
# Query execution settings
|
|
||||||
connect_args={
|
|
||||||
"application_name": "eventspace",
|
|
||||||
"keepalives": 1,
|
|
||||||
"keepalives_idle": 60,
|
|
||||||
"keepalives_interval": 10,
|
|
||||||
"keepalives_count": 5,
|
|
||||||
"options": "-c timezone=UTC",
|
|
||||||
},
|
|
||||||
isolation_level="READ COMMITTED",
|
|
||||||
echo=settings.sql_echo,
|
|
||||||
echo_pool=settings.sql_echo_pool,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Default production engine and session factory
|
|
||||||
engine = create_production_engine()
|
def get_async_database_url(url: str) -> str:
|
||||||
SessionLocal = sessionmaker(
|
"""
|
||||||
|
Convert sync database URL to async URL.
|
||||||
|
|
||||||
|
postgresql:// -> postgresql+asyncpg://
|
||||||
|
sqlite:// -> sqlite+aiosqlite://
|
||||||
|
"""
|
||||||
|
if url.startswith("postgresql://"):
|
||||||
|
return url.replace("postgresql://", "postgresql+asyncpg://")
|
||||||
|
elif url.startswith("sqlite://"):
|
||||||
|
return url.replace("sqlite://", "sqlite+aiosqlite://")
|
||||||
|
return url
|
||||||
|
|
||||||
|
|
||||||
|
# Create async engine with optimized settings
|
||||||
|
def create_async_production_engine() -> AsyncEngine:
|
||||||
|
"""Create an async database engine with production settings."""
|
||||||
|
async_url = get_async_database_url(settings.database_url)
|
||||||
|
|
||||||
|
# Base engine config
|
||||||
|
engine_config = {
|
||||||
|
"pool_size": settings.db_pool_size,
|
||||||
|
"max_overflow": settings.db_max_overflow,
|
||||||
|
"pool_timeout": settings.db_pool_timeout,
|
||||||
|
"pool_recycle": settings.db_pool_recycle,
|
||||||
|
"pool_pre_ping": True,
|
||||||
|
"echo": settings.sql_echo,
|
||||||
|
"echo_pool": settings.sql_echo_pool,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Add PostgreSQL-specific connect_args
|
||||||
|
if "postgresql" in async_url:
|
||||||
|
engine_config["connect_args"] = {
|
||||||
|
"server_settings": {
|
||||||
|
"application_name": "eventspace",
|
||||||
|
"timezone": "UTC",
|
||||||
|
},
|
||||||
|
# asyncpg-specific settings
|
||||||
|
"command_timeout": 60,
|
||||||
|
"timeout": 10,
|
||||||
|
}
|
||||||
|
|
||||||
|
return create_async_engine(async_url, **engine_config)
|
||||||
|
|
||||||
|
|
||||||
|
# Create async engine and session factory
|
||||||
|
engine = create_async_production_engine()
|
||||||
|
SessionLocal = async_sessionmaker(
|
||||||
|
engine,
|
||||||
|
class_=AsyncSession,
|
||||||
autocommit=False,
|
autocommit=False,
|
||||||
autoflush=False,
|
autoflush=False,
|
||||||
bind=engine,
|
expire_on_commit=False, # Prevent unnecessary queries after commit
|
||||||
expire_on_commit=False # Prevent unnecessary queries after commit
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# FastAPI dependency
|
|
||||||
def get_db() -> Generator[Session, None, None]:
|
# FastAPI dependency for async database sessions
|
||||||
|
async def get_db() -> AsyncGenerator[AsyncSession, None]:
|
||||||
"""
|
"""
|
||||||
FastAPI dependency that provides a database session.
|
FastAPI dependency that provides an async database session.
|
||||||
Automatically closes the session after the request completes.
|
Automatically closes the session after the request completes.
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
@router.get("/users")
|
||||||
|
async def get_users(db: AsyncSession = Depends(get_db)):
|
||||||
|
result = await db.execute(select(User))
|
||||||
|
return result.scalars().all()
|
||||||
"""
|
"""
|
||||||
db = SessionLocal()
|
async with SessionLocal() as session:
|
||||||
try:
|
try:
|
||||||
yield db
|
yield session
|
||||||
finally:
|
finally:
|
||||||
db.close()
|
await session.close()
|
||||||
|
|
||||||
|
|
||||||
@contextmanager
|
@asynccontextmanager
|
||||||
def transaction_scope() -> Generator[Session, None, None]:
|
async def async_transaction_scope() -> AsyncGenerator[AsyncSession, None]:
|
||||||
"""
|
"""
|
||||||
Provide a transactional scope for database operations.
|
Provide an async transactional scope for database operations.
|
||||||
|
|
||||||
Automatically commits on success or rolls back on exception.
|
Automatically commits on success or rolls back on exception.
|
||||||
Useful for grouping multiple operations in a single transaction.
|
Useful for grouping multiple operations in a single transaction.
|
||||||
|
|
||||||
Usage:
|
Usage:
|
||||||
with transaction_scope() as db:
|
async with async_transaction_scope() as db:
|
||||||
user = user_crud.create(db, obj_in=user_create)
|
user = await user_crud.create(db, obj_in=user_create)
|
||||||
profile = profile_crud.create(db, obj_in=profile_create)
|
profile = await profile_crud.create(db, obj_in=profile_create)
|
||||||
# Both operations committed together
|
# Both operations committed together
|
||||||
"""
|
"""
|
||||||
db = SessionLocal()
|
async with SessionLocal() as session:
|
||||||
try:
|
try:
|
||||||
yield db
|
yield session
|
||||||
db.commit()
|
await session.commit()
|
||||||
logger.debug("Transaction committed successfully")
|
logger.debug("Async transaction committed successfully")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
db.rollback()
|
await session.rollback()
|
||||||
logger.error(f"Transaction failed, rolling back: {str(e)}")
|
logger.error(f"Async transaction failed, rolling back: {str(e)}")
|
||||||
raise
|
raise
|
||||||
finally:
|
finally:
|
||||||
db.close()
|
await session.close()
|
||||||
|
|
||||||
|
|
||||||
def check_database_health() -> bool:
|
async def check_async_database_health() -> bool:
|
||||||
"""
|
"""
|
||||||
Check if database connection is healthy.
|
Check if async database connection is healthy.
|
||||||
Returns True if connection is successful, False otherwise.
|
Returns True if connection is successful, False otherwise.
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
with transaction_scope() as db:
|
async with async_transaction_scope() as db:
|
||||||
db.execute(text("SELECT 1"))
|
await db.execute(text("SELECT 1"))
|
||||||
return True
|
return True
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Database health check failed: {str(e)}")
|
logger.error(f"Async database health check failed: {str(e)}")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
# Alias for consistency with main.py
|
||||||
|
check_database_health = check_async_database_health
|
||||||
|
|
||||||
|
|
||||||
|
async def init_async_db() -> None:
|
||||||
|
"""
|
||||||
|
Initialize async database tables.
|
||||||
|
|
||||||
|
This creates all tables defined in the models.
|
||||||
|
Should only be used in development or testing.
|
||||||
|
In production, use Alembic migrations.
|
||||||
|
"""
|
||||||
|
async with engine.begin() as conn:
|
||||||
|
await conn.run_sync(Base.metadata.create_all)
|
||||||
|
logger.info("Async database tables created")
|
||||||
|
|
||||||
|
|
||||||
|
async def close_async_db() -> None:
|
||||||
|
"""
|
||||||
|
Close all async database connections.
|
||||||
|
|
||||||
|
Should be called during application shutdown.
|
||||||
|
"""
|
||||||
|
await engine.dispose()
|
||||||
|
logger.info("Async database connections closed")
|
||||||
|
|||||||
@@ -1,186 +0,0 @@
|
|||||||
# app/core/database_async.py
|
|
||||||
"""
|
|
||||||
Async database configuration using SQLAlchemy 2.0 and asyncpg.
|
|
||||||
|
|
||||||
This module provides async database connectivity with proper connection pooling
|
|
||||||
and session management for FastAPI endpoints.
|
|
||||||
"""
|
|
||||||
import logging
|
|
||||||
from contextlib import asynccontextmanager
|
|
||||||
from typing import AsyncGenerator
|
|
||||||
|
|
||||||
from sqlalchemy import text
|
|
||||||
from sqlalchemy.dialects.postgresql import JSONB, UUID
|
|
||||||
from sqlalchemy.ext.asyncio import (
|
|
||||||
AsyncSession,
|
|
||||||
AsyncEngine,
|
|
||||||
create_async_engine,
|
|
||||||
async_sessionmaker,
|
|
||||||
)
|
|
||||||
from sqlalchemy.ext.compiler import compiles
|
|
||||||
from sqlalchemy.orm import DeclarativeBase
|
|
||||||
|
|
||||||
from app.core.config import settings
|
|
||||||
|
|
||||||
# Configure logging
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
# SQLite compatibility for testing
|
|
||||||
@compiles(JSONB, 'sqlite')
|
|
||||||
def compile_jsonb_sqlite(type_, compiler, **kw):
|
|
||||||
return "TEXT"
|
|
||||||
|
|
||||||
|
|
||||||
@compiles(UUID, 'sqlite')
|
|
||||||
def compile_uuid_sqlite(type_, compiler, **kw):
|
|
||||||
return "TEXT"
|
|
||||||
|
|
||||||
|
|
||||||
# Declarative base for models (SQLAlchemy 2.0 style)
|
|
||||||
class Base(DeclarativeBase):
|
|
||||||
"""Base class for all database models."""
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
def get_async_database_url(url: str) -> str:
|
|
||||||
"""
|
|
||||||
Convert sync database URL to async URL.
|
|
||||||
|
|
||||||
postgresql:// -> postgresql+asyncpg://
|
|
||||||
sqlite:// -> sqlite+aiosqlite://
|
|
||||||
"""
|
|
||||||
if url.startswith("postgresql://"):
|
|
||||||
return url.replace("postgresql://", "postgresql+asyncpg://")
|
|
||||||
elif url.startswith("sqlite://"):
|
|
||||||
return url.replace("sqlite://", "sqlite+aiosqlite://")
|
|
||||||
return url
|
|
||||||
|
|
||||||
|
|
||||||
# Create async engine with optimized settings
|
|
||||||
def create_async_production_engine() -> AsyncEngine:
|
|
||||||
"""Create an async database engine with production settings."""
|
|
||||||
async_url = get_async_database_url(settings.database_url)
|
|
||||||
|
|
||||||
# Base engine config
|
|
||||||
engine_config = {
|
|
||||||
"pool_size": settings.db_pool_size,
|
|
||||||
"max_overflow": settings.db_max_overflow,
|
|
||||||
"pool_timeout": settings.db_pool_timeout,
|
|
||||||
"pool_recycle": settings.db_pool_recycle,
|
|
||||||
"pool_pre_ping": True,
|
|
||||||
"echo": settings.sql_echo,
|
|
||||||
"echo_pool": settings.sql_echo_pool,
|
|
||||||
}
|
|
||||||
|
|
||||||
# Add PostgreSQL-specific connect_args
|
|
||||||
if "postgresql" in async_url:
|
|
||||||
engine_config["connect_args"] = {
|
|
||||||
"server_settings": {
|
|
||||||
"application_name": "eventspace",
|
|
||||||
"timezone": "UTC",
|
|
||||||
},
|
|
||||||
# asyncpg-specific settings
|
|
||||||
"command_timeout": 60,
|
|
||||||
"timeout": 10,
|
|
||||||
}
|
|
||||||
|
|
||||||
return create_async_engine(async_url, **engine_config)
|
|
||||||
|
|
||||||
|
|
||||||
# Create async engine and session factory
|
|
||||||
async_engine = create_async_production_engine()
|
|
||||||
AsyncSessionLocal = async_sessionmaker(
|
|
||||||
async_engine,
|
|
||||||
class_=AsyncSession,
|
|
||||||
autocommit=False,
|
|
||||||
autoflush=False,
|
|
||||||
expire_on_commit=False, # Prevent unnecessary queries after commit
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# FastAPI dependency for async database sessions
|
|
||||||
async def get_async_db() -> AsyncGenerator[AsyncSession, None]:
|
|
||||||
"""
|
|
||||||
FastAPI dependency that provides an async database session.
|
|
||||||
Automatically closes the session after the request completes.
|
|
||||||
|
|
||||||
Usage:
|
|
||||||
@router.get("/users")
|
|
||||||
async def get_users(db: AsyncSession = Depends(get_async_db)):
|
|
||||||
result = await db.execute(select(User))
|
|
||||||
return result.scalars().all()
|
|
||||||
"""
|
|
||||||
async with AsyncSessionLocal() as session:
|
|
||||||
try:
|
|
||||||
yield session
|
|
||||||
finally:
|
|
||||||
await session.close()
|
|
||||||
|
|
||||||
|
|
||||||
@asynccontextmanager
|
|
||||||
async def async_transaction_scope() -> AsyncGenerator[AsyncSession, None]:
|
|
||||||
"""
|
|
||||||
Provide an async transactional scope for database operations.
|
|
||||||
|
|
||||||
Automatically commits on success or rolls back on exception.
|
|
||||||
Useful for grouping multiple operations in a single transaction.
|
|
||||||
|
|
||||||
Usage:
|
|
||||||
async with async_transaction_scope() as db:
|
|
||||||
user = await user_crud.create(db, obj_in=user_create)
|
|
||||||
profile = await profile_crud.create(db, obj_in=profile_create)
|
|
||||||
# Both operations committed together
|
|
||||||
"""
|
|
||||||
async with AsyncSessionLocal() as session:
|
|
||||||
try:
|
|
||||||
yield session
|
|
||||||
await session.commit()
|
|
||||||
logger.debug("Async transaction committed successfully")
|
|
||||||
except Exception as e:
|
|
||||||
await session.rollback()
|
|
||||||
logger.error(f"Async transaction failed, rolling back: {str(e)}")
|
|
||||||
raise
|
|
||||||
finally:
|
|
||||||
await session.close()
|
|
||||||
|
|
||||||
|
|
||||||
async def check_async_database_health() -> bool:
|
|
||||||
"""
|
|
||||||
Check if async database connection is healthy.
|
|
||||||
Returns True if connection is successful, False otherwise.
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
async with async_transaction_scope() as db:
|
|
||||||
await db.execute(text("SELECT 1"))
|
|
||||||
return True
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Async database health check failed: {str(e)}")
|
|
||||||
return False
|
|
||||||
|
|
||||||
|
|
||||||
# Alias for consistency with main.py
|
|
||||||
check_database_health = check_async_database_health
|
|
||||||
|
|
||||||
|
|
||||||
async def init_async_db() -> None:
|
|
||||||
"""
|
|
||||||
Initialize async database tables.
|
|
||||||
|
|
||||||
This creates all tables defined in the models.
|
|
||||||
Should only be used in development or testing.
|
|
||||||
In production, use Alembic migrations.
|
|
||||||
"""
|
|
||||||
async with async_engine.begin() as conn:
|
|
||||||
await conn.run_sync(Base.metadata.create_all)
|
|
||||||
logger.info("Async database tables created")
|
|
||||||
|
|
||||||
|
|
||||||
async def close_async_db() -> None:
|
|
||||||
"""
|
|
||||||
Close all async database connections.
|
|
||||||
|
|
||||||
Should be called during application shutdown.
|
|
||||||
"""
|
|
||||||
await async_engine.dispose()
|
|
||||||
logger.info("Async database connections closed")
|
|
||||||
207
backend/app/crud/base.py
Normal file → Executable file
207
backend/app/crud/base.py
Normal file → Executable file
@@ -1,13 +1,19 @@
|
|||||||
|
# app/crud/base_async.py
|
||||||
|
"""
|
||||||
|
Async CRUD operations base class using SQLAlchemy 2.0 async patterns.
|
||||||
|
|
||||||
|
Provides reusable create, read, update, and delete operations for all models.
|
||||||
|
"""
|
||||||
import logging
|
import logging
|
||||||
import uuid
|
import uuid
|
||||||
from datetime import datetime, timezone
|
|
||||||
from typing import Any, Dict, Generic, List, Optional, Type, TypeVar, Union, Tuple
|
from typing import Any, Dict, Generic, List, Optional, Type, TypeVar, Union, Tuple
|
||||||
|
|
||||||
from fastapi.encoders import jsonable_encoder
|
from fastapi.encoders import jsonable_encoder
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from sqlalchemy import asc, desc
|
from sqlalchemy import func, select
|
||||||
from sqlalchemy.exc import IntegrityError, OperationalError, DataError
|
from sqlalchemy.exc import IntegrityError, OperationalError, DataError
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
from sqlalchemy.orm import Load
|
||||||
|
|
||||||
from app.core.database import Base
|
from app.core.database import Base
|
||||||
|
|
||||||
@@ -19,17 +25,40 @@ UpdateSchemaType = TypeVar("UpdateSchemaType", bound=BaseModel)
|
|||||||
|
|
||||||
|
|
||||||
class CRUDBase(Generic[ModelType, CreateSchemaType, UpdateSchemaType]):
|
class CRUDBase(Generic[ModelType, CreateSchemaType, UpdateSchemaType]):
|
||||||
|
"""Async CRUD operations for a model."""
|
||||||
|
|
||||||
def __init__(self, model: Type[ModelType]):
|
def __init__(self, model: Type[ModelType]):
|
||||||
"""
|
"""
|
||||||
CRUD object with default methods to Create, Read, Update, Delete (CRUD).
|
CRUD object with default async methods to Create, Read, Update, Delete.
|
||||||
|
|
||||||
Parameters:
|
Parameters:
|
||||||
model: A SQLAlchemy model class
|
model: A SQLAlchemy model class
|
||||||
"""
|
"""
|
||||||
self.model = model
|
self.model = model
|
||||||
|
|
||||||
def get(self, db: Session, id: str) -> Optional[ModelType]:
|
async def get(
|
||||||
"""Get a single record by ID with UUID validation."""
|
self,
|
||||||
|
db: AsyncSession,
|
||||||
|
id: str,
|
||||||
|
options: Optional[List[Load]] = None
|
||||||
|
) -> Optional[ModelType]:
|
||||||
|
"""
|
||||||
|
Get a single record by ID with UUID validation and optional eager loading.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
db: Database session
|
||||||
|
id: Record UUID
|
||||||
|
options: Optional list of SQLAlchemy load options (e.g., joinedload, selectinload)
|
||||||
|
for eager loading relationships to prevent N+1 queries
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Model instance or None if not found
|
||||||
|
|
||||||
|
Example:
|
||||||
|
# Eager load user relationship
|
||||||
|
from sqlalchemy.orm import joinedload
|
||||||
|
session = await session_crud.get(db, id=session_id, options=[joinedload(UserSession.user)])
|
||||||
|
"""
|
||||||
# Validate UUID format and convert to UUID object if string
|
# Validate UUID format and convert to UUID object if string
|
||||||
try:
|
try:
|
||||||
if isinstance(id, uuid.UUID):
|
if isinstance(id, uuid.UUID):
|
||||||
@@ -41,15 +70,39 @@ class CRUDBase(Generic[ModelType, CreateSchemaType, UpdateSchemaType]):
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
try:
|
try:
|
||||||
return db.query(self.model).filter(self.model.id == uuid_obj).first()
|
query = select(self.model).where(self.model.id == uuid_obj)
|
||||||
|
|
||||||
|
# Apply eager loading options if provided
|
||||||
|
if options:
|
||||||
|
for option in options:
|
||||||
|
query = query.options(option)
|
||||||
|
|
||||||
|
result = await db.execute(query)
|
||||||
|
return result.scalar_one_or_none()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error retrieving {self.model.__name__} with id {id}: {str(e)}")
|
logger.error(f"Error retrieving {self.model.__name__} with id {id}: {str(e)}")
|
||||||
raise
|
raise
|
||||||
|
|
||||||
def get_multi(
|
async def get_multi(
|
||||||
self, db: Session, *, skip: int = 0, limit: int = 100
|
self,
|
||||||
|
db: AsyncSession,
|
||||||
|
*,
|
||||||
|
skip: int = 0,
|
||||||
|
limit: int = 100,
|
||||||
|
options: Optional[List[Load]] = None
|
||||||
) -> List[ModelType]:
|
) -> List[ModelType]:
|
||||||
"""Get multiple records with pagination validation."""
|
"""
|
||||||
|
Get multiple records with pagination validation and optional eager loading.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
db: Database session
|
||||||
|
skip: Number of records to skip
|
||||||
|
limit: Maximum number of records to return
|
||||||
|
options: Optional list of SQLAlchemy load options for eager loading
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of model instances
|
||||||
|
"""
|
||||||
# Validate pagination parameters
|
# Validate pagination parameters
|
||||||
if skip < 0:
|
if skip < 0:
|
||||||
raise ValueError("skip must be non-negative")
|
raise ValueError("skip must be non-negative")
|
||||||
@@ -59,22 +112,30 @@ class CRUDBase(Generic[ModelType, CreateSchemaType, UpdateSchemaType]):
|
|||||||
raise ValueError("Maximum limit is 1000")
|
raise ValueError("Maximum limit is 1000")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
return db.query(self.model).offset(skip).limit(limit).all()
|
query = select(self.model).offset(skip).limit(limit)
|
||||||
|
|
||||||
|
# Apply eager loading options if provided
|
||||||
|
if options:
|
||||||
|
for option in options:
|
||||||
|
query = query.options(option)
|
||||||
|
|
||||||
|
result = await db.execute(query)
|
||||||
|
return list(result.scalars().all())
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error retrieving multiple {self.model.__name__} records: {str(e)}")
|
logger.error(f"Error retrieving multiple {self.model.__name__} records: {str(e)}")
|
||||||
raise
|
raise
|
||||||
|
|
||||||
def create(self, db: Session, *, obj_in: CreateSchemaType) -> ModelType:
|
async def create(self, db: AsyncSession, *, obj_in: CreateSchemaType) -> ModelType:
|
||||||
"""Create a new record with error handling."""
|
"""Create a new record with error handling."""
|
||||||
try:
|
try:
|
||||||
obj_in_data = jsonable_encoder(obj_in)
|
obj_in_data = jsonable_encoder(obj_in)
|
||||||
db_obj = self.model(**obj_in_data)
|
db_obj = self.model(**obj_in_data)
|
||||||
db.add(db_obj)
|
db.add(db_obj)
|
||||||
db.commit()
|
await db.commit()
|
||||||
db.refresh(db_obj)
|
await db.refresh(db_obj)
|
||||||
return db_obj
|
return db_obj
|
||||||
except IntegrityError as e:
|
except IntegrityError as e:
|
||||||
db.rollback()
|
await db.rollback()
|
||||||
error_msg = str(e.orig) if hasattr(e, 'orig') else str(e)
|
error_msg = str(e.orig) if hasattr(e, 'orig') else str(e)
|
||||||
if "unique" in error_msg.lower() or "duplicate" in error_msg.lower():
|
if "unique" in error_msg.lower() or "duplicate" in error_msg.lower():
|
||||||
logger.warning(f"Duplicate entry attempted for {self.model.__name__}: {error_msg}")
|
logger.warning(f"Duplicate entry attempted for {self.model.__name__}: {error_msg}")
|
||||||
@@ -82,20 +143,20 @@ class CRUDBase(Generic[ModelType, CreateSchemaType, UpdateSchemaType]):
|
|||||||
logger.error(f"Integrity error creating {self.model.__name__}: {error_msg}")
|
logger.error(f"Integrity error creating {self.model.__name__}: {error_msg}")
|
||||||
raise ValueError(f"Database integrity error: {error_msg}")
|
raise ValueError(f"Database integrity error: {error_msg}")
|
||||||
except (OperationalError, DataError) as e:
|
except (OperationalError, DataError) as e:
|
||||||
db.rollback()
|
await db.rollback()
|
||||||
logger.error(f"Database error creating {self.model.__name__}: {str(e)}")
|
logger.error(f"Database error creating {self.model.__name__}: {str(e)}")
|
||||||
raise ValueError(f"Database operation failed: {str(e)}")
|
raise ValueError(f"Database operation failed: {str(e)}")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
db.rollback()
|
await db.rollback()
|
||||||
logger.error(f"Unexpected error creating {self.model.__name__}: {str(e)}", exc_info=True)
|
logger.error(f"Unexpected error creating {self.model.__name__}: {str(e)}", exc_info=True)
|
||||||
raise
|
raise
|
||||||
|
|
||||||
def update(
|
async def update(
|
||||||
self,
|
self,
|
||||||
db: Session,
|
db: AsyncSession,
|
||||||
*,
|
*,
|
||||||
db_obj: ModelType,
|
db_obj: ModelType,
|
||||||
obj_in: Union[UpdateSchemaType, Dict[str, Any]]
|
obj_in: Union[UpdateSchemaType, Dict[str, Any]]
|
||||||
) -> ModelType:
|
) -> ModelType:
|
||||||
"""Update a record with error handling."""
|
"""Update a record with error handling."""
|
||||||
try:
|
try:
|
||||||
@@ -104,15 +165,17 @@ class CRUDBase(Generic[ModelType, CreateSchemaType, UpdateSchemaType]):
|
|||||||
update_data = obj_in
|
update_data = obj_in
|
||||||
else:
|
else:
|
||||||
update_data = obj_in.model_dump(exclude_unset=True)
|
update_data = obj_in.model_dump(exclude_unset=True)
|
||||||
|
|
||||||
for field in obj_data:
|
for field in obj_data:
|
||||||
if field in update_data:
|
if field in update_data:
|
||||||
setattr(db_obj, field, update_data[field])
|
setattr(db_obj, field, update_data[field])
|
||||||
|
|
||||||
db.add(db_obj)
|
db.add(db_obj)
|
||||||
db.commit()
|
await db.commit()
|
||||||
db.refresh(db_obj)
|
await db.refresh(db_obj)
|
||||||
return db_obj
|
return db_obj
|
||||||
except IntegrityError as e:
|
except IntegrityError as e:
|
||||||
db.rollback()
|
await db.rollback()
|
||||||
error_msg = str(e.orig) if hasattr(e, 'orig') else str(e)
|
error_msg = str(e.orig) if hasattr(e, 'orig') else str(e)
|
||||||
if "unique" in error_msg.lower() or "duplicate" in error_msg.lower():
|
if "unique" in error_msg.lower() or "duplicate" in error_msg.lower():
|
||||||
logger.warning(f"Duplicate entry attempted for {self.model.__name__}: {error_msg}")
|
logger.warning(f"Duplicate entry attempted for {self.model.__name__}: {error_msg}")
|
||||||
@@ -120,15 +183,15 @@ class CRUDBase(Generic[ModelType, CreateSchemaType, UpdateSchemaType]):
|
|||||||
logger.error(f"Integrity error updating {self.model.__name__}: {error_msg}")
|
logger.error(f"Integrity error updating {self.model.__name__}: {error_msg}")
|
||||||
raise ValueError(f"Database integrity error: {error_msg}")
|
raise ValueError(f"Database integrity error: {error_msg}")
|
||||||
except (OperationalError, DataError) as e:
|
except (OperationalError, DataError) as e:
|
||||||
db.rollback()
|
await db.rollback()
|
||||||
logger.error(f"Database error updating {self.model.__name__}: {str(e)}")
|
logger.error(f"Database error updating {self.model.__name__}: {str(e)}")
|
||||||
raise ValueError(f"Database operation failed: {str(e)}")
|
raise ValueError(f"Database operation failed: {str(e)}")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
db.rollback()
|
await db.rollback()
|
||||||
logger.error(f"Unexpected error updating {self.model.__name__}: {str(e)}", exc_info=True)
|
logger.error(f"Unexpected error updating {self.model.__name__}: {str(e)}", exc_info=True)
|
||||||
raise
|
raise
|
||||||
|
|
||||||
def remove(self, db: Session, *, id: str) -> Optional[ModelType]:
|
async def remove(self, db: AsyncSession, *, id: str) -> Optional[ModelType]:
|
||||||
"""Delete a record with error handling and null check."""
|
"""Delete a record with error handling and null check."""
|
||||||
# Validate UUID format and convert to UUID object if string
|
# Validate UUID format and convert to UUID object if string
|
||||||
try:
|
try:
|
||||||
@@ -141,27 +204,31 @@ class CRUDBase(Generic[ModelType, CreateSchemaType, UpdateSchemaType]):
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
try:
|
try:
|
||||||
obj = db.query(self.model).filter(self.model.id == uuid_obj).first()
|
result = await db.execute(
|
||||||
|
select(self.model).where(self.model.id == uuid_obj)
|
||||||
|
)
|
||||||
|
obj = result.scalar_one_or_none()
|
||||||
|
|
||||||
if obj is None:
|
if obj is None:
|
||||||
logger.warning(f"{self.model.__name__} with id {id} not found for deletion")
|
logger.warning(f"{self.model.__name__} with id {id} not found for deletion")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
db.delete(obj)
|
await db.delete(obj)
|
||||||
db.commit()
|
await db.commit()
|
||||||
return obj
|
return obj
|
||||||
except IntegrityError as e:
|
except IntegrityError as e:
|
||||||
db.rollback()
|
await db.rollback()
|
||||||
error_msg = str(e.orig) if hasattr(e, 'orig') else str(e)
|
error_msg = str(e.orig) if hasattr(e, 'orig') else str(e)
|
||||||
logger.error(f"Integrity error deleting {self.model.__name__}: {error_msg}")
|
logger.error(f"Integrity error deleting {self.model.__name__}: {error_msg}")
|
||||||
raise ValueError(f"Cannot delete {self.model.__name__}: referenced by other records")
|
raise ValueError(f"Cannot delete {self.model.__name__}: referenced by other records")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
db.rollback()
|
await db.rollback()
|
||||||
logger.error(f"Error deleting {self.model.__name__} with id {id}: {str(e)}", exc_info=True)
|
logger.error(f"Error deleting {self.model.__name__} with id {id}: {str(e)}", exc_info=True)
|
||||||
raise
|
raise
|
||||||
|
|
||||||
def get_multi_with_total(
|
async def get_multi_with_total(
|
||||||
self,
|
self,
|
||||||
db: Session,
|
db: AsyncSession,
|
||||||
*,
|
*,
|
||||||
skip: int = 0,
|
skip: int = 0,
|
||||||
limit: int = 100,
|
limit: int = 100,
|
||||||
@@ -193,43 +260,63 @@ class CRUDBase(Generic[ModelType, CreateSchemaType, UpdateSchemaType]):
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
# Build base query
|
# Build base query
|
||||||
query = db.query(self.model)
|
query = select(self.model)
|
||||||
|
|
||||||
# Exclude soft-deleted records by default
|
# Exclude soft-deleted records by default
|
||||||
if hasattr(self.model, 'deleted_at'):
|
if hasattr(self.model, 'deleted_at'):
|
||||||
query = query.filter(self.model.deleted_at.is_(None))
|
query = query.where(self.model.deleted_at.is_(None))
|
||||||
|
|
||||||
# Apply filters
|
# Apply filters
|
||||||
if filters:
|
if filters:
|
||||||
for field, value in filters.items():
|
for field, value in filters.items():
|
||||||
if hasattr(self.model, field) and value is not None:
|
if hasattr(self.model, field) and value is not None:
|
||||||
query = query.filter(getattr(self.model, field) == value)
|
query = query.where(getattr(self.model, field) == value)
|
||||||
|
|
||||||
# Get total count (before pagination)
|
# Get total count (before pagination)
|
||||||
total = query.count()
|
count_query = select(func.count()).select_from(query.alias())
|
||||||
|
count_result = await db.execute(count_query)
|
||||||
|
total = count_result.scalar_one()
|
||||||
|
|
||||||
# Apply sorting
|
# Apply sorting
|
||||||
if sort_by and hasattr(self.model, sort_by):
|
if sort_by and hasattr(self.model, sort_by):
|
||||||
sort_column = getattr(self.model, sort_by)
|
sort_column = getattr(self.model, sort_by)
|
||||||
if sort_order.lower() == "desc":
|
if sort_order.lower() == "desc":
|
||||||
query = query.order_by(desc(sort_column))
|
query = query.order_by(sort_column.desc())
|
||||||
else:
|
else:
|
||||||
query = query.order_by(asc(sort_column))
|
query = query.order_by(sort_column.asc())
|
||||||
|
|
||||||
# Apply pagination
|
# Apply pagination
|
||||||
items = query.offset(skip).limit(limit).all()
|
query = query.offset(skip).limit(limit)
|
||||||
|
items_result = await db.execute(query)
|
||||||
|
items = list(items_result.scalars().all())
|
||||||
|
|
||||||
return items, total
|
return items, total
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error retrieving paginated {self.model.__name__} records: {str(e)}")
|
logger.error(f"Error retrieving paginated {self.model.__name__} records: {str(e)}")
|
||||||
raise
|
raise
|
||||||
|
|
||||||
def soft_delete(self, db: Session, *, id: str) -> Optional[ModelType]:
|
async def count(self, db: AsyncSession) -> int:
|
||||||
|
"""Get total count of records."""
|
||||||
|
try:
|
||||||
|
result = await db.execute(select(func.count(self.model.id)))
|
||||||
|
return result.scalar_one()
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error counting {self.model.__name__} records: {str(e)}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
async def exists(self, db: AsyncSession, id: str) -> bool:
|
||||||
|
"""Check if a record exists by ID."""
|
||||||
|
obj = await self.get(db, id=id)
|
||||||
|
return obj is not None
|
||||||
|
|
||||||
|
async def soft_delete(self, db: AsyncSession, *, id: str) -> Optional[ModelType]:
|
||||||
"""
|
"""
|
||||||
Soft delete a record by setting deleted_at timestamp.
|
Soft delete a record by setting deleted_at timestamp.
|
||||||
|
|
||||||
Only works if the model has a 'deleted_at' column.
|
Only works if the model has a 'deleted_at' column.
|
||||||
"""
|
"""
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
|
||||||
# Validate UUID format and convert to UUID object if string
|
# Validate UUID format and convert to UUID object if string
|
||||||
try:
|
try:
|
||||||
if isinstance(id, uuid.UUID):
|
if isinstance(id, uuid.UUID):
|
||||||
@@ -241,7 +328,10 @@ class CRUDBase(Generic[ModelType, CreateSchemaType, UpdateSchemaType]):
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
try:
|
try:
|
||||||
obj = db.query(self.model).filter(self.model.id == uuid_obj).first()
|
result = await db.execute(
|
||||||
|
select(self.model).where(self.model.id == uuid_obj)
|
||||||
|
)
|
||||||
|
obj = result.scalar_one_or_none()
|
||||||
|
|
||||||
if obj is None:
|
if obj is None:
|
||||||
logger.warning(f"{self.model.__name__} with id {id} not found for soft deletion")
|
logger.warning(f"{self.model.__name__} with id {id} not found for soft deletion")
|
||||||
@@ -255,15 +345,15 @@ class CRUDBase(Generic[ModelType, CreateSchemaType, UpdateSchemaType]):
|
|||||||
# Set deleted_at timestamp
|
# Set deleted_at timestamp
|
||||||
obj.deleted_at = datetime.now(timezone.utc)
|
obj.deleted_at = datetime.now(timezone.utc)
|
||||||
db.add(obj)
|
db.add(obj)
|
||||||
db.commit()
|
await db.commit()
|
||||||
db.refresh(obj)
|
await db.refresh(obj)
|
||||||
return obj
|
return obj
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
db.rollback()
|
await db.rollback()
|
||||||
logger.error(f"Error soft deleting {self.model.__name__} with id {id}: {str(e)}", exc_info=True)
|
logger.error(f"Error soft deleting {self.model.__name__} with id {id}: {str(e)}", exc_info=True)
|
||||||
raise
|
raise
|
||||||
|
|
||||||
def restore(self, db: Session, *, id: str) -> Optional[ModelType]:
|
async def restore(self, db: AsyncSession, *, id: str) -> Optional[ModelType]:
|
||||||
"""
|
"""
|
||||||
Restore a soft-deleted record by clearing the deleted_at timestamp.
|
Restore a soft-deleted record by clearing the deleted_at timestamp.
|
||||||
|
|
||||||
@@ -282,10 +372,13 @@ class CRUDBase(Generic[ModelType, CreateSchemaType, UpdateSchemaType]):
|
|||||||
try:
|
try:
|
||||||
# Find the soft-deleted record
|
# Find the soft-deleted record
|
||||||
if hasattr(self.model, 'deleted_at'):
|
if hasattr(self.model, 'deleted_at'):
|
||||||
obj = db.query(self.model).filter(
|
result = await db.execute(
|
||||||
self.model.id == uuid_obj,
|
select(self.model).where(
|
||||||
self.model.deleted_at.isnot(None)
|
self.model.id == uuid_obj,
|
||||||
).first()
|
self.model.deleted_at.isnot(None)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
obj = result.scalar_one_or_none()
|
||||||
else:
|
else:
|
||||||
logger.error(f"{self.model.__name__} does not support soft deletes")
|
logger.error(f"{self.model.__name__} does not support soft deletes")
|
||||||
raise ValueError(f"{self.model.__name__} does not have a deleted_at column")
|
raise ValueError(f"{self.model.__name__} does not have a deleted_at column")
|
||||||
@@ -297,10 +390,10 @@ class CRUDBase(Generic[ModelType, CreateSchemaType, UpdateSchemaType]):
|
|||||||
# Clear deleted_at timestamp
|
# Clear deleted_at timestamp
|
||||||
obj.deleted_at = None
|
obj.deleted_at = None
|
||||||
db.add(obj)
|
db.add(obj)
|
||||||
db.commit()
|
await db.commit()
|
||||||
db.refresh(obj)
|
await db.refresh(obj)
|
||||||
return obj
|
return obj
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
db.rollback()
|
await db.rollback()
|
||||||
logger.error(f"Error restoring {self.model.__name__} with id {id}: {str(e)}", exc_info=True)
|
logger.error(f"Error restoring {self.model.__name__} with id {id}: {str(e)}", exc_info=True)
|
||||||
raise
|
raise
|
||||||
|
|||||||
@@ -1,399 +0,0 @@
|
|||||||
# app/crud/base_async.py
|
|
||||||
"""
|
|
||||||
Async CRUD operations base class using SQLAlchemy 2.0 async patterns.
|
|
||||||
|
|
||||||
Provides reusable create, read, update, and delete operations for all models.
|
|
||||||
"""
|
|
||||||
import logging
|
|
||||||
import uuid
|
|
||||||
from typing import Any, Dict, Generic, List, Optional, Type, TypeVar, Union, Tuple
|
|
||||||
|
|
||||||
from fastapi.encoders import jsonable_encoder
|
|
||||||
from pydantic import BaseModel
|
|
||||||
from sqlalchemy import func, select
|
|
||||||
from sqlalchemy.exc import IntegrityError, OperationalError, DataError
|
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
|
||||||
from sqlalchemy.orm import Load
|
|
||||||
|
|
||||||
from app.core.database_async import Base
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
ModelType = TypeVar("ModelType", bound=Base)
|
|
||||||
CreateSchemaType = TypeVar("CreateSchemaType", bound=BaseModel)
|
|
||||||
UpdateSchemaType = TypeVar("UpdateSchemaType", bound=BaseModel)
|
|
||||||
|
|
||||||
|
|
||||||
class CRUDBaseAsync(Generic[ModelType, CreateSchemaType, UpdateSchemaType]):
|
|
||||||
"""Async CRUD operations for a model."""
|
|
||||||
|
|
||||||
def __init__(self, model: Type[ModelType]):
|
|
||||||
"""
|
|
||||||
CRUD object with default async methods to Create, Read, Update, Delete.
|
|
||||||
|
|
||||||
Parameters:
|
|
||||||
model: A SQLAlchemy model class
|
|
||||||
"""
|
|
||||||
self.model = model
|
|
||||||
|
|
||||||
async def get(
|
|
||||||
self,
|
|
||||||
db: AsyncSession,
|
|
||||||
id: str,
|
|
||||||
options: Optional[List[Load]] = None
|
|
||||||
) -> Optional[ModelType]:
|
|
||||||
"""
|
|
||||||
Get a single record by ID with UUID validation and optional eager loading.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
db: Database session
|
|
||||||
id: Record UUID
|
|
||||||
options: Optional list of SQLAlchemy load options (e.g., joinedload, selectinload)
|
|
||||||
for eager loading relationships to prevent N+1 queries
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Model instance or None if not found
|
|
||||||
|
|
||||||
Example:
|
|
||||||
# Eager load user relationship
|
|
||||||
from sqlalchemy.orm import joinedload
|
|
||||||
session = await session_crud.get(db, id=session_id, options=[joinedload(UserSession.user)])
|
|
||||||
"""
|
|
||||||
# Validate UUID format and convert to UUID object if string
|
|
||||||
try:
|
|
||||||
if isinstance(id, uuid.UUID):
|
|
||||||
uuid_obj = id
|
|
||||||
else:
|
|
||||||
uuid_obj = uuid.UUID(str(id))
|
|
||||||
except (ValueError, AttributeError, TypeError) as e:
|
|
||||||
logger.warning(f"Invalid UUID format: {id} - {str(e)}")
|
|
||||||
return None
|
|
||||||
|
|
||||||
try:
|
|
||||||
query = select(self.model).where(self.model.id == uuid_obj)
|
|
||||||
|
|
||||||
# Apply eager loading options if provided
|
|
||||||
if options:
|
|
||||||
for option in options:
|
|
||||||
query = query.options(option)
|
|
||||||
|
|
||||||
result = await db.execute(query)
|
|
||||||
return result.scalar_one_or_none()
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error retrieving {self.model.__name__} with id {id}: {str(e)}")
|
|
||||||
raise
|
|
||||||
|
|
||||||
async def get_multi(
|
|
||||||
self,
|
|
||||||
db: AsyncSession,
|
|
||||||
*,
|
|
||||||
skip: int = 0,
|
|
||||||
limit: int = 100,
|
|
||||||
options: Optional[List[Load]] = None
|
|
||||||
) -> List[ModelType]:
|
|
||||||
"""
|
|
||||||
Get multiple records with pagination validation and optional eager loading.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
db: Database session
|
|
||||||
skip: Number of records to skip
|
|
||||||
limit: Maximum number of records to return
|
|
||||||
options: Optional list of SQLAlchemy load options for eager loading
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List of model instances
|
|
||||||
"""
|
|
||||||
# Validate pagination parameters
|
|
||||||
if skip < 0:
|
|
||||||
raise ValueError("skip must be non-negative")
|
|
||||||
if limit < 0:
|
|
||||||
raise ValueError("limit must be non-negative")
|
|
||||||
if limit > 1000:
|
|
||||||
raise ValueError("Maximum limit is 1000")
|
|
||||||
|
|
||||||
try:
|
|
||||||
query = select(self.model).offset(skip).limit(limit)
|
|
||||||
|
|
||||||
# Apply eager loading options if provided
|
|
||||||
if options:
|
|
||||||
for option in options:
|
|
||||||
query = query.options(option)
|
|
||||||
|
|
||||||
result = await db.execute(query)
|
|
||||||
return list(result.scalars().all())
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error retrieving multiple {self.model.__name__} records: {str(e)}")
|
|
||||||
raise
|
|
||||||
|
|
||||||
async def create(self, db: AsyncSession, *, obj_in: CreateSchemaType) -> ModelType:
|
|
||||||
"""Create a new record with error handling."""
|
|
||||||
try:
|
|
||||||
obj_in_data = jsonable_encoder(obj_in)
|
|
||||||
db_obj = self.model(**obj_in_data)
|
|
||||||
db.add(db_obj)
|
|
||||||
await db.commit()
|
|
||||||
await db.refresh(db_obj)
|
|
||||||
return db_obj
|
|
||||||
except IntegrityError as e:
|
|
||||||
await db.rollback()
|
|
||||||
error_msg = str(e.orig) if hasattr(e, 'orig') else str(e)
|
|
||||||
if "unique" in error_msg.lower() or "duplicate" in error_msg.lower():
|
|
||||||
logger.warning(f"Duplicate entry attempted for {self.model.__name__}: {error_msg}")
|
|
||||||
raise ValueError(f"A {self.model.__name__} with this data already exists")
|
|
||||||
logger.error(f"Integrity error creating {self.model.__name__}: {error_msg}")
|
|
||||||
raise ValueError(f"Database integrity error: {error_msg}")
|
|
||||||
except (OperationalError, DataError) as e:
|
|
||||||
await db.rollback()
|
|
||||||
logger.error(f"Database error creating {self.model.__name__}: {str(e)}")
|
|
||||||
raise ValueError(f"Database operation failed: {str(e)}")
|
|
||||||
except Exception as e:
|
|
||||||
await db.rollback()
|
|
||||||
logger.error(f"Unexpected error creating {self.model.__name__}: {str(e)}", exc_info=True)
|
|
||||||
raise
|
|
||||||
|
|
||||||
async def update(
|
|
||||||
self,
|
|
||||||
db: AsyncSession,
|
|
||||||
*,
|
|
||||||
db_obj: ModelType,
|
|
||||||
obj_in: Union[UpdateSchemaType, Dict[str, Any]]
|
|
||||||
) -> ModelType:
|
|
||||||
"""Update a record with error handling."""
|
|
||||||
try:
|
|
||||||
obj_data = jsonable_encoder(db_obj)
|
|
||||||
if isinstance(obj_in, dict):
|
|
||||||
update_data = obj_in
|
|
||||||
else:
|
|
||||||
update_data = obj_in.model_dump(exclude_unset=True)
|
|
||||||
|
|
||||||
for field in obj_data:
|
|
||||||
if field in update_data:
|
|
||||||
setattr(db_obj, field, update_data[field])
|
|
||||||
|
|
||||||
db.add(db_obj)
|
|
||||||
await db.commit()
|
|
||||||
await db.refresh(db_obj)
|
|
||||||
return db_obj
|
|
||||||
except IntegrityError as e:
|
|
||||||
await db.rollback()
|
|
||||||
error_msg = str(e.orig) if hasattr(e, 'orig') else str(e)
|
|
||||||
if "unique" in error_msg.lower() or "duplicate" in error_msg.lower():
|
|
||||||
logger.warning(f"Duplicate entry attempted for {self.model.__name__}: {error_msg}")
|
|
||||||
raise ValueError(f"A {self.model.__name__} with this data already exists")
|
|
||||||
logger.error(f"Integrity error updating {self.model.__name__}: {error_msg}")
|
|
||||||
raise ValueError(f"Database integrity error: {error_msg}")
|
|
||||||
except (OperationalError, DataError) as e:
|
|
||||||
await db.rollback()
|
|
||||||
logger.error(f"Database error updating {self.model.__name__}: {str(e)}")
|
|
||||||
raise ValueError(f"Database operation failed: {str(e)}")
|
|
||||||
except Exception as e:
|
|
||||||
await db.rollback()
|
|
||||||
logger.error(f"Unexpected error updating {self.model.__name__}: {str(e)}", exc_info=True)
|
|
||||||
raise
|
|
||||||
|
|
||||||
async def remove(self, db: AsyncSession, *, id: str) -> Optional[ModelType]:
|
|
||||||
"""Delete a record with error handling and null check."""
|
|
||||||
# Validate UUID format and convert to UUID object if string
|
|
||||||
try:
|
|
||||||
if isinstance(id, uuid.UUID):
|
|
||||||
uuid_obj = id
|
|
||||||
else:
|
|
||||||
uuid_obj = uuid.UUID(str(id))
|
|
||||||
except (ValueError, AttributeError, TypeError) as e:
|
|
||||||
logger.warning(f"Invalid UUID format for deletion: {id} - {str(e)}")
|
|
||||||
return None
|
|
||||||
|
|
||||||
try:
|
|
||||||
result = await db.execute(
|
|
||||||
select(self.model).where(self.model.id == uuid_obj)
|
|
||||||
)
|
|
||||||
obj = result.scalar_one_or_none()
|
|
||||||
|
|
||||||
if obj is None:
|
|
||||||
logger.warning(f"{self.model.__name__} with id {id} not found for deletion")
|
|
||||||
return None
|
|
||||||
|
|
||||||
await db.delete(obj)
|
|
||||||
await db.commit()
|
|
||||||
return obj
|
|
||||||
except IntegrityError as e:
|
|
||||||
await db.rollback()
|
|
||||||
error_msg = str(e.orig) if hasattr(e, 'orig') else str(e)
|
|
||||||
logger.error(f"Integrity error deleting {self.model.__name__}: {error_msg}")
|
|
||||||
raise ValueError(f"Cannot delete {self.model.__name__}: referenced by other records")
|
|
||||||
except Exception as e:
|
|
||||||
await db.rollback()
|
|
||||||
logger.error(f"Error deleting {self.model.__name__} with id {id}: {str(e)}", exc_info=True)
|
|
||||||
raise
|
|
||||||
|
|
||||||
async def get_multi_with_total(
|
|
||||||
self,
|
|
||||||
db: AsyncSession,
|
|
||||||
*,
|
|
||||||
skip: int = 0,
|
|
||||||
limit: int = 100,
|
|
||||||
sort_by: Optional[str] = None,
|
|
||||||
sort_order: str = "asc",
|
|
||||||
filters: Optional[Dict[str, Any]] = None
|
|
||||||
) -> Tuple[List[ModelType], int]:
|
|
||||||
"""
|
|
||||||
Get multiple records with total count, filtering, and sorting.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
db: Database session
|
|
||||||
skip: Number of records to skip
|
|
||||||
limit: Maximum number of records to return
|
|
||||||
sort_by: Field name to sort by (must be a valid model attribute)
|
|
||||||
sort_order: Sort order ("asc" or "desc")
|
|
||||||
filters: Dictionary of filters (field_name: value)
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Tuple of (items, total_count)
|
|
||||||
"""
|
|
||||||
# Validate pagination parameters
|
|
||||||
if skip < 0:
|
|
||||||
raise ValueError("skip must be non-negative")
|
|
||||||
if limit < 0:
|
|
||||||
raise ValueError("limit must be non-negative")
|
|
||||||
if limit > 1000:
|
|
||||||
raise ValueError("Maximum limit is 1000")
|
|
||||||
|
|
||||||
try:
|
|
||||||
# Build base query
|
|
||||||
query = select(self.model)
|
|
||||||
|
|
||||||
# Exclude soft-deleted records by default
|
|
||||||
if hasattr(self.model, 'deleted_at'):
|
|
||||||
query = query.where(self.model.deleted_at.is_(None))
|
|
||||||
|
|
||||||
# Apply filters
|
|
||||||
if filters:
|
|
||||||
for field, value in filters.items():
|
|
||||||
if hasattr(self.model, field) and value is not None:
|
|
||||||
query = query.where(getattr(self.model, field) == value)
|
|
||||||
|
|
||||||
# Get total count (before pagination)
|
|
||||||
count_query = select(func.count()).select_from(query.alias())
|
|
||||||
count_result = await db.execute(count_query)
|
|
||||||
total = count_result.scalar_one()
|
|
||||||
|
|
||||||
# Apply sorting
|
|
||||||
if sort_by and hasattr(self.model, sort_by):
|
|
||||||
sort_column = getattr(self.model, sort_by)
|
|
||||||
if sort_order.lower() == "desc":
|
|
||||||
query = query.order_by(sort_column.desc())
|
|
||||||
else:
|
|
||||||
query = query.order_by(sort_column.asc())
|
|
||||||
|
|
||||||
# Apply pagination
|
|
||||||
query = query.offset(skip).limit(limit)
|
|
||||||
items_result = await db.execute(query)
|
|
||||||
items = list(items_result.scalars().all())
|
|
||||||
|
|
||||||
return items, total
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error retrieving paginated {self.model.__name__} records: {str(e)}")
|
|
||||||
raise
|
|
||||||
|
|
||||||
async def count(self, db: AsyncSession) -> int:
|
|
||||||
"""Get total count of records."""
|
|
||||||
try:
|
|
||||||
result = await db.execute(select(func.count(self.model.id)))
|
|
||||||
return result.scalar_one()
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error counting {self.model.__name__} records: {str(e)}")
|
|
||||||
raise
|
|
||||||
|
|
||||||
async def exists(self, db: AsyncSession, id: str) -> bool:
|
|
||||||
"""Check if a record exists by ID."""
|
|
||||||
obj = await self.get(db, id=id)
|
|
||||||
return obj is not None
|
|
||||||
|
|
||||||
async def soft_delete(self, db: AsyncSession, *, id: str) -> Optional[ModelType]:
|
|
||||||
"""
|
|
||||||
Soft delete a record by setting deleted_at timestamp.
|
|
||||||
|
|
||||||
Only works if the model has a 'deleted_at' column.
|
|
||||||
"""
|
|
||||||
from datetime import datetime, timezone
|
|
||||||
|
|
||||||
# Validate UUID format and convert to UUID object if string
|
|
||||||
try:
|
|
||||||
if isinstance(id, uuid.UUID):
|
|
||||||
uuid_obj = id
|
|
||||||
else:
|
|
||||||
uuid_obj = uuid.UUID(str(id))
|
|
||||||
except (ValueError, AttributeError, TypeError) as e:
|
|
||||||
logger.warning(f"Invalid UUID format for soft deletion: {id} - {str(e)}")
|
|
||||||
return None
|
|
||||||
|
|
||||||
try:
|
|
||||||
result = await db.execute(
|
|
||||||
select(self.model).where(self.model.id == uuid_obj)
|
|
||||||
)
|
|
||||||
obj = result.scalar_one_or_none()
|
|
||||||
|
|
||||||
if obj is None:
|
|
||||||
logger.warning(f"{self.model.__name__} with id {id} not found for soft deletion")
|
|
||||||
return None
|
|
||||||
|
|
||||||
# Check if model supports soft deletes
|
|
||||||
if not hasattr(self.model, 'deleted_at'):
|
|
||||||
logger.error(f"{self.model.__name__} does not support soft deletes")
|
|
||||||
raise ValueError(f"{self.model.__name__} does not have a deleted_at column")
|
|
||||||
|
|
||||||
# Set deleted_at timestamp
|
|
||||||
obj.deleted_at = datetime.now(timezone.utc)
|
|
||||||
db.add(obj)
|
|
||||||
await db.commit()
|
|
||||||
await db.refresh(obj)
|
|
||||||
return obj
|
|
||||||
except Exception as e:
|
|
||||||
await db.rollback()
|
|
||||||
logger.error(f"Error soft deleting {self.model.__name__} with id {id}: {str(e)}", exc_info=True)
|
|
||||||
raise
|
|
||||||
|
|
||||||
async def restore(self, db: AsyncSession, *, id: str) -> Optional[ModelType]:
|
|
||||||
"""
|
|
||||||
Restore a soft-deleted record by clearing the deleted_at timestamp.
|
|
||||||
|
|
||||||
Only works if the model has a 'deleted_at' column.
|
|
||||||
"""
|
|
||||||
# Validate UUID format
|
|
||||||
try:
|
|
||||||
if isinstance(id, uuid.UUID):
|
|
||||||
uuid_obj = id
|
|
||||||
else:
|
|
||||||
uuid_obj = uuid.UUID(str(id))
|
|
||||||
except (ValueError, AttributeError, TypeError) as e:
|
|
||||||
logger.warning(f"Invalid UUID format for restoration: {id} - {str(e)}")
|
|
||||||
return None
|
|
||||||
|
|
||||||
try:
|
|
||||||
# Find the soft-deleted record
|
|
||||||
if hasattr(self.model, 'deleted_at'):
|
|
||||||
result = await db.execute(
|
|
||||||
select(self.model).where(
|
|
||||||
self.model.id == uuid_obj,
|
|
||||||
self.model.deleted_at.isnot(None)
|
|
||||||
)
|
|
||||||
)
|
|
||||||
obj = result.scalar_one_or_none()
|
|
||||||
else:
|
|
||||||
logger.error(f"{self.model.__name__} does not support soft deletes")
|
|
||||||
raise ValueError(f"{self.model.__name__} does not have a deleted_at column")
|
|
||||||
|
|
||||||
if obj is None:
|
|
||||||
logger.warning(f"Soft-deleted {self.model.__name__} with id {id} not found for restoration")
|
|
||||||
return None
|
|
||||||
|
|
||||||
# Clear deleted_at timestamp
|
|
||||||
obj.deleted_at = None
|
|
||||||
db.add(obj)
|
|
||||||
await db.commit()
|
|
||||||
await db.refresh(obj)
|
|
||||||
return obj
|
|
||||||
except Exception as e:
|
|
||||||
await db.rollback()
|
|
||||||
logger.error(f"Error restoring {self.model.__name__} with id {id}: {str(e)}", exc_info=True)
|
|
||||||
raise
|
|
||||||
434
backend/app/crud/organization.py
Normal file → Executable file
434
backend/app/crud/organization.py
Normal file → Executable file
@@ -1,11 +1,12 @@
|
|||||||
# app/crud/organization.py
|
# app/crud/organization_async.py
|
||||||
|
"""Async CRUD operations for Organization model using SQLAlchemy 2.0 patterns."""
|
||||||
import logging
|
import logging
|
||||||
from typing import Optional, List, Dict, Any
|
from typing import Optional, List, Dict, Any
|
||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
|
|
||||||
from sqlalchemy import func, or_, and_
|
from sqlalchemy import func, or_, and_, select
|
||||||
from sqlalchemy.exc import IntegrityError
|
from sqlalchemy.exc import IntegrityError
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
from app.crud.base import CRUDBase
|
from app.crud.base import CRUDBase
|
||||||
from app.models.organization import Organization
|
from app.models.organization import Organization
|
||||||
@@ -13,20 +14,27 @@ from app.models.user import User
|
|||||||
from app.models.user_organization import UserOrganization, OrganizationRole
|
from app.models.user_organization import UserOrganization, OrganizationRole
|
||||||
from app.schemas.organizations import (
|
from app.schemas.organizations import (
|
||||||
OrganizationCreate,
|
OrganizationCreate,
|
||||||
OrganizationUpdate
|
OrganizationUpdate,
|
||||||
)
|
)
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUpdate]):
|
class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUpdate]):
|
||||||
"""CRUD operations for Organization model."""
|
"""Async CRUD operations for Organization model."""
|
||||||
|
|
||||||
def get_by_slug(self, db: Session, *, slug: str) -> Optional[Organization]:
|
async def get_by_slug(self, db: AsyncSession, *, slug: str) -> Optional[Organization]:
|
||||||
"""Get organization by slug."""
|
"""Get organization by slug."""
|
||||||
return db.query(Organization).filter(Organization.slug == slug).first()
|
try:
|
||||||
|
result = await db.execute(
|
||||||
|
select(Organization).where(Organization.slug == slug)
|
||||||
|
)
|
||||||
|
return result.scalar_one_or_none()
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error getting organization by slug {slug}: {str(e)}")
|
||||||
|
raise
|
||||||
|
|
||||||
def create(self, db: Session, *, obj_in: OrganizationCreate) -> Organization:
|
async def create(self, db: AsyncSession, *, obj_in: OrganizationCreate) -> Organization:
|
||||||
"""Create a new organization with error handling."""
|
"""Create a new organization with error handling."""
|
||||||
try:
|
try:
|
||||||
db_obj = Organization(
|
db_obj = Organization(
|
||||||
@@ -37,11 +45,11 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp
|
|||||||
settings=obj_in.settings or {}
|
settings=obj_in.settings or {}
|
||||||
)
|
)
|
||||||
db.add(db_obj)
|
db.add(db_obj)
|
||||||
db.commit()
|
await db.commit()
|
||||||
db.refresh(db_obj)
|
await db.refresh(db_obj)
|
||||||
return db_obj
|
return db_obj
|
||||||
except IntegrityError as e:
|
except IntegrityError as e:
|
||||||
db.rollback()
|
await db.rollback()
|
||||||
error_msg = str(e.orig) if hasattr(e, 'orig') else str(e)
|
error_msg = str(e.orig) if hasattr(e, 'orig') else str(e)
|
||||||
if "slug" in error_msg.lower():
|
if "slug" in error_msg.lower():
|
||||||
logger.warning(f"Duplicate slug attempted: {obj_in.slug}")
|
logger.warning(f"Duplicate slug attempted: {obj_in.slug}")
|
||||||
@@ -49,13 +57,13 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp
|
|||||||
logger.error(f"Integrity error creating organization: {error_msg}")
|
logger.error(f"Integrity error creating organization: {error_msg}")
|
||||||
raise ValueError(f"Database integrity error: {error_msg}")
|
raise ValueError(f"Database integrity error: {error_msg}")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
db.rollback()
|
await db.rollback()
|
||||||
logger.error(f"Unexpected error creating organization: {str(e)}", exc_info=True)
|
logger.error(f"Unexpected error creating organization: {str(e)}", exc_info=True)
|
||||||
raise
|
raise
|
||||||
|
|
||||||
def get_multi_with_filters(
|
async def get_multi_with_filters(
|
||||||
self,
|
self,
|
||||||
db: Session,
|
db: AsyncSession,
|
||||||
*,
|
*,
|
||||||
skip: int = 0,
|
skip: int = 0,
|
||||||
limit: int = 100,
|
limit: int = 100,
|
||||||
@@ -70,47 +78,139 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp
|
|||||||
Returns:
|
Returns:
|
||||||
Tuple of (organizations list, total count)
|
Tuple of (organizations list, total count)
|
||||||
"""
|
"""
|
||||||
query = db.query(Organization)
|
try:
|
||||||
|
query = select(Organization)
|
||||||
|
|
||||||
# Apply filters
|
# Apply filters
|
||||||
if is_active is not None:
|
if is_active is not None:
|
||||||
query = query.filter(Organization.is_active == is_active)
|
query = query.where(Organization.is_active == is_active)
|
||||||
|
|
||||||
if search:
|
if search:
|
||||||
search_filter = or_(
|
search_filter = or_(
|
||||||
Organization.name.ilike(f"%{search}%"),
|
Organization.name.ilike(f"%{search}%"),
|
||||||
Organization.slug.ilike(f"%{search}%"),
|
Organization.slug.ilike(f"%{search}%"),
|
||||||
Organization.description.ilike(f"%{search}%")
|
Organization.description.ilike(f"%{search}%")
|
||||||
)
|
)
|
||||||
query = query.filter(search_filter)
|
query = query.where(search_filter)
|
||||||
|
|
||||||
# Get total count before pagination
|
# Get total count before pagination
|
||||||
total = query.count()
|
count_query = select(func.count()).select_from(query.alias())
|
||||||
|
count_result = await db.execute(count_query)
|
||||||
|
total = count_result.scalar_one()
|
||||||
|
|
||||||
# Apply sorting
|
# Apply sorting
|
||||||
sort_column = getattr(Organization, sort_by, Organization.created_at)
|
sort_column = getattr(Organization, sort_by, Organization.created_at)
|
||||||
if sort_order == "desc":
|
if sort_order == "desc":
|
||||||
query = query.order_by(sort_column.desc())
|
query = query.order_by(sort_column.desc())
|
||||||
else:
|
else:
|
||||||
query = query.order_by(sort_column.asc())
|
query = query.order_by(sort_column.asc())
|
||||||
|
|
||||||
# Apply pagination
|
# Apply pagination
|
||||||
organizations = query.offset(skip).limit(limit).all()
|
query = query.offset(skip).limit(limit)
|
||||||
|
result = await db.execute(query)
|
||||||
|
organizations = list(result.scalars().all())
|
||||||
|
|
||||||
return organizations, total
|
return organizations, total
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error getting organizations with filters: {str(e)}")
|
||||||
|
raise
|
||||||
|
|
||||||
def get_member_count(self, db: Session, *, organization_id: UUID) -> int:
|
async def get_member_count(self, db: AsyncSession, *, organization_id: UUID) -> int:
|
||||||
"""Get the count of active members in an organization."""
|
"""Get the count of active members in an organization."""
|
||||||
return db.query(func.count(UserOrganization.user_id)).filter(
|
try:
|
||||||
and_(
|
result = await db.execute(
|
||||||
UserOrganization.organization_id == organization_id,
|
select(func.count(UserOrganization.user_id)).where(
|
||||||
UserOrganization.is_active == True
|
and_(
|
||||||
|
UserOrganization.organization_id == organization_id,
|
||||||
|
UserOrganization.is_active == True
|
||||||
|
)
|
||||||
|
)
|
||||||
)
|
)
|
||||||
).scalar() or 0
|
return result.scalar_one() or 0
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error getting member count for organization {organization_id}: {str(e)}")
|
||||||
|
raise
|
||||||
|
|
||||||
def add_user(
|
async def get_multi_with_member_counts(
|
||||||
self,
|
self,
|
||||||
db: Session,
|
db: AsyncSession,
|
||||||
|
*,
|
||||||
|
skip: int = 0,
|
||||||
|
limit: int = 100,
|
||||||
|
is_active: Optional[bool] = None,
|
||||||
|
search: Optional[str] = None
|
||||||
|
) -> tuple[List[Dict[str, Any]], int]:
|
||||||
|
"""
|
||||||
|
Get organizations with member counts in a SINGLE QUERY using JOIN and GROUP BY.
|
||||||
|
This eliminates the N+1 query problem.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (list of dicts with org and member_count, total count)
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Build base query with LEFT JOIN and GROUP BY
|
||||||
|
query = (
|
||||||
|
select(
|
||||||
|
Organization,
|
||||||
|
func.count(
|
||||||
|
func.distinct(
|
||||||
|
and_(
|
||||||
|
UserOrganization.is_active == True,
|
||||||
|
UserOrganization.user_id
|
||||||
|
).self_group()
|
||||||
|
)
|
||||||
|
).label('member_count')
|
||||||
|
)
|
||||||
|
.outerjoin(UserOrganization, Organization.id == UserOrganization.organization_id)
|
||||||
|
.group_by(Organization.id)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Apply filters
|
||||||
|
if is_active is not None:
|
||||||
|
query = query.where(Organization.is_active == is_active)
|
||||||
|
|
||||||
|
if search:
|
||||||
|
search_filter = or_(
|
||||||
|
Organization.name.ilike(f"%{search}%"),
|
||||||
|
Organization.slug.ilike(f"%{search}%"),
|
||||||
|
Organization.description.ilike(f"%{search}%")
|
||||||
|
)
|
||||||
|
query = query.where(search_filter)
|
||||||
|
|
||||||
|
# Get total count
|
||||||
|
count_query = select(func.count(Organization.id))
|
||||||
|
if is_active is not None:
|
||||||
|
count_query = count_query.where(Organization.is_active == is_active)
|
||||||
|
if search:
|
||||||
|
count_query = count_query.where(search_filter)
|
||||||
|
|
||||||
|
count_result = await db.execute(count_query)
|
||||||
|
total = count_result.scalar_one()
|
||||||
|
|
||||||
|
# Apply pagination and ordering
|
||||||
|
query = query.order_by(Organization.created_at.desc()).offset(skip).limit(limit)
|
||||||
|
|
||||||
|
result = await db.execute(query)
|
||||||
|
rows = result.all()
|
||||||
|
|
||||||
|
# Convert to list of dicts
|
||||||
|
orgs_with_counts = [
|
||||||
|
{
|
||||||
|
'organization': org,
|
||||||
|
'member_count': member_count
|
||||||
|
}
|
||||||
|
for org, member_count in rows
|
||||||
|
]
|
||||||
|
|
||||||
|
return orgs_with_counts, total
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error getting organizations with member counts: {str(e)}", exc_info=True)
|
||||||
|
raise
|
||||||
|
|
||||||
|
async def add_user(
|
||||||
|
self,
|
||||||
|
db: AsyncSession,
|
||||||
*,
|
*,
|
||||||
organization_id: UUID,
|
organization_id: UUID,
|
||||||
user_id: UUID,
|
user_id: UUID,
|
||||||
@@ -120,12 +220,15 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp
|
|||||||
"""Add a user to an organization with a specific role."""
|
"""Add a user to an organization with a specific role."""
|
||||||
try:
|
try:
|
||||||
# Check if relationship already exists
|
# Check if relationship already exists
|
||||||
existing = db.query(UserOrganization).filter(
|
result = await db.execute(
|
||||||
and_(
|
select(UserOrganization).where(
|
||||||
UserOrganization.user_id == user_id,
|
and_(
|
||||||
UserOrganization.organization_id == organization_id
|
UserOrganization.user_id == user_id,
|
||||||
|
UserOrganization.organization_id == organization_id
|
||||||
|
)
|
||||||
)
|
)
|
||||||
).first()
|
)
|
||||||
|
existing = result.scalar_one_or_none()
|
||||||
|
|
||||||
if existing:
|
if existing:
|
||||||
# Reactivate if inactive, or raise error if already active
|
# Reactivate if inactive, or raise error if already active
|
||||||
@@ -133,8 +236,8 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp
|
|||||||
existing.is_active = True
|
existing.is_active = True
|
||||||
existing.role = role
|
existing.role = role
|
||||||
existing.custom_permissions = custom_permissions
|
existing.custom_permissions = custom_permissions
|
||||||
db.commit()
|
await db.commit()
|
||||||
db.refresh(existing)
|
await db.refresh(existing)
|
||||||
return existing
|
return existing
|
||||||
else:
|
else:
|
||||||
raise ValueError("User is already a member of this organization")
|
raise ValueError("User is already a member of this organization")
|
||||||
@@ -148,48 +251,51 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp
|
|||||||
custom_permissions=custom_permissions
|
custom_permissions=custom_permissions
|
||||||
)
|
)
|
||||||
db.add(user_org)
|
db.add(user_org)
|
||||||
db.commit()
|
await db.commit()
|
||||||
db.refresh(user_org)
|
await db.refresh(user_org)
|
||||||
return user_org
|
return user_org
|
||||||
except IntegrityError as e:
|
except IntegrityError as e:
|
||||||
db.rollback()
|
await db.rollback()
|
||||||
logger.error(f"Integrity error adding user to organization: {str(e)}")
|
logger.error(f"Integrity error adding user to organization: {str(e)}")
|
||||||
raise ValueError("Failed to add user to organization")
|
raise ValueError("Failed to add user to organization")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
db.rollback()
|
await db.rollback()
|
||||||
logger.error(f"Error adding user to organization: {str(e)}", exc_info=True)
|
logger.error(f"Error adding user to organization: {str(e)}", exc_info=True)
|
||||||
raise
|
raise
|
||||||
|
|
||||||
def remove_user(
|
async def remove_user(
|
||||||
self,
|
self,
|
||||||
db: Session,
|
db: AsyncSession,
|
||||||
*,
|
*,
|
||||||
organization_id: UUID,
|
organization_id: UUID,
|
||||||
user_id: UUID
|
user_id: UUID
|
||||||
) -> bool:
|
) -> bool:
|
||||||
"""Remove a user from an organization (soft delete)."""
|
"""Remove a user from an organization (soft delete)."""
|
||||||
try:
|
try:
|
||||||
user_org = db.query(UserOrganization).filter(
|
result = await db.execute(
|
||||||
and_(
|
select(UserOrganization).where(
|
||||||
UserOrganization.user_id == user_id,
|
and_(
|
||||||
UserOrganization.organization_id == organization_id
|
UserOrganization.user_id == user_id,
|
||||||
|
UserOrganization.organization_id == organization_id
|
||||||
|
)
|
||||||
)
|
)
|
||||||
).first()
|
)
|
||||||
|
user_org = result.scalar_one_or_none()
|
||||||
|
|
||||||
if not user_org:
|
if not user_org:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
user_org.is_active = False
|
user_org.is_active = False
|
||||||
db.commit()
|
await db.commit()
|
||||||
return True
|
return True
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
db.rollback()
|
await db.rollback()
|
||||||
logger.error(f"Error removing user from organization: {str(e)}", exc_info=True)
|
logger.error(f"Error removing user from organization: {str(e)}", exc_info=True)
|
||||||
raise
|
raise
|
||||||
|
|
||||||
def update_user_role(
|
async def update_user_role(
|
||||||
self,
|
self,
|
||||||
db: Session,
|
db: AsyncSession,
|
||||||
*,
|
*,
|
||||||
organization_id: UUID,
|
organization_id: UUID,
|
||||||
user_id: UUID,
|
user_id: UUID,
|
||||||
@@ -198,12 +304,15 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp
|
|||||||
) -> Optional[UserOrganization]:
|
) -> Optional[UserOrganization]:
|
||||||
"""Update a user's role in an organization."""
|
"""Update a user's role in an organization."""
|
||||||
try:
|
try:
|
||||||
user_org = db.query(UserOrganization).filter(
|
result = await db.execute(
|
||||||
and_(
|
select(UserOrganization).where(
|
||||||
UserOrganization.user_id == user_id,
|
and_(
|
||||||
UserOrganization.organization_id == organization_id
|
UserOrganization.user_id == user_id,
|
||||||
|
UserOrganization.organization_id == organization_id
|
||||||
|
)
|
||||||
)
|
)
|
||||||
).first()
|
)
|
||||||
|
user_org = result.scalar_one_or_none()
|
||||||
|
|
||||||
if not user_org:
|
if not user_org:
|
||||||
return None
|
return None
|
||||||
@@ -211,17 +320,17 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp
|
|||||||
user_org.role = role
|
user_org.role = role
|
||||||
if custom_permissions is not None:
|
if custom_permissions is not None:
|
||||||
user_org.custom_permissions = custom_permissions
|
user_org.custom_permissions = custom_permissions
|
||||||
db.commit()
|
await db.commit()
|
||||||
db.refresh(user_org)
|
await db.refresh(user_org)
|
||||||
return user_org
|
return user_org
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
db.rollback()
|
await db.rollback()
|
||||||
logger.error(f"Error updating user role: {str(e)}", exc_info=True)
|
logger.error(f"Error updating user role: {str(e)}", exc_info=True)
|
||||||
raise
|
raise
|
||||||
|
|
||||||
def get_organization_members(
|
async def get_organization_members(
|
||||||
self,
|
self,
|
||||||
db: Session,
|
db: AsyncSession,
|
||||||
*,
|
*,
|
||||||
organization_id: UUID,
|
organization_id: UUID,
|
||||||
skip: int = 0,
|
skip: int = 0,
|
||||||
@@ -234,86 +343,175 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp
|
|||||||
Returns:
|
Returns:
|
||||||
Tuple of (members list with user details, total count)
|
Tuple of (members list with user details, total count)
|
||||||
"""
|
"""
|
||||||
query = db.query(UserOrganization, User).join(
|
try:
|
||||||
User, UserOrganization.user_id == User.id
|
# Build query with join
|
||||||
).filter(UserOrganization.organization_id == organization_id)
|
query = (
|
||||||
|
select(UserOrganization, User)
|
||||||
|
.join(User, UserOrganization.user_id == User.id)
|
||||||
|
.where(UserOrganization.organization_id == organization_id)
|
||||||
|
)
|
||||||
|
|
||||||
if is_active is not None:
|
if is_active is not None:
|
||||||
query = query.filter(UserOrganization.is_active == is_active)
|
query = query.where(UserOrganization.is_active == is_active)
|
||||||
|
|
||||||
total = query.count()
|
# Get total count
|
||||||
|
count_query = select(func.count()).select_from(
|
||||||
|
select(UserOrganization)
|
||||||
|
.where(UserOrganization.organization_id == organization_id)
|
||||||
|
.where(UserOrganization.is_active == is_active if is_active is not None else True)
|
||||||
|
.alias()
|
||||||
|
)
|
||||||
|
count_result = await db.execute(count_query)
|
||||||
|
total = count_result.scalar_one()
|
||||||
|
|
||||||
results = query.order_by(UserOrganization.created_at.desc()).offset(skip).limit(limit).all()
|
# Apply ordering and pagination
|
||||||
|
query = query.order_by(UserOrganization.created_at.desc()).offset(skip).limit(limit)
|
||||||
|
result = await db.execute(query)
|
||||||
|
results = result.all()
|
||||||
|
|
||||||
members = []
|
members = []
|
||||||
for user_org, user in results:
|
for user_org, user in results:
|
||||||
members.append({
|
members.append({
|
||||||
"user_id": user.id,
|
"user_id": user.id,
|
||||||
"email": user.email,
|
"email": user.email,
|
||||||
"first_name": user.first_name,
|
"first_name": user.first_name,
|
||||||
"last_name": user.last_name,
|
"last_name": user.last_name,
|
||||||
"role": user_org.role,
|
"role": user_org.role,
|
||||||
"is_active": user_org.is_active,
|
"is_active": user_org.is_active,
|
||||||
"joined_at": user_org.created_at
|
"joined_at": user_org.created_at
|
||||||
})
|
})
|
||||||
|
|
||||||
return members, total
|
return members, total
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error getting organization members: {str(e)}")
|
||||||
|
raise
|
||||||
|
|
||||||
def get_user_organizations(
|
async def get_user_organizations(
|
||||||
self,
|
self,
|
||||||
db: Session,
|
db: AsyncSession,
|
||||||
*,
|
*,
|
||||||
user_id: UUID,
|
user_id: UUID,
|
||||||
is_active: bool = True
|
is_active: bool = True
|
||||||
) -> List[Organization]:
|
) -> List[Organization]:
|
||||||
"""Get all organizations a user belongs to."""
|
"""Get all organizations a user belongs to."""
|
||||||
query = db.query(Organization).join(
|
try:
|
||||||
UserOrganization, Organization.id == UserOrganization.organization_id
|
query = (
|
||||||
).filter(UserOrganization.user_id == user_id)
|
select(Organization)
|
||||||
|
.join(UserOrganization, Organization.id == UserOrganization.organization_id)
|
||||||
|
.where(UserOrganization.user_id == user_id)
|
||||||
|
)
|
||||||
|
|
||||||
if is_active is not None:
|
if is_active is not None:
|
||||||
query = query.filter(UserOrganization.is_active == is_active)
|
query = query.where(UserOrganization.is_active == is_active)
|
||||||
|
|
||||||
return query.all()
|
result = await db.execute(query)
|
||||||
|
return list(result.scalars().all())
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error getting user organizations: {str(e)}")
|
||||||
|
raise
|
||||||
|
|
||||||
def get_user_role_in_org(
|
async def get_user_organizations_with_details(
|
||||||
self,
|
self,
|
||||||
db: Session,
|
db: AsyncSession,
|
||||||
|
*,
|
||||||
|
user_id: UUID,
|
||||||
|
is_active: bool = True
|
||||||
|
) -> List[Dict[str, Any]]:
|
||||||
|
"""
|
||||||
|
Get user's organizations with role and member count in SINGLE QUERY.
|
||||||
|
Eliminates N+1 problem by using subquery for member counts.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of dicts with organization, role, and member_count
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Subquery to get member counts for each organization
|
||||||
|
member_count_subq = (
|
||||||
|
select(
|
||||||
|
UserOrganization.organization_id,
|
||||||
|
func.count(UserOrganization.user_id).label('member_count')
|
||||||
|
)
|
||||||
|
.where(UserOrganization.is_active == True)
|
||||||
|
.group_by(UserOrganization.organization_id)
|
||||||
|
.subquery()
|
||||||
|
)
|
||||||
|
|
||||||
|
# Main query with JOIN to get org, role, and member count
|
||||||
|
query = (
|
||||||
|
select(
|
||||||
|
Organization,
|
||||||
|
UserOrganization.role,
|
||||||
|
func.coalesce(member_count_subq.c.member_count, 0).label('member_count')
|
||||||
|
)
|
||||||
|
.join(UserOrganization, Organization.id == UserOrganization.organization_id)
|
||||||
|
.outerjoin(member_count_subq, Organization.id == member_count_subq.c.organization_id)
|
||||||
|
.where(UserOrganization.user_id == user_id)
|
||||||
|
)
|
||||||
|
|
||||||
|
if is_active is not None:
|
||||||
|
query = query.where(UserOrganization.is_active == is_active)
|
||||||
|
|
||||||
|
result = await db.execute(query)
|
||||||
|
rows = result.all()
|
||||||
|
|
||||||
|
return [
|
||||||
|
{
|
||||||
|
'organization': org,
|
||||||
|
'role': role,
|
||||||
|
'member_count': member_count
|
||||||
|
}
|
||||||
|
for org, role, member_count in rows
|
||||||
|
]
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error getting user organizations with details: {str(e)}", exc_info=True)
|
||||||
|
raise
|
||||||
|
|
||||||
|
async def get_user_role_in_org(
|
||||||
|
self,
|
||||||
|
db: AsyncSession,
|
||||||
*,
|
*,
|
||||||
user_id: UUID,
|
user_id: UUID,
|
||||||
organization_id: UUID
|
organization_id: UUID
|
||||||
) -> Optional[OrganizationRole]:
|
) -> Optional[OrganizationRole]:
|
||||||
"""Get a user's role in a specific organization."""
|
"""Get a user's role in a specific organization."""
|
||||||
user_org = db.query(UserOrganization).filter(
|
try:
|
||||||
and_(
|
result = await db.execute(
|
||||||
UserOrganization.user_id == user_id,
|
select(UserOrganization).where(
|
||||||
UserOrganization.organization_id == organization_id,
|
and_(
|
||||||
UserOrganization.is_active == True
|
UserOrganization.user_id == user_id,
|
||||||
|
UserOrganization.organization_id == organization_id,
|
||||||
|
UserOrganization.is_active == True
|
||||||
|
)
|
||||||
|
)
|
||||||
)
|
)
|
||||||
).first()
|
user_org = result.scalar_one_or_none()
|
||||||
|
|
||||||
return user_org.role if user_org else None
|
return user_org.role if user_org else None
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error getting user role in org: {str(e)}")
|
||||||
|
raise
|
||||||
|
|
||||||
def is_user_org_owner(
|
async def is_user_org_owner(
|
||||||
self,
|
self,
|
||||||
db: Session,
|
db: AsyncSession,
|
||||||
*,
|
*,
|
||||||
user_id: UUID,
|
user_id: UUID,
|
||||||
organization_id: UUID
|
organization_id: UUID
|
||||||
) -> bool:
|
) -> bool:
|
||||||
"""Check if a user is an owner of an organization."""
|
"""Check if a user is an owner of an organization."""
|
||||||
role = self.get_user_role_in_org(db, user_id=user_id, organization_id=organization_id)
|
role = await self.get_user_role_in_org(db, user_id=user_id, organization_id=organization_id)
|
||||||
return role == OrganizationRole.OWNER
|
return role == OrganizationRole.OWNER
|
||||||
|
|
||||||
def is_user_org_admin(
|
async def is_user_org_admin(
|
||||||
self,
|
self,
|
||||||
db: Session,
|
db: AsyncSession,
|
||||||
*,
|
*,
|
||||||
user_id: UUID,
|
user_id: UUID,
|
||||||
organization_id: UUID
|
organization_id: UUID
|
||||||
) -> bool:
|
) -> bool:
|
||||||
"""Check if a user is an owner or admin of an organization."""
|
"""Check if a user is an owner or admin of an organization."""
|
||||||
role = self.get_user_role_in_org(db, user_id=user_id, organization_id=organization_id)
|
role = await self.get_user_role_in_org(db, user_id=user_id, organization_id=organization_id)
|
||||||
return role in [OrganizationRole.OWNER, OrganizationRole.ADMIN]
|
return role in [OrganizationRole.OWNER, OrganizationRole.ADMIN]
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -1,519 +0,0 @@
|
|||||||
# app/crud/organization_async.py
|
|
||||||
"""Async CRUD operations for Organization model using SQLAlchemy 2.0 patterns."""
|
|
||||||
import logging
|
|
||||||
from typing import Optional, List, Dict, Any
|
|
||||||
from uuid import UUID
|
|
||||||
|
|
||||||
from sqlalchemy import func, or_, and_, select
|
|
||||||
from sqlalchemy.exc import IntegrityError
|
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
|
||||||
|
|
||||||
from app.crud.base_async import CRUDBaseAsync
|
|
||||||
from app.models.organization import Organization
|
|
||||||
from app.models.user import User
|
|
||||||
from app.models.user_organization import UserOrganization, OrganizationRole
|
|
||||||
from app.schemas.organizations import (
|
|
||||||
OrganizationCreate,
|
|
||||||
OrganizationUpdate,
|
|
||||||
)
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
class CRUDOrganizationAsync(CRUDBaseAsync[Organization, OrganizationCreate, OrganizationUpdate]):
|
|
||||||
"""Async CRUD operations for Organization model."""
|
|
||||||
|
|
||||||
async def get_by_slug(self, db: AsyncSession, *, slug: str) -> Optional[Organization]:
|
|
||||||
"""Get organization by slug."""
|
|
||||||
try:
|
|
||||||
result = await db.execute(
|
|
||||||
select(Organization).where(Organization.slug == slug)
|
|
||||||
)
|
|
||||||
return result.scalar_one_or_none()
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error getting organization by slug {slug}: {str(e)}")
|
|
||||||
raise
|
|
||||||
|
|
||||||
async def create(self, db: AsyncSession, *, obj_in: OrganizationCreate) -> Organization:
|
|
||||||
"""Create a new organization with error handling."""
|
|
||||||
try:
|
|
||||||
db_obj = Organization(
|
|
||||||
name=obj_in.name,
|
|
||||||
slug=obj_in.slug,
|
|
||||||
description=obj_in.description,
|
|
||||||
is_active=obj_in.is_active,
|
|
||||||
settings=obj_in.settings or {}
|
|
||||||
)
|
|
||||||
db.add(db_obj)
|
|
||||||
await db.commit()
|
|
||||||
await db.refresh(db_obj)
|
|
||||||
return db_obj
|
|
||||||
except IntegrityError as e:
|
|
||||||
await db.rollback()
|
|
||||||
error_msg = str(e.orig) if hasattr(e, 'orig') else str(e)
|
|
||||||
if "slug" in error_msg.lower():
|
|
||||||
logger.warning(f"Duplicate slug attempted: {obj_in.slug}")
|
|
||||||
raise ValueError(f"Organization with slug '{obj_in.slug}' already exists")
|
|
||||||
logger.error(f"Integrity error creating organization: {error_msg}")
|
|
||||||
raise ValueError(f"Database integrity error: {error_msg}")
|
|
||||||
except Exception as e:
|
|
||||||
await db.rollback()
|
|
||||||
logger.error(f"Unexpected error creating organization: {str(e)}", exc_info=True)
|
|
||||||
raise
|
|
||||||
|
|
||||||
async def get_multi_with_filters(
|
|
||||||
self,
|
|
||||||
db: AsyncSession,
|
|
||||||
*,
|
|
||||||
skip: int = 0,
|
|
||||||
limit: int = 100,
|
|
||||||
is_active: Optional[bool] = None,
|
|
||||||
search: Optional[str] = None,
|
|
||||||
sort_by: str = "created_at",
|
|
||||||
sort_order: str = "desc"
|
|
||||||
) -> tuple[List[Organization], int]:
|
|
||||||
"""
|
|
||||||
Get multiple organizations with filtering, searching, and sorting.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Tuple of (organizations list, total count)
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
query = select(Organization)
|
|
||||||
|
|
||||||
# Apply filters
|
|
||||||
if is_active is not None:
|
|
||||||
query = query.where(Organization.is_active == is_active)
|
|
||||||
|
|
||||||
if search:
|
|
||||||
search_filter = or_(
|
|
||||||
Organization.name.ilike(f"%{search}%"),
|
|
||||||
Organization.slug.ilike(f"%{search}%"),
|
|
||||||
Organization.description.ilike(f"%{search}%")
|
|
||||||
)
|
|
||||||
query = query.where(search_filter)
|
|
||||||
|
|
||||||
# Get total count before pagination
|
|
||||||
count_query = select(func.count()).select_from(query.alias())
|
|
||||||
count_result = await db.execute(count_query)
|
|
||||||
total = count_result.scalar_one()
|
|
||||||
|
|
||||||
# Apply sorting
|
|
||||||
sort_column = getattr(Organization, sort_by, Organization.created_at)
|
|
||||||
if sort_order == "desc":
|
|
||||||
query = query.order_by(sort_column.desc())
|
|
||||||
else:
|
|
||||||
query = query.order_by(sort_column.asc())
|
|
||||||
|
|
||||||
# Apply pagination
|
|
||||||
query = query.offset(skip).limit(limit)
|
|
||||||
result = await db.execute(query)
|
|
||||||
organizations = list(result.scalars().all())
|
|
||||||
|
|
||||||
return organizations, total
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error getting organizations with filters: {str(e)}")
|
|
||||||
raise
|
|
||||||
|
|
||||||
async def get_member_count(self, db: AsyncSession, *, organization_id: UUID) -> int:
|
|
||||||
"""Get the count of active members in an organization."""
|
|
||||||
try:
|
|
||||||
result = await db.execute(
|
|
||||||
select(func.count(UserOrganization.user_id)).where(
|
|
||||||
and_(
|
|
||||||
UserOrganization.organization_id == organization_id,
|
|
||||||
UserOrganization.is_active == True
|
|
||||||
)
|
|
||||||
)
|
|
||||||
)
|
|
||||||
return result.scalar_one() or 0
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error getting member count for organization {organization_id}: {str(e)}")
|
|
||||||
raise
|
|
||||||
|
|
||||||
async def get_multi_with_member_counts(
|
|
||||||
self,
|
|
||||||
db: AsyncSession,
|
|
||||||
*,
|
|
||||||
skip: int = 0,
|
|
||||||
limit: int = 100,
|
|
||||||
is_active: Optional[bool] = None,
|
|
||||||
search: Optional[str] = None
|
|
||||||
) -> tuple[List[Dict[str, Any]], int]:
|
|
||||||
"""
|
|
||||||
Get organizations with member counts in a SINGLE QUERY using JOIN and GROUP BY.
|
|
||||||
This eliminates the N+1 query problem.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Tuple of (list of dicts with org and member_count, total count)
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
# Build base query with LEFT JOIN and GROUP BY
|
|
||||||
query = (
|
|
||||||
select(
|
|
||||||
Organization,
|
|
||||||
func.count(
|
|
||||||
func.distinct(
|
|
||||||
and_(
|
|
||||||
UserOrganization.is_active == True,
|
|
||||||
UserOrganization.user_id
|
|
||||||
).self_group()
|
|
||||||
)
|
|
||||||
).label('member_count')
|
|
||||||
)
|
|
||||||
.outerjoin(UserOrganization, Organization.id == UserOrganization.organization_id)
|
|
||||||
.group_by(Organization.id)
|
|
||||||
)
|
|
||||||
|
|
||||||
# Apply filters
|
|
||||||
if is_active is not None:
|
|
||||||
query = query.where(Organization.is_active == is_active)
|
|
||||||
|
|
||||||
if search:
|
|
||||||
search_filter = or_(
|
|
||||||
Organization.name.ilike(f"%{search}%"),
|
|
||||||
Organization.slug.ilike(f"%{search}%"),
|
|
||||||
Organization.description.ilike(f"%{search}%")
|
|
||||||
)
|
|
||||||
query = query.where(search_filter)
|
|
||||||
|
|
||||||
# Get total count
|
|
||||||
count_query = select(func.count(Organization.id))
|
|
||||||
if is_active is not None:
|
|
||||||
count_query = count_query.where(Organization.is_active == is_active)
|
|
||||||
if search:
|
|
||||||
count_query = count_query.where(search_filter)
|
|
||||||
|
|
||||||
count_result = await db.execute(count_query)
|
|
||||||
total = count_result.scalar_one()
|
|
||||||
|
|
||||||
# Apply pagination and ordering
|
|
||||||
query = query.order_by(Organization.created_at.desc()).offset(skip).limit(limit)
|
|
||||||
|
|
||||||
result = await db.execute(query)
|
|
||||||
rows = result.all()
|
|
||||||
|
|
||||||
# Convert to list of dicts
|
|
||||||
orgs_with_counts = [
|
|
||||||
{
|
|
||||||
'organization': org,
|
|
||||||
'member_count': member_count
|
|
||||||
}
|
|
||||||
for org, member_count in rows
|
|
||||||
]
|
|
||||||
|
|
||||||
return orgs_with_counts, total
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error getting organizations with member counts: {str(e)}", exc_info=True)
|
|
||||||
raise
|
|
||||||
|
|
||||||
async def add_user(
|
|
||||||
self,
|
|
||||||
db: AsyncSession,
|
|
||||||
*,
|
|
||||||
organization_id: UUID,
|
|
||||||
user_id: UUID,
|
|
||||||
role: OrganizationRole = OrganizationRole.MEMBER,
|
|
||||||
custom_permissions: Optional[str] = None
|
|
||||||
) -> UserOrganization:
|
|
||||||
"""Add a user to an organization with a specific role."""
|
|
||||||
try:
|
|
||||||
# Check if relationship already exists
|
|
||||||
result = await db.execute(
|
|
||||||
select(UserOrganization).where(
|
|
||||||
and_(
|
|
||||||
UserOrganization.user_id == user_id,
|
|
||||||
UserOrganization.organization_id == organization_id
|
|
||||||
)
|
|
||||||
)
|
|
||||||
)
|
|
||||||
existing = result.scalar_one_or_none()
|
|
||||||
|
|
||||||
if existing:
|
|
||||||
# Reactivate if inactive, or raise error if already active
|
|
||||||
if not existing.is_active:
|
|
||||||
existing.is_active = True
|
|
||||||
existing.role = role
|
|
||||||
existing.custom_permissions = custom_permissions
|
|
||||||
await db.commit()
|
|
||||||
await db.refresh(existing)
|
|
||||||
return existing
|
|
||||||
else:
|
|
||||||
raise ValueError("User is already a member of this organization")
|
|
||||||
|
|
||||||
# Create new relationship
|
|
||||||
user_org = UserOrganization(
|
|
||||||
user_id=user_id,
|
|
||||||
organization_id=organization_id,
|
|
||||||
role=role,
|
|
||||||
is_active=True,
|
|
||||||
custom_permissions=custom_permissions
|
|
||||||
)
|
|
||||||
db.add(user_org)
|
|
||||||
await db.commit()
|
|
||||||
await db.refresh(user_org)
|
|
||||||
return user_org
|
|
||||||
except IntegrityError as e:
|
|
||||||
await db.rollback()
|
|
||||||
logger.error(f"Integrity error adding user to organization: {str(e)}")
|
|
||||||
raise ValueError("Failed to add user to organization")
|
|
||||||
except Exception as e:
|
|
||||||
await db.rollback()
|
|
||||||
logger.error(f"Error adding user to organization: {str(e)}", exc_info=True)
|
|
||||||
raise
|
|
||||||
|
|
||||||
async def remove_user(
|
|
||||||
self,
|
|
||||||
db: AsyncSession,
|
|
||||||
*,
|
|
||||||
organization_id: UUID,
|
|
||||||
user_id: UUID
|
|
||||||
) -> bool:
|
|
||||||
"""Remove a user from an organization (soft delete)."""
|
|
||||||
try:
|
|
||||||
result = await db.execute(
|
|
||||||
select(UserOrganization).where(
|
|
||||||
and_(
|
|
||||||
UserOrganization.user_id == user_id,
|
|
||||||
UserOrganization.organization_id == organization_id
|
|
||||||
)
|
|
||||||
)
|
|
||||||
)
|
|
||||||
user_org = result.scalar_one_or_none()
|
|
||||||
|
|
||||||
if not user_org:
|
|
||||||
return False
|
|
||||||
|
|
||||||
user_org.is_active = False
|
|
||||||
await db.commit()
|
|
||||||
return True
|
|
||||||
except Exception as e:
|
|
||||||
await db.rollback()
|
|
||||||
logger.error(f"Error removing user from organization: {str(e)}", exc_info=True)
|
|
||||||
raise
|
|
||||||
|
|
||||||
async def update_user_role(
|
|
||||||
self,
|
|
||||||
db: AsyncSession,
|
|
||||||
*,
|
|
||||||
organization_id: UUID,
|
|
||||||
user_id: UUID,
|
|
||||||
role: OrganizationRole,
|
|
||||||
custom_permissions: Optional[str] = None
|
|
||||||
) -> Optional[UserOrganization]:
|
|
||||||
"""Update a user's role in an organization."""
|
|
||||||
try:
|
|
||||||
result = await db.execute(
|
|
||||||
select(UserOrganization).where(
|
|
||||||
and_(
|
|
||||||
UserOrganization.user_id == user_id,
|
|
||||||
UserOrganization.organization_id == organization_id
|
|
||||||
)
|
|
||||||
)
|
|
||||||
)
|
|
||||||
user_org = result.scalar_one_or_none()
|
|
||||||
|
|
||||||
if not user_org:
|
|
||||||
return None
|
|
||||||
|
|
||||||
user_org.role = role
|
|
||||||
if custom_permissions is not None:
|
|
||||||
user_org.custom_permissions = custom_permissions
|
|
||||||
await db.commit()
|
|
||||||
await db.refresh(user_org)
|
|
||||||
return user_org
|
|
||||||
except Exception as e:
|
|
||||||
await db.rollback()
|
|
||||||
logger.error(f"Error updating user role: {str(e)}", exc_info=True)
|
|
||||||
raise
|
|
||||||
|
|
||||||
async def get_organization_members(
|
|
||||||
self,
|
|
||||||
db: AsyncSession,
|
|
||||||
*,
|
|
||||||
organization_id: UUID,
|
|
||||||
skip: int = 0,
|
|
||||||
limit: int = 100,
|
|
||||||
is_active: bool = True
|
|
||||||
) -> tuple[List[Dict[str, Any]], int]:
|
|
||||||
"""
|
|
||||||
Get members of an organization with user details.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Tuple of (members list with user details, total count)
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
# Build query with join
|
|
||||||
query = (
|
|
||||||
select(UserOrganization, User)
|
|
||||||
.join(User, UserOrganization.user_id == User.id)
|
|
||||||
.where(UserOrganization.organization_id == organization_id)
|
|
||||||
)
|
|
||||||
|
|
||||||
if is_active is not None:
|
|
||||||
query = query.where(UserOrganization.is_active == is_active)
|
|
||||||
|
|
||||||
# Get total count
|
|
||||||
count_query = select(func.count()).select_from(
|
|
||||||
select(UserOrganization)
|
|
||||||
.where(UserOrganization.organization_id == organization_id)
|
|
||||||
.where(UserOrganization.is_active == is_active if is_active is not None else True)
|
|
||||||
.alias()
|
|
||||||
)
|
|
||||||
count_result = await db.execute(count_query)
|
|
||||||
total = count_result.scalar_one()
|
|
||||||
|
|
||||||
# Apply ordering and pagination
|
|
||||||
query = query.order_by(UserOrganization.created_at.desc()).offset(skip).limit(limit)
|
|
||||||
result = await db.execute(query)
|
|
||||||
results = result.all()
|
|
||||||
|
|
||||||
members = []
|
|
||||||
for user_org, user in results:
|
|
||||||
members.append({
|
|
||||||
"user_id": user.id,
|
|
||||||
"email": user.email,
|
|
||||||
"first_name": user.first_name,
|
|
||||||
"last_name": user.last_name,
|
|
||||||
"role": user_org.role,
|
|
||||||
"is_active": user_org.is_active,
|
|
||||||
"joined_at": user_org.created_at
|
|
||||||
})
|
|
||||||
|
|
||||||
return members, total
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error getting organization members: {str(e)}")
|
|
||||||
raise
|
|
||||||
|
|
||||||
async def get_user_organizations(
|
|
||||||
self,
|
|
||||||
db: AsyncSession,
|
|
||||||
*,
|
|
||||||
user_id: UUID,
|
|
||||||
is_active: bool = True
|
|
||||||
) -> List[Organization]:
|
|
||||||
"""Get all organizations a user belongs to."""
|
|
||||||
try:
|
|
||||||
query = (
|
|
||||||
select(Organization)
|
|
||||||
.join(UserOrganization, Organization.id == UserOrganization.organization_id)
|
|
||||||
.where(UserOrganization.user_id == user_id)
|
|
||||||
)
|
|
||||||
|
|
||||||
if is_active is not None:
|
|
||||||
query = query.where(UserOrganization.is_active == is_active)
|
|
||||||
|
|
||||||
result = await db.execute(query)
|
|
||||||
return list(result.scalars().all())
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error getting user organizations: {str(e)}")
|
|
||||||
raise
|
|
||||||
|
|
||||||
async def get_user_organizations_with_details(
|
|
||||||
self,
|
|
||||||
db: AsyncSession,
|
|
||||||
*,
|
|
||||||
user_id: UUID,
|
|
||||||
is_active: bool = True
|
|
||||||
) -> List[Dict[str, Any]]:
|
|
||||||
"""
|
|
||||||
Get user's organizations with role and member count in SINGLE QUERY.
|
|
||||||
Eliminates N+1 problem by using subquery for member counts.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List of dicts with organization, role, and member_count
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
# Subquery to get member counts for each organization
|
|
||||||
member_count_subq = (
|
|
||||||
select(
|
|
||||||
UserOrganization.organization_id,
|
|
||||||
func.count(UserOrganization.user_id).label('member_count')
|
|
||||||
)
|
|
||||||
.where(UserOrganization.is_active == True)
|
|
||||||
.group_by(UserOrganization.organization_id)
|
|
||||||
.subquery()
|
|
||||||
)
|
|
||||||
|
|
||||||
# Main query with JOIN to get org, role, and member count
|
|
||||||
query = (
|
|
||||||
select(
|
|
||||||
Organization,
|
|
||||||
UserOrganization.role,
|
|
||||||
func.coalesce(member_count_subq.c.member_count, 0).label('member_count')
|
|
||||||
)
|
|
||||||
.join(UserOrganization, Organization.id == UserOrganization.organization_id)
|
|
||||||
.outerjoin(member_count_subq, Organization.id == member_count_subq.c.organization_id)
|
|
||||||
.where(UserOrganization.user_id == user_id)
|
|
||||||
)
|
|
||||||
|
|
||||||
if is_active is not None:
|
|
||||||
query = query.where(UserOrganization.is_active == is_active)
|
|
||||||
|
|
||||||
result = await db.execute(query)
|
|
||||||
rows = result.all()
|
|
||||||
|
|
||||||
return [
|
|
||||||
{
|
|
||||||
'organization': org,
|
|
||||||
'role': role,
|
|
||||||
'member_count': member_count
|
|
||||||
}
|
|
||||||
for org, role, member_count in rows
|
|
||||||
]
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error getting user organizations with details: {str(e)}", exc_info=True)
|
|
||||||
raise
|
|
||||||
|
|
||||||
async def get_user_role_in_org(
|
|
||||||
self,
|
|
||||||
db: AsyncSession,
|
|
||||||
*,
|
|
||||||
user_id: UUID,
|
|
||||||
organization_id: UUID
|
|
||||||
) -> Optional[OrganizationRole]:
|
|
||||||
"""Get a user's role in a specific organization."""
|
|
||||||
try:
|
|
||||||
result = await db.execute(
|
|
||||||
select(UserOrganization).where(
|
|
||||||
and_(
|
|
||||||
UserOrganization.user_id == user_id,
|
|
||||||
UserOrganization.organization_id == organization_id,
|
|
||||||
UserOrganization.is_active == True
|
|
||||||
)
|
|
||||||
)
|
|
||||||
)
|
|
||||||
user_org = result.scalar_one_or_none()
|
|
||||||
|
|
||||||
return user_org.role if user_org else None
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error getting user role in org: {str(e)}")
|
|
||||||
raise
|
|
||||||
|
|
||||||
async def is_user_org_owner(
|
|
||||||
self,
|
|
||||||
db: AsyncSession,
|
|
||||||
*,
|
|
||||||
user_id: UUID,
|
|
||||||
organization_id: UUID
|
|
||||||
) -> bool:
|
|
||||||
"""Check if a user is an owner of an organization."""
|
|
||||||
role = await self.get_user_role_in_org(db, user_id=user_id, organization_id=organization_id)
|
|
||||||
return role == OrganizationRole.OWNER
|
|
||||||
|
|
||||||
async def is_user_org_admin(
|
|
||||||
self,
|
|
||||||
db: AsyncSession,
|
|
||||||
*,
|
|
||||||
user_id: UUID,
|
|
||||||
organization_id: UUID
|
|
||||||
) -> bool:
|
|
||||||
"""Check if a user is an owner or admin of an organization."""
|
|
||||||
role = await self.get_user_role_in_org(db, user_id=user_id, organization_id=organization_id)
|
|
||||||
return role in [OrganizationRole.OWNER, OrganizationRole.ADMIN]
|
|
||||||
|
|
||||||
|
|
||||||
# Create a singleton instance for use across the application
|
|
||||||
organization_async = CRUDOrganizationAsync(Organization)
|
|
||||||
220
backend/app/crud/session.py
Normal file → Executable file
220
backend/app/crud/session.py
Normal file → Executable file
@@ -1,13 +1,14 @@
|
|||||||
"""
|
"""
|
||||||
CRUD operations for user sessions.
|
Async CRUD operations for user sessions using SQLAlchemy 2.0 patterns.
|
||||||
"""
|
"""
|
||||||
import logging
|
import logging
|
||||||
from datetime import datetime, timezone, timedelta
|
from datetime import datetime, timezone, timedelta
|
||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
|
|
||||||
from sqlalchemy import and_
|
from sqlalchemy import and_, select, update, delete, func
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
from sqlalchemy.orm import joinedload
|
||||||
|
|
||||||
from app.crud.base import CRUDBase
|
from app.crud.base import CRUDBase
|
||||||
from app.models.user_session import UserSession
|
from app.models.user_session import UserSession
|
||||||
@@ -17,9 +18,9 @@ logger = logging.getLogger(__name__)
|
|||||||
|
|
||||||
|
|
||||||
class CRUDSession(CRUDBase[UserSession, SessionCreate, SessionUpdate]):
|
class CRUDSession(CRUDBase[UserSession, SessionCreate, SessionUpdate]):
|
||||||
"""CRUD operations for user sessions."""
|
"""Async CRUD operations for user sessions."""
|
||||||
|
|
||||||
def get_by_jti(self, db: Session, *, jti: str) -> Optional[UserSession]:
|
async def get_by_jti(self, db: AsyncSession, *, jti: str) -> Optional[UserSession]:
|
||||||
"""
|
"""
|
||||||
Get session by refresh token JTI.
|
Get session by refresh token JTI.
|
||||||
|
|
||||||
@@ -31,14 +32,15 @@ class CRUDSession(CRUDBase[UserSession, SessionCreate, SessionUpdate]):
|
|||||||
UserSession if found, None otherwise
|
UserSession if found, None otherwise
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
return db.query(UserSession).filter(
|
result = await db.execute(
|
||||||
UserSession.refresh_token_jti == jti
|
select(UserSession).where(UserSession.refresh_token_jti == jti)
|
||||||
).first()
|
)
|
||||||
|
return result.scalar_one_or_none()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error getting session by JTI {jti}: {str(e)}")
|
logger.error(f"Error getting session by JTI {jti}: {str(e)}")
|
||||||
raise
|
raise
|
||||||
|
|
||||||
def get_active_by_jti(self, db: Session, *, jti: str) -> Optional[UserSession]:
|
async def get_active_by_jti(self, db: AsyncSession, *, jti: str) -> Optional[UserSession]:
|
||||||
"""
|
"""
|
||||||
Get active session by refresh token JTI.
|
Get active session by refresh token JTI.
|
||||||
|
|
||||||
@@ -50,30 +52,35 @@ class CRUDSession(CRUDBase[UserSession, SessionCreate, SessionUpdate]):
|
|||||||
Active UserSession if found, None otherwise
|
Active UserSession if found, None otherwise
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
return db.query(UserSession).filter(
|
result = await db.execute(
|
||||||
and_(
|
select(UserSession).where(
|
||||||
UserSession.refresh_token_jti == jti,
|
and_(
|
||||||
UserSession.is_active == True
|
UserSession.refresh_token_jti == jti,
|
||||||
|
UserSession.is_active == True
|
||||||
|
)
|
||||||
)
|
)
|
||||||
).first()
|
)
|
||||||
|
return result.scalar_one_or_none()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error getting active session by JTI {jti}: {str(e)}")
|
logger.error(f"Error getting active session by JTI {jti}: {str(e)}")
|
||||||
raise
|
raise
|
||||||
|
|
||||||
def get_user_sessions(
|
async def get_user_sessions(
|
||||||
self,
|
self,
|
||||||
db: Session,
|
db: AsyncSession,
|
||||||
*,
|
*,
|
||||||
user_id: str,
|
user_id: str,
|
||||||
active_only: bool = True
|
active_only: bool = True,
|
||||||
|
with_user: bool = False
|
||||||
) -> List[UserSession]:
|
) -> List[UserSession]:
|
||||||
"""
|
"""
|
||||||
Get all sessions for a user.
|
Get all sessions for a user with optional eager loading.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
db: Database session
|
db: Database session
|
||||||
user_id: User ID
|
user_id: User ID
|
||||||
active_only: If True, return only active sessions
|
active_only: If True, return only active sessions
|
||||||
|
with_user: If True, eager load user relationship to prevent N+1
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
List of UserSession objects
|
List of UserSession objects
|
||||||
@@ -82,19 +89,25 @@ class CRUDSession(CRUDBase[UserSession, SessionCreate, SessionUpdate]):
|
|||||||
# Convert user_id string to UUID if needed
|
# Convert user_id string to UUID if needed
|
||||||
user_uuid = UUID(user_id) if isinstance(user_id, str) else user_id
|
user_uuid = UUID(user_id) if isinstance(user_id, str) else user_id
|
||||||
|
|
||||||
query = db.query(UserSession).filter(UserSession.user_id == user_uuid)
|
query = select(UserSession).where(UserSession.user_id == user_uuid)
|
||||||
|
|
||||||
|
# Add eager loading if requested to prevent N+1 queries
|
||||||
|
if with_user:
|
||||||
|
query = query.options(joinedload(UserSession.user))
|
||||||
|
|
||||||
if active_only:
|
if active_only:
|
||||||
query = query.filter(UserSession.is_active == True)
|
query = query.where(UserSession.is_active == True)
|
||||||
|
|
||||||
return query.order_by(UserSession.last_used_at.desc()).all()
|
query = query.order_by(UserSession.last_used_at.desc())
|
||||||
|
result = await db.execute(query)
|
||||||
|
return list(result.scalars().all())
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error getting sessions for user {user_id}: {str(e)}")
|
logger.error(f"Error getting sessions for user {user_id}: {str(e)}")
|
||||||
raise
|
raise
|
||||||
|
|
||||||
def create_session(
|
async def create_session(
|
||||||
self,
|
self,
|
||||||
db: Session,
|
db: AsyncSession,
|
||||||
*,
|
*,
|
||||||
obj_in: SessionCreate
|
obj_in: SessionCreate
|
||||||
) -> UserSession:
|
) -> UserSession:
|
||||||
@@ -126,8 +139,8 @@ class CRUDSession(CRUDBase[UserSession, SessionCreate, SessionUpdate]):
|
|||||||
location_country=obj_in.location_country,
|
location_country=obj_in.location_country,
|
||||||
)
|
)
|
||||||
db.add(db_obj)
|
db.add(db_obj)
|
||||||
db.commit()
|
await db.commit()
|
||||||
db.refresh(db_obj)
|
await db.refresh(db_obj)
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Session created for user {obj_in.user_id} from {obj_in.device_name} "
|
f"Session created for user {obj_in.user_id} from {obj_in.device_name} "
|
||||||
@@ -136,11 +149,11 @@ class CRUDSession(CRUDBase[UserSession, SessionCreate, SessionUpdate]):
|
|||||||
|
|
||||||
return db_obj
|
return db_obj
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
db.rollback()
|
await db.rollback()
|
||||||
logger.error(f"Error creating session: {str(e)}", exc_info=True)
|
logger.error(f"Error creating session: {str(e)}", exc_info=True)
|
||||||
raise ValueError(f"Failed to create session: {str(e)}")
|
raise ValueError(f"Failed to create session: {str(e)}")
|
||||||
|
|
||||||
def deactivate(self, db: Session, *, session_id: str) -> Optional[UserSession]:
|
async def deactivate(self, db: AsyncSession, *, session_id: str) -> Optional[UserSession]:
|
||||||
"""
|
"""
|
||||||
Deactivate a session (logout from device).
|
Deactivate a session (logout from device).
|
||||||
|
|
||||||
@@ -152,15 +165,15 @@ class CRUDSession(CRUDBase[UserSession, SessionCreate, SessionUpdate]):
|
|||||||
Deactivated UserSession if found, None otherwise
|
Deactivated UserSession if found, None otherwise
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
session = self.get(db, id=session_id)
|
session = await self.get(db, id=session_id)
|
||||||
if not session:
|
if not session:
|
||||||
logger.warning(f"Session {session_id} not found for deactivation")
|
logger.warning(f"Session {session_id} not found for deactivation")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
session.is_active = False
|
session.is_active = False
|
||||||
db.add(session)
|
db.add(session)
|
||||||
db.commit()
|
await db.commit()
|
||||||
db.refresh(session)
|
await db.refresh(session)
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Session {session_id} deactivated for user {session.user_id} "
|
f"Session {session_id} deactivated for user {session.user_id} "
|
||||||
@@ -169,13 +182,13 @@ class CRUDSession(CRUDBase[UserSession, SessionCreate, SessionUpdate]):
|
|||||||
|
|
||||||
return session
|
return session
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
db.rollback()
|
await db.rollback()
|
||||||
logger.error(f"Error deactivating session {session_id}: {str(e)}")
|
logger.error(f"Error deactivating session {session_id}: {str(e)}")
|
||||||
raise
|
raise
|
||||||
|
|
||||||
def deactivate_all_user_sessions(
|
async def deactivate_all_user_sessions(
|
||||||
self,
|
self,
|
||||||
db: Session,
|
db: AsyncSession,
|
||||||
*,
|
*,
|
||||||
user_id: str
|
user_id: str
|
||||||
) -> int:
|
) -> int:
|
||||||
@@ -193,26 +206,33 @@ class CRUDSession(CRUDBase[UserSession, SessionCreate, SessionUpdate]):
|
|||||||
# Convert user_id string to UUID if needed
|
# Convert user_id string to UUID if needed
|
||||||
user_uuid = UUID(user_id) if isinstance(user_id, str) else user_id
|
user_uuid = UUID(user_id) if isinstance(user_id, str) else user_id
|
||||||
|
|
||||||
count = db.query(UserSession).filter(
|
stmt = (
|
||||||
and_(
|
update(UserSession)
|
||||||
UserSession.user_id == user_uuid,
|
.where(
|
||||||
UserSession.is_active == True
|
and_(
|
||||||
|
UserSession.user_id == user_uuid,
|
||||||
|
UserSession.is_active == True
|
||||||
|
)
|
||||||
)
|
)
|
||||||
).update({"is_active": False})
|
.values(is_active=False)
|
||||||
|
)
|
||||||
|
|
||||||
db.commit()
|
result = await db.execute(stmt)
|
||||||
|
await db.commit()
|
||||||
|
|
||||||
|
count = result.rowcount
|
||||||
|
|
||||||
logger.info(f"Deactivated {count} sessions for user {user_id}")
|
logger.info(f"Deactivated {count} sessions for user {user_id}")
|
||||||
|
|
||||||
return count
|
return count
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
db.rollback()
|
await db.rollback()
|
||||||
logger.error(f"Error deactivating all sessions for user {user_id}: {str(e)}")
|
logger.error(f"Error deactivating all sessions for user {user_id}: {str(e)}")
|
||||||
raise
|
raise
|
||||||
|
|
||||||
def update_last_used(
|
async def update_last_used(
|
||||||
self,
|
self,
|
||||||
db: Session,
|
db: AsyncSession,
|
||||||
*,
|
*,
|
||||||
session: UserSession
|
session: UserSession
|
||||||
) -> UserSession:
|
) -> UserSession:
|
||||||
@@ -229,17 +249,17 @@ class CRUDSession(CRUDBase[UserSession, SessionCreate, SessionUpdate]):
|
|||||||
try:
|
try:
|
||||||
session.last_used_at = datetime.now(timezone.utc)
|
session.last_used_at = datetime.now(timezone.utc)
|
||||||
db.add(session)
|
db.add(session)
|
||||||
db.commit()
|
await db.commit()
|
||||||
db.refresh(session)
|
await db.refresh(session)
|
||||||
return session
|
return session
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
db.rollback()
|
await db.rollback()
|
||||||
logger.error(f"Error updating last_used for session {session.id}: {str(e)}")
|
logger.error(f"Error updating last_used for session {session.id}: {str(e)}")
|
||||||
raise
|
raise
|
||||||
|
|
||||||
def update_refresh_token(
|
async def update_refresh_token(
|
||||||
self,
|
self,
|
||||||
db: Session,
|
db: AsyncSession,
|
||||||
*,
|
*,
|
||||||
session: UserSession,
|
session: UserSession,
|
||||||
new_jti: str,
|
new_jti: str,
|
||||||
@@ -264,22 +284,24 @@ class CRUDSession(CRUDBase[UserSession, SessionCreate, SessionUpdate]):
|
|||||||
session.expires_at = new_expires_at
|
session.expires_at = new_expires_at
|
||||||
session.last_used_at = datetime.now(timezone.utc)
|
session.last_used_at = datetime.now(timezone.utc)
|
||||||
db.add(session)
|
db.add(session)
|
||||||
db.commit()
|
await db.commit()
|
||||||
db.refresh(session)
|
await db.refresh(session)
|
||||||
return session
|
return session
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
db.rollback()
|
await db.rollback()
|
||||||
logger.error(f"Error updating refresh token for session {session.id}: {str(e)}")
|
logger.error(f"Error updating refresh token for session {session.id}: {str(e)}")
|
||||||
raise
|
raise
|
||||||
|
|
||||||
def cleanup_expired(self, db: Session, *, keep_days: int = 30) -> int:
|
async def cleanup_expired(self, db: AsyncSession, *, keep_days: int = 30) -> int:
|
||||||
"""
|
"""
|
||||||
Clean up expired sessions.
|
Clean up expired sessions using optimized bulk DELETE.
|
||||||
|
|
||||||
Deletes sessions that are:
|
Deletes sessions that are:
|
||||||
- Expired AND inactive
|
- Expired AND inactive
|
||||||
- Older than keep_days
|
- Older than keep_days
|
||||||
|
|
||||||
|
Uses single DELETE query instead of N individual deletes for efficiency.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
db: Database session
|
db: Database session
|
||||||
keep_days: Keep inactive sessions for this many days (for audit)
|
keep_days: Keep inactive sessions for this many days (for audit)
|
||||||
@@ -289,31 +311,87 @@ class CRUDSession(CRUDBase[UserSession, SessionCreate, SessionUpdate]):
|
|||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
cutoff_date = datetime.now(timezone.utc) - timedelta(days=keep_days)
|
cutoff_date = datetime.now(timezone.utc) - timedelta(days=keep_days)
|
||||||
|
now = datetime.now(timezone.utc)
|
||||||
|
|
||||||
# Delete sessions that are:
|
# Use bulk DELETE with WHERE clause - single query
|
||||||
# 1. Expired (expires_at < now) AND inactive
|
stmt = delete(UserSession).where(
|
||||||
# AND
|
|
||||||
# 2. Older than keep_days
|
|
||||||
count = db.query(UserSession).filter(
|
|
||||||
and_(
|
and_(
|
||||||
UserSession.is_active == False,
|
UserSession.is_active == False,
|
||||||
UserSession.expires_at < datetime.now(timezone.utc),
|
UserSession.expires_at < now,
|
||||||
UserSession.created_at < cutoff_date
|
UserSession.created_at < cutoff_date
|
||||||
)
|
)
|
||||||
).delete()
|
)
|
||||||
|
|
||||||
db.commit()
|
result = await db.execute(stmt)
|
||||||
|
await db.commit()
|
||||||
|
|
||||||
|
count = result.rowcount
|
||||||
|
|
||||||
if count > 0:
|
if count > 0:
|
||||||
logger.info(f"Cleaned up {count} expired sessions")
|
logger.info(f"Cleaned up {count} expired sessions using bulk DELETE")
|
||||||
|
|
||||||
return count
|
return count
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
db.rollback()
|
await db.rollback()
|
||||||
logger.error(f"Error cleaning up expired sessions: {str(e)}")
|
logger.error(f"Error cleaning up expired sessions: {str(e)}")
|
||||||
raise
|
raise
|
||||||
|
|
||||||
def get_user_session_count(self, db: Session, *, user_id: str) -> int:
|
async def cleanup_expired_for_user(
|
||||||
|
self,
|
||||||
|
db: AsyncSession,
|
||||||
|
*,
|
||||||
|
user_id: str
|
||||||
|
) -> int:
|
||||||
|
"""
|
||||||
|
Clean up expired and inactive sessions for a specific user.
|
||||||
|
|
||||||
|
Uses single bulk DELETE query for efficiency instead of N individual deletes.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
db: Database session
|
||||||
|
user_id: User ID to cleanup sessions for
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Number of sessions deleted
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Validate UUID
|
||||||
|
try:
|
||||||
|
uuid_obj = uuid.UUID(user_id)
|
||||||
|
except (ValueError, AttributeError):
|
||||||
|
logger.error(f"Invalid UUID format: {user_id}")
|
||||||
|
raise ValueError(f"Invalid user ID format: {user_id}")
|
||||||
|
|
||||||
|
now = datetime.now(timezone.utc)
|
||||||
|
|
||||||
|
# Use bulk DELETE with WHERE clause - single query
|
||||||
|
stmt = delete(UserSession).where(
|
||||||
|
and_(
|
||||||
|
UserSession.user_id == uuid_obj,
|
||||||
|
UserSession.is_active == False,
|
||||||
|
UserSession.expires_at < now
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
result = await db.execute(stmt)
|
||||||
|
await db.commit()
|
||||||
|
|
||||||
|
count = result.rowcount
|
||||||
|
|
||||||
|
if count > 0:
|
||||||
|
logger.info(
|
||||||
|
f"Cleaned up {count} expired sessions for user {user_id} using bulk DELETE"
|
||||||
|
)
|
||||||
|
|
||||||
|
return count
|
||||||
|
except Exception as e:
|
||||||
|
await db.rollback()
|
||||||
|
logger.error(
|
||||||
|
f"Error cleaning up expired sessions for user {user_id}: {str(e)}"
|
||||||
|
)
|
||||||
|
raise
|
||||||
|
|
||||||
|
async def get_user_session_count(self, db: AsyncSession, *, user_id: str) -> int:
|
||||||
"""
|
"""
|
||||||
Get count of active sessions for a user.
|
Get count of active sessions for a user.
|
||||||
|
|
||||||
@@ -325,12 +403,18 @@ class CRUDSession(CRUDBase[UserSession, SessionCreate, SessionUpdate]):
|
|||||||
Number of active sessions
|
Number of active sessions
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
return db.query(UserSession).filter(
|
# Convert user_id string to UUID if needed
|
||||||
and_(
|
user_uuid = UUID(user_id) if isinstance(user_id, str) else user_id
|
||||||
UserSession.user_id == user_id,
|
|
||||||
UserSession.is_active == True
|
result = await db.execute(
|
||||||
|
select(func.count(UserSession.id)).where(
|
||||||
|
and_(
|
||||||
|
UserSession.user_id == user_uuid,
|
||||||
|
UserSession.is_active == True
|
||||||
|
)
|
||||||
)
|
)
|
||||||
).count()
|
)
|
||||||
|
return result.scalar_one()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error counting sessions for user {user_id}: {str(e)}")
|
logger.error(f"Error counting sessions for user {user_id}: {str(e)}")
|
||||||
raise
|
raise
|
||||||
|
|||||||
@@ -1,424 +0,0 @@
|
|||||||
"""
|
|
||||||
Async CRUD operations for user sessions using SQLAlchemy 2.0 patterns.
|
|
||||||
"""
|
|
||||||
import logging
|
|
||||||
from datetime import datetime, timezone, timedelta
|
|
||||||
from typing import List, Optional
|
|
||||||
from uuid import UUID
|
|
||||||
|
|
||||||
from sqlalchemy import and_, select, update, delete, func
|
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
|
||||||
from sqlalchemy.orm import joinedload
|
|
||||||
|
|
||||||
from app.crud.base_async import CRUDBaseAsync
|
|
||||||
from app.models.user_session import UserSession
|
|
||||||
from app.schemas.sessions import SessionCreate, SessionUpdate
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
class CRUDSessionAsync(CRUDBaseAsync[UserSession, SessionCreate, SessionUpdate]):
|
|
||||||
"""Async CRUD operations for user sessions."""
|
|
||||||
|
|
||||||
async def get_by_jti(self, db: AsyncSession, *, jti: str) -> Optional[UserSession]:
|
|
||||||
"""
|
|
||||||
Get session by refresh token JTI.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
db: Database session
|
|
||||||
jti: Refresh token JWT ID
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
UserSession if found, None otherwise
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
result = await db.execute(
|
|
||||||
select(UserSession).where(UserSession.refresh_token_jti == jti)
|
|
||||||
)
|
|
||||||
return result.scalar_one_or_none()
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error getting session by JTI {jti}: {str(e)}")
|
|
||||||
raise
|
|
||||||
|
|
||||||
async def get_active_by_jti(self, db: AsyncSession, *, jti: str) -> Optional[UserSession]:
|
|
||||||
"""
|
|
||||||
Get active session by refresh token JTI.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
db: Database session
|
|
||||||
jti: Refresh token JWT ID
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Active UserSession if found, None otherwise
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
result = await db.execute(
|
|
||||||
select(UserSession).where(
|
|
||||||
and_(
|
|
||||||
UserSession.refresh_token_jti == jti,
|
|
||||||
UserSession.is_active == True
|
|
||||||
)
|
|
||||||
)
|
|
||||||
)
|
|
||||||
return result.scalar_one_or_none()
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error getting active session by JTI {jti}: {str(e)}")
|
|
||||||
raise
|
|
||||||
|
|
||||||
async def get_user_sessions(
|
|
||||||
self,
|
|
||||||
db: AsyncSession,
|
|
||||||
*,
|
|
||||||
user_id: str,
|
|
||||||
active_only: bool = True,
|
|
||||||
with_user: bool = False
|
|
||||||
) -> List[UserSession]:
|
|
||||||
"""
|
|
||||||
Get all sessions for a user with optional eager loading.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
db: Database session
|
|
||||||
user_id: User ID
|
|
||||||
active_only: If True, return only active sessions
|
|
||||||
with_user: If True, eager load user relationship to prevent N+1
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List of UserSession objects
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
# Convert user_id string to UUID if needed
|
|
||||||
user_uuid = UUID(user_id) if isinstance(user_id, str) else user_id
|
|
||||||
|
|
||||||
query = select(UserSession).where(UserSession.user_id == user_uuid)
|
|
||||||
|
|
||||||
# Add eager loading if requested to prevent N+1 queries
|
|
||||||
if with_user:
|
|
||||||
query = query.options(joinedload(UserSession.user))
|
|
||||||
|
|
||||||
if active_only:
|
|
||||||
query = query.where(UserSession.is_active == True)
|
|
||||||
|
|
||||||
query = query.order_by(UserSession.last_used_at.desc())
|
|
||||||
result = await db.execute(query)
|
|
||||||
return list(result.scalars().all())
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error getting sessions for user {user_id}: {str(e)}")
|
|
||||||
raise
|
|
||||||
|
|
||||||
async def create_session(
|
|
||||||
self,
|
|
||||||
db: AsyncSession,
|
|
||||||
*,
|
|
||||||
obj_in: SessionCreate
|
|
||||||
) -> UserSession:
|
|
||||||
"""
|
|
||||||
Create a new user session.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
db: Database session
|
|
||||||
obj_in: SessionCreate schema with session data
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Created UserSession
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
ValueError: If session creation fails
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
db_obj = UserSession(
|
|
||||||
user_id=obj_in.user_id,
|
|
||||||
refresh_token_jti=obj_in.refresh_token_jti,
|
|
||||||
device_name=obj_in.device_name,
|
|
||||||
device_id=obj_in.device_id,
|
|
||||||
ip_address=obj_in.ip_address,
|
|
||||||
user_agent=obj_in.user_agent,
|
|
||||||
last_used_at=obj_in.last_used_at,
|
|
||||||
expires_at=obj_in.expires_at,
|
|
||||||
is_active=True,
|
|
||||||
location_city=obj_in.location_city,
|
|
||||||
location_country=obj_in.location_country,
|
|
||||||
)
|
|
||||||
db.add(db_obj)
|
|
||||||
await db.commit()
|
|
||||||
await db.refresh(db_obj)
|
|
||||||
|
|
||||||
logger.info(
|
|
||||||
f"Session created for user {obj_in.user_id} from {obj_in.device_name} "
|
|
||||||
f"(IP: {obj_in.ip_address})"
|
|
||||||
)
|
|
||||||
|
|
||||||
return db_obj
|
|
||||||
except Exception as e:
|
|
||||||
await db.rollback()
|
|
||||||
logger.error(f"Error creating session: {str(e)}", exc_info=True)
|
|
||||||
raise ValueError(f"Failed to create session: {str(e)}")
|
|
||||||
|
|
||||||
async def deactivate(self, db: AsyncSession, *, session_id: str) -> Optional[UserSession]:
|
|
||||||
"""
|
|
||||||
Deactivate a session (logout from device).
|
|
||||||
|
|
||||||
Args:
|
|
||||||
db: Database session
|
|
||||||
session_id: Session UUID
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Deactivated UserSession if found, None otherwise
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
session = await self.get(db, id=session_id)
|
|
||||||
if not session:
|
|
||||||
logger.warning(f"Session {session_id} not found for deactivation")
|
|
||||||
return None
|
|
||||||
|
|
||||||
session.is_active = False
|
|
||||||
db.add(session)
|
|
||||||
await db.commit()
|
|
||||||
await db.refresh(session)
|
|
||||||
|
|
||||||
logger.info(
|
|
||||||
f"Session {session_id} deactivated for user {session.user_id} "
|
|
||||||
f"({session.device_name})"
|
|
||||||
)
|
|
||||||
|
|
||||||
return session
|
|
||||||
except Exception as e:
|
|
||||||
await db.rollback()
|
|
||||||
logger.error(f"Error deactivating session {session_id}: {str(e)}")
|
|
||||||
raise
|
|
||||||
|
|
||||||
async def deactivate_all_user_sessions(
|
|
||||||
self,
|
|
||||||
db: AsyncSession,
|
|
||||||
*,
|
|
||||||
user_id: str
|
|
||||||
) -> int:
|
|
||||||
"""
|
|
||||||
Deactivate all active sessions for a user (logout from all devices).
|
|
||||||
|
|
||||||
Args:
|
|
||||||
db: Database session
|
|
||||||
user_id: User ID
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Number of sessions deactivated
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
# Convert user_id string to UUID if needed
|
|
||||||
user_uuid = UUID(user_id) if isinstance(user_id, str) else user_id
|
|
||||||
|
|
||||||
stmt = (
|
|
||||||
update(UserSession)
|
|
||||||
.where(
|
|
||||||
and_(
|
|
||||||
UserSession.user_id == user_uuid,
|
|
||||||
UserSession.is_active == True
|
|
||||||
)
|
|
||||||
)
|
|
||||||
.values(is_active=False)
|
|
||||||
)
|
|
||||||
|
|
||||||
result = await db.execute(stmt)
|
|
||||||
await db.commit()
|
|
||||||
|
|
||||||
count = result.rowcount
|
|
||||||
|
|
||||||
logger.info(f"Deactivated {count} sessions for user {user_id}")
|
|
||||||
|
|
||||||
return count
|
|
||||||
except Exception as e:
|
|
||||||
await db.rollback()
|
|
||||||
logger.error(f"Error deactivating all sessions for user {user_id}: {str(e)}")
|
|
||||||
raise
|
|
||||||
|
|
||||||
async def update_last_used(
|
|
||||||
self,
|
|
||||||
db: AsyncSession,
|
|
||||||
*,
|
|
||||||
session: UserSession
|
|
||||||
) -> UserSession:
|
|
||||||
"""
|
|
||||||
Update the last_used_at timestamp for a session.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
db: Database session
|
|
||||||
session: UserSession object
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Updated UserSession
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
session.last_used_at = datetime.now(timezone.utc)
|
|
||||||
db.add(session)
|
|
||||||
await db.commit()
|
|
||||||
await db.refresh(session)
|
|
||||||
return session
|
|
||||||
except Exception as e:
|
|
||||||
await db.rollback()
|
|
||||||
logger.error(f"Error updating last_used for session {session.id}: {str(e)}")
|
|
||||||
raise
|
|
||||||
|
|
||||||
async def update_refresh_token(
|
|
||||||
self,
|
|
||||||
db: AsyncSession,
|
|
||||||
*,
|
|
||||||
session: UserSession,
|
|
||||||
new_jti: str,
|
|
||||||
new_expires_at: datetime
|
|
||||||
) -> UserSession:
|
|
||||||
"""
|
|
||||||
Update session with new refresh token JTI and expiration.
|
|
||||||
|
|
||||||
Called during token refresh.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
db: Database session
|
|
||||||
session: UserSession object
|
|
||||||
new_jti: New refresh token JTI
|
|
||||||
new_expires_at: New expiration datetime
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Updated UserSession
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
session.refresh_token_jti = new_jti
|
|
||||||
session.expires_at = new_expires_at
|
|
||||||
session.last_used_at = datetime.now(timezone.utc)
|
|
||||||
db.add(session)
|
|
||||||
await db.commit()
|
|
||||||
await db.refresh(session)
|
|
||||||
return session
|
|
||||||
except Exception as e:
|
|
||||||
await db.rollback()
|
|
||||||
logger.error(f"Error updating refresh token for session {session.id}: {str(e)}")
|
|
||||||
raise
|
|
||||||
|
|
||||||
async def cleanup_expired(self, db: AsyncSession, *, keep_days: int = 30) -> int:
|
|
||||||
"""
|
|
||||||
Clean up expired sessions using optimized bulk DELETE.
|
|
||||||
|
|
||||||
Deletes sessions that are:
|
|
||||||
- Expired AND inactive
|
|
||||||
- Older than keep_days
|
|
||||||
|
|
||||||
Uses single DELETE query instead of N individual deletes for efficiency.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
db: Database session
|
|
||||||
keep_days: Keep inactive sessions for this many days (for audit)
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Number of sessions deleted
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
cutoff_date = datetime.now(timezone.utc) - timedelta(days=keep_days)
|
|
||||||
now = datetime.now(timezone.utc)
|
|
||||||
|
|
||||||
# Use bulk DELETE with WHERE clause - single query
|
|
||||||
stmt = delete(UserSession).where(
|
|
||||||
and_(
|
|
||||||
UserSession.is_active == False,
|
|
||||||
UserSession.expires_at < now,
|
|
||||||
UserSession.created_at < cutoff_date
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
result = await db.execute(stmt)
|
|
||||||
await db.commit()
|
|
||||||
|
|
||||||
count = result.rowcount
|
|
||||||
|
|
||||||
if count > 0:
|
|
||||||
logger.info(f"Cleaned up {count} expired sessions using bulk DELETE")
|
|
||||||
|
|
||||||
return count
|
|
||||||
except Exception as e:
|
|
||||||
await db.rollback()
|
|
||||||
logger.error(f"Error cleaning up expired sessions: {str(e)}")
|
|
||||||
raise
|
|
||||||
|
|
||||||
async def cleanup_expired_for_user(
|
|
||||||
self,
|
|
||||||
db: AsyncSession,
|
|
||||||
*,
|
|
||||||
user_id: str
|
|
||||||
) -> int:
|
|
||||||
"""
|
|
||||||
Clean up expired and inactive sessions for a specific user.
|
|
||||||
|
|
||||||
Uses single bulk DELETE query for efficiency instead of N individual deletes.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
db: Database session
|
|
||||||
user_id: User ID to cleanup sessions for
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Number of sessions deleted
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
# Validate UUID
|
|
||||||
try:
|
|
||||||
uuid_obj = uuid.UUID(user_id)
|
|
||||||
except (ValueError, AttributeError):
|
|
||||||
logger.error(f"Invalid UUID format: {user_id}")
|
|
||||||
raise ValueError(f"Invalid user ID format: {user_id}")
|
|
||||||
|
|
||||||
now = datetime.now(timezone.utc)
|
|
||||||
|
|
||||||
# Use bulk DELETE with WHERE clause - single query
|
|
||||||
stmt = delete(UserSession).where(
|
|
||||||
and_(
|
|
||||||
UserSession.user_id == uuid_obj,
|
|
||||||
UserSession.is_active == False,
|
|
||||||
UserSession.expires_at < now
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
result = await db.execute(stmt)
|
|
||||||
await db.commit()
|
|
||||||
|
|
||||||
count = result.rowcount
|
|
||||||
|
|
||||||
if count > 0:
|
|
||||||
logger.info(
|
|
||||||
f"Cleaned up {count} expired sessions for user {user_id} using bulk DELETE"
|
|
||||||
)
|
|
||||||
|
|
||||||
return count
|
|
||||||
except Exception as e:
|
|
||||||
await db.rollback()
|
|
||||||
logger.error(
|
|
||||||
f"Error cleaning up expired sessions for user {user_id}: {str(e)}"
|
|
||||||
)
|
|
||||||
raise
|
|
||||||
|
|
||||||
async def get_user_session_count(self, db: AsyncSession, *, user_id: str) -> int:
|
|
||||||
"""
|
|
||||||
Get count of active sessions for a user.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
db: Database session
|
|
||||||
user_id: User ID
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Number of active sessions
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
# Convert user_id string to UUID if needed
|
|
||||||
user_uuid = UUID(user_id) if isinstance(user_id, str) else user_id
|
|
||||||
|
|
||||||
result = await db.execute(
|
|
||||||
select(func.count(UserSession.id)).where(
|
|
||||||
and_(
|
|
||||||
UserSession.user_id == user_uuid,
|
|
||||||
UserSession.is_active == True
|
|
||||||
)
|
|
||||||
)
|
|
||||||
)
|
|
||||||
return result.scalar_one()
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error counting sessions for user {user_id}: {str(e)}")
|
|
||||||
raise
|
|
||||||
|
|
||||||
|
|
||||||
# Create singleton instance
|
|
||||||
session_async = CRUDSessionAsync(UserSession)
|
|
||||||
183
backend/app/crud/user.py
Normal file → Executable file
183
backend/app/crud/user.py
Normal file → Executable file
@@ -1,12 +1,15 @@
|
|||||||
# app/crud/user.py
|
# app/crud/user_async.py
|
||||||
|
"""Async CRUD operations for User model using SQLAlchemy 2.0 patterns."""
|
||||||
import logging
|
import logging
|
||||||
|
from datetime import datetime, timezone
|
||||||
from typing import Optional, Union, Dict, Any, List, Tuple
|
from typing import Optional, Union, Dict, Any, List, Tuple
|
||||||
|
from uuid import UUID
|
||||||
|
|
||||||
from sqlalchemy import or_, asc, desc
|
from sqlalchemy import or_, select, update
|
||||||
from sqlalchemy.exc import IntegrityError
|
from sqlalchemy.exc import IntegrityError
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
from app.core.auth import get_password_hash
|
from app.core.auth import get_password_hash_async
|
||||||
from app.crud.base import CRUDBase
|
from app.crud.base import CRUDBase
|
||||||
from app.models.user import User
|
from app.models.user import User
|
||||||
from app.schemas.users import UserCreate, UserUpdate
|
from app.schemas.users import UserCreate, UserUpdate
|
||||||
@@ -15,15 +18,28 @@ logger = logging.getLogger(__name__)
|
|||||||
|
|
||||||
|
|
||||||
class CRUDUser(CRUDBase[User, UserCreate, UserUpdate]):
|
class CRUDUser(CRUDBase[User, UserCreate, UserUpdate]):
|
||||||
def get_by_email(self, db: Session, *, email: str) -> Optional[User]:
|
"""Async CRUD operations for User model."""
|
||||||
return db.query(User).filter(User.email == email).first()
|
|
||||||
|
|
||||||
def create(self, db: Session, *, obj_in: UserCreate) -> User:
|
async def get_by_email(self, db: AsyncSession, *, email: str) -> Optional[User]:
|
||||||
"""Create a new user with password hashing and error handling."""
|
"""Get user by email address."""
|
||||||
try:
|
try:
|
||||||
|
result = await db.execute(
|
||||||
|
select(User).where(User.email == email)
|
||||||
|
)
|
||||||
|
return result.scalar_one_or_none()
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error getting user by email {email}: {str(e)}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
async def create(self, db: AsyncSession, *, obj_in: UserCreate) -> User:
|
||||||
|
"""Create a new user with async password hashing and error handling."""
|
||||||
|
try:
|
||||||
|
# Hash password asynchronously to avoid blocking event loop
|
||||||
|
password_hash = await get_password_hash_async(obj_in.password)
|
||||||
|
|
||||||
db_obj = User(
|
db_obj = User(
|
||||||
email=obj_in.email,
|
email=obj_in.email,
|
||||||
password_hash=get_password_hash(obj_in.password),
|
password_hash=password_hash,
|
||||||
first_name=obj_in.first_name,
|
first_name=obj_in.first_name,
|
||||||
last_name=obj_in.last_name,
|
last_name=obj_in.last_name,
|
||||||
phone_number=obj_in.phone_number if hasattr(obj_in, 'phone_number') else None,
|
phone_number=obj_in.phone_number if hasattr(obj_in, 'phone_number') else None,
|
||||||
@@ -31,11 +47,11 @@ class CRUDUser(CRUDBase[User, UserCreate, UserUpdate]):
|
|||||||
preferences={}
|
preferences={}
|
||||||
)
|
)
|
||||||
db.add(db_obj)
|
db.add(db_obj)
|
||||||
db.commit()
|
await db.commit()
|
||||||
db.refresh(db_obj)
|
await db.refresh(db_obj)
|
||||||
return db_obj
|
return db_obj
|
||||||
except IntegrityError as e:
|
except IntegrityError as e:
|
||||||
db.rollback()
|
await db.rollback()
|
||||||
error_msg = str(e.orig) if hasattr(e, 'orig') else str(e)
|
error_msg = str(e.orig) if hasattr(e, 'orig') else str(e)
|
||||||
if "email" in error_msg.lower():
|
if "email" in error_msg.lower():
|
||||||
logger.warning(f"Duplicate email attempted: {obj_in.email}")
|
logger.warning(f"Duplicate email attempted: {obj_in.email}")
|
||||||
@@ -43,32 +59,34 @@ class CRUDUser(CRUDBase[User, UserCreate, UserUpdate]):
|
|||||||
logger.error(f"Integrity error creating user: {error_msg}")
|
logger.error(f"Integrity error creating user: {error_msg}")
|
||||||
raise ValueError(f"Database integrity error: {error_msg}")
|
raise ValueError(f"Database integrity error: {error_msg}")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
db.rollback()
|
await db.rollback()
|
||||||
logger.error(f"Unexpected error creating user: {str(e)}", exc_info=True)
|
logger.error(f"Unexpected error creating user: {str(e)}", exc_info=True)
|
||||||
raise
|
raise
|
||||||
|
|
||||||
def update(
|
async def update(
|
||||||
self,
|
self,
|
||||||
db: Session,
|
db: AsyncSession,
|
||||||
*,
|
*,
|
||||||
db_obj: User,
|
db_obj: User,
|
||||||
obj_in: Union[UserUpdate, Dict[str, Any]]
|
obj_in: Union[UserUpdate, Dict[str, Any]]
|
||||||
) -> User:
|
) -> User:
|
||||||
|
"""Update user with async password hashing if password is updated."""
|
||||||
if isinstance(obj_in, dict):
|
if isinstance(obj_in, dict):
|
||||||
update_data = obj_in
|
update_data = obj_in
|
||||||
else:
|
else:
|
||||||
update_data = obj_in.model_dump(exclude_unset=True)
|
update_data = obj_in.model_dump(exclude_unset=True)
|
||||||
|
|
||||||
# Handle password separately if it exists in update data
|
# Handle password separately if it exists in update data
|
||||||
|
# Hash password asynchronously to avoid blocking event loop
|
||||||
if "password" in update_data:
|
if "password" in update_data:
|
||||||
update_data["password_hash"] = get_password_hash(update_data["password"])
|
update_data["password_hash"] = await get_password_hash_async(update_data["password"])
|
||||||
del update_data["password"]
|
del update_data["password"]
|
||||||
|
|
||||||
return super().update(db, db_obj=db_obj, obj_in=update_data)
|
return await super().update(db, db_obj=db_obj, obj_in=update_data)
|
||||||
|
|
||||||
def get_multi_with_total(
|
async def get_multi_with_total(
|
||||||
self,
|
self,
|
||||||
db: Session,
|
db: AsyncSession,
|
||||||
*,
|
*,
|
||||||
skip: int = 0,
|
skip: int = 0,
|
||||||
limit: int = 100,
|
limit: int = 100,
|
||||||
@@ -102,16 +120,16 @@ class CRUDUser(CRUDBase[User, UserCreate, UserUpdate]):
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
# Build base query
|
# Build base query
|
||||||
query = db.query(User)
|
query = select(User)
|
||||||
|
|
||||||
# Exclude soft-deleted users
|
# Exclude soft-deleted users
|
||||||
query = query.filter(User.deleted_at.is_(None))
|
query = query.where(User.deleted_at.is_(None))
|
||||||
|
|
||||||
# Apply filters
|
# Apply filters
|
||||||
if filters:
|
if filters:
|
||||||
for field, value in filters.items():
|
for field, value in filters.items():
|
||||||
if hasattr(User, field) and value is not None:
|
if hasattr(User, field) and value is not None:
|
||||||
query = query.filter(getattr(User, field) == value)
|
query = query.where(getattr(User, field) == value)
|
||||||
|
|
||||||
# Apply search
|
# Apply search
|
||||||
if search:
|
if search:
|
||||||
@@ -120,21 +138,26 @@ class CRUDUser(CRUDBase[User, UserCreate, UserUpdate]):
|
|||||||
User.first_name.ilike(f"%{search}%"),
|
User.first_name.ilike(f"%{search}%"),
|
||||||
User.last_name.ilike(f"%{search}%")
|
User.last_name.ilike(f"%{search}%")
|
||||||
)
|
)
|
||||||
query = query.filter(search_filter)
|
query = query.where(search_filter)
|
||||||
|
|
||||||
# Get total count
|
# Get total count
|
||||||
total = query.count()
|
from sqlalchemy import func
|
||||||
|
count_query = select(func.count()).select_from(query.alias())
|
||||||
|
count_result = await db.execute(count_query)
|
||||||
|
total = count_result.scalar_one()
|
||||||
|
|
||||||
# Apply sorting
|
# Apply sorting
|
||||||
if sort_by and hasattr(User, sort_by):
|
if sort_by and hasattr(User, sort_by):
|
||||||
sort_column = getattr(User, sort_by)
|
sort_column = getattr(User, sort_by)
|
||||||
if sort_order.lower() == "desc":
|
if sort_order.lower() == "desc":
|
||||||
query = query.order_by(desc(sort_column))
|
query = query.order_by(sort_column.desc())
|
||||||
else:
|
else:
|
||||||
query = query.order_by(asc(sort_column))
|
query = query.order_by(sort_column.asc())
|
||||||
|
|
||||||
# Apply pagination
|
# Apply pagination
|
||||||
users = query.offset(skip).limit(limit).all()
|
query = query.offset(skip).limit(limit)
|
||||||
|
result = await db.execute(query)
|
||||||
|
users = list(result.scalars().all())
|
||||||
|
|
||||||
return users, total
|
return users, total
|
||||||
|
|
||||||
@@ -142,12 +165,108 @@ class CRUDUser(CRUDBase[User, UserCreate, UserUpdate]):
|
|||||||
logger.error(f"Error retrieving paginated users: {str(e)}")
|
logger.error(f"Error retrieving paginated users: {str(e)}")
|
||||||
raise
|
raise
|
||||||
|
|
||||||
|
async def bulk_update_status(
|
||||||
|
self,
|
||||||
|
db: AsyncSession,
|
||||||
|
*,
|
||||||
|
user_ids: List[UUID],
|
||||||
|
is_active: bool
|
||||||
|
) -> int:
|
||||||
|
"""
|
||||||
|
Bulk update is_active status for multiple users.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
db: Database session
|
||||||
|
user_ids: List of user IDs to update
|
||||||
|
is_active: New active status
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Number of users updated
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
if not user_ids:
|
||||||
|
return 0
|
||||||
|
|
||||||
|
# Use UPDATE with WHERE IN for efficiency
|
||||||
|
stmt = (
|
||||||
|
update(User)
|
||||||
|
.where(User.id.in_(user_ids))
|
||||||
|
.where(User.deleted_at.is_(None)) # Don't update deleted users
|
||||||
|
.values(is_active=is_active, updated_at=datetime.now(timezone.utc))
|
||||||
|
)
|
||||||
|
|
||||||
|
result = await db.execute(stmt)
|
||||||
|
await db.commit()
|
||||||
|
|
||||||
|
updated_count = result.rowcount
|
||||||
|
logger.info(f"Bulk updated {updated_count} users to is_active={is_active}")
|
||||||
|
return updated_count
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
await db.rollback()
|
||||||
|
logger.error(f"Error bulk updating user status: {str(e)}", exc_info=True)
|
||||||
|
raise
|
||||||
|
|
||||||
|
async def bulk_soft_delete(
|
||||||
|
self,
|
||||||
|
db: AsyncSession,
|
||||||
|
*,
|
||||||
|
user_ids: List[UUID],
|
||||||
|
exclude_user_id: Optional[UUID] = None
|
||||||
|
) -> int:
|
||||||
|
"""
|
||||||
|
Bulk soft delete multiple users.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
db: Database session
|
||||||
|
user_ids: List of user IDs to delete
|
||||||
|
exclude_user_id: Optional user ID to exclude (e.g., the admin performing the action)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Number of users deleted
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
if not user_ids:
|
||||||
|
return 0
|
||||||
|
|
||||||
|
# Remove excluded user from list
|
||||||
|
filtered_ids = [uid for uid in user_ids if uid != exclude_user_id]
|
||||||
|
|
||||||
|
if not filtered_ids:
|
||||||
|
return 0
|
||||||
|
|
||||||
|
# Use UPDATE with WHERE IN for efficiency
|
||||||
|
stmt = (
|
||||||
|
update(User)
|
||||||
|
.where(User.id.in_(filtered_ids))
|
||||||
|
.where(User.deleted_at.is_(None)) # Don't re-delete already deleted users
|
||||||
|
.values(
|
||||||
|
deleted_at=datetime.now(timezone.utc),
|
||||||
|
is_active=False,
|
||||||
|
updated_at=datetime.now(timezone.utc)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
result = await db.execute(stmt)
|
||||||
|
await db.commit()
|
||||||
|
|
||||||
|
deleted_count = result.rowcount
|
||||||
|
logger.info(f"Bulk soft deleted {deleted_count} users")
|
||||||
|
return deleted_count
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
await db.rollback()
|
||||||
|
logger.error(f"Error bulk deleting users: {str(e)}", exc_info=True)
|
||||||
|
raise
|
||||||
|
|
||||||
def is_active(self, user: User) -> bool:
|
def is_active(self, user: User) -> bool:
|
||||||
|
"""Check if user is active."""
|
||||||
return user.is_active
|
return user.is_active
|
||||||
|
|
||||||
def is_superuser(self, user: User) -> bool:
|
def is_superuser(self, user: User) -> bool:
|
||||||
|
"""Check if user is a superuser."""
|
||||||
return user.is_superuser
|
return user.is_superuser
|
||||||
|
|
||||||
|
|
||||||
# Create a singleton instance for use across the application
|
# Create a singleton instance for use across the application
|
||||||
user = CRUDUser(User)
|
user = CRUDUser(User)
|
||||||
|
|||||||
@@ -1,272 +0,0 @@
|
|||||||
# app/crud/user_async.py
|
|
||||||
"""Async CRUD operations for User model using SQLAlchemy 2.0 patterns."""
|
|
||||||
import logging
|
|
||||||
from datetime import datetime, timezone
|
|
||||||
from typing import Optional, Union, Dict, Any, List, Tuple
|
|
||||||
from uuid import UUID
|
|
||||||
|
|
||||||
from sqlalchemy import or_, select, update
|
|
||||||
from sqlalchemy.exc import IntegrityError
|
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
|
||||||
|
|
||||||
from app.core.auth import get_password_hash_async
|
|
||||||
from app.crud.base_async import CRUDBaseAsync
|
|
||||||
from app.models.user import User
|
|
||||||
from app.schemas.users import UserCreate, UserUpdate
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
class CRUDUserAsync(CRUDBaseAsync[User, UserCreate, UserUpdate]):
|
|
||||||
"""Async CRUD operations for User model."""
|
|
||||||
|
|
||||||
async def get_by_email(self, db: AsyncSession, *, email: str) -> Optional[User]:
|
|
||||||
"""Get user by email address."""
|
|
||||||
try:
|
|
||||||
result = await db.execute(
|
|
||||||
select(User).where(User.email == email)
|
|
||||||
)
|
|
||||||
return result.scalar_one_or_none()
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error getting user by email {email}: {str(e)}")
|
|
||||||
raise
|
|
||||||
|
|
||||||
async def create(self, db: AsyncSession, *, obj_in: UserCreate) -> User:
|
|
||||||
"""Create a new user with async password hashing and error handling."""
|
|
||||||
try:
|
|
||||||
# Hash password asynchronously to avoid blocking event loop
|
|
||||||
password_hash = await get_password_hash_async(obj_in.password)
|
|
||||||
|
|
||||||
db_obj = User(
|
|
||||||
email=obj_in.email,
|
|
||||||
password_hash=password_hash,
|
|
||||||
first_name=obj_in.first_name,
|
|
||||||
last_name=obj_in.last_name,
|
|
||||||
phone_number=obj_in.phone_number if hasattr(obj_in, 'phone_number') else None,
|
|
||||||
is_superuser=obj_in.is_superuser if hasattr(obj_in, 'is_superuser') else False,
|
|
||||||
preferences={}
|
|
||||||
)
|
|
||||||
db.add(db_obj)
|
|
||||||
await db.commit()
|
|
||||||
await db.refresh(db_obj)
|
|
||||||
return db_obj
|
|
||||||
except IntegrityError as e:
|
|
||||||
await db.rollback()
|
|
||||||
error_msg = str(e.orig) if hasattr(e, 'orig') else str(e)
|
|
||||||
if "email" in error_msg.lower():
|
|
||||||
logger.warning(f"Duplicate email attempted: {obj_in.email}")
|
|
||||||
raise ValueError(f"User with email {obj_in.email} already exists")
|
|
||||||
logger.error(f"Integrity error creating user: {error_msg}")
|
|
||||||
raise ValueError(f"Database integrity error: {error_msg}")
|
|
||||||
except Exception as e:
|
|
||||||
await db.rollback()
|
|
||||||
logger.error(f"Unexpected error creating user: {str(e)}", exc_info=True)
|
|
||||||
raise
|
|
||||||
|
|
||||||
async def update(
|
|
||||||
self,
|
|
||||||
db: AsyncSession,
|
|
||||||
*,
|
|
||||||
db_obj: User,
|
|
||||||
obj_in: Union[UserUpdate, Dict[str, Any]]
|
|
||||||
) -> User:
|
|
||||||
"""Update user with async password hashing if password is updated."""
|
|
||||||
if isinstance(obj_in, dict):
|
|
||||||
update_data = obj_in
|
|
||||||
else:
|
|
||||||
update_data = obj_in.model_dump(exclude_unset=True)
|
|
||||||
|
|
||||||
# Handle password separately if it exists in update data
|
|
||||||
# Hash password asynchronously to avoid blocking event loop
|
|
||||||
if "password" in update_data:
|
|
||||||
update_data["password_hash"] = await get_password_hash_async(update_data["password"])
|
|
||||||
del update_data["password"]
|
|
||||||
|
|
||||||
return await super().update(db, db_obj=db_obj, obj_in=update_data)
|
|
||||||
|
|
||||||
async def get_multi_with_total(
|
|
||||||
self,
|
|
||||||
db: AsyncSession,
|
|
||||||
*,
|
|
||||||
skip: int = 0,
|
|
||||||
limit: int = 100,
|
|
||||||
sort_by: Optional[str] = None,
|
|
||||||
sort_order: str = "asc",
|
|
||||||
filters: Optional[Dict[str, Any]] = None,
|
|
||||||
search: Optional[str] = None
|
|
||||||
) -> Tuple[List[User], int]:
|
|
||||||
"""
|
|
||||||
Get multiple users with total count, filtering, sorting, and search.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
db: Database session
|
|
||||||
skip: Number of records to skip
|
|
||||||
limit: Maximum number of records to return
|
|
||||||
sort_by: Field name to sort by
|
|
||||||
sort_order: Sort order ("asc" or "desc")
|
|
||||||
filters: Dictionary of filters (field_name: value)
|
|
||||||
search: Search term to match against email, first_name, last_name
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Tuple of (users list, total count)
|
|
||||||
"""
|
|
||||||
# Validate pagination
|
|
||||||
if skip < 0:
|
|
||||||
raise ValueError("skip must be non-negative")
|
|
||||||
if limit < 0:
|
|
||||||
raise ValueError("limit must be non-negative")
|
|
||||||
if limit > 1000:
|
|
||||||
raise ValueError("Maximum limit is 1000")
|
|
||||||
|
|
||||||
try:
|
|
||||||
# Build base query
|
|
||||||
query = select(User)
|
|
||||||
|
|
||||||
# Exclude soft-deleted users
|
|
||||||
query = query.where(User.deleted_at.is_(None))
|
|
||||||
|
|
||||||
# Apply filters
|
|
||||||
if filters:
|
|
||||||
for field, value in filters.items():
|
|
||||||
if hasattr(User, field) and value is not None:
|
|
||||||
query = query.where(getattr(User, field) == value)
|
|
||||||
|
|
||||||
# Apply search
|
|
||||||
if search:
|
|
||||||
search_filter = or_(
|
|
||||||
User.email.ilike(f"%{search}%"),
|
|
||||||
User.first_name.ilike(f"%{search}%"),
|
|
||||||
User.last_name.ilike(f"%{search}%")
|
|
||||||
)
|
|
||||||
query = query.where(search_filter)
|
|
||||||
|
|
||||||
# Get total count
|
|
||||||
from sqlalchemy import func
|
|
||||||
count_query = select(func.count()).select_from(query.alias())
|
|
||||||
count_result = await db.execute(count_query)
|
|
||||||
total = count_result.scalar_one()
|
|
||||||
|
|
||||||
# Apply sorting
|
|
||||||
if sort_by and hasattr(User, sort_by):
|
|
||||||
sort_column = getattr(User, sort_by)
|
|
||||||
if sort_order.lower() == "desc":
|
|
||||||
query = query.order_by(sort_column.desc())
|
|
||||||
else:
|
|
||||||
query = query.order_by(sort_column.asc())
|
|
||||||
|
|
||||||
# Apply pagination
|
|
||||||
query = query.offset(skip).limit(limit)
|
|
||||||
result = await db.execute(query)
|
|
||||||
users = list(result.scalars().all())
|
|
||||||
|
|
||||||
return users, total
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error retrieving paginated users: {str(e)}")
|
|
||||||
raise
|
|
||||||
|
|
||||||
async def bulk_update_status(
|
|
||||||
self,
|
|
||||||
db: AsyncSession,
|
|
||||||
*,
|
|
||||||
user_ids: List[UUID],
|
|
||||||
is_active: bool
|
|
||||||
) -> int:
|
|
||||||
"""
|
|
||||||
Bulk update is_active status for multiple users.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
db: Database session
|
|
||||||
user_ids: List of user IDs to update
|
|
||||||
is_active: New active status
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Number of users updated
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
if not user_ids:
|
|
||||||
return 0
|
|
||||||
|
|
||||||
# Use UPDATE with WHERE IN for efficiency
|
|
||||||
stmt = (
|
|
||||||
update(User)
|
|
||||||
.where(User.id.in_(user_ids))
|
|
||||||
.where(User.deleted_at.is_(None)) # Don't update deleted users
|
|
||||||
.values(is_active=is_active, updated_at=datetime.now(timezone.utc))
|
|
||||||
)
|
|
||||||
|
|
||||||
result = await db.execute(stmt)
|
|
||||||
await db.commit()
|
|
||||||
|
|
||||||
updated_count = result.rowcount
|
|
||||||
logger.info(f"Bulk updated {updated_count} users to is_active={is_active}")
|
|
||||||
return updated_count
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
await db.rollback()
|
|
||||||
logger.error(f"Error bulk updating user status: {str(e)}", exc_info=True)
|
|
||||||
raise
|
|
||||||
|
|
||||||
async def bulk_soft_delete(
|
|
||||||
self,
|
|
||||||
db: AsyncSession,
|
|
||||||
*,
|
|
||||||
user_ids: List[UUID],
|
|
||||||
exclude_user_id: Optional[UUID] = None
|
|
||||||
) -> int:
|
|
||||||
"""
|
|
||||||
Bulk soft delete multiple users.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
db: Database session
|
|
||||||
user_ids: List of user IDs to delete
|
|
||||||
exclude_user_id: Optional user ID to exclude (e.g., the admin performing the action)
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Number of users deleted
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
if not user_ids:
|
|
||||||
return 0
|
|
||||||
|
|
||||||
# Remove excluded user from list
|
|
||||||
filtered_ids = [uid for uid in user_ids if uid != exclude_user_id]
|
|
||||||
|
|
||||||
if not filtered_ids:
|
|
||||||
return 0
|
|
||||||
|
|
||||||
# Use UPDATE with WHERE IN for efficiency
|
|
||||||
stmt = (
|
|
||||||
update(User)
|
|
||||||
.where(User.id.in_(filtered_ids))
|
|
||||||
.where(User.deleted_at.is_(None)) # Don't re-delete already deleted users
|
|
||||||
.values(
|
|
||||||
deleted_at=datetime.now(timezone.utc),
|
|
||||||
is_active=False,
|
|
||||||
updated_at=datetime.now(timezone.utc)
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
result = await db.execute(stmt)
|
|
||||||
await db.commit()
|
|
||||||
|
|
||||||
deleted_count = result.rowcount
|
|
||||||
logger.info(f"Bulk soft deleted {deleted_count} users")
|
|
||||||
return deleted_count
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
await db.rollback()
|
|
||||||
logger.error(f"Error bulk deleting users: {str(e)}", exc_info=True)
|
|
||||||
raise
|
|
||||||
|
|
||||||
def is_active(self, user: User) -> bool:
|
|
||||||
"""Check if user is active."""
|
|
||||||
return user.is_active
|
|
||||||
|
|
||||||
def is_superuser(self, user: User) -> bool:
|
|
||||||
"""Check if user is a superuser."""
|
|
||||||
return user.is_superuser
|
|
||||||
|
|
||||||
|
|
||||||
# Create a singleton instance for use across the application
|
|
||||||
user_async = CRUDUserAsync(User)
|
|
||||||
@@ -1,78 +0,0 @@
|
|||||||
# app/init_db.py
|
|
||||||
import logging
|
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
from sqlalchemy.orm import Session
|
|
||||||
|
|
||||||
from app.core.config import settings
|
|
||||||
from app.core.database import engine
|
|
||||||
from app.crud.user import user as user_crud
|
|
||||||
from app.schemas.users import UserCreate
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
def init_db(db: Session) -> Optional[UserCreate]:
|
|
||||||
"""
|
|
||||||
Initialize database with first superuser if settings are configured and user doesn't exist.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
The created or existing superuser, or None if creation fails
|
|
||||||
"""
|
|
||||||
# Use default values if not set in environment variables
|
|
||||||
superuser_email = settings.FIRST_SUPERUSER_EMAIL or "admin@example.com"
|
|
||||||
superuser_password = settings.FIRST_SUPERUSER_PASSWORD or "Admin123!Change"
|
|
||||||
|
|
||||||
if not settings.FIRST_SUPERUSER_EMAIL or not settings.FIRST_SUPERUSER_PASSWORD:
|
|
||||||
logger.warning(
|
|
||||||
"First superuser credentials not configured in settings. "
|
|
||||||
f"Using defaults: {superuser_email}"
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
|
||||||
# Check if superuser already exists
|
|
||||||
existing_user = user_crud.get_by_email(db, email=superuser_email)
|
|
||||||
|
|
||||||
if existing_user:
|
|
||||||
logger.info(f"Superuser already exists: {existing_user.email}")
|
|
||||||
return existing_user
|
|
||||||
|
|
||||||
# Create superuser if doesn't exist
|
|
||||||
user_in = UserCreate(
|
|
||||||
email=superuser_email,
|
|
||||||
password=superuser_password,
|
|
||||||
first_name="Admin",
|
|
||||||
last_name="User",
|
|
||||||
is_superuser=True
|
|
||||||
)
|
|
||||||
|
|
||||||
user = user_crud.create(db, obj_in=user_in)
|
|
||||||
logger.info(f"Created first superuser: {user.email}")
|
|
||||||
|
|
||||||
return user
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error initializing database: {e}")
|
|
||||||
raise
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
# Configure logging to show info logs
|
|
||||||
logging.basicConfig(
|
|
||||||
level=logging.INFO,
|
|
||||||
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
|
|
||||||
)
|
|
||||||
|
|
||||||
with Session(engine) as session:
|
|
||||||
try:
|
|
||||||
user = init_db(session)
|
|
||||||
if user:
|
|
||||||
print(f"✓ Database initialized successfully")
|
|
||||||
print(f"✓ Superuser: {user.email}")
|
|
||||||
else:
|
|
||||||
print("✗ Failed to initialize database")
|
|
||||||
except Exception as e:
|
|
||||||
print(f"✗ Error initializing database: {e}")
|
|
||||||
raise
|
|
||||||
finally:
|
|
||||||
session.close()
|
|
||||||
@@ -13,7 +13,7 @@ from slowapi.util import get_remote_address
|
|||||||
|
|
||||||
from app.api.main import api_router
|
from app.api.main import api_router
|
||||||
from app.core.config import settings
|
from app.core.config import settings
|
||||||
from app.core.database_async import check_database_health
|
from app.core.database import check_database_health
|
||||||
from app.core.exceptions import (
|
from app.core.exceptions import (
|
||||||
APIException,
|
APIException,
|
||||||
api_exception_handler,
|
api_exception_handler,
|
||||||
|
|||||||
@@ -6,8 +6,8 @@ This service runs periodically to remove old session records from the database.
|
|||||||
import logging
|
import logging
|
||||||
from datetime import datetime, timezone
|
from datetime import datetime, timezone
|
||||||
|
|
||||||
from app.core.database_async import AsyncSessionLocal
|
from app.core.database import SessionLocal
|
||||||
from app.crud.session_async import session_async as session_crud
|
from app.crud.session import session as session_crud
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -29,7 +29,7 @@ async def cleanup_expired_sessions(keep_days: int = 30) -> int:
|
|||||||
"""
|
"""
|
||||||
logger.info("Starting session cleanup job...")
|
logger.info("Starting session cleanup job...")
|
||||||
|
|
||||||
async with AsyncSessionLocal() as db:
|
async with SessionLocal() as db:
|
||||||
try:
|
try:
|
||||||
# Use CRUD method to cleanup
|
# Use CRUD method to cleanup
|
||||||
count = await session_crud.cleanup_expired(db, keep_days=keep_days)
|
count = await session_crud.cleanup_expired(db, keep_days=keep_days)
|
||||||
@@ -50,7 +50,7 @@ async def get_session_statistics() -> dict:
|
|||||||
Returns:
|
Returns:
|
||||||
Dictionary with session stats
|
Dictionary with session stats
|
||||||
"""
|
"""
|
||||||
async with AsyncSessionLocal() as db:
|
async with SessionLocal() as db:
|
||||||
try:
|
try:
|
||||||
from app.models.user_session import UserSession
|
from app.models.user_session import UserSession
|
||||||
from sqlalchemy import select, func
|
from sqlalchemy import select, func
|
||||||
|
|||||||
Reference in New Issue
Block a user