Refactor backend to adopt async patterns across services, API routes, and CRUD operations
- Migrated database sessions and operations to `AsyncSession` for full async support. - Updated all service methods and dependencies (`get_db` to `get_async_db`) to support async logic. - Refactored admin, user, organization, session-related CRUD methods, and routes with await syntax. - Improved consistency and performance with async SQLAlchemy patterns. - Enhanced logging and error handling for async context.
This commit is contained in:
29
backend/app/services/auth_service.py
Normal file → Executable file
29
backend/app/services/auth_service.py
Normal file → Executable file
@@ -3,7 +3,8 @@ import logging
|
||||
from typing import Optional
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import select
|
||||
|
||||
from app.core.auth import (
|
||||
verify_password,
|
||||
@@ -28,7 +29,7 @@ class AuthService:
|
||||
"""Service for handling authentication operations"""
|
||||
|
||||
@staticmethod
|
||||
def authenticate_user(db: Session, email: str, password: str) -> Optional[User]:
|
||||
async def authenticate_user(db: AsyncSession, email: str, password: str) -> Optional[User]:
|
||||
"""
|
||||
Authenticate a user with email and password.
|
||||
|
||||
@@ -40,7 +41,8 @@ class AuthService:
|
||||
Returns:
|
||||
User if authenticated, None otherwise
|
||||
"""
|
||||
user = db.query(User).filter(User.email == email).first()
|
||||
result = await db.execute(select(User).where(User.email == email))
|
||||
user = result.scalar_one_or_none()
|
||||
|
||||
if not user:
|
||||
return None
|
||||
@@ -54,7 +56,7 @@ class AuthService:
|
||||
return user
|
||||
|
||||
@staticmethod
|
||||
def create_user(db: Session, user_data: UserCreate) -> User:
|
||||
async def create_user(db: AsyncSession, user_data: UserCreate) -> User:
|
||||
"""
|
||||
Create a new user.
|
||||
|
||||
@@ -66,7 +68,8 @@ class AuthService:
|
||||
Created user
|
||||
"""
|
||||
# Check if user already exists
|
||||
existing_user = db.query(User).filter(User.email == user_data.email).first()
|
||||
result = await db.execute(select(User).where(User.email == user_data.email))
|
||||
existing_user = result.scalar_one_or_none()
|
||||
if existing_user:
|
||||
raise AuthenticationError("User with this email already exists")
|
||||
|
||||
@@ -85,8 +88,8 @@ class AuthService:
|
||||
)
|
||||
|
||||
db.add(user)
|
||||
db.commit()
|
||||
db.refresh(user)
|
||||
await db.commit()
|
||||
await db.refresh(user)
|
||||
|
||||
return user
|
||||
|
||||
@@ -124,7 +127,7 @@ class AuthService:
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def refresh_tokens(db: Session, refresh_token: str) -> Token:
|
||||
async def refresh_tokens(db: AsyncSession, refresh_token: str) -> Token:
|
||||
"""
|
||||
Generate new tokens using a refresh token.
|
||||
|
||||
@@ -150,7 +153,8 @@ class AuthService:
|
||||
user_id = token_data.user_id
|
||||
|
||||
# Get user from database
|
||||
user = db.query(User).filter(User.id == user_id).first()
|
||||
result = await db.execute(select(User).where(User.id == user_id))
|
||||
user = result.scalar_one_or_none()
|
||||
if not user or not user.is_active:
|
||||
raise TokenInvalidError("Invalid user or inactive account")
|
||||
|
||||
@@ -162,7 +166,7 @@ class AuthService:
|
||||
raise
|
||||
|
||||
@staticmethod
|
||||
def change_password(db: Session, user_id: UUID, current_password: str, new_password: str) -> bool:
|
||||
async def change_password(db: AsyncSession, user_id: UUID, current_password: str, new_password: str) -> bool:
|
||||
"""
|
||||
Change a user's password.
|
||||
|
||||
@@ -178,7 +182,8 @@ class AuthService:
|
||||
Raises:
|
||||
AuthenticationError: If current password is incorrect
|
||||
"""
|
||||
user = db.query(User).filter(User.id == user_id).first()
|
||||
result = await db.execute(select(User).where(User.id == user_id))
|
||||
user = result.scalar_one_or_none()
|
||||
if not user:
|
||||
raise AuthenticationError("User not found")
|
||||
|
||||
@@ -188,6 +193,6 @@ class AuthService:
|
||||
|
||||
# Update password
|
||||
user.password_hash = get_password_hash(new_password)
|
||||
db.commit()
|
||||
await db.commit()
|
||||
|
||||
return True
|
||||
|
||||
78
backend/app/services/session_cleanup.py
Normal file → Executable file
78
backend/app/services/session_cleanup.py
Normal file → Executable file
@@ -6,13 +6,13 @@ This service runs periodically to remove old session records from the database.
|
||||
import logging
|
||||
from datetime import datetime, timezone
|
||||
|
||||
from app.core.database import SessionLocal
|
||||
from app.crud.session import session as session_crud
|
||||
from app.core.database_async import AsyncSessionLocal
|
||||
from app.crud.session_async import session_async as session_crud
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def cleanup_expired_sessions(keep_days: int = 30) -> int:
|
||||
async def cleanup_expired_sessions(keep_days: int = 30) -> int:
|
||||
"""
|
||||
Clean up expired and inactive sessions.
|
||||
|
||||
@@ -29,52 +29,58 @@ def cleanup_expired_sessions(keep_days: int = 30) -> int:
|
||||
"""
|
||||
logger.info("Starting session cleanup job...")
|
||||
|
||||
db = SessionLocal()
|
||||
try:
|
||||
# Use CRUD method to cleanup
|
||||
count = session_crud.cleanup_expired(db, keep_days=keep_days)
|
||||
async with AsyncSessionLocal() as db:
|
||||
try:
|
||||
# Use CRUD method to cleanup
|
||||
count = await session_crud.cleanup_expired(db, keep_days=keep_days)
|
||||
|
||||
logger.info(f"Session cleanup complete: {count} sessions deleted")
|
||||
logger.info(f"Session cleanup complete: {count} sessions deleted")
|
||||
|
||||
return count
|
||||
return count
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error during session cleanup: {str(e)}", exc_info=True)
|
||||
return 0
|
||||
finally:
|
||||
db.close()
|
||||
except Exception as e:
|
||||
logger.error(f"Error during session cleanup: {str(e)}", exc_info=True)
|
||||
return 0
|
||||
|
||||
|
||||
def get_session_statistics() -> dict:
|
||||
async def get_session_statistics() -> dict:
|
||||
"""
|
||||
Get statistics about current sessions.
|
||||
|
||||
Returns:
|
||||
Dictionary with session stats
|
||||
"""
|
||||
db = SessionLocal()
|
||||
try:
|
||||
from app.models.user_session import UserSession
|
||||
async with AsyncSessionLocal() as db:
|
||||
try:
|
||||
from app.models.user_session import UserSession
|
||||
from sqlalchemy import select, func
|
||||
|
||||
total_sessions = db.query(UserSession).count()
|
||||
active_sessions = db.query(UserSession).filter(UserSession.is_active == True).count()
|
||||
expired_sessions = db.query(UserSession).filter(
|
||||
UserSession.expires_at < datetime.now(timezone.utc)
|
||||
).count()
|
||||
total_result = await db.execute(select(func.count(UserSession.id)))
|
||||
total_sessions = total_result.scalar_one()
|
||||
|
||||
stats = {
|
||||
"total": total_sessions,
|
||||
"active": active_sessions,
|
||||
"inactive": total_sessions - active_sessions,
|
||||
"expired": expired_sessions,
|
||||
}
|
||||
active_result = await db.execute(
|
||||
select(func.count(UserSession.id)).where(UserSession.is_active == True)
|
||||
)
|
||||
active_sessions = active_result.scalar_one()
|
||||
|
||||
logger.info(f"Session statistics: {stats}")
|
||||
expired_result = await db.execute(
|
||||
select(func.count(UserSession.id)).where(
|
||||
UserSession.expires_at < datetime.now(timezone.utc)
|
||||
)
|
||||
)
|
||||
expired_sessions = expired_result.scalar_one()
|
||||
|
||||
return stats
|
||||
stats = {
|
||||
"total": total_sessions,
|
||||
"active": active_sessions,
|
||||
"inactive": total_sessions - active_sessions,
|
||||
"expired": expired_sessions,
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting session statistics: {str(e)}", exc_info=True)
|
||||
return {}
|
||||
finally:
|
||||
db.close()
|
||||
logger.info(f"Session statistics: {stats}")
|
||||
|
||||
return stats
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting session statistics: {str(e)}", exc_info=True)
|
||||
return {}
|
||||
|
||||
Reference in New Issue
Block a user