refactor(backend): enforce route→service→repo layered architecture
- introduce custom repository exception hierarchy (DuplicateEntryError, IntegrityConstraintError, InvalidInputError) replacing raw ValueError - eliminate all direct repository imports and raw SQL from route layer - add UserService, SessionService, OrganizationService to service layer - add get_stats/get_org_distribution service methods replacing admin inline SQL - fix timing side-channel in authenticate_user via dummy bcrypt check - replace SHA-256 client secret fallback with explicit InvalidClientError - replace assert with InvalidGrantError in authorization code exchange - replace N+1 token revocation loops with bulk UPDATE statements - rename oauth account token fields (drop misleading 'encrypted' suffix) - add Alembic migration 0003 for token field column rename - add 45 new service/repository tests; 975 passing, 94% coverage
This commit is contained in:
39
backend/app/repositories/__init__.py
Normal file
39
backend/app/repositories/__init__.py
Normal file
@@ -0,0 +1,39 @@
|
||||
# app/repositories/__init__.py
|
||||
"""Repository layer — all database access goes through these classes."""
|
||||
|
||||
from app.repositories.oauth_account import OAuthAccountRepository, oauth_account_repo
|
||||
from app.repositories.oauth_authorization_code import (
|
||||
OAuthAuthorizationCodeRepository,
|
||||
oauth_authorization_code_repo,
|
||||
)
|
||||
from app.repositories.oauth_client import OAuthClientRepository, oauth_client_repo
|
||||
from app.repositories.oauth_consent import OAuthConsentRepository, oauth_consent_repo
|
||||
from app.repositories.oauth_provider_token import (
|
||||
OAuthProviderTokenRepository,
|
||||
oauth_provider_token_repo,
|
||||
)
|
||||
from app.repositories.oauth_state import OAuthStateRepository, oauth_state_repo
|
||||
from app.repositories.organization import OrganizationRepository, organization_repo
|
||||
from app.repositories.session import SessionRepository, session_repo
|
||||
from app.repositories.user import UserRepository, user_repo
|
||||
|
||||
__all__ = [
|
||||
"UserRepository",
|
||||
"user_repo",
|
||||
"OrganizationRepository",
|
||||
"organization_repo",
|
||||
"SessionRepository",
|
||||
"session_repo",
|
||||
"OAuthAccountRepository",
|
||||
"oauth_account_repo",
|
||||
"OAuthAuthorizationCodeRepository",
|
||||
"oauth_authorization_code_repo",
|
||||
"OAuthClientRepository",
|
||||
"oauth_client_repo",
|
||||
"OAuthConsentRepository",
|
||||
"oauth_consent_repo",
|
||||
"OAuthProviderTokenRepository",
|
||||
"oauth_provider_token_repo",
|
||||
"OAuthStateRepository",
|
||||
"oauth_state_repo",
|
||||
]
|
||||
414
backend/app/repositories/base.py
Normal file
414
backend/app/repositories/base.py
Normal file
@@ -0,0 +1,414 @@
|
||||
# app/repositories/base.py
|
||||
"""
|
||||
Base repository class for async CRUD operations using SQLAlchemy 2.0 async patterns.
|
||||
|
||||
Provides reusable create, read, update, and delete operations for all models.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import uuid
|
||||
from datetime import UTC
|
||||
from typing import Any, TypeVar
|
||||
|
||||
from fastapi.encoders import jsonable_encoder
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy import func, select
|
||||
from sqlalchemy.exc import DataError, IntegrityError, OperationalError
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm import Load
|
||||
|
||||
from app.core.database import Base
|
||||
from app.core.repository_exceptions import (
|
||||
DuplicateEntryError,
|
||||
IntegrityConstraintError,
|
||||
InvalidInputError,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
ModelType = TypeVar("ModelType", bound=Base)
|
||||
CreateSchemaType = TypeVar("CreateSchemaType", bound=BaseModel)
|
||||
UpdateSchemaType = TypeVar("UpdateSchemaType", bound=BaseModel)
|
||||
|
||||
|
||||
class BaseRepository[
|
||||
ModelType: Base,
|
||||
CreateSchemaType: BaseModel,
|
||||
UpdateSchemaType: BaseModel,
|
||||
]:
|
||||
"""Async repository operations for a model."""
|
||||
|
||||
def __init__(self, model: type[ModelType]):
|
||||
"""
|
||||
Repository object with default async methods to Create, Read, Update, Delete.
|
||||
|
||||
Parameters:
|
||||
model: A SQLAlchemy model class
|
||||
"""
|
||||
self.model = model
|
||||
|
||||
async def get(
|
||||
self, db: AsyncSession, id: str, options: list[Load] | None = None
|
||||
) -> ModelType | None:
|
||||
"""
|
||||
Get a single record by ID with UUID validation and optional eager loading.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
id: Record UUID
|
||||
options: Optional list of SQLAlchemy load options (e.g., joinedload, selectinload)
|
||||
for eager loading relationships to prevent N+1 queries
|
||||
|
||||
Returns:
|
||||
Model instance or None if not found
|
||||
"""
|
||||
try:
|
||||
if isinstance(id, uuid.UUID):
|
||||
uuid_obj = id
|
||||
else:
|
||||
uuid_obj = uuid.UUID(str(id))
|
||||
except (ValueError, AttributeError, TypeError) as e:
|
||||
logger.warning(f"Invalid UUID format: {id} - {e!s}")
|
||||
return None
|
||||
|
||||
try:
|
||||
query = select(self.model).where(self.model.id == uuid_obj)
|
||||
|
||||
if options:
|
||||
for option in options:
|
||||
query = query.options(option)
|
||||
|
||||
result = await db.execute(query)
|
||||
return result.scalar_one_or_none()
|
||||
except Exception as e:
|
||||
logger.error(f"Error retrieving {self.model.__name__} with id {id}: {e!s}")
|
||||
raise
|
||||
|
||||
async def get_multi(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
*,
|
||||
skip: int = 0,
|
||||
limit: int = 100,
|
||||
options: list[Load] | None = None,
|
||||
) -> list[ModelType]:
|
||||
"""
|
||||
Get multiple records with pagination validation and optional eager loading.
|
||||
"""
|
||||
if skip < 0:
|
||||
raise InvalidInputError("skip must be non-negative")
|
||||
if limit < 0:
|
||||
raise InvalidInputError("limit must be non-negative")
|
||||
if limit > 1000:
|
||||
raise InvalidInputError("Maximum limit is 1000")
|
||||
|
||||
try:
|
||||
query = select(self.model).order_by(self.model.id).offset(skip).limit(limit)
|
||||
|
||||
if options:
|
||||
for option in options:
|
||||
query = query.options(option)
|
||||
|
||||
result = await db.execute(query)
|
||||
return list(result.scalars().all())
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error retrieving multiple {self.model.__name__} records: {e!s}"
|
||||
)
|
||||
raise
|
||||
|
||||
async def create(
|
||||
self, db: AsyncSession, *, obj_in: CreateSchemaType
|
||||
) -> ModelType: # pragma: no cover
|
||||
"""Create a new record with error handling.
|
||||
|
||||
NOTE: This method is defensive code that's never called in practice.
|
||||
All repository subclasses override this method with their own implementations.
|
||||
Marked as pragma: no cover to avoid false coverage gaps.
|
||||
"""
|
||||
try: # pragma: no cover
|
||||
obj_in_data = jsonable_encoder(obj_in)
|
||||
db_obj = self.model(**obj_in_data)
|
||||
db.add(db_obj)
|
||||
await db.commit()
|
||||
await db.refresh(db_obj)
|
||||
return db_obj
|
||||
except IntegrityError as e: # pragma: no cover
|
||||
await db.rollback()
|
||||
error_msg = str(e.orig) if hasattr(e, "orig") else str(e)
|
||||
if "unique" in error_msg.lower() or "duplicate" in error_msg.lower():
|
||||
logger.warning(
|
||||
f"Duplicate entry attempted for {self.model.__name__}: {error_msg}"
|
||||
)
|
||||
raise DuplicateEntryError(
|
||||
f"A {self.model.__name__} with this data already exists"
|
||||
)
|
||||
logger.error(f"Integrity error creating {self.model.__name__}: {error_msg}")
|
||||
raise IntegrityConstraintError(f"Database integrity error: {error_msg}")
|
||||
except (OperationalError, DataError) as e: # pragma: no cover
|
||||
await db.rollback()
|
||||
logger.error(f"Database error creating {self.model.__name__}: {e!s}")
|
||||
raise IntegrityConstraintError(f"Database operation failed: {e!s}")
|
||||
except Exception as e: # pragma: no cover
|
||||
await db.rollback()
|
||||
logger.error(
|
||||
f"Unexpected error creating {self.model.__name__}: {e!s}", exc_info=True
|
||||
)
|
||||
raise
|
||||
|
||||
async def update(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
*,
|
||||
db_obj: ModelType,
|
||||
obj_in: UpdateSchemaType | dict[str, Any],
|
||||
) -> ModelType:
|
||||
"""Update a record with error handling."""
|
||||
try:
|
||||
obj_data = jsonable_encoder(db_obj)
|
||||
if isinstance(obj_in, dict):
|
||||
update_data = obj_in
|
||||
else:
|
||||
update_data = obj_in.model_dump(exclude_unset=True)
|
||||
|
||||
for field in obj_data:
|
||||
if field in update_data:
|
||||
setattr(db_obj, field, update_data[field])
|
||||
|
||||
db.add(db_obj)
|
||||
await db.commit()
|
||||
await db.refresh(db_obj)
|
||||
return db_obj
|
||||
except IntegrityError as e:
|
||||
await db.rollback()
|
||||
error_msg = str(e.orig) if hasattr(e, "orig") else str(e)
|
||||
if "unique" in error_msg.lower() or "duplicate" in error_msg.lower():
|
||||
logger.warning(
|
||||
f"Duplicate entry attempted for {self.model.__name__}: {error_msg}"
|
||||
)
|
||||
raise DuplicateEntryError(
|
||||
f"A {self.model.__name__} with this data already exists"
|
||||
)
|
||||
logger.error(f"Integrity error updating {self.model.__name__}: {error_msg}")
|
||||
raise IntegrityConstraintError(f"Database integrity error: {error_msg}")
|
||||
except (OperationalError, DataError) as e:
|
||||
await db.rollback()
|
||||
logger.error(f"Database error updating {self.model.__name__}: {e!s}")
|
||||
raise IntegrityConstraintError(f"Database operation failed: {e!s}")
|
||||
except Exception as e:
|
||||
await db.rollback()
|
||||
logger.error(
|
||||
f"Unexpected error updating {self.model.__name__}: {e!s}", exc_info=True
|
||||
)
|
||||
raise
|
||||
|
||||
async def remove(self, db: AsyncSession, *, id: str) -> ModelType | None:
|
||||
"""Delete a record with error handling and null check."""
|
||||
try:
|
||||
if isinstance(id, uuid.UUID):
|
||||
uuid_obj = id
|
||||
else:
|
||||
uuid_obj = uuid.UUID(str(id))
|
||||
except (ValueError, AttributeError, TypeError) as e:
|
||||
logger.warning(f"Invalid UUID format for deletion: {id} - {e!s}")
|
||||
return None
|
||||
|
||||
try:
|
||||
result = await db.execute(
|
||||
select(self.model).where(self.model.id == uuid_obj)
|
||||
)
|
||||
obj = result.scalar_one_or_none()
|
||||
|
||||
if obj is None:
|
||||
logger.warning(
|
||||
f"{self.model.__name__} with id {id} not found for deletion"
|
||||
)
|
||||
return None
|
||||
|
||||
await db.delete(obj)
|
||||
await db.commit()
|
||||
return obj
|
||||
except IntegrityError as e:
|
||||
await db.rollback()
|
||||
error_msg = str(e.orig) if hasattr(e, "orig") else str(e)
|
||||
logger.error(f"Integrity error deleting {self.model.__name__}: {error_msg}")
|
||||
raise IntegrityConstraintError(
|
||||
f"Cannot delete {self.model.__name__}: referenced by other records"
|
||||
)
|
||||
except Exception as e:
|
||||
await db.rollback()
|
||||
logger.error(
|
||||
f"Error deleting {self.model.__name__} with id {id}: {e!s}",
|
||||
exc_info=True,
|
||||
)
|
||||
raise
|
||||
|
||||
async def get_multi_with_total(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
*,
|
||||
skip: int = 0,
|
||||
limit: int = 100,
|
||||
sort_by: str | None = None,
|
||||
sort_order: str = "asc",
|
||||
filters: dict[str, Any] | None = None,
|
||||
) -> tuple[list[ModelType], int]: # pragma: no cover
|
||||
"""
|
||||
Get multiple records with total count, filtering, and sorting.
|
||||
|
||||
NOTE: This method is defensive code that's never called in practice.
|
||||
All repository subclasses override this method with their own implementations.
|
||||
Marked as pragma: no cover to avoid false coverage gaps.
|
||||
"""
|
||||
if skip < 0:
|
||||
raise InvalidInputError("skip must be non-negative")
|
||||
if limit < 0:
|
||||
raise InvalidInputError("limit must be non-negative")
|
||||
if limit > 1000:
|
||||
raise InvalidInputError("Maximum limit is 1000")
|
||||
|
||||
try:
|
||||
query = select(self.model)
|
||||
|
||||
if hasattr(self.model, "deleted_at"):
|
||||
query = query.where(self.model.deleted_at.is_(None))
|
||||
|
||||
if filters:
|
||||
for field, value in filters.items():
|
||||
if hasattr(self.model, field) and value is not None:
|
||||
query = query.where(getattr(self.model, field) == value)
|
||||
|
||||
count_query = select(func.count()).select_from(query.alias())
|
||||
count_result = await db.execute(count_query)
|
||||
total = count_result.scalar_one()
|
||||
|
||||
if sort_by and hasattr(self.model, sort_by):
|
||||
sort_column = getattr(self.model, sort_by)
|
||||
if sort_order.lower() == "desc":
|
||||
query = query.order_by(sort_column.desc())
|
||||
else:
|
||||
query = query.order_by(sort_column.asc())
|
||||
else:
|
||||
query = query.order_by(self.model.id)
|
||||
|
||||
query = query.offset(skip).limit(limit)
|
||||
items_result = await db.execute(query)
|
||||
items = list(items_result.scalars().all())
|
||||
|
||||
return items, total
|
||||
except Exception as e: # pragma: no cover
|
||||
logger.error(
|
||||
f"Error retrieving paginated {self.model.__name__} records: {e!s}"
|
||||
)
|
||||
raise
|
||||
|
||||
async def count(self, db: AsyncSession) -> int:
|
||||
"""Get total count of records."""
|
||||
try:
|
||||
result = await db.execute(select(func.count(self.model.id)))
|
||||
return result.scalar_one()
|
||||
except Exception as e:
|
||||
logger.error(f"Error counting {self.model.__name__} records: {e!s}")
|
||||
raise
|
||||
|
||||
async def exists(self, db: AsyncSession, id: str) -> bool:
|
||||
"""Check if a record exists by ID."""
|
||||
obj = await self.get(db, id=id)
|
||||
return obj is not None
|
||||
|
||||
async def soft_delete(self, db: AsyncSession, *, id: str) -> ModelType | None:
|
||||
"""
|
||||
Soft delete a record by setting deleted_at timestamp.
|
||||
|
||||
Only works if the model has a 'deleted_at' column.
|
||||
"""
|
||||
from datetime import datetime
|
||||
|
||||
try:
|
||||
if isinstance(id, uuid.UUID):
|
||||
uuid_obj = id
|
||||
else:
|
||||
uuid_obj = uuid.UUID(str(id))
|
||||
except (ValueError, AttributeError, TypeError) as e:
|
||||
logger.warning(f"Invalid UUID format for soft deletion: {id} - {e!s}")
|
||||
return None
|
||||
|
||||
try:
|
||||
result = await db.execute(
|
||||
select(self.model).where(self.model.id == uuid_obj)
|
||||
)
|
||||
obj = result.scalar_one_or_none()
|
||||
|
||||
if obj is None:
|
||||
logger.warning(
|
||||
f"{self.model.__name__} with id {id} not found for soft deletion"
|
||||
)
|
||||
return None
|
||||
|
||||
if not hasattr(self.model, "deleted_at"):
|
||||
logger.error(f"{self.model.__name__} does not support soft deletes")
|
||||
raise InvalidInputError(
|
||||
f"{self.model.__name__} does not have a deleted_at column"
|
||||
)
|
||||
|
||||
obj.deleted_at = datetime.now(UTC)
|
||||
db.add(obj)
|
||||
await db.commit()
|
||||
await db.refresh(obj)
|
||||
return obj
|
||||
except Exception as e:
|
||||
await db.rollback()
|
||||
logger.error(
|
||||
f"Error soft deleting {self.model.__name__} with id {id}: {e!s}",
|
||||
exc_info=True,
|
||||
)
|
||||
raise
|
||||
|
||||
async def restore(self, db: AsyncSession, *, id: str) -> ModelType | None:
|
||||
"""
|
||||
Restore a soft-deleted record by clearing the deleted_at timestamp.
|
||||
|
||||
Only works if the model has a 'deleted_at' column.
|
||||
"""
|
||||
try:
|
||||
if isinstance(id, uuid.UUID):
|
||||
uuid_obj = id
|
||||
else:
|
||||
uuid_obj = uuid.UUID(str(id))
|
||||
except (ValueError, AttributeError, TypeError) as e:
|
||||
logger.warning(f"Invalid UUID format for restoration: {id} - {e!s}")
|
||||
return None
|
||||
|
||||
try:
|
||||
if hasattr(self.model, "deleted_at"):
|
||||
result = await db.execute(
|
||||
select(self.model).where(
|
||||
self.model.id == uuid_obj, self.model.deleted_at.isnot(None)
|
||||
)
|
||||
)
|
||||
obj = result.scalar_one_or_none()
|
||||
else:
|
||||
logger.error(f"{self.model.__name__} does not support soft deletes")
|
||||
raise InvalidInputError(
|
||||
f"{self.model.__name__} does not have a deleted_at column"
|
||||
)
|
||||
|
||||
if obj is None:
|
||||
logger.warning(
|
||||
f"Soft-deleted {self.model.__name__} with id {id} not found for restoration"
|
||||
)
|
||||
return None
|
||||
|
||||
obj.deleted_at = None
|
||||
db.add(obj)
|
||||
await db.commit()
|
||||
await db.refresh(obj)
|
||||
return obj
|
||||
except Exception as e:
|
||||
await db.rollback()
|
||||
logger.error(
|
||||
f"Error restoring {self.model.__name__} with id {id}: {e!s}",
|
||||
exc_info=True,
|
||||
)
|
||||
raise
|
||||
|
||||
235
backend/app/repositories/oauth_account.py
Normal file
235
backend/app/repositories/oauth_account.py
Normal file
@@ -0,0 +1,235 @@
|
||||
# app/repositories/oauth_account.py
|
||||
"""Repository for OAuthAccount model async CRUD operations."""
|
||||
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from uuid import UUID
|
||||
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy import and_, delete, select
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm import joinedload
|
||||
|
||||
from app.core.repository_exceptions import DuplicateEntryError
|
||||
from app.models.oauth_account import OAuthAccount
|
||||
from app.repositories.base import BaseRepository
|
||||
from app.schemas.oauth import OAuthAccountCreate
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class EmptySchema(BaseModel):
|
||||
"""Placeholder schema for repository operations that don't need update schemas."""
|
||||
|
||||
|
||||
class OAuthAccountRepository(BaseRepository[OAuthAccount, OAuthAccountCreate, EmptySchema]):
|
||||
"""Repository for OAuth account links."""
|
||||
|
||||
async def get_by_provider_id(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
*,
|
||||
provider: str,
|
||||
provider_user_id: str,
|
||||
) -> OAuthAccount | None:
|
||||
"""Get OAuth account by provider and provider user ID."""
|
||||
try:
|
||||
result = await db.execute(
|
||||
select(OAuthAccount)
|
||||
.where(
|
||||
and_(
|
||||
OAuthAccount.provider == provider,
|
||||
OAuthAccount.provider_user_id == provider_user_id,
|
||||
)
|
||||
)
|
||||
.options(joinedload(OAuthAccount.user))
|
||||
)
|
||||
return result.scalar_one_or_none()
|
||||
except Exception as e: # pragma: no cover
|
||||
logger.error(
|
||||
f"Error getting OAuth account for {provider}:{provider_user_id}: {e!s}"
|
||||
)
|
||||
raise
|
||||
|
||||
async def get_by_provider_email(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
*,
|
||||
provider: str,
|
||||
email: str,
|
||||
) -> OAuthAccount | None:
|
||||
"""Get OAuth account by provider and email."""
|
||||
try:
|
||||
result = await db.execute(
|
||||
select(OAuthAccount)
|
||||
.where(
|
||||
and_(
|
||||
OAuthAccount.provider == provider,
|
||||
OAuthAccount.provider_email == email,
|
||||
)
|
||||
)
|
||||
.options(joinedload(OAuthAccount.user))
|
||||
)
|
||||
return result.scalar_one_or_none()
|
||||
except Exception as e: # pragma: no cover
|
||||
logger.error(
|
||||
f"Error getting OAuth account for {provider} email {email}: {e!s}"
|
||||
)
|
||||
raise
|
||||
|
||||
async def get_user_accounts(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
*,
|
||||
user_id: str | UUID,
|
||||
) -> list[OAuthAccount]:
|
||||
"""Get all OAuth accounts linked to a user."""
|
||||
try:
|
||||
user_uuid = UUID(str(user_id)) if isinstance(user_id, str) else user_id
|
||||
|
||||
result = await db.execute(
|
||||
select(OAuthAccount)
|
||||
.where(OAuthAccount.user_id == user_uuid)
|
||||
.order_by(OAuthAccount.created_at.desc())
|
||||
)
|
||||
return list(result.scalars().all())
|
||||
except Exception as e: # pragma: no cover
|
||||
logger.error(f"Error getting OAuth accounts for user {user_id}: {e!s}")
|
||||
raise
|
||||
|
||||
async def get_user_account_by_provider(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
*,
|
||||
user_id: str | UUID,
|
||||
provider: str,
|
||||
) -> OAuthAccount | None:
|
||||
"""Get a specific OAuth account for a user and provider."""
|
||||
try:
|
||||
user_uuid = UUID(str(user_id)) if isinstance(user_id, str) else user_id
|
||||
|
||||
result = await db.execute(
|
||||
select(OAuthAccount).where(
|
||||
and_(
|
||||
OAuthAccount.user_id == user_uuid,
|
||||
OAuthAccount.provider == provider,
|
||||
)
|
||||
)
|
||||
)
|
||||
return result.scalar_one_or_none()
|
||||
except Exception as e: # pragma: no cover
|
||||
logger.error(
|
||||
f"Error getting OAuth account for user {user_id}, provider {provider}: {e!s}"
|
||||
)
|
||||
raise
|
||||
|
||||
async def create_account(
|
||||
self, db: AsyncSession, *, obj_in: OAuthAccountCreate
|
||||
) -> OAuthAccount:
|
||||
"""Create a new OAuth account link."""
|
||||
try:
|
||||
db_obj = OAuthAccount(
|
||||
user_id=obj_in.user_id,
|
||||
provider=obj_in.provider,
|
||||
provider_user_id=obj_in.provider_user_id,
|
||||
provider_email=obj_in.provider_email,
|
||||
access_token=obj_in.access_token,
|
||||
refresh_token=obj_in.refresh_token,
|
||||
token_expires_at=obj_in.token_expires_at,
|
||||
)
|
||||
db.add(db_obj)
|
||||
await db.commit()
|
||||
await db.refresh(db_obj)
|
||||
|
||||
logger.info(
|
||||
f"OAuth account created: {obj_in.provider} linked to user {obj_in.user_id}"
|
||||
)
|
||||
return db_obj
|
||||
except IntegrityError as e: # pragma: no cover
|
||||
await db.rollback()
|
||||
error_msg = str(e.orig) if hasattr(e, "orig") else str(e)
|
||||
if "uq_oauth_provider_user" in error_msg.lower():
|
||||
logger.warning(
|
||||
f"OAuth account already exists: {obj_in.provider}:{obj_in.provider_user_id}"
|
||||
)
|
||||
raise DuplicateEntryError(
|
||||
f"This {obj_in.provider} account is already linked to another user"
|
||||
)
|
||||
logger.error(f"Integrity error creating OAuth account: {error_msg}")
|
||||
raise DuplicateEntryError(f"Failed to create OAuth account: {error_msg}")
|
||||
except Exception as e: # pragma: no cover
|
||||
await db.rollback()
|
||||
logger.error(f"Error creating OAuth account: {e!s}", exc_info=True)
|
||||
raise
|
||||
|
||||
async def delete_account(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
*,
|
||||
user_id: str | UUID,
|
||||
provider: str,
|
||||
) -> bool:
|
||||
"""Delete an OAuth account link."""
|
||||
try:
|
||||
user_uuid = UUID(str(user_id)) if isinstance(user_id, str) else user_id
|
||||
|
||||
result = await db.execute(
|
||||
delete(OAuthAccount).where(
|
||||
and_(
|
||||
OAuthAccount.user_id == user_uuid,
|
||||
OAuthAccount.provider == provider,
|
||||
)
|
||||
)
|
||||
)
|
||||
await db.commit()
|
||||
|
||||
deleted = result.rowcount > 0
|
||||
if deleted:
|
||||
logger.info(
|
||||
f"OAuth account deleted: {provider} unlinked from user {user_id}"
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
f"OAuth account not found for deletion: {provider} for user {user_id}"
|
||||
)
|
||||
|
||||
return deleted
|
||||
except Exception as e: # pragma: no cover
|
||||
await db.rollback()
|
||||
logger.error(
|
||||
f"Error deleting OAuth account {provider} for user {user_id}: {e!s}"
|
||||
)
|
||||
raise
|
||||
|
||||
async def update_tokens(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
*,
|
||||
account: OAuthAccount,
|
||||
access_token: str | None = None,
|
||||
refresh_token: str | None = None,
|
||||
token_expires_at: datetime | None = None,
|
||||
) -> OAuthAccount:
|
||||
"""Update OAuth tokens for an account."""
|
||||
try:
|
||||
if access_token is not None:
|
||||
account.access_token = access_token
|
||||
if refresh_token is not None:
|
||||
account.refresh_token = refresh_token
|
||||
if token_expires_at is not None:
|
||||
account.token_expires_at = token_expires_at
|
||||
|
||||
db.add(account)
|
||||
await db.commit()
|
||||
await db.refresh(account)
|
||||
|
||||
return account
|
||||
except Exception as e: # pragma: no cover
|
||||
await db.rollback()
|
||||
logger.error(f"Error updating OAuth tokens: {e!s}")
|
||||
raise
|
||||
|
||||
|
||||
# Singleton instance
|
||||
oauth_account_repo = OAuthAccountRepository(OAuthAccount)
|
||||
108
backend/app/repositories/oauth_authorization_code.py
Normal file
108
backend/app/repositories/oauth_authorization_code.py
Normal file
@@ -0,0 +1,108 @@
|
||||
# app/repositories/oauth_authorization_code.py
|
||||
"""Repository for OAuthAuthorizationCode model."""
|
||||
|
||||
import logging
|
||||
from datetime import UTC, datetime
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy import and_, delete, select, update
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.models.oauth_authorization_code import OAuthAuthorizationCode
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class OAuthAuthorizationCodeRepository:
|
||||
"""Repository for OAuth 2.0 authorization codes."""
|
||||
|
||||
async def create_code(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
*,
|
||||
code: str,
|
||||
client_id: str,
|
||||
user_id: UUID,
|
||||
redirect_uri: str,
|
||||
scope: str,
|
||||
expires_at: datetime,
|
||||
code_challenge: str | None = None,
|
||||
code_challenge_method: str | None = None,
|
||||
state: str | None = None,
|
||||
nonce: str | None = None,
|
||||
) -> OAuthAuthorizationCode:
|
||||
"""Create and persist a new authorization code."""
|
||||
auth_code = OAuthAuthorizationCode(
|
||||
code=code,
|
||||
client_id=client_id,
|
||||
user_id=user_id,
|
||||
redirect_uri=redirect_uri,
|
||||
scope=scope,
|
||||
code_challenge=code_challenge,
|
||||
code_challenge_method=code_challenge_method,
|
||||
state=state,
|
||||
nonce=nonce,
|
||||
expires_at=expires_at,
|
||||
used=False,
|
||||
)
|
||||
db.add(auth_code)
|
||||
await db.commit()
|
||||
return auth_code
|
||||
|
||||
async def consume_code_atomically(
|
||||
self, db: AsyncSession, *, code: str
|
||||
) -> UUID | None:
|
||||
"""
|
||||
Atomically mark a code as used and return its UUID.
|
||||
|
||||
Returns the UUID if the code was found and not yet used, None otherwise.
|
||||
This prevents race conditions per RFC 6749 Section 4.1.2.
|
||||
"""
|
||||
stmt = (
|
||||
update(OAuthAuthorizationCode)
|
||||
.where(
|
||||
and_(
|
||||
OAuthAuthorizationCode.code == code,
|
||||
OAuthAuthorizationCode.used == False, # noqa: E712
|
||||
)
|
||||
)
|
||||
.values(used=True)
|
||||
.returning(OAuthAuthorizationCode.id)
|
||||
)
|
||||
result = await db.execute(stmt)
|
||||
row_id = result.scalar_one_or_none()
|
||||
if row_id is not None:
|
||||
await db.commit()
|
||||
return row_id
|
||||
|
||||
async def get_by_id(
|
||||
self, db: AsyncSession, *, code_id: UUID
|
||||
) -> OAuthAuthorizationCode | None:
|
||||
"""Get authorization code by its UUID primary key."""
|
||||
result = await db.execute(
|
||||
select(OAuthAuthorizationCode).where(OAuthAuthorizationCode.id == code_id)
|
||||
)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
async def get_by_code(
|
||||
self, db: AsyncSession, *, code: str
|
||||
) -> OAuthAuthorizationCode | None:
|
||||
"""Get authorization code by the code string value."""
|
||||
result = await db.execute(
|
||||
select(OAuthAuthorizationCode).where(OAuthAuthorizationCode.code == code)
|
||||
)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
async def cleanup_expired(self, db: AsyncSession) -> int:
|
||||
"""Delete all expired authorization codes. Returns count deleted."""
|
||||
result = await db.execute(
|
||||
delete(OAuthAuthorizationCode).where(
|
||||
OAuthAuthorizationCode.expires_at < datetime.now(UTC)
|
||||
)
|
||||
)
|
||||
await db.commit()
|
||||
return result.rowcount # type: ignore[attr-defined]
|
||||
|
||||
|
||||
# Singleton instance
|
||||
oauth_authorization_code_repo = OAuthAuthorizationCodeRepository()
|
||||
199
backend/app/repositories/oauth_client.py
Normal file
199
backend/app/repositories/oauth_client.py
Normal file
@@ -0,0 +1,199 @@
|
||||
# app/repositories/oauth_client.py
|
||||
"""Repository for OAuthClient model async CRUD operations."""
|
||||
|
||||
import logging
|
||||
import secrets
|
||||
from uuid import UUID
|
||||
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy import and_, delete, select
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.core.repository_exceptions import DuplicateEntryError
|
||||
from app.models.oauth_client import OAuthClient
|
||||
from app.repositories.base import BaseRepository
|
||||
from app.schemas.oauth import OAuthClientCreate
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class EmptySchema(BaseModel):
|
||||
"""Placeholder schema for repository operations that don't need update schemas."""
|
||||
|
||||
|
||||
class OAuthClientRepository(BaseRepository[OAuthClient, OAuthClientCreate, EmptySchema]):
|
||||
"""Repository for OAuth clients (provider mode)."""
|
||||
|
||||
async def get_by_client_id(
|
||||
self, db: AsyncSession, *, client_id: str
|
||||
) -> OAuthClient | None:
|
||||
"""Get OAuth client by client_id."""
|
||||
try:
|
||||
result = await db.execute(
|
||||
select(OAuthClient).where(
|
||||
and_(
|
||||
OAuthClient.client_id == client_id,
|
||||
OAuthClient.is_active == True, # noqa: E712
|
||||
)
|
||||
)
|
||||
)
|
||||
return result.scalar_one_or_none()
|
||||
except Exception as e: # pragma: no cover
|
||||
logger.error(f"Error getting OAuth client {client_id}: {e!s}")
|
||||
raise
|
||||
|
||||
async def create_client(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
*,
|
||||
obj_in: OAuthClientCreate,
|
||||
owner_user_id: UUID | None = None,
|
||||
) -> tuple[OAuthClient, str | None]:
|
||||
"""Create a new OAuth client."""
|
||||
try:
|
||||
client_id = secrets.token_urlsafe(32)
|
||||
|
||||
client_secret = None
|
||||
client_secret_hash = None
|
||||
if obj_in.client_type == "confidential":
|
||||
client_secret = secrets.token_urlsafe(48)
|
||||
from app.core.auth import get_password_hash
|
||||
|
||||
client_secret_hash = get_password_hash(client_secret)
|
||||
|
||||
db_obj = OAuthClient(
|
||||
client_id=client_id,
|
||||
client_secret_hash=client_secret_hash,
|
||||
client_name=obj_in.client_name,
|
||||
client_description=obj_in.client_description,
|
||||
client_type=obj_in.client_type,
|
||||
redirect_uris=obj_in.redirect_uris,
|
||||
allowed_scopes=obj_in.allowed_scopes,
|
||||
owner_user_id=owner_user_id,
|
||||
is_active=True,
|
||||
)
|
||||
db.add(db_obj)
|
||||
await db.commit()
|
||||
await db.refresh(db_obj)
|
||||
|
||||
logger.info(
|
||||
f"OAuth client created: {obj_in.client_name} ({client_id[:8]}...)"
|
||||
)
|
||||
return db_obj, client_secret
|
||||
except IntegrityError as e: # pragma: no cover
|
||||
await db.rollback()
|
||||
error_msg = str(e.orig) if hasattr(e, "orig") else str(e)
|
||||
logger.error(f"Error creating OAuth client: {error_msg}")
|
||||
raise DuplicateEntryError(f"Failed to create OAuth client: {error_msg}")
|
||||
except Exception as e: # pragma: no cover
|
||||
await db.rollback()
|
||||
logger.error(f"Error creating OAuth client: {e!s}", exc_info=True)
|
||||
raise
|
||||
|
||||
async def deactivate_client(
|
||||
self, db: AsyncSession, *, client_id: str
|
||||
) -> OAuthClient | None:
|
||||
"""Deactivate an OAuth client."""
|
||||
try:
|
||||
client = await self.get_by_client_id(db, client_id=client_id)
|
||||
if client is None:
|
||||
return None
|
||||
|
||||
client.is_active = False
|
||||
db.add(client)
|
||||
await db.commit()
|
||||
await db.refresh(client)
|
||||
|
||||
logger.info(f"OAuth client deactivated: {client.client_name}")
|
||||
return client
|
||||
except Exception as e: # pragma: no cover
|
||||
await db.rollback()
|
||||
logger.error(f"Error deactivating OAuth client {client_id}: {e!s}")
|
||||
raise
|
||||
|
||||
async def validate_redirect_uri(
|
||||
self, db: AsyncSession, *, client_id: str, redirect_uri: str
|
||||
) -> bool:
|
||||
"""Validate that a redirect URI is allowed for a client."""
|
||||
try:
|
||||
client = await self.get_by_client_id(db, client_id=client_id)
|
||||
if client is None:
|
||||
return False
|
||||
|
||||
return redirect_uri in (client.redirect_uris or [])
|
||||
except Exception as e: # pragma: no cover
|
||||
logger.error(f"Error validating redirect URI: {e!s}")
|
||||
return False
|
||||
|
||||
async def verify_client_secret(
|
||||
self, db: AsyncSession, *, client_id: str, client_secret: str
|
||||
) -> bool:
|
||||
"""Verify client credentials."""
|
||||
try:
|
||||
result = await db.execute(
|
||||
select(OAuthClient).where(
|
||||
and_(
|
||||
OAuthClient.client_id == client_id,
|
||||
OAuthClient.is_active == True, # noqa: E712
|
||||
)
|
||||
)
|
||||
)
|
||||
client = result.scalar_one_or_none()
|
||||
|
||||
if client is None or client.client_secret_hash is None:
|
||||
return False
|
||||
|
||||
from app.core.auth import verify_password
|
||||
|
||||
stored_hash: str = str(client.client_secret_hash)
|
||||
|
||||
if stored_hash.startswith("$2"):
|
||||
return verify_password(client_secret, stored_hash)
|
||||
else:
|
||||
import hashlib
|
||||
|
||||
secret_hash = hashlib.sha256(client_secret.encode()).hexdigest()
|
||||
return secrets.compare_digest(stored_hash, secret_hash)
|
||||
except Exception as e: # pragma: no cover
|
||||
logger.error(f"Error verifying client secret: {e!s}")
|
||||
return False
|
||||
|
||||
async def get_all_clients(
|
||||
self, db: AsyncSession, *, include_inactive: bool = False
|
||||
) -> list[OAuthClient]:
|
||||
"""Get all OAuth clients."""
|
||||
try:
|
||||
query = select(OAuthClient).order_by(OAuthClient.created_at.desc())
|
||||
if not include_inactive:
|
||||
query = query.where(OAuthClient.is_active == True) # noqa: E712
|
||||
|
||||
result = await db.execute(query)
|
||||
return list(result.scalars().all())
|
||||
except Exception as e: # pragma: no cover
|
||||
logger.error(f"Error getting all OAuth clients: {e!s}")
|
||||
raise
|
||||
|
||||
async def delete_client(self, db: AsyncSession, *, client_id: str) -> bool:
|
||||
"""Delete an OAuth client permanently."""
|
||||
try:
|
||||
result = await db.execute(
|
||||
delete(OAuthClient).where(OAuthClient.client_id == client_id)
|
||||
)
|
||||
await db.commit()
|
||||
|
||||
deleted = result.rowcount > 0
|
||||
if deleted:
|
||||
logger.info(f"OAuth client deleted: {client_id}")
|
||||
else:
|
||||
logger.warning(f"OAuth client not found for deletion: {client_id}")
|
||||
|
||||
return deleted
|
||||
except Exception as e: # pragma: no cover
|
||||
await db.rollback()
|
||||
logger.error(f"Error deleting OAuth client {client_id}: {e!s}")
|
||||
raise
|
||||
|
||||
|
||||
# Singleton instance
|
||||
oauth_client_repo = OAuthClientRepository(OAuthClient)
|
||||
112
backend/app/repositories/oauth_consent.py
Normal file
112
backend/app/repositories/oauth_consent.py
Normal file
@@ -0,0 +1,112 @@
|
||||
# app/repositories/oauth_consent.py
|
||||
"""Repository for OAuthConsent model."""
|
||||
|
||||
import logging
|
||||
from uuid import UUID
|
||||
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy import and_, delete, select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.models.oauth_client import OAuthClient
|
||||
from app.models.oauth_provider_token import OAuthConsent
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class OAuthConsentRepository:
|
||||
"""Repository for OAuth consent records (user grants to clients)."""
|
||||
|
||||
async def get_consent(
|
||||
self, db: AsyncSession, *, user_id: UUID, client_id: str
|
||||
) -> OAuthConsent | None:
|
||||
"""Get the consent record for a user-client pair, or None if not found."""
|
||||
result = await db.execute(
|
||||
select(OAuthConsent).where(
|
||||
and_(
|
||||
OAuthConsent.user_id == user_id,
|
||||
OAuthConsent.client_id == client_id,
|
||||
)
|
||||
)
|
||||
)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
async def grant_consent(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
*,
|
||||
user_id: UUID,
|
||||
client_id: str,
|
||||
scopes: list[str],
|
||||
) -> OAuthConsent:
|
||||
"""
|
||||
Create or update consent for a user-client pair.
|
||||
|
||||
If consent already exists, the new scopes are merged with existing ones.
|
||||
Returns the created or updated consent record.
|
||||
"""
|
||||
consent = await self.get_consent(db, user_id=user_id, client_id=client_id)
|
||||
|
||||
if consent:
|
||||
existing = set(consent.granted_scopes.split()) if consent.granted_scopes else set()
|
||||
merged = existing | set(scopes)
|
||||
consent.granted_scopes = " ".join(sorted(merged)) # type: ignore[assignment]
|
||||
else:
|
||||
consent = OAuthConsent(
|
||||
user_id=user_id,
|
||||
client_id=client_id,
|
||||
granted_scopes=" ".join(sorted(set(scopes))),
|
||||
)
|
||||
db.add(consent)
|
||||
|
||||
await db.commit()
|
||||
await db.refresh(consent)
|
||||
return consent
|
||||
|
||||
async def get_user_consents_with_clients(
|
||||
self, db: AsyncSession, *, user_id: UUID
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Get all consent records for a user joined with client details."""
|
||||
result = await db.execute(
|
||||
select(OAuthConsent, OAuthClient)
|
||||
.join(OAuthClient, OAuthConsent.client_id == OAuthClient.client_id)
|
||||
.where(OAuthConsent.user_id == user_id)
|
||||
)
|
||||
rows = result.all()
|
||||
return [
|
||||
{
|
||||
"client_id": consent.client_id,
|
||||
"client_name": client.client_name,
|
||||
"client_description": client.client_description,
|
||||
"granted_scopes": consent.granted_scopes.split()
|
||||
if consent.granted_scopes
|
||||
else [],
|
||||
"granted_at": consent.created_at.isoformat(),
|
||||
}
|
||||
for consent, client in rows
|
||||
]
|
||||
|
||||
async def revoke_consent(
|
||||
self, db: AsyncSession, *, user_id: UUID, client_id: str
|
||||
) -> bool:
|
||||
"""
|
||||
Delete the consent record for a user-client pair.
|
||||
|
||||
Returns True if a record was found and deleted.
|
||||
Note: Callers are responsible for also revoking associated tokens.
|
||||
"""
|
||||
result = await db.execute(
|
||||
delete(OAuthConsent).where(
|
||||
and_(
|
||||
OAuthConsent.user_id == user_id,
|
||||
OAuthConsent.client_id == client_id,
|
||||
)
|
||||
)
|
||||
)
|
||||
await db.commit()
|
||||
return result.rowcount > 0 # type: ignore[attr-defined]
|
||||
|
||||
|
||||
# Singleton instance
|
||||
oauth_consent_repo = OAuthConsentRepository()
|
||||
146
backend/app/repositories/oauth_provider_token.py
Normal file
146
backend/app/repositories/oauth_provider_token.py
Normal file
@@ -0,0 +1,146 @@
|
||||
# app/repositories/oauth_provider_token.py
|
||||
"""Repository for OAuthProviderRefreshToken model."""
|
||||
|
||||
import logging
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy import and_, delete, select, update
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.models.oauth_provider_token import OAuthProviderRefreshToken
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class OAuthProviderTokenRepository:
|
||||
"""Repository for OAuth provider refresh tokens."""
|
||||
|
||||
async def create_token(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
*,
|
||||
token_hash: str,
|
||||
jti: str,
|
||||
client_id: str,
|
||||
user_id: UUID,
|
||||
scope: str,
|
||||
expires_at: datetime,
|
||||
device_info: str | None = None,
|
||||
ip_address: str | None = None,
|
||||
) -> OAuthProviderRefreshToken:
|
||||
"""Create and persist a new refresh token record."""
|
||||
token = OAuthProviderRefreshToken(
|
||||
token_hash=token_hash,
|
||||
jti=jti,
|
||||
client_id=client_id,
|
||||
user_id=user_id,
|
||||
scope=scope,
|
||||
expires_at=expires_at,
|
||||
device_info=device_info,
|
||||
ip_address=ip_address,
|
||||
)
|
||||
db.add(token)
|
||||
await db.commit()
|
||||
return token
|
||||
|
||||
async def get_by_token_hash(
|
||||
self, db: AsyncSession, *, token_hash: str
|
||||
) -> OAuthProviderRefreshToken | None:
|
||||
"""Get refresh token record by SHA-256 token hash."""
|
||||
result = await db.execute(
|
||||
select(OAuthProviderRefreshToken).where(
|
||||
OAuthProviderRefreshToken.token_hash == token_hash
|
||||
)
|
||||
)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
async def get_by_jti(
|
||||
self, db: AsyncSession, *, jti: str
|
||||
) -> OAuthProviderRefreshToken | None:
|
||||
"""Get refresh token record by JWT ID (JTI)."""
|
||||
result = await db.execute(
|
||||
select(OAuthProviderRefreshToken).where(
|
||||
OAuthProviderRefreshToken.jti == jti
|
||||
)
|
||||
)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
async def revoke(
|
||||
self, db: AsyncSession, *, token: OAuthProviderRefreshToken
|
||||
) -> None:
|
||||
"""Mark a specific token record as revoked."""
|
||||
token.revoked = True # type: ignore[assignment]
|
||||
token.last_used_at = datetime.now(UTC) # type: ignore[assignment]
|
||||
await db.commit()
|
||||
|
||||
async def revoke_all_for_user_client(
|
||||
self, db: AsyncSession, *, user_id: UUID, client_id: str
|
||||
) -> int:
|
||||
"""
|
||||
Revoke all active tokens for a specific user-client pair.
|
||||
|
||||
Used when security incidents are detected (e.g., authorization code reuse).
|
||||
Returns the number of tokens revoked.
|
||||
"""
|
||||
result = await db.execute(
|
||||
update(OAuthProviderRefreshToken)
|
||||
.where(
|
||||
and_(
|
||||
OAuthProviderRefreshToken.user_id == user_id,
|
||||
OAuthProviderRefreshToken.client_id == client_id,
|
||||
OAuthProviderRefreshToken.revoked == False, # noqa: E712
|
||||
)
|
||||
)
|
||||
.values(revoked=True)
|
||||
)
|
||||
count = result.rowcount # type: ignore[attr-defined]
|
||||
if count > 0:
|
||||
await db.commit()
|
||||
return count
|
||||
|
||||
async def revoke_all_for_user(
|
||||
self, db: AsyncSession, *, user_id: UUID
|
||||
) -> int:
|
||||
"""
|
||||
Revoke all active tokens for a user across all clients.
|
||||
|
||||
Used when user changes password or logs out everywhere.
|
||||
Returns the number of tokens revoked.
|
||||
"""
|
||||
result = await db.execute(
|
||||
update(OAuthProviderRefreshToken)
|
||||
.where(
|
||||
and_(
|
||||
OAuthProviderRefreshToken.user_id == user_id,
|
||||
OAuthProviderRefreshToken.revoked == False, # noqa: E712
|
||||
)
|
||||
)
|
||||
.values(revoked=True)
|
||||
)
|
||||
count = result.rowcount # type: ignore[attr-defined]
|
||||
if count > 0:
|
||||
await db.commit()
|
||||
return count
|
||||
|
||||
async def cleanup_expired(
|
||||
self, db: AsyncSession, *, cutoff_days: int = 7
|
||||
) -> int:
|
||||
"""
|
||||
Delete expired refresh tokens older than cutoff_days.
|
||||
|
||||
Should be called periodically (e.g., daily).
|
||||
Returns the number of tokens deleted.
|
||||
"""
|
||||
cutoff = datetime.now(UTC) - timedelta(days=cutoff_days)
|
||||
result = await db.execute(
|
||||
delete(OAuthProviderRefreshToken).where(
|
||||
OAuthProviderRefreshToken.expires_at < cutoff
|
||||
)
|
||||
)
|
||||
await db.commit()
|
||||
return result.rowcount # type: ignore[attr-defined]
|
||||
|
||||
|
||||
# Singleton instance
|
||||
oauth_provider_token_repo = OAuthProviderTokenRepository()
|
||||
113
backend/app/repositories/oauth_state.py
Normal file
113
backend/app/repositories/oauth_state.py
Normal file
@@ -0,0 +1,113 @@
|
||||
# app/repositories/oauth_state.py
|
||||
"""Repository for OAuthState model async CRUD operations."""
|
||||
|
||||
import logging
|
||||
from datetime import UTC, datetime
|
||||
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy import delete, select
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.core.repository_exceptions import DuplicateEntryError
|
||||
from app.models.oauth_state import OAuthState
|
||||
from app.repositories.base import BaseRepository
|
||||
from app.schemas.oauth import OAuthStateCreate
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class EmptySchema(BaseModel):
|
||||
"""Placeholder schema for repository operations that don't need update schemas."""
|
||||
|
||||
|
||||
class OAuthStateRepository(BaseRepository[OAuthState, OAuthStateCreate, EmptySchema]):
|
||||
"""Repository for OAuth state (CSRF protection)."""
|
||||
|
||||
async def create_state(
|
||||
self, db: AsyncSession, *, obj_in: OAuthStateCreate
|
||||
) -> OAuthState:
|
||||
"""Create a new OAuth state for CSRF protection."""
|
||||
try:
|
||||
db_obj = OAuthState(
|
||||
state=obj_in.state,
|
||||
code_verifier=obj_in.code_verifier,
|
||||
nonce=obj_in.nonce,
|
||||
provider=obj_in.provider,
|
||||
redirect_uri=obj_in.redirect_uri,
|
||||
user_id=obj_in.user_id,
|
||||
expires_at=obj_in.expires_at,
|
||||
)
|
||||
db.add(db_obj)
|
||||
await db.commit()
|
||||
await db.refresh(db_obj)
|
||||
|
||||
logger.debug(f"OAuth state created for {obj_in.provider}")
|
||||
return db_obj
|
||||
except IntegrityError as e: # pragma: no cover
|
||||
await db.rollback()
|
||||
error_msg = str(e.orig) if hasattr(e, "orig") else str(e)
|
||||
logger.error(f"OAuth state collision: {error_msg}")
|
||||
raise DuplicateEntryError("Failed to create OAuth state, please retry")
|
||||
except Exception as e: # pragma: no cover
|
||||
await db.rollback()
|
||||
logger.error(f"Error creating OAuth state: {e!s}", exc_info=True)
|
||||
raise
|
||||
|
||||
async def get_and_consume_state(
|
||||
self, db: AsyncSession, *, state: str
|
||||
) -> OAuthState | None:
|
||||
"""Get and delete OAuth state (consume it)."""
|
||||
try:
|
||||
result = await db.execute(
|
||||
select(OAuthState).where(OAuthState.state == state)
|
||||
)
|
||||
db_obj = result.scalar_one_or_none()
|
||||
|
||||
if db_obj is None:
|
||||
logger.warning(f"OAuth state not found: {state[:8]}...")
|
||||
return None
|
||||
|
||||
now = datetime.now(UTC)
|
||||
expires_at = db_obj.expires_at
|
||||
if expires_at.tzinfo is None:
|
||||
expires_at = expires_at.replace(tzinfo=UTC)
|
||||
|
||||
if expires_at < now:
|
||||
logger.warning(f"OAuth state expired: {state[:8]}...")
|
||||
await db.delete(db_obj)
|
||||
await db.commit()
|
||||
return None
|
||||
|
||||
await db.delete(db_obj)
|
||||
await db.commit()
|
||||
|
||||
logger.debug(f"OAuth state consumed: {state[:8]}...")
|
||||
return db_obj
|
||||
except Exception as e: # pragma: no cover
|
||||
await db.rollback()
|
||||
logger.error(f"Error consuming OAuth state: {e!s}")
|
||||
raise
|
||||
|
||||
async def cleanup_expired(self, db: AsyncSession) -> int:
|
||||
"""Clean up expired OAuth states."""
|
||||
try:
|
||||
now = datetime.now(UTC)
|
||||
|
||||
stmt = delete(OAuthState).where(OAuthState.expires_at < now)
|
||||
result = await db.execute(stmt)
|
||||
await db.commit()
|
||||
|
||||
count = result.rowcount
|
||||
if count > 0:
|
||||
logger.info(f"Cleaned up {count} expired OAuth states")
|
||||
|
||||
return count
|
||||
except Exception as e: # pragma: no cover
|
||||
await db.rollback()
|
||||
logger.error(f"Error cleaning up expired OAuth states: {e!s}")
|
||||
raise
|
||||
|
||||
|
||||
# Singleton instance
|
||||
oauth_state_repo = OAuthStateRepository(OAuthState)
|
||||
496
backend/app/repositories/organization.py
Normal file
496
backend/app/repositories/organization.py
Normal file
@@ -0,0 +1,496 @@
|
||||
# app/repositories/organization.py
|
||||
"""Repository for Organization model async CRUD operations using SQLAlchemy 2.0 patterns."""
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy import and_, case, func, or_, select
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.core.repository_exceptions import DuplicateEntryError, IntegrityConstraintError
|
||||
from app.models.organization import Organization
|
||||
from app.models.user import User
|
||||
from app.models.user_organization import OrganizationRole, UserOrganization
|
||||
from app.repositories.base import BaseRepository
|
||||
from app.schemas.organizations import (
|
||||
OrganizationCreate,
|
||||
OrganizationUpdate,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class OrganizationRepository(BaseRepository[Organization, OrganizationCreate, OrganizationUpdate]):
|
||||
"""Repository for Organization model."""
|
||||
|
||||
async def get_by_slug(self, db: AsyncSession, *, slug: str) -> Organization | None:
|
||||
"""Get organization by slug."""
|
||||
try:
|
||||
result = await db.execute(
|
||||
select(Organization).where(Organization.slug == slug)
|
||||
)
|
||||
return result.scalar_one_or_none()
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting organization by slug {slug}: {e!s}")
|
||||
raise
|
||||
|
||||
async def create(
|
||||
self, db: AsyncSession, *, obj_in: OrganizationCreate
|
||||
) -> Organization:
|
||||
"""Create a new organization with error handling."""
|
||||
try:
|
||||
db_obj = Organization(
|
||||
name=obj_in.name,
|
||||
slug=obj_in.slug,
|
||||
description=obj_in.description,
|
||||
is_active=obj_in.is_active,
|
||||
settings=obj_in.settings or {},
|
||||
)
|
||||
db.add(db_obj)
|
||||
await db.commit()
|
||||
await db.refresh(db_obj)
|
||||
return db_obj
|
||||
except IntegrityError as e:
|
||||
await db.rollback()
|
||||
error_msg = str(e.orig) if hasattr(e, "orig") else str(e)
|
||||
if "slug" in error_msg.lower() or "unique" in error_msg.lower() or "duplicate" in error_msg.lower():
|
||||
logger.warning(f"Duplicate slug attempted: {obj_in.slug}")
|
||||
raise DuplicateEntryError(
|
||||
f"Organization with slug '{obj_in.slug}' already exists"
|
||||
)
|
||||
logger.error(f"Integrity error creating organization: {error_msg}")
|
||||
raise IntegrityConstraintError(f"Database integrity error: {error_msg}")
|
||||
except Exception as e:
|
||||
await db.rollback()
|
||||
logger.error(
|
||||
f"Unexpected error creating organization: {e!s}", exc_info=True
|
||||
)
|
||||
raise
|
||||
|
||||
async def get_multi_with_filters(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
*,
|
||||
skip: int = 0,
|
||||
limit: int = 100,
|
||||
is_active: bool | None = None,
|
||||
search: str | None = None,
|
||||
sort_by: str = "created_at",
|
||||
sort_order: str = "desc",
|
||||
) -> tuple[list[Organization], int]:
|
||||
"""Get multiple organizations with filtering, searching, and sorting."""
|
||||
try:
|
||||
query = select(Organization)
|
||||
|
||||
if is_active is not None:
|
||||
query = query.where(Organization.is_active == is_active)
|
||||
|
||||
if search:
|
||||
search_filter = or_(
|
||||
Organization.name.ilike(f"%{search}%"),
|
||||
Organization.slug.ilike(f"%{search}%"),
|
||||
Organization.description.ilike(f"%{search}%"),
|
||||
)
|
||||
query = query.where(search_filter)
|
||||
|
||||
count_query = select(func.count()).select_from(query.alias())
|
||||
count_result = await db.execute(count_query)
|
||||
total = count_result.scalar_one()
|
||||
|
||||
sort_column = getattr(Organization, sort_by, Organization.created_at)
|
||||
if sort_order == "desc":
|
||||
query = query.order_by(sort_column.desc())
|
||||
else:
|
||||
query = query.order_by(sort_column.asc())
|
||||
|
||||
query = query.offset(skip).limit(limit)
|
||||
result = await db.execute(query)
|
||||
organizations = list(result.scalars().all())
|
||||
|
||||
return organizations, total
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting organizations with filters: {e!s}")
|
||||
raise
|
||||
|
||||
async def get_member_count(self, db: AsyncSession, *, organization_id: UUID) -> int:
|
||||
"""Get the count of active members in an organization."""
|
||||
try:
|
||||
result = await db.execute(
|
||||
select(func.count(UserOrganization.user_id)).where(
|
||||
and_(
|
||||
UserOrganization.organization_id == organization_id,
|
||||
UserOrganization.is_active,
|
||||
)
|
||||
)
|
||||
)
|
||||
return result.scalar_one() or 0
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error getting member count for organization {organization_id}: {e!s}"
|
||||
)
|
||||
raise
|
||||
|
||||
async def get_multi_with_member_counts(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
*,
|
||||
skip: int = 0,
|
||||
limit: int = 100,
|
||||
is_active: bool | None = None,
|
||||
search: str | None = None,
|
||||
) -> tuple[list[dict[str, Any]], int]:
|
||||
"""Get organizations with member counts in a SINGLE QUERY using JOIN and GROUP BY."""
|
||||
try:
|
||||
query = (
|
||||
select(
|
||||
Organization,
|
||||
func.count(
|
||||
func.distinct(
|
||||
case(
|
||||
(
|
||||
UserOrganization.is_active,
|
||||
UserOrganization.user_id,
|
||||
),
|
||||
else_=None,
|
||||
)
|
||||
)
|
||||
).label("member_count"),
|
||||
)
|
||||
.outerjoin(
|
||||
UserOrganization,
|
||||
Organization.id == UserOrganization.organization_id,
|
||||
)
|
||||
.group_by(Organization.id)
|
||||
)
|
||||
|
||||
if is_active is not None:
|
||||
query = query.where(Organization.is_active == is_active)
|
||||
|
||||
if search:
|
||||
search_filter = or_(
|
||||
Organization.name.ilike(f"%{search}%"),
|
||||
Organization.slug.ilike(f"%{search}%"),
|
||||
Organization.description.ilike(f"%{search}%"),
|
||||
)
|
||||
query = query.where(search_filter)
|
||||
|
||||
count_query = select(func.count(Organization.id))
|
||||
if is_active is not None:
|
||||
count_query = count_query.where(Organization.is_active == is_active)
|
||||
if search:
|
||||
count_query = count_query.where(search_filter)
|
||||
|
||||
count_result = await db.execute(count_query)
|
||||
total = count_result.scalar_one()
|
||||
|
||||
query = (
|
||||
query.order_by(Organization.created_at.desc()).offset(skip).limit(limit)
|
||||
)
|
||||
|
||||
result = await db.execute(query)
|
||||
rows = result.all()
|
||||
|
||||
orgs_with_counts = [
|
||||
{"organization": org, "member_count": member_count}
|
||||
for org, member_count in rows
|
||||
]
|
||||
|
||||
return orgs_with_counts, total
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error getting organizations with member counts: {e!s}", exc_info=True
|
||||
)
|
||||
raise
|
||||
|
||||
async def add_user(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
*,
|
||||
organization_id: UUID,
|
||||
user_id: UUID,
|
||||
role: OrganizationRole = OrganizationRole.MEMBER,
|
||||
custom_permissions: str | None = None,
|
||||
) -> UserOrganization:
|
||||
"""Add a user to an organization with a specific role."""
|
||||
try:
|
||||
result = await db.execute(
|
||||
select(UserOrganization).where(
|
||||
and_(
|
||||
UserOrganization.user_id == user_id,
|
||||
UserOrganization.organization_id == organization_id,
|
||||
)
|
||||
)
|
||||
)
|
||||
existing = result.scalar_one_or_none()
|
||||
|
||||
if existing:
|
||||
if not existing.is_active:
|
||||
existing.is_active = True
|
||||
existing.role = role
|
||||
existing.custom_permissions = custom_permissions
|
||||
await db.commit()
|
||||
await db.refresh(existing)
|
||||
return existing
|
||||
else:
|
||||
raise DuplicateEntryError("User is already a member of this organization")
|
||||
|
||||
user_org = UserOrganization(
|
||||
user_id=user_id,
|
||||
organization_id=organization_id,
|
||||
role=role,
|
||||
is_active=True,
|
||||
custom_permissions=custom_permissions,
|
||||
)
|
||||
db.add(user_org)
|
||||
await db.commit()
|
||||
await db.refresh(user_org)
|
||||
return user_org
|
||||
except IntegrityError as e:
|
||||
await db.rollback()
|
||||
logger.error(f"Integrity error adding user to organization: {e!s}")
|
||||
raise IntegrityConstraintError("Failed to add user to organization")
|
||||
except Exception as e:
|
||||
await db.rollback()
|
||||
logger.error(f"Error adding user to organization: {e!s}", exc_info=True)
|
||||
raise
|
||||
|
||||
async def remove_user(
|
||||
self, db: AsyncSession, *, organization_id: UUID, user_id: UUID
|
||||
) -> bool:
|
||||
"""Remove a user from an organization (soft delete)."""
|
||||
try:
|
||||
result = await db.execute(
|
||||
select(UserOrganization).where(
|
||||
and_(
|
||||
UserOrganization.user_id == user_id,
|
||||
UserOrganization.organization_id == organization_id,
|
||||
)
|
||||
)
|
||||
)
|
||||
user_org = result.scalar_one_or_none()
|
||||
|
||||
if not user_org:
|
||||
return False
|
||||
|
||||
user_org.is_active = False
|
||||
await db.commit()
|
||||
return True
|
||||
except Exception as e:
|
||||
await db.rollback()
|
||||
logger.error(f"Error removing user from organization: {e!s}", exc_info=True)
|
||||
raise
|
||||
|
||||
async def update_user_role(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
*,
|
||||
organization_id: UUID,
|
||||
user_id: UUID,
|
||||
role: OrganizationRole,
|
||||
custom_permissions: str | None = None,
|
||||
) -> UserOrganization | None:
|
||||
"""Update a user's role in an organization."""
|
||||
try:
|
||||
result = await db.execute(
|
||||
select(UserOrganization).where(
|
||||
and_(
|
||||
UserOrganization.user_id == user_id,
|
||||
UserOrganization.organization_id == organization_id,
|
||||
)
|
||||
)
|
||||
)
|
||||
user_org = result.scalar_one_or_none()
|
||||
|
||||
if not user_org:
|
||||
return None
|
||||
|
||||
user_org.role = role
|
||||
if custom_permissions is not None:
|
||||
user_org.custom_permissions = custom_permissions
|
||||
await db.commit()
|
||||
await db.refresh(user_org)
|
||||
return user_org
|
||||
except Exception as e:
|
||||
await db.rollback()
|
||||
logger.error(f"Error updating user role: {e!s}", exc_info=True)
|
||||
raise
|
||||
|
||||
async def get_organization_members(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
*,
|
||||
organization_id: UUID,
|
||||
skip: int = 0,
|
||||
limit: int = 100,
|
||||
is_active: bool = True,
|
||||
) -> tuple[list[dict[str, Any]], int]:
|
||||
"""Get members of an organization with user details."""
|
||||
try:
|
||||
query = (
|
||||
select(UserOrganization, User)
|
||||
.join(User, UserOrganization.user_id == User.id)
|
||||
.where(UserOrganization.organization_id == organization_id)
|
||||
)
|
||||
|
||||
if is_active is not None:
|
||||
query = query.where(UserOrganization.is_active == is_active)
|
||||
|
||||
count_query = select(func.count()).select_from(
|
||||
select(UserOrganization)
|
||||
.where(UserOrganization.organization_id == organization_id)
|
||||
.where(
|
||||
UserOrganization.is_active == is_active
|
||||
if is_active is not None
|
||||
else True
|
||||
)
|
||||
.alias()
|
||||
)
|
||||
count_result = await db.execute(count_query)
|
||||
total = count_result.scalar_one()
|
||||
|
||||
query = (
|
||||
query.order_by(UserOrganization.created_at.desc())
|
||||
.offset(skip)
|
||||
.limit(limit)
|
||||
)
|
||||
result = await db.execute(query)
|
||||
results = result.all()
|
||||
|
||||
members = []
|
||||
for user_org, user in results:
|
||||
members.append(
|
||||
{
|
||||
"user_id": user.id,
|
||||
"email": user.email,
|
||||
"first_name": user.first_name,
|
||||
"last_name": user.last_name,
|
||||
"role": user_org.role,
|
||||
"is_active": user_org.is_active,
|
||||
"joined_at": user_org.created_at,
|
||||
}
|
||||
)
|
||||
|
||||
return members, total
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting organization members: {e!s}")
|
||||
raise
|
||||
|
||||
async def get_user_organizations(
|
||||
self, db: AsyncSession, *, user_id: UUID, is_active: bool = True
|
||||
) -> list[Organization]:
|
||||
"""Get all organizations a user belongs to."""
|
||||
try:
|
||||
query = (
|
||||
select(Organization)
|
||||
.join(
|
||||
UserOrganization,
|
||||
Organization.id == UserOrganization.organization_id,
|
||||
)
|
||||
.where(UserOrganization.user_id == user_id)
|
||||
)
|
||||
|
||||
if is_active is not None:
|
||||
query = query.where(UserOrganization.is_active == is_active)
|
||||
|
||||
result = await db.execute(query)
|
||||
return list(result.scalars().all())
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting user organizations: {e!s}")
|
||||
raise
|
||||
|
||||
async def get_user_organizations_with_details(
|
||||
self, db: AsyncSession, *, user_id: UUID, is_active: bool = True
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Get user's organizations with role and member count in SINGLE QUERY."""
|
||||
try:
|
||||
member_count_subq = (
|
||||
select(
|
||||
UserOrganization.organization_id,
|
||||
func.count(UserOrganization.user_id).label("member_count"),
|
||||
)
|
||||
.where(UserOrganization.is_active)
|
||||
.group_by(UserOrganization.organization_id)
|
||||
.subquery()
|
||||
)
|
||||
|
||||
query = (
|
||||
select(
|
||||
Organization,
|
||||
UserOrganization.role,
|
||||
func.coalesce(member_count_subq.c.member_count, 0).label(
|
||||
"member_count"
|
||||
),
|
||||
)
|
||||
.join(
|
||||
UserOrganization,
|
||||
Organization.id == UserOrganization.organization_id,
|
||||
)
|
||||
.outerjoin(
|
||||
member_count_subq,
|
||||
Organization.id == member_count_subq.c.organization_id,
|
||||
)
|
||||
.where(UserOrganization.user_id == user_id)
|
||||
)
|
||||
|
||||
if is_active is not None:
|
||||
query = query.where(UserOrganization.is_active == is_active)
|
||||
|
||||
result = await db.execute(query)
|
||||
rows = result.all()
|
||||
|
||||
return [
|
||||
{"organization": org, "role": role, "member_count": member_count}
|
||||
for org, role, member_count in rows
|
||||
]
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error getting user organizations with details: {e!s}", exc_info=True
|
||||
)
|
||||
raise
|
||||
|
||||
async def get_user_role_in_org(
|
||||
self, db: AsyncSession, *, user_id: UUID, organization_id: UUID
|
||||
) -> OrganizationRole | None:
|
||||
"""Get a user's role in a specific organization."""
|
||||
try:
|
||||
result = await db.execute(
|
||||
select(UserOrganization).where(
|
||||
and_(
|
||||
UserOrganization.user_id == user_id,
|
||||
UserOrganization.organization_id == organization_id,
|
||||
UserOrganization.is_active,
|
||||
)
|
||||
)
|
||||
)
|
||||
user_org = result.scalar_one_or_none()
|
||||
|
||||
return user_org.role if user_org else None
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting user role in org: {e!s}")
|
||||
raise
|
||||
|
||||
async def is_user_org_owner(
|
||||
self, db: AsyncSession, *, user_id: UUID, organization_id: UUID
|
||||
) -> bool:
|
||||
"""Check if a user is an owner of an organization."""
|
||||
role = await self.get_user_role_in_org(
|
||||
db, user_id=user_id, organization_id=organization_id
|
||||
)
|
||||
return role == OrganizationRole.OWNER
|
||||
|
||||
async def is_user_org_admin(
|
||||
self, db: AsyncSession, *, user_id: UUID, organization_id: UUID
|
||||
) -> bool:
|
||||
"""Check if a user is an owner or admin of an organization."""
|
||||
role = await self.get_user_role_in_org(
|
||||
db, user_id=user_id, organization_id=organization_id
|
||||
)
|
||||
return role in [OrganizationRole.OWNER, OrganizationRole.ADMIN]
|
||||
|
||||
|
||||
# Singleton instance
|
||||
organization_repo = OrganizationRepository(Organization)
|
||||
327
backend/app/repositories/session.py
Normal file
327
backend/app/repositories/session.py
Normal file
@@ -0,0 +1,327 @@
|
||||
# app/repositories/session.py
|
||||
"""Repository for UserSession model async CRUD operations using SQLAlchemy 2.0 patterns."""
|
||||
|
||||
import logging
|
||||
import uuid
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy import and_, delete, func, select, update
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm import joinedload
|
||||
|
||||
from app.core.repository_exceptions import InvalidInputError, IntegrityConstraintError
|
||||
from app.models.user_session import UserSession
|
||||
from app.repositories.base import BaseRepository
|
||||
from app.schemas.sessions import SessionCreate, SessionUpdate
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SessionRepository(BaseRepository[UserSession, SessionCreate, SessionUpdate]):
|
||||
"""Repository for UserSession model."""
|
||||
|
||||
async def get_by_jti(self, db: AsyncSession, *, jti: str) -> UserSession | None:
|
||||
"""Get session by refresh token JTI."""
|
||||
try:
|
||||
result = await db.execute(
|
||||
select(UserSession).where(UserSession.refresh_token_jti == jti)
|
||||
)
|
||||
return result.scalar_one_or_none()
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting session by JTI {jti}: {e!s}")
|
||||
raise
|
||||
|
||||
async def get_active_by_jti(
|
||||
self, db: AsyncSession, *, jti: str
|
||||
) -> UserSession | None:
|
||||
"""Get active session by refresh token JTI."""
|
||||
try:
|
||||
result = await db.execute(
|
||||
select(UserSession).where(
|
||||
and_(
|
||||
UserSession.refresh_token_jti == jti,
|
||||
UserSession.is_active,
|
||||
)
|
||||
)
|
||||
)
|
||||
return result.scalar_one_or_none()
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting active session by JTI {jti}: {e!s}")
|
||||
raise
|
||||
|
||||
async def get_user_sessions(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
*,
|
||||
user_id: str,
|
||||
active_only: bool = True,
|
||||
with_user: bool = False,
|
||||
) -> list[UserSession]:
|
||||
"""Get all sessions for a user with optional eager loading."""
|
||||
try:
|
||||
user_uuid = UUID(user_id) if isinstance(user_id, str) else user_id
|
||||
|
||||
query = select(UserSession).where(UserSession.user_id == user_uuid)
|
||||
|
||||
if with_user:
|
||||
query = query.options(joinedload(UserSession.user))
|
||||
|
||||
if active_only:
|
||||
query = query.where(UserSession.is_active)
|
||||
|
||||
query = query.order_by(UserSession.last_used_at.desc())
|
||||
result = await db.execute(query)
|
||||
return list(result.scalars().all())
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting sessions for user {user_id}: {e!s}")
|
||||
raise
|
||||
|
||||
async def create_session(
|
||||
self, db: AsyncSession, *, obj_in: SessionCreate
|
||||
) -> UserSession:
|
||||
"""Create a new user session."""
|
||||
try:
|
||||
db_obj = UserSession(
|
||||
user_id=obj_in.user_id,
|
||||
refresh_token_jti=obj_in.refresh_token_jti,
|
||||
device_name=obj_in.device_name,
|
||||
device_id=obj_in.device_id,
|
||||
ip_address=obj_in.ip_address,
|
||||
user_agent=obj_in.user_agent,
|
||||
last_used_at=obj_in.last_used_at,
|
||||
expires_at=obj_in.expires_at,
|
||||
is_active=True,
|
||||
location_city=obj_in.location_city,
|
||||
location_country=obj_in.location_country,
|
||||
)
|
||||
db.add(db_obj)
|
||||
await db.commit()
|
||||
await db.refresh(db_obj)
|
||||
|
||||
logger.info(
|
||||
f"Session created for user {obj_in.user_id} from {obj_in.device_name} "
|
||||
f"(IP: {obj_in.ip_address})"
|
||||
)
|
||||
|
||||
return db_obj
|
||||
except Exception as e:
|
||||
await db.rollback()
|
||||
logger.error(f"Error creating session: {e!s}", exc_info=True)
|
||||
raise IntegrityConstraintError(f"Failed to create session: {e!s}")
|
||||
|
||||
async def deactivate(
|
||||
self, db: AsyncSession, *, session_id: str
|
||||
) -> UserSession | None:
|
||||
"""Deactivate a session (logout from device)."""
|
||||
try:
|
||||
session = await self.get(db, id=session_id)
|
||||
if not session:
|
||||
logger.warning(f"Session {session_id} not found for deactivation")
|
||||
return None
|
||||
|
||||
session.is_active = False
|
||||
db.add(session)
|
||||
await db.commit()
|
||||
await db.refresh(session)
|
||||
|
||||
logger.info(
|
||||
f"Session {session_id} deactivated for user {session.user_id} "
|
||||
f"({session.device_name})"
|
||||
)
|
||||
|
||||
return session
|
||||
except Exception as e:
|
||||
await db.rollback()
|
||||
logger.error(f"Error deactivating session {session_id}: {e!s}")
|
||||
raise
|
||||
|
||||
async def deactivate_all_user_sessions(
|
||||
self, db: AsyncSession, *, user_id: str
|
||||
) -> int:
|
||||
"""Deactivate all active sessions for a user (logout from all devices)."""
|
||||
try:
|
||||
user_uuid = UUID(user_id) if isinstance(user_id, str) else user_id
|
||||
|
||||
stmt = (
|
||||
update(UserSession)
|
||||
.where(and_(UserSession.user_id == user_uuid, UserSession.is_active))
|
||||
.values(is_active=False)
|
||||
)
|
||||
|
||||
result = await db.execute(stmt)
|
||||
await db.commit()
|
||||
|
||||
count = result.rowcount
|
||||
|
||||
logger.info(f"Deactivated {count} sessions for user {user_id}")
|
||||
|
||||
return count
|
||||
except Exception as e:
|
||||
await db.rollback()
|
||||
logger.error(f"Error deactivating all sessions for user {user_id}: {e!s}")
|
||||
raise
|
||||
|
||||
async def update_last_used(
|
||||
self, db: AsyncSession, *, session: UserSession
|
||||
) -> UserSession:
|
||||
"""Update the last_used_at timestamp for a session."""
|
||||
try:
|
||||
session.last_used_at = datetime.now(UTC)
|
||||
db.add(session)
|
||||
await db.commit()
|
||||
await db.refresh(session)
|
||||
return session
|
||||
except Exception as e:
|
||||
await db.rollback()
|
||||
logger.error(f"Error updating last_used for session {session.id}: {e!s}")
|
||||
raise
|
||||
|
||||
async def update_refresh_token(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
*,
|
||||
session: UserSession,
|
||||
new_jti: str,
|
||||
new_expires_at: datetime,
|
||||
) -> UserSession:
|
||||
"""Update session with new refresh token JTI and expiration."""
|
||||
try:
|
||||
session.refresh_token_jti = new_jti
|
||||
session.expires_at = new_expires_at
|
||||
session.last_used_at = datetime.now(UTC)
|
||||
db.add(session)
|
||||
await db.commit()
|
||||
await db.refresh(session)
|
||||
return session
|
||||
except Exception as e:
|
||||
await db.rollback()
|
||||
logger.error(
|
||||
f"Error updating refresh token for session {session.id}: {e!s}"
|
||||
)
|
||||
raise
|
||||
|
||||
async def cleanup_expired(self, db: AsyncSession, *, keep_days: int = 30) -> int:
|
||||
"""Clean up expired sessions using optimized bulk DELETE."""
|
||||
try:
|
||||
cutoff_date = datetime.now(UTC) - timedelta(days=keep_days)
|
||||
now = datetime.now(UTC)
|
||||
|
||||
stmt = delete(UserSession).where(
|
||||
and_(
|
||||
UserSession.is_active == False, # noqa: E712
|
||||
UserSession.expires_at < now,
|
||||
UserSession.created_at < cutoff_date,
|
||||
)
|
||||
)
|
||||
|
||||
result = await db.execute(stmt)
|
||||
await db.commit()
|
||||
|
||||
count = result.rowcount
|
||||
|
||||
if count > 0:
|
||||
logger.info(f"Cleaned up {count} expired sessions using bulk DELETE")
|
||||
|
||||
return count
|
||||
except Exception as e:
|
||||
await db.rollback()
|
||||
logger.error(f"Error cleaning up expired sessions: {e!s}")
|
||||
raise
|
||||
|
||||
async def cleanup_expired_for_user(self, db: AsyncSession, *, user_id: str) -> int:
|
||||
"""Clean up expired and inactive sessions for a specific user."""
|
||||
try:
|
||||
try:
|
||||
uuid_obj = uuid.UUID(user_id)
|
||||
except (ValueError, AttributeError):
|
||||
logger.error(f"Invalid UUID format: {user_id}")
|
||||
raise InvalidInputError(f"Invalid user ID format: {user_id}")
|
||||
|
||||
now = datetime.now(UTC)
|
||||
|
||||
stmt = delete(UserSession).where(
|
||||
and_(
|
||||
UserSession.user_id == uuid_obj,
|
||||
UserSession.is_active == False, # noqa: E712
|
||||
UserSession.expires_at < now,
|
||||
)
|
||||
)
|
||||
|
||||
result = await db.execute(stmt)
|
||||
await db.commit()
|
||||
|
||||
count = result.rowcount
|
||||
|
||||
if count > 0:
|
||||
logger.info(
|
||||
f"Cleaned up {count} expired sessions for user {user_id} using bulk DELETE"
|
||||
)
|
||||
|
||||
return count
|
||||
except Exception as e:
|
||||
await db.rollback()
|
||||
logger.error(
|
||||
f"Error cleaning up expired sessions for user {user_id}: {e!s}"
|
||||
)
|
||||
raise
|
||||
|
||||
async def get_user_session_count(self, db: AsyncSession, *, user_id: str) -> int:
|
||||
"""Get count of active sessions for a user."""
|
||||
try:
|
||||
user_uuid = UUID(user_id) if isinstance(user_id, str) else user_id
|
||||
|
||||
result = await db.execute(
|
||||
select(func.count(UserSession.id)).where(
|
||||
and_(UserSession.user_id == user_uuid, UserSession.is_active)
|
||||
)
|
||||
)
|
||||
return result.scalar_one()
|
||||
except Exception as e:
|
||||
logger.error(f"Error counting sessions for user {user_id}: {e!s}")
|
||||
raise
|
||||
|
||||
async def get_all_sessions(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
*,
|
||||
skip: int = 0,
|
||||
limit: int = 100,
|
||||
active_only: bool = True,
|
||||
with_user: bool = True,
|
||||
) -> tuple[list[UserSession], int]:
|
||||
"""Get all sessions across all users with pagination (admin only)."""
|
||||
try:
|
||||
query = select(UserSession)
|
||||
|
||||
if with_user:
|
||||
query = query.options(joinedload(UserSession.user))
|
||||
|
||||
if active_only:
|
||||
query = query.where(UserSession.is_active)
|
||||
|
||||
count_query = select(func.count(UserSession.id))
|
||||
if active_only:
|
||||
count_query = count_query.where(UserSession.is_active)
|
||||
|
||||
count_result = await db.execute(count_query)
|
||||
total = count_result.scalar_one()
|
||||
|
||||
query = (
|
||||
query.order_by(UserSession.last_used_at.desc())
|
||||
.offset(skip)
|
||||
.limit(limit)
|
||||
)
|
||||
|
||||
result = await db.execute(query)
|
||||
sessions = list(result.scalars().all())
|
||||
|
||||
return sessions, total
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting all sessions: {e!s}", exc_info=True)
|
||||
raise
|
||||
|
||||
|
||||
# Singleton instance
|
||||
session_repo = SessionRepository(UserSession)
|
||||
265
backend/app/repositories/user.py
Normal file
265
backend/app/repositories/user.py
Normal file
@@ -0,0 +1,265 @@
|
||||
# app/repositories/user.py
|
||||
"""Repository for User model async CRUD operations using SQLAlchemy 2.0 patterns."""
|
||||
|
||||
import logging
|
||||
from datetime import UTC, datetime
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy import or_, select, update
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.core.auth import get_password_hash_async
|
||||
from app.core.repository_exceptions import DuplicateEntryError, InvalidInputError
|
||||
from app.models.user import User
|
||||
from app.repositories.base import BaseRepository
|
||||
from app.schemas.users import UserCreate, UserUpdate
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class UserRepository(BaseRepository[User, UserCreate, UserUpdate]):
|
||||
"""Repository for User model."""
|
||||
|
||||
async def get_by_email(self, db: AsyncSession, *, email: str) -> User | None:
|
||||
"""Get user by email address."""
|
||||
try:
|
||||
result = await db.execute(select(User).where(User.email == email))
|
||||
return result.scalar_one_or_none()
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting user by email {email}: {e!s}")
|
||||
raise
|
||||
|
||||
async def create(self, db: AsyncSession, *, obj_in: UserCreate) -> User:
|
||||
"""Create a new user with async password hashing and error handling."""
|
||||
try:
|
||||
password_hash = await get_password_hash_async(obj_in.password)
|
||||
|
||||
db_obj = User(
|
||||
email=obj_in.email,
|
||||
password_hash=password_hash,
|
||||
first_name=obj_in.first_name,
|
||||
last_name=obj_in.last_name,
|
||||
phone_number=obj_in.phone_number
|
||||
if hasattr(obj_in, "phone_number")
|
||||
else None,
|
||||
is_superuser=obj_in.is_superuser
|
||||
if hasattr(obj_in, "is_superuser")
|
||||
else False,
|
||||
preferences={},
|
||||
)
|
||||
db.add(db_obj)
|
||||
await db.commit()
|
||||
await db.refresh(db_obj)
|
||||
return db_obj
|
||||
except IntegrityError as e:
|
||||
await db.rollback()
|
||||
error_msg = str(e.orig) if hasattr(e, "orig") else str(e)
|
||||
if "email" in error_msg.lower():
|
||||
logger.warning(f"Duplicate email attempted: {obj_in.email}")
|
||||
raise DuplicateEntryError(f"User with email {obj_in.email} already exists")
|
||||
logger.error(f"Integrity error creating user: {error_msg}")
|
||||
raise DuplicateEntryError(f"Database integrity error: {error_msg}")
|
||||
except Exception as e:
|
||||
await db.rollback()
|
||||
logger.error(f"Unexpected error creating user: {e!s}", exc_info=True)
|
||||
raise
|
||||
|
||||
async def create_oauth_user(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
*,
|
||||
email: str,
|
||||
first_name: str = "User",
|
||||
last_name: str | None = None,
|
||||
) -> User:
|
||||
"""Create a new passwordless user for OAuth sign-in."""
|
||||
try:
|
||||
db_obj = User(
|
||||
email=email,
|
||||
password_hash=None, # OAuth-only user
|
||||
first_name=first_name,
|
||||
last_name=last_name,
|
||||
is_active=True,
|
||||
is_superuser=False,
|
||||
)
|
||||
db.add(db_obj)
|
||||
await db.flush() # Get user.id without committing
|
||||
return db_obj
|
||||
except IntegrityError as e:
|
||||
await db.rollback()
|
||||
error_msg = str(e.orig) if hasattr(e, "orig") else str(e)
|
||||
if "email" in error_msg.lower():
|
||||
logger.warning(f"Duplicate email attempted: {email}")
|
||||
raise DuplicateEntryError(f"User with email {email} already exists")
|
||||
logger.error(f"Integrity error creating OAuth user: {error_msg}")
|
||||
raise DuplicateEntryError(f"Database integrity error: {error_msg}")
|
||||
except Exception as e:
|
||||
await db.rollback()
|
||||
logger.error(f"Unexpected error creating OAuth user: {e!s}", exc_info=True)
|
||||
raise
|
||||
|
||||
async def update(
|
||||
self, db: AsyncSession, *, db_obj: User, obj_in: UserUpdate | dict[str, Any]
|
||||
) -> User:
|
||||
"""Update user with async password hashing if password is updated."""
|
||||
if isinstance(obj_in, dict):
|
||||
update_data = obj_in
|
||||
else:
|
||||
update_data = obj_in.model_dump(exclude_unset=True)
|
||||
|
||||
if "password" in update_data:
|
||||
update_data["password_hash"] = await get_password_hash_async(
|
||||
update_data["password"]
|
||||
)
|
||||
del update_data["password"]
|
||||
|
||||
return await super().update(db, db_obj=db_obj, obj_in=update_data)
|
||||
|
||||
async def update_password(
|
||||
self, db: AsyncSession, *, user: User, password_hash: str
|
||||
) -> User:
|
||||
"""Set a new password hash on a user and commit."""
|
||||
user.password_hash = password_hash
|
||||
await db.commit()
|
||||
await db.refresh(user)
|
||||
return user
|
||||
|
||||
async def get_multi_with_total(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
*,
|
||||
skip: int = 0,
|
||||
limit: int = 100,
|
||||
sort_by: str | None = None,
|
||||
sort_order: str = "asc",
|
||||
filters: dict[str, Any] | None = None,
|
||||
search: str | None = None,
|
||||
) -> tuple[list[User], int]:
|
||||
"""Get multiple users with total count, filtering, sorting, and search."""
|
||||
if skip < 0:
|
||||
raise InvalidInputError("skip must be non-negative")
|
||||
if limit < 0:
|
||||
raise InvalidInputError("limit must be non-negative")
|
||||
if limit > 1000:
|
||||
raise InvalidInputError("Maximum limit is 1000")
|
||||
|
||||
try:
|
||||
query = select(User)
|
||||
query = query.where(User.deleted_at.is_(None))
|
||||
|
||||
if filters:
|
||||
for field, value in filters.items():
|
||||
if hasattr(User, field) and value is not None:
|
||||
query = query.where(getattr(User, field) == value)
|
||||
|
||||
if search:
|
||||
search_filter = or_(
|
||||
User.email.ilike(f"%{search}%"),
|
||||
User.first_name.ilike(f"%{search}%"),
|
||||
User.last_name.ilike(f"%{search}%"),
|
||||
)
|
||||
query = query.where(search_filter)
|
||||
|
||||
from sqlalchemy import func
|
||||
|
||||
count_query = select(func.count()).select_from(query.alias())
|
||||
count_result = await db.execute(count_query)
|
||||
total = count_result.scalar_one()
|
||||
|
||||
if sort_by and hasattr(User, sort_by):
|
||||
sort_column = getattr(User, sort_by)
|
||||
if sort_order.lower() == "desc":
|
||||
query = query.order_by(sort_column.desc())
|
||||
else:
|
||||
query = query.order_by(sort_column.asc())
|
||||
|
||||
query = query.offset(skip).limit(limit)
|
||||
result = await db.execute(query)
|
||||
users = list(result.scalars().all())
|
||||
|
||||
return users, total
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error retrieving paginated users: {e!s}")
|
||||
raise
|
||||
|
||||
async def bulk_update_status(
|
||||
self, db: AsyncSession, *, user_ids: list[UUID], is_active: bool
|
||||
) -> int:
|
||||
"""Bulk update is_active status for multiple users."""
|
||||
try:
|
||||
if not user_ids:
|
||||
return 0
|
||||
|
||||
stmt = (
|
||||
update(User)
|
||||
.where(User.id.in_(user_ids))
|
||||
.where(User.deleted_at.is_(None))
|
||||
.values(is_active=is_active, updated_at=datetime.now(UTC))
|
||||
)
|
||||
|
||||
result = await db.execute(stmt)
|
||||
await db.commit()
|
||||
|
||||
updated_count = result.rowcount
|
||||
logger.info(f"Bulk updated {updated_count} users to is_active={is_active}")
|
||||
return updated_count
|
||||
|
||||
except Exception as e:
|
||||
await db.rollback()
|
||||
logger.error(f"Error bulk updating user status: {e!s}", exc_info=True)
|
||||
raise
|
||||
|
||||
async def bulk_soft_delete(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
*,
|
||||
user_ids: list[UUID],
|
||||
exclude_user_id: UUID | None = None,
|
||||
) -> int:
|
||||
"""Bulk soft delete multiple users."""
|
||||
try:
|
||||
if not user_ids:
|
||||
return 0
|
||||
|
||||
filtered_ids = [uid for uid in user_ids if uid != exclude_user_id]
|
||||
|
||||
if not filtered_ids:
|
||||
return 0
|
||||
|
||||
stmt = (
|
||||
update(User)
|
||||
.where(User.id.in_(filtered_ids))
|
||||
.where(User.deleted_at.is_(None))
|
||||
.values(
|
||||
deleted_at=datetime.now(UTC),
|
||||
is_active=False,
|
||||
updated_at=datetime.now(UTC),
|
||||
)
|
||||
)
|
||||
|
||||
result = await db.execute(stmt)
|
||||
await db.commit()
|
||||
|
||||
deleted_count = result.rowcount
|
||||
logger.info(f"Bulk soft deleted {deleted_count} users")
|
||||
return deleted_count
|
||||
|
||||
except Exception as e:
|
||||
await db.rollback()
|
||||
logger.error(f"Error bulk deleting users: {e!s}", exc_info=True)
|
||||
raise
|
||||
|
||||
def is_active(self, user: User) -> bool:
|
||||
"""Check if user is active."""
|
||||
return user.is_active
|
||||
|
||||
def is_superuser(self, user: User) -> bool:
|
||||
"""Check if user is a superuser."""
|
||||
return user.is_superuser
|
||||
|
||||
|
||||
# Singleton instance
|
||||
user_repo = UserRepository(User)
|
||||
Reference in New Issue
Block a user