Refactor backend to adopt async patterns across services, API routes, and CRUD operations

- Migrated database sessions and operations to `AsyncSession` for full async support.
- Updated all service methods and dependencies (`get_db` to `get_async_db`) to support async logic.
- Refactored admin, user, organization, session-related CRUD methods, and routes with await syntax.
- Improved consistency and performance with async SQLAlchemy patterns.
- Enhanced logging and error handling for async context.
This commit is contained in:
Felipe Cardoso
2025-10-31 21:57:12 +01:00
parent 19ecd04a41
commit 26ff08d9f9
14 changed files with 385 additions and 239 deletions

0
backend/app/__init__.py Normal file → Executable file
View File

24
backend/app/api/dependencies/auth.py Normal file → Executable file
View File

@@ -3,18 +3,19 @@ from typing import Optional
from fastapi import Depends, HTTPException, status, Header
from fastapi.security import OAuth2PasswordBearer
from fastapi.security.utils import get_authorization_scheme_param
from sqlalchemy.orm import Session
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select
from app.core.auth import get_token_data, TokenExpiredError, TokenInvalidError
from app.core.database import get_db
from app.core.database_async import get_async_db
from app.models.user import User
# OAuth2 configuration
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/api/v1/auth/login")
def get_current_user(
db: Session = Depends(get_db),
async def get_current_user(
db: AsyncSession = Depends(get_async_db),
token: str = Depends(oauth2_scheme)
) -> User:
"""
@@ -35,7 +36,11 @@ def get_current_user(
token_data = get_token_data(token)
# Get user from database
user = db.query(User).filter(User.id == token_data.user_id).first()
result = await db.execute(
select(User).where(User.id == token_data.user_id)
)
user = result.scalar_one_or_none()
if not user:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
@@ -133,8 +138,8 @@ async def get_optional_token(authorization: str = Header(None)) -> Optional[str]
return token
def get_optional_current_user(
db: Session = Depends(get_db),
async def get_optional_current_user(
db: AsyncSession = Depends(get_async_db),
token: Optional[str] = Depends(get_optional_token)
) -> Optional[User]:
"""
@@ -153,7 +158,10 @@ def get_optional_current_user(
try:
token_data = get_token_data(token)
user = db.query(User).filter(User.id == token_data.user_id).first()
result = await db.execute(
select(User).where(User.id == token_data.user_id)
)
user = result.scalar_one_or_none()
if not user or not user.is_active:
return None
return user

26
backend/app/api/dependencies/permissions.py Normal file → Executable file
View File

@@ -10,13 +10,13 @@ These dependencies are optional and flexible:
from typing import Optional
from uuid import UUID
from fastapi import Depends, HTTPException, status
from sqlalchemy.orm import Session
from sqlalchemy.ext.asyncio import AsyncSession
from app.core.database import get_db
from app.core.database_async import get_async_db
from app.models.user import User
from app.models.user_organization import OrganizationRole
from app.api.dependencies.auth import get_current_user
from app.crud.organization import organization as organization_crud
from app.crud.organization_async import organization_async as organization_crud
def require_superuser(
@@ -73,11 +73,11 @@ class OrganizationPermission:
"""
self.allowed_roles = allowed_roles
def __call__(
async def __call__(
self,
organization_id: UUID,
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db)
db: AsyncSession = Depends(get_async_db)
) -> User:
"""
Check if user has required role in the organization.
@@ -98,7 +98,7 @@ class OrganizationPermission:
return current_user
# Get user's role in organization
user_role = organization_crud.get_user_role_in_org(
user_role = await organization_crud.get_user_role_in_org(
db,
user_id=current_user.id,
organization_id=organization_id
@@ -129,10 +129,10 @@ require_org_member = OrganizationPermission([
])
def get_current_org_role(
async def get_current_org_role(
organization_id: UUID,
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db)
db: AsyncSession = Depends(get_async_db)
) -> Optional[OrganizationRole]:
"""
Get the current user's role in an organization.
@@ -142,7 +142,7 @@ def get_current_org_role(
Example:
@router.get("/organizations/{org_id}/items")
def list_items(
async def list_items(
org_id: UUID,
role: OrganizationRole = Depends(get_current_org_role)
):
@@ -153,17 +153,17 @@ def get_current_org_role(
if current_user.is_superuser:
return OrganizationRole.OWNER
return organization_crud.get_user_role_in_org(
return await organization_crud.get_user_role_in_org(
db,
user_id=current_user.id,
organization_id=organization_id
)
def require_org_membership(
async def require_org_membership(
organization_id: UUID,
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db)
db: AsyncSession = Depends(get_async_db)
) -> User:
"""
Ensure user is a member of the organization (any role).
@@ -173,7 +173,7 @@ def require_org_membership(
if current_user.is_superuser:
return current_user
user_role = organization_crud.get_user_role_in_org(
user_role = await organization_crud.get_user_role_in_org(
db,
user_id=current_user.id,
organization_id=organization_id

138
backend/app/api/routes/admin.py Normal file → Executable file
View File

@@ -11,13 +11,13 @@ from uuid import UUID
from enum import Enum
from fastapi import APIRouter, Depends, Query, Body, status
from sqlalchemy.orm import Session
from sqlalchemy.ext.asyncio import AsyncSession
from pydantic import BaseModel, Field
from app.api.dependencies.permissions import require_superuser
from app.core.database import get_db
from app.crud.user import user as user_crud
from app.crud.organization import organization as organization_crud
from app.core.database_async import get_async_db
from app.crud.user_async import user_async as user_crud
from app.crud.organization_async import organization_async as organization_crud
from app.models.user import User
from app.models.user_organization import OrganizationRole
from app.schemas.users import UserResponse, UserCreate, UserUpdate
@@ -73,14 +73,14 @@ class BulkActionResult(BaseModel):
description="Get paginated list of all users with filtering and search (admin only)",
operation_id="admin_list_users"
)
def admin_list_users(
async def admin_list_users(
pagination: PaginationParams = Depends(),
sort: SortParams = Depends(),
is_active: Optional[bool] = Query(None, description="Filter by active status"),
is_superuser: Optional[bool] = Query(None, description="Filter by superuser status"),
search: Optional[str] = Query(None, description="Search by email, name"),
admin: User = Depends(require_superuser),
db: Session = Depends(get_db)
db: AsyncSession = Depends(get_async_db)
) -> Any:
"""
List all users with comprehensive filtering and search.
@@ -96,7 +96,7 @@ def admin_list_users(
filters["is_superuser"] = is_superuser
# Get users with search
users, total = user_crud.get_multi_with_total(
users, total = await user_crud.get_multi_with_total(
db,
skip=pagination.offset,
limit=pagination.limit,
@@ -128,10 +128,10 @@ def admin_list_users(
description="Create a new user (admin only)",
operation_id="admin_create_user"
)
def admin_create_user(
async def admin_create_user(
user_in: UserCreate,
admin: User = Depends(require_superuser),
db: Session = Depends(get_db)
db: AsyncSession = Depends(get_async_db)
) -> Any:
"""
Create a new user with admin privileges.
@@ -139,7 +139,7 @@ def admin_create_user(
Allows setting is_superuser and other fields.
"""
try:
user = user_crud.create(db, obj_in=user_in)
user = await user_crud.create(db, obj_in=user_in)
logger.info(f"Admin {admin.email} created user {user.email}")
return user
except ValueError as e:
@@ -160,13 +160,13 @@ def admin_create_user(
description="Get detailed user information (admin only)",
operation_id="admin_get_user"
)
def admin_get_user(
async def admin_get_user(
user_id: UUID,
admin: User = Depends(require_superuser),
db: Session = Depends(get_db)
db: AsyncSession = Depends(get_async_db)
) -> Any:
"""Get detailed information about a specific user."""
user = user_crud.get(db, id=user_id)
user = await user_crud.get(db, id=user_id)
if not user:
raise NotFoundError(
detail=f"User {user_id} not found",
@@ -182,22 +182,22 @@ def admin_get_user(
description="Update user information (admin only)",
operation_id="admin_update_user"
)
def admin_update_user(
async def admin_update_user(
user_id: UUID,
user_in: UserUpdate,
admin: User = Depends(require_superuser),
db: Session = Depends(get_db)
db: AsyncSession = Depends(get_async_db)
) -> Any:
"""Update user information with admin privileges."""
try:
user = user_crud.get(db, id=user_id)
user = await user_crud.get(db, id=user_id)
if not user:
raise NotFoundError(
detail=f"User {user_id} not found",
error_code=ErrorCode.USER_NOT_FOUND
)
updated_user = user_crud.update(db, db_obj=user, obj_in=user_in)
updated_user = await user_crud.update(db, db_obj=user, obj_in=user_in)
logger.info(f"Admin {admin.email} updated user {updated_user.email}")
return updated_user
@@ -215,14 +215,14 @@ def admin_update_user(
description="Soft delete a user (admin only)",
operation_id="admin_delete_user"
)
def admin_delete_user(
async def admin_delete_user(
user_id: UUID,
admin: User = Depends(require_superuser),
db: Session = Depends(get_db)
db: AsyncSession = Depends(get_async_db)
) -> Any:
"""Soft delete a user (sets deleted_at timestamp)."""
try:
user = user_crud.get(db, id=user_id)
user = await user_crud.get(db, id=user_id)
if not user:
raise NotFoundError(
detail=f"User {user_id} not found",
@@ -236,7 +236,7 @@ def admin_delete_user(
error_code=ErrorCode.OPERATION_FORBIDDEN
)
user_crud.soft_delete(db, id=user_id)
await user_crud.soft_delete(db, id=user_id)
logger.info(f"Admin {admin.email} deleted user {user.email}")
return MessageResponse(
@@ -258,21 +258,21 @@ def admin_delete_user(
description="Activate a user account (admin only)",
operation_id="admin_activate_user"
)
def admin_activate_user(
async def admin_activate_user(
user_id: UUID,
admin: User = Depends(require_superuser),
db: Session = Depends(get_db)
db: AsyncSession = Depends(get_async_db)
) -> Any:
"""Activate a user account."""
try:
user = user_crud.get(db, id=user_id)
user = await user_crud.get(db, id=user_id)
if not user:
raise NotFoundError(
detail=f"User {user_id} not found",
error_code=ErrorCode.USER_NOT_FOUND
)
user_crud.update(db, db_obj=user, obj_in={"is_active": True})
await user_crud.update(db, db_obj=user, obj_in={"is_active": True})
logger.info(f"Admin {admin.email} activated user {user.email}")
return MessageResponse(
@@ -294,14 +294,14 @@ def admin_activate_user(
description="Deactivate a user account (admin only)",
operation_id="admin_deactivate_user"
)
def admin_deactivate_user(
async def admin_deactivate_user(
user_id: UUID,
admin: User = Depends(require_superuser),
db: Session = Depends(get_db)
db: AsyncSession = Depends(get_async_db)
) -> Any:
"""Deactivate a user account."""
try:
user = user_crud.get(db, id=user_id)
user = await user_crud.get(db, id=user_id)
if not user:
raise NotFoundError(
detail=f"User {user_id} not found",
@@ -315,7 +315,7 @@ def admin_deactivate_user(
error_code=ErrorCode.OPERATION_FORBIDDEN
)
user_crud.update(db, db_obj=user, obj_in={"is_active": False})
await user_crud.update(db, db_obj=user, obj_in={"is_active": False})
logger.info(f"Admin {admin.email} deactivated user {user.email}")
return MessageResponse(
@@ -337,10 +337,10 @@ def admin_deactivate_user(
description="Perform bulk actions on multiple users (admin only)",
operation_id="admin_bulk_user_action"
)
def admin_bulk_user_action(
async def admin_bulk_user_action(
bulk_action: BulkUserAction,
admin: User = Depends(require_superuser),
db: Session = Depends(get_db)
db: AsyncSession = Depends(get_async_db)
) -> Any:
"""
Perform bulk actions on multiple users.
@@ -354,7 +354,7 @@ def admin_bulk_user_action(
try:
for user_id in bulk_action.user_ids:
try:
user = user_crud.get(db, id=user_id)
user = await user_crud.get(db, id=user_id)
if not user:
failed_count += 1
failed_ids.append(user_id)
@@ -367,11 +367,11 @@ def admin_bulk_user_action(
continue
if bulk_action.action == BulkAction.ACTIVATE:
user_crud.update(db, db_obj=user, obj_in={"is_active": True})
await user_crud.update(db, db_obj=user, obj_in={"is_active": True})
elif bulk_action.action == BulkAction.DEACTIVATE:
user_crud.update(db, db_obj=user, obj_in={"is_active": False})
await user_crud.update(db, db_obj=user, obj_in={"is_active": False})
elif bulk_action.action == BulkAction.DELETE:
user_crud.soft_delete(db, id=user_id)
await user_crud.soft_delete(db, id=user_id)
affected_count += 1
@@ -407,16 +407,16 @@ def admin_bulk_user_action(
description="Get paginated list of all organizations (admin only)",
operation_id="admin_list_organizations"
)
def admin_list_organizations(
async def admin_list_organizations(
pagination: PaginationParams = Depends(),
is_active: Optional[bool] = Query(None, description="Filter by active status"),
search: Optional[str] = Query(None, description="Search by name, slug, description"),
admin: User = Depends(require_superuser),
db: Session = Depends(get_db)
db: AsyncSession = Depends(get_async_db)
) -> Any:
"""List all organizations with filtering and search."""
try:
orgs, total = organization_crud.get_multi_with_filters(
orgs, total = await organization_crud.get_multi_with_filters(
db,
skip=pagination.offset,
limit=pagination.limit,
@@ -438,7 +438,7 @@ def admin_list_organizations(
"settings": org.settings,
"created_at": org.created_at,
"updated_at": org.updated_at,
"member_count": organization_crud.get_member_count(db, organization_id=org.id)
"member_count": await organization_crud.get_member_count(db, organization_id=org.id)
}
orgs_with_count.append(OrganizationResponse(**org_dict))
@@ -464,14 +464,14 @@ def admin_list_organizations(
description="Create a new organization (admin only)",
operation_id="admin_create_organization"
)
def admin_create_organization(
async def admin_create_organization(
org_in: OrganizationCreate,
admin: User = Depends(require_superuser),
db: Session = Depends(get_db)
db: AsyncSession = Depends(get_async_db)
) -> Any:
"""Create a new organization."""
try:
org = organization_crud.create(db, obj_in=org_in)
org = await organization_crud.create(db, obj_in=org_in)
logger.info(f"Admin {admin.email} created organization {org.name}")
# Add member count
@@ -506,13 +506,13 @@ def admin_create_organization(
description="Get detailed organization information (admin only)",
operation_id="admin_get_organization"
)
def admin_get_organization(
async def admin_get_organization(
org_id: UUID,
admin: User = Depends(require_superuser),
db: Session = Depends(get_db)
db: AsyncSession = Depends(get_async_db)
) -> Any:
"""Get detailed information about a specific organization."""
org = organization_crud.get(db, id=org_id)
org = await organization_crud.get(db, id=org_id)
if not org:
raise NotFoundError(
detail=f"Organization {org_id} not found",
@@ -528,7 +528,7 @@ def admin_get_organization(
"settings": org.settings,
"created_at": org.created_at,
"updated_at": org.updated_at,
"member_count": organization_crud.get_member_count(db, organization_id=org.id)
"member_count": await organization_crud.get_member_count(db, organization_id=org.id)
}
return OrganizationResponse(**org_dict)
@@ -540,22 +540,22 @@ def admin_get_organization(
description="Update organization information (admin only)",
operation_id="admin_update_organization"
)
def admin_update_organization(
async def admin_update_organization(
org_id: UUID,
org_in: OrganizationUpdate,
admin: User = Depends(require_superuser),
db: Session = Depends(get_db)
db: AsyncSession = Depends(get_async_db)
) -> Any:
"""Update organization information."""
try:
org = organization_crud.get(db, id=org_id)
org = await organization_crud.get(db, id=org_id)
if not org:
raise NotFoundError(
detail=f"Organization {org_id} not found",
error_code=ErrorCode.NOT_FOUND
)
updated_org = organization_crud.update(db, db_obj=org, obj_in=org_in)
updated_org = await organization_crud.update(db, db_obj=org, obj_in=org_in)
logger.info(f"Admin {admin.email} updated organization {updated_org.name}")
org_dict = {
@@ -567,7 +567,7 @@ def admin_update_organization(
"settings": updated_org.settings,
"created_at": updated_org.created_at,
"updated_at": updated_org.updated_at,
"member_count": organization_crud.get_member_count(db, organization_id=updated_org.id)
"member_count": await organization_crud.get_member_count(db, organization_id=updated_org.id)
}
return OrganizationResponse(**org_dict)
@@ -585,21 +585,21 @@ def admin_update_organization(
description="Delete an organization (admin only)",
operation_id="admin_delete_organization"
)
def admin_delete_organization(
async def admin_delete_organization(
org_id: UUID,
admin: User = Depends(require_superuser),
db: Session = Depends(get_db)
db: AsyncSession = Depends(get_async_db)
) -> Any:
"""Delete an organization and all its relationships."""
try:
org = organization_crud.get(db, id=org_id)
org = await organization_crud.get(db, id=org_id)
if not org:
raise NotFoundError(
detail=f"Organization {org_id} not found",
error_code=ErrorCode.NOT_FOUND
)
organization_crud.remove(db, id=org_id)
await organization_crud.remove(db, id=org_id)
logger.info(f"Admin {admin.email} deleted organization {org.name}")
return MessageResponse(
@@ -621,23 +621,23 @@ def admin_delete_organization(
description="Get all members of an organization (admin only)",
operation_id="admin_list_organization_members"
)
def admin_list_organization_members(
async def admin_list_organization_members(
org_id: UUID,
pagination: PaginationParams = Depends(),
is_active: Optional[bool] = Query(True, description="Filter by active status"),
admin: User = Depends(require_superuser),
db: Session = Depends(get_db)
db: AsyncSession = Depends(get_async_db)
) -> Any:
"""List all members of an organization."""
try:
org = organization_crud.get(db, id=org_id)
org = await organization_crud.get(db, id=org_id)
if not org:
raise NotFoundError(
detail=f"Organization {org_id} not found",
error_code=ErrorCode.NOT_FOUND
)
members, total = organization_crud.get_organization_members(
members, total = await organization_crud.get_organization_members(
db,
organization_id=org_id,
skip=pagination.offset,
@@ -677,29 +677,29 @@ class AddMemberRequest(BaseModel):
description="Add a user to an organization (admin only)",
operation_id="admin_add_organization_member"
)
def admin_add_organization_member(
async def admin_add_organization_member(
org_id: UUID,
request: AddMemberRequest,
admin: User = Depends(require_superuser),
db: Session = Depends(get_db)
db: AsyncSession = Depends(get_async_db)
) -> Any:
"""Add a user to an organization."""
try:
org = organization_crud.get(db, id=org_id)
org = await organization_crud.get(db, id=org_id)
if not org:
raise NotFoundError(
detail=f"Organization {org_id} not found",
error_code=ErrorCode.NOT_FOUND
)
user = user_crud.get(db, id=request.user_id)
user = await user_crud.get(db, id=request.user_id)
if not user:
raise NotFoundError(
detail=f"User {request.user_id} not found",
error_code=ErrorCode.USER_NOT_FOUND
)
organization_crud.add_user(
await organization_crud.add_user(
db,
organization_id=org_id,
user_id=request.user_id,
@@ -733,29 +733,29 @@ def admin_add_organization_member(
description="Remove a user from an organization (admin only)",
operation_id="admin_remove_organization_member"
)
def admin_remove_organization_member(
async def admin_remove_organization_member(
org_id: UUID,
user_id: UUID,
admin: User = Depends(require_superuser),
db: Session = Depends(get_db)
db: AsyncSession = Depends(get_async_db)
) -> Any:
"""Remove a user from an organization."""
try:
org = organization_crud.get(db, id=org_id)
org = await organization_crud.get(db, id=org_id)
if not org:
raise NotFoundError(
detail=f"Organization {org_id} not found",
error_code=ErrorCode.NOT_FOUND
)
user = user_crud.get(db, id=user_id)
user = await user_crud.get(db, id=user_id)
if not user:
raise NotFoundError(
detail=f"User {user_id} not found",
error_code=ErrorCode.USER_NOT_FOUND
)
success = organization_crud.remove_user(
success = await organization_crud.remove_user(
db,
organization_id=org_id,
user_id=user_id

62
backend/app/api/routes/auth.py Normal file → Executable file
View File

@@ -8,11 +8,11 @@ from fastapi import APIRouter, Depends, HTTPException, status, Body, Request
from fastapi.security import OAuth2PasswordRequestForm
from slowapi import Limiter
from slowapi.util import get_remote_address
from sqlalchemy.orm import Session
from sqlalchemy.ext.asyncio import AsyncSession
from app.api.dependencies.auth import get_current_user
from app.core.auth import TokenExpiredError, TokenInvalidError, decode_token
from app.core.database import get_db
from app.core.database_async import get_async_db
from app.models.user import User
from app.schemas.users import (
UserCreate,
@@ -29,8 +29,8 @@ from app.services.auth_service import AuthService, AuthenticationError
from app.services.email_service import email_service
from app.utils.security import create_password_reset_token, verify_password_reset_token
from app.utils.device import extract_device_info
from app.crud.user import user as user_crud
from app.crud.session import session as session_crud
from app.crud.user_async import user_async as user_crud
from app.crud.session_async import session_async as session_crud
from app.core.auth import get_password_hash
router = APIRouter()
@@ -49,7 +49,7 @@ RATE_MULTIPLIER = 100 if IS_TEST else 1
async def register_user(
request: Request,
user_data: UserCreate,
db: Session = Depends(get_db)
db: AsyncSession = Depends(get_async_db)
) -> Any:
"""
Register a new user.
@@ -58,7 +58,7 @@ async def register_user(
The created user information.
"""
try:
user = AuthService.create_user(db, user_data)
user = await AuthService.create_user(db, user_data)
return user
except AuthenticationError as e:
logger.warning(f"Registration failed: {str(e)}")
@@ -79,7 +79,7 @@ async def register_user(
async def login(
request: Request,
login_data: LoginRequest,
db: Session = Depends(get_db)
db: AsyncSession = Depends(get_async_db)
) -> Any:
"""
Login with username and password.
@@ -91,7 +91,7 @@ async def login(
"""
try:
# Attempt to authenticate the user
user = AuthService.authenticate_user(db, login_data.email, login_data.password)
user = await AuthService.authenticate_user(db, login_data.email, login_data.password)
# Explicitly check for None result and raise correct exception
if user is None:
@@ -126,7 +126,7 @@ async def login(
location_country=device_info.location_country,
)
session_crud.create_session(db, obj_in=session_data)
await session_crud.create_session(db, obj_in=session_data)
logger.info(
f"User login successful: {user.email} from {device_info.device_name} "
@@ -163,7 +163,7 @@ async def login(
async def login_oauth(
request: Request,
form_data: OAuth2PasswordRequestForm = Depends(),
db: Session = Depends(get_db)
db: AsyncSession = Depends(get_async_db)
) -> Any:
"""
OAuth2-compatible login endpoint, used by the OpenAPI UI.
@@ -174,7 +174,7 @@ async def login_oauth(
Access and refresh tokens.
"""
try:
user = AuthService.authenticate_user(db, form_data.username, form_data.password)
user = await AuthService.authenticate_user(db, form_data.username, form_data.password)
if user is None:
raise HTTPException(
@@ -207,7 +207,7 @@ async def login_oauth(
location_country=device_info.location_country,
)
session_crud.create_session(db, obj_in=session_data)
await session_crud.create_session(db, obj_in=session_data)
logger.info(f"OAuth login successful: {user.email} from {device_info.device_name}")
except Exception as session_err:
@@ -241,7 +241,7 @@ async def login_oauth(
async def refresh_token(
request: Request,
refresh_data: RefreshTokenRequest,
db: Session = Depends(get_db)
db: AsyncSession = Depends(get_async_db)
) -> Any:
"""
Refresh access token using a refresh token.
@@ -256,7 +256,7 @@ async def refresh_token(
refresh_payload = decode_token(refresh_data.refresh_token, verify_type="refresh")
# Check if session exists and is active
session = session_crud.get_active_by_jti(db, jti=refresh_payload.jti)
session = await session_crud.get_active_by_jti(db, jti=refresh_payload.jti)
if not session:
logger.warning(f"Refresh token used for inactive or non-existent session: {refresh_payload.jti}")
@@ -267,14 +267,14 @@ async def refresh_token(
)
# Generate new tokens
tokens = AuthService.refresh_tokens(db, refresh_data.refresh_token)
tokens = await AuthService.refresh_tokens(db, refresh_data.refresh_token)
# Decode new refresh token to get new JTI
new_refresh_payload = decode_token(tokens.refresh_token, verify_type="refresh")
# Update session with new refresh token JTI and expiration
try:
session_crud.update_refresh_token(
await session_crud.update_refresh_token(
db,
session=session,
new_jti=new_refresh_payload.jti,
@@ -344,7 +344,7 @@ async def get_current_user_info(
async def request_password_reset(
request: Request,
reset_request: PasswordResetRequest,
db: Session = Depends(get_db)
db: AsyncSession = Depends(get_async_db)
) -> Any:
"""
Request a password reset.
@@ -354,7 +354,7 @@ async def request_password_reset(
"""
try:
# Look up user by email
user = user_crud.get_by_email(db, email=reset_request.email)
user = await user_crud.get_by_email(db, email=reset_request.email)
# Only send email if user exists and is active
if user and user.is_active:
@@ -399,10 +399,10 @@ async def request_password_reset(
operation_id="confirm_password_reset"
)
@limiter.limit("5/minute")
def confirm_password_reset(
async def confirm_password_reset(
request: Request,
reset_confirm: PasswordResetConfirm,
db: Session = Depends(get_db)
db: AsyncSession = Depends(get_async_db)
) -> Any:
"""
Confirm password reset with token.
@@ -420,7 +420,7 @@ def confirm_password_reset(
)
# Look up user
user = user_crud.get_by_email(db, email=email)
user = await user_crud.get_by_email(db, email=email)
if not user:
raise HTTPException(
@@ -437,7 +437,7 @@ def confirm_password_reset(
# Update password
user.password_hash = get_password_hash(reset_confirm.new_password)
db.add(user)
db.commit()
await db.commit()
logger.info(f"Password reset successful for {user.email}")
@@ -450,7 +450,7 @@ def confirm_password_reset(
raise
except Exception as e:
logger.error(f"Error confirming password reset: {str(e)}", exc_info=True)
db.rollback()
await db.rollback()
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="An error occurred while resetting your password"
@@ -474,11 +474,11 @@ def confirm_password_reset(
operation_id="logout"
)
@limiter.limit("10/minute")
def logout(
async def logout(
request: Request,
logout_request: LogoutRequest,
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db)
db: AsyncSession = Depends(get_async_db)
) -> Any:
"""
Logout from current device by deactivating the session.
@@ -505,7 +505,7 @@ def logout(
)
# Find the session by JTI
session = session_crud.get_by_jti(db, jti=refresh_payload.jti)
session = await session_crud.get_by_jti(db, jti=refresh_payload.jti)
if session:
# Verify session belongs to current user (security check)
@@ -520,7 +520,7 @@ def logout(
)
# Deactivate the session
session_crud.deactivate(db, session_id=str(session.id))
await session_crud.deactivate(db, session_id=str(session.id))
logger.info(
f"User {current_user.id} logged out from {session.device_name} "
@@ -563,10 +563,10 @@ def logout(
operation_id="logout_all"
)
@limiter.limit("5/minute")
def logout_all(
async def logout_all(
request: Request,
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db)
db: AsyncSession = Depends(get_async_db)
) -> Any:
"""
Logout from all devices by deactivating all user sessions.
@@ -580,7 +580,7 @@ def logout_all(
"""
try:
# Deactivate all sessions for this user
count = session_crud.deactivate_all_user_sessions(db, user_id=str(current_user.id))
count = await session_crud.deactivate_all_user_sessions(db, user_id=str(current_user.id))
logger.info(f"User {current_user.id} logged out from all devices ({count} sessions)")
@@ -591,7 +591,7 @@ def logout_all(
except Exception as e:
logger.error(f"Error during logout-all for user {current_user.id}: {str(e)}", exc_info=True)
db.rollback()
await db.rollback()
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="An error occurred while logging out"

38
backend/app/api/routes/organizations.py Normal file → Executable file
View File

@@ -9,12 +9,12 @@ from typing import Any, List, Optional
from uuid import UUID
from fastapi import APIRouter, Depends, Query, status
from sqlalchemy.orm import Session
from sqlalchemy.ext.asyncio import AsyncSession
from app.api.dependencies.auth import get_current_user
from app.api.dependencies.permissions import require_org_admin, require_org_membership, get_current_org_role
from app.core.database import get_db
from app.crud.organization import organization as organization_crud
from app.core.database_async import get_async_db
from app.crud.organization_async import organization_async as organization_crud
from app.models.user import User
from app.models.user_organization import OrganizationRole
from app.schemas.organizations import (
@@ -42,10 +42,10 @@ router = APIRouter()
description="Get all organizations the current user belongs to",
operation_id="get_my_organizations"
)
def get_my_organizations(
async def get_my_organizations(
is_active: bool = Query(True, description="Filter by active membership"),
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db)
db: AsyncSession = Depends(get_async_db)
) -> Any:
"""
Get all organizations the current user belongs to.
@@ -53,7 +53,7 @@ def get_my_organizations(
Returns organizations with member count for each.
"""
try:
orgs = organization_crud.get_user_organizations(
orgs = await organization_crud.get_user_organizations(
db,
user_id=current_user.id,
is_active=is_active
@@ -77,7 +77,7 @@ def get_my_organizations(
"settings": org.settings,
"created_at": org.created_at,
"updated_at": org.updated_at,
"member_count": organization_crud.get_member_count(db, organization_id=org.id)
"member_count": await organization_crud.get_member_count(db, organization_id=org.id)
}
orgs_with_data.append(OrganizationResponse(**org_dict))
@@ -95,10 +95,10 @@ def get_my_organizations(
description="Get details of an organization the user belongs to",
operation_id="get_organization"
)
def get_organization(
async def get_organization(
organization_id: UUID,
current_user: User = Depends(require_org_membership),
db: Session = Depends(get_db)
db: AsyncSession = Depends(get_async_db)
) -> Any:
"""
Get details of a specific organization.
@@ -106,7 +106,7 @@ def get_organization(
User must be a member of the organization.
"""
try:
org = organization_crud.get(db, id=organization_id)
org = await organization_crud.get(db, id=organization_id)
if not org:
raise NotFoundError(
detail=f"Organization {organization_id} not found",
@@ -122,7 +122,7 @@ def get_organization(
"settings": org.settings,
"created_at": org.created_at,
"updated_at": org.updated_at,
"member_count": organization_crud.get_member_count(db, organization_id=org.id)
"member_count": await organization_crud.get_member_count(db, organization_id=org.id)
}
return OrganizationResponse(**org_dict)
@@ -140,12 +140,12 @@ def get_organization(
description="Get all members of an organization (members can view)",
operation_id="get_organization_members"
)
def get_organization_members(
async def get_organization_members(
organization_id: UUID,
pagination: PaginationParams = Depends(),
is_active: bool = Query(True, description="Filter by active status"),
current_user: User = Depends(require_org_membership),
db: Session = Depends(get_db)
db: AsyncSession = Depends(get_async_db)
) -> Any:
"""
Get all members of an organization.
@@ -153,7 +153,7 @@ def get_organization_members(
User must be a member of the organization to view members.
"""
try:
members, total = organization_crud.get_organization_members(
members, total = await organization_crud.get_organization_members(
db,
organization_id=organization_id,
skip=pagination.offset,
@@ -184,11 +184,11 @@ def get_organization_members(
description="Update organization details (admin/owner only)",
operation_id="update_organization"
)
def update_organization(
async def update_organization(
organization_id: UUID,
org_in: OrganizationUpdate,
current_user: User = Depends(require_org_admin),
db: Session = Depends(get_db)
db: AsyncSession = Depends(get_async_db)
) -> Any:
"""
Update organization details.
@@ -196,14 +196,14 @@ def update_organization(
Requires owner or admin role in the organization.
"""
try:
org = organization_crud.get(db, id=organization_id)
org = await organization_crud.get(db, id=organization_id)
if not org:
raise NotFoundError(
detail=f"Organization {organization_id} not found",
error_code=ErrorCode.NOT_FOUND
)
updated_org = organization_crud.update(db, db_obj=org, obj_in=org_in)
updated_org = await organization_crud.update(db, db_obj=org, obj_in=org_in)
logger.info(f"User {current_user.email} updated organization {updated_org.name}")
org_dict = {
@@ -215,7 +215,7 @@ def update_organization(
"settings": updated_org.settings,
"created_at": updated_org.created_at,
"updated_at": updated_org.updated_at,
"member_count": organization_crud.get_member_count(db, organization_id=updated_org.id)
"member_count": await organization_crud.get_member_count(db, organization_id=updated_org.id)
}
return OrganizationResponse(**org_dict)

32
backend/app/api/routes/sessions.py Normal file → Executable file
View File

@@ -10,15 +10,15 @@ from uuid import UUID
from fastapi import APIRouter, Depends, HTTPException, status, Request
from slowapi import Limiter
from slowapi.util import get_remote_address
from sqlalchemy.orm import Session
from sqlalchemy.ext.asyncio import AsyncSession
from app.api.dependencies.auth import get_current_user
from app.core.database import get_db
from app.core.database_async import get_async_db
from app.core.auth import decode_token
from app.models.user import User
from app.schemas.sessions import SessionResponse, SessionListResponse
from app.schemas.common import MessageResponse
from app.crud.session import session as session_crud
from app.crud.session_async import session_async as session_crud
from app.core.exceptions import NotFoundError, AuthorizationError, ErrorCode
router = APIRouter()
@@ -42,10 +42,10 @@ limiter = Limiter(key_func=get_remote_address)
operation_id="list_my_sessions"
)
@limiter.limit("30/minute")
def list_my_sessions(
async def list_my_sessions(
request: Request,
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db)
db: AsyncSession = Depends(get_async_db)
) -> Any:
"""
List all active sessions for the current user.
@@ -59,7 +59,7 @@ def list_my_sessions(
"""
try:
# Get all active sessions for user
sessions = session_crud.get_user_sessions(
sessions = await session_crud.get_user_sessions(
db,
user_id=str(current_user.id),
active_only=True
@@ -125,11 +125,11 @@ def list_my_sessions(
operation_id="revoke_session"
)
@limiter.limit("10/minute")
def revoke_session(
async def revoke_session(
request: Request,
session_id: UUID,
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db)
db: AsyncSession = Depends(get_async_db)
) -> Any:
"""
Revoke a specific session by ID.
@@ -144,7 +144,7 @@ def revoke_session(
"""
try:
# Get the session
session = session_crud.get(db, id=str(session_id))
session = await session_crud.get(db, id=str(session_id))
if not session:
raise NotFoundError(
@@ -164,7 +164,7 @@ def revoke_session(
)
# Deactivate the session
session_crud.deactivate(db, session_id=str(session_id))
await session_crud.deactivate(db, session_id=str(session_id))
logger.info(
f"User {current_user.id} revoked session {session_id} "
@@ -201,10 +201,10 @@ def revoke_session(
operation_id="cleanup_expired_sessions"
)
@limiter.limit("5/minute")
def cleanup_expired_sessions(
async def cleanup_expired_sessions(
request: Request,
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db)
db: AsyncSession = Depends(get_async_db)
) -> Any:
"""
Cleanup expired sessions for the current user.
@@ -220,7 +220,7 @@ def cleanup_expired_sessions(
from datetime import datetime, timezone
# Get all sessions for user
all_sessions = session_crud.get_user_sessions(
all_sessions = await session_crud.get_user_sessions(
db,
user_id=str(current_user.id),
active_only=False
@@ -230,10 +230,10 @@ def cleanup_expired_sessions(
deleted_count = 0
for s in all_sessions:
if not s.is_active and s.expires_at < datetime.now(timezone.utc):
db.delete(s)
await db.delete(s)
deleted_count += 1
db.commit()
await db.commit()
logger.info(f"User {current_user.id} cleaned up {deleted_count} expired sessions")
@@ -244,7 +244,7 @@ def cleanup_expired_sessions(
except Exception as e:
logger.error(f"Error cleaning up sessions for user {current_user.id}: {str(e)}", exc_info=True)
db.rollback()
await db.rollback()
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to cleanup sessions"

46
backend/app/api/routes/users.py Normal file → Executable file
View File

@@ -6,13 +6,13 @@ from typing import Any, Optional
from uuid import UUID
from fastapi import APIRouter, Depends, Query, status, Request
from sqlalchemy.orm import Session
from sqlalchemy.ext.asyncio import AsyncSession
from slowapi import Limiter
from slowapi.util import get_remote_address
from app.api.dependencies.auth import get_current_user, get_current_superuser
from app.core.database import get_db
from app.crud.user import user as user_crud
from app.core.database_async import get_async_db
from app.crud.user_async import user_async as user_crud
from app.models.user import User
from app.schemas.users import UserResponse, UserUpdate, PasswordChange
from app.schemas.common import (
@@ -52,13 +52,13 @@ limiter = Limiter(key_func=get_remote_address)
""",
operation_id="list_users"
)
def list_users(
async def list_users(
pagination: PaginationParams = Depends(),
sort: SortParams = Depends(),
is_active: Optional[bool] = Query(None, description="Filter by active status"),
is_superuser: Optional[bool] = Query(None, description="Filter by superuser status"),
current_user: User = Depends(get_current_superuser),
db: Session = Depends(get_db)
db: AsyncSession = Depends(get_async_db)
) -> Any:
"""
List all users with pagination, filtering, and sorting.
@@ -74,7 +74,7 @@ def list_users(
filters["is_superuser"] = is_superuser
# Get paginated users with total count
users, total = user_crud.get_multi_with_total(
users, total = await user_crud.get_multi_with_total(
db,
skip=pagination.offset,
limit=pagination.limit,
@@ -135,10 +135,10 @@ def get_current_user_profile(
""",
operation_id="update_current_user"
)
def update_current_user(
async def update_current_user(
user_update: UserUpdate,
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db)
db: AsyncSession = Depends(get_async_db)
) -> Any:
"""
Update current user's profile.
@@ -154,7 +154,7 @@ def update_current_user(
)
try:
updated_user = user_crud.update(
updated_user = await user_crud.update(
db,
db_obj=current_user,
obj_in=user_update
@@ -185,10 +185,10 @@ def update_current_user(
""",
operation_id="get_user_by_id"
)
def get_user_by_id(
async def get_user_by_id(
user_id: UUID,
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db)
db: AsyncSession = Depends(get_async_db)
) -> Any:
"""
Get user by ID.
@@ -206,7 +206,7 @@ def get_user_by_id(
)
# Get user
user = user_crud.get(db, id=str(user_id))
user = await user_crud.get(db, id=str(user_id))
if not user:
raise NotFoundError(
message=f"User with id {user_id} not found",
@@ -232,11 +232,11 @@ def get_user_by_id(
""",
operation_id="update_user"
)
def update_user(
async def update_user(
user_id: UUID,
user_update: UserUpdate,
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db)
db: AsyncSession = Depends(get_async_db)
) -> Any:
"""
Update user by ID.
@@ -257,7 +257,7 @@ def update_user(
)
# Get user
user = user_crud.get(db, id=str(user_id))
user = await user_crud.get(db, id=str(user_id))
if not user:
raise NotFoundError(
message=f"User with id {user_id} not found",
@@ -273,7 +273,7 @@ def update_user(
)
try:
updated_user = user_crud.update(db, db_obj=user, obj_in=user_update)
updated_user = await user_crud.update(db, db_obj=user, obj_in=user_update)
logger.info(f"User {user_id} updated by {current_user.id}")
return updated_user
except ValueError as e:
@@ -300,11 +300,11 @@ def update_user(
operation_id="change_current_user_password"
)
@limiter.limit("5/minute")
def change_current_user_password(
async def change_current_user_password(
request: Request,
password_change: PasswordChange,
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db)
db: AsyncSession = Depends(get_async_db)
) -> Any:
"""
Change current user's password.
@@ -312,7 +312,7 @@ def change_current_user_password(
Requires current password for verification.
"""
try:
success = AuthService.change_password(
success = await AuthService.change_password(
db=db,
user_id=current_user.id,
current_password=password_change.current_password,
@@ -353,10 +353,10 @@ def change_current_user_password(
""",
operation_id="delete_user"
)
def delete_user(
async def delete_user(
user_id: UUID,
current_user: User = Depends(get_current_superuser),
db: Session = Depends(get_db)
db: AsyncSession = Depends(get_async_db)
) -> Any:
"""
Delete user by ID (superuser only).
@@ -371,7 +371,7 @@ def delete_user(
)
# Get user
user = user_crud.get(db, id=str(user_id))
user = await user_crud.get(db, id=str(user_id))
if not user:
raise NotFoundError(
message=f"User with id {user_id} not found",
@@ -380,7 +380,7 @@ def delete_user(
try:
# Use soft delete instead of hard delete
user_crud.soft_delete(db, id=str(user_id))
await user_crud.soft_delete(db, id=str(user_id))
logger.info(f"User {user_id} soft-deleted by {current_user.id}")
return MessageResponse(
success=True,

4
backend/app/core/database_async.py Normal file → Executable file
View File

@@ -159,6 +159,10 @@ async def check_async_database_health() -> bool:
return False
# Alias for consistency with main.py
check_database_health = check_async_database_health
async def init_async_db() -> None:
"""
Initialize async database tables.

143
backend/app/crud/base_async.py Normal file → Executable file
View File

@@ -179,10 +179,25 @@ class CRUDBaseAsync(Generic[ModelType, CreateSchemaType, UpdateSchemaType]):
raise
async def get_multi_with_total(
self, db: AsyncSession, *, skip: int = 0, limit: int = 100
self,
db: AsyncSession,
*,
skip: int = 0,
limit: int = 100,
sort_by: Optional[str] = None,
sort_order: str = "asc",
filters: Optional[Dict[str, Any]] = None
) -> Tuple[List[ModelType], int]:
"""
Get multiple records with total count for pagination.
Get multiple records with total count, filtering, and sorting.
Args:
db: Database session
skip: Number of records to skip
limit: Maximum number of records to return
sort_by: Field name to sort by (must be a valid model attribute)
sort_order: Sort order ("asc" or "desc")
filters: Dictionary of filters (field_name: value)
Returns:
Tuple of (items, total_count)
@@ -196,16 +211,35 @@ class CRUDBaseAsync(Generic[ModelType, CreateSchemaType, UpdateSchemaType]):
raise ValueError("Maximum limit is 1000")
try:
# Get total count
count_result = await db.execute(
select(func.count(self.model.id))
)
# Build base query
query = select(self.model)
# Exclude soft-deleted records by default
if hasattr(self.model, 'deleted_at'):
query = query.where(self.model.deleted_at.is_(None))
# Apply filters
if filters:
for field, value in filters.items():
if hasattr(self.model, field) and value is not None:
query = query.where(getattr(self.model, field) == value)
# Get total count (before pagination)
count_query = select(func.count()).select_from(query.alias())
count_result = await db.execute(count_query)
total = count_result.scalar_one()
# Get paginated items
items_result = await db.execute(
select(self.model).offset(skip).limit(limit)
)
# Apply sorting
if sort_by and hasattr(self.model, sort_by):
sort_column = getattr(self.model, sort_by)
if sort_order.lower() == "desc":
query = query.order_by(sort_column.desc())
else:
query = query.order_by(sort_column.asc())
# Apply pagination
query = query.offset(skip).limit(limit)
items_result = await db.execute(query)
items = list(items_result.scalars().all())
return items, total
@@ -226,3 +260,92 @@ class CRUDBaseAsync(Generic[ModelType, CreateSchemaType, UpdateSchemaType]):
"""Check if a record exists by ID."""
obj = await self.get(db, id=id)
return obj is not None
async def soft_delete(self, db: AsyncSession, *, id: str) -> Optional[ModelType]:
"""
Soft delete a record by setting deleted_at timestamp.
Only works if the model has a 'deleted_at' column.
"""
from datetime import datetime, timezone
# Validate UUID format and convert to UUID object if string
try:
if isinstance(id, uuid.UUID):
uuid_obj = id
else:
uuid_obj = uuid.UUID(str(id))
except (ValueError, AttributeError, TypeError) as e:
logger.warning(f"Invalid UUID format for soft deletion: {id} - {str(e)}")
return None
try:
result = await db.execute(
select(self.model).where(self.model.id == uuid_obj)
)
obj = result.scalar_one_or_none()
if obj is None:
logger.warning(f"{self.model.__name__} with id {id} not found for soft deletion")
return None
# Check if model supports soft deletes
if not hasattr(self.model, 'deleted_at'):
logger.error(f"{self.model.__name__} does not support soft deletes")
raise ValueError(f"{self.model.__name__} does not have a deleted_at column")
# Set deleted_at timestamp
obj.deleted_at = datetime.now(timezone.utc)
db.add(obj)
await db.commit()
await db.refresh(obj)
return obj
except Exception as e:
await db.rollback()
logger.error(f"Error soft deleting {self.model.__name__} with id {id}: {str(e)}", exc_info=True)
raise
async def restore(self, db: AsyncSession, *, id: str) -> Optional[ModelType]:
"""
Restore a soft-deleted record by clearing the deleted_at timestamp.
Only works if the model has a 'deleted_at' column.
"""
# Validate UUID format
try:
if isinstance(id, uuid.UUID):
uuid_obj = id
else:
uuid_obj = uuid.UUID(str(id))
except (ValueError, AttributeError, TypeError) as e:
logger.warning(f"Invalid UUID format for restoration: {id} - {str(e)}")
return None
try:
# Find the soft-deleted record
if hasattr(self.model, 'deleted_at'):
result = await db.execute(
select(self.model).where(
self.model.id == uuid_obj,
self.model.deleted_at.isnot(None)
)
)
obj = result.scalar_one_or_none()
else:
logger.error(f"{self.model.__name__} does not support soft deletes")
raise ValueError(f"{self.model.__name__} does not have a deleted_at column")
if obj is None:
logger.warning(f"Soft-deleted {self.model.__name__} with id {id} not found for restoration")
return None
# Clear deleted_at timestamp
obj.deleted_at = None
db.add(obj)
await db.commit()
await db.refresh(obj)
return obj
except Exception as e:
await db.rollback()
logger.error(f"Error restoring {self.model.__name__} with id {id}: {str(e)}", exc_info=True)
raise

0
backend/app/init_db.py Normal file → Executable file
View File

4
backend/app/main.py Normal file → Executable file
View File

@@ -14,7 +14,7 @@ from sqlalchemy import text
from app.api.main import api_router
from app.core.config import settings
from app.core.database import get_db, check_database_health
from app.core.database_async import check_database_health
from app.core.exceptions import (
APIException,
api_exception_handler,
@@ -218,7 +218,7 @@ async def health_check() -> JSONResponse:
# Database health check using dedicated health check function
try:
db_healthy = check_database_health()
db_healthy = await check_database_health()
if db_healthy:
health_status["checks"]["database"] = {
"status": "healthy",

29
backend/app/services/auth_service.py Normal file → Executable file
View File

@@ -3,7 +3,8 @@ import logging
from typing import Optional
from uuid import UUID
from sqlalchemy.orm import Session
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select
from app.core.auth import (
verify_password,
@@ -28,7 +29,7 @@ class AuthService:
"""Service for handling authentication operations"""
@staticmethod
def authenticate_user(db: Session, email: str, password: str) -> Optional[User]:
async def authenticate_user(db: AsyncSession, email: str, password: str) -> Optional[User]:
"""
Authenticate a user with email and password.
@@ -40,7 +41,8 @@ class AuthService:
Returns:
User if authenticated, None otherwise
"""
user = db.query(User).filter(User.email == email).first()
result = await db.execute(select(User).where(User.email == email))
user = result.scalar_one_or_none()
if not user:
return None
@@ -54,7 +56,7 @@ class AuthService:
return user
@staticmethod
def create_user(db: Session, user_data: UserCreate) -> User:
async def create_user(db: AsyncSession, user_data: UserCreate) -> User:
"""
Create a new user.
@@ -66,7 +68,8 @@ class AuthService:
Created user
"""
# Check if user already exists
existing_user = db.query(User).filter(User.email == user_data.email).first()
result = await db.execute(select(User).where(User.email == user_data.email))
existing_user = result.scalar_one_or_none()
if existing_user:
raise AuthenticationError("User with this email already exists")
@@ -85,8 +88,8 @@ class AuthService:
)
db.add(user)
db.commit()
db.refresh(user)
await db.commit()
await db.refresh(user)
return user
@@ -124,7 +127,7 @@ class AuthService:
)
@staticmethod
def refresh_tokens(db: Session, refresh_token: str) -> Token:
async def refresh_tokens(db: AsyncSession, refresh_token: str) -> Token:
"""
Generate new tokens using a refresh token.
@@ -150,7 +153,8 @@ class AuthService:
user_id = token_data.user_id
# Get user from database
user = db.query(User).filter(User.id == user_id).first()
result = await db.execute(select(User).where(User.id == user_id))
user = result.scalar_one_or_none()
if not user or not user.is_active:
raise TokenInvalidError("Invalid user or inactive account")
@@ -162,7 +166,7 @@ class AuthService:
raise
@staticmethod
def change_password(db: Session, user_id: UUID, current_password: str, new_password: str) -> bool:
async def change_password(db: AsyncSession, user_id: UUID, current_password: str, new_password: str) -> bool:
"""
Change a user's password.
@@ -178,7 +182,8 @@ class AuthService:
Raises:
AuthenticationError: If current password is incorrect
"""
user = db.query(User).filter(User.id == user_id).first()
result = await db.execute(select(User).where(User.id == user_id))
user = result.scalar_one_or_none()
if not user:
raise AuthenticationError("User not found")
@@ -188,6 +193,6 @@ class AuthService:
# Update password
user.password_hash = get_password_hash(new_password)
db.commit()
await db.commit()
return True

78
backend/app/services/session_cleanup.py Normal file → Executable file
View File

@@ -6,13 +6,13 @@ This service runs periodically to remove old session records from the database.
import logging
from datetime import datetime, timezone
from app.core.database import SessionLocal
from app.crud.session import session as session_crud
from app.core.database_async import AsyncSessionLocal
from app.crud.session_async import session_async as session_crud
logger = logging.getLogger(__name__)
def cleanup_expired_sessions(keep_days: int = 30) -> int:
async def cleanup_expired_sessions(keep_days: int = 30) -> int:
"""
Clean up expired and inactive sessions.
@@ -29,52 +29,58 @@ def cleanup_expired_sessions(keep_days: int = 30) -> int:
"""
logger.info("Starting session cleanup job...")
db = SessionLocal()
try:
# Use CRUD method to cleanup
count = session_crud.cleanup_expired(db, keep_days=keep_days)
async with AsyncSessionLocal() as db:
try:
# Use CRUD method to cleanup
count = await session_crud.cleanup_expired(db, keep_days=keep_days)
logger.info(f"Session cleanup complete: {count} sessions deleted")
logger.info(f"Session cleanup complete: {count} sessions deleted")
return count
return count
except Exception as e:
logger.error(f"Error during session cleanup: {str(e)}", exc_info=True)
return 0
finally:
db.close()
except Exception as e:
logger.error(f"Error during session cleanup: {str(e)}", exc_info=True)
return 0
def get_session_statistics() -> dict:
async def get_session_statistics() -> dict:
"""
Get statistics about current sessions.
Returns:
Dictionary with session stats
"""
db = SessionLocal()
try:
from app.models.user_session import UserSession
async with AsyncSessionLocal() as db:
try:
from app.models.user_session import UserSession
from sqlalchemy import select, func
total_sessions = db.query(UserSession).count()
active_sessions = db.query(UserSession).filter(UserSession.is_active == True).count()
expired_sessions = db.query(UserSession).filter(
UserSession.expires_at < datetime.now(timezone.utc)
).count()
total_result = await db.execute(select(func.count(UserSession.id)))
total_sessions = total_result.scalar_one()
stats = {
"total": total_sessions,
"active": active_sessions,
"inactive": total_sessions - active_sessions,
"expired": expired_sessions,
}
active_result = await db.execute(
select(func.count(UserSession.id)).where(UserSession.is_active == True)
)
active_sessions = active_result.scalar_one()
logger.info(f"Session statistics: {stats}")
expired_result = await db.execute(
select(func.count(UserSession.id)).where(
UserSession.expires_at < datetime.now(timezone.utc)
)
)
expired_sessions = expired_result.scalar_one()
return stats
stats = {
"total": total_sessions,
"active": active_sessions,
"inactive": total_sessions - active_sessions,
"expired": expired_sessions,
}
except Exception as e:
logger.error(f"Error getting session statistics: {str(e)}", exc_info=True)
return {}
finally:
db.close()
logger.info(f"Session statistics: {stats}")
return stats
except Exception as e:
logger.error(f"Error getting session statistics: {str(e)}", exc_info=True)
return {}