Add async CRUD classes for organizations, sessions, and users

- Implemented `CRUDOrganizationAsync`, `CRUDSessionAsync`, and `CRUDUserAsync` with full async support for database operations.
- Added filtering, sorting, pagination, and advanced methods for organization management.
- Developed session-specific logic, including cleanup, per-device management, and security enhancements.
- Enhanced user CRUD with password hashing and comprehensive update handling.
This commit is contained in:
Felipe Cardoso
2025-10-31 21:59:40 +01:00
parent 26ff08d9f9
commit 1f15ee6db3
3 changed files with 917 additions and 0 deletions

363
backend/app/crud/session_async.py Executable file
View File

@@ -0,0 +1,363 @@
"""
Async CRUD operations for user sessions using SQLAlchemy 2.0 patterns.
"""
from datetime import datetime, timezone, timedelta
from typing import List, Optional
from uuid import UUID
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import and_, select, update, func
import logging
from app.crud.base_async import CRUDBaseAsync
from app.models.user_session import UserSession
from app.schemas.sessions import SessionCreate, SessionUpdate
logger = logging.getLogger(__name__)
class CRUDSessionAsync(CRUDBaseAsync[UserSession, SessionCreate, SessionUpdate]):
"""Async CRUD operations for user sessions."""
async def get_by_jti(self, db: AsyncSession, *, jti: str) -> Optional[UserSession]:
"""
Get session by refresh token JTI.
Args:
db: Database session
jti: Refresh token JWT ID
Returns:
UserSession if found, None otherwise
"""
try:
result = await db.execute(
select(UserSession).where(UserSession.refresh_token_jti == jti)
)
return result.scalar_one_or_none()
except Exception as e:
logger.error(f"Error getting session by JTI {jti}: {str(e)}")
raise
async def get_active_by_jti(self, db: AsyncSession, *, jti: str) -> Optional[UserSession]:
"""
Get active session by refresh token JTI.
Args:
db: Database session
jti: Refresh token JWT ID
Returns:
Active UserSession if found, None otherwise
"""
try:
result = await db.execute(
select(UserSession).where(
and_(
UserSession.refresh_token_jti == jti,
UserSession.is_active == True
)
)
)
return result.scalar_one_or_none()
except Exception as e:
logger.error(f"Error getting active session by JTI {jti}: {str(e)}")
raise
async def get_user_sessions(
self,
db: AsyncSession,
*,
user_id: str,
active_only: bool = True
) -> List[UserSession]:
"""
Get all sessions for a user.
Args:
db: Database session
user_id: User ID
active_only: If True, return only active sessions
Returns:
List of UserSession objects
"""
try:
# Convert user_id string to UUID if needed
user_uuid = UUID(user_id) if isinstance(user_id, str) else user_id
query = select(UserSession).where(UserSession.user_id == user_uuid)
if active_only:
query = query.where(UserSession.is_active == True)
query = query.order_by(UserSession.last_used_at.desc())
result = await db.execute(query)
return list(result.scalars().all())
except Exception as e:
logger.error(f"Error getting sessions for user {user_id}: {str(e)}")
raise
async def create_session(
self,
db: AsyncSession,
*,
obj_in: SessionCreate
) -> UserSession:
"""
Create a new user session.
Args:
db: Database session
obj_in: SessionCreate schema with session data
Returns:
Created UserSession
Raises:
ValueError: If session creation fails
"""
try:
db_obj = UserSession(
user_id=obj_in.user_id,
refresh_token_jti=obj_in.refresh_token_jti,
device_name=obj_in.device_name,
device_id=obj_in.device_id,
ip_address=obj_in.ip_address,
user_agent=obj_in.user_agent,
last_used_at=obj_in.last_used_at,
expires_at=obj_in.expires_at,
is_active=True,
location_city=obj_in.location_city,
location_country=obj_in.location_country,
)
db.add(db_obj)
await db.commit()
await db.refresh(db_obj)
logger.info(
f"Session created for user {obj_in.user_id} from {obj_in.device_name} "
f"(IP: {obj_in.ip_address})"
)
return db_obj
except Exception as e:
await db.rollback()
logger.error(f"Error creating session: {str(e)}", exc_info=True)
raise ValueError(f"Failed to create session: {str(e)}")
async def deactivate(self, db: AsyncSession, *, session_id: str) -> Optional[UserSession]:
"""
Deactivate a session (logout from device).
Args:
db: Database session
session_id: Session UUID
Returns:
Deactivated UserSession if found, None otherwise
"""
try:
session = await self.get(db, id=session_id)
if not session:
logger.warning(f"Session {session_id} not found for deactivation")
return None
session.is_active = False
db.add(session)
await db.commit()
await db.refresh(session)
logger.info(
f"Session {session_id} deactivated for user {session.user_id} "
f"({session.device_name})"
)
return session
except Exception as e:
await db.rollback()
logger.error(f"Error deactivating session {session_id}: {str(e)}")
raise
async def deactivate_all_user_sessions(
self,
db: AsyncSession,
*,
user_id: str
) -> int:
"""
Deactivate all active sessions for a user (logout from all devices).
Args:
db: Database session
user_id: User ID
Returns:
Number of sessions deactivated
"""
try:
# Convert user_id string to UUID if needed
user_uuid = UUID(user_id) if isinstance(user_id, str) else user_id
stmt = (
update(UserSession)
.where(
and_(
UserSession.user_id == user_uuid,
UserSession.is_active == True
)
)
.values(is_active=False)
)
result = await db.execute(stmt)
await db.commit()
count = result.rowcount
logger.info(f"Deactivated {count} sessions for user {user_id}")
return count
except Exception as e:
await db.rollback()
logger.error(f"Error deactivating all sessions for user {user_id}: {str(e)}")
raise
async def update_last_used(
self,
db: AsyncSession,
*,
session: UserSession
) -> UserSession:
"""
Update the last_used_at timestamp for a session.
Args:
db: Database session
session: UserSession object
Returns:
Updated UserSession
"""
try:
session.last_used_at = datetime.now(timezone.utc)
db.add(session)
await db.commit()
await db.refresh(session)
return session
except Exception as e:
await db.rollback()
logger.error(f"Error updating last_used for session {session.id}: {str(e)}")
raise
async def update_refresh_token(
self,
db: AsyncSession,
*,
session: UserSession,
new_jti: str,
new_expires_at: datetime
) -> UserSession:
"""
Update session with new refresh token JTI and expiration.
Called during token refresh.
Args:
db: Database session
session: UserSession object
new_jti: New refresh token JTI
new_expires_at: New expiration datetime
Returns:
Updated UserSession
"""
try:
session.refresh_token_jti = new_jti
session.expires_at = new_expires_at
session.last_used_at = datetime.now(timezone.utc)
db.add(session)
await db.commit()
await db.refresh(session)
return session
except Exception as e:
await db.rollback()
logger.error(f"Error updating refresh token for session {session.id}: {str(e)}")
raise
async def cleanup_expired(self, db: AsyncSession, *, keep_days: int = 30) -> int:
"""
Clean up expired sessions.
Deletes sessions that are:
- Expired AND inactive
- Older than keep_days
Args:
db: Database session
keep_days: Keep inactive sessions for this many days (for audit)
Returns:
Number of sessions deleted
"""
try:
cutoff_date = datetime.now(timezone.utc) - timedelta(days=keep_days)
# Get sessions to delete
stmt = select(UserSession).where(
and_(
UserSession.is_active == False,
UserSession.expires_at < datetime.now(timezone.utc),
UserSession.created_at < cutoff_date
)
)
result = await db.execute(stmt)
sessions_to_delete = list(result.scalars().all())
# Delete them
for session in sessions_to_delete:
await db.delete(session)
await db.commit()
count = len(sessions_to_delete)
if count > 0:
logger.info(f"Cleaned up {count} expired sessions")
return count
except Exception as e:
await db.rollback()
logger.error(f"Error cleaning up expired sessions: {str(e)}")
raise
async def get_user_session_count(self, db: AsyncSession, *, user_id: str) -> int:
"""
Get count of active sessions for a user.
Args:
db: Database session
user_id: User ID
Returns:
Number of active sessions
"""
try:
# Convert user_id string to UUID if needed
user_uuid = UUID(user_id) if isinstance(user_id, str) else user_id
result = await db.execute(
select(func.count(UserSession.id)).where(
and_(
UserSession.user_id == user_uuid,
UserSession.is_active == True
)
)
)
return result.scalar_one()
except Exception as e:
logger.error(f"Error counting sessions for user {user_id}: {str(e)}")
raise
# Create singleton instance
session_async = CRUDSessionAsync(UserSession)