forked from cardosofelipe/fast-next-template
- Replaced `not UserSession.is_active` with `UserSession.is_active == False` in cleanup queries for explicit comparison. - Added `mypy` overrides for `app.alembic` and external libraries (`starlette`). - Refactored `Makefile` to use virtual environment binaries for commands like `ruff`, `mypy`, and `pytest`.
467 lines
14 KiB
Python
Executable File
467 lines
14 KiB
Python
Executable File
"""
|
|
Async CRUD operations for user sessions using SQLAlchemy 2.0 patterns.
|
|
"""
|
|
|
|
import logging
|
|
import uuid
|
|
from datetime import UTC, datetime, timedelta
|
|
from uuid import UUID
|
|
|
|
from sqlalchemy import and_, delete, func, select, update
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
from sqlalchemy.orm import joinedload
|
|
|
|
from app.crud.base import CRUDBase
|
|
from app.models.user_session import UserSession
|
|
from app.schemas.sessions import SessionCreate, SessionUpdate
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class CRUDSession(CRUDBase[UserSession, SessionCreate, SessionUpdate]):
|
|
"""Async CRUD operations for user sessions."""
|
|
|
|
async def get_by_jti(self, db: AsyncSession, *, jti: str) -> UserSession | None:
|
|
"""
|
|
Get session by refresh token JTI.
|
|
|
|
Args:
|
|
db: Database session
|
|
jti: Refresh token JWT ID
|
|
|
|
Returns:
|
|
UserSession if found, None otherwise
|
|
"""
|
|
try:
|
|
result = await db.execute(
|
|
select(UserSession).where(UserSession.refresh_token_jti == jti)
|
|
)
|
|
return result.scalar_one_or_none()
|
|
except Exception as e:
|
|
logger.error(f"Error getting session by JTI {jti}: {e!s}")
|
|
raise
|
|
|
|
async def get_active_by_jti(
|
|
self, db: AsyncSession, *, jti: str
|
|
) -> UserSession | None:
|
|
"""
|
|
Get active session by refresh token JTI.
|
|
|
|
Args:
|
|
db: Database session
|
|
jti: Refresh token JWT ID
|
|
|
|
Returns:
|
|
Active UserSession if found, None otherwise
|
|
"""
|
|
try:
|
|
result = await db.execute(
|
|
select(UserSession).where(
|
|
and_(
|
|
UserSession.refresh_token_jti == jti,
|
|
UserSession.is_active,
|
|
)
|
|
)
|
|
)
|
|
return result.scalar_one_or_none()
|
|
except Exception as e:
|
|
logger.error(f"Error getting active session by JTI {jti}: {e!s}")
|
|
raise
|
|
|
|
async def get_user_sessions(
|
|
self,
|
|
db: AsyncSession,
|
|
*,
|
|
user_id: str,
|
|
active_only: bool = True,
|
|
with_user: bool = False,
|
|
) -> list[UserSession]:
|
|
"""
|
|
Get all sessions for a user with optional eager loading.
|
|
|
|
Args:
|
|
db: Database session
|
|
user_id: User ID
|
|
active_only: If True, return only active sessions
|
|
with_user: If True, eager load user relationship to prevent N+1
|
|
|
|
Returns:
|
|
List of UserSession objects
|
|
"""
|
|
try:
|
|
# Convert user_id string to UUID if needed
|
|
user_uuid = UUID(user_id) if isinstance(user_id, str) else user_id
|
|
|
|
query = select(UserSession).where(UserSession.user_id == user_uuid)
|
|
|
|
# Add eager loading if requested to prevent N+1 queries
|
|
if with_user:
|
|
query = query.options(joinedload(UserSession.user))
|
|
|
|
if active_only:
|
|
query = query.where(UserSession.is_active)
|
|
|
|
query = query.order_by(UserSession.last_used_at.desc())
|
|
result = await db.execute(query)
|
|
return list(result.scalars().all())
|
|
except Exception as e:
|
|
logger.error(f"Error getting sessions for user {user_id}: {e!s}")
|
|
raise
|
|
|
|
async def create_session(
|
|
self, db: AsyncSession, *, obj_in: SessionCreate
|
|
) -> UserSession:
|
|
"""
|
|
Create a new user session.
|
|
|
|
Args:
|
|
db: Database session
|
|
obj_in: SessionCreate schema with session data
|
|
|
|
Returns:
|
|
Created UserSession
|
|
|
|
Raises:
|
|
ValueError: If session creation fails
|
|
"""
|
|
try:
|
|
db_obj = UserSession(
|
|
user_id=obj_in.user_id,
|
|
refresh_token_jti=obj_in.refresh_token_jti,
|
|
device_name=obj_in.device_name,
|
|
device_id=obj_in.device_id,
|
|
ip_address=obj_in.ip_address,
|
|
user_agent=obj_in.user_agent,
|
|
last_used_at=obj_in.last_used_at,
|
|
expires_at=obj_in.expires_at,
|
|
is_active=True,
|
|
location_city=obj_in.location_city,
|
|
location_country=obj_in.location_country,
|
|
)
|
|
db.add(db_obj)
|
|
await db.commit()
|
|
await db.refresh(db_obj)
|
|
|
|
logger.info(
|
|
f"Session created for user {obj_in.user_id} from {obj_in.device_name} "
|
|
f"(IP: {obj_in.ip_address})"
|
|
)
|
|
|
|
return db_obj
|
|
except Exception as e:
|
|
await db.rollback()
|
|
logger.error(f"Error creating session: {e!s}", exc_info=True)
|
|
raise ValueError(f"Failed to create session: {e!s}")
|
|
|
|
async def deactivate(
|
|
self, db: AsyncSession, *, session_id: str
|
|
) -> UserSession | None:
|
|
"""
|
|
Deactivate a session (logout from device).
|
|
|
|
Args:
|
|
db: Database session
|
|
session_id: Session UUID
|
|
|
|
Returns:
|
|
Deactivated UserSession if found, None otherwise
|
|
"""
|
|
try:
|
|
session = await self.get(db, id=session_id)
|
|
if not session:
|
|
logger.warning(f"Session {session_id} not found for deactivation")
|
|
return None
|
|
|
|
session.is_active = False
|
|
db.add(session)
|
|
await db.commit()
|
|
await db.refresh(session)
|
|
|
|
logger.info(
|
|
f"Session {session_id} deactivated for user {session.user_id} "
|
|
f"({session.device_name})"
|
|
)
|
|
|
|
return session
|
|
except Exception as e:
|
|
await db.rollback()
|
|
logger.error(f"Error deactivating session {session_id}: {e!s}")
|
|
raise
|
|
|
|
async def deactivate_all_user_sessions(
|
|
self, db: AsyncSession, *, user_id: str
|
|
) -> int:
|
|
"""
|
|
Deactivate all active sessions for a user (logout from all devices).
|
|
|
|
Args:
|
|
db: Database session
|
|
user_id: User ID
|
|
|
|
Returns:
|
|
Number of sessions deactivated
|
|
"""
|
|
try:
|
|
# Convert user_id string to UUID if needed
|
|
user_uuid = UUID(user_id) if isinstance(user_id, str) else user_id
|
|
|
|
stmt = (
|
|
update(UserSession)
|
|
.where(and_(UserSession.user_id == user_uuid, UserSession.is_active))
|
|
.values(is_active=False)
|
|
)
|
|
|
|
result = await db.execute(stmt)
|
|
await db.commit()
|
|
|
|
count = result.rowcount
|
|
|
|
logger.info(f"Deactivated {count} sessions for user {user_id}")
|
|
|
|
return count
|
|
except Exception as e:
|
|
await db.rollback()
|
|
logger.error(f"Error deactivating all sessions for user {user_id}: {e!s}")
|
|
raise
|
|
|
|
async def update_last_used(
|
|
self, db: AsyncSession, *, session: UserSession
|
|
) -> UserSession:
|
|
"""
|
|
Update the last_used_at timestamp for a session.
|
|
|
|
Args:
|
|
db: Database session
|
|
session: UserSession object
|
|
|
|
Returns:
|
|
Updated UserSession
|
|
"""
|
|
try:
|
|
session.last_used_at = datetime.now(UTC)
|
|
db.add(session)
|
|
await db.commit()
|
|
await db.refresh(session)
|
|
return session
|
|
except Exception as e:
|
|
await db.rollback()
|
|
logger.error(f"Error updating last_used for session {session.id}: {e!s}")
|
|
raise
|
|
|
|
async def update_refresh_token(
|
|
self,
|
|
db: AsyncSession,
|
|
*,
|
|
session: UserSession,
|
|
new_jti: str,
|
|
new_expires_at: datetime,
|
|
) -> UserSession:
|
|
"""
|
|
Update session with new refresh token JTI and expiration.
|
|
|
|
Called during token refresh.
|
|
|
|
Args:
|
|
db: Database session
|
|
session: UserSession object
|
|
new_jti: New refresh token JTI
|
|
new_expires_at: New expiration datetime
|
|
|
|
Returns:
|
|
Updated UserSession
|
|
"""
|
|
try:
|
|
session.refresh_token_jti = new_jti
|
|
session.expires_at = new_expires_at
|
|
session.last_used_at = datetime.now(UTC)
|
|
db.add(session)
|
|
await db.commit()
|
|
await db.refresh(session)
|
|
return session
|
|
except Exception as e:
|
|
await db.rollback()
|
|
logger.error(
|
|
f"Error updating refresh token for session {session.id}: {e!s}"
|
|
)
|
|
raise
|
|
|
|
async def cleanup_expired(self, db: AsyncSession, *, keep_days: int = 30) -> int:
|
|
"""
|
|
Clean up expired sessions using optimized bulk DELETE.
|
|
|
|
Deletes sessions that are:
|
|
- Expired AND inactive
|
|
- Older than keep_days
|
|
|
|
Uses single DELETE query instead of N individual deletes for efficiency.
|
|
|
|
Args:
|
|
db: Database session
|
|
keep_days: Keep inactive sessions for this many days (for audit)
|
|
|
|
Returns:
|
|
Number of sessions deleted
|
|
"""
|
|
try:
|
|
cutoff_date = datetime.now(UTC) - timedelta(days=keep_days)
|
|
now = datetime.now(UTC)
|
|
|
|
# Use bulk DELETE with WHERE clause - single query
|
|
stmt = delete(UserSession).where(
|
|
and_(
|
|
UserSession.is_active == False, # noqa: E712
|
|
UserSession.expires_at < now,
|
|
UserSession.created_at < cutoff_date,
|
|
)
|
|
)
|
|
|
|
result = await db.execute(stmt)
|
|
await db.commit()
|
|
|
|
count = result.rowcount
|
|
|
|
if count > 0:
|
|
logger.info(f"Cleaned up {count} expired sessions using bulk DELETE")
|
|
|
|
return count
|
|
except Exception as e:
|
|
await db.rollback()
|
|
logger.error(f"Error cleaning up expired sessions: {e!s}")
|
|
raise
|
|
|
|
async def cleanup_expired_for_user(self, db: AsyncSession, *, user_id: str) -> int:
|
|
"""
|
|
Clean up expired and inactive sessions for a specific user.
|
|
|
|
Uses single bulk DELETE query for efficiency instead of N individual deletes.
|
|
|
|
Args:
|
|
db: Database session
|
|
user_id: User ID to cleanup sessions for
|
|
|
|
Returns:
|
|
Number of sessions deleted
|
|
"""
|
|
try:
|
|
# Validate UUID
|
|
try:
|
|
uuid_obj = uuid.UUID(user_id)
|
|
except (ValueError, AttributeError):
|
|
logger.error(f"Invalid UUID format: {user_id}")
|
|
raise ValueError(f"Invalid user ID format: {user_id}")
|
|
|
|
now = datetime.now(UTC)
|
|
|
|
# Use bulk DELETE with WHERE clause - single query
|
|
stmt = delete(UserSession).where(
|
|
and_(
|
|
UserSession.user_id == uuid_obj,
|
|
UserSession.is_active == False, # noqa: E712
|
|
UserSession.expires_at < now,
|
|
)
|
|
)
|
|
|
|
result = await db.execute(stmt)
|
|
await db.commit()
|
|
|
|
count = result.rowcount
|
|
|
|
if count > 0:
|
|
logger.info(
|
|
f"Cleaned up {count} expired sessions for user {user_id} using bulk DELETE"
|
|
)
|
|
|
|
return count
|
|
except Exception as e:
|
|
await db.rollback()
|
|
logger.error(
|
|
f"Error cleaning up expired sessions for user {user_id}: {e!s}"
|
|
)
|
|
raise
|
|
|
|
async def get_user_session_count(self, db: AsyncSession, *, user_id: str) -> int:
|
|
"""
|
|
Get count of active sessions for a user.
|
|
|
|
Args:
|
|
db: Database session
|
|
user_id: User ID
|
|
|
|
Returns:
|
|
Number of active sessions
|
|
"""
|
|
try:
|
|
# Convert user_id string to UUID if needed
|
|
user_uuid = UUID(user_id) if isinstance(user_id, str) else user_id
|
|
|
|
result = await db.execute(
|
|
select(func.count(UserSession.id)).where(
|
|
and_(UserSession.user_id == user_uuid, UserSession.is_active)
|
|
)
|
|
)
|
|
return result.scalar_one()
|
|
except Exception as e:
|
|
logger.error(f"Error counting sessions for user {user_id}: {e!s}")
|
|
raise
|
|
|
|
async def get_all_sessions(
|
|
self,
|
|
db: AsyncSession,
|
|
*,
|
|
skip: int = 0,
|
|
limit: int = 100,
|
|
active_only: bool = True,
|
|
with_user: bool = True,
|
|
) -> tuple[list[UserSession], int]:
|
|
"""
|
|
Get all sessions across all users with pagination (admin only).
|
|
|
|
Args:
|
|
db: Database session
|
|
skip: Number of records to skip
|
|
limit: Maximum number of records to return
|
|
active_only: If True, return only active sessions
|
|
with_user: If True, eager load user relationship to prevent N+1
|
|
|
|
Returns:
|
|
Tuple of (list of UserSession objects, total count)
|
|
"""
|
|
try:
|
|
# Build query
|
|
query = select(UserSession)
|
|
|
|
# Add eager loading if requested to prevent N+1 queries
|
|
if with_user:
|
|
query = query.options(joinedload(UserSession.user))
|
|
|
|
if active_only:
|
|
query = query.where(UserSession.is_active)
|
|
|
|
# Get total count
|
|
count_query = select(func.count(UserSession.id))
|
|
if active_only:
|
|
count_query = count_query.where(UserSession.is_active)
|
|
|
|
count_result = await db.execute(count_query)
|
|
total = count_result.scalar_one()
|
|
|
|
# Apply pagination and ordering
|
|
query = (
|
|
query.order_by(UserSession.last_used_at.desc())
|
|
.offset(skip)
|
|
.limit(limit)
|
|
)
|
|
|
|
result = await db.execute(query)
|
|
sessions = list(result.scalars().all())
|
|
|
|
return sessions, total
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error getting all sessions: {e!s}", exc_info=True)
|
|
raise
|
|
|
|
|
|
# Create singleton instance
|
|
session = CRUDSession(UserSession)
|