- introduce custom repository exception hierarchy (DuplicateEntryError, IntegrityConstraintError, InvalidInputError) replacing raw ValueError - eliminate all direct repository imports and raw SQL from route layer - add UserService, SessionService, OrganizationService to service layer - add get_stats/get_org_distribution service methods replacing admin inline SQL - fix timing side-channel in authenticate_user via dummy bcrypt check - replace SHA-256 client secret fallback with explicit InvalidClientError - replace assert with InvalidGrantError in authorization code exchange - replace N+1 token revocation loops with bulk UPDATE statements - rename oauth account token fields (drop misleading 'encrypted' suffix) - add Alembic migration 0003 for token field column rename - add 45 new service/repository tests; 975 passing, 94% coverage
328 lines
11 KiB
Python
328 lines
11 KiB
Python
# app/repositories/session.py
|
|
"""Repository for UserSession model async CRUD operations using SQLAlchemy 2.0 patterns."""
|
|
|
|
import logging
|
|
import uuid
|
|
from datetime import UTC, datetime, timedelta
|
|
from uuid import UUID
|
|
|
|
from sqlalchemy import and_, delete, func, select, update
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
from sqlalchemy.orm import joinedload
|
|
|
|
from app.core.repository_exceptions import InvalidInputError, IntegrityConstraintError
|
|
from app.models.user_session import UserSession
|
|
from app.repositories.base import BaseRepository
|
|
from app.schemas.sessions import SessionCreate, SessionUpdate
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class SessionRepository(BaseRepository[UserSession, SessionCreate, SessionUpdate]):
|
|
"""Repository for UserSession model."""
|
|
|
|
async def get_by_jti(self, db: AsyncSession, *, jti: str) -> UserSession | None:
|
|
"""Get session by refresh token JTI."""
|
|
try:
|
|
result = await db.execute(
|
|
select(UserSession).where(UserSession.refresh_token_jti == jti)
|
|
)
|
|
return result.scalar_one_or_none()
|
|
except Exception as e:
|
|
logger.error(f"Error getting session by JTI {jti}: {e!s}")
|
|
raise
|
|
|
|
async def get_active_by_jti(
|
|
self, db: AsyncSession, *, jti: str
|
|
) -> UserSession | None:
|
|
"""Get active session by refresh token JTI."""
|
|
try:
|
|
result = await db.execute(
|
|
select(UserSession).where(
|
|
and_(
|
|
UserSession.refresh_token_jti == jti,
|
|
UserSession.is_active,
|
|
)
|
|
)
|
|
)
|
|
return result.scalar_one_or_none()
|
|
except Exception as e:
|
|
logger.error(f"Error getting active session by JTI {jti}: {e!s}")
|
|
raise
|
|
|
|
async def get_user_sessions(
|
|
self,
|
|
db: AsyncSession,
|
|
*,
|
|
user_id: str,
|
|
active_only: bool = True,
|
|
with_user: bool = False,
|
|
) -> list[UserSession]:
|
|
"""Get all sessions for a user with optional eager loading."""
|
|
try:
|
|
user_uuid = UUID(user_id) if isinstance(user_id, str) else user_id
|
|
|
|
query = select(UserSession).where(UserSession.user_id == user_uuid)
|
|
|
|
if with_user:
|
|
query = query.options(joinedload(UserSession.user))
|
|
|
|
if active_only:
|
|
query = query.where(UserSession.is_active)
|
|
|
|
query = query.order_by(UserSession.last_used_at.desc())
|
|
result = await db.execute(query)
|
|
return list(result.scalars().all())
|
|
except Exception as e:
|
|
logger.error(f"Error getting sessions for user {user_id}: {e!s}")
|
|
raise
|
|
|
|
async def create_session(
|
|
self, db: AsyncSession, *, obj_in: SessionCreate
|
|
) -> UserSession:
|
|
"""Create a new user session."""
|
|
try:
|
|
db_obj = UserSession(
|
|
user_id=obj_in.user_id,
|
|
refresh_token_jti=obj_in.refresh_token_jti,
|
|
device_name=obj_in.device_name,
|
|
device_id=obj_in.device_id,
|
|
ip_address=obj_in.ip_address,
|
|
user_agent=obj_in.user_agent,
|
|
last_used_at=obj_in.last_used_at,
|
|
expires_at=obj_in.expires_at,
|
|
is_active=True,
|
|
location_city=obj_in.location_city,
|
|
location_country=obj_in.location_country,
|
|
)
|
|
db.add(db_obj)
|
|
await db.commit()
|
|
await db.refresh(db_obj)
|
|
|
|
logger.info(
|
|
f"Session created for user {obj_in.user_id} from {obj_in.device_name} "
|
|
f"(IP: {obj_in.ip_address})"
|
|
)
|
|
|
|
return db_obj
|
|
except Exception as e:
|
|
await db.rollback()
|
|
logger.error(f"Error creating session: {e!s}", exc_info=True)
|
|
raise IntegrityConstraintError(f"Failed to create session: {e!s}")
|
|
|
|
async def deactivate(
|
|
self, db: AsyncSession, *, session_id: str
|
|
) -> UserSession | None:
|
|
"""Deactivate a session (logout from device)."""
|
|
try:
|
|
session = await self.get(db, id=session_id)
|
|
if not session:
|
|
logger.warning(f"Session {session_id} not found for deactivation")
|
|
return None
|
|
|
|
session.is_active = False
|
|
db.add(session)
|
|
await db.commit()
|
|
await db.refresh(session)
|
|
|
|
logger.info(
|
|
f"Session {session_id} deactivated for user {session.user_id} "
|
|
f"({session.device_name})"
|
|
)
|
|
|
|
return session
|
|
except Exception as e:
|
|
await db.rollback()
|
|
logger.error(f"Error deactivating session {session_id}: {e!s}")
|
|
raise
|
|
|
|
async def deactivate_all_user_sessions(
|
|
self, db: AsyncSession, *, user_id: str
|
|
) -> int:
|
|
"""Deactivate all active sessions for a user (logout from all devices)."""
|
|
try:
|
|
user_uuid = UUID(user_id) if isinstance(user_id, str) else user_id
|
|
|
|
stmt = (
|
|
update(UserSession)
|
|
.where(and_(UserSession.user_id == user_uuid, UserSession.is_active))
|
|
.values(is_active=False)
|
|
)
|
|
|
|
result = await db.execute(stmt)
|
|
await db.commit()
|
|
|
|
count = result.rowcount
|
|
|
|
logger.info(f"Deactivated {count} sessions for user {user_id}")
|
|
|
|
return count
|
|
except Exception as e:
|
|
await db.rollback()
|
|
logger.error(f"Error deactivating all sessions for user {user_id}: {e!s}")
|
|
raise
|
|
|
|
async def update_last_used(
|
|
self, db: AsyncSession, *, session: UserSession
|
|
) -> UserSession:
|
|
"""Update the last_used_at timestamp for a session."""
|
|
try:
|
|
session.last_used_at = datetime.now(UTC)
|
|
db.add(session)
|
|
await db.commit()
|
|
await db.refresh(session)
|
|
return session
|
|
except Exception as e:
|
|
await db.rollback()
|
|
logger.error(f"Error updating last_used for session {session.id}: {e!s}")
|
|
raise
|
|
|
|
async def update_refresh_token(
|
|
self,
|
|
db: AsyncSession,
|
|
*,
|
|
session: UserSession,
|
|
new_jti: str,
|
|
new_expires_at: datetime,
|
|
) -> UserSession:
|
|
"""Update session with new refresh token JTI and expiration."""
|
|
try:
|
|
session.refresh_token_jti = new_jti
|
|
session.expires_at = new_expires_at
|
|
session.last_used_at = datetime.now(UTC)
|
|
db.add(session)
|
|
await db.commit()
|
|
await db.refresh(session)
|
|
return session
|
|
except Exception as e:
|
|
await db.rollback()
|
|
logger.error(
|
|
f"Error updating refresh token for session {session.id}: {e!s}"
|
|
)
|
|
raise
|
|
|
|
async def cleanup_expired(self, db: AsyncSession, *, keep_days: int = 30) -> int:
|
|
"""Clean up expired sessions using optimized bulk DELETE."""
|
|
try:
|
|
cutoff_date = datetime.now(UTC) - timedelta(days=keep_days)
|
|
now = datetime.now(UTC)
|
|
|
|
stmt = delete(UserSession).where(
|
|
and_(
|
|
UserSession.is_active == False, # noqa: E712
|
|
UserSession.expires_at < now,
|
|
UserSession.created_at < cutoff_date,
|
|
)
|
|
)
|
|
|
|
result = await db.execute(stmt)
|
|
await db.commit()
|
|
|
|
count = result.rowcount
|
|
|
|
if count > 0:
|
|
logger.info(f"Cleaned up {count} expired sessions using bulk DELETE")
|
|
|
|
return count
|
|
except Exception as e:
|
|
await db.rollback()
|
|
logger.error(f"Error cleaning up expired sessions: {e!s}")
|
|
raise
|
|
|
|
async def cleanup_expired_for_user(self, db: AsyncSession, *, user_id: str) -> int:
|
|
"""Clean up expired and inactive sessions for a specific user."""
|
|
try:
|
|
try:
|
|
uuid_obj = uuid.UUID(user_id)
|
|
except (ValueError, AttributeError):
|
|
logger.error(f"Invalid UUID format: {user_id}")
|
|
raise InvalidInputError(f"Invalid user ID format: {user_id}")
|
|
|
|
now = datetime.now(UTC)
|
|
|
|
stmt = delete(UserSession).where(
|
|
and_(
|
|
UserSession.user_id == uuid_obj,
|
|
UserSession.is_active == False, # noqa: E712
|
|
UserSession.expires_at < now,
|
|
)
|
|
)
|
|
|
|
result = await db.execute(stmt)
|
|
await db.commit()
|
|
|
|
count = result.rowcount
|
|
|
|
if count > 0:
|
|
logger.info(
|
|
f"Cleaned up {count} expired sessions for user {user_id} using bulk DELETE"
|
|
)
|
|
|
|
return count
|
|
except Exception as e:
|
|
await db.rollback()
|
|
logger.error(
|
|
f"Error cleaning up expired sessions for user {user_id}: {e!s}"
|
|
)
|
|
raise
|
|
|
|
async def get_user_session_count(self, db: AsyncSession, *, user_id: str) -> int:
|
|
"""Get count of active sessions for a user."""
|
|
try:
|
|
user_uuid = UUID(user_id) if isinstance(user_id, str) else user_id
|
|
|
|
result = await db.execute(
|
|
select(func.count(UserSession.id)).where(
|
|
and_(UserSession.user_id == user_uuid, UserSession.is_active)
|
|
)
|
|
)
|
|
return result.scalar_one()
|
|
except Exception as e:
|
|
logger.error(f"Error counting sessions for user {user_id}: {e!s}")
|
|
raise
|
|
|
|
async def get_all_sessions(
|
|
self,
|
|
db: AsyncSession,
|
|
*,
|
|
skip: int = 0,
|
|
limit: int = 100,
|
|
active_only: bool = True,
|
|
with_user: bool = True,
|
|
) -> tuple[list[UserSession], int]:
|
|
"""Get all sessions across all users with pagination (admin only)."""
|
|
try:
|
|
query = select(UserSession)
|
|
|
|
if with_user:
|
|
query = query.options(joinedload(UserSession.user))
|
|
|
|
if active_only:
|
|
query = query.where(UserSession.is_active)
|
|
|
|
count_query = select(func.count(UserSession.id))
|
|
if active_only:
|
|
count_query = count_query.where(UserSession.is_active)
|
|
|
|
count_result = await db.execute(count_query)
|
|
total = count_result.scalar_one()
|
|
|
|
query = (
|
|
query.order_by(UserSession.last_used_at.desc())
|
|
.offset(skip)
|
|
.limit(limit)
|
|
)
|
|
|
|
result = await db.execute(query)
|
|
sessions = list(result.scalars().all())
|
|
|
|
return sessions, total
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error getting all sessions: {e!s}", exc_info=True)
|
|
raise
|
|
|
|
|
|
# Singleton instance
|
|
session_repo = SessionRepository(UserSession)
|