Refactor backend to adopt async patterns across services, API routes, and CRUD operations

- Migrated database sessions and operations to `AsyncSession` for full async support.
- Updated all service methods and dependencies (`get_db` to `get_async_db`) to support async logic.
- Refactored admin, user, organization, session-related CRUD methods, and routes with await syntax.
- Improved consistency and performance with async SQLAlchemy patterns.
- Enhanced logging and error handling for async context.
This commit is contained in:
Felipe Cardoso
2025-10-31 21:57:12 +01:00
parent 19ecd04a41
commit 26ff08d9f9
14 changed files with 385 additions and 239 deletions

0
backend/app/__init__.py Normal file → Executable file
View File

24
backend/app/api/dependencies/auth.py Normal file → Executable file
View File

@@ -3,18 +3,19 @@ from typing import Optional
from fastapi import Depends, HTTPException, status, Header from fastapi import Depends, HTTPException, status, Header
from fastapi.security import OAuth2PasswordBearer from fastapi.security import OAuth2PasswordBearer
from fastapi.security.utils import get_authorization_scheme_param from fastapi.security.utils import get_authorization_scheme_param
from sqlalchemy.orm import Session from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select
from app.core.auth import get_token_data, TokenExpiredError, TokenInvalidError from app.core.auth import get_token_data, TokenExpiredError, TokenInvalidError
from app.core.database import get_db from app.core.database_async import get_async_db
from app.models.user import User from app.models.user import User
# OAuth2 configuration # OAuth2 configuration
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/api/v1/auth/login") oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/api/v1/auth/login")
def get_current_user( async def get_current_user(
db: Session = Depends(get_db), db: AsyncSession = Depends(get_async_db),
token: str = Depends(oauth2_scheme) token: str = Depends(oauth2_scheme)
) -> User: ) -> User:
""" """
@@ -35,7 +36,11 @@ def get_current_user(
token_data = get_token_data(token) token_data = get_token_data(token)
# Get user from database # Get user from database
user = db.query(User).filter(User.id == token_data.user_id).first() result = await db.execute(
select(User).where(User.id == token_data.user_id)
)
user = result.scalar_one_or_none()
if not user: if not user:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, status_code=status.HTTP_404_NOT_FOUND,
@@ -133,8 +138,8 @@ async def get_optional_token(authorization: str = Header(None)) -> Optional[str]
return token return token
def get_optional_current_user( async def get_optional_current_user(
db: Session = Depends(get_db), db: AsyncSession = Depends(get_async_db),
token: Optional[str] = Depends(get_optional_token) token: Optional[str] = Depends(get_optional_token)
) -> Optional[User]: ) -> Optional[User]:
""" """
@@ -153,7 +158,10 @@ def get_optional_current_user(
try: try:
token_data = get_token_data(token) token_data = get_token_data(token)
user = db.query(User).filter(User.id == token_data.user_id).first() result = await db.execute(
select(User).where(User.id == token_data.user_id)
)
user = result.scalar_one_or_none()
if not user or not user.is_active: if not user or not user.is_active:
return None return None
return user return user

26
backend/app/api/dependencies/permissions.py Normal file → Executable file
View File

@@ -10,13 +10,13 @@ These dependencies are optional and flexible:
from typing import Optional from typing import Optional
from uuid import UUID from uuid import UUID
from fastapi import Depends, HTTPException, status from fastapi import Depends, HTTPException, status
from sqlalchemy.orm import Session from sqlalchemy.ext.asyncio import AsyncSession
from app.core.database import get_db from app.core.database_async import get_async_db
from app.models.user import User from app.models.user import User
from app.models.user_organization import OrganizationRole from app.models.user_organization import OrganizationRole
from app.api.dependencies.auth import get_current_user from app.api.dependencies.auth import get_current_user
from app.crud.organization import organization as organization_crud from app.crud.organization_async import organization_async as organization_crud
def require_superuser( def require_superuser(
@@ -73,11 +73,11 @@ class OrganizationPermission:
""" """
self.allowed_roles = allowed_roles self.allowed_roles = allowed_roles
def __call__( async def __call__(
self, self,
organization_id: UUID, organization_id: UUID,
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
db: Session = Depends(get_db) db: AsyncSession = Depends(get_async_db)
) -> User: ) -> User:
""" """
Check if user has required role in the organization. Check if user has required role in the organization.
@@ -98,7 +98,7 @@ class OrganizationPermission:
return current_user return current_user
# Get user's role in organization # Get user's role in organization
user_role = organization_crud.get_user_role_in_org( user_role = await organization_crud.get_user_role_in_org(
db, db,
user_id=current_user.id, user_id=current_user.id,
organization_id=organization_id organization_id=organization_id
@@ -129,10 +129,10 @@ require_org_member = OrganizationPermission([
]) ])
def get_current_org_role( async def get_current_org_role(
organization_id: UUID, organization_id: UUID,
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
db: Session = Depends(get_db) db: AsyncSession = Depends(get_async_db)
) -> Optional[OrganizationRole]: ) -> Optional[OrganizationRole]:
""" """
Get the current user's role in an organization. Get the current user's role in an organization.
@@ -142,7 +142,7 @@ def get_current_org_role(
Example: Example:
@router.get("/organizations/{org_id}/items") @router.get("/organizations/{org_id}/items")
def list_items( async def list_items(
org_id: UUID, org_id: UUID,
role: OrganizationRole = Depends(get_current_org_role) role: OrganizationRole = Depends(get_current_org_role)
): ):
@@ -153,17 +153,17 @@ def get_current_org_role(
if current_user.is_superuser: if current_user.is_superuser:
return OrganizationRole.OWNER return OrganizationRole.OWNER
return organization_crud.get_user_role_in_org( return await organization_crud.get_user_role_in_org(
db, db,
user_id=current_user.id, user_id=current_user.id,
organization_id=organization_id organization_id=organization_id
) )
def require_org_membership( async def require_org_membership(
organization_id: UUID, organization_id: UUID,
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
db: Session = Depends(get_db) db: AsyncSession = Depends(get_async_db)
) -> User: ) -> User:
""" """
Ensure user is a member of the organization (any role). Ensure user is a member of the organization (any role).
@@ -173,7 +173,7 @@ def require_org_membership(
if current_user.is_superuser: if current_user.is_superuser:
return current_user return current_user
user_role = organization_crud.get_user_role_in_org( user_role = await organization_crud.get_user_role_in_org(
db, db,
user_id=current_user.id, user_id=current_user.id,
organization_id=organization_id organization_id=organization_id

138
backend/app/api/routes/admin.py Normal file → Executable file
View File

@@ -11,13 +11,13 @@ from uuid import UUID
from enum import Enum from enum import Enum
from fastapi import APIRouter, Depends, Query, Body, status from fastapi import APIRouter, Depends, Query, Body, status
from sqlalchemy.orm import Session from sqlalchemy.ext.asyncio import AsyncSession
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from app.api.dependencies.permissions import require_superuser from app.api.dependencies.permissions import require_superuser
from app.core.database import get_db from app.core.database_async import get_async_db
from app.crud.user import user as user_crud from app.crud.user_async import user_async as user_crud
from app.crud.organization import organization as organization_crud from app.crud.organization_async import organization_async as organization_crud
from app.models.user import User from app.models.user import User
from app.models.user_organization import OrganizationRole from app.models.user_organization import OrganizationRole
from app.schemas.users import UserResponse, UserCreate, UserUpdate from app.schemas.users import UserResponse, UserCreate, UserUpdate
@@ -73,14 +73,14 @@ class BulkActionResult(BaseModel):
description="Get paginated list of all users with filtering and search (admin only)", description="Get paginated list of all users with filtering and search (admin only)",
operation_id="admin_list_users" operation_id="admin_list_users"
) )
def admin_list_users( async def admin_list_users(
pagination: PaginationParams = Depends(), pagination: PaginationParams = Depends(),
sort: SortParams = Depends(), sort: SortParams = Depends(),
is_active: Optional[bool] = Query(None, description="Filter by active status"), is_active: Optional[bool] = Query(None, description="Filter by active status"),
is_superuser: Optional[bool] = Query(None, description="Filter by superuser status"), is_superuser: Optional[bool] = Query(None, description="Filter by superuser status"),
search: Optional[str] = Query(None, description="Search by email, name"), search: Optional[str] = Query(None, description="Search by email, name"),
admin: User = Depends(require_superuser), admin: User = Depends(require_superuser),
db: Session = Depends(get_db) db: AsyncSession = Depends(get_async_db)
) -> Any: ) -> Any:
""" """
List all users with comprehensive filtering and search. List all users with comprehensive filtering and search.
@@ -96,7 +96,7 @@ def admin_list_users(
filters["is_superuser"] = is_superuser filters["is_superuser"] = is_superuser
# Get users with search # Get users with search
users, total = user_crud.get_multi_with_total( users, total = await user_crud.get_multi_with_total(
db, db,
skip=pagination.offset, skip=pagination.offset,
limit=pagination.limit, limit=pagination.limit,
@@ -128,10 +128,10 @@ def admin_list_users(
description="Create a new user (admin only)", description="Create a new user (admin only)",
operation_id="admin_create_user" operation_id="admin_create_user"
) )
def admin_create_user( async def admin_create_user(
user_in: UserCreate, user_in: UserCreate,
admin: User = Depends(require_superuser), admin: User = Depends(require_superuser),
db: Session = Depends(get_db) db: AsyncSession = Depends(get_async_db)
) -> Any: ) -> Any:
""" """
Create a new user with admin privileges. Create a new user with admin privileges.
@@ -139,7 +139,7 @@ def admin_create_user(
Allows setting is_superuser and other fields. Allows setting is_superuser and other fields.
""" """
try: try:
user = user_crud.create(db, obj_in=user_in) user = await user_crud.create(db, obj_in=user_in)
logger.info(f"Admin {admin.email} created user {user.email}") logger.info(f"Admin {admin.email} created user {user.email}")
return user return user
except ValueError as e: except ValueError as e:
@@ -160,13 +160,13 @@ def admin_create_user(
description="Get detailed user information (admin only)", description="Get detailed user information (admin only)",
operation_id="admin_get_user" operation_id="admin_get_user"
) )
def admin_get_user( async def admin_get_user(
user_id: UUID, user_id: UUID,
admin: User = Depends(require_superuser), admin: User = Depends(require_superuser),
db: Session = Depends(get_db) db: AsyncSession = Depends(get_async_db)
) -> Any: ) -> Any:
"""Get detailed information about a specific user.""" """Get detailed information about a specific user."""
user = user_crud.get(db, id=user_id) user = await user_crud.get(db, id=user_id)
if not user: if not user:
raise NotFoundError( raise NotFoundError(
detail=f"User {user_id} not found", detail=f"User {user_id} not found",
@@ -182,22 +182,22 @@ def admin_get_user(
description="Update user information (admin only)", description="Update user information (admin only)",
operation_id="admin_update_user" operation_id="admin_update_user"
) )
def admin_update_user( async def admin_update_user(
user_id: UUID, user_id: UUID,
user_in: UserUpdate, user_in: UserUpdate,
admin: User = Depends(require_superuser), admin: User = Depends(require_superuser),
db: Session = Depends(get_db) db: AsyncSession = Depends(get_async_db)
) -> Any: ) -> Any:
"""Update user information with admin privileges.""" """Update user information with admin privileges."""
try: try:
user = user_crud.get(db, id=user_id) user = await user_crud.get(db, id=user_id)
if not user: if not user:
raise NotFoundError( raise NotFoundError(
detail=f"User {user_id} not found", detail=f"User {user_id} not found",
error_code=ErrorCode.USER_NOT_FOUND error_code=ErrorCode.USER_NOT_FOUND
) )
updated_user = user_crud.update(db, db_obj=user, obj_in=user_in) updated_user = await user_crud.update(db, db_obj=user, obj_in=user_in)
logger.info(f"Admin {admin.email} updated user {updated_user.email}") logger.info(f"Admin {admin.email} updated user {updated_user.email}")
return updated_user return updated_user
@@ -215,14 +215,14 @@ def admin_update_user(
description="Soft delete a user (admin only)", description="Soft delete a user (admin only)",
operation_id="admin_delete_user" operation_id="admin_delete_user"
) )
def admin_delete_user( async def admin_delete_user(
user_id: UUID, user_id: UUID,
admin: User = Depends(require_superuser), admin: User = Depends(require_superuser),
db: Session = Depends(get_db) db: AsyncSession = Depends(get_async_db)
) -> Any: ) -> Any:
"""Soft delete a user (sets deleted_at timestamp).""" """Soft delete a user (sets deleted_at timestamp)."""
try: try:
user = user_crud.get(db, id=user_id) user = await user_crud.get(db, id=user_id)
if not user: if not user:
raise NotFoundError( raise NotFoundError(
detail=f"User {user_id} not found", detail=f"User {user_id} not found",
@@ -236,7 +236,7 @@ def admin_delete_user(
error_code=ErrorCode.OPERATION_FORBIDDEN error_code=ErrorCode.OPERATION_FORBIDDEN
) )
user_crud.soft_delete(db, id=user_id) await user_crud.soft_delete(db, id=user_id)
logger.info(f"Admin {admin.email} deleted user {user.email}") logger.info(f"Admin {admin.email} deleted user {user.email}")
return MessageResponse( return MessageResponse(
@@ -258,21 +258,21 @@ def admin_delete_user(
description="Activate a user account (admin only)", description="Activate a user account (admin only)",
operation_id="admin_activate_user" operation_id="admin_activate_user"
) )
def admin_activate_user( async def admin_activate_user(
user_id: UUID, user_id: UUID,
admin: User = Depends(require_superuser), admin: User = Depends(require_superuser),
db: Session = Depends(get_db) db: AsyncSession = Depends(get_async_db)
) -> Any: ) -> Any:
"""Activate a user account.""" """Activate a user account."""
try: try:
user = user_crud.get(db, id=user_id) user = await user_crud.get(db, id=user_id)
if not user: if not user:
raise NotFoundError( raise NotFoundError(
detail=f"User {user_id} not found", detail=f"User {user_id} not found",
error_code=ErrorCode.USER_NOT_FOUND error_code=ErrorCode.USER_NOT_FOUND
) )
user_crud.update(db, db_obj=user, obj_in={"is_active": True}) await user_crud.update(db, db_obj=user, obj_in={"is_active": True})
logger.info(f"Admin {admin.email} activated user {user.email}") logger.info(f"Admin {admin.email} activated user {user.email}")
return MessageResponse( return MessageResponse(
@@ -294,14 +294,14 @@ def admin_activate_user(
description="Deactivate a user account (admin only)", description="Deactivate a user account (admin only)",
operation_id="admin_deactivate_user" operation_id="admin_deactivate_user"
) )
def admin_deactivate_user( async def admin_deactivate_user(
user_id: UUID, user_id: UUID,
admin: User = Depends(require_superuser), admin: User = Depends(require_superuser),
db: Session = Depends(get_db) db: AsyncSession = Depends(get_async_db)
) -> Any: ) -> Any:
"""Deactivate a user account.""" """Deactivate a user account."""
try: try:
user = user_crud.get(db, id=user_id) user = await user_crud.get(db, id=user_id)
if not user: if not user:
raise NotFoundError( raise NotFoundError(
detail=f"User {user_id} not found", detail=f"User {user_id} not found",
@@ -315,7 +315,7 @@ def admin_deactivate_user(
error_code=ErrorCode.OPERATION_FORBIDDEN error_code=ErrorCode.OPERATION_FORBIDDEN
) )
user_crud.update(db, db_obj=user, obj_in={"is_active": False}) await user_crud.update(db, db_obj=user, obj_in={"is_active": False})
logger.info(f"Admin {admin.email} deactivated user {user.email}") logger.info(f"Admin {admin.email} deactivated user {user.email}")
return MessageResponse( return MessageResponse(
@@ -337,10 +337,10 @@ def admin_deactivate_user(
description="Perform bulk actions on multiple users (admin only)", description="Perform bulk actions on multiple users (admin only)",
operation_id="admin_bulk_user_action" operation_id="admin_bulk_user_action"
) )
def admin_bulk_user_action( async def admin_bulk_user_action(
bulk_action: BulkUserAction, bulk_action: BulkUserAction,
admin: User = Depends(require_superuser), admin: User = Depends(require_superuser),
db: Session = Depends(get_db) db: AsyncSession = Depends(get_async_db)
) -> Any: ) -> Any:
""" """
Perform bulk actions on multiple users. Perform bulk actions on multiple users.
@@ -354,7 +354,7 @@ def admin_bulk_user_action(
try: try:
for user_id in bulk_action.user_ids: for user_id in bulk_action.user_ids:
try: try:
user = user_crud.get(db, id=user_id) user = await user_crud.get(db, id=user_id)
if not user: if not user:
failed_count += 1 failed_count += 1
failed_ids.append(user_id) failed_ids.append(user_id)
@@ -367,11 +367,11 @@ def admin_bulk_user_action(
continue continue
if bulk_action.action == BulkAction.ACTIVATE: if bulk_action.action == BulkAction.ACTIVATE:
user_crud.update(db, db_obj=user, obj_in={"is_active": True}) await user_crud.update(db, db_obj=user, obj_in={"is_active": True})
elif bulk_action.action == BulkAction.DEACTIVATE: elif bulk_action.action == BulkAction.DEACTIVATE:
user_crud.update(db, db_obj=user, obj_in={"is_active": False}) await user_crud.update(db, db_obj=user, obj_in={"is_active": False})
elif bulk_action.action == BulkAction.DELETE: elif bulk_action.action == BulkAction.DELETE:
user_crud.soft_delete(db, id=user_id) await user_crud.soft_delete(db, id=user_id)
affected_count += 1 affected_count += 1
@@ -407,16 +407,16 @@ def admin_bulk_user_action(
description="Get paginated list of all organizations (admin only)", description="Get paginated list of all organizations (admin only)",
operation_id="admin_list_organizations" operation_id="admin_list_organizations"
) )
def admin_list_organizations( async def admin_list_organizations(
pagination: PaginationParams = Depends(), pagination: PaginationParams = Depends(),
is_active: Optional[bool] = Query(None, description="Filter by active status"), is_active: Optional[bool] = Query(None, description="Filter by active status"),
search: Optional[str] = Query(None, description="Search by name, slug, description"), search: Optional[str] = Query(None, description="Search by name, slug, description"),
admin: User = Depends(require_superuser), admin: User = Depends(require_superuser),
db: Session = Depends(get_db) db: AsyncSession = Depends(get_async_db)
) -> Any: ) -> Any:
"""List all organizations with filtering and search.""" """List all organizations with filtering and search."""
try: try:
orgs, total = organization_crud.get_multi_with_filters( orgs, total = await organization_crud.get_multi_with_filters(
db, db,
skip=pagination.offset, skip=pagination.offset,
limit=pagination.limit, limit=pagination.limit,
@@ -438,7 +438,7 @@ def admin_list_organizations(
"settings": org.settings, "settings": org.settings,
"created_at": org.created_at, "created_at": org.created_at,
"updated_at": org.updated_at, "updated_at": org.updated_at,
"member_count": organization_crud.get_member_count(db, organization_id=org.id) "member_count": await organization_crud.get_member_count(db, organization_id=org.id)
} }
orgs_with_count.append(OrganizationResponse(**org_dict)) orgs_with_count.append(OrganizationResponse(**org_dict))
@@ -464,14 +464,14 @@ def admin_list_organizations(
description="Create a new organization (admin only)", description="Create a new organization (admin only)",
operation_id="admin_create_organization" operation_id="admin_create_organization"
) )
def admin_create_organization( async def admin_create_organization(
org_in: OrganizationCreate, org_in: OrganizationCreate,
admin: User = Depends(require_superuser), admin: User = Depends(require_superuser),
db: Session = Depends(get_db) db: AsyncSession = Depends(get_async_db)
) -> Any: ) -> Any:
"""Create a new organization.""" """Create a new organization."""
try: try:
org = organization_crud.create(db, obj_in=org_in) org = await organization_crud.create(db, obj_in=org_in)
logger.info(f"Admin {admin.email} created organization {org.name}") logger.info(f"Admin {admin.email} created organization {org.name}")
# Add member count # Add member count
@@ -506,13 +506,13 @@ def admin_create_organization(
description="Get detailed organization information (admin only)", description="Get detailed organization information (admin only)",
operation_id="admin_get_organization" operation_id="admin_get_organization"
) )
def admin_get_organization( async def admin_get_organization(
org_id: UUID, org_id: UUID,
admin: User = Depends(require_superuser), admin: User = Depends(require_superuser),
db: Session = Depends(get_db) db: AsyncSession = Depends(get_async_db)
) -> Any: ) -> Any:
"""Get detailed information about a specific organization.""" """Get detailed information about a specific organization."""
org = organization_crud.get(db, id=org_id) org = await organization_crud.get(db, id=org_id)
if not org: if not org:
raise NotFoundError( raise NotFoundError(
detail=f"Organization {org_id} not found", detail=f"Organization {org_id} not found",
@@ -528,7 +528,7 @@ def admin_get_organization(
"settings": org.settings, "settings": org.settings,
"created_at": org.created_at, "created_at": org.created_at,
"updated_at": org.updated_at, "updated_at": org.updated_at,
"member_count": organization_crud.get_member_count(db, organization_id=org.id) "member_count": await organization_crud.get_member_count(db, organization_id=org.id)
} }
return OrganizationResponse(**org_dict) return OrganizationResponse(**org_dict)
@@ -540,22 +540,22 @@ def admin_get_organization(
description="Update organization information (admin only)", description="Update organization information (admin only)",
operation_id="admin_update_organization" operation_id="admin_update_organization"
) )
def admin_update_organization( async def admin_update_organization(
org_id: UUID, org_id: UUID,
org_in: OrganizationUpdate, org_in: OrganizationUpdate,
admin: User = Depends(require_superuser), admin: User = Depends(require_superuser),
db: Session = Depends(get_db) db: AsyncSession = Depends(get_async_db)
) -> Any: ) -> Any:
"""Update organization information.""" """Update organization information."""
try: try:
org = organization_crud.get(db, id=org_id) org = await organization_crud.get(db, id=org_id)
if not org: if not org:
raise NotFoundError( raise NotFoundError(
detail=f"Organization {org_id} not found", detail=f"Organization {org_id} not found",
error_code=ErrorCode.NOT_FOUND error_code=ErrorCode.NOT_FOUND
) )
updated_org = organization_crud.update(db, db_obj=org, obj_in=org_in) updated_org = await organization_crud.update(db, db_obj=org, obj_in=org_in)
logger.info(f"Admin {admin.email} updated organization {updated_org.name}") logger.info(f"Admin {admin.email} updated organization {updated_org.name}")
org_dict = { org_dict = {
@@ -567,7 +567,7 @@ def admin_update_organization(
"settings": updated_org.settings, "settings": updated_org.settings,
"created_at": updated_org.created_at, "created_at": updated_org.created_at,
"updated_at": updated_org.updated_at, "updated_at": updated_org.updated_at,
"member_count": organization_crud.get_member_count(db, organization_id=updated_org.id) "member_count": await organization_crud.get_member_count(db, organization_id=updated_org.id)
} }
return OrganizationResponse(**org_dict) return OrganizationResponse(**org_dict)
@@ -585,21 +585,21 @@ def admin_update_organization(
description="Delete an organization (admin only)", description="Delete an organization (admin only)",
operation_id="admin_delete_organization" operation_id="admin_delete_organization"
) )
def admin_delete_organization( async def admin_delete_organization(
org_id: UUID, org_id: UUID,
admin: User = Depends(require_superuser), admin: User = Depends(require_superuser),
db: Session = Depends(get_db) db: AsyncSession = Depends(get_async_db)
) -> Any: ) -> Any:
"""Delete an organization and all its relationships.""" """Delete an organization and all its relationships."""
try: try:
org = organization_crud.get(db, id=org_id) org = await organization_crud.get(db, id=org_id)
if not org: if not org:
raise NotFoundError( raise NotFoundError(
detail=f"Organization {org_id} not found", detail=f"Organization {org_id} not found",
error_code=ErrorCode.NOT_FOUND error_code=ErrorCode.NOT_FOUND
) )
organization_crud.remove(db, id=org_id) await organization_crud.remove(db, id=org_id)
logger.info(f"Admin {admin.email} deleted organization {org.name}") logger.info(f"Admin {admin.email} deleted organization {org.name}")
return MessageResponse( return MessageResponse(
@@ -621,23 +621,23 @@ def admin_delete_organization(
description="Get all members of an organization (admin only)", description="Get all members of an organization (admin only)",
operation_id="admin_list_organization_members" operation_id="admin_list_organization_members"
) )
def admin_list_organization_members( async def admin_list_organization_members(
org_id: UUID, org_id: UUID,
pagination: PaginationParams = Depends(), pagination: PaginationParams = Depends(),
is_active: Optional[bool] = Query(True, description="Filter by active status"), is_active: Optional[bool] = Query(True, description="Filter by active status"),
admin: User = Depends(require_superuser), admin: User = Depends(require_superuser),
db: Session = Depends(get_db) db: AsyncSession = Depends(get_async_db)
) -> Any: ) -> Any:
"""List all members of an organization.""" """List all members of an organization."""
try: try:
org = organization_crud.get(db, id=org_id) org = await organization_crud.get(db, id=org_id)
if not org: if not org:
raise NotFoundError( raise NotFoundError(
detail=f"Organization {org_id} not found", detail=f"Organization {org_id} not found",
error_code=ErrorCode.NOT_FOUND error_code=ErrorCode.NOT_FOUND
) )
members, total = organization_crud.get_organization_members( members, total = await organization_crud.get_organization_members(
db, db,
organization_id=org_id, organization_id=org_id,
skip=pagination.offset, skip=pagination.offset,
@@ -677,29 +677,29 @@ class AddMemberRequest(BaseModel):
description="Add a user to an organization (admin only)", description="Add a user to an organization (admin only)",
operation_id="admin_add_organization_member" operation_id="admin_add_organization_member"
) )
def admin_add_organization_member( async def admin_add_organization_member(
org_id: UUID, org_id: UUID,
request: AddMemberRequest, request: AddMemberRequest,
admin: User = Depends(require_superuser), admin: User = Depends(require_superuser),
db: Session = Depends(get_db) db: AsyncSession = Depends(get_async_db)
) -> Any: ) -> Any:
"""Add a user to an organization.""" """Add a user to an organization."""
try: try:
org = organization_crud.get(db, id=org_id) org = await organization_crud.get(db, id=org_id)
if not org: if not org:
raise NotFoundError( raise NotFoundError(
detail=f"Organization {org_id} not found", detail=f"Organization {org_id} not found",
error_code=ErrorCode.NOT_FOUND error_code=ErrorCode.NOT_FOUND
) )
user = user_crud.get(db, id=request.user_id) user = await user_crud.get(db, id=request.user_id)
if not user: if not user:
raise NotFoundError( raise NotFoundError(
detail=f"User {request.user_id} not found", detail=f"User {request.user_id} not found",
error_code=ErrorCode.USER_NOT_FOUND error_code=ErrorCode.USER_NOT_FOUND
) )
organization_crud.add_user( await organization_crud.add_user(
db, db,
organization_id=org_id, organization_id=org_id,
user_id=request.user_id, user_id=request.user_id,
@@ -733,29 +733,29 @@ def admin_add_organization_member(
description="Remove a user from an organization (admin only)", description="Remove a user from an organization (admin only)",
operation_id="admin_remove_organization_member" operation_id="admin_remove_organization_member"
) )
def admin_remove_organization_member( async def admin_remove_organization_member(
org_id: UUID, org_id: UUID,
user_id: UUID, user_id: UUID,
admin: User = Depends(require_superuser), admin: User = Depends(require_superuser),
db: Session = Depends(get_db) db: AsyncSession = Depends(get_async_db)
) -> Any: ) -> Any:
"""Remove a user from an organization.""" """Remove a user from an organization."""
try: try:
org = organization_crud.get(db, id=org_id) org = await organization_crud.get(db, id=org_id)
if not org: if not org:
raise NotFoundError( raise NotFoundError(
detail=f"Organization {org_id} not found", detail=f"Organization {org_id} not found",
error_code=ErrorCode.NOT_FOUND error_code=ErrorCode.NOT_FOUND
) )
user = user_crud.get(db, id=user_id) user = await user_crud.get(db, id=user_id)
if not user: if not user:
raise NotFoundError( raise NotFoundError(
detail=f"User {user_id} not found", detail=f"User {user_id} not found",
error_code=ErrorCode.USER_NOT_FOUND error_code=ErrorCode.USER_NOT_FOUND
) )
success = organization_crud.remove_user( success = await organization_crud.remove_user(
db, db,
organization_id=org_id, organization_id=org_id,
user_id=user_id user_id=user_id

62
backend/app/api/routes/auth.py Normal file → Executable file
View File

@@ -8,11 +8,11 @@ from fastapi import APIRouter, Depends, HTTPException, status, Body, Request
from fastapi.security import OAuth2PasswordRequestForm from fastapi.security import OAuth2PasswordRequestForm
from slowapi import Limiter from slowapi import Limiter
from slowapi.util import get_remote_address from slowapi.util import get_remote_address
from sqlalchemy.orm import Session from sqlalchemy.ext.asyncio import AsyncSession
from app.api.dependencies.auth import get_current_user from app.api.dependencies.auth import get_current_user
from app.core.auth import TokenExpiredError, TokenInvalidError, decode_token from app.core.auth import TokenExpiredError, TokenInvalidError, decode_token
from app.core.database import get_db from app.core.database_async import get_async_db
from app.models.user import User from app.models.user import User
from app.schemas.users import ( from app.schemas.users import (
UserCreate, UserCreate,
@@ -29,8 +29,8 @@ from app.services.auth_service import AuthService, AuthenticationError
from app.services.email_service import email_service from app.services.email_service import email_service
from app.utils.security import create_password_reset_token, verify_password_reset_token from app.utils.security import create_password_reset_token, verify_password_reset_token
from app.utils.device import extract_device_info from app.utils.device import extract_device_info
from app.crud.user import user as user_crud from app.crud.user_async import user_async as user_crud
from app.crud.session import session as session_crud from app.crud.session_async import session_async as session_crud
from app.core.auth import get_password_hash from app.core.auth import get_password_hash
router = APIRouter() router = APIRouter()
@@ -49,7 +49,7 @@ RATE_MULTIPLIER = 100 if IS_TEST else 1
async def register_user( async def register_user(
request: Request, request: Request,
user_data: UserCreate, user_data: UserCreate,
db: Session = Depends(get_db) db: AsyncSession = Depends(get_async_db)
) -> Any: ) -> Any:
""" """
Register a new user. Register a new user.
@@ -58,7 +58,7 @@ async def register_user(
The created user information. The created user information.
""" """
try: try:
user = AuthService.create_user(db, user_data) user = await AuthService.create_user(db, user_data)
return user return user
except AuthenticationError as e: except AuthenticationError as e:
logger.warning(f"Registration failed: {str(e)}") logger.warning(f"Registration failed: {str(e)}")
@@ -79,7 +79,7 @@ async def register_user(
async def login( async def login(
request: Request, request: Request,
login_data: LoginRequest, login_data: LoginRequest,
db: Session = Depends(get_db) db: AsyncSession = Depends(get_async_db)
) -> Any: ) -> Any:
""" """
Login with username and password. Login with username and password.
@@ -91,7 +91,7 @@ async def login(
""" """
try: try:
# Attempt to authenticate the user # Attempt to authenticate the user
user = AuthService.authenticate_user(db, login_data.email, login_data.password) user = await AuthService.authenticate_user(db, login_data.email, login_data.password)
# Explicitly check for None result and raise correct exception # Explicitly check for None result and raise correct exception
if user is None: if user is None:
@@ -126,7 +126,7 @@ async def login(
location_country=device_info.location_country, location_country=device_info.location_country,
) )
session_crud.create_session(db, obj_in=session_data) await session_crud.create_session(db, obj_in=session_data)
logger.info( logger.info(
f"User login successful: {user.email} from {device_info.device_name} " f"User login successful: {user.email} from {device_info.device_name} "
@@ -163,7 +163,7 @@ async def login(
async def login_oauth( async def login_oauth(
request: Request, request: Request,
form_data: OAuth2PasswordRequestForm = Depends(), form_data: OAuth2PasswordRequestForm = Depends(),
db: Session = Depends(get_db) db: AsyncSession = Depends(get_async_db)
) -> Any: ) -> Any:
""" """
OAuth2-compatible login endpoint, used by the OpenAPI UI. OAuth2-compatible login endpoint, used by the OpenAPI UI.
@@ -174,7 +174,7 @@ async def login_oauth(
Access and refresh tokens. Access and refresh tokens.
""" """
try: try:
user = AuthService.authenticate_user(db, form_data.username, form_data.password) user = await AuthService.authenticate_user(db, form_data.username, form_data.password)
if user is None: if user is None:
raise HTTPException( raise HTTPException(
@@ -207,7 +207,7 @@ async def login_oauth(
location_country=device_info.location_country, location_country=device_info.location_country,
) )
session_crud.create_session(db, obj_in=session_data) await session_crud.create_session(db, obj_in=session_data)
logger.info(f"OAuth login successful: {user.email} from {device_info.device_name}") logger.info(f"OAuth login successful: {user.email} from {device_info.device_name}")
except Exception as session_err: except Exception as session_err:
@@ -241,7 +241,7 @@ async def login_oauth(
async def refresh_token( async def refresh_token(
request: Request, request: Request,
refresh_data: RefreshTokenRequest, refresh_data: RefreshTokenRequest,
db: Session = Depends(get_db) db: AsyncSession = Depends(get_async_db)
) -> Any: ) -> Any:
""" """
Refresh access token using a refresh token. Refresh access token using a refresh token.
@@ -256,7 +256,7 @@ async def refresh_token(
refresh_payload = decode_token(refresh_data.refresh_token, verify_type="refresh") refresh_payload = decode_token(refresh_data.refresh_token, verify_type="refresh")
# Check if session exists and is active # Check if session exists and is active
session = session_crud.get_active_by_jti(db, jti=refresh_payload.jti) session = await session_crud.get_active_by_jti(db, jti=refresh_payload.jti)
if not session: if not session:
logger.warning(f"Refresh token used for inactive or non-existent session: {refresh_payload.jti}") logger.warning(f"Refresh token used for inactive or non-existent session: {refresh_payload.jti}")
@@ -267,14 +267,14 @@ async def refresh_token(
) )
# Generate new tokens # Generate new tokens
tokens = AuthService.refresh_tokens(db, refresh_data.refresh_token) tokens = await AuthService.refresh_tokens(db, refresh_data.refresh_token)
# Decode new refresh token to get new JTI # Decode new refresh token to get new JTI
new_refresh_payload = decode_token(tokens.refresh_token, verify_type="refresh") new_refresh_payload = decode_token(tokens.refresh_token, verify_type="refresh")
# Update session with new refresh token JTI and expiration # Update session with new refresh token JTI and expiration
try: try:
session_crud.update_refresh_token( await session_crud.update_refresh_token(
db, db,
session=session, session=session,
new_jti=new_refresh_payload.jti, new_jti=new_refresh_payload.jti,
@@ -344,7 +344,7 @@ async def get_current_user_info(
async def request_password_reset( async def request_password_reset(
request: Request, request: Request,
reset_request: PasswordResetRequest, reset_request: PasswordResetRequest,
db: Session = Depends(get_db) db: AsyncSession = Depends(get_async_db)
) -> Any: ) -> Any:
""" """
Request a password reset. Request a password reset.
@@ -354,7 +354,7 @@ async def request_password_reset(
""" """
try: try:
# Look up user by email # Look up user by email
user = user_crud.get_by_email(db, email=reset_request.email) user = await user_crud.get_by_email(db, email=reset_request.email)
# Only send email if user exists and is active # Only send email if user exists and is active
if user and user.is_active: if user and user.is_active:
@@ -399,10 +399,10 @@ async def request_password_reset(
operation_id="confirm_password_reset" operation_id="confirm_password_reset"
) )
@limiter.limit("5/minute") @limiter.limit("5/minute")
def confirm_password_reset( async def confirm_password_reset(
request: Request, request: Request,
reset_confirm: PasswordResetConfirm, reset_confirm: PasswordResetConfirm,
db: Session = Depends(get_db) db: AsyncSession = Depends(get_async_db)
) -> Any: ) -> Any:
""" """
Confirm password reset with token. Confirm password reset with token.
@@ -420,7 +420,7 @@ def confirm_password_reset(
) )
# Look up user # Look up user
user = user_crud.get_by_email(db, email=email) user = await user_crud.get_by_email(db, email=email)
if not user: if not user:
raise HTTPException( raise HTTPException(
@@ -437,7 +437,7 @@ def confirm_password_reset(
# Update password # Update password
user.password_hash = get_password_hash(reset_confirm.new_password) user.password_hash = get_password_hash(reset_confirm.new_password)
db.add(user) db.add(user)
db.commit() await db.commit()
logger.info(f"Password reset successful for {user.email}") logger.info(f"Password reset successful for {user.email}")
@@ -450,7 +450,7 @@ def confirm_password_reset(
raise raise
except Exception as e: except Exception as e:
logger.error(f"Error confirming password reset: {str(e)}", exc_info=True) logger.error(f"Error confirming password reset: {str(e)}", exc_info=True)
db.rollback() await db.rollback()
raise HTTPException( raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="An error occurred while resetting your password" detail="An error occurred while resetting your password"
@@ -474,11 +474,11 @@ def confirm_password_reset(
operation_id="logout" operation_id="logout"
) )
@limiter.limit("10/minute") @limiter.limit("10/minute")
def logout( async def logout(
request: Request, request: Request,
logout_request: LogoutRequest, logout_request: LogoutRequest,
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
db: Session = Depends(get_db) db: AsyncSession = Depends(get_async_db)
) -> Any: ) -> Any:
""" """
Logout from current device by deactivating the session. Logout from current device by deactivating the session.
@@ -505,7 +505,7 @@ def logout(
) )
# Find the session by JTI # Find the session by JTI
session = session_crud.get_by_jti(db, jti=refresh_payload.jti) session = await session_crud.get_by_jti(db, jti=refresh_payload.jti)
if session: if session:
# Verify session belongs to current user (security check) # Verify session belongs to current user (security check)
@@ -520,7 +520,7 @@ def logout(
) )
# Deactivate the session # Deactivate the session
session_crud.deactivate(db, session_id=str(session.id)) await session_crud.deactivate(db, session_id=str(session.id))
logger.info( logger.info(
f"User {current_user.id} logged out from {session.device_name} " f"User {current_user.id} logged out from {session.device_name} "
@@ -563,10 +563,10 @@ def logout(
operation_id="logout_all" operation_id="logout_all"
) )
@limiter.limit("5/minute") @limiter.limit("5/minute")
def logout_all( async def logout_all(
request: Request, request: Request,
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
db: Session = Depends(get_db) db: AsyncSession = Depends(get_async_db)
) -> Any: ) -> Any:
""" """
Logout from all devices by deactivating all user sessions. Logout from all devices by deactivating all user sessions.
@@ -580,7 +580,7 @@ def logout_all(
""" """
try: try:
# Deactivate all sessions for this user # Deactivate all sessions for this user
count = session_crud.deactivate_all_user_sessions(db, user_id=str(current_user.id)) count = await session_crud.deactivate_all_user_sessions(db, user_id=str(current_user.id))
logger.info(f"User {current_user.id} logged out from all devices ({count} sessions)") logger.info(f"User {current_user.id} logged out from all devices ({count} sessions)")
@@ -591,7 +591,7 @@ def logout_all(
except Exception as e: except Exception as e:
logger.error(f"Error during logout-all for user {current_user.id}: {str(e)}", exc_info=True) logger.error(f"Error during logout-all for user {current_user.id}: {str(e)}", exc_info=True)
db.rollback() await db.rollback()
raise HTTPException( raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="An error occurred while logging out" detail="An error occurred while logging out"

38
backend/app/api/routes/organizations.py Normal file → Executable file
View File

@@ -9,12 +9,12 @@ from typing import Any, List, Optional
from uuid import UUID from uuid import UUID
from fastapi import APIRouter, Depends, Query, status from fastapi import APIRouter, Depends, Query, status
from sqlalchemy.orm import Session from sqlalchemy.ext.asyncio import AsyncSession
from app.api.dependencies.auth import get_current_user from app.api.dependencies.auth import get_current_user
from app.api.dependencies.permissions import require_org_admin, require_org_membership, get_current_org_role from app.api.dependencies.permissions import require_org_admin, require_org_membership, get_current_org_role
from app.core.database import get_db from app.core.database_async import get_async_db
from app.crud.organization import organization as organization_crud from app.crud.organization_async import organization_async as organization_crud
from app.models.user import User from app.models.user import User
from app.models.user_organization import OrganizationRole from app.models.user_organization import OrganizationRole
from app.schemas.organizations import ( from app.schemas.organizations import (
@@ -42,10 +42,10 @@ router = APIRouter()
description="Get all organizations the current user belongs to", description="Get all organizations the current user belongs to",
operation_id="get_my_organizations" operation_id="get_my_organizations"
) )
def get_my_organizations( async def get_my_organizations(
is_active: bool = Query(True, description="Filter by active membership"), is_active: bool = Query(True, description="Filter by active membership"),
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
db: Session = Depends(get_db) db: AsyncSession = Depends(get_async_db)
) -> Any: ) -> Any:
""" """
Get all organizations the current user belongs to. Get all organizations the current user belongs to.
@@ -53,7 +53,7 @@ def get_my_organizations(
Returns organizations with member count for each. Returns organizations with member count for each.
""" """
try: try:
orgs = organization_crud.get_user_organizations( orgs = await organization_crud.get_user_organizations(
db, db,
user_id=current_user.id, user_id=current_user.id,
is_active=is_active is_active=is_active
@@ -77,7 +77,7 @@ def get_my_organizations(
"settings": org.settings, "settings": org.settings,
"created_at": org.created_at, "created_at": org.created_at,
"updated_at": org.updated_at, "updated_at": org.updated_at,
"member_count": organization_crud.get_member_count(db, organization_id=org.id) "member_count": await organization_crud.get_member_count(db, organization_id=org.id)
} }
orgs_with_data.append(OrganizationResponse(**org_dict)) orgs_with_data.append(OrganizationResponse(**org_dict))
@@ -95,10 +95,10 @@ def get_my_organizations(
description="Get details of an organization the user belongs to", description="Get details of an organization the user belongs to",
operation_id="get_organization" operation_id="get_organization"
) )
def get_organization( async def get_organization(
organization_id: UUID, organization_id: UUID,
current_user: User = Depends(require_org_membership), current_user: User = Depends(require_org_membership),
db: Session = Depends(get_db) db: AsyncSession = Depends(get_async_db)
) -> Any: ) -> Any:
""" """
Get details of a specific organization. Get details of a specific organization.
@@ -106,7 +106,7 @@ def get_organization(
User must be a member of the organization. User must be a member of the organization.
""" """
try: try:
org = organization_crud.get(db, id=organization_id) org = await organization_crud.get(db, id=organization_id)
if not org: if not org:
raise NotFoundError( raise NotFoundError(
detail=f"Organization {organization_id} not found", detail=f"Organization {organization_id} not found",
@@ -122,7 +122,7 @@ def get_organization(
"settings": org.settings, "settings": org.settings,
"created_at": org.created_at, "created_at": org.created_at,
"updated_at": org.updated_at, "updated_at": org.updated_at,
"member_count": organization_crud.get_member_count(db, organization_id=org.id) "member_count": await organization_crud.get_member_count(db, organization_id=org.id)
} }
return OrganizationResponse(**org_dict) return OrganizationResponse(**org_dict)
@@ -140,12 +140,12 @@ def get_organization(
description="Get all members of an organization (members can view)", description="Get all members of an organization (members can view)",
operation_id="get_organization_members" operation_id="get_organization_members"
) )
def get_organization_members( async def get_organization_members(
organization_id: UUID, organization_id: UUID,
pagination: PaginationParams = Depends(), pagination: PaginationParams = Depends(),
is_active: bool = Query(True, description="Filter by active status"), is_active: bool = Query(True, description="Filter by active status"),
current_user: User = Depends(require_org_membership), current_user: User = Depends(require_org_membership),
db: Session = Depends(get_db) db: AsyncSession = Depends(get_async_db)
) -> Any: ) -> Any:
""" """
Get all members of an organization. Get all members of an organization.
@@ -153,7 +153,7 @@ def get_organization_members(
User must be a member of the organization to view members. User must be a member of the organization to view members.
""" """
try: try:
members, total = organization_crud.get_organization_members( members, total = await organization_crud.get_organization_members(
db, db,
organization_id=organization_id, organization_id=organization_id,
skip=pagination.offset, skip=pagination.offset,
@@ -184,11 +184,11 @@ def get_organization_members(
description="Update organization details (admin/owner only)", description="Update organization details (admin/owner only)",
operation_id="update_organization" operation_id="update_organization"
) )
def update_organization( async def update_organization(
organization_id: UUID, organization_id: UUID,
org_in: OrganizationUpdate, org_in: OrganizationUpdate,
current_user: User = Depends(require_org_admin), current_user: User = Depends(require_org_admin),
db: Session = Depends(get_db) db: AsyncSession = Depends(get_async_db)
) -> Any: ) -> Any:
""" """
Update organization details. Update organization details.
@@ -196,14 +196,14 @@ def update_organization(
Requires owner or admin role in the organization. Requires owner or admin role in the organization.
""" """
try: try:
org = organization_crud.get(db, id=organization_id) org = await organization_crud.get(db, id=organization_id)
if not org: if not org:
raise NotFoundError( raise NotFoundError(
detail=f"Organization {organization_id} not found", detail=f"Organization {organization_id} not found",
error_code=ErrorCode.NOT_FOUND error_code=ErrorCode.NOT_FOUND
) )
updated_org = organization_crud.update(db, db_obj=org, obj_in=org_in) updated_org = await organization_crud.update(db, db_obj=org, obj_in=org_in)
logger.info(f"User {current_user.email} updated organization {updated_org.name}") logger.info(f"User {current_user.email} updated organization {updated_org.name}")
org_dict = { org_dict = {
@@ -215,7 +215,7 @@ def update_organization(
"settings": updated_org.settings, "settings": updated_org.settings,
"created_at": updated_org.created_at, "created_at": updated_org.created_at,
"updated_at": updated_org.updated_at, "updated_at": updated_org.updated_at,
"member_count": organization_crud.get_member_count(db, organization_id=updated_org.id) "member_count": await organization_crud.get_member_count(db, organization_id=updated_org.id)
} }
return OrganizationResponse(**org_dict) return OrganizationResponse(**org_dict)

32
backend/app/api/routes/sessions.py Normal file → Executable file
View File

@@ -10,15 +10,15 @@ from uuid import UUID
from fastapi import APIRouter, Depends, HTTPException, status, Request from fastapi import APIRouter, Depends, HTTPException, status, Request
from slowapi import Limiter from slowapi import Limiter
from slowapi.util import get_remote_address from slowapi.util import get_remote_address
from sqlalchemy.orm import Session from sqlalchemy.ext.asyncio import AsyncSession
from app.api.dependencies.auth import get_current_user from app.api.dependencies.auth import get_current_user
from app.core.database import get_db from app.core.database_async import get_async_db
from app.core.auth import decode_token from app.core.auth import decode_token
from app.models.user import User from app.models.user import User
from app.schemas.sessions import SessionResponse, SessionListResponse from app.schemas.sessions import SessionResponse, SessionListResponse
from app.schemas.common import MessageResponse from app.schemas.common import MessageResponse
from app.crud.session import session as session_crud from app.crud.session_async import session_async as session_crud
from app.core.exceptions import NotFoundError, AuthorizationError, ErrorCode from app.core.exceptions import NotFoundError, AuthorizationError, ErrorCode
router = APIRouter() router = APIRouter()
@@ -42,10 +42,10 @@ limiter = Limiter(key_func=get_remote_address)
operation_id="list_my_sessions" operation_id="list_my_sessions"
) )
@limiter.limit("30/minute") @limiter.limit("30/minute")
def list_my_sessions( async def list_my_sessions(
request: Request, request: Request,
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
db: Session = Depends(get_db) db: AsyncSession = Depends(get_async_db)
) -> Any: ) -> Any:
""" """
List all active sessions for the current user. List all active sessions for the current user.
@@ -59,7 +59,7 @@ def list_my_sessions(
""" """
try: try:
# Get all active sessions for user # Get all active sessions for user
sessions = session_crud.get_user_sessions( sessions = await session_crud.get_user_sessions(
db, db,
user_id=str(current_user.id), user_id=str(current_user.id),
active_only=True active_only=True
@@ -125,11 +125,11 @@ def list_my_sessions(
operation_id="revoke_session" operation_id="revoke_session"
) )
@limiter.limit("10/minute") @limiter.limit("10/minute")
def revoke_session( async def revoke_session(
request: Request, request: Request,
session_id: UUID, session_id: UUID,
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
db: Session = Depends(get_db) db: AsyncSession = Depends(get_async_db)
) -> Any: ) -> Any:
""" """
Revoke a specific session by ID. Revoke a specific session by ID.
@@ -144,7 +144,7 @@ def revoke_session(
""" """
try: try:
# Get the session # Get the session
session = session_crud.get(db, id=str(session_id)) session = await session_crud.get(db, id=str(session_id))
if not session: if not session:
raise NotFoundError( raise NotFoundError(
@@ -164,7 +164,7 @@ def revoke_session(
) )
# Deactivate the session # Deactivate the session
session_crud.deactivate(db, session_id=str(session_id)) await session_crud.deactivate(db, session_id=str(session_id))
logger.info( logger.info(
f"User {current_user.id} revoked session {session_id} " f"User {current_user.id} revoked session {session_id} "
@@ -201,10 +201,10 @@ def revoke_session(
operation_id="cleanup_expired_sessions" operation_id="cleanup_expired_sessions"
) )
@limiter.limit("5/minute") @limiter.limit("5/minute")
def cleanup_expired_sessions( async def cleanup_expired_sessions(
request: Request, request: Request,
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
db: Session = Depends(get_db) db: AsyncSession = Depends(get_async_db)
) -> Any: ) -> Any:
""" """
Cleanup expired sessions for the current user. Cleanup expired sessions for the current user.
@@ -220,7 +220,7 @@ def cleanup_expired_sessions(
from datetime import datetime, timezone from datetime import datetime, timezone
# Get all sessions for user # Get all sessions for user
all_sessions = session_crud.get_user_sessions( all_sessions = await session_crud.get_user_sessions(
db, db,
user_id=str(current_user.id), user_id=str(current_user.id),
active_only=False active_only=False
@@ -230,10 +230,10 @@ def cleanup_expired_sessions(
deleted_count = 0 deleted_count = 0
for s in all_sessions: for s in all_sessions:
if not s.is_active and s.expires_at < datetime.now(timezone.utc): if not s.is_active and s.expires_at < datetime.now(timezone.utc):
db.delete(s) await db.delete(s)
deleted_count += 1 deleted_count += 1
db.commit() await db.commit()
logger.info(f"User {current_user.id} cleaned up {deleted_count} expired sessions") logger.info(f"User {current_user.id} cleaned up {deleted_count} expired sessions")
@@ -244,7 +244,7 @@ def cleanup_expired_sessions(
except Exception as e: except Exception as e:
logger.error(f"Error cleaning up sessions for user {current_user.id}: {str(e)}", exc_info=True) logger.error(f"Error cleaning up sessions for user {current_user.id}: {str(e)}", exc_info=True)
db.rollback() await db.rollback()
raise HTTPException( raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to cleanup sessions" detail="Failed to cleanup sessions"

46
backend/app/api/routes/users.py Normal file → Executable file
View File

@@ -6,13 +6,13 @@ from typing import Any, Optional
from uuid import UUID from uuid import UUID
from fastapi import APIRouter, Depends, Query, status, Request from fastapi import APIRouter, Depends, Query, status, Request
from sqlalchemy.orm import Session from sqlalchemy.ext.asyncio import AsyncSession
from slowapi import Limiter from slowapi import Limiter
from slowapi.util import get_remote_address from slowapi.util import get_remote_address
from app.api.dependencies.auth import get_current_user, get_current_superuser from app.api.dependencies.auth import get_current_user, get_current_superuser
from app.core.database import get_db from app.core.database_async import get_async_db
from app.crud.user import user as user_crud from app.crud.user_async import user_async as user_crud
from app.models.user import User from app.models.user import User
from app.schemas.users import UserResponse, UserUpdate, PasswordChange from app.schemas.users import UserResponse, UserUpdate, PasswordChange
from app.schemas.common import ( from app.schemas.common import (
@@ -52,13 +52,13 @@ limiter = Limiter(key_func=get_remote_address)
""", """,
operation_id="list_users" operation_id="list_users"
) )
def list_users( async def list_users(
pagination: PaginationParams = Depends(), pagination: PaginationParams = Depends(),
sort: SortParams = Depends(), sort: SortParams = Depends(),
is_active: Optional[bool] = Query(None, description="Filter by active status"), is_active: Optional[bool] = Query(None, description="Filter by active status"),
is_superuser: Optional[bool] = Query(None, description="Filter by superuser status"), is_superuser: Optional[bool] = Query(None, description="Filter by superuser status"),
current_user: User = Depends(get_current_superuser), current_user: User = Depends(get_current_superuser),
db: Session = Depends(get_db) db: AsyncSession = Depends(get_async_db)
) -> Any: ) -> Any:
""" """
List all users with pagination, filtering, and sorting. List all users with pagination, filtering, and sorting.
@@ -74,7 +74,7 @@ def list_users(
filters["is_superuser"] = is_superuser filters["is_superuser"] = is_superuser
# Get paginated users with total count # Get paginated users with total count
users, total = user_crud.get_multi_with_total( users, total = await user_crud.get_multi_with_total(
db, db,
skip=pagination.offset, skip=pagination.offset,
limit=pagination.limit, limit=pagination.limit,
@@ -135,10 +135,10 @@ def get_current_user_profile(
""", """,
operation_id="update_current_user" operation_id="update_current_user"
) )
def update_current_user( async def update_current_user(
user_update: UserUpdate, user_update: UserUpdate,
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
db: Session = Depends(get_db) db: AsyncSession = Depends(get_async_db)
) -> Any: ) -> Any:
""" """
Update current user's profile. Update current user's profile.
@@ -154,7 +154,7 @@ def update_current_user(
) )
try: try:
updated_user = user_crud.update( updated_user = await user_crud.update(
db, db,
db_obj=current_user, db_obj=current_user,
obj_in=user_update obj_in=user_update
@@ -185,10 +185,10 @@ def update_current_user(
""", """,
operation_id="get_user_by_id" operation_id="get_user_by_id"
) )
def get_user_by_id( async def get_user_by_id(
user_id: UUID, user_id: UUID,
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
db: Session = Depends(get_db) db: AsyncSession = Depends(get_async_db)
) -> Any: ) -> Any:
""" """
Get user by ID. Get user by ID.
@@ -206,7 +206,7 @@ def get_user_by_id(
) )
# Get user # Get user
user = user_crud.get(db, id=str(user_id)) user = await user_crud.get(db, id=str(user_id))
if not user: if not user:
raise NotFoundError( raise NotFoundError(
message=f"User with id {user_id} not found", message=f"User with id {user_id} not found",
@@ -232,11 +232,11 @@ def get_user_by_id(
""", """,
operation_id="update_user" operation_id="update_user"
) )
def update_user( async def update_user(
user_id: UUID, user_id: UUID,
user_update: UserUpdate, user_update: UserUpdate,
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
db: Session = Depends(get_db) db: AsyncSession = Depends(get_async_db)
) -> Any: ) -> Any:
""" """
Update user by ID. Update user by ID.
@@ -257,7 +257,7 @@ def update_user(
) )
# Get user # Get user
user = user_crud.get(db, id=str(user_id)) user = await user_crud.get(db, id=str(user_id))
if not user: if not user:
raise NotFoundError( raise NotFoundError(
message=f"User with id {user_id} not found", message=f"User with id {user_id} not found",
@@ -273,7 +273,7 @@ def update_user(
) )
try: try:
updated_user = user_crud.update(db, db_obj=user, obj_in=user_update) updated_user = await user_crud.update(db, db_obj=user, obj_in=user_update)
logger.info(f"User {user_id} updated by {current_user.id}") logger.info(f"User {user_id} updated by {current_user.id}")
return updated_user return updated_user
except ValueError as e: except ValueError as e:
@@ -300,11 +300,11 @@ def update_user(
operation_id="change_current_user_password" operation_id="change_current_user_password"
) )
@limiter.limit("5/minute") @limiter.limit("5/minute")
def change_current_user_password( async def change_current_user_password(
request: Request, request: Request,
password_change: PasswordChange, password_change: PasswordChange,
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
db: Session = Depends(get_db) db: AsyncSession = Depends(get_async_db)
) -> Any: ) -> Any:
""" """
Change current user's password. Change current user's password.
@@ -312,7 +312,7 @@ def change_current_user_password(
Requires current password for verification. Requires current password for verification.
""" """
try: try:
success = AuthService.change_password( success = await AuthService.change_password(
db=db, db=db,
user_id=current_user.id, user_id=current_user.id,
current_password=password_change.current_password, current_password=password_change.current_password,
@@ -353,10 +353,10 @@ def change_current_user_password(
""", """,
operation_id="delete_user" operation_id="delete_user"
) )
def delete_user( async def delete_user(
user_id: UUID, user_id: UUID,
current_user: User = Depends(get_current_superuser), current_user: User = Depends(get_current_superuser),
db: Session = Depends(get_db) db: AsyncSession = Depends(get_async_db)
) -> Any: ) -> Any:
""" """
Delete user by ID (superuser only). Delete user by ID (superuser only).
@@ -371,7 +371,7 @@ def delete_user(
) )
# Get user # Get user
user = user_crud.get(db, id=str(user_id)) user = await user_crud.get(db, id=str(user_id))
if not user: if not user:
raise NotFoundError( raise NotFoundError(
message=f"User with id {user_id} not found", message=f"User with id {user_id} not found",
@@ -380,7 +380,7 @@ def delete_user(
try: try:
# Use soft delete instead of hard delete # Use soft delete instead of hard delete
user_crud.soft_delete(db, id=str(user_id)) await user_crud.soft_delete(db, id=str(user_id))
logger.info(f"User {user_id} soft-deleted by {current_user.id}") logger.info(f"User {user_id} soft-deleted by {current_user.id}")
return MessageResponse( return MessageResponse(
success=True, success=True,

4
backend/app/core/database_async.py Normal file → Executable file
View File

@@ -159,6 +159,10 @@ async def check_async_database_health() -> bool:
return False return False
# Alias for consistency with main.py
check_database_health = check_async_database_health
async def init_async_db() -> None: async def init_async_db() -> None:
""" """
Initialize async database tables. Initialize async database tables.

143
backend/app/crud/base_async.py Normal file → Executable file
View File

@@ -179,10 +179,25 @@ class CRUDBaseAsync(Generic[ModelType, CreateSchemaType, UpdateSchemaType]):
raise raise
async def get_multi_with_total( async def get_multi_with_total(
self, db: AsyncSession, *, skip: int = 0, limit: int = 100 self,
db: AsyncSession,
*,
skip: int = 0,
limit: int = 100,
sort_by: Optional[str] = None,
sort_order: str = "asc",
filters: Optional[Dict[str, Any]] = None
) -> Tuple[List[ModelType], int]: ) -> Tuple[List[ModelType], int]:
""" """
Get multiple records with total count for pagination. Get multiple records with total count, filtering, and sorting.
Args:
db: Database session
skip: Number of records to skip
limit: Maximum number of records to return
sort_by: Field name to sort by (must be a valid model attribute)
sort_order: Sort order ("asc" or "desc")
filters: Dictionary of filters (field_name: value)
Returns: Returns:
Tuple of (items, total_count) Tuple of (items, total_count)
@@ -196,16 +211,35 @@ class CRUDBaseAsync(Generic[ModelType, CreateSchemaType, UpdateSchemaType]):
raise ValueError("Maximum limit is 1000") raise ValueError("Maximum limit is 1000")
try: try:
# Get total count # Build base query
count_result = await db.execute( query = select(self.model)
select(func.count(self.model.id))
) # Exclude soft-deleted records by default
if hasattr(self.model, 'deleted_at'):
query = query.where(self.model.deleted_at.is_(None))
# Apply filters
if filters:
for field, value in filters.items():
if hasattr(self.model, field) and value is not None:
query = query.where(getattr(self.model, field) == value)
# Get total count (before pagination)
count_query = select(func.count()).select_from(query.alias())
count_result = await db.execute(count_query)
total = count_result.scalar_one() total = count_result.scalar_one()
# Get paginated items # Apply sorting
items_result = await db.execute( if sort_by and hasattr(self.model, sort_by):
select(self.model).offset(skip).limit(limit) sort_column = getattr(self.model, sort_by)
) if sort_order.lower() == "desc":
query = query.order_by(sort_column.desc())
else:
query = query.order_by(sort_column.asc())
# Apply pagination
query = query.offset(skip).limit(limit)
items_result = await db.execute(query)
items = list(items_result.scalars().all()) items = list(items_result.scalars().all())
return items, total return items, total
@@ -226,3 +260,92 @@ class CRUDBaseAsync(Generic[ModelType, CreateSchemaType, UpdateSchemaType]):
"""Check if a record exists by ID.""" """Check if a record exists by ID."""
obj = await self.get(db, id=id) obj = await self.get(db, id=id)
return obj is not None return obj is not None
async def soft_delete(self, db: AsyncSession, *, id: str) -> Optional[ModelType]:
"""
Soft delete a record by setting deleted_at timestamp.
Only works if the model has a 'deleted_at' column.
"""
from datetime import datetime, timezone
# Validate UUID format and convert to UUID object if string
try:
if isinstance(id, uuid.UUID):
uuid_obj = id
else:
uuid_obj = uuid.UUID(str(id))
except (ValueError, AttributeError, TypeError) as e:
logger.warning(f"Invalid UUID format for soft deletion: {id} - {str(e)}")
return None
try:
result = await db.execute(
select(self.model).where(self.model.id == uuid_obj)
)
obj = result.scalar_one_or_none()
if obj is None:
logger.warning(f"{self.model.__name__} with id {id} not found for soft deletion")
return None
# Check if model supports soft deletes
if not hasattr(self.model, 'deleted_at'):
logger.error(f"{self.model.__name__} does not support soft deletes")
raise ValueError(f"{self.model.__name__} does not have a deleted_at column")
# Set deleted_at timestamp
obj.deleted_at = datetime.now(timezone.utc)
db.add(obj)
await db.commit()
await db.refresh(obj)
return obj
except Exception as e:
await db.rollback()
logger.error(f"Error soft deleting {self.model.__name__} with id {id}: {str(e)}", exc_info=True)
raise
async def restore(self, db: AsyncSession, *, id: str) -> Optional[ModelType]:
"""
Restore a soft-deleted record by clearing the deleted_at timestamp.
Only works if the model has a 'deleted_at' column.
"""
# Validate UUID format
try:
if isinstance(id, uuid.UUID):
uuid_obj = id
else:
uuid_obj = uuid.UUID(str(id))
except (ValueError, AttributeError, TypeError) as e:
logger.warning(f"Invalid UUID format for restoration: {id} - {str(e)}")
return None
try:
# Find the soft-deleted record
if hasattr(self.model, 'deleted_at'):
result = await db.execute(
select(self.model).where(
self.model.id == uuid_obj,
self.model.deleted_at.isnot(None)
)
)
obj = result.scalar_one_or_none()
else:
logger.error(f"{self.model.__name__} does not support soft deletes")
raise ValueError(f"{self.model.__name__} does not have a deleted_at column")
if obj is None:
logger.warning(f"Soft-deleted {self.model.__name__} with id {id} not found for restoration")
return None
# Clear deleted_at timestamp
obj.deleted_at = None
db.add(obj)
await db.commit()
await db.refresh(obj)
return obj
except Exception as e:
await db.rollback()
logger.error(f"Error restoring {self.model.__name__} with id {id}: {str(e)}", exc_info=True)
raise

0
backend/app/init_db.py Normal file → Executable file
View File

4
backend/app/main.py Normal file → Executable file
View File

@@ -14,7 +14,7 @@ from sqlalchemy import text
from app.api.main import api_router from app.api.main import api_router
from app.core.config import settings from app.core.config import settings
from app.core.database import get_db, check_database_health from app.core.database_async import check_database_health
from app.core.exceptions import ( from app.core.exceptions import (
APIException, APIException,
api_exception_handler, api_exception_handler,
@@ -218,7 +218,7 @@ async def health_check() -> JSONResponse:
# Database health check using dedicated health check function # Database health check using dedicated health check function
try: try:
db_healthy = check_database_health() db_healthy = await check_database_health()
if db_healthy: if db_healthy:
health_status["checks"]["database"] = { health_status["checks"]["database"] = {
"status": "healthy", "status": "healthy",

29
backend/app/services/auth_service.py Normal file → Executable file
View File

@@ -3,7 +3,8 @@ import logging
from typing import Optional from typing import Optional
from uuid import UUID from uuid import UUID
from sqlalchemy.orm import Session from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select
from app.core.auth import ( from app.core.auth import (
verify_password, verify_password,
@@ -28,7 +29,7 @@ class AuthService:
"""Service for handling authentication operations""" """Service for handling authentication operations"""
@staticmethod @staticmethod
def authenticate_user(db: Session, email: str, password: str) -> Optional[User]: async def authenticate_user(db: AsyncSession, email: str, password: str) -> Optional[User]:
""" """
Authenticate a user with email and password. Authenticate a user with email and password.
@@ -40,7 +41,8 @@ class AuthService:
Returns: Returns:
User if authenticated, None otherwise User if authenticated, None otherwise
""" """
user = db.query(User).filter(User.email == email).first() result = await db.execute(select(User).where(User.email == email))
user = result.scalar_one_or_none()
if not user: if not user:
return None return None
@@ -54,7 +56,7 @@ class AuthService:
return user return user
@staticmethod @staticmethod
def create_user(db: Session, user_data: UserCreate) -> User: async def create_user(db: AsyncSession, user_data: UserCreate) -> User:
""" """
Create a new user. Create a new user.
@@ -66,7 +68,8 @@ class AuthService:
Created user Created user
""" """
# Check if user already exists # Check if user already exists
existing_user = db.query(User).filter(User.email == user_data.email).first() result = await db.execute(select(User).where(User.email == user_data.email))
existing_user = result.scalar_one_or_none()
if existing_user: if existing_user:
raise AuthenticationError("User with this email already exists") raise AuthenticationError("User with this email already exists")
@@ -85,8 +88,8 @@ class AuthService:
) )
db.add(user) db.add(user)
db.commit() await db.commit()
db.refresh(user) await db.refresh(user)
return user return user
@@ -124,7 +127,7 @@ class AuthService:
) )
@staticmethod @staticmethod
def refresh_tokens(db: Session, refresh_token: str) -> Token: async def refresh_tokens(db: AsyncSession, refresh_token: str) -> Token:
""" """
Generate new tokens using a refresh token. Generate new tokens using a refresh token.
@@ -150,7 +153,8 @@ class AuthService:
user_id = token_data.user_id user_id = token_data.user_id
# Get user from database # Get user from database
user = db.query(User).filter(User.id == user_id).first() result = await db.execute(select(User).where(User.id == user_id))
user = result.scalar_one_or_none()
if not user or not user.is_active: if not user or not user.is_active:
raise TokenInvalidError("Invalid user or inactive account") raise TokenInvalidError("Invalid user or inactive account")
@@ -162,7 +166,7 @@ class AuthService:
raise raise
@staticmethod @staticmethod
def change_password(db: Session, user_id: UUID, current_password: str, new_password: str) -> bool: async def change_password(db: AsyncSession, user_id: UUID, current_password: str, new_password: str) -> bool:
""" """
Change a user's password. Change a user's password.
@@ -178,7 +182,8 @@ class AuthService:
Raises: Raises:
AuthenticationError: If current password is incorrect AuthenticationError: If current password is incorrect
""" """
user = db.query(User).filter(User.id == user_id).first() result = await db.execute(select(User).where(User.id == user_id))
user = result.scalar_one_or_none()
if not user: if not user:
raise AuthenticationError("User not found") raise AuthenticationError("User not found")
@@ -188,6 +193,6 @@ class AuthService:
# Update password # Update password
user.password_hash = get_password_hash(new_password) user.password_hash = get_password_hash(new_password)
db.commit() await db.commit()
return True return True

78
backend/app/services/session_cleanup.py Normal file → Executable file
View File

@@ -6,13 +6,13 @@ This service runs periodically to remove old session records from the database.
import logging import logging
from datetime import datetime, timezone from datetime import datetime, timezone
from app.core.database import SessionLocal from app.core.database_async import AsyncSessionLocal
from app.crud.session import session as session_crud from app.crud.session_async import session_async as session_crud
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def cleanup_expired_sessions(keep_days: int = 30) -> int: async def cleanup_expired_sessions(keep_days: int = 30) -> int:
""" """
Clean up expired and inactive sessions. Clean up expired and inactive sessions.
@@ -29,52 +29,58 @@ def cleanup_expired_sessions(keep_days: int = 30) -> int:
""" """
logger.info("Starting session cleanup job...") logger.info("Starting session cleanup job...")
db = SessionLocal() async with AsyncSessionLocal() as db:
try: try:
# Use CRUD method to cleanup # Use CRUD method to cleanup
count = session_crud.cleanup_expired(db, keep_days=keep_days) count = await session_crud.cleanup_expired(db, keep_days=keep_days)
logger.info(f"Session cleanup complete: {count} sessions deleted") logger.info(f"Session cleanup complete: {count} sessions deleted")
return count return count
except Exception as e: except Exception as e:
logger.error(f"Error during session cleanup: {str(e)}", exc_info=True) logger.error(f"Error during session cleanup: {str(e)}", exc_info=True)
return 0 return 0
finally:
db.close()
def get_session_statistics() -> dict: async def get_session_statistics() -> dict:
""" """
Get statistics about current sessions. Get statistics about current sessions.
Returns: Returns:
Dictionary with session stats Dictionary with session stats
""" """
db = SessionLocal() async with AsyncSessionLocal() as db:
try: try:
from app.models.user_session import UserSession from app.models.user_session import UserSession
from sqlalchemy import select, func
total_sessions = db.query(UserSession).count() total_result = await db.execute(select(func.count(UserSession.id)))
active_sessions = db.query(UserSession).filter(UserSession.is_active == True).count() total_sessions = total_result.scalar_one()
expired_sessions = db.query(UserSession).filter(
UserSession.expires_at < datetime.now(timezone.utc)
).count()
stats = { active_result = await db.execute(
"total": total_sessions, select(func.count(UserSession.id)).where(UserSession.is_active == True)
"active": active_sessions, )
"inactive": total_sessions - active_sessions, active_sessions = active_result.scalar_one()
"expired": expired_sessions,
}
logger.info(f"Session statistics: {stats}") expired_result = await db.execute(
select(func.count(UserSession.id)).where(
UserSession.expires_at < datetime.now(timezone.utc)
)
)
expired_sessions = expired_result.scalar_one()
return stats stats = {
"total": total_sessions,
"active": active_sessions,
"inactive": total_sessions - active_sessions,
"expired": expired_sessions,
}
except Exception as e: logger.info(f"Session statistics: {stats}")
logger.error(f"Error getting session statistics: {str(e)}", exc_info=True)
return {} return stats
finally:
db.close() except Exception as e:
logger.error(f"Error getting session statistics: {str(e)}", exc_info=True)
return {}