diff --git a/backend/app/__init__.py b/backend/app/__init__.py old mode 100644 new mode 100755 diff --git a/backend/app/api/dependencies/auth.py b/backend/app/api/dependencies/auth.py old mode 100644 new mode 100755 index 417e1cc..736f209 --- a/backend/app/api/dependencies/auth.py +++ b/backend/app/api/dependencies/auth.py @@ -3,18 +3,19 @@ from typing import Optional from fastapi import Depends, HTTPException, status, Header from fastapi.security import OAuth2PasswordBearer from fastapi.security.utils import get_authorization_scheme_param -from sqlalchemy.orm import Session +from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy import select from app.core.auth import get_token_data, TokenExpiredError, TokenInvalidError -from app.core.database import get_db +from app.core.database_async import get_async_db from app.models.user import User # OAuth2 configuration oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/api/v1/auth/login") -def get_current_user( - db: Session = Depends(get_db), +async def get_current_user( + db: AsyncSession = Depends(get_async_db), token: str = Depends(oauth2_scheme) ) -> User: """ @@ -35,7 +36,11 @@ def get_current_user( token_data = get_token_data(token) # Get user from database - user = db.query(User).filter(User.id == token_data.user_id).first() + result = await db.execute( + select(User).where(User.id == token_data.user_id) + ) + user = result.scalar_one_or_none() + if not user: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, @@ -133,8 +138,8 @@ async def get_optional_token(authorization: str = Header(None)) -> Optional[str] return token -def get_optional_current_user( - db: Session = Depends(get_db), +async def get_optional_current_user( + db: AsyncSession = Depends(get_async_db), token: Optional[str] = Depends(get_optional_token) ) -> Optional[User]: """ @@ -153,7 +158,10 @@ def get_optional_current_user( try: token_data = get_token_data(token) - user = db.query(User).filter(User.id == token_data.user_id).first() + result = await db.execute( + select(User).where(User.id == token_data.user_id) + ) + user = result.scalar_one_or_none() if not user or not user.is_active: return None return user diff --git a/backend/app/api/dependencies/permissions.py b/backend/app/api/dependencies/permissions.py old mode 100644 new mode 100755 index f7cac20..73749ce --- a/backend/app/api/dependencies/permissions.py +++ b/backend/app/api/dependencies/permissions.py @@ -10,13 +10,13 @@ These dependencies are optional and flexible: from typing import Optional from uuid import UUID from fastapi import Depends, HTTPException, status -from sqlalchemy.orm import Session +from sqlalchemy.ext.asyncio import AsyncSession -from app.core.database import get_db +from app.core.database_async import get_async_db from app.models.user import User from app.models.user_organization import OrganizationRole from app.api.dependencies.auth import get_current_user -from app.crud.organization import organization as organization_crud +from app.crud.organization_async import organization_async as organization_crud def require_superuser( @@ -73,11 +73,11 @@ class OrganizationPermission: """ self.allowed_roles = allowed_roles - def __call__( + async def __call__( self, organization_id: UUID, current_user: User = Depends(get_current_user), - db: Session = Depends(get_db) + db: AsyncSession = Depends(get_async_db) ) -> User: """ Check if user has required role in the organization. @@ -98,7 +98,7 @@ class OrganizationPermission: return current_user # Get user's role in organization - user_role = organization_crud.get_user_role_in_org( + user_role = await organization_crud.get_user_role_in_org( db, user_id=current_user.id, organization_id=organization_id @@ -129,10 +129,10 @@ require_org_member = OrganizationPermission([ ]) -def get_current_org_role( +async def get_current_org_role( organization_id: UUID, current_user: User = Depends(get_current_user), - db: Session = Depends(get_db) + db: AsyncSession = Depends(get_async_db) ) -> Optional[OrganizationRole]: """ Get the current user's role in an organization. @@ -142,7 +142,7 @@ def get_current_org_role( Example: @router.get("/organizations/{org_id}/items") - def list_items( + async def list_items( org_id: UUID, role: OrganizationRole = Depends(get_current_org_role) ): @@ -153,17 +153,17 @@ def get_current_org_role( if current_user.is_superuser: return OrganizationRole.OWNER - return organization_crud.get_user_role_in_org( + return await organization_crud.get_user_role_in_org( db, user_id=current_user.id, organization_id=organization_id ) -def require_org_membership( +async def require_org_membership( organization_id: UUID, current_user: User = Depends(get_current_user), - db: Session = Depends(get_db) + db: AsyncSession = Depends(get_async_db) ) -> User: """ Ensure user is a member of the organization (any role). @@ -173,7 +173,7 @@ def require_org_membership( if current_user.is_superuser: return current_user - user_role = organization_crud.get_user_role_in_org( + user_role = await organization_crud.get_user_role_in_org( db, user_id=current_user.id, organization_id=organization_id diff --git a/backend/app/api/routes/admin.py b/backend/app/api/routes/admin.py old mode 100644 new mode 100755 index 8813511..5a2e8a3 --- a/backend/app/api/routes/admin.py +++ b/backend/app/api/routes/admin.py @@ -11,13 +11,13 @@ from uuid import UUID from enum import Enum from fastapi import APIRouter, Depends, Query, Body, status -from sqlalchemy.orm import Session +from sqlalchemy.ext.asyncio import AsyncSession from pydantic import BaseModel, Field from app.api.dependencies.permissions import require_superuser -from app.core.database import get_db -from app.crud.user import user as user_crud -from app.crud.organization import organization as organization_crud +from app.core.database_async import get_async_db +from app.crud.user_async import user_async as user_crud +from app.crud.organization_async import organization_async as organization_crud from app.models.user import User from app.models.user_organization import OrganizationRole from app.schemas.users import UserResponse, UserCreate, UserUpdate @@ -73,14 +73,14 @@ class BulkActionResult(BaseModel): description="Get paginated list of all users with filtering and search (admin only)", operation_id="admin_list_users" ) -def admin_list_users( +async def admin_list_users( pagination: PaginationParams = Depends(), sort: SortParams = Depends(), is_active: Optional[bool] = Query(None, description="Filter by active status"), is_superuser: Optional[bool] = Query(None, description="Filter by superuser status"), search: Optional[str] = Query(None, description="Search by email, name"), admin: User = Depends(require_superuser), - db: Session = Depends(get_db) + db: AsyncSession = Depends(get_async_db) ) -> Any: """ List all users with comprehensive filtering and search. @@ -96,7 +96,7 @@ def admin_list_users( filters["is_superuser"] = is_superuser # Get users with search - users, total = user_crud.get_multi_with_total( + users, total = await user_crud.get_multi_with_total( db, skip=pagination.offset, limit=pagination.limit, @@ -128,10 +128,10 @@ def admin_list_users( description="Create a new user (admin only)", operation_id="admin_create_user" ) -def admin_create_user( +async def admin_create_user( user_in: UserCreate, admin: User = Depends(require_superuser), - db: Session = Depends(get_db) + db: AsyncSession = Depends(get_async_db) ) -> Any: """ Create a new user with admin privileges. @@ -139,7 +139,7 @@ def admin_create_user( Allows setting is_superuser and other fields. """ try: - user = user_crud.create(db, obj_in=user_in) + user = await user_crud.create(db, obj_in=user_in) logger.info(f"Admin {admin.email} created user {user.email}") return user except ValueError as e: @@ -160,13 +160,13 @@ def admin_create_user( description="Get detailed user information (admin only)", operation_id="admin_get_user" ) -def admin_get_user( +async def admin_get_user( user_id: UUID, admin: User = Depends(require_superuser), - db: Session = Depends(get_db) + db: AsyncSession = Depends(get_async_db) ) -> Any: """Get detailed information about a specific user.""" - user = user_crud.get(db, id=user_id) + user = await user_crud.get(db, id=user_id) if not user: raise NotFoundError( detail=f"User {user_id} not found", @@ -182,22 +182,22 @@ def admin_get_user( description="Update user information (admin only)", operation_id="admin_update_user" ) -def admin_update_user( +async def admin_update_user( user_id: UUID, user_in: UserUpdate, admin: User = Depends(require_superuser), - db: Session = Depends(get_db) + db: AsyncSession = Depends(get_async_db) ) -> Any: """Update user information with admin privileges.""" try: - user = user_crud.get(db, id=user_id) + user = await user_crud.get(db, id=user_id) if not user: raise NotFoundError( detail=f"User {user_id} not found", error_code=ErrorCode.USER_NOT_FOUND ) - updated_user = user_crud.update(db, db_obj=user, obj_in=user_in) + updated_user = await user_crud.update(db, db_obj=user, obj_in=user_in) logger.info(f"Admin {admin.email} updated user {updated_user.email}") return updated_user @@ -215,14 +215,14 @@ def admin_update_user( description="Soft delete a user (admin only)", operation_id="admin_delete_user" ) -def admin_delete_user( +async def admin_delete_user( user_id: UUID, admin: User = Depends(require_superuser), - db: Session = Depends(get_db) + db: AsyncSession = Depends(get_async_db) ) -> Any: """Soft delete a user (sets deleted_at timestamp).""" try: - user = user_crud.get(db, id=user_id) + user = await user_crud.get(db, id=user_id) if not user: raise NotFoundError( detail=f"User {user_id} not found", @@ -236,7 +236,7 @@ def admin_delete_user( error_code=ErrorCode.OPERATION_FORBIDDEN ) - user_crud.soft_delete(db, id=user_id) + await user_crud.soft_delete(db, id=user_id) logger.info(f"Admin {admin.email} deleted user {user.email}") return MessageResponse( @@ -258,21 +258,21 @@ def admin_delete_user( description="Activate a user account (admin only)", operation_id="admin_activate_user" ) -def admin_activate_user( +async def admin_activate_user( user_id: UUID, admin: User = Depends(require_superuser), - db: Session = Depends(get_db) + db: AsyncSession = Depends(get_async_db) ) -> Any: """Activate a user account.""" try: - user = user_crud.get(db, id=user_id) + user = await user_crud.get(db, id=user_id) if not user: raise NotFoundError( detail=f"User {user_id} not found", error_code=ErrorCode.USER_NOT_FOUND ) - user_crud.update(db, db_obj=user, obj_in={"is_active": True}) + await user_crud.update(db, db_obj=user, obj_in={"is_active": True}) logger.info(f"Admin {admin.email} activated user {user.email}") return MessageResponse( @@ -294,14 +294,14 @@ def admin_activate_user( description="Deactivate a user account (admin only)", operation_id="admin_deactivate_user" ) -def admin_deactivate_user( +async def admin_deactivate_user( user_id: UUID, admin: User = Depends(require_superuser), - db: Session = Depends(get_db) + db: AsyncSession = Depends(get_async_db) ) -> Any: """Deactivate a user account.""" try: - user = user_crud.get(db, id=user_id) + user = await user_crud.get(db, id=user_id) if not user: raise NotFoundError( detail=f"User {user_id} not found", @@ -315,7 +315,7 @@ def admin_deactivate_user( error_code=ErrorCode.OPERATION_FORBIDDEN ) - user_crud.update(db, db_obj=user, obj_in={"is_active": False}) + await user_crud.update(db, db_obj=user, obj_in={"is_active": False}) logger.info(f"Admin {admin.email} deactivated user {user.email}") return MessageResponse( @@ -337,10 +337,10 @@ def admin_deactivate_user( description="Perform bulk actions on multiple users (admin only)", operation_id="admin_bulk_user_action" ) -def admin_bulk_user_action( +async def admin_bulk_user_action( bulk_action: BulkUserAction, admin: User = Depends(require_superuser), - db: Session = Depends(get_db) + db: AsyncSession = Depends(get_async_db) ) -> Any: """ Perform bulk actions on multiple users. @@ -354,7 +354,7 @@ def admin_bulk_user_action( try: for user_id in bulk_action.user_ids: try: - user = user_crud.get(db, id=user_id) + user = await user_crud.get(db, id=user_id) if not user: failed_count += 1 failed_ids.append(user_id) @@ -367,11 +367,11 @@ def admin_bulk_user_action( continue if bulk_action.action == BulkAction.ACTIVATE: - user_crud.update(db, db_obj=user, obj_in={"is_active": True}) + await user_crud.update(db, db_obj=user, obj_in={"is_active": True}) elif bulk_action.action == BulkAction.DEACTIVATE: - user_crud.update(db, db_obj=user, obj_in={"is_active": False}) + await user_crud.update(db, db_obj=user, obj_in={"is_active": False}) elif bulk_action.action == BulkAction.DELETE: - user_crud.soft_delete(db, id=user_id) + await user_crud.soft_delete(db, id=user_id) affected_count += 1 @@ -407,16 +407,16 @@ def admin_bulk_user_action( description="Get paginated list of all organizations (admin only)", operation_id="admin_list_organizations" ) -def admin_list_organizations( +async def admin_list_organizations( pagination: PaginationParams = Depends(), is_active: Optional[bool] = Query(None, description="Filter by active status"), search: Optional[str] = Query(None, description="Search by name, slug, description"), admin: User = Depends(require_superuser), - db: Session = Depends(get_db) + db: AsyncSession = Depends(get_async_db) ) -> Any: """List all organizations with filtering and search.""" try: - orgs, total = organization_crud.get_multi_with_filters( + orgs, total = await organization_crud.get_multi_with_filters( db, skip=pagination.offset, limit=pagination.limit, @@ -438,7 +438,7 @@ def admin_list_organizations( "settings": org.settings, "created_at": org.created_at, "updated_at": org.updated_at, - "member_count": organization_crud.get_member_count(db, organization_id=org.id) + "member_count": await organization_crud.get_member_count(db, organization_id=org.id) } orgs_with_count.append(OrganizationResponse(**org_dict)) @@ -464,14 +464,14 @@ def admin_list_organizations( description="Create a new organization (admin only)", operation_id="admin_create_organization" ) -def admin_create_organization( +async def admin_create_organization( org_in: OrganizationCreate, admin: User = Depends(require_superuser), - db: Session = Depends(get_db) + db: AsyncSession = Depends(get_async_db) ) -> Any: """Create a new organization.""" try: - org = organization_crud.create(db, obj_in=org_in) + org = await organization_crud.create(db, obj_in=org_in) logger.info(f"Admin {admin.email} created organization {org.name}") # Add member count @@ -506,13 +506,13 @@ def admin_create_organization( description="Get detailed organization information (admin only)", operation_id="admin_get_organization" ) -def admin_get_organization( +async def admin_get_organization( org_id: UUID, admin: User = Depends(require_superuser), - db: Session = Depends(get_db) + db: AsyncSession = Depends(get_async_db) ) -> Any: """Get detailed information about a specific organization.""" - org = organization_crud.get(db, id=org_id) + org = await organization_crud.get(db, id=org_id) if not org: raise NotFoundError( detail=f"Organization {org_id} not found", @@ -528,7 +528,7 @@ def admin_get_organization( "settings": org.settings, "created_at": org.created_at, "updated_at": org.updated_at, - "member_count": organization_crud.get_member_count(db, organization_id=org.id) + "member_count": await organization_crud.get_member_count(db, organization_id=org.id) } return OrganizationResponse(**org_dict) @@ -540,22 +540,22 @@ def admin_get_organization( description="Update organization information (admin only)", operation_id="admin_update_organization" ) -def admin_update_organization( +async def admin_update_organization( org_id: UUID, org_in: OrganizationUpdate, admin: User = Depends(require_superuser), - db: Session = Depends(get_db) + db: AsyncSession = Depends(get_async_db) ) -> Any: """Update organization information.""" try: - org = organization_crud.get(db, id=org_id) + org = await organization_crud.get(db, id=org_id) if not org: raise NotFoundError( detail=f"Organization {org_id} not found", error_code=ErrorCode.NOT_FOUND ) - updated_org = organization_crud.update(db, db_obj=org, obj_in=org_in) + updated_org = await organization_crud.update(db, db_obj=org, obj_in=org_in) logger.info(f"Admin {admin.email} updated organization {updated_org.name}") org_dict = { @@ -567,7 +567,7 @@ def admin_update_organization( "settings": updated_org.settings, "created_at": updated_org.created_at, "updated_at": updated_org.updated_at, - "member_count": organization_crud.get_member_count(db, organization_id=updated_org.id) + "member_count": await organization_crud.get_member_count(db, organization_id=updated_org.id) } return OrganizationResponse(**org_dict) @@ -585,21 +585,21 @@ def admin_update_organization( description="Delete an organization (admin only)", operation_id="admin_delete_organization" ) -def admin_delete_organization( +async def admin_delete_organization( org_id: UUID, admin: User = Depends(require_superuser), - db: Session = Depends(get_db) + db: AsyncSession = Depends(get_async_db) ) -> Any: """Delete an organization and all its relationships.""" try: - org = organization_crud.get(db, id=org_id) + org = await organization_crud.get(db, id=org_id) if not org: raise NotFoundError( detail=f"Organization {org_id} not found", error_code=ErrorCode.NOT_FOUND ) - organization_crud.remove(db, id=org_id) + await organization_crud.remove(db, id=org_id) logger.info(f"Admin {admin.email} deleted organization {org.name}") return MessageResponse( @@ -621,23 +621,23 @@ def admin_delete_organization( description="Get all members of an organization (admin only)", operation_id="admin_list_organization_members" ) -def admin_list_organization_members( +async def admin_list_organization_members( org_id: UUID, pagination: PaginationParams = Depends(), is_active: Optional[bool] = Query(True, description="Filter by active status"), admin: User = Depends(require_superuser), - db: Session = Depends(get_db) + db: AsyncSession = Depends(get_async_db) ) -> Any: """List all members of an organization.""" try: - org = organization_crud.get(db, id=org_id) + org = await organization_crud.get(db, id=org_id) if not org: raise NotFoundError( detail=f"Organization {org_id} not found", error_code=ErrorCode.NOT_FOUND ) - members, total = organization_crud.get_organization_members( + members, total = await organization_crud.get_organization_members( db, organization_id=org_id, skip=pagination.offset, @@ -677,29 +677,29 @@ class AddMemberRequest(BaseModel): description="Add a user to an organization (admin only)", operation_id="admin_add_organization_member" ) -def admin_add_organization_member( +async def admin_add_organization_member( org_id: UUID, request: AddMemberRequest, admin: User = Depends(require_superuser), - db: Session = Depends(get_db) + db: AsyncSession = Depends(get_async_db) ) -> Any: """Add a user to an organization.""" try: - org = organization_crud.get(db, id=org_id) + org = await organization_crud.get(db, id=org_id) if not org: raise NotFoundError( detail=f"Organization {org_id} not found", error_code=ErrorCode.NOT_FOUND ) - user = user_crud.get(db, id=request.user_id) + user = await user_crud.get(db, id=request.user_id) if not user: raise NotFoundError( detail=f"User {request.user_id} not found", error_code=ErrorCode.USER_NOT_FOUND ) - organization_crud.add_user( + await organization_crud.add_user( db, organization_id=org_id, user_id=request.user_id, @@ -733,29 +733,29 @@ def admin_add_organization_member( description="Remove a user from an organization (admin only)", operation_id="admin_remove_organization_member" ) -def admin_remove_organization_member( +async def admin_remove_organization_member( org_id: UUID, user_id: UUID, admin: User = Depends(require_superuser), - db: Session = Depends(get_db) + db: AsyncSession = Depends(get_async_db) ) -> Any: """Remove a user from an organization.""" try: - org = organization_crud.get(db, id=org_id) + org = await organization_crud.get(db, id=org_id) if not org: raise NotFoundError( detail=f"Organization {org_id} not found", error_code=ErrorCode.NOT_FOUND ) - user = user_crud.get(db, id=user_id) + user = await user_crud.get(db, id=user_id) if not user: raise NotFoundError( detail=f"User {user_id} not found", error_code=ErrorCode.USER_NOT_FOUND ) - success = organization_crud.remove_user( + success = await organization_crud.remove_user( db, organization_id=org_id, user_id=user_id diff --git a/backend/app/api/routes/auth.py b/backend/app/api/routes/auth.py old mode 100644 new mode 100755 index 162f868..fcfe1d0 --- a/backend/app/api/routes/auth.py +++ b/backend/app/api/routes/auth.py @@ -8,11 +8,11 @@ from fastapi import APIRouter, Depends, HTTPException, status, Body, Request from fastapi.security import OAuth2PasswordRequestForm from slowapi import Limiter from slowapi.util import get_remote_address -from sqlalchemy.orm import Session +from sqlalchemy.ext.asyncio import AsyncSession from app.api.dependencies.auth import get_current_user from app.core.auth import TokenExpiredError, TokenInvalidError, decode_token -from app.core.database import get_db +from app.core.database_async import get_async_db from app.models.user import User from app.schemas.users import ( UserCreate, @@ -29,8 +29,8 @@ from app.services.auth_service import AuthService, AuthenticationError from app.services.email_service import email_service from app.utils.security import create_password_reset_token, verify_password_reset_token from app.utils.device import extract_device_info -from app.crud.user import user as user_crud -from app.crud.session import session as session_crud +from app.crud.user_async import user_async as user_crud +from app.crud.session_async import session_async as session_crud from app.core.auth import get_password_hash router = APIRouter() @@ -49,7 +49,7 @@ RATE_MULTIPLIER = 100 if IS_TEST else 1 async def register_user( request: Request, user_data: UserCreate, - db: Session = Depends(get_db) + db: AsyncSession = Depends(get_async_db) ) -> Any: """ Register a new user. @@ -58,7 +58,7 @@ async def register_user( The created user information. """ try: - user = AuthService.create_user(db, user_data) + user = await AuthService.create_user(db, user_data) return user except AuthenticationError as e: logger.warning(f"Registration failed: {str(e)}") @@ -79,7 +79,7 @@ async def register_user( async def login( request: Request, login_data: LoginRequest, - db: Session = Depends(get_db) + db: AsyncSession = Depends(get_async_db) ) -> Any: """ Login with username and password. @@ -91,7 +91,7 @@ async def login( """ try: # Attempt to authenticate the user - user = AuthService.authenticate_user(db, login_data.email, login_data.password) + user = await AuthService.authenticate_user(db, login_data.email, login_data.password) # Explicitly check for None result and raise correct exception if user is None: @@ -126,7 +126,7 @@ async def login( location_country=device_info.location_country, ) - session_crud.create_session(db, obj_in=session_data) + await session_crud.create_session(db, obj_in=session_data) logger.info( f"User login successful: {user.email} from {device_info.device_name} " @@ -163,7 +163,7 @@ async def login( async def login_oauth( request: Request, form_data: OAuth2PasswordRequestForm = Depends(), - db: Session = Depends(get_db) + db: AsyncSession = Depends(get_async_db) ) -> Any: """ OAuth2-compatible login endpoint, used by the OpenAPI UI. @@ -174,7 +174,7 @@ async def login_oauth( Access and refresh tokens. """ try: - user = AuthService.authenticate_user(db, form_data.username, form_data.password) + user = await AuthService.authenticate_user(db, form_data.username, form_data.password) if user is None: raise HTTPException( @@ -207,7 +207,7 @@ async def login_oauth( location_country=device_info.location_country, ) - session_crud.create_session(db, obj_in=session_data) + await session_crud.create_session(db, obj_in=session_data) logger.info(f"OAuth login successful: {user.email} from {device_info.device_name}") except Exception as session_err: @@ -241,7 +241,7 @@ async def login_oauth( async def refresh_token( request: Request, refresh_data: RefreshTokenRequest, - db: Session = Depends(get_db) + db: AsyncSession = Depends(get_async_db) ) -> Any: """ Refresh access token using a refresh token. @@ -256,7 +256,7 @@ async def refresh_token( refresh_payload = decode_token(refresh_data.refresh_token, verify_type="refresh") # Check if session exists and is active - session = session_crud.get_active_by_jti(db, jti=refresh_payload.jti) + session = await session_crud.get_active_by_jti(db, jti=refresh_payload.jti) if not session: logger.warning(f"Refresh token used for inactive or non-existent session: {refresh_payload.jti}") @@ -267,14 +267,14 @@ async def refresh_token( ) # Generate new tokens - tokens = AuthService.refresh_tokens(db, refresh_data.refresh_token) + tokens = await AuthService.refresh_tokens(db, refresh_data.refresh_token) # Decode new refresh token to get new JTI new_refresh_payload = decode_token(tokens.refresh_token, verify_type="refresh") # Update session with new refresh token JTI and expiration try: - session_crud.update_refresh_token( + await session_crud.update_refresh_token( db, session=session, new_jti=new_refresh_payload.jti, @@ -344,7 +344,7 @@ async def get_current_user_info( async def request_password_reset( request: Request, reset_request: PasswordResetRequest, - db: Session = Depends(get_db) + db: AsyncSession = Depends(get_async_db) ) -> Any: """ Request a password reset. @@ -354,7 +354,7 @@ async def request_password_reset( """ try: # Look up user by email - user = user_crud.get_by_email(db, email=reset_request.email) + user = await user_crud.get_by_email(db, email=reset_request.email) # Only send email if user exists and is active if user and user.is_active: @@ -399,10 +399,10 @@ async def request_password_reset( operation_id="confirm_password_reset" ) @limiter.limit("5/minute") -def confirm_password_reset( +async def confirm_password_reset( request: Request, reset_confirm: PasswordResetConfirm, - db: Session = Depends(get_db) + db: AsyncSession = Depends(get_async_db) ) -> Any: """ Confirm password reset with token. @@ -420,7 +420,7 @@ def confirm_password_reset( ) # Look up user - user = user_crud.get_by_email(db, email=email) + user = await user_crud.get_by_email(db, email=email) if not user: raise HTTPException( @@ -437,7 +437,7 @@ def confirm_password_reset( # Update password user.password_hash = get_password_hash(reset_confirm.new_password) db.add(user) - db.commit() + await db.commit() logger.info(f"Password reset successful for {user.email}") @@ -450,7 +450,7 @@ def confirm_password_reset( raise except Exception as e: logger.error(f"Error confirming password reset: {str(e)}", exc_info=True) - db.rollback() + await db.rollback() raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="An error occurred while resetting your password" @@ -474,11 +474,11 @@ def confirm_password_reset( operation_id="logout" ) @limiter.limit("10/minute") -def logout( +async def logout( request: Request, logout_request: LogoutRequest, current_user: User = Depends(get_current_user), - db: Session = Depends(get_db) + db: AsyncSession = Depends(get_async_db) ) -> Any: """ Logout from current device by deactivating the session. @@ -505,7 +505,7 @@ def logout( ) # Find the session by JTI - session = session_crud.get_by_jti(db, jti=refresh_payload.jti) + session = await session_crud.get_by_jti(db, jti=refresh_payload.jti) if session: # Verify session belongs to current user (security check) @@ -520,7 +520,7 @@ def logout( ) # Deactivate the session - session_crud.deactivate(db, session_id=str(session.id)) + await session_crud.deactivate(db, session_id=str(session.id)) logger.info( f"User {current_user.id} logged out from {session.device_name} " @@ -563,10 +563,10 @@ def logout( operation_id="logout_all" ) @limiter.limit("5/minute") -def logout_all( +async def logout_all( request: Request, current_user: User = Depends(get_current_user), - db: Session = Depends(get_db) + db: AsyncSession = Depends(get_async_db) ) -> Any: """ Logout from all devices by deactivating all user sessions. @@ -580,7 +580,7 @@ def logout_all( """ try: # Deactivate all sessions for this user - count = session_crud.deactivate_all_user_sessions(db, user_id=str(current_user.id)) + count = await session_crud.deactivate_all_user_sessions(db, user_id=str(current_user.id)) logger.info(f"User {current_user.id} logged out from all devices ({count} sessions)") @@ -591,7 +591,7 @@ def logout_all( except Exception as e: logger.error(f"Error during logout-all for user {current_user.id}: {str(e)}", exc_info=True) - db.rollback() + await db.rollback() raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="An error occurred while logging out" diff --git a/backend/app/api/routes/organizations.py b/backend/app/api/routes/organizations.py old mode 100644 new mode 100755 index 6f8748d..6a756e3 --- a/backend/app/api/routes/organizations.py +++ b/backend/app/api/routes/organizations.py @@ -9,12 +9,12 @@ from typing import Any, List, Optional from uuid import UUID from fastapi import APIRouter, Depends, Query, status -from sqlalchemy.orm import Session +from sqlalchemy.ext.asyncio import AsyncSession from app.api.dependencies.auth import get_current_user from app.api.dependencies.permissions import require_org_admin, require_org_membership, get_current_org_role -from app.core.database import get_db -from app.crud.organization import organization as organization_crud +from app.core.database_async import get_async_db +from app.crud.organization_async import organization_async as organization_crud from app.models.user import User from app.models.user_organization import OrganizationRole from app.schemas.organizations import ( @@ -42,10 +42,10 @@ router = APIRouter() description="Get all organizations the current user belongs to", operation_id="get_my_organizations" ) -def get_my_organizations( +async def get_my_organizations( is_active: bool = Query(True, description="Filter by active membership"), current_user: User = Depends(get_current_user), - db: Session = Depends(get_db) + db: AsyncSession = Depends(get_async_db) ) -> Any: """ Get all organizations the current user belongs to. @@ -53,7 +53,7 @@ def get_my_organizations( Returns organizations with member count for each. """ try: - orgs = organization_crud.get_user_organizations( + orgs = await organization_crud.get_user_organizations( db, user_id=current_user.id, is_active=is_active @@ -77,7 +77,7 @@ def get_my_organizations( "settings": org.settings, "created_at": org.created_at, "updated_at": org.updated_at, - "member_count": organization_crud.get_member_count(db, organization_id=org.id) + "member_count": await organization_crud.get_member_count(db, organization_id=org.id) } orgs_with_data.append(OrganizationResponse(**org_dict)) @@ -95,10 +95,10 @@ def get_my_organizations( description="Get details of an organization the user belongs to", operation_id="get_organization" ) -def get_organization( +async def get_organization( organization_id: UUID, current_user: User = Depends(require_org_membership), - db: Session = Depends(get_db) + db: AsyncSession = Depends(get_async_db) ) -> Any: """ Get details of a specific organization. @@ -106,7 +106,7 @@ def get_organization( User must be a member of the organization. """ try: - org = organization_crud.get(db, id=organization_id) + org = await organization_crud.get(db, id=organization_id) if not org: raise NotFoundError( detail=f"Organization {organization_id} not found", @@ -122,7 +122,7 @@ def get_organization( "settings": org.settings, "created_at": org.created_at, "updated_at": org.updated_at, - "member_count": organization_crud.get_member_count(db, organization_id=org.id) + "member_count": await organization_crud.get_member_count(db, organization_id=org.id) } return OrganizationResponse(**org_dict) @@ -140,12 +140,12 @@ def get_organization( description="Get all members of an organization (members can view)", operation_id="get_organization_members" ) -def get_organization_members( +async def get_organization_members( organization_id: UUID, pagination: PaginationParams = Depends(), is_active: bool = Query(True, description="Filter by active status"), current_user: User = Depends(require_org_membership), - db: Session = Depends(get_db) + db: AsyncSession = Depends(get_async_db) ) -> Any: """ Get all members of an organization. @@ -153,7 +153,7 @@ def get_organization_members( User must be a member of the organization to view members. """ try: - members, total = organization_crud.get_organization_members( + members, total = await organization_crud.get_organization_members( db, organization_id=organization_id, skip=pagination.offset, @@ -184,11 +184,11 @@ def get_organization_members( description="Update organization details (admin/owner only)", operation_id="update_organization" ) -def update_organization( +async def update_organization( organization_id: UUID, org_in: OrganizationUpdate, current_user: User = Depends(require_org_admin), - db: Session = Depends(get_db) + db: AsyncSession = Depends(get_async_db) ) -> Any: """ Update organization details. @@ -196,14 +196,14 @@ def update_organization( Requires owner or admin role in the organization. """ try: - org = organization_crud.get(db, id=organization_id) + org = await organization_crud.get(db, id=organization_id) if not org: raise NotFoundError( detail=f"Organization {organization_id} not found", error_code=ErrorCode.NOT_FOUND ) - updated_org = organization_crud.update(db, db_obj=org, obj_in=org_in) + updated_org = await organization_crud.update(db, db_obj=org, obj_in=org_in) logger.info(f"User {current_user.email} updated organization {updated_org.name}") org_dict = { @@ -215,7 +215,7 @@ def update_organization( "settings": updated_org.settings, "created_at": updated_org.created_at, "updated_at": updated_org.updated_at, - "member_count": organization_crud.get_member_count(db, organization_id=updated_org.id) + "member_count": await organization_crud.get_member_count(db, organization_id=updated_org.id) } return OrganizationResponse(**org_dict) diff --git a/backend/app/api/routes/sessions.py b/backend/app/api/routes/sessions.py old mode 100644 new mode 100755 index ee4e710..b9dc3b5 --- a/backend/app/api/routes/sessions.py +++ b/backend/app/api/routes/sessions.py @@ -10,15 +10,15 @@ from uuid import UUID from fastapi import APIRouter, Depends, HTTPException, status, Request from slowapi import Limiter from slowapi.util import get_remote_address -from sqlalchemy.orm import Session +from sqlalchemy.ext.asyncio import AsyncSession from app.api.dependencies.auth import get_current_user -from app.core.database import get_db +from app.core.database_async import get_async_db from app.core.auth import decode_token from app.models.user import User from app.schemas.sessions import SessionResponse, SessionListResponse from app.schemas.common import MessageResponse -from app.crud.session import session as session_crud +from app.crud.session_async import session_async as session_crud from app.core.exceptions import NotFoundError, AuthorizationError, ErrorCode router = APIRouter() @@ -42,10 +42,10 @@ limiter = Limiter(key_func=get_remote_address) operation_id="list_my_sessions" ) @limiter.limit("30/minute") -def list_my_sessions( +async def list_my_sessions( request: Request, current_user: User = Depends(get_current_user), - db: Session = Depends(get_db) + db: AsyncSession = Depends(get_async_db) ) -> Any: """ List all active sessions for the current user. @@ -59,7 +59,7 @@ def list_my_sessions( """ try: # Get all active sessions for user - sessions = session_crud.get_user_sessions( + sessions = await session_crud.get_user_sessions( db, user_id=str(current_user.id), active_only=True @@ -125,11 +125,11 @@ def list_my_sessions( operation_id="revoke_session" ) @limiter.limit("10/minute") -def revoke_session( +async def revoke_session( request: Request, session_id: UUID, current_user: User = Depends(get_current_user), - db: Session = Depends(get_db) + db: AsyncSession = Depends(get_async_db) ) -> Any: """ Revoke a specific session by ID. @@ -144,7 +144,7 @@ def revoke_session( """ try: # Get the session - session = session_crud.get(db, id=str(session_id)) + session = await session_crud.get(db, id=str(session_id)) if not session: raise NotFoundError( @@ -164,7 +164,7 @@ def revoke_session( ) # Deactivate the session - session_crud.deactivate(db, session_id=str(session_id)) + await session_crud.deactivate(db, session_id=str(session_id)) logger.info( f"User {current_user.id} revoked session {session_id} " @@ -201,10 +201,10 @@ def revoke_session( operation_id="cleanup_expired_sessions" ) @limiter.limit("5/minute") -def cleanup_expired_sessions( +async def cleanup_expired_sessions( request: Request, current_user: User = Depends(get_current_user), - db: Session = Depends(get_db) + db: AsyncSession = Depends(get_async_db) ) -> Any: """ Cleanup expired sessions for the current user. @@ -220,7 +220,7 @@ def cleanup_expired_sessions( from datetime import datetime, timezone # Get all sessions for user - all_sessions = session_crud.get_user_sessions( + all_sessions = await session_crud.get_user_sessions( db, user_id=str(current_user.id), active_only=False @@ -230,10 +230,10 @@ def cleanup_expired_sessions( deleted_count = 0 for s in all_sessions: if not s.is_active and s.expires_at < datetime.now(timezone.utc): - db.delete(s) + await db.delete(s) deleted_count += 1 - db.commit() + await db.commit() logger.info(f"User {current_user.id} cleaned up {deleted_count} expired sessions") @@ -244,7 +244,7 @@ def cleanup_expired_sessions( except Exception as e: logger.error(f"Error cleaning up sessions for user {current_user.id}: {str(e)}", exc_info=True) - db.rollback() + await db.rollback() raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Failed to cleanup sessions" diff --git a/backend/app/api/routes/users.py b/backend/app/api/routes/users.py old mode 100644 new mode 100755 index fd38297..1412040 --- a/backend/app/api/routes/users.py +++ b/backend/app/api/routes/users.py @@ -6,13 +6,13 @@ from typing import Any, Optional from uuid import UUID from fastapi import APIRouter, Depends, Query, status, Request -from sqlalchemy.orm import Session +from sqlalchemy.ext.asyncio import AsyncSession from slowapi import Limiter from slowapi.util import get_remote_address from app.api.dependencies.auth import get_current_user, get_current_superuser -from app.core.database import get_db -from app.crud.user import user as user_crud +from app.core.database_async import get_async_db +from app.crud.user_async import user_async as user_crud from app.models.user import User from app.schemas.users import UserResponse, UserUpdate, PasswordChange from app.schemas.common import ( @@ -52,13 +52,13 @@ limiter = Limiter(key_func=get_remote_address) """, operation_id="list_users" ) -def list_users( +async def list_users( pagination: PaginationParams = Depends(), sort: SortParams = Depends(), is_active: Optional[bool] = Query(None, description="Filter by active status"), is_superuser: Optional[bool] = Query(None, description="Filter by superuser status"), current_user: User = Depends(get_current_superuser), - db: Session = Depends(get_db) + db: AsyncSession = Depends(get_async_db) ) -> Any: """ List all users with pagination, filtering, and sorting. @@ -74,7 +74,7 @@ def list_users( filters["is_superuser"] = is_superuser # Get paginated users with total count - users, total = user_crud.get_multi_with_total( + users, total = await user_crud.get_multi_with_total( db, skip=pagination.offset, limit=pagination.limit, @@ -135,10 +135,10 @@ def get_current_user_profile( """, operation_id="update_current_user" ) -def update_current_user( +async def update_current_user( user_update: UserUpdate, current_user: User = Depends(get_current_user), - db: Session = Depends(get_db) + db: AsyncSession = Depends(get_async_db) ) -> Any: """ Update current user's profile. @@ -154,7 +154,7 @@ def update_current_user( ) try: - updated_user = user_crud.update( + updated_user = await user_crud.update( db, db_obj=current_user, obj_in=user_update @@ -185,10 +185,10 @@ def update_current_user( """, operation_id="get_user_by_id" ) -def get_user_by_id( +async def get_user_by_id( user_id: UUID, current_user: User = Depends(get_current_user), - db: Session = Depends(get_db) + db: AsyncSession = Depends(get_async_db) ) -> Any: """ Get user by ID. @@ -206,7 +206,7 @@ def get_user_by_id( ) # Get user - user = user_crud.get(db, id=str(user_id)) + user = await user_crud.get(db, id=str(user_id)) if not user: raise NotFoundError( message=f"User with id {user_id} not found", @@ -232,11 +232,11 @@ def get_user_by_id( """, operation_id="update_user" ) -def update_user( +async def update_user( user_id: UUID, user_update: UserUpdate, current_user: User = Depends(get_current_user), - db: Session = Depends(get_db) + db: AsyncSession = Depends(get_async_db) ) -> Any: """ Update user by ID. @@ -257,7 +257,7 @@ def update_user( ) # Get user - user = user_crud.get(db, id=str(user_id)) + user = await user_crud.get(db, id=str(user_id)) if not user: raise NotFoundError( message=f"User with id {user_id} not found", @@ -273,7 +273,7 @@ def update_user( ) try: - updated_user = user_crud.update(db, db_obj=user, obj_in=user_update) + updated_user = await user_crud.update(db, db_obj=user, obj_in=user_update) logger.info(f"User {user_id} updated by {current_user.id}") return updated_user except ValueError as e: @@ -300,11 +300,11 @@ def update_user( operation_id="change_current_user_password" ) @limiter.limit("5/minute") -def change_current_user_password( +async def change_current_user_password( request: Request, password_change: PasswordChange, current_user: User = Depends(get_current_user), - db: Session = Depends(get_db) + db: AsyncSession = Depends(get_async_db) ) -> Any: """ Change current user's password. @@ -312,7 +312,7 @@ def change_current_user_password( Requires current password for verification. """ try: - success = AuthService.change_password( + success = await AuthService.change_password( db=db, user_id=current_user.id, current_password=password_change.current_password, @@ -353,10 +353,10 @@ def change_current_user_password( """, operation_id="delete_user" ) -def delete_user( +async def delete_user( user_id: UUID, current_user: User = Depends(get_current_superuser), - db: Session = Depends(get_db) + db: AsyncSession = Depends(get_async_db) ) -> Any: """ Delete user by ID (superuser only). @@ -371,7 +371,7 @@ def delete_user( ) # Get user - user = user_crud.get(db, id=str(user_id)) + user = await user_crud.get(db, id=str(user_id)) if not user: raise NotFoundError( message=f"User with id {user_id} not found", @@ -380,7 +380,7 @@ def delete_user( try: # Use soft delete instead of hard delete - user_crud.soft_delete(db, id=str(user_id)) + await user_crud.soft_delete(db, id=str(user_id)) logger.info(f"User {user_id} soft-deleted by {current_user.id}") return MessageResponse( success=True, diff --git a/backend/app/core/database_async.py b/backend/app/core/database_async.py old mode 100644 new mode 100755 index aecfa14..e198b33 --- a/backend/app/core/database_async.py +++ b/backend/app/core/database_async.py @@ -159,6 +159,10 @@ async def check_async_database_health() -> bool: return False +# Alias for consistency with main.py +check_database_health = check_async_database_health + + async def init_async_db() -> None: """ Initialize async database tables. diff --git a/backend/app/crud/base_async.py b/backend/app/crud/base_async.py old mode 100644 new mode 100755 index 0354b21..fed8e8a --- a/backend/app/crud/base_async.py +++ b/backend/app/crud/base_async.py @@ -179,10 +179,25 @@ class CRUDBaseAsync(Generic[ModelType, CreateSchemaType, UpdateSchemaType]): raise async def get_multi_with_total( - self, db: AsyncSession, *, skip: int = 0, limit: int = 100 + self, + db: AsyncSession, + *, + skip: int = 0, + limit: int = 100, + sort_by: Optional[str] = None, + sort_order: str = "asc", + filters: Optional[Dict[str, Any]] = None ) -> Tuple[List[ModelType], int]: """ - Get multiple records with total count for pagination. + Get multiple records with total count, filtering, and sorting. + + Args: + db: Database session + skip: Number of records to skip + limit: Maximum number of records to return + sort_by: Field name to sort by (must be a valid model attribute) + sort_order: Sort order ("asc" or "desc") + filters: Dictionary of filters (field_name: value) Returns: Tuple of (items, total_count) @@ -196,16 +211,35 @@ class CRUDBaseAsync(Generic[ModelType, CreateSchemaType, UpdateSchemaType]): raise ValueError("Maximum limit is 1000") try: - # Get total count - count_result = await db.execute( - select(func.count(self.model.id)) - ) + # Build base query + query = select(self.model) + + # Exclude soft-deleted records by default + if hasattr(self.model, 'deleted_at'): + query = query.where(self.model.deleted_at.is_(None)) + + # Apply filters + if filters: + for field, value in filters.items(): + if hasattr(self.model, field) and value is not None: + query = query.where(getattr(self.model, field) == value) + + # Get total count (before pagination) + count_query = select(func.count()).select_from(query.alias()) + count_result = await db.execute(count_query) total = count_result.scalar_one() - # Get paginated items - items_result = await db.execute( - select(self.model).offset(skip).limit(limit) - ) + # Apply sorting + if sort_by and hasattr(self.model, sort_by): + sort_column = getattr(self.model, sort_by) + if sort_order.lower() == "desc": + query = query.order_by(sort_column.desc()) + else: + query = query.order_by(sort_column.asc()) + + # Apply pagination + query = query.offset(skip).limit(limit) + items_result = await db.execute(query) items = list(items_result.scalars().all()) return items, total @@ -226,3 +260,92 @@ class CRUDBaseAsync(Generic[ModelType, CreateSchemaType, UpdateSchemaType]): """Check if a record exists by ID.""" obj = await self.get(db, id=id) return obj is not None + + async def soft_delete(self, db: AsyncSession, *, id: str) -> Optional[ModelType]: + """ + Soft delete a record by setting deleted_at timestamp. + + Only works if the model has a 'deleted_at' column. + """ + from datetime import datetime, timezone + + # Validate UUID format and convert to UUID object if string + try: + if isinstance(id, uuid.UUID): + uuid_obj = id + else: + uuid_obj = uuid.UUID(str(id)) + except (ValueError, AttributeError, TypeError) as e: + logger.warning(f"Invalid UUID format for soft deletion: {id} - {str(e)}") + return None + + try: + result = await db.execute( + select(self.model).where(self.model.id == uuid_obj) + ) + obj = result.scalar_one_or_none() + + if obj is None: + logger.warning(f"{self.model.__name__} with id {id} not found for soft deletion") + return None + + # Check if model supports soft deletes + if not hasattr(self.model, 'deleted_at'): + logger.error(f"{self.model.__name__} does not support soft deletes") + raise ValueError(f"{self.model.__name__} does not have a deleted_at column") + + # Set deleted_at timestamp + obj.deleted_at = datetime.now(timezone.utc) + db.add(obj) + await db.commit() + await db.refresh(obj) + return obj + except Exception as e: + await db.rollback() + logger.error(f"Error soft deleting {self.model.__name__} with id {id}: {str(e)}", exc_info=True) + raise + + async def restore(self, db: AsyncSession, *, id: str) -> Optional[ModelType]: + """ + Restore a soft-deleted record by clearing the deleted_at timestamp. + + Only works if the model has a 'deleted_at' column. + """ + # Validate UUID format + try: + if isinstance(id, uuid.UUID): + uuid_obj = id + else: + uuid_obj = uuid.UUID(str(id)) + except (ValueError, AttributeError, TypeError) as e: + logger.warning(f"Invalid UUID format for restoration: {id} - {str(e)}") + return None + + try: + # Find the soft-deleted record + if hasattr(self.model, 'deleted_at'): + result = await db.execute( + select(self.model).where( + self.model.id == uuid_obj, + self.model.deleted_at.isnot(None) + ) + ) + obj = result.scalar_one_or_none() + else: + logger.error(f"{self.model.__name__} does not support soft deletes") + raise ValueError(f"{self.model.__name__} does not have a deleted_at column") + + if obj is None: + logger.warning(f"Soft-deleted {self.model.__name__} with id {id} not found for restoration") + return None + + # Clear deleted_at timestamp + obj.deleted_at = None + db.add(obj) + await db.commit() + await db.refresh(obj) + return obj + except Exception as e: + await db.rollback() + logger.error(f"Error restoring {self.model.__name__} with id {id}: {str(e)}", exc_info=True) + raise diff --git a/backend/app/init_db.py b/backend/app/init_db.py old mode 100644 new mode 100755 diff --git a/backend/app/main.py b/backend/app/main.py old mode 100644 new mode 100755 index d2fb1cb..32c02c8 --- a/backend/app/main.py +++ b/backend/app/main.py @@ -14,7 +14,7 @@ from sqlalchemy import text from app.api.main import api_router from app.core.config import settings -from app.core.database import get_db, check_database_health +from app.core.database_async import check_database_health from app.core.exceptions import ( APIException, api_exception_handler, @@ -218,7 +218,7 @@ async def health_check() -> JSONResponse: # Database health check using dedicated health check function try: - db_healthy = check_database_health() + db_healthy = await check_database_health() if db_healthy: health_status["checks"]["database"] = { "status": "healthy", diff --git a/backend/app/services/auth_service.py b/backend/app/services/auth_service.py old mode 100644 new mode 100755 index 4941671..5d3537c --- a/backend/app/services/auth_service.py +++ b/backend/app/services/auth_service.py @@ -3,7 +3,8 @@ import logging from typing import Optional from uuid import UUID -from sqlalchemy.orm import Session +from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy import select from app.core.auth import ( verify_password, @@ -28,7 +29,7 @@ class AuthService: """Service for handling authentication operations""" @staticmethod - def authenticate_user(db: Session, email: str, password: str) -> Optional[User]: + async def authenticate_user(db: AsyncSession, email: str, password: str) -> Optional[User]: """ Authenticate a user with email and password. @@ -40,7 +41,8 @@ class AuthService: Returns: User if authenticated, None otherwise """ - user = db.query(User).filter(User.email == email).first() + result = await db.execute(select(User).where(User.email == email)) + user = result.scalar_one_or_none() if not user: return None @@ -54,7 +56,7 @@ class AuthService: return user @staticmethod - def create_user(db: Session, user_data: UserCreate) -> User: + async def create_user(db: AsyncSession, user_data: UserCreate) -> User: """ Create a new user. @@ -66,7 +68,8 @@ class AuthService: Created user """ # Check if user already exists - existing_user = db.query(User).filter(User.email == user_data.email).first() + result = await db.execute(select(User).where(User.email == user_data.email)) + existing_user = result.scalar_one_or_none() if existing_user: raise AuthenticationError("User with this email already exists") @@ -85,8 +88,8 @@ class AuthService: ) db.add(user) - db.commit() - db.refresh(user) + await db.commit() + await db.refresh(user) return user @@ -124,7 +127,7 @@ class AuthService: ) @staticmethod - def refresh_tokens(db: Session, refresh_token: str) -> Token: + async def refresh_tokens(db: AsyncSession, refresh_token: str) -> Token: """ Generate new tokens using a refresh token. @@ -150,7 +153,8 @@ class AuthService: user_id = token_data.user_id # Get user from database - user = db.query(User).filter(User.id == user_id).first() + result = await db.execute(select(User).where(User.id == user_id)) + user = result.scalar_one_or_none() if not user or not user.is_active: raise TokenInvalidError("Invalid user or inactive account") @@ -162,7 +166,7 @@ class AuthService: raise @staticmethod - def change_password(db: Session, user_id: UUID, current_password: str, new_password: str) -> bool: + async def change_password(db: AsyncSession, user_id: UUID, current_password: str, new_password: str) -> bool: """ Change a user's password. @@ -178,7 +182,8 @@ class AuthService: Raises: AuthenticationError: If current password is incorrect """ - user = db.query(User).filter(User.id == user_id).first() + result = await db.execute(select(User).where(User.id == user_id)) + user = result.scalar_one_or_none() if not user: raise AuthenticationError("User not found") @@ -188,6 +193,6 @@ class AuthService: # Update password user.password_hash = get_password_hash(new_password) - db.commit() + await db.commit() return True diff --git a/backend/app/services/session_cleanup.py b/backend/app/services/session_cleanup.py old mode 100644 new mode 100755 index d15fb33..ff7fa04 --- a/backend/app/services/session_cleanup.py +++ b/backend/app/services/session_cleanup.py @@ -6,13 +6,13 @@ This service runs periodically to remove old session records from the database. import logging from datetime import datetime, timezone -from app.core.database import SessionLocal -from app.crud.session import session as session_crud +from app.core.database_async import AsyncSessionLocal +from app.crud.session_async import session_async as session_crud logger = logging.getLogger(__name__) -def cleanup_expired_sessions(keep_days: int = 30) -> int: +async def cleanup_expired_sessions(keep_days: int = 30) -> int: """ Clean up expired and inactive sessions. @@ -29,52 +29,58 @@ def cleanup_expired_sessions(keep_days: int = 30) -> int: """ logger.info("Starting session cleanup job...") - db = SessionLocal() - try: - # Use CRUD method to cleanup - count = session_crud.cleanup_expired(db, keep_days=keep_days) + async with AsyncSessionLocal() as db: + try: + # Use CRUD method to cleanup + count = await session_crud.cleanup_expired(db, keep_days=keep_days) - logger.info(f"Session cleanup complete: {count} sessions deleted") + logger.info(f"Session cleanup complete: {count} sessions deleted") - return count + return count - except Exception as e: - logger.error(f"Error during session cleanup: {str(e)}", exc_info=True) - return 0 - finally: - db.close() + except Exception as e: + logger.error(f"Error during session cleanup: {str(e)}", exc_info=True) + return 0 -def get_session_statistics() -> dict: +async def get_session_statistics() -> dict: """ Get statistics about current sessions. Returns: Dictionary with session stats """ - db = SessionLocal() - try: - from app.models.user_session import UserSession + async with AsyncSessionLocal() as db: + try: + from app.models.user_session import UserSession + from sqlalchemy import select, func - total_sessions = db.query(UserSession).count() - active_sessions = db.query(UserSession).filter(UserSession.is_active == True).count() - expired_sessions = db.query(UserSession).filter( - UserSession.expires_at < datetime.now(timezone.utc) - ).count() + total_result = await db.execute(select(func.count(UserSession.id))) + total_sessions = total_result.scalar_one() - stats = { - "total": total_sessions, - "active": active_sessions, - "inactive": total_sessions - active_sessions, - "expired": expired_sessions, - } + active_result = await db.execute( + select(func.count(UserSession.id)).where(UserSession.is_active == True) + ) + active_sessions = active_result.scalar_one() - logger.info(f"Session statistics: {stats}") + expired_result = await db.execute( + select(func.count(UserSession.id)).where( + UserSession.expires_at < datetime.now(timezone.utc) + ) + ) + expired_sessions = expired_result.scalar_one() - return stats + stats = { + "total": total_sessions, + "active": active_sessions, + "inactive": total_sessions - active_sessions, + "expired": expired_sessions, + } - except Exception as e: - logger.error(f"Error getting session statistics: {str(e)}", exc_info=True) - return {} - finally: - db.close() + logger.info(f"Session statistics: {stats}") + + return stats + + except Exception as e: + logger.error(f"Error getting session statistics: {str(e)}", exc_info=True) + return {}