forked from cardosofelipe/fast-next-template
Refactor backend to adopt async patterns across services, API routes, and CRUD operations
- Migrated database sessions and operations to `AsyncSession` for full async support. - Updated all service methods and dependencies (`get_db` to `get_async_db`) to support async logic. - Refactored admin, user, organization, session-related CRUD methods, and routes with await syntax. - Improved consistency and performance with async SQLAlchemy patterns. - Enhanced logging and error handling for async context.
This commit is contained in:
0
backend/app/__init__.py
Normal file → Executable file
0
backend/app/__init__.py
Normal file → Executable file
24
backend/app/api/dependencies/auth.py
Normal file → Executable file
24
backend/app/api/dependencies/auth.py
Normal file → Executable file
@@ -3,18 +3,19 @@ from typing import Optional
|
|||||||
from fastapi import Depends, HTTPException, status, Header
|
from fastapi import Depends, HTTPException, status, Header
|
||||||
from fastapi.security import OAuth2PasswordBearer
|
from fastapi.security import OAuth2PasswordBearer
|
||||||
from fastapi.security.utils import get_authorization_scheme_param
|
from fastapi.security.utils import get_authorization_scheme_param
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
from sqlalchemy import select
|
||||||
|
|
||||||
from app.core.auth import get_token_data, TokenExpiredError, TokenInvalidError
|
from app.core.auth import get_token_data, TokenExpiredError, TokenInvalidError
|
||||||
from app.core.database import get_db
|
from app.core.database_async import get_async_db
|
||||||
from app.models.user import User
|
from app.models.user import User
|
||||||
|
|
||||||
# OAuth2 configuration
|
# OAuth2 configuration
|
||||||
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/api/v1/auth/login")
|
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/api/v1/auth/login")
|
||||||
|
|
||||||
|
|
||||||
def get_current_user(
|
async def get_current_user(
|
||||||
db: Session = Depends(get_db),
|
db: AsyncSession = Depends(get_async_db),
|
||||||
token: str = Depends(oauth2_scheme)
|
token: str = Depends(oauth2_scheme)
|
||||||
) -> User:
|
) -> User:
|
||||||
"""
|
"""
|
||||||
@@ -35,7 +36,11 @@ def get_current_user(
|
|||||||
token_data = get_token_data(token)
|
token_data = get_token_data(token)
|
||||||
|
|
||||||
# Get user from database
|
# Get user from database
|
||||||
user = db.query(User).filter(User.id == token_data.user_id).first()
|
result = await db.execute(
|
||||||
|
select(User).where(User.id == token_data.user_id)
|
||||||
|
)
|
||||||
|
user = result.scalar_one_or_none()
|
||||||
|
|
||||||
if not user:
|
if not user:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_404_NOT_FOUND,
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
@@ -133,8 +138,8 @@ async def get_optional_token(authorization: str = Header(None)) -> Optional[str]
|
|||||||
return token
|
return token
|
||||||
|
|
||||||
|
|
||||||
def get_optional_current_user(
|
async def get_optional_current_user(
|
||||||
db: Session = Depends(get_db),
|
db: AsyncSession = Depends(get_async_db),
|
||||||
token: Optional[str] = Depends(get_optional_token)
|
token: Optional[str] = Depends(get_optional_token)
|
||||||
) -> Optional[User]:
|
) -> Optional[User]:
|
||||||
"""
|
"""
|
||||||
@@ -153,7 +158,10 @@ def get_optional_current_user(
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
token_data = get_token_data(token)
|
token_data = get_token_data(token)
|
||||||
user = db.query(User).filter(User.id == token_data.user_id).first()
|
result = await db.execute(
|
||||||
|
select(User).where(User.id == token_data.user_id)
|
||||||
|
)
|
||||||
|
user = result.scalar_one_or_none()
|
||||||
if not user or not user.is_active:
|
if not user or not user.is_active:
|
||||||
return None
|
return None
|
||||||
return user
|
return user
|
||||||
|
|||||||
26
backend/app/api/dependencies/permissions.py
Normal file → Executable file
26
backend/app/api/dependencies/permissions.py
Normal file → Executable file
@@ -10,13 +10,13 @@ These dependencies are optional and flexible:
|
|||||||
from typing import Optional
|
from typing import Optional
|
||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
from fastapi import Depends, HTTPException, status
|
from fastapi import Depends, HTTPException, status
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
from app.core.database import get_db
|
from app.core.database_async import get_async_db
|
||||||
from app.models.user import User
|
from app.models.user import User
|
||||||
from app.models.user_organization import OrganizationRole
|
from app.models.user_organization import OrganizationRole
|
||||||
from app.api.dependencies.auth import get_current_user
|
from app.api.dependencies.auth import get_current_user
|
||||||
from app.crud.organization import organization as organization_crud
|
from app.crud.organization_async import organization_async as organization_crud
|
||||||
|
|
||||||
|
|
||||||
def require_superuser(
|
def require_superuser(
|
||||||
@@ -73,11 +73,11 @@ class OrganizationPermission:
|
|||||||
"""
|
"""
|
||||||
self.allowed_roles = allowed_roles
|
self.allowed_roles = allowed_roles
|
||||||
|
|
||||||
def __call__(
|
async def __call__(
|
||||||
self,
|
self,
|
||||||
organization_id: UUID,
|
organization_id: UUID,
|
||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
db: Session = Depends(get_db)
|
db: AsyncSession = Depends(get_async_db)
|
||||||
) -> User:
|
) -> User:
|
||||||
"""
|
"""
|
||||||
Check if user has required role in the organization.
|
Check if user has required role in the organization.
|
||||||
@@ -98,7 +98,7 @@ class OrganizationPermission:
|
|||||||
return current_user
|
return current_user
|
||||||
|
|
||||||
# Get user's role in organization
|
# Get user's role in organization
|
||||||
user_role = organization_crud.get_user_role_in_org(
|
user_role = await organization_crud.get_user_role_in_org(
|
||||||
db,
|
db,
|
||||||
user_id=current_user.id,
|
user_id=current_user.id,
|
||||||
organization_id=organization_id
|
organization_id=organization_id
|
||||||
@@ -129,10 +129,10 @@ require_org_member = OrganizationPermission([
|
|||||||
])
|
])
|
||||||
|
|
||||||
|
|
||||||
def get_current_org_role(
|
async def get_current_org_role(
|
||||||
organization_id: UUID,
|
organization_id: UUID,
|
||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
db: Session = Depends(get_db)
|
db: AsyncSession = Depends(get_async_db)
|
||||||
) -> Optional[OrganizationRole]:
|
) -> Optional[OrganizationRole]:
|
||||||
"""
|
"""
|
||||||
Get the current user's role in an organization.
|
Get the current user's role in an organization.
|
||||||
@@ -142,7 +142,7 @@ def get_current_org_role(
|
|||||||
|
|
||||||
Example:
|
Example:
|
||||||
@router.get("/organizations/{org_id}/items")
|
@router.get("/organizations/{org_id}/items")
|
||||||
def list_items(
|
async def list_items(
|
||||||
org_id: UUID,
|
org_id: UUID,
|
||||||
role: OrganizationRole = Depends(get_current_org_role)
|
role: OrganizationRole = Depends(get_current_org_role)
|
||||||
):
|
):
|
||||||
@@ -153,17 +153,17 @@ def get_current_org_role(
|
|||||||
if current_user.is_superuser:
|
if current_user.is_superuser:
|
||||||
return OrganizationRole.OWNER
|
return OrganizationRole.OWNER
|
||||||
|
|
||||||
return organization_crud.get_user_role_in_org(
|
return await organization_crud.get_user_role_in_org(
|
||||||
db,
|
db,
|
||||||
user_id=current_user.id,
|
user_id=current_user.id,
|
||||||
organization_id=organization_id
|
organization_id=organization_id
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def require_org_membership(
|
async def require_org_membership(
|
||||||
organization_id: UUID,
|
organization_id: UUID,
|
||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
db: Session = Depends(get_db)
|
db: AsyncSession = Depends(get_async_db)
|
||||||
) -> User:
|
) -> User:
|
||||||
"""
|
"""
|
||||||
Ensure user is a member of the organization (any role).
|
Ensure user is a member of the organization (any role).
|
||||||
@@ -173,7 +173,7 @@ def require_org_membership(
|
|||||||
if current_user.is_superuser:
|
if current_user.is_superuser:
|
||||||
return current_user
|
return current_user
|
||||||
|
|
||||||
user_role = organization_crud.get_user_role_in_org(
|
user_role = await organization_crud.get_user_role_in_org(
|
||||||
db,
|
db,
|
||||||
user_id=current_user.id,
|
user_id=current_user.id,
|
||||||
organization_id=organization_id
|
organization_id=organization_id
|
||||||
|
|||||||
138
backend/app/api/routes/admin.py
Normal file → Executable file
138
backend/app/api/routes/admin.py
Normal file → Executable file
@@ -11,13 +11,13 @@ from uuid import UUID
|
|||||||
from enum import Enum
|
from enum import Enum
|
||||||
|
|
||||||
from fastapi import APIRouter, Depends, Query, Body, status
|
from fastapi import APIRouter, Depends, Query, Body, status
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
from app.api.dependencies.permissions import require_superuser
|
from app.api.dependencies.permissions import require_superuser
|
||||||
from app.core.database import get_db
|
from app.core.database_async import get_async_db
|
||||||
from app.crud.user import user as user_crud
|
from app.crud.user_async import user_async as user_crud
|
||||||
from app.crud.organization import organization as organization_crud
|
from app.crud.organization_async import organization_async as organization_crud
|
||||||
from app.models.user import User
|
from app.models.user import User
|
||||||
from app.models.user_organization import OrganizationRole
|
from app.models.user_organization import OrganizationRole
|
||||||
from app.schemas.users import UserResponse, UserCreate, UserUpdate
|
from app.schemas.users import UserResponse, UserCreate, UserUpdate
|
||||||
@@ -73,14 +73,14 @@ class BulkActionResult(BaseModel):
|
|||||||
description="Get paginated list of all users with filtering and search (admin only)",
|
description="Get paginated list of all users with filtering and search (admin only)",
|
||||||
operation_id="admin_list_users"
|
operation_id="admin_list_users"
|
||||||
)
|
)
|
||||||
def admin_list_users(
|
async def admin_list_users(
|
||||||
pagination: PaginationParams = Depends(),
|
pagination: PaginationParams = Depends(),
|
||||||
sort: SortParams = Depends(),
|
sort: SortParams = Depends(),
|
||||||
is_active: Optional[bool] = Query(None, description="Filter by active status"),
|
is_active: Optional[bool] = Query(None, description="Filter by active status"),
|
||||||
is_superuser: Optional[bool] = Query(None, description="Filter by superuser status"),
|
is_superuser: Optional[bool] = Query(None, description="Filter by superuser status"),
|
||||||
search: Optional[str] = Query(None, description="Search by email, name"),
|
search: Optional[str] = Query(None, description="Search by email, name"),
|
||||||
admin: User = Depends(require_superuser),
|
admin: User = Depends(require_superuser),
|
||||||
db: Session = Depends(get_db)
|
db: AsyncSession = Depends(get_async_db)
|
||||||
) -> Any:
|
) -> Any:
|
||||||
"""
|
"""
|
||||||
List all users with comprehensive filtering and search.
|
List all users with comprehensive filtering and search.
|
||||||
@@ -96,7 +96,7 @@ def admin_list_users(
|
|||||||
filters["is_superuser"] = is_superuser
|
filters["is_superuser"] = is_superuser
|
||||||
|
|
||||||
# Get users with search
|
# Get users with search
|
||||||
users, total = user_crud.get_multi_with_total(
|
users, total = await user_crud.get_multi_with_total(
|
||||||
db,
|
db,
|
||||||
skip=pagination.offset,
|
skip=pagination.offset,
|
||||||
limit=pagination.limit,
|
limit=pagination.limit,
|
||||||
@@ -128,10 +128,10 @@ def admin_list_users(
|
|||||||
description="Create a new user (admin only)",
|
description="Create a new user (admin only)",
|
||||||
operation_id="admin_create_user"
|
operation_id="admin_create_user"
|
||||||
)
|
)
|
||||||
def admin_create_user(
|
async def admin_create_user(
|
||||||
user_in: UserCreate,
|
user_in: UserCreate,
|
||||||
admin: User = Depends(require_superuser),
|
admin: User = Depends(require_superuser),
|
||||||
db: Session = Depends(get_db)
|
db: AsyncSession = Depends(get_async_db)
|
||||||
) -> Any:
|
) -> Any:
|
||||||
"""
|
"""
|
||||||
Create a new user with admin privileges.
|
Create a new user with admin privileges.
|
||||||
@@ -139,7 +139,7 @@ def admin_create_user(
|
|||||||
Allows setting is_superuser and other fields.
|
Allows setting is_superuser and other fields.
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
user = user_crud.create(db, obj_in=user_in)
|
user = await user_crud.create(db, obj_in=user_in)
|
||||||
logger.info(f"Admin {admin.email} created user {user.email}")
|
logger.info(f"Admin {admin.email} created user {user.email}")
|
||||||
return user
|
return user
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
@@ -160,13 +160,13 @@ def admin_create_user(
|
|||||||
description="Get detailed user information (admin only)",
|
description="Get detailed user information (admin only)",
|
||||||
operation_id="admin_get_user"
|
operation_id="admin_get_user"
|
||||||
)
|
)
|
||||||
def admin_get_user(
|
async def admin_get_user(
|
||||||
user_id: UUID,
|
user_id: UUID,
|
||||||
admin: User = Depends(require_superuser),
|
admin: User = Depends(require_superuser),
|
||||||
db: Session = Depends(get_db)
|
db: AsyncSession = Depends(get_async_db)
|
||||||
) -> Any:
|
) -> Any:
|
||||||
"""Get detailed information about a specific user."""
|
"""Get detailed information about a specific user."""
|
||||||
user = user_crud.get(db, id=user_id)
|
user = await user_crud.get(db, id=user_id)
|
||||||
if not user:
|
if not user:
|
||||||
raise NotFoundError(
|
raise NotFoundError(
|
||||||
detail=f"User {user_id} not found",
|
detail=f"User {user_id} not found",
|
||||||
@@ -182,22 +182,22 @@ def admin_get_user(
|
|||||||
description="Update user information (admin only)",
|
description="Update user information (admin only)",
|
||||||
operation_id="admin_update_user"
|
operation_id="admin_update_user"
|
||||||
)
|
)
|
||||||
def admin_update_user(
|
async def admin_update_user(
|
||||||
user_id: UUID,
|
user_id: UUID,
|
||||||
user_in: UserUpdate,
|
user_in: UserUpdate,
|
||||||
admin: User = Depends(require_superuser),
|
admin: User = Depends(require_superuser),
|
||||||
db: Session = Depends(get_db)
|
db: AsyncSession = Depends(get_async_db)
|
||||||
) -> Any:
|
) -> Any:
|
||||||
"""Update user information with admin privileges."""
|
"""Update user information with admin privileges."""
|
||||||
try:
|
try:
|
||||||
user = user_crud.get(db, id=user_id)
|
user = await user_crud.get(db, id=user_id)
|
||||||
if not user:
|
if not user:
|
||||||
raise NotFoundError(
|
raise NotFoundError(
|
||||||
detail=f"User {user_id} not found",
|
detail=f"User {user_id} not found",
|
||||||
error_code=ErrorCode.USER_NOT_FOUND
|
error_code=ErrorCode.USER_NOT_FOUND
|
||||||
)
|
)
|
||||||
|
|
||||||
updated_user = user_crud.update(db, db_obj=user, obj_in=user_in)
|
updated_user = await user_crud.update(db, db_obj=user, obj_in=user_in)
|
||||||
logger.info(f"Admin {admin.email} updated user {updated_user.email}")
|
logger.info(f"Admin {admin.email} updated user {updated_user.email}")
|
||||||
return updated_user
|
return updated_user
|
||||||
|
|
||||||
@@ -215,14 +215,14 @@ def admin_update_user(
|
|||||||
description="Soft delete a user (admin only)",
|
description="Soft delete a user (admin only)",
|
||||||
operation_id="admin_delete_user"
|
operation_id="admin_delete_user"
|
||||||
)
|
)
|
||||||
def admin_delete_user(
|
async def admin_delete_user(
|
||||||
user_id: UUID,
|
user_id: UUID,
|
||||||
admin: User = Depends(require_superuser),
|
admin: User = Depends(require_superuser),
|
||||||
db: Session = Depends(get_db)
|
db: AsyncSession = Depends(get_async_db)
|
||||||
) -> Any:
|
) -> Any:
|
||||||
"""Soft delete a user (sets deleted_at timestamp)."""
|
"""Soft delete a user (sets deleted_at timestamp)."""
|
||||||
try:
|
try:
|
||||||
user = user_crud.get(db, id=user_id)
|
user = await user_crud.get(db, id=user_id)
|
||||||
if not user:
|
if not user:
|
||||||
raise NotFoundError(
|
raise NotFoundError(
|
||||||
detail=f"User {user_id} not found",
|
detail=f"User {user_id} not found",
|
||||||
@@ -236,7 +236,7 @@ def admin_delete_user(
|
|||||||
error_code=ErrorCode.OPERATION_FORBIDDEN
|
error_code=ErrorCode.OPERATION_FORBIDDEN
|
||||||
)
|
)
|
||||||
|
|
||||||
user_crud.soft_delete(db, id=user_id)
|
await user_crud.soft_delete(db, id=user_id)
|
||||||
logger.info(f"Admin {admin.email} deleted user {user.email}")
|
logger.info(f"Admin {admin.email} deleted user {user.email}")
|
||||||
|
|
||||||
return MessageResponse(
|
return MessageResponse(
|
||||||
@@ -258,21 +258,21 @@ def admin_delete_user(
|
|||||||
description="Activate a user account (admin only)",
|
description="Activate a user account (admin only)",
|
||||||
operation_id="admin_activate_user"
|
operation_id="admin_activate_user"
|
||||||
)
|
)
|
||||||
def admin_activate_user(
|
async def admin_activate_user(
|
||||||
user_id: UUID,
|
user_id: UUID,
|
||||||
admin: User = Depends(require_superuser),
|
admin: User = Depends(require_superuser),
|
||||||
db: Session = Depends(get_db)
|
db: AsyncSession = Depends(get_async_db)
|
||||||
) -> Any:
|
) -> Any:
|
||||||
"""Activate a user account."""
|
"""Activate a user account."""
|
||||||
try:
|
try:
|
||||||
user = user_crud.get(db, id=user_id)
|
user = await user_crud.get(db, id=user_id)
|
||||||
if not user:
|
if not user:
|
||||||
raise NotFoundError(
|
raise NotFoundError(
|
||||||
detail=f"User {user_id} not found",
|
detail=f"User {user_id} not found",
|
||||||
error_code=ErrorCode.USER_NOT_FOUND
|
error_code=ErrorCode.USER_NOT_FOUND
|
||||||
)
|
)
|
||||||
|
|
||||||
user_crud.update(db, db_obj=user, obj_in={"is_active": True})
|
await user_crud.update(db, db_obj=user, obj_in={"is_active": True})
|
||||||
logger.info(f"Admin {admin.email} activated user {user.email}")
|
logger.info(f"Admin {admin.email} activated user {user.email}")
|
||||||
|
|
||||||
return MessageResponse(
|
return MessageResponse(
|
||||||
@@ -294,14 +294,14 @@ def admin_activate_user(
|
|||||||
description="Deactivate a user account (admin only)",
|
description="Deactivate a user account (admin only)",
|
||||||
operation_id="admin_deactivate_user"
|
operation_id="admin_deactivate_user"
|
||||||
)
|
)
|
||||||
def admin_deactivate_user(
|
async def admin_deactivate_user(
|
||||||
user_id: UUID,
|
user_id: UUID,
|
||||||
admin: User = Depends(require_superuser),
|
admin: User = Depends(require_superuser),
|
||||||
db: Session = Depends(get_db)
|
db: AsyncSession = Depends(get_async_db)
|
||||||
) -> Any:
|
) -> Any:
|
||||||
"""Deactivate a user account."""
|
"""Deactivate a user account."""
|
||||||
try:
|
try:
|
||||||
user = user_crud.get(db, id=user_id)
|
user = await user_crud.get(db, id=user_id)
|
||||||
if not user:
|
if not user:
|
||||||
raise NotFoundError(
|
raise NotFoundError(
|
||||||
detail=f"User {user_id} not found",
|
detail=f"User {user_id} not found",
|
||||||
@@ -315,7 +315,7 @@ def admin_deactivate_user(
|
|||||||
error_code=ErrorCode.OPERATION_FORBIDDEN
|
error_code=ErrorCode.OPERATION_FORBIDDEN
|
||||||
)
|
)
|
||||||
|
|
||||||
user_crud.update(db, db_obj=user, obj_in={"is_active": False})
|
await user_crud.update(db, db_obj=user, obj_in={"is_active": False})
|
||||||
logger.info(f"Admin {admin.email} deactivated user {user.email}")
|
logger.info(f"Admin {admin.email} deactivated user {user.email}")
|
||||||
|
|
||||||
return MessageResponse(
|
return MessageResponse(
|
||||||
@@ -337,10 +337,10 @@ def admin_deactivate_user(
|
|||||||
description="Perform bulk actions on multiple users (admin only)",
|
description="Perform bulk actions on multiple users (admin only)",
|
||||||
operation_id="admin_bulk_user_action"
|
operation_id="admin_bulk_user_action"
|
||||||
)
|
)
|
||||||
def admin_bulk_user_action(
|
async def admin_bulk_user_action(
|
||||||
bulk_action: BulkUserAction,
|
bulk_action: BulkUserAction,
|
||||||
admin: User = Depends(require_superuser),
|
admin: User = Depends(require_superuser),
|
||||||
db: Session = Depends(get_db)
|
db: AsyncSession = Depends(get_async_db)
|
||||||
) -> Any:
|
) -> Any:
|
||||||
"""
|
"""
|
||||||
Perform bulk actions on multiple users.
|
Perform bulk actions on multiple users.
|
||||||
@@ -354,7 +354,7 @@ def admin_bulk_user_action(
|
|||||||
try:
|
try:
|
||||||
for user_id in bulk_action.user_ids:
|
for user_id in bulk_action.user_ids:
|
||||||
try:
|
try:
|
||||||
user = user_crud.get(db, id=user_id)
|
user = await user_crud.get(db, id=user_id)
|
||||||
if not user:
|
if not user:
|
||||||
failed_count += 1
|
failed_count += 1
|
||||||
failed_ids.append(user_id)
|
failed_ids.append(user_id)
|
||||||
@@ -367,11 +367,11 @@ def admin_bulk_user_action(
|
|||||||
continue
|
continue
|
||||||
|
|
||||||
if bulk_action.action == BulkAction.ACTIVATE:
|
if bulk_action.action == BulkAction.ACTIVATE:
|
||||||
user_crud.update(db, db_obj=user, obj_in={"is_active": True})
|
await user_crud.update(db, db_obj=user, obj_in={"is_active": True})
|
||||||
elif bulk_action.action == BulkAction.DEACTIVATE:
|
elif bulk_action.action == BulkAction.DEACTIVATE:
|
||||||
user_crud.update(db, db_obj=user, obj_in={"is_active": False})
|
await user_crud.update(db, db_obj=user, obj_in={"is_active": False})
|
||||||
elif bulk_action.action == BulkAction.DELETE:
|
elif bulk_action.action == BulkAction.DELETE:
|
||||||
user_crud.soft_delete(db, id=user_id)
|
await user_crud.soft_delete(db, id=user_id)
|
||||||
|
|
||||||
affected_count += 1
|
affected_count += 1
|
||||||
|
|
||||||
@@ -407,16 +407,16 @@ def admin_bulk_user_action(
|
|||||||
description="Get paginated list of all organizations (admin only)",
|
description="Get paginated list of all organizations (admin only)",
|
||||||
operation_id="admin_list_organizations"
|
operation_id="admin_list_organizations"
|
||||||
)
|
)
|
||||||
def admin_list_organizations(
|
async def admin_list_organizations(
|
||||||
pagination: PaginationParams = Depends(),
|
pagination: PaginationParams = Depends(),
|
||||||
is_active: Optional[bool] = Query(None, description="Filter by active status"),
|
is_active: Optional[bool] = Query(None, description="Filter by active status"),
|
||||||
search: Optional[str] = Query(None, description="Search by name, slug, description"),
|
search: Optional[str] = Query(None, description="Search by name, slug, description"),
|
||||||
admin: User = Depends(require_superuser),
|
admin: User = Depends(require_superuser),
|
||||||
db: Session = Depends(get_db)
|
db: AsyncSession = Depends(get_async_db)
|
||||||
) -> Any:
|
) -> Any:
|
||||||
"""List all organizations with filtering and search."""
|
"""List all organizations with filtering and search."""
|
||||||
try:
|
try:
|
||||||
orgs, total = organization_crud.get_multi_with_filters(
|
orgs, total = await organization_crud.get_multi_with_filters(
|
||||||
db,
|
db,
|
||||||
skip=pagination.offset,
|
skip=pagination.offset,
|
||||||
limit=pagination.limit,
|
limit=pagination.limit,
|
||||||
@@ -438,7 +438,7 @@ def admin_list_organizations(
|
|||||||
"settings": org.settings,
|
"settings": org.settings,
|
||||||
"created_at": org.created_at,
|
"created_at": org.created_at,
|
||||||
"updated_at": org.updated_at,
|
"updated_at": org.updated_at,
|
||||||
"member_count": organization_crud.get_member_count(db, organization_id=org.id)
|
"member_count": await organization_crud.get_member_count(db, organization_id=org.id)
|
||||||
}
|
}
|
||||||
orgs_with_count.append(OrganizationResponse(**org_dict))
|
orgs_with_count.append(OrganizationResponse(**org_dict))
|
||||||
|
|
||||||
@@ -464,14 +464,14 @@ def admin_list_organizations(
|
|||||||
description="Create a new organization (admin only)",
|
description="Create a new organization (admin only)",
|
||||||
operation_id="admin_create_organization"
|
operation_id="admin_create_organization"
|
||||||
)
|
)
|
||||||
def admin_create_organization(
|
async def admin_create_organization(
|
||||||
org_in: OrganizationCreate,
|
org_in: OrganizationCreate,
|
||||||
admin: User = Depends(require_superuser),
|
admin: User = Depends(require_superuser),
|
||||||
db: Session = Depends(get_db)
|
db: AsyncSession = Depends(get_async_db)
|
||||||
) -> Any:
|
) -> Any:
|
||||||
"""Create a new organization."""
|
"""Create a new organization."""
|
||||||
try:
|
try:
|
||||||
org = organization_crud.create(db, obj_in=org_in)
|
org = await organization_crud.create(db, obj_in=org_in)
|
||||||
logger.info(f"Admin {admin.email} created organization {org.name}")
|
logger.info(f"Admin {admin.email} created organization {org.name}")
|
||||||
|
|
||||||
# Add member count
|
# Add member count
|
||||||
@@ -506,13 +506,13 @@ def admin_create_organization(
|
|||||||
description="Get detailed organization information (admin only)",
|
description="Get detailed organization information (admin only)",
|
||||||
operation_id="admin_get_organization"
|
operation_id="admin_get_organization"
|
||||||
)
|
)
|
||||||
def admin_get_organization(
|
async def admin_get_organization(
|
||||||
org_id: UUID,
|
org_id: UUID,
|
||||||
admin: User = Depends(require_superuser),
|
admin: User = Depends(require_superuser),
|
||||||
db: Session = Depends(get_db)
|
db: AsyncSession = Depends(get_async_db)
|
||||||
) -> Any:
|
) -> Any:
|
||||||
"""Get detailed information about a specific organization."""
|
"""Get detailed information about a specific organization."""
|
||||||
org = organization_crud.get(db, id=org_id)
|
org = await organization_crud.get(db, id=org_id)
|
||||||
if not org:
|
if not org:
|
||||||
raise NotFoundError(
|
raise NotFoundError(
|
||||||
detail=f"Organization {org_id} not found",
|
detail=f"Organization {org_id} not found",
|
||||||
@@ -528,7 +528,7 @@ def admin_get_organization(
|
|||||||
"settings": org.settings,
|
"settings": org.settings,
|
||||||
"created_at": org.created_at,
|
"created_at": org.created_at,
|
||||||
"updated_at": org.updated_at,
|
"updated_at": org.updated_at,
|
||||||
"member_count": organization_crud.get_member_count(db, organization_id=org.id)
|
"member_count": await organization_crud.get_member_count(db, organization_id=org.id)
|
||||||
}
|
}
|
||||||
return OrganizationResponse(**org_dict)
|
return OrganizationResponse(**org_dict)
|
||||||
|
|
||||||
@@ -540,22 +540,22 @@ def admin_get_organization(
|
|||||||
description="Update organization information (admin only)",
|
description="Update organization information (admin only)",
|
||||||
operation_id="admin_update_organization"
|
operation_id="admin_update_organization"
|
||||||
)
|
)
|
||||||
def admin_update_organization(
|
async def admin_update_organization(
|
||||||
org_id: UUID,
|
org_id: UUID,
|
||||||
org_in: OrganizationUpdate,
|
org_in: OrganizationUpdate,
|
||||||
admin: User = Depends(require_superuser),
|
admin: User = Depends(require_superuser),
|
||||||
db: Session = Depends(get_db)
|
db: AsyncSession = Depends(get_async_db)
|
||||||
) -> Any:
|
) -> Any:
|
||||||
"""Update organization information."""
|
"""Update organization information."""
|
||||||
try:
|
try:
|
||||||
org = organization_crud.get(db, id=org_id)
|
org = await organization_crud.get(db, id=org_id)
|
||||||
if not org:
|
if not org:
|
||||||
raise NotFoundError(
|
raise NotFoundError(
|
||||||
detail=f"Organization {org_id} not found",
|
detail=f"Organization {org_id} not found",
|
||||||
error_code=ErrorCode.NOT_FOUND
|
error_code=ErrorCode.NOT_FOUND
|
||||||
)
|
)
|
||||||
|
|
||||||
updated_org = organization_crud.update(db, db_obj=org, obj_in=org_in)
|
updated_org = await organization_crud.update(db, db_obj=org, obj_in=org_in)
|
||||||
logger.info(f"Admin {admin.email} updated organization {updated_org.name}")
|
logger.info(f"Admin {admin.email} updated organization {updated_org.name}")
|
||||||
|
|
||||||
org_dict = {
|
org_dict = {
|
||||||
@@ -567,7 +567,7 @@ def admin_update_organization(
|
|||||||
"settings": updated_org.settings,
|
"settings": updated_org.settings,
|
||||||
"created_at": updated_org.created_at,
|
"created_at": updated_org.created_at,
|
||||||
"updated_at": updated_org.updated_at,
|
"updated_at": updated_org.updated_at,
|
||||||
"member_count": organization_crud.get_member_count(db, organization_id=updated_org.id)
|
"member_count": await organization_crud.get_member_count(db, organization_id=updated_org.id)
|
||||||
}
|
}
|
||||||
return OrganizationResponse(**org_dict)
|
return OrganizationResponse(**org_dict)
|
||||||
|
|
||||||
@@ -585,21 +585,21 @@ def admin_update_organization(
|
|||||||
description="Delete an organization (admin only)",
|
description="Delete an organization (admin only)",
|
||||||
operation_id="admin_delete_organization"
|
operation_id="admin_delete_organization"
|
||||||
)
|
)
|
||||||
def admin_delete_organization(
|
async def admin_delete_organization(
|
||||||
org_id: UUID,
|
org_id: UUID,
|
||||||
admin: User = Depends(require_superuser),
|
admin: User = Depends(require_superuser),
|
||||||
db: Session = Depends(get_db)
|
db: AsyncSession = Depends(get_async_db)
|
||||||
) -> Any:
|
) -> Any:
|
||||||
"""Delete an organization and all its relationships."""
|
"""Delete an organization and all its relationships."""
|
||||||
try:
|
try:
|
||||||
org = organization_crud.get(db, id=org_id)
|
org = await organization_crud.get(db, id=org_id)
|
||||||
if not org:
|
if not org:
|
||||||
raise NotFoundError(
|
raise NotFoundError(
|
||||||
detail=f"Organization {org_id} not found",
|
detail=f"Organization {org_id} not found",
|
||||||
error_code=ErrorCode.NOT_FOUND
|
error_code=ErrorCode.NOT_FOUND
|
||||||
)
|
)
|
||||||
|
|
||||||
organization_crud.remove(db, id=org_id)
|
await organization_crud.remove(db, id=org_id)
|
||||||
logger.info(f"Admin {admin.email} deleted organization {org.name}")
|
logger.info(f"Admin {admin.email} deleted organization {org.name}")
|
||||||
|
|
||||||
return MessageResponse(
|
return MessageResponse(
|
||||||
@@ -621,23 +621,23 @@ def admin_delete_organization(
|
|||||||
description="Get all members of an organization (admin only)",
|
description="Get all members of an organization (admin only)",
|
||||||
operation_id="admin_list_organization_members"
|
operation_id="admin_list_organization_members"
|
||||||
)
|
)
|
||||||
def admin_list_organization_members(
|
async def admin_list_organization_members(
|
||||||
org_id: UUID,
|
org_id: UUID,
|
||||||
pagination: PaginationParams = Depends(),
|
pagination: PaginationParams = Depends(),
|
||||||
is_active: Optional[bool] = Query(True, description="Filter by active status"),
|
is_active: Optional[bool] = Query(True, description="Filter by active status"),
|
||||||
admin: User = Depends(require_superuser),
|
admin: User = Depends(require_superuser),
|
||||||
db: Session = Depends(get_db)
|
db: AsyncSession = Depends(get_async_db)
|
||||||
) -> Any:
|
) -> Any:
|
||||||
"""List all members of an organization."""
|
"""List all members of an organization."""
|
||||||
try:
|
try:
|
||||||
org = organization_crud.get(db, id=org_id)
|
org = await organization_crud.get(db, id=org_id)
|
||||||
if not org:
|
if not org:
|
||||||
raise NotFoundError(
|
raise NotFoundError(
|
||||||
detail=f"Organization {org_id} not found",
|
detail=f"Organization {org_id} not found",
|
||||||
error_code=ErrorCode.NOT_FOUND
|
error_code=ErrorCode.NOT_FOUND
|
||||||
)
|
)
|
||||||
|
|
||||||
members, total = organization_crud.get_organization_members(
|
members, total = await organization_crud.get_organization_members(
|
||||||
db,
|
db,
|
||||||
organization_id=org_id,
|
organization_id=org_id,
|
||||||
skip=pagination.offset,
|
skip=pagination.offset,
|
||||||
@@ -677,29 +677,29 @@ class AddMemberRequest(BaseModel):
|
|||||||
description="Add a user to an organization (admin only)",
|
description="Add a user to an organization (admin only)",
|
||||||
operation_id="admin_add_organization_member"
|
operation_id="admin_add_organization_member"
|
||||||
)
|
)
|
||||||
def admin_add_organization_member(
|
async def admin_add_organization_member(
|
||||||
org_id: UUID,
|
org_id: UUID,
|
||||||
request: AddMemberRequest,
|
request: AddMemberRequest,
|
||||||
admin: User = Depends(require_superuser),
|
admin: User = Depends(require_superuser),
|
||||||
db: Session = Depends(get_db)
|
db: AsyncSession = Depends(get_async_db)
|
||||||
) -> Any:
|
) -> Any:
|
||||||
"""Add a user to an organization."""
|
"""Add a user to an organization."""
|
||||||
try:
|
try:
|
||||||
org = organization_crud.get(db, id=org_id)
|
org = await organization_crud.get(db, id=org_id)
|
||||||
if not org:
|
if not org:
|
||||||
raise NotFoundError(
|
raise NotFoundError(
|
||||||
detail=f"Organization {org_id} not found",
|
detail=f"Organization {org_id} not found",
|
||||||
error_code=ErrorCode.NOT_FOUND
|
error_code=ErrorCode.NOT_FOUND
|
||||||
)
|
)
|
||||||
|
|
||||||
user = user_crud.get(db, id=request.user_id)
|
user = await user_crud.get(db, id=request.user_id)
|
||||||
if not user:
|
if not user:
|
||||||
raise NotFoundError(
|
raise NotFoundError(
|
||||||
detail=f"User {request.user_id} not found",
|
detail=f"User {request.user_id} not found",
|
||||||
error_code=ErrorCode.USER_NOT_FOUND
|
error_code=ErrorCode.USER_NOT_FOUND
|
||||||
)
|
)
|
||||||
|
|
||||||
organization_crud.add_user(
|
await organization_crud.add_user(
|
||||||
db,
|
db,
|
||||||
organization_id=org_id,
|
organization_id=org_id,
|
||||||
user_id=request.user_id,
|
user_id=request.user_id,
|
||||||
@@ -733,29 +733,29 @@ def admin_add_organization_member(
|
|||||||
description="Remove a user from an organization (admin only)",
|
description="Remove a user from an organization (admin only)",
|
||||||
operation_id="admin_remove_organization_member"
|
operation_id="admin_remove_organization_member"
|
||||||
)
|
)
|
||||||
def admin_remove_organization_member(
|
async def admin_remove_organization_member(
|
||||||
org_id: UUID,
|
org_id: UUID,
|
||||||
user_id: UUID,
|
user_id: UUID,
|
||||||
admin: User = Depends(require_superuser),
|
admin: User = Depends(require_superuser),
|
||||||
db: Session = Depends(get_db)
|
db: AsyncSession = Depends(get_async_db)
|
||||||
) -> Any:
|
) -> Any:
|
||||||
"""Remove a user from an organization."""
|
"""Remove a user from an organization."""
|
||||||
try:
|
try:
|
||||||
org = organization_crud.get(db, id=org_id)
|
org = await organization_crud.get(db, id=org_id)
|
||||||
if not org:
|
if not org:
|
||||||
raise NotFoundError(
|
raise NotFoundError(
|
||||||
detail=f"Organization {org_id} not found",
|
detail=f"Organization {org_id} not found",
|
||||||
error_code=ErrorCode.NOT_FOUND
|
error_code=ErrorCode.NOT_FOUND
|
||||||
)
|
)
|
||||||
|
|
||||||
user = user_crud.get(db, id=user_id)
|
user = await user_crud.get(db, id=user_id)
|
||||||
if not user:
|
if not user:
|
||||||
raise NotFoundError(
|
raise NotFoundError(
|
||||||
detail=f"User {user_id} not found",
|
detail=f"User {user_id} not found",
|
||||||
error_code=ErrorCode.USER_NOT_FOUND
|
error_code=ErrorCode.USER_NOT_FOUND
|
||||||
)
|
)
|
||||||
|
|
||||||
success = organization_crud.remove_user(
|
success = await organization_crud.remove_user(
|
||||||
db,
|
db,
|
||||||
organization_id=org_id,
|
organization_id=org_id,
|
||||||
user_id=user_id
|
user_id=user_id
|
||||||
|
|||||||
62
backend/app/api/routes/auth.py
Normal file → Executable file
62
backend/app/api/routes/auth.py
Normal file → Executable file
@@ -8,11 +8,11 @@ from fastapi import APIRouter, Depends, HTTPException, status, Body, Request
|
|||||||
from fastapi.security import OAuth2PasswordRequestForm
|
from fastapi.security import OAuth2PasswordRequestForm
|
||||||
from slowapi import Limiter
|
from slowapi import Limiter
|
||||||
from slowapi.util import get_remote_address
|
from slowapi.util import get_remote_address
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
from app.api.dependencies.auth import get_current_user
|
from app.api.dependencies.auth import get_current_user
|
||||||
from app.core.auth import TokenExpiredError, TokenInvalidError, decode_token
|
from app.core.auth import TokenExpiredError, TokenInvalidError, decode_token
|
||||||
from app.core.database import get_db
|
from app.core.database_async import get_async_db
|
||||||
from app.models.user import User
|
from app.models.user import User
|
||||||
from app.schemas.users import (
|
from app.schemas.users import (
|
||||||
UserCreate,
|
UserCreate,
|
||||||
@@ -29,8 +29,8 @@ from app.services.auth_service import AuthService, AuthenticationError
|
|||||||
from app.services.email_service import email_service
|
from app.services.email_service import email_service
|
||||||
from app.utils.security import create_password_reset_token, verify_password_reset_token
|
from app.utils.security import create_password_reset_token, verify_password_reset_token
|
||||||
from app.utils.device import extract_device_info
|
from app.utils.device import extract_device_info
|
||||||
from app.crud.user import user as user_crud
|
from app.crud.user_async import user_async as user_crud
|
||||||
from app.crud.session import session as session_crud
|
from app.crud.session_async import session_async as session_crud
|
||||||
from app.core.auth import get_password_hash
|
from app.core.auth import get_password_hash
|
||||||
|
|
||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
@@ -49,7 +49,7 @@ RATE_MULTIPLIER = 100 if IS_TEST else 1
|
|||||||
async def register_user(
|
async def register_user(
|
||||||
request: Request,
|
request: Request,
|
||||||
user_data: UserCreate,
|
user_data: UserCreate,
|
||||||
db: Session = Depends(get_db)
|
db: AsyncSession = Depends(get_async_db)
|
||||||
) -> Any:
|
) -> Any:
|
||||||
"""
|
"""
|
||||||
Register a new user.
|
Register a new user.
|
||||||
@@ -58,7 +58,7 @@ async def register_user(
|
|||||||
The created user information.
|
The created user information.
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
user = AuthService.create_user(db, user_data)
|
user = await AuthService.create_user(db, user_data)
|
||||||
return user
|
return user
|
||||||
except AuthenticationError as e:
|
except AuthenticationError as e:
|
||||||
logger.warning(f"Registration failed: {str(e)}")
|
logger.warning(f"Registration failed: {str(e)}")
|
||||||
@@ -79,7 +79,7 @@ async def register_user(
|
|||||||
async def login(
|
async def login(
|
||||||
request: Request,
|
request: Request,
|
||||||
login_data: LoginRequest,
|
login_data: LoginRequest,
|
||||||
db: Session = Depends(get_db)
|
db: AsyncSession = Depends(get_async_db)
|
||||||
) -> Any:
|
) -> Any:
|
||||||
"""
|
"""
|
||||||
Login with username and password.
|
Login with username and password.
|
||||||
@@ -91,7 +91,7 @@ async def login(
|
|||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
# Attempt to authenticate the user
|
# Attempt to authenticate the user
|
||||||
user = AuthService.authenticate_user(db, login_data.email, login_data.password)
|
user = await AuthService.authenticate_user(db, login_data.email, login_data.password)
|
||||||
|
|
||||||
# Explicitly check for None result and raise correct exception
|
# Explicitly check for None result and raise correct exception
|
||||||
if user is None:
|
if user is None:
|
||||||
@@ -126,7 +126,7 @@ async def login(
|
|||||||
location_country=device_info.location_country,
|
location_country=device_info.location_country,
|
||||||
)
|
)
|
||||||
|
|
||||||
session_crud.create_session(db, obj_in=session_data)
|
await session_crud.create_session(db, obj_in=session_data)
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
f"User login successful: {user.email} from {device_info.device_name} "
|
f"User login successful: {user.email} from {device_info.device_name} "
|
||||||
@@ -163,7 +163,7 @@ async def login(
|
|||||||
async def login_oauth(
|
async def login_oauth(
|
||||||
request: Request,
|
request: Request,
|
||||||
form_data: OAuth2PasswordRequestForm = Depends(),
|
form_data: OAuth2PasswordRequestForm = Depends(),
|
||||||
db: Session = Depends(get_db)
|
db: AsyncSession = Depends(get_async_db)
|
||||||
) -> Any:
|
) -> Any:
|
||||||
"""
|
"""
|
||||||
OAuth2-compatible login endpoint, used by the OpenAPI UI.
|
OAuth2-compatible login endpoint, used by the OpenAPI UI.
|
||||||
@@ -174,7 +174,7 @@ async def login_oauth(
|
|||||||
Access and refresh tokens.
|
Access and refresh tokens.
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
user = AuthService.authenticate_user(db, form_data.username, form_data.password)
|
user = await AuthService.authenticate_user(db, form_data.username, form_data.password)
|
||||||
|
|
||||||
if user is None:
|
if user is None:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
@@ -207,7 +207,7 @@ async def login_oauth(
|
|||||||
location_country=device_info.location_country,
|
location_country=device_info.location_country,
|
||||||
)
|
)
|
||||||
|
|
||||||
session_crud.create_session(db, obj_in=session_data)
|
await session_crud.create_session(db, obj_in=session_data)
|
||||||
|
|
||||||
logger.info(f"OAuth login successful: {user.email} from {device_info.device_name}")
|
logger.info(f"OAuth login successful: {user.email} from {device_info.device_name}")
|
||||||
except Exception as session_err:
|
except Exception as session_err:
|
||||||
@@ -241,7 +241,7 @@ async def login_oauth(
|
|||||||
async def refresh_token(
|
async def refresh_token(
|
||||||
request: Request,
|
request: Request,
|
||||||
refresh_data: RefreshTokenRequest,
|
refresh_data: RefreshTokenRequest,
|
||||||
db: Session = Depends(get_db)
|
db: AsyncSession = Depends(get_async_db)
|
||||||
) -> Any:
|
) -> Any:
|
||||||
"""
|
"""
|
||||||
Refresh access token using a refresh token.
|
Refresh access token using a refresh token.
|
||||||
@@ -256,7 +256,7 @@ async def refresh_token(
|
|||||||
refresh_payload = decode_token(refresh_data.refresh_token, verify_type="refresh")
|
refresh_payload = decode_token(refresh_data.refresh_token, verify_type="refresh")
|
||||||
|
|
||||||
# Check if session exists and is active
|
# Check if session exists and is active
|
||||||
session = session_crud.get_active_by_jti(db, jti=refresh_payload.jti)
|
session = await session_crud.get_active_by_jti(db, jti=refresh_payload.jti)
|
||||||
|
|
||||||
if not session:
|
if not session:
|
||||||
logger.warning(f"Refresh token used for inactive or non-existent session: {refresh_payload.jti}")
|
logger.warning(f"Refresh token used for inactive or non-existent session: {refresh_payload.jti}")
|
||||||
@@ -267,14 +267,14 @@ async def refresh_token(
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Generate new tokens
|
# Generate new tokens
|
||||||
tokens = AuthService.refresh_tokens(db, refresh_data.refresh_token)
|
tokens = await AuthService.refresh_tokens(db, refresh_data.refresh_token)
|
||||||
|
|
||||||
# Decode new refresh token to get new JTI
|
# Decode new refresh token to get new JTI
|
||||||
new_refresh_payload = decode_token(tokens.refresh_token, verify_type="refresh")
|
new_refresh_payload = decode_token(tokens.refresh_token, verify_type="refresh")
|
||||||
|
|
||||||
# Update session with new refresh token JTI and expiration
|
# Update session with new refresh token JTI and expiration
|
||||||
try:
|
try:
|
||||||
session_crud.update_refresh_token(
|
await session_crud.update_refresh_token(
|
||||||
db,
|
db,
|
||||||
session=session,
|
session=session,
|
||||||
new_jti=new_refresh_payload.jti,
|
new_jti=new_refresh_payload.jti,
|
||||||
@@ -344,7 +344,7 @@ async def get_current_user_info(
|
|||||||
async def request_password_reset(
|
async def request_password_reset(
|
||||||
request: Request,
|
request: Request,
|
||||||
reset_request: PasswordResetRequest,
|
reset_request: PasswordResetRequest,
|
||||||
db: Session = Depends(get_db)
|
db: AsyncSession = Depends(get_async_db)
|
||||||
) -> Any:
|
) -> Any:
|
||||||
"""
|
"""
|
||||||
Request a password reset.
|
Request a password reset.
|
||||||
@@ -354,7 +354,7 @@ async def request_password_reset(
|
|||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
# Look up user by email
|
# Look up user by email
|
||||||
user = user_crud.get_by_email(db, email=reset_request.email)
|
user = await user_crud.get_by_email(db, email=reset_request.email)
|
||||||
|
|
||||||
# Only send email if user exists and is active
|
# Only send email if user exists and is active
|
||||||
if user and user.is_active:
|
if user and user.is_active:
|
||||||
@@ -399,10 +399,10 @@ async def request_password_reset(
|
|||||||
operation_id="confirm_password_reset"
|
operation_id="confirm_password_reset"
|
||||||
)
|
)
|
||||||
@limiter.limit("5/minute")
|
@limiter.limit("5/minute")
|
||||||
def confirm_password_reset(
|
async def confirm_password_reset(
|
||||||
request: Request,
|
request: Request,
|
||||||
reset_confirm: PasswordResetConfirm,
|
reset_confirm: PasswordResetConfirm,
|
||||||
db: Session = Depends(get_db)
|
db: AsyncSession = Depends(get_async_db)
|
||||||
) -> Any:
|
) -> Any:
|
||||||
"""
|
"""
|
||||||
Confirm password reset with token.
|
Confirm password reset with token.
|
||||||
@@ -420,7 +420,7 @@ def confirm_password_reset(
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Look up user
|
# Look up user
|
||||||
user = user_crud.get_by_email(db, email=email)
|
user = await user_crud.get_by_email(db, email=email)
|
||||||
|
|
||||||
if not user:
|
if not user:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
@@ -437,7 +437,7 @@ def confirm_password_reset(
|
|||||||
# Update password
|
# Update password
|
||||||
user.password_hash = get_password_hash(reset_confirm.new_password)
|
user.password_hash = get_password_hash(reset_confirm.new_password)
|
||||||
db.add(user)
|
db.add(user)
|
||||||
db.commit()
|
await db.commit()
|
||||||
|
|
||||||
logger.info(f"Password reset successful for {user.email}")
|
logger.info(f"Password reset successful for {user.email}")
|
||||||
|
|
||||||
@@ -450,7 +450,7 @@ def confirm_password_reset(
|
|||||||
raise
|
raise
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error confirming password reset: {str(e)}", exc_info=True)
|
logger.error(f"Error confirming password reset: {str(e)}", exc_info=True)
|
||||||
db.rollback()
|
await db.rollback()
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
detail="An error occurred while resetting your password"
|
detail="An error occurred while resetting your password"
|
||||||
@@ -474,11 +474,11 @@ def confirm_password_reset(
|
|||||||
operation_id="logout"
|
operation_id="logout"
|
||||||
)
|
)
|
||||||
@limiter.limit("10/minute")
|
@limiter.limit("10/minute")
|
||||||
def logout(
|
async def logout(
|
||||||
request: Request,
|
request: Request,
|
||||||
logout_request: LogoutRequest,
|
logout_request: LogoutRequest,
|
||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
db: Session = Depends(get_db)
|
db: AsyncSession = Depends(get_async_db)
|
||||||
) -> Any:
|
) -> Any:
|
||||||
"""
|
"""
|
||||||
Logout from current device by deactivating the session.
|
Logout from current device by deactivating the session.
|
||||||
@@ -505,7 +505,7 @@ def logout(
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Find the session by JTI
|
# Find the session by JTI
|
||||||
session = session_crud.get_by_jti(db, jti=refresh_payload.jti)
|
session = await session_crud.get_by_jti(db, jti=refresh_payload.jti)
|
||||||
|
|
||||||
if session:
|
if session:
|
||||||
# Verify session belongs to current user (security check)
|
# Verify session belongs to current user (security check)
|
||||||
@@ -520,7 +520,7 @@ def logout(
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Deactivate the session
|
# Deactivate the session
|
||||||
session_crud.deactivate(db, session_id=str(session.id))
|
await session_crud.deactivate(db, session_id=str(session.id))
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
f"User {current_user.id} logged out from {session.device_name} "
|
f"User {current_user.id} logged out from {session.device_name} "
|
||||||
@@ -563,10 +563,10 @@ def logout(
|
|||||||
operation_id="logout_all"
|
operation_id="logout_all"
|
||||||
)
|
)
|
||||||
@limiter.limit("5/minute")
|
@limiter.limit("5/minute")
|
||||||
def logout_all(
|
async def logout_all(
|
||||||
request: Request,
|
request: Request,
|
||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
db: Session = Depends(get_db)
|
db: AsyncSession = Depends(get_async_db)
|
||||||
) -> Any:
|
) -> Any:
|
||||||
"""
|
"""
|
||||||
Logout from all devices by deactivating all user sessions.
|
Logout from all devices by deactivating all user sessions.
|
||||||
@@ -580,7 +580,7 @@ def logout_all(
|
|||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
# Deactivate all sessions for this user
|
# Deactivate all sessions for this user
|
||||||
count = session_crud.deactivate_all_user_sessions(db, user_id=str(current_user.id))
|
count = await session_crud.deactivate_all_user_sessions(db, user_id=str(current_user.id))
|
||||||
|
|
||||||
logger.info(f"User {current_user.id} logged out from all devices ({count} sessions)")
|
logger.info(f"User {current_user.id} logged out from all devices ({count} sessions)")
|
||||||
|
|
||||||
@@ -591,7 +591,7 @@ def logout_all(
|
|||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error during logout-all for user {current_user.id}: {str(e)}", exc_info=True)
|
logger.error(f"Error during logout-all for user {current_user.id}: {str(e)}", exc_info=True)
|
||||||
db.rollback()
|
await db.rollback()
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
detail="An error occurred while logging out"
|
detail="An error occurred while logging out"
|
||||||
|
|||||||
38
backend/app/api/routes/organizations.py
Normal file → Executable file
38
backend/app/api/routes/organizations.py
Normal file → Executable file
@@ -9,12 +9,12 @@ from typing import Any, List, Optional
|
|||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
|
|
||||||
from fastapi import APIRouter, Depends, Query, status
|
from fastapi import APIRouter, Depends, Query, status
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
from app.api.dependencies.auth import get_current_user
|
from app.api.dependencies.auth import get_current_user
|
||||||
from app.api.dependencies.permissions import require_org_admin, require_org_membership, get_current_org_role
|
from app.api.dependencies.permissions import require_org_admin, require_org_membership, get_current_org_role
|
||||||
from app.core.database import get_db
|
from app.core.database_async import get_async_db
|
||||||
from app.crud.organization import organization as organization_crud
|
from app.crud.organization_async import organization_async as organization_crud
|
||||||
from app.models.user import User
|
from app.models.user import User
|
||||||
from app.models.user_organization import OrganizationRole
|
from app.models.user_organization import OrganizationRole
|
||||||
from app.schemas.organizations import (
|
from app.schemas.organizations import (
|
||||||
@@ -42,10 +42,10 @@ router = APIRouter()
|
|||||||
description="Get all organizations the current user belongs to",
|
description="Get all organizations the current user belongs to",
|
||||||
operation_id="get_my_organizations"
|
operation_id="get_my_organizations"
|
||||||
)
|
)
|
||||||
def get_my_organizations(
|
async def get_my_organizations(
|
||||||
is_active: bool = Query(True, description="Filter by active membership"),
|
is_active: bool = Query(True, description="Filter by active membership"),
|
||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
db: Session = Depends(get_db)
|
db: AsyncSession = Depends(get_async_db)
|
||||||
) -> Any:
|
) -> Any:
|
||||||
"""
|
"""
|
||||||
Get all organizations the current user belongs to.
|
Get all organizations the current user belongs to.
|
||||||
@@ -53,7 +53,7 @@ def get_my_organizations(
|
|||||||
Returns organizations with member count for each.
|
Returns organizations with member count for each.
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
orgs = organization_crud.get_user_organizations(
|
orgs = await organization_crud.get_user_organizations(
|
||||||
db,
|
db,
|
||||||
user_id=current_user.id,
|
user_id=current_user.id,
|
||||||
is_active=is_active
|
is_active=is_active
|
||||||
@@ -77,7 +77,7 @@ def get_my_organizations(
|
|||||||
"settings": org.settings,
|
"settings": org.settings,
|
||||||
"created_at": org.created_at,
|
"created_at": org.created_at,
|
||||||
"updated_at": org.updated_at,
|
"updated_at": org.updated_at,
|
||||||
"member_count": organization_crud.get_member_count(db, organization_id=org.id)
|
"member_count": await organization_crud.get_member_count(db, organization_id=org.id)
|
||||||
}
|
}
|
||||||
orgs_with_data.append(OrganizationResponse(**org_dict))
|
orgs_with_data.append(OrganizationResponse(**org_dict))
|
||||||
|
|
||||||
@@ -95,10 +95,10 @@ def get_my_organizations(
|
|||||||
description="Get details of an organization the user belongs to",
|
description="Get details of an organization the user belongs to",
|
||||||
operation_id="get_organization"
|
operation_id="get_organization"
|
||||||
)
|
)
|
||||||
def get_organization(
|
async def get_organization(
|
||||||
organization_id: UUID,
|
organization_id: UUID,
|
||||||
current_user: User = Depends(require_org_membership),
|
current_user: User = Depends(require_org_membership),
|
||||||
db: Session = Depends(get_db)
|
db: AsyncSession = Depends(get_async_db)
|
||||||
) -> Any:
|
) -> Any:
|
||||||
"""
|
"""
|
||||||
Get details of a specific organization.
|
Get details of a specific organization.
|
||||||
@@ -106,7 +106,7 @@ def get_organization(
|
|||||||
User must be a member of the organization.
|
User must be a member of the organization.
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
org = organization_crud.get(db, id=organization_id)
|
org = await organization_crud.get(db, id=organization_id)
|
||||||
if not org:
|
if not org:
|
||||||
raise NotFoundError(
|
raise NotFoundError(
|
||||||
detail=f"Organization {organization_id} not found",
|
detail=f"Organization {organization_id} not found",
|
||||||
@@ -122,7 +122,7 @@ def get_organization(
|
|||||||
"settings": org.settings,
|
"settings": org.settings,
|
||||||
"created_at": org.created_at,
|
"created_at": org.created_at,
|
||||||
"updated_at": org.updated_at,
|
"updated_at": org.updated_at,
|
||||||
"member_count": organization_crud.get_member_count(db, organization_id=org.id)
|
"member_count": await organization_crud.get_member_count(db, organization_id=org.id)
|
||||||
}
|
}
|
||||||
return OrganizationResponse(**org_dict)
|
return OrganizationResponse(**org_dict)
|
||||||
|
|
||||||
@@ -140,12 +140,12 @@ def get_organization(
|
|||||||
description="Get all members of an organization (members can view)",
|
description="Get all members of an organization (members can view)",
|
||||||
operation_id="get_organization_members"
|
operation_id="get_organization_members"
|
||||||
)
|
)
|
||||||
def get_organization_members(
|
async def get_organization_members(
|
||||||
organization_id: UUID,
|
organization_id: UUID,
|
||||||
pagination: PaginationParams = Depends(),
|
pagination: PaginationParams = Depends(),
|
||||||
is_active: bool = Query(True, description="Filter by active status"),
|
is_active: bool = Query(True, description="Filter by active status"),
|
||||||
current_user: User = Depends(require_org_membership),
|
current_user: User = Depends(require_org_membership),
|
||||||
db: Session = Depends(get_db)
|
db: AsyncSession = Depends(get_async_db)
|
||||||
) -> Any:
|
) -> Any:
|
||||||
"""
|
"""
|
||||||
Get all members of an organization.
|
Get all members of an organization.
|
||||||
@@ -153,7 +153,7 @@ def get_organization_members(
|
|||||||
User must be a member of the organization to view members.
|
User must be a member of the organization to view members.
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
members, total = organization_crud.get_organization_members(
|
members, total = await organization_crud.get_organization_members(
|
||||||
db,
|
db,
|
||||||
organization_id=organization_id,
|
organization_id=organization_id,
|
||||||
skip=pagination.offset,
|
skip=pagination.offset,
|
||||||
@@ -184,11 +184,11 @@ def get_organization_members(
|
|||||||
description="Update organization details (admin/owner only)",
|
description="Update organization details (admin/owner only)",
|
||||||
operation_id="update_organization"
|
operation_id="update_organization"
|
||||||
)
|
)
|
||||||
def update_organization(
|
async def update_organization(
|
||||||
organization_id: UUID,
|
organization_id: UUID,
|
||||||
org_in: OrganizationUpdate,
|
org_in: OrganizationUpdate,
|
||||||
current_user: User = Depends(require_org_admin),
|
current_user: User = Depends(require_org_admin),
|
||||||
db: Session = Depends(get_db)
|
db: AsyncSession = Depends(get_async_db)
|
||||||
) -> Any:
|
) -> Any:
|
||||||
"""
|
"""
|
||||||
Update organization details.
|
Update organization details.
|
||||||
@@ -196,14 +196,14 @@ def update_organization(
|
|||||||
Requires owner or admin role in the organization.
|
Requires owner or admin role in the organization.
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
org = organization_crud.get(db, id=organization_id)
|
org = await organization_crud.get(db, id=organization_id)
|
||||||
if not org:
|
if not org:
|
||||||
raise NotFoundError(
|
raise NotFoundError(
|
||||||
detail=f"Organization {organization_id} not found",
|
detail=f"Organization {organization_id} not found",
|
||||||
error_code=ErrorCode.NOT_FOUND
|
error_code=ErrorCode.NOT_FOUND
|
||||||
)
|
)
|
||||||
|
|
||||||
updated_org = organization_crud.update(db, db_obj=org, obj_in=org_in)
|
updated_org = await organization_crud.update(db, db_obj=org, obj_in=org_in)
|
||||||
logger.info(f"User {current_user.email} updated organization {updated_org.name}")
|
logger.info(f"User {current_user.email} updated organization {updated_org.name}")
|
||||||
|
|
||||||
org_dict = {
|
org_dict = {
|
||||||
@@ -215,7 +215,7 @@ def update_organization(
|
|||||||
"settings": updated_org.settings,
|
"settings": updated_org.settings,
|
||||||
"created_at": updated_org.created_at,
|
"created_at": updated_org.created_at,
|
||||||
"updated_at": updated_org.updated_at,
|
"updated_at": updated_org.updated_at,
|
||||||
"member_count": organization_crud.get_member_count(db, organization_id=updated_org.id)
|
"member_count": await organization_crud.get_member_count(db, organization_id=updated_org.id)
|
||||||
}
|
}
|
||||||
return OrganizationResponse(**org_dict)
|
return OrganizationResponse(**org_dict)
|
||||||
|
|
||||||
|
|||||||
32
backend/app/api/routes/sessions.py
Normal file → Executable file
32
backend/app/api/routes/sessions.py
Normal file → Executable file
@@ -10,15 +10,15 @@ from uuid import UUID
|
|||||||
from fastapi import APIRouter, Depends, HTTPException, status, Request
|
from fastapi import APIRouter, Depends, HTTPException, status, Request
|
||||||
from slowapi import Limiter
|
from slowapi import Limiter
|
||||||
from slowapi.util import get_remote_address
|
from slowapi.util import get_remote_address
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
from app.api.dependencies.auth import get_current_user
|
from app.api.dependencies.auth import get_current_user
|
||||||
from app.core.database import get_db
|
from app.core.database_async import get_async_db
|
||||||
from app.core.auth import decode_token
|
from app.core.auth import decode_token
|
||||||
from app.models.user import User
|
from app.models.user import User
|
||||||
from app.schemas.sessions import SessionResponse, SessionListResponse
|
from app.schemas.sessions import SessionResponse, SessionListResponse
|
||||||
from app.schemas.common import MessageResponse
|
from app.schemas.common import MessageResponse
|
||||||
from app.crud.session import session as session_crud
|
from app.crud.session_async import session_async as session_crud
|
||||||
from app.core.exceptions import NotFoundError, AuthorizationError, ErrorCode
|
from app.core.exceptions import NotFoundError, AuthorizationError, ErrorCode
|
||||||
|
|
||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
@@ -42,10 +42,10 @@ limiter = Limiter(key_func=get_remote_address)
|
|||||||
operation_id="list_my_sessions"
|
operation_id="list_my_sessions"
|
||||||
)
|
)
|
||||||
@limiter.limit("30/minute")
|
@limiter.limit("30/minute")
|
||||||
def list_my_sessions(
|
async def list_my_sessions(
|
||||||
request: Request,
|
request: Request,
|
||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
db: Session = Depends(get_db)
|
db: AsyncSession = Depends(get_async_db)
|
||||||
) -> Any:
|
) -> Any:
|
||||||
"""
|
"""
|
||||||
List all active sessions for the current user.
|
List all active sessions for the current user.
|
||||||
@@ -59,7 +59,7 @@ def list_my_sessions(
|
|||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
# Get all active sessions for user
|
# Get all active sessions for user
|
||||||
sessions = session_crud.get_user_sessions(
|
sessions = await session_crud.get_user_sessions(
|
||||||
db,
|
db,
|
||||||
user_id=str(current_user.id),
|
user_id=str(current_user.id),
|
||||||
active_only=True
|
active_only=True
|
||||||
@@ -125,11 +125,11 @@ def list_my_sessions(
|
|||||||
operation_id="revoke_session"
|
operation_id="revoke_session"
|
||||||
)
|
)
|
||||||
@limiter.limit("10/minute")
|
@limiter.limit("10/minute")
|
||||||
def revoke_session(
|
async def revoke_session(
|
||||||
request: Request,
|
request: Request,
|
||||||
session_id: UUID,
|
session_id: UUID,
|
||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
db: Session = Depends(get_db)
|
db: AsyncSession = Depends(get_async_db)
|
||||||
) -> Any:
|
) -> Any:
|
||||||
"""
|
"""
|
||||||
Revoke a specific session by ID.
|
Revoke a specific session by ID.
|
||||||
@@ -144,7 +144,7 @@ def revoke_session(
|
|||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
# Get the session
|
# Get the session
|
||||||
session = session_crud.get(db, id=str(session_id))
|
session = await session_crud.get(db, id=str(session_id))
|
||||||
|
|
||||||
if not session:
|
if not session:
|
||||||
raise NotFoundError(
|
raise NotFoundError(
|
||||||
@@ -164,7 +164,7 @@ def revoke_session(
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Deactivate the session
|
# Deactivate the session
|
||||||
session_crud.deactivate(db, session_id=str(session_id))
|
await session_crud.deactivate(db, session_id=str(session_id))
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
f"User {current_user.id} revoked session {session_id} "
|
f"User {current_user.id} revoked session {session_id} "
|
||||||
@@ -201,10 +201,10 @@ def revoke_session(
|
|||||||
operation_id="cleanup_expired_sessions"
|
operation_id="cleanup_expired_sessions"
|
||||||
)
|
)
|
||||||
@limiter.limit("5/minute")
|
@limiter.limit("5/minute")
|
||||||
def cleanup_expired_sessions(
|
async def cleanup_expired_sessions(
|
||||||
request: Request,
|
request: Request,
|
||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
db: Session = Depends(get_db)
|
db: AsyncSession = Depends(get_async_db)
|
||||||
) -> Any:
|
) -> Any:
|
||||||
"""
|
"""
|
||||||
Cleanup expired sessions for the current user.
|
Cleanup expired sessions for the current user.
|
||||||
@@ -220,7 +220,7 @@ def cleanup_expired_sessions(
|
|||||||
from datetime import datetime, timezone
|
from datetime import datetime, timezone
|
||||||
|
|
||||||
# Get all sessions for user
|
# Get all sessions for user
|
||||||
all_sessions = session_crud.get_user_sessions(
|
all_sessions = await session_crud.get_user_sessions(
|
||||||
db,
|
db,
|
||||||
user_id=str(current_user.id),
|
user_id=str(current_user.id),
|
||||||
active_only=False
|
active_only=False
|
||||||
@@ -230,10 +230,10 @@ def cleanup_expired_sessions(
|
|||||||
deleted_count = 0
|
deleted_count = 0
|
||||||
for s in all_sessions:
|
for s in all_sessions:
|
||||||
if not s.is_active and s.expires_at < datetime.now(timezone.utc):
|
if not s.is_active and s.expires_at < datetime.now(timezone.utc):
|
||||||
db.delete(s)
|
await db.delete(s)
|
||||||
deleted_count += 1
|
deleted_count += 1
|
||||||
|
|
||||||
db.commit()
|
await db.commit()
|
||||||
|
|
||||||
logger.info(f"User {current_user.id} cleaned up {deleted_count} expired sessions")
|
logger.info(f"User {current_user.id} cleaned up {deleted_count} expired sessions")
|
||||||
|
|
||||||
@@ -244,7 +244,7 @@ def cleanup_expired_sessions(
|
|||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error cleaning up sessions for user {current_user.id}: {str(e)}", exc_info=True)
|
logger.error(f"Error cleaning up sessions for user {current_user.id}: {str(e)}", exc_info=True)
|
||||||
db.rollback()
|
await db.rollback()
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
detail="Failed to cleanup sessions"
|
detail="Failed to cleanup sessions"
|
||||||
|
|||||||
46
backend/app/api/routes/users.py
Normal file → Executable file
46
backend/app/api/routes/users.py
Normal file → Executable file
@@ -6,13 +6,13 @@ from typing import Any, Optional
|
|||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
|
|
||||||
from fastapi import APIRouter, Depends, Query, status, Request
|
from fastapi import APIRouter, Depends, Query, status, Request
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
from slowapi import Limiter
|
from slowapi import Limiter
|
||||||
from slowapi.util import get_remote_address
|
from slowapi.util import get_remote_address
|
||||||
|
|
||||||
from app.api.dependencies.auth import get_current_user, get_current_superuser
|
from app.api.dependencies.auth import get_current_user, get_current_superuser
|
||||||
from app.core.database import get_db
|
from app.core.database_async import get_async_db
|
||||||
from app.crud.user import user as user_crud
|
from app.crud.user_async import user_async as user_crud
|
||||||
from app.models.user import User
|
from app.models.user import User
|
||||||
from app.schemas.users import UserResponse, UserUpdate, PasswordChange
|
from app.schemas.users import UserResponse, UserUpdate, PasswordChange
|
||||||
from app.schemas.common import (
|
from app.schemas.common import (
|
||||||
@@ -52,13 +52,13 @@ limiter = Limiter(key_func=get_remote_address)
|
|||||||
""",
|
""",
|
||||||
operation_id="list_users"
|
operation_id="list_users"
|
||||||
)
|
)
|
||||||
def list_users(
|
async def list_users(
|
||||||
pagination: PaginationParams = Depends(),
|
pagination: PaginationParams = Depends(),
|
||||||
sort: SortParams = Depends(),
|
sort: SortParams = Depends(),
|
||||||
is_active: Optional[bool] = Query(None, description="Filter by active status"),
|
is_active: Optional[bool] = Query(None, description="Filter by active status"),
|
||||||
is_superuser: Optional[bool] = Query(None, description="Filter by superuser status"),
|
is_superuser: Optional[bool] = Query(None, description="Filter by superuser status"),
|
||||||
current_user: User = Depends(get_current_superuser),
|
current_user: User = Depends(get_current_superuser),
|
||||||
db: Session = Depends(get_db)
|
db: AsyncSession = Depends(get_async_db)
|
||||||
) -> Any:
|
) -> Any:
|
||||||
"""
|
"""
|
||||||
List all users with pagination, filtering, and sorting.
|
List all users with pagination, filtering, and sorting.
|
||||||
@@ -74,7 +74,7 @@ def list_users(
|
|||||||
filters["is_superuser"] = is_superuser
|
filters["is_superuser"] = is_superuser
|
||||||
|
|
||||||
# Get paginated users with total count
|
# Get paginated users with total count
|
||||||
users, total = user_crud.get_multi_with_total(
|
users, total = await user_crud.get_multi_with_total(
|
||||||
db,
|
db,
|
||||||
skip=pagination.offset,
|
skip=pagination.offset,
|
||||||
limit=pagination.limit,
|
limit=pagination.limit,
|
||||||
@@ -135,10 +135,10 @@ def get_current_user_profile(
|
|||||||
""",
|
""",
|
||||||
operation_id="update_current_user"
|
operation_id="update_current_user"
|
||||||
)
|
)
|
||||||
def update_current_user(
|
async def update_current_user(
|
||||||
user_update: UserUpdate,
|
user_update: UserUpdate,
|
||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
db: Session = Depends(get_db)
|
db: AsyncSession = Depends(get_async_db)
|
||||||
) -> Any:
|
) -> Any:
|
||||||
"""
|
"""
|
||||||
Update current user's profile.
|
Update current user's profile.
|
||||||
@@ -154,7 +154,7 @@ def update_current_user(
|
|||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
updated_user = user_crud.update(
|
updated_user = await user_crud.update(
|
||||||
db,
|
db,
|
||||||
db_obj=current_user,
|
db_obj=current_user,
|
||||||
obj_in=user_update
|
obj_in=user_update
|
||||||
@@ -185,10 +185,10 @@ def update_current_user(
|
|||||||
""",
|
""",
|
||||||
operation_id="get_user_by_id"
|
operation_id="get_user_by_id"
|
||||||
)
|
)
|
||||||
def get_user_by_id(
|
async def get_user_by_id(
|
||||||
user_id: UUID,
|
user_id: UUID,
|
||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
db: Session = Depends(get_db)
|
db: AsyncSession = Depends(get_async_db)
|
||||||
) -> Any:
|
) -> Any:
|
||||||
"""
|
"""
|
||||||
Get user by ID.
|
Get user by ID.
|
||||||
@@ -206,7 +206,7 @@ def get_user_by_id(
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Get user
|
# Get user
|
||||||
user = user_crud.get(db, id=str(user_id))
|
user = await user_crud.get(db, id=str(user_id))
|
||||||
if not user:
|
if not user:
|
||||||
raise NotFoundError(
|
raise NotFoundError(
|
||||||
message=f"User with id {user_id} not found",
|
message=f"User with id {user_id} not found",
|
||||||
@@ -232,11 +232,11 @@ def get_user_by_id(
|
|||||||
""",
|
""",
|
||||||
operation_id="update_user"
|
operation_id="update_user"
|
||||||
)
|
)
|
||||||
def update_user(
|
async def update_user(
|
||||||
user_id: UUID,
|
user_id: UUID,
|
||||||
user_update: UserUpdate,
|
user_update: UserUpdate,
|
||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
db: Session = Depends(get_db)
|
db: AsyncSession = Depends(get_async_db)
|
||||||
) -> Any:
|
) -> Any:
|
||||||
"""
|
"""
|
||||||
Update user by ID.
|
Update user by ID.
|
||||||
@@ -257,7 +257,7 @@ def update_user(
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Get user
|
# Get user
|
||||||
user = user_crud.get(db, id=str(user_id))
|
user = await user_crud.get(db, id=str(user_id))
|
||||||
if not user:
|
if not user:
|
||||||
raise NotFoundError(
|
raise NotFoundError(
|
||||||
message=f"User with id {user_id} not found",
|
message=f"User with id {user_id} not found",
|
||||||
@@ -273,7 +273,7 @@ def update_user(
|
|||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
updated_user = user_crud.update(db, db_obj=user, obj_in=user_update)
|
updated_user = await user_crud.update(db, db_obj=user, obj_in=user_update)
|
||||||
logger.info(f"User {user_id} updated by {current_user.id}")
|
logger.info(f"User {user_id} updated by {current_user.id}")
|
||||||
return updated_user
|
return updated_user
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
@@ -300,11 +300,11 @@ def update_user(
|
|||||||
operation_id="change_current_user_password"
|
operation_id="change_current_user_password"
|
||||||
)
|
)
|
||||||
@limiter.limit("5/minute")
|
@limiter.limit("5/minute")
|
||||||
def change_current_user_password(
|
async def change_current_user_password(
|
||||||
request: Request,
|
request: Request,
|
||||||
password_change: PasswordChange,
|
password_change: PasswordChange,
|
||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
db: Session = Depends(get_db)
|
db: AsyncSession = Depends(get_async_db)
|
||||||
) -> Any:
|
) -> Any:
|
||||||
"""
|
"""
|
||||||
Change current user's password.
|
Change current user's password.
|
||||||
@@ -312,7 +312,7 @@ def change_current_user_password(
|
|||||||
Requires current password for verification.
|
Requires current password for verification.
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
success = AuthService.change_password(
|
success = await AuthService.change_password(
|
||||||
db=db,
|
db=db,
|
||||||
user_id=current_user.id,
|
user_id=current_user.id,
|
||||||
current_password=password_change.current_password,
|
current_password=password_change.current_password,
|
||||||
@@ -353,10 +353,10 @@ def change_current_user_password(
|
|||||||
""",
|
""",
|
||||||
operation_id="delete_user"
|
operation_id="delete_user"
|
||||||
)
|
)
|
||||||
def delete_user(
|
async def delete_user(
|
||||||
user_id: UUID,
|
user_id: UUID,
|
||||||
current_user: User = Depends(get_current_superuser),
|
current_user: User = Depends(get_current_superuser),
|
||||||
db: Session = Depends(get_db)
|
db: AsyncSession = Depends(get_async_db)
|
||||||
) -> Any:
|
) -> Any:
|
||||||
"""
|
"""
|
||||||
Delete user by ID (superuser only).
|
Delete user by ID (superuser only).
|
||||||
@@ -371,7 +371,7 @@ def delete_user(
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Get user
|
# Get user
|
||||||
user = user_crud.get(db, id=str(user_id))
|
user = await user_crud.get(db, id=str(user_id))
|
||||||
if not user:
|
if not user:
|
||||||
raise NotFoundError(
|
raise NotFoundError(
|
||||||
message=f"User with id {user_id} not found",
|
message=f"User with id {user_id} not found",
|
||||||
@@ -380,7 +380,7 @@ def delete_user(
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
# Use soft delete instead of hard delete
|
# Use soft delete instead of hard delete
|
||||||
user_crud.soft_delete(db, id=str(user_id))
|
await user_crud.soft_delete(db, id=str(user_id))
|
||||||
logger.info(f"User {user_id} soft-deleted by {current_user.id}")
|
logger.info(f"User {user_id} soft-deleted by {current_user.id}")
|
||||||
return MessageResponse(
|
return MessageResponse(
|
||||||
success=True,
|
success=True,
|
||||||
|
|||||||
4
backend/app/core/database_async.py
Normal file → Executable file
4
backend/app/core/database_async.py
Normal file → Executable file
@@ -159,6 +159,10 @@ async def check_async_database_health() -> bool:
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
# Alias for consistency with main.py
|
||||||
|
check_database_health = check_async_database_health
|
||||||
|
|
||||||
|
|
||||||
async def init_async_db() -> None:
|
async def init_async_db() -> None:
|
||||||
"""
|
"""
|
||||||
Initialize async database tables.
|
Initialize async database tables.
|
||||||
|
|||||||
143
backend/app/crud/base_async.py
Normal file → Executable file
143
backend/app/crud/base_async.py
Normal file → Executable file
@@ -179,10 +179,25 @@ class CRUDBaseAsync(Generic[ModelType, CreateSchemaType, UpdateSchemaType]):
|
|||||||
raise
|
raise
|
||||||
|
|
||||||
async def get_multi_with_total(
|
async def get_multi_with_total(
|
||||||
self, db: AsyncSession, *, skip: int = 0, limit: int = 100
|
self,
|
||||||
|
db: AsyncSession,
|
||||||
|
*,
|
||||||
|
skip: int = 0,
|
||||||
|
limit: int = 100,
|
||||||
|
sort_by: Optional[str] = None,
|
||||||
|
sort_order: str = "asc",
|
||||||
|
filters: Optional[Dict[str, Any]] = None
|
||||||
) -> Tuple[List[ModelType], int]:
|
) -> Tuple[List[ModelType], int]:
|
||||||
"""
|
"""
|
||||||
Get multiple records with total count for pagination.
|
Get multiple records with total count, filtering, and sorting.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
db: Database session
|
||||||
|
skip: Number of records to skip
|
||||||
|
limit: Maximum number of records to return
|
||||||
|
sort_by: Field name to sort by (must be a valid model attribute)
|
||||||
|
sort_order: Sort order ("asc" or "desc")
|
||||||
|
filters: Dictionary of filters (field_name: value)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Tuple of (items, total_count)
|
Tuple of (items, total_count)
|
||||||
@@ -196,16 +211,35 @@ class CRUDBaseAsync(Generic[ModelType, CreateSchemaType, UpdateSchemaType]):
|
|||||||
raise ValueError("Maximum limit is 1000")
|
raise ValueError("Maximum limit is 1000")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Get total count
|
# Build base query
|
||||||
count_result = await db.execute(
|
query = select(self.model)
|
||||||
select(func.count(self.model.id))
|
|
||||||
)
|
# Exclude soft-deleted records by default
|
||||||
|
if hasattr(self.model, 'deleted_at'):
|
||||||
|
query = query.where(self.model.deleted_at.is_(None))
|
||||||
|
|
||||||
|
# Apply filters
|
||||||
|
if filters:
|
||||||
|
for field, value in filters.items():
|
||||||
|
if hasattr(self.model, field) and value is not None:
|
||||||
|
query = query.where(getattr(self.model, field) == value)
|
||||||
|
|
||||||
|
# Get total count (before pagination)
|
||||||
|
count_query = select(func.count()).select_from(query.alias())
|
||||||
|
count_result = await db.execute(count_query)
|
||||||
total = count_result.scalar_one()
|
total = count_result.scalar_one()
|
||||||
|
|
||||||
# Get paginated items
|
# Apply sorting
|
||||||
items_result = await db.execute(
|
if sort_by and hasattr(self.model, sort_by):
|
||||||
select(self.model).offset(skip).limit(limit)
|
sort_column = getattr(self.model, sort_by)
|
||||||
)
|
if sort_order.lower() == "desc":
|
||||||
|
query = query.order_by(sort_column.desc())
|
||||||
|
else:
|
||||||
|
query = query.order_by(sort_column.asc())
|
||||||
|
|
||||||
|
# Apply pagination
|
||||||
|
query = query.offset(skip).limit(limit)
|
||||||
|
items_result = await db.execute(query)
|
||||||
items = list(items_result.scalars().all())
|
items = list(items_result.scalars().all())
|
||||||
|
|
||||||
return items, total
|
return items, total
|
||||||
@@ -226,3 +260,92 @@ class CRUDBaseAsync(Generic[ModelType, CreateSchemaType, UpdateSchemaType]):
|
|||||||
"""Check if a record exists by ID."""
|
"""Check if a record exists by ID."""
|
||||||
obj = await self.get(db, id=id)
|
obj = await self.get(db, id=id)
|
||||||
return obj is not None
|
return obj is not None
|
||||||
|
|
||||||
|
async def soft_delete(self, db: AsyncSession, *, id: str) -> Optional[ModelType]:
|
||||||
|
"""
|
||||||
|
Soft delete a record by setting deleted_at timestamp.
|
||||||
|
|
||||||
|
Only works if the model has a 'deleted_at' column.
|
||||||
|
"""
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
|
||||||
|
# Validate UUID format and convert to UUID object if string
|
||||||
|
try:
|
||||||
|
if isinstance(id, uuid.UUID):
|
||||||
|
uuid_obj = id
|
||||||
|
else:
|
||||||
|
uuid_obj = uuid.UUID(str(id))
|
||||||
|
except (ValueError, AttributeError, TypeError) as e:
|
||||||
|
logger.warning(f"Invalid UUID format for soft deletion: {id} - {str(e)}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
try:
|
||||||
|
result = await db.execute(
|
||||||
|
select(self.model).where(self.model.id == uuid_obj)
|
||||||
|
)
|
||||||
|
obj = result.scalar_one_or_none()
|
||||||
|
|
||||||
|
if obj is None:
|
||||||
|
logger.warning(f"{self.model.__name__} with id {id} not found for soft deletion")
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Check if model supports soft deletes
|
||||||
|
if not hasattr(self.model, 'deleted_at'):
|
||||||
|
logger.error(f"{self.model.__name__} does not support soft deletes")
|
||||||
|
raise ValueError(f"{self.model.__name__} does not have a deleted_at column")
|
||||||
|
|
||||||
|
# Set deleted_at timestamp
|
||||||
|
obj.deleted_at = datetime.now(timezone.utc)
|
||||||
|
db.add(obj)
|
||||||
|
await db.commit()
|
||||||
|
await db.refresh(obj)
|
||||||
|
return obj
|
||||||
|
except Exception as e:
|
||||||
|
await db.rollback()
|
||||||
|
logger.error(f"Error soft deleting {self.model.__name__} with id {id}: {str(e)}", exc_info=True)
|
||||||
|
raise
|
||||||
|
|
||||||
|
async def restore(self, db: AsyncSession, *, id: str) -> Optional[ModelType]:
|
||||||
|
"""
|
||||||
|
Restore a soft-deleted record by clearing the deleted_at timestamp.
|
||||||
|
|
||||||
|
Only works if the model has a 'deleted_at' column.
|
||||||
|
"""
|
||||||
|
# Validate UUID format
|
||||||
|
try:
|
||||||
|
if isinstance(id, uuid.UUID):
|
||||||
|
uuid_obj = id
|
||||||
|
else:
|
||||||
|
uuid_obj = uuid.UUID(str(id))
|
||||||
|
except (ValueError, AttributeError, TypeError) as e:
|
||||||
|
logger.warning(f"Invalid UUID format for restoration: {id} - {str(e)}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Find the soft-deleted record
|
||||||
|
if hasattr(self.model, 'deleted_at'):
|
||||||
|
result = await db.execute(
|
||||||
|
select(self.model).where(
|
||||||
|
self.model.id == uuid_obj,
|
||||||
|
self.model.deleted_at.isnot(None)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
obj = result.scalar_one_or_none()
|
||||||
|
else:
|
||||||
|
logger.error(f"{self.model.__name__} does not support soft deletes")
|
||||||
|
raise ValueError(f"{self.model.__name__} does not have a deleted_at column")
|
||||||
|
|
||||||
|
if obj is None:
|
||||||
|
logger.warning(f"Soft-deleted {self.model.__name__} with id {id} not found for restoration")
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Clear deleted_at timestamp
|
||||||
|
obj.deleted_at = None
|
||||||
|
db.add(obj)
|
||||||
|
await db.commit()
|
||||||
|
await db.refresh(obj)
|
||||||
|
return obj
|
||||||
|
except Exception as e:
|
||||||
|
await db.rollback()
|
||||||
|
logger.error(f"Error restoring {self.model.__name__} with id {id}: {str(e)}", exc_info=True)
|
||||||
|
raise
|
||||||
|
|||||||
0
backend/app/init_db.py
Normal file → Executable file
0
backend/app/init_db.py
Normal file → Executable file
4
backend/app/main.py
Normal file → Executable file
4
backend/app/main.py
Normal file → Executable file
@@ -14,7 +14,7 @@ from sqlalchemy import text
|
|||||||
|
|
||||||
from app.api.main import api_router
|
from app.api.main import api_router
|
||||||
from app.core.config import settings
|
from app.core.config import settings
|
||||||
from app.core.database import get_db, check_database_health
|
from app.core.database_async import check_database_health
|
||||||
from app.core.exceptions import (
|
from app.core.exceptions import (
|
||||||
APIException,
|
APIException,
|
||||||
api_exception_handler,
|
api_exception_handler,
|
||||||
@@ -218,7 +218,7 @@ async def health_check() -> JSONResponse:
|
|||||||
|
|
||||||
# Database health check using dedicated health check function
|
# Database health check using dedicated health check function
|
||||||
try:
|
try:
|
||||||
db_healthy = check_database_health()
|
db_healthy = await check_database_health()
|
||||||
if db_healthy:
|
if db_healthy:
|
||||||
health_status["checks"]["database"] = {
|
health_status["checks"]["database"] = {
|
||||||
"status": "healthy",
|
"status": "healthy",
|
||||||
|
|||||||
29
backend/app/services/auth_service.py
Normal file → Executable file
29
backend/app/services/auth_service.py
Normal file → Executable file
@@ -3,7 +3,8 @@ import logging
|
|||||||
from typing import Optional
|
from typing import Optional
|
||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
|
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
from sqlalchemy import select
|
||||||
|
|
||||||
from app.core.auth import (
|
from app.core.auth import (
|
||||||
verify_password,
|
verify_password,
|
||||||
@@ -28,7 +29,7 @@ class AuthService:
|
|||||||
"""Service for handling authentication operations"""
|
"""Service for handling authentication operations"""
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def authenticate_user(db: Session, email: str, password: str) -> Optional[User]:
|
async def authenticate_user(db: AsyncSession, email: str, password: str) -> Optional[User]:
|
||||||
"""
|
"""
|
||||||
Authenticate a user with email and password.
|
Authenticate a user with email and password.
|
||||||
|
|
||||||
@@ -40,7 +41,8 @@ class AuthService:
|
|||||||
Returns:
|
Returns:
|
||||||
User if authenticated, None otherwise
|
User if authenticated, None otherwise
|
||||||
"""
|
"""
|
||||||
user = db.query(User).filter(User.email == email).first()
|
result = await db.execute(select(User).where(User.email == email))
|
||||||
|
user = result.scalar_one_or_none()
|
||||||
|
|
||||||
if not user:
|
if not user:
|
||||||
return None
|
return None
|
||||||
@@ -54,7 +56,7 @@ class AuthService:
|
|||||||
return user
|
return user
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def create_user(db: Session, user_data: UserCreate) -> User:
|
async def create_user(db: AsyncSession, user_data: UserCreate) -> User:
|
||||||
"""
|
"""
|
||||||
Create a new user.
|
Create a new user.
|
||||||
|
|
||||||
@@ -66,7 +68,8 @@ class AuthService:
|
|||||||
Created user
|
Created user
|
||||||
"""
|
"""
|
||||||
# Check if user already exists
|
# Check if user already exists
|
||||||
existing_user = db.query(User).filter(User.email == user_data.email).first()
|
result = await db.execute(select(User).where(User.email == user_data.email))
|
||||||
|
existing_user = result.scalar_one_or_none()
|
||||||
if existing_user:
|
if existing_user:
|
||||||
raise AuthenticationError("User with this email already exists")
|
raise AuthenticationError("User with this email already exists")
|
||||||
|
|
||||||
@@ -85,8 +88,8 @@ class AuthService:
|
|||||||
)
|
)
|
||||||
|
|
||||||
db.add(user)
|
db.add(user)
|
||||||
db.commit()
|
await db.commit()
|
||||||
db.refresh(user)
|
await db.refresh(user)
|
||||||
|
|
||||||
return user
|
return user
|
||||||
|
|
||||||
@@ -124,7 +127,7 @@ class AuthService:
|
|||||||
)
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def refresh_tokens(db: Session, refresh_token: str) -> Token:
|
async def refresh_tokens(db: AsyncSession, refresh_token: str) -> Token:
|
||||||
"""
|
"""
|
||||||
Generate new tokens using a refresh token.
|
Generate new tokens using a refresh token.
|
||||||
|
|
||||||
@@ -150,7 +153,8 @@ class AuthService:
|
|||||||
user_id = token_data.user_id
|
user_id = token_data.user_id
|
||||||
|
|
||||||
# Get user from database
|
# Get user from database
|
||||||
user = db.query(User).filter(User.id == user_id).first()
|
result = await db.execute(select(User).where(User.id == user_id))
|
||||||
|
user = result.scalar_one_or_none()
|
||||||
if not user or not user.is_active:
|
if not user or not user.is_active:
|
||||||
raise TokenInvalidError("Invalid user or inactive account")
|
raise TokenInvalidError("Invalid user or inactive account")
|
||||||
|
|
||||||
@@ -162,7 +166,7 @@ class AuthService:
|
|||||||
raise
|
raise
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def change_password(db: Session, user_id: UUID, current_password: str, new_password: str) -> bool:
|
async def change_password(db: AsyncSession, user_id: UUID, current_password: str, new_password: str) -> bool:
|
||||||
"""
|
"""
|
||||||
Change a user's password.
|
Change a user's password.
|
||||||
|
|
||||||
@@ -178,7 +182,8 @@ class AuthService:
|
|||||||
Raises:
|
Raises:
|
||||||
AuthenticationError: If current password is incorrect
|
AuthenticationError: If current password is incorrect
|
||||||
"""
|
"""
|
||||||
user = db.query(User).filter(User.id == user_id).first()
|
result = await db.execute(select(User).where(User.id == user_id))
|
||||||
|
user = result.scalar_one_or_none()
|
||||||
if not user:
|
if not user:
|
||||||
raise AuthenticationError("User not found")
|
raise AuthenticationError("User not found")
|
||||||
|
|
||||||
@@ -188,6 +193,6 @@ class AuthService:
|
|||||||
|
|
||||||
# Update password
|
# Update password
|
||||||
user.password_hash = get_password_hash(new_password)
|
user.password_hash = get_password_hash(new_password)
|
||||||
db.commit()
|
await db.commit()
|
||||||
|
|
||||||
return True
|
return True
|
||||||
|
|||||||
78
backend/app/services/session_cleanup.py
Normal file → Executable file
78
backend/app/services/session_cleanup.py
Normal file → Executable file
@@ -6,13 +6,13 @@ This service runs periodically to remove old session records from the database.
|
|||||||
import logging
|
import logging
|
||||||
from datetime import datetime, timezone
|
from datetime import datetime, timezone
|
||||||
|
|
||||||
from app.core.database import SessionLocal
|
from app.core.database_async import AsyncSessionLocal
|
||||||
from app.crud.session import session as session_crud
|
from app.crud.session_async import session_async as session_crud
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def cleanup_expired_sessions(keep_days: int = 30) -> int:
|
async def cleanup_expired_sessions(keep_days: int = 30) -> int:
|
||||||
"""
|
"""
|
||||||
Clean up expired and inactive sessions.
|
Clean up expired and inactive sessions.
|
||||||
|
|
||||||
@@ -29,52 +29,58 @@ def cleanup_expired_sessions(keep_days: int = 30) -> int:
|
|||||||
"""
|
"""
|
||||||
logger.info("Starting session cleanup job...")
|
logger.info("Starting session cleanup job...")
|
||||||
|
|
||||||
db = SessionLocal()
|
async with AsyncSessionLocal() as db:
|
||||||
try:
|
try:
|
||||||
# Use CRUD method to cleanup
|
# Use CRUD method to cleanup
|
||||||
count = session_crud.cleanup_expired(db, keep_days=keep_days)
|
count = await session_crud.cleanup_expired(db, keep_days=keep_days)
|
||||||
|
|
||||||
logger.info(f"Session cleanup complete: {count} sessions deleted")
|
logger.info(f"Session cleanup complete: {count} sessions deleted")
|
||||||
|
|
||||||
return count
|
return count
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error during session cleanup: {str(e)}", exc_info=True)
|
logger.error(f"Error during session cleanup: {str(e)}", exc_info=True)
|
||||||
return 0
|
return 0
|
||||||
finally:
|
|
||||||
db.close()
|
|
||||||
|
|
||||||
|
|
||||||
def get_session_statistics() -> dict:
|
async def get_session_statistics() -> dict:
|
||||||
"""
|
"""
|
||||||
Get statistics about current sessions.
|
Get statistics about current sessions.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Dictionary with session stats
|
Dictionary with session stats
|
||||||
"""
|
"""
|
||||||
db = SessionLocal()
|
async with AsyncSessionLocal() as db:
|
||||||
try:
|
try:
|
||||||
from app.models.user_session import UserSession
|
from app.models.user_session import UserSession
|
||||||
|
from sqlalchemy import select, func
|
||||||
|
|
||||||
total_sessions = db.query(UserSession).count()
|
total_result = await db.execute(select(func.count(UserSession.id)))
|
||||||
active_sessions = db.query(UserSession).filter(UserSession.is_active == True).count()
|
total_sessions = total_result.scalar_one()
|
||||||
expired_sessions = db.query(UserSession).filter(
|
|
||||||
UserSession.expires_at < datetime.now(timezone.utc)
|
|
||||||
).count()
|
|
||||||
|
|
||||||
stats = {
|
active_result = await db.execute(
|
||||||
"total": total_sessions,
|
select(func.count(UserSession.id)).where(UserSession.is_active == True)
|
||||||
"active": active_sessions,
|
)
|
||||||
"inactive": total_sessions - active_sessions,
|
active_sessions = active_result.scalar_one()
|
||||||
"expired": expired_sessions,
|
|
||||||
}
|
|
||||||
|
|
||||||
logger.info(f"Session statistics: {stats}")
|
expired_result = await db.execute(
|
||||||
|
select(func.count(UserSession.id)).where(
|
||||||
|
UserSession.expires_at < datetime.now(timezone.utc)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
expired_sessions = expired_result.scalar_one()
|
||||||
|
|
||||||
return stats
|
stats = {
|
||||||
|
"total": total_sessions,
|
||||||
|
"active": active_sessions,
|
||||||
|
"inactive": total_sessions - active_sessions,
|
||||||
|
"expired": expired_sessions,
|
||||||
|
}
|
||||||
|
|
||||||
except Exception as e:
|
logger.info(f"Session statistics: {stats}")
|
||||||
logger.error(f"Error getting session statistics: {str(e)}", exc_info=True)
|
|
||||||
return {}
|
return stats
|
||||||
finally:
|
|
||||||
db.close()
|
except Exception as e:
|
||||||
|
logger.error(f"Error getting session statistics: {str(e)}", exc_info=True)
|
||||||
|
return {}
|
||||||
|
|||||||
Reference in New Issue
Block a user