refactor(backend): enforce route→service→repo layered architecture

- introduce custom repository exception hierarchy (DuplicateEntryError,
  IntegrityConstraintError, InvalidInputError) replacing raw ValueError
- eliminate all direct repository imports and raw SQL from route layer
- add UserService, SessionService, OrganizationService to service layer
- add get_stats/get_org_distribution service methods replacing admin inline SQL
- fix timing side-channel in authenticate_user via dummy bcrypt check
- replace SHA-256 client secret fallback with explicit InvalidClientError
- replace assert with InvalidGrantError in authorization code exchange
- replace N+1 token revocation loops with bulk UPDATE statements
- rename oauth account token fields (drop misleading 'encrypted' suffix)
- add Alembic migration 0003 for token field column rename
- add 45 new service/repository tests; 975 passing, 94% coverage
This commit is contained in:
2026-02-27 09:32:57 +01:00
parent 0646c96b19
commit 98b455fdc3
62 changed files with 2933 additions and 1728 deletions

View File

@@ -0,0 +1,39 @@
# app/repositories/__init__.py
"""Repository layer — all database access goes through these classes."""
from app.repositories.oauth_account import OAuthAccountRepository, oauth_account_repo
from app.repositories.oauth_authorization_code import (
OAuthAuthorizationCodeRepository,
oauth_authorization_code_repo,
)
from app.repositories.oauth_client import OAuthClientRepository, oauth_client_repo
from app.repositories.oauth_consent import OAuthConsentRepository, oauth_consent_repo
from app.repositories.oauth_provider_token import (
OAuthProviderTokenRepository,
oauth_provider_token_repo,
)
from app.repositories.oauth_state import OAuthStateRepository, oauth_state_repo
from app.repositories.organization import OrganizationRepository, organization_repo
from app.repositories.session import SessionRepository, session_repo
from app.repositories.user import UserRepository, user_repo
__all__ = [
"UserRepository",
"user_repo",
"OrganizationRepository",
"organization_repo",
"SessionRepository",
"session_repo",
"OAuthAccountRepository",
"oauth_account_repo",
"OAuthAuthorizationCodeRepository",
"oauth_authorization_code_repo",
"OAuthClientRepository",
"oauth_client_repo",
"OAuthConsentRepository",
"oauth_consent_repo",
"OAuthProviderTokenRepository",
"oauth_provider_token_repo",
"OAuthStateRepository",
"oauth_state_repo",
]

View File

@@ -0,0 +1,414 @@
# app/repositories/base.py
"""
Base repository class for async CRUD operations using SQLAlchemy 2.0 async patterns.
Provides reusable create, read, update, and delete operations for all models.
"""
import logging
import uuid
from datetime import UTC
from typing import Any, TypeVar
from fastapi.encoders import jsonable_encoder
from pydantic import BaseModel
from sqlalchemy import func, select
from sqlalchemy.exc import DataError, IntegrityError, OperationalError
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import Load
from app.core.database import Base
from app.core.repository_exceptions import (
DuplicateEntryError,
IntegrityConstraintError,
InvalidInputError,
)
logger = logging.getLogger(__name__)
ModelType = TypeVar("ModelType", bound=Base)
CreateSchemaType = TypeVar("CreateSchemaType", bound=BaseModel)
UpdateSchemaType = TypeVar("UpdateSchemaType", bound=BaseModel)
class BaseRepository[
ModelType: Base,
CreateSchemaType: BaseModel,
UpdateSchemaType: BaseModel,
]:
"""Async repository operations for a model."""
def __init__(self, model: type[ModelType]):
"""
Repository object with default async methods to Create, Read, Update, Delete.
Parameters:
model: A SQLAlchemy model class
"""
self.model = model
async def get(
self, db: AsyncSession, id: str, options: list[Load] | None = None
) -> ModelType | None:
"""
Get a single record by ID with UUID validation and optional eager loading.
Args:
db: Database session
id: Record UUID
options: Optional list of SQLAlchemy load options (e.g., joinedload, selectinload)
for eager loading relationships to prevent N+1 queries
Returns:
Model instance or None if not found
"""
try:
if isinstance(id, uuid.UUID):
uuid_obj = id
else:
uuid_obj = uuid.UUID(str(id))
except (ValueError, AttributeError, TypeError) as e:
logger.warning(f"Invalid UUID format: {id} - {e!s}")
return None
try:
query = select(self.model).where(self.model.id == uuid_obj)
if options:
for option in options:
query = query.options(option)
result = await db.execute(query)
return result.scalar_one_or_none()
except Exception as e:
logger.error(f"Error retrieving {self.model.__name__} with id {id}: {e!s}")
raise
async def get_multi(
self,
db: AsyncSession,
*,
skip: int = 0,
limit: int = 100,
options: list[Load] | None = None,
) -> list[ModelType]:
"""
Get multiple records with pagination validation and optional eager loading.
"""
if skip < 0:
raise InvalidInputError("skip must be non-negative")
if limit < 0:
raise InvalidInputError("limit must be non-negative")
if limit > 1000:
raise InvalidInputError("Maximum limit is 1000")
try:
query = select(self.model).order_by(self.model.id).offset(skip).limit(limit)
if options:
for option in options:
query = query.options(option)
result = await db.execute(query)
return list(result.scalars().all())
except Exception as e:
logger.error(
f"Error retrieving multiple {self.model.__name__} records: {e!s}"
)
raise
async def create(
self, db: AsyncSession, *, obj_in: CreateSchemaType
) -> ModelType: # pragma: no cover
"""Create a new record with error handling.
NOTE: This method is defensive code that's never called in practice.
All repository subclasses override this method with their own implementations.
Marked as pragma: no cover to avoid false coverage gaps.
"""
try: # pragma: no cover
obj_in_data = jsonable_encoder(obj_in)
db_obj = self.model(**obj_in_data)
db.add(db_obj)
await db.commit()
await db.refresh(db_obj)
return db_obj
except IntegrityError as e: # pragma: no cover
await db.rollback()
error_msg = str(e.orig) if hasattr(e, "orig") else str(e)
if "unique" in error_msg.lower() or "duplicate" in error_msg.lower():
logger.warning(
f"Duplicate entry attempted for {self.model.__name__}: {error_msg}"
)
raise DuplicateEntryError(
f"A {self.model.__name__} with this data already exists"
)
logger.error(f"Integrity error creating {self.model.__name__}: {error_msg}")
raise IntegrityConstraintError(f"Database integrity error: {error_msg}")
except (OperationalError, DataError) as e: # pragma: no cover
await db.rollback()
logger.error(f"Database error creating {self.model.__name__}: {e!s}")
raise IntegrityConstraintError(f"Database operation failed: {e!s}")
except Exception as e: # pragma: no cover
await db.rollback()
logger.error(
f"Unexpected error creating {self.model.__name__}: {e!s}", exc_info=True
)
raise
async def update(
self,
db: AsyncSession,
*,
db_obj: ModelType,
obj_in: UpdateSchemaType | dict[str, Any],
) -> ModelType:
"""Update a record with error handling."""
try:
obj_data = jsonable_encoder(db_obj)
if isinstance(obj_in, dict):
update_data = obj_in
else:
update_data = obj_in.model_dump(exclude_unset=True)
for field in obj_data:
if field in update_data:
setattr(db_obj, field, update_data[field])
db.add(db_obj)
await db.commit()
await db.refresh(db_obj)
return db_obj
except IntegrityError as e:
await db.rollback()
error_msg = str(e.orig) if hasattr(e, "orig") else str(e)
if "unique" in error_msg.lower() or "duplicate" in error_msg.lower():
logger.warning(
f"Duplicate entry attempted for {self.model.__name__}: {error_msg}"
)
raise DuplicateEntryError(
f"A {self.model.__name__} with this data already exists"
)
logger.error(f"Integrity error updating {self.model.__name__}: {error_msg}")
raise IntegrityConstraintError(f"Database integrity error: {error_msg}")
except (OperationalError, DataError) as e:
await db.rollback()
logger.error(f"Database error updating {self.model.__name__}: {e!s}")
raise IntegrityConstraintError(f"Database operation failed: {e!s}")
except Exception as e:
await db.rollback()
logger.error(
f"Unexpected error updating {self.model.__name__}: {e!s}", exc_info=True
)
raise
async def remove(self, db: AsyncSession, *, id: str) -> ModelType | None:
"""Delete a record with error handling and null check."""
try:
if isinstance(id, uuid.UUID):
uuid_obj = id
else:
uuid_obj = uuid.UUID(str(id))
except (ValueError, AttributeError, TypeError) as e:
logger.warning(f"Invalid UUID format for deletion: {id} - {e!s}")
return None
try:
result = await db.execute(
select(self.model).where(self.model.id == uuid_obj)
)
obj = result.scalar_one_or_none()
if obj is None:
logger.warning(
f"{self.model.__name__} with id {id} not found for deletion"
)
return None
await db.delete(obj)
await db.commit()
return obj
except IntegrityError as e:
await db.rollback()
error_msg = str(e.orig) if hasattr(e, "orig") else str(e)
logger.error(f"Integrity error deleting {self.model.__name__}: {error_msg}")
raise IntegrityConstraintError(
f"Cannot delete {self.model.__name__}: referenced by other records"
)
except Exception as e:
await db.rollback()
logger.error(
f"Error deleting {self.model.__name__} with id {id}: {e!s}",
exc_info=True,
)
raise
async def get_multi_with_total(
self,
db: AsyncSession,
*,
skip: int = 0,
limit: int = 100,
sort_by: str | None = None,
sort_order: str = "asc",
filters: dict[str, Any] | None = None,
) -> tuple[list[ModelType], int]: # pragma: no cover
"""
Get multiple records with total count, filtering, and sorting.
NOTE: This method is defensive code that's never called in practice.
All repository subclasses override this method with their own implementations.
Marked as pragma: no cover to avoid false coverage gaps.
"""
if skip < 0:
raise InvalidInputError("skip must be non-negative")
if limit < 0:
raise InvalidInputError("limit must be non-negative")
if limit > 1000:
raise InvalidInputError("Maximum limit is 1000")
try:
query = select(self.model)
if hasattr(self.model, "deleted_at"):
query = query.where(self.model.deleted_at.is_(None))
if filters:
for field, value in filters.items():
if hasattr(self.model, field) and value is not None:
query = query.where(getattr(self.model, field) == value)
count_query = select(func.count()).select_from(query.alias())
count_result = await db.execute(count_query)
total = count_result.scalar_one()
if sort_by and hasattr(self.model, sort_by):
sort_column = getattr(self.model, sort_by)
if sort_order.lower() == "desc":
query = query.order_by(sort_column.desc())
else:
query = query.order_by(sort_column.asc())
else:
query = query.order_by(self.model.id)
query = query.offset(skip).limit(limit)
items_result = await db.execute(query)
items = list(items_result.scalars().all())
return items, total
except Exception as e: # pragma: no cover
logger.error(
f"Error retrieving paginated {self.model.__name__} records: {e!s}"
)
raise
async def count(self, db: AsyncSession) -> int:
"""Get total count of records."""
try:
result = await db.execute(select(func.count(self.model.id)))
return result.scalar_one()
except Exception as e:
logger.error(f"Error counting {self.model.__name__} records: {e!s}")
raise
async def exists(self, db: AsyncSession, id: str) -> bool:
"""Check if a record exists by ID."""
obj = await self.get(db, id=id)
return obj is not None
async def soft_delete(self, db: AsyncSession, *, id: str) -> ModelType | None:
"""
Soft delete a record by setting deleted_at timestamp.
Only works if the model has a 'deleted_at' column.
"""
from datetime import datetime
try:
if isinstance(id, uuid.UUID):
uuid_obj = id
else:
uuid_obj = uuid.UUID(str(id))
except (ValueError, AttributeError, TypeError) as e:
logger.warning(f"Invalid UUID format for soft deletion: {id} - {e!s}")
return None
try:
result = await db.execute(
select(self.model).where(self.model.id == uuid_obj)
)
obj = result.scalar_one_or_none()
if obj is None:
logger.warning(
f"{self.model.__name__} with id {id} not found for soft deletion"
)
return None
if not hasattr(self.model, "deleted_at"):
logger.error(f"{self.model.__name__} does not support soft deletes")
raise InvalidInputError(
f"{self.model.__name__} does not have a deleted_at column"
)
obj.deleted_at = datetime.now(UTC)
db.add(obj)
await db.commit()
await db.refresh(obj)
return obj
except Exception as e:
await db.rollback()
logger.error(
f"Error soft deleting {self.model.__name__} with id {id}: {e!s}",
exc_info=True,
)
raise
async def restore(self, db: AsyncSession, *, id: str) -> ModelType | None:
"""
Restore a soft-deleted record by clearing the deleted_at timestamp.
Only works if the model has a 'deleted_at' column.
"""
try:
if isinstance(id, uuid.UUID):
uuid_obj = id
else:
uuid_obj = uuid.UUID(str(id))
except (ValueError, AttributeError, TypeError) as e:
logger.warning(f"Invalid UUID format for restoration: {id} - {e!s}")
return None
try:
if hasattr(self.model, "deleted_at"):
result = await db.execute(
select(self.model).where(
self.model.id == uuid_obj, self.model.deleted_at.isnot(None)
)
)
obj = result.scalar_one_or_none()
else:
logger.error(f"{self.model.__name__} does not support soft deletes")
raise InvalidInputError(
f"{self.model.__name__} does not have a deleted_at column"
)
if obj is None:
logger.warning(
f"Soft-deleted {self.model.__name__} with id {id} not found for restoration"
)
return None
obj.deleted_at = None
db.add(obj)
await db.commit()
await db.refresh(obj)
return obj
except Exception as e:
await db.rollback()
logger.error(
f"Error restoring {self.model.__name__} with id {id}: {e!s}",
exc_info=True,
)
raise

View File

@@ -0,0 +1,235 @@
# app/repositories/oauth_account.py
"""Repository for OAuthAccount model async CRUD operations."""
import logging
from datetime import datetime
from uuid import UUID
from pydantic import BaseModel
from sqlalchemy import and_, delete, select
from sqlalchemy.exc import IntegrityError
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import joinedload
from app.core.repository_exceptions import DuplicateEntryError
from app.models.oauth_account import OAuthAccount
from app.repositories.base import BaseRepository
from app.schemas.oauth import OAuthAccountCreate
logger = logging.getLogger(__name__)
class EmptySchema(BaseModel):
"""Placeholder schema for repository operations that don't need update schemas."""
class OAuthAccountRepository(BaseRepository[OAuthAccount, OAuthAccountCreate, EmptySchema]):
"""Repository for OAuth account links."""
async def get_by_provider_id(
self,
db: AsyncSession,
*,
provider: str,
provider_user_id: str,
) -> OAuthAccount | None:
"""Get OAuth account by provider and provider user ID."""
try:
result = await db.execute(
select(OAuthAccount)
.where(
and_(
OAuthAccount.provider == provider,
OAuthAccount.provider_user_id == provider_user_id,
)
)
.options(joinedload(OAuthAccount.user))
)
return result.scalar_one_or_none()
except Exception as e: # pragma: no cover
logger.error(
f"Error getting OAuth account for {provider}:{provider_user_id}: {e!s}"
)
raise
async def get_by_provider_email(
self,
db: AsyncSession,
*,
provider: str,
email: str,
) -> OAuthAccount | None:
"""Get OAuth account by provider and email."""
try:
result = await db.execute(
select(OAuthAccount)
.where(
and_(
OAuthAccount.provider == provider,
OAuthAccount.provider_email == email,
)
)
.options(joinedload(OAuthAccount.user))
)
return result.scalar_one_or_none()
except Exception as e: # pragma: no cover
logger.error(
f"Error getting OAuth account for {provider} email {email}: {e!s}"
)
raise
async def get_user_accounts(
self,
db: AsyncSession,
*,
user_id: str | UUID,
) -> list[OAuthAccount]:
"""Get all OAuth accounts linked to a user."""
try:
user_uuid = UUID(str(user_id)) if isinstance(user_id, str) else user_id
result = await db.execute(
select(OAuthAccount)
.where(OAuthAccount.user_id == user_uuid)
.order_by(OAuthAccount.created_at.desc())
)
return list(result.scalars().all())
except Exception as e: # pragma: no cover
logger.error(f"Error getting OAuth accounts for user {user_id}: {e!s}")
raise
async def get_user_account_by_provider(
self,
db: AsyncSession,
*,
user_id: str | UUID,
provider: str,
) -> OAuthAccount | None:
"""Get a specific OAuth account for a user and provider."""
try:
user_uuid = UUID(str(user_id)) if isinstance(user_id, str) else user_id
result = await db.execute(
select(OAuthAccount).where(
and_(
OAuthAccount.user_id == user_uuid,
OAuthAccount.provider == provider,
)
)
)
return result.scalar_one_or_none()
except Exception as e: # pragma: no cover
logger.error(
f"Error getting OAuth account for user {user_id}, provider {provider}: {e!s}"
)
raise
async def create_account(
self, db: AsyncSession, *, obj_in: OAuthAccountCreate
) -> OAuthAccount:
"""Create a new OAuth account link."""
try:
db_obj = OAuthAccount(
user_id=obj_in.user_id,
provider=obj_in.provider,
provider_user_id=obj_in.provider_user_id,
provider_email=obj_in.provider_email,
access_token=obj_in.access_token,
refresh_token=obj_in.refresh_token,
token_expires_at=obj_in.token_expires_at,
)
db.add(db_obj)
await db.commit()
await db.refresh(db_obj)
logger.info(
f"OAuth account created: {obj_in.provider} linked to user {obj_in.user_id}"
)
return db_obj
except IntegrityError as e: # pragma: no cover
await db.rollback()
error_msg = str(e.orig) if hasattr(e, "orig") else str(e)
if "uq_oauth_provider_user" in error_msg.lower():
logger.warning(
f"OAuth account already exists: {obj_in.provider}:{obj_in.provider_user_id}"
)
raise DuplicateEntryError(
f"This {obj_in.provider} account is already linked to another user"
)
logger.error(f"Integrity error creating OAuth account: {error_msg}")
raise DuplicateEntryError(f"Failed to create OAuth account: {error_msg}")
except Exception as e: # pragma: no cover
await db.rollback()
logger.error(f"Error creating OAuth account: {e!s}", exc_info=True)
raise
async def delete_account(
self,
db: AsyncSession,
*,
user_id: str | UUID,
provider: str,
) -> bool:
"""Delete an OAuth account link."""
try:
user_uuid = UUID(str(user_id)) if isinstance(user_id, str) else user_id
result = await db.execute(
delete(OAuthAccount).where(
and_(
OAuthAccount.user_id == user_uuid,
OAuthAccount.provider == provider,
)
)
)
await db.commit()
deleted = result.rowcount > 0
if deleted:
logger.info(
f"OAuth account deleted: {provider} unlinked from user {user_id}"
)
else:
logger.warning(
f"OAuth account not found for deletion: {provider} for user {user_id}"
)
return deleted
except Exception as e: # pragma: no cover
await db.rollback()
logger.error(
f"Error deleting OAuth account {provider} for user {user_id}: {e!s}"
)
raise
async def update_tokens(
self,
db: AsyncSession,
*,
account: OAuthAccount,
access_token: str | None = None,
refresh_token: str | None = None,
token_expires_at: datetime | None = None,
) -> OAuthAccount:
"""Update OAuth tokens for an account."""
try:
if access_token is not None:
account.access_token = access_token
if refresh_token is not None:
account.refresh_token = refresh_token
if token_expires_at is not None:
account.token_expires_at = token_expires_at
db.add(account)
await db.commit()
await db.refresh(account)
return account
except Exception as e: # pragma: no cover
await db.rollback()
logger.error(f"Error updating OAuth tokens: {e!s}")
raise
# Singleton instance
oauth_account_repo = OAuthAccountRepository(OAuthAccount)

View File

@@ -0,0 +1,108 @@
# app/repositories/oauth_authorization_code.py
"""Repository for OAuthAuthorizationCode model."""
import logging
from datetime import UTC, datetime
from uuid import UUID
from sqlalchemy import and_, delete, select, update
from sqlalchemy.ext.asyncio import AsyncSession
from app.models.oauth_authorization_code import OAuthAuthorizationCode
logger = logging.getLogger(__name__)
class OAuthAuthorizationCodeRepository:
"""Repository for OAuth 2.0 authorization codes."""
async def create_code(
self,
db: AsyncSession,
*,
code: str,
client_id: str,
user_id: UUID,
redirect_uri: str,
scope: str,
expires_at: datetime,
code_challenge: str | None = None,
code_challenge_method: str | None = None,
state: str | None = None,
nonce: str | None = None,
) -> OAuthAuthorizationCode:
"""Create and persist a new authorization code."""
auth_code = OAuthAuthorizationCode(
code=code,
client_id=client_id,
user_id=user_id,
redirect_uri=redirect_uri,
scope=scope,
code_challenge=code_challenge,
code_challenge_method=code_challenge_method,
state=state,
nonce=nonce,
expires_at=expires_at,
used=False,
)
db.add(auth_code)
await db.commit()
return auth_code
async def consume_code_atomically(
self, db: AsyncSession, *, code: str
) -> UUID | None:
"""
Atomically mark a code as used and return its UUID.
Returns the UUID if the code was found and not yet used, None otherwise.
This prevents race conditions per RFC 6749 Section 4.1.2.
"""
stmt = (
update(OAuthAuthorizationCode)
.where(
and_(
OAuthAuthorizationCode.code == code,
OAuthAuthorizationCode.used == False, # noqa: E712
)
)
.values(used=True)
.returning(OAuthAuthorizationCode.id)
)
result = await db.execute(stmt)
row_id = result.scalar_one_or_none()
if row_id is not None:
await db.commit()
return row_id
async def get_by_id(
self, db: AsyncSession, *, code_id: UUID
) -> OAuthAuthorizationCode | None:
"""Get authorization code by its UUID primary key."""
result = await db.execute(
select(OAuthAuthorizationCode).where(OAuthAuthorizationCode.id == code_id)
)
return result.scalar_one_or_none()
async def get_by_code(
self, db: AsyncSession, *, code: str
) -> OAuthAuthorizationCode | None:
"""Get authorization code by the code string value."""
result = await db.execute(
select(OAuthAuthorizationCode).where(OAuthAuthorizationCode.code == code)
)
return result.scalar_one_or_none()
async def cleanup_expired(self, db: AsyncSession) -> int:
"""Delete all expired authorization codes. Returns count deleted."""
result = await db.execute(
delete(OAuthAuthorizationCode).where(
OAuthAuthorizationCode.expires_at < datetime.now(UTC)
)
)
await db.commit()
return result.rowcount # type: ignore[attr-defined]
# Singleton instance
oauth_authorization_code_repo = OAuthAuthorizationCodeRepository()

View File

@@ -0,0 +1,199 @@
# app/repositories/oauth_client.py
"""Repository for OAuthClient model async CRUD operations."""
import logging
import secrets
from uuid import UUID
from pydantic import BaseModel
from sqlalchemy import and_, delete, select
from sqlalchemy.exc import IntegrityError
from sqlalchemy.ext.asyncio import AsyncSession
from app.core.repository_exceptions import DuplicateEntryError
from app.models.oauth_client import OAuthClient
from app.repositories.base import BaseRepository
from app.schemas.oauth import OAuthClientCreate
logger = logging.getLogger(__name__)
class EmptySchema(BaseModel):
"""Placeholder schema for repository operations that don't need update schemas."""
class OAuthClientRepository(BaseRepository[OAuthClient, OAuthClientCreate, EmptySchema]):
"""Repository for OAuth clients (provider mode)."""
async def get_by_client_id(
self, db: AsyncSession, *, client_id: str
) -> OAuthClient | None:
"""Get OAuth client by client_id."""
try:
result = await db.execute(
select(OAuthClient).where(
and_(
OAuthClient.client_id == client_id,
OAuthClient.is_active == True, # noqa: E712
)
)
)
return result.scalar_one_or_none()
except Exception as e: # pragma: no cover
logger.error(f"Error getting OAuth client {client_id}: {e!s}")
raise
async def create_client(
self,
db: AsyncSession,
*,
obj_in: OAuthClientCreate,
owner_user_id: UUID | None = None,
) -> tuple[OAuthClient, str | None]:
"""Create a new OAuth client."""
try:
client_id = secrets.token_urlsafe(32)
client_secret = None
client_secret_hash = None
if obj_in.client_type == "confidential":
client_secret = secrets.token_urlsafe(48)
from app.core.auth import get_password_hash
client_secret_hash = get_password_hash(client_secret)
db_obj = OAuthClient(
client_id=client_id,
client_secret_hash=client_secret_hash,
client_name=obj_in.client_name,
client_description=obj_in.client_description,
client_type=obj_in.client_type,
redirect_uris=obj_in.redirect_uris,
allowed_scopes=obj_in.allowed_scopes,
owner_user_id=owner_user_id,
is_active=True,
)
db.add(db_obj)
await db.commit()
await db.refresh(db_obj)
logger.info(
f"OAuth client created: {obj_in.client_name} ({client_id[:8]}...)"
)
return db_obj, client_secret
except IntegrityError as e: # pragma: no cover
await db.rollback()
error_msg = str(e.orig) if hasattr(e, "orig") else str(e)
logger.error(f"Error creating OAuth client: {error_msg}")
raise DuplicateEntryError(f"Failed to create OAuth client: {error_msg}")
except Exception as e: # pragma: no cover
await db.rollback()
logger.error(f"Error creating OAuth client: {e!s}", exc_info=True)
raise
async def deactivate_client(
self, db: AsyncSession, *, client_id: str
) -> OAuthClient | None:
"""Deactivate an OAuth client."""
try:
client = await self.get_by_client_id(db, client_id=client_id)
if client is None:
return None
client.is_active = False
db.add(client)
await db.commit()
await db.refresh(client)
logger.info(f"OAuth client deactivated: {client.client_name}")
return client
except Exception as e: # pragma: no cover
await db.rollback()
logger.error(f"Error deactivating OAuth client {client_id}: {e!s}")
raise
async def validate_redirect_uri(
self, db: AsyncSession, *, client_id: str, redirect_uri: str
) -> bool:
"""Validate that a redirect URI is allowed for a client."""
try:
client = await self.get_by_client_id(db, client_id=client_id)
if client is None:
return False
return redirect_uri in (client.redirect_uris or [])
except Exception as e: # pragma: no cover
logger.error(f"Error validating redirect URI: {e!s}")
return False
async def verify_client_secret(
self, db: AsyncSession, *, client_id: str, client_secret: str
) -> bool:
"""Verify client credentials."""
try:
result = await db.execute(
select(OAuthClient).where(
and_(
OAuthClient.client_id == client_id,
OAuthClient.is_active == True, # noqa: E712
)
)
)
client = result.scalar_one_or_none()
if client is None or client.client_secret_hash is None:
return False
from app.core.auth import verify_password
stored_hash: str = str(client.client_secret_hash)
if stored_hash.startswith("$2"):
return verify_password(client_secret, stored_hash)
else:
import hashlib
secret_hash = hashlib.sha256(client_secret.encode()).hexdigest()
return secrets.compare_digest(stored_hash, secret_hash)
except Exception as e: # pragma: no cover
logger.error(f"Error verifying client secret: {e!s}")
return False
async def get_all_clients(
self, db: AsyncSession, *, include_inactive: bool = False
) -> list[OAuthClient]:
"""Get all OAuth clients."""
try:
query = select(OAuthClient).order_by(OAuthClient.created_at.desc())
if not include_inactive:
query = query.where(OAuthClient.is_active == True) # noqa: E712
result = await db.execute(query)
return list(result.scalars().all())
except Exception as e: # pragma: no cover
logger.error(f"Error getting all OAuth clients: {e!s}")
raise
async def delete_client(self, db: AsyncSession, *, client_id: str) -> bool:
"""Delete an OAuth client permanently."""
try:
result = await db.execute(
delete(OAuthClient).where(OAuthClient.client_id == client_id)
)
await db.commit()
deleted = result.rowcount > 0
if deleted:
logger.info(f"OAuth client deleted: {client_id}")
else:
logger.warning(f"OAuth client not found for deletion: {client_id}")
return deleted
except Exception as e: # pragma: no cover
await db.rollback()
logger.error(f"Error deleting OAuth client {client_id}: {e!s}")
raise
# Singleton instance
oauth_client_repo = OAuthClientRepository(OAuthClient)

View File

@@ -0,0 +1,112 @@
# app/repositories/oauth_consent.py
"""Repository for OAuthConsent model."""
import logging
from uuid import UUID
from typing import Any
from sqlalchemy import and_, delete, select
from sqlalchemy.ext.asyncio import AsyncSession
from app.models.oauth_client import OAuthClient
from app.models.oauth_provider_token import OAuthConsent
logger = logging.getLogger(__name__)
class OAuthConsentRepository:
"""Repository for OAuth consent records (user grants to clients)."""
async def get_consent(
self, db: AsyncSession, *, user_id: UUID, client_id: str
) -> OAuthConsent | None:
"""Get the consent record for a user-client pair, or None if not found."""
result = await db.execute(
select(OAuthConsent).where(
and_(
OAuthConsent.user_id == user_id,
OAuthConsent.client_id == client_id,
)
)
)
return result.scalar_one_or_none()
async def grant_consent(
self,
db: AsyncSession,
*,
user_id: UUID,
client_id: str,
scopes: list[str],
) -> OAuthConsent:
"""
Create or update consent for a user-client pair.
If consent already exists, the new scopes are merged with existing ones.
Returns the created or updated consent record.
"""
consent = await self.get_consent(db, user_id=user_id, client_id=client_id)
if consent:
existing = set(consent.granted_scopes.split()) if consent.granted_scopes else set()
merged = existing | set(scopes)
consent.granted_scopes = " ".join(sorted(merged)) # type: ignore[assignment]
else:
consent = OAuthConsent(
user_id=user_id,
client_id=client_id,
granted_scopes=" ".join(sorted(set(scopes))),
)
db.add(consent)
await db.commit()
await db.refresh(consent)
return consent
async def get_user_consents_with_clients(
self, db: AsyncSession, *, user_id: UUID
) -> list[dict[str, Any]]:
"""Get all consent records for a user joined with client details."""
result = await db.execute(
select(OAuthConsent, OAuthClient)
.join(OAuthClient, OAuthConsent.client_id == OAuthClient.client_id)
.where(OAuthConsent.user_id == user_id)
)
rows = result.all()
return [
{
"client_id": consent.client_id,
"client_name": client.client_name,
"client_description": client.client_description,
"granted_scopes": consent.granted_scopes.split()
if consent.granted_scopes
else [],
"granted_at": consent.created_at.isoformat(),
}
for consent, client in rows
]
async def revoke_consent(
self, db: AsyncSession, *, user_id: UUID, client_id: str
) -> bool:
"""
Delete the consent record for a user-client pair.
Returns True if a record was found and deleted.
Note: Callers are responsible for also revoking associated tokens.
"""
result = await db.execute(
delete(OAuthConsent).where(
and_(
OAuthConsent.user_id == user_id,
OAuthConsent.client_id == client_id,
)
)
)
await db.commit()
return result.rowcount > 0 # type: ignore[attr-defined]
# Singleton instance
oauth_consent_repo = OAuthConsentRepository()

View File

@@ -0,0 +1,146 @@
# app/repositories/oauth_provider_token.py
"""Repository for OAuthProviderRefreshToken model."""
import logging
from datetime import UTC, datetime, timedelta
from uuid import UUID
from sqlalchemy import and_, delete, select, update
from sqlalchemy.ext.asyncio import AsyncSession
from app.models.oauth_provider_token import OAuthProviderRefreshToken
logger = logging.getLogger(__name__)
class OAuthProviderTokenRepository:
"""Repository for OAuth provider refresh tokens."""
async def create_token(
self,
db: AsyncSession,
*,
token_hash: str,
jti: str,
client_id: str,
user_id: UUID,
scope: str,
expires_at: datetime,
device_info: str | None = None,
ip_address: str | None = None,
) -> OAuthProviderRefreshToken:
"""Create and persist a new refresh token record."""
token = OAuthProviderRefreshToken(
token_hash=token_hash,
jti=jti,
client_id=client_id,
user_id=user_id,
scope=scope,
expires_at=expires_at,
device_info=device_info,
ip_address=ip_address,
)
db.add(token)
await db.commit()
return token
async def get_by_token_hash(
self, db: AsyncSession, *, token_hash: str
) -> OAuthProviderRefreshToken | None:
"""Get refresh token record by SHA-256 token hash."""
result = await db.execute(
select(OAuthProviderRefreshToken).where(
OAuthProviderRefreshToken.token_hash == token_hash
)
)
return result.scalar_one_or_none()
async def get_by_jti(
self, db: AsyncSession, *, jti: str
) -> OAuthProviderRefreshToken | None:
"""Get refresh token record by JWT ID (JTI)."""
result = await db.execute(
select(OAuthProviderRefreshToken).where(
OAuthProviderRefreshToken.jti == jti
)
)
return result.scalar_one_or_none()
async def revoke(
self, db: AsyncSession, *, token: OAuthProviderRefreshToken
) -> None:
"""Mark a specific token record as revoked."""
token.revoked = True # type: ignore[assignment]
token.last_used_at = datetime.now(UTC) # type: ignore[assignment]
await db.commit()
async def revoke_all_for_user_client(
self, db: AsyncSession, *, user_id: UUID, client_id: str
) -> int:
"""
Revoke all active tokens for a specific user-client pair.
Used when security incidents are detected (e.g., authorization code reuse).
Returns the number of tokens revoked.
"""
result = await db.execute(
update(OAuthProviderRefreshToken)
.where(
and_(
OAuthProviderRefreshToken.user_id == user_id,
OAuthProviderRefreshToken.client_id == client_id,
OAuthProviderRefreshToken.revoked == False, # noqa: E712
)
)
.values(revoked=True)
)
count = result.rowcount # type: ignore[attr-defined]
if count > 0:
await db.commit()
return count
async def revoke_all_for_user(
self, db: AsyncSession, *, user_id: UUID
) -> int:
"""
Revoke all active tokens for a user across all clients.
Used when user changes password or logs out everywhere.
Returns the number of tokens revoked.
"""
result = await db.execute(
update(OAuthProviderRefreshToken)
.where(
and_(
OAuthProviderRefreshToken.user_id == user_id,
OAuthProviderRefreshToken.revoked == False, # noqa: E712
)
)
.values(revoked=True)
)
count = result.rowcount # type: ignore[attr-defined]
if count > 0:
await db.commit()
return count
async def cleanup_expired(
self, db: AsyncSession, *, cutoff_days: int = 7
) -> int:
"""
Delete expired refresh tokens older than cutoff_days.
Should be called periodically (e.g., daily).
Returns the number of tokens deleted.
"""
cutoff = datetime.now(UTC) - timedelta(days=cutoff_days)
result = await db.execute(
delete(OAuthProviderRefreshToken).where(
OAuthProviderRefreshToken.expires_at < cutoff
)
)
await db.commit()
return result.rowcount # type: ignore[attr-defined]
# Singleton instance
oauth_provider_token_repo = OAuthProviderTokenRepository()

View File

@@ -0,0 +1,113 @@
# app/repositories/oauth_state.py
"""Repository for OAuthState model async CRUD operations."""
import logging
from datetime import UTC, datetime
from pydantic import BaseModel
from sqlalchemy import delete, select
from sqlalchemy.exc import IntegrityError
from sqlalchemy.ext.asyncio import AsyncSession
from app.core.repository_exceptions import DuplicateEntryError
from app.models.oauth_state import OAuthState
from app.repositories.base import BaseRepository
from app.schemas.oauth import OAuthStateCreate
logger = logging.getLogger(__name__)
class EmptySchema(BaseModel):
"""Placeholder schema for repository operations that don't need update schemas."""
class OAuthStateRepository(BaseRepository[OAuthState, OAuthStateCreate, EmptySchema]):
"""Repository for OAuth state (CSRF protection)."""
async def create_state(
self, db: AsyncSession, *, obj_in: OAuthStateCreate
) -> OAuthState:
"""Create a new OAuth state for CSRF protection."""
try:
db_obj = OAuthState(
state=obj_in.state,
code_verifier=obj_in.code_verifier,
nonce=obj_in.nonce,
provider=obj_in.provider,
redirect_uri=obj_in.redirect_uri,
user_id=obj_in.user_id,
expires_at=obj_in.expires_at,
)
db.add(db_obj)
await db.commit()
await db.refresh(db_obj)
logger.debug(f"OAuth state created for {obj_in.provider}")
return db_obj
except IntegrityError as e: # pragma: no cover
await db.rollback()
error_msg = str(e.orig) if hasattr(e, "orig") else str(e)
logger.error(f"OAuth state collision: {error_msg}")
raise DuplicateEntryError("Failed to create OAuth state, please retry")
except Exception as e: # pragma: no cover
await db.rollback()
logger.error(f"Error creating OAuth state: {e!s}", exc_info=True)
raise
async def get_and_consume_state(
self, db: AsyncSession, *, state: str
) -> OAuthState | None:
"""Get and delete OAuth state (consume it)."""
try:
result = await db.execute(
select(OAuthState).where(OAuthState.state == state)
)
db_obj = result.scalar_one_or_none()
if db_obj is None:
logger.warning(f"OAuth state not found: {state[:8]}...")
return None
now = datetime.now(UTC)
expires_at = db_obj.expires_at
if expires_at.tzinfo is None:
expires_at = expires_at.replace(tzinfo=UTC)
if expires_at < now:
logger.warning(f"OAuth state expired: {state[:8]}...")
await db.delete(db_obj)
await db.commit()
return None
await db.delete(db_obj)
await db.commit()
logger.debug(f"OAuth state consumed: {state[:8]}...")
return db_obj
except Exception as e: # pragma: no cover
await db.rollback()
logger.error(f"Error consuming OAuth state: {e!s}")
raise
async def cleanup_expired(self, db: AsyncSession) -> int:
"""Clean up expired OAuth states."""
try:
now = datetime.now(UTC)
stmt = delete(OAuthState).where(OAuthState.expires_at < now)
result = await db.execute(stmt)
await db.commit()
count = result.rowcount
if count > 0:
logger.info(f"Cleaned up {count} expired OAuth states")
return count
except Exception as e: # pragma: no cover
await db.rollback()
logger.error(f"Error cleaning up expired OAuth states: {e!s}")
raise
# Singleton instance
oauth_state_repo = OAuthStateRepository(OAuthState)

View File

@@ -0,0 +1,496 @@
# app/repositories/organization.py
"""Repository for Organization model async CRUD operations using SQLAlchemy 2.0 patterns."""
import logging
from typing import Any
from uuid import UUID
from sqlalchemy import and_, case, func, or_, select
from sqlalchemy.exc import IntegrityError
from sqlalchemy.ext.asyncio import AsyncSession
from app.core.repository_exceptions import DuplicateEntryError, IntegrityConstraintError
from app.models.organization import Organization
from app.models.user import User
from app.models.user_organization import OrganizationRole, UserOrganization
from app.repositories.base import BaseRepository
from app.schemas.organizations import (
OrganizationCreate,
OrganizationUpdate,
)
logger = logging.getLogger(__name__)
class OrganizationRepository(BaseRepository[Organization, OrganizationCreate, OrganizationUpdate]):
"""Repository for Organization model."""
async def get_by_slug(self, db: AsyncSession, *, slug: str) -> Organization | None:
"""Get organization by slug."""
try:
result = await db.execute(
select(Organization).where(Organization.slug == slug)
)
return result.scalar_one_or_none()
except Exception as e:
logger.error(f"Error getting organization by slug {slug}: {e!s}")
raise
async def create(
self, db: AsyncSession, *, obj_in: OrganizationCreate
) -> Organization:
"""Create a new organization with error handling."""
try:
db_obj = Organization(
name=obj_in.name,
slug=obj_in.slug,
description=obj_in.description,
is_active=obj_in.is_active,
settings=obj_in.settings or {},
)
db.add(db_obj)
await db.commit()
await db.refresh(db_obj)
return db_obj
except IntegrityError as e:
await db.rollback()
error_msg = str(e.orig) if hasattr(e, "orig") else str(e)
if "slug" in error_msg.lower() or "unique" in error_msg.lower() or "duplicate" in error_msg.lower():
logger.warning(f"Duplicate slug attempted: {obj_in.slug}")
raise DuplicateEntryError(
f"Organization with slug '{obj_in.slug}' already exists"
)
logger.error(f"Integrity error creating organization: {error_msg}")
raise IntegrityConstraintError(f"Database integrity error: {error_msg}")
except Exception as e:
await db.rollback()
logger.error(
f"Unexpected error creating organization: {e!s}", exc_info=True
)
raise
async def get_multi_with_filters(
self,
db: AsyncSession,
*,
skip: int = 0,
limit: int = 100,
is_active: bool | None = None,
search: str | None = None,
sort_by: str = "created_at",
sort_order: str = "desc",
) -> tuple[list[Organization], int]:
"""Get multiple organizations with filtering, searching, and sorting."""
try:
query = select(Organization)
if is_active is not None:
query = query.where(Organization.is_active == is_active)
if search:
search_filter = or_(
Organization.name.ilike(f"%{search}%"),
Organization.slug.ilike(f"%{search}%"),
Organization.description.ilike(f"%{search}%"),
)
query = query.where(search_filter)
count_query = select(func.count()).select_from(query.alias())
count_result = await db.execute(count_query)
total = count_result.scalar_one()
sort_column = getattr(Organization, sort_by, Organization.created_at)
if sort_order == "desc":
query = query.order_by(sort_column.desc())
else:
query = query.order_by(sort_column.asc())
query = query.offset(skip).limit(limit)
result = await db.execute(query)
organizations = list(result.scalars().all())
return organizations, total
except Exception as e:
logger.error(f"Error getting organizations with filters: {e!s}")
raise
async def get_member_count(self, db: AsyncSession, *, organization_id: UUID) -> int:
"""Get the count of active members in an organization."""
try:
result = await db.execute(
select(func.count(UserOrganization.user_id)).where(
and_(
UserOrganization.organization_id == organization_id,
UserOrganization.is_active,
)
)
)
return result.scalar_one() or 0
except Exception as e:
logger.error(
f"Error getting member count for organization {organization_id}: {e!s}"
)
raise
async def get_multi_with_member_counts(
self,
db: AsyncSession,
*,
skip: int = 0,
limit: int = 100,
is_active: bool | None = None,
search: str | None = None,
) -> tuple[list[dict[str, Any]], int]:
"""Get organizations with member counts in a SINGLE QUERY using JOIN and GROUP BY."""
try:
query = (
select(
Organization,
func.count(
func.distinct(
case(
(
UserOrganization.is_active,
UserOrganization.user_id,
),
else_=None,
)
)
).label("member_count"),
)
.outerjoin(
UserOrganization,
Organization.id == UserOrganization.organization_id,
)
.group_by(Organization.id)
)
if is_active is not None:
query = query.where(Organization.is_active == is_active)
if search:
search_filter = or_(
Organization.name.ilike(f"%{search}%"),
Organization.slug.ilike(f"%{search}%"),
Organization.description.ilike(f"%{search}%"),
)
query = query.where(search_filter)
count_query = select(func.count(Organization.id))
if is_active is not None:
count_query = count_query.where(Organization.is_active == is_active)
if search:
count_query = count_query.where(search_filter)
count_result = await db.execute(count_query)
total = count_result.scalar_one()
query = (
query.order_by(Organization.created_at.desc()).offset(skip).limit(limit)
)
result = await db.execute(query)
rows = result.all()
orgs_with_counts = [
{"organization": org, "member_count": member_count}
for org, member_count in rows
]
return orgs_with_counts, total
except Exception as e:
logger.error(
f"Error getting organizations with member counts: {e!s}", exc_info=True
)
raise
async def add_user(
self,
db: AsyncSession,
*,
organization_id: UUID,
user_id: UUID,
role: OrganizationRole = OrganizationRole.MEMBER,
custom_permissions: str | None = None,
) -> UserOrganization:
"""Add a user to an organization with a specific role."""
try:
result = await db.execute(
select(UserOrganization).where(
and_(
UserOrganization.user_id == user_id,
UserOrganization.organization_id == organization_id,
)
)
)
existing = result.scalar_one_or_none()
if existing:
if not existing.is_active:
existing.is_active = True
existing.role = role
existing.custom_permissions = custom_permissions
await db.commit()
await db.refresh(existing)
return existing
else:
raise DuplicateEntryError("User is already a member of this organization")
user_org = UserOrganization(
user_id=user_id,
organization_id=organization_id,
role=role,
is_active=True,
custom_permissions=custom_permissions,
)
db.add(user_org)
await db.commit()
await db.refresh(user_org)
return user_org
except IntegrityError as e:
await db.rollback()
logger.error(f"Integrity error adding user to organization: {e!s}")
raise IntegrityConstraintError("Failed to add user to organization")
except Exception as e:
await db.rollback()
logger.error(f"Error adding user to organization: {e!s}", exc_info=True)
raise
async def remove_user(
self, db: AsyncSession, *, organization_id: UUID, user_id: UUID
) -> bool:
"""Remove a user from an organization (soft delete)."""
try:
result = await db.execute(
select(UserOrganization).where(
and_(
UserOrganization.user_id == user_id,
UserOrganization.organization_id == organization_id,
)
)
)
user_org = result.scalar_one_or_none()
if not user_org:
return False
user_org.is_active = False
await db.commit()
return True
except Exception as e:
await db.rollback()
logger.error(f"Error removing user from organization: {e!s}", exc_info=True)
raise
async def update_user_role(
self,
db: AsyncSession,
*,
organization_id: UUID,
user_id: UUID,
role: OrganizationRole,
custom_permissions: str | None = None,
) -> UserOrganization | None:
"""Update a user's role in an organization."""
try:
result = await db.execute(
select(UserOrganization).where(
and_(
UserOrganization.user_id == user_id,
UserOrganization.organization_id == organization_id,
)
)
)
user_org = result.scalar_one_or_none()
if not user_org:
return None
user_org.role = role
if custom_permissions is not None:
user_org.custom_permissions = custom_permissions
await db.commit()
await db.refresh(user_org)
return user_org
except Exception as e:
await db.rollback()
logger.error(f"Error updating user role: {e!s}", exc_info=True)
raise
async def get_organization_members(
self,
db: AsyncSession,
*,
organization_id: UUID,
skip: int = 0,
limit: int = 100,
is_active: bool = True,
) -> tuple[list[dict[str, Any]], int]:
"""Get members of an organization with user details."""
try:
query = (
select(UserOrganization, User)
.join(User, UserOrganization.user_id == User.id)
.where(UserOrganization.organization_id == organization_id)
)
if is_active is not None:
query = query.where(UserOrganization.is_active == is_active)
count_query = select(func.count()).select_from(
select(UserOrganization)
.where(UserOrganization.organization_id == organization_id)
.where(
UserOrganization.is_active == is_active
if is_active is not None
else True
)
.alias()
)
count_result = await db.execute(count_query)
total = count_result.scalar_one()
query = (
query.order_by(UserOrganization.created_at.desc())
.offset(skip)
.limit(limit)
)
result = await db.execute(query)
results = result.all()
members = []
for user_org, user in results:
members.append(
{
"user_id": user.id,
"email": user.email,
"first_name": user.first_name,
"last_name": user.last_name,
"role": user_org.role,
"is_active": user_org.is_active,
"joined_at": user_org.created_at,
}
)
return members, total
except Exception as e:
logger.error(f"Error getting organization members: {e!s}")
raise
async def get_user_organizations(
self, db: AsyncSession, *, user_id: UUID, is_active: bool = True
) -> list[Organization]:
"""Get all organizations a user belongs to."""
try:
query = (
select(Organization)
.join(
UserOrganization,
Organization.id == UserOrganization.organization_id,
)
.where(UserOrganization.user_id == user_id)
)
if is_active is not None:
query = query.where(UserOrganization.is_active == is_active)
result = await db.execute(query)
return list(result.scalars().all())
except Exception as e:
logger.error(f"Error getting user organizations: {e!s}")
raise
async def get_user_organizations_with_details(
self, db: AsyncSession, *, user_id: UUID, is_active: bool = True
) -> list[dict[str, Any]]:
"""Get user's organizations with role and member count in SINGLE QUERY."""
try:
member_count_subq = (
select(
UserOrganization.organization_id,
func.count(UserOrganization.user_id).label("member_count"),
)
.where(UserOrganization.is_active)
.group_by(UserOrganization.organization_id)
.subquery()
)
query = (
select(
Organization,
UserOrganization.role,
func.coalesce(member_count_subq.c.member_count, 0).label(
"member_count"
),
)
.join(
UserOrganization,
Organization.id == UserOrganization.organization_id,
)
.outerjoin(
member_count_subq,
Organization.id == member_count_subq.c.organization_id,
)
.where(UserOrganization.user_id == user_id)
)
if is_active is not None:
query = query.where(UserOrganization.is_active == is_active)
result = await db.execute(query)
rows = result.all()
return [
{"organization": org, "role": role, "member_count": member_count}
for org, role, member_count in rows
]
except Exception as e:
logger.error(
f"Error getting user organizations with details: {e!s}", exc_info=True
)
raise
async def get_user_role_in_org(
self, db: AsyncSession, *, user_id: UUID, organization_id: UUID
) -> OrganizationRole | None:
"""Get a user's role in a specific organization."""
try:
result = await db.execute(
select(UserOrganization).where(
and_(
UserOrganization.user_id == user_id,
UserOrganization.organization_id == organization_id,
UserOrganization.is_active,
)
)
)
user_org = result.scalar_one_or_none()
return user_org.role if user_org else None
except Exception as e:
logger.error(f"Error getting user role in org: {e!s}")
raise
async def is_user_org_owner(
self, db: AsyncSession, *, user_id: UUID, organization_id: UUID
) -> bool:
"""Check if a user is an owner of an organization."""
role = await self.get_user_role_in_org(
db, user_id=user_id, organization_id=organization_id
)
return role == OrganizationRole.OWNER
async def is_user_org_admin(
self, db: AsyncSession, *, user_id: UUID, organization_id: UUID
) -> bool:
"""Check if a user is an owner or admin of an organization."""
role = await self.get_user_role_in_org(
db, user_id=user_id, organization_id=organization_id
)
return role in [OrganizationRole.OWNER, OrganizationRole.ADMIN]
# Singleton instance
organization_repo = OrganizationRepository(Organization)

View File

@@ -0,0 +1,327 @@
# app/repositories/session.py
"""Repository for UserSession model async CRUD operations using SQLAlchemy 2.0 patterns."""
import logging
import uuid
from datetime import UTC, datetime, timedelta
from uuid import UUID
from sqlalchemy import and_, delete, func, select, update
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import joinedload
from app.core.repository_exceptions import InvalidInputError, IntegrityConstraintError
from app.models.user_session import UserSession
from app.repositories.base import BaseRepository
from app.schemas.sessions import SessionCreate, SessionUpdate
logger = logging.getLogger(__name__)
class SessionRepository(BaseRepository[UserSession, SessionCreate, SessionUpdate]):
"""Repository for UserSession model."""
async def get_by_jti(self, db: AsyncSession, *, jti: str) -> UserSession | None:
"""Get session by refresh token JTI."""
try:
result = await db.execute(
select(UserSession).where(UserSession.refresh_token_jti == jti)
)
return result.scalar_one_or_none()
except Exception as e:
logger.error(f"Error getting session by JTI {jti}: {e!s}")
raise
async def get_active_by_jti(
self, db: AsyncSession, *, jti: str
) -> UserSession | None:
"""Get active session by refresh token JTI."""
try:
result = await db.execute(
select(UserSession).where(
and_(
UserSession.refresh_token_jti == jti,
UserSession.is_active,
)
)
)
return result.scalar_one_or_none()
except Exception as e:
logger.error(f"Error getting active session by JTI {jti}: {e!s}")
raise
async def get_user_sessions(
self,
db: AsyncSession,
*,
user_id: str,
active_only: bool = True,
with_user: bool = False,
) -> list[UserSession]:
"""Get all sessions for a user with optional eager loading."""
try:
user_uuid = UUID(user_id) if isinstance(user_id, str) else user_id
query = select(UserSession).where(UserSession.user_id == user_uuid)
if with_user:
query = query.options(joinedload(UserSession.user))
if active_only:
query = query.where(UserSession.is_active)
query = query.order_by(UserSession.last_used_at.desc())
result = await db.execute(query)
return list(result.scalars().all())
except Exception as e:
logger.error(f"Error getting sessions for user {user_id}: {e!s}")
raise
async def create_session(
self, db: AsyncSession, *, obj_in: SessionCreate
) -> UserSession:
"""Create a new user session."""
try:
db_obj = UserSession(
user_id=obj_in.user_id,
refresh_token_jti=obj_in.refresh_token_jti,
device_name=obj_in.device_name,
device_id=obj_in.device_id,
ip_address=obj_in.ip_address,
user_agent=obj_in.user_agent,
last_used_at=obj_in.last_used_at,
expires_at=obj_in.expires_at,
is_active=True,
location_city=obj_in.location_city,
location_country=obj_in.location_country,
)
db.add(db_obj)
await db.commit()
await db.refresh(db_obj)
logger.info(
f"Session created for user {obj_in.user_id} from {obj_in.device_name} "
f"(IP: {obj_in.ip_address})"
)
return db_obj
except Exception as e:
await db.rollback()
logger.error(f"Error creating session: {e!s}", exc_info=True)
raise IntegrityConstraintError(f"Failed to create session: {e!s}")
async def deactivate(
self, db: AsyncSession, *, session_id: str
) -> UserSession | None:
"""Deactivate a session (logout from device)."""
try:
session = await self.get(db, id=session_id)
if not session:
logger.warning(f"Session {session_id} not found for deactivation")
return None
session.is_active = False
db.add(session)
await db.commit()
await db.refresh(session)
logger.info(
f"Session {session_id} deactivated for user {session.user_id} "
f"({session.device_name})"
)
return session
except Exception as e:
await db.rollback()
logger.error(f"Error deactivating session {session_id}: {e!s}")
raise
async def deactivate_all_user_sessions(
self, db: AsyncSession, *, user_id: str
) -> int:
"""Deactivate all active sessions for a user (logout from all devices)."""
try:
user_uuid = UUID(user_id) if isinstance(user_id, str) else user_id
stmt = (
update(UserSession)
.where(and_(UserSession.user_id == user_uuid, UserSession.is_active))
.values(is_active=False)
)
result = await db.execute(stmt)
await db.commit()
count = result.rowcount
logger.info(f"Deactivated {count} sessions for user {user_id}")
return count
except Exception as e:
await db.rollback()
logger.error(f"Error deactivating all sessions for user {user_id}: {e!s}")
raise
async def update_last_used(
self, db: AsyncSession, *, session: UserSession
) -> UserSession:
"""Update the last_used_at timestamp for a session."""
try:
session.last_used_at = datetime.now(UTC)
db.add(session)
await db.commit()
await db.refresh(session)
return session
except Exception as e:
await db.rollback()
logger.error(f"Error updating last_used for session {session.id}: {e!s}")
raise
async def update_refresh_token(
self,
db: AsyncSession,
*,
session: UserSession,
new_jti: str,
new_expires_at: datetime,
) -> UserSession:
"""Update session with new refresh token JTI and expiration."""
try:
session.refresh_token_jti = new_jti
session.expires_at = new_expires_at
session.last_used_at = datetime.now(UTC)
db.add(session)
await db.commit()
await db.refresh(session)
return session
except Exception as e:
await db.rollback()
logger.error(
f"Error updating refresh token for session {session.id}: {e!s}"
)
raise
async def cleanup_expired(self, db: AsyncSession, *, keep_days: int = 30) -> int:
"""Clean up expired sessions using optimized bulk DELETE."""
try:
cutoff_date = datetime.now(UTC) - timedelta(days=keep_days)
now = datetime.now(UTC)
stmt = delete(UserSession).where(
and_(
UserSession.is_active == False, # noqa: E712
UserSession.expires_at < now,
UserSession.created_at < cutoff_date,
)
)
result = await db.execute(stmt)
await db.commit()
count = result.rowcount
if count > 0:
logger.info(f"Cleaned up {count} expired sessions using bulk DELETE")
return count
except Exception as e:
await db.rollback()
logger.error(f"Error cleaning up expired sessions: {e!s}")
raise
async def cleanup_expired_for_user(self, db: AsyncSession, *, user_id: str) -> int:
"""Clean up expired and inactive sessions for a specific user."""
try:
try:
uuid_obj = uuid.UUID(user_id)
except (ValueError, AttributeError):
logger.error(f"Invalid UUID format: {user_id}")
raise InvalidInputError(f"Invalid user ID format: {user_id}")
now = datetime.now(UTC)
stmt = delete(UserSession).where(
and_(
UserSession.user_id == uuid_obj,
UserSession.is_active == False, # noqa: E712
UserSession.expires_at < now,
)
)
result = await db.execute(stmt)
await db.commit()
count = result.rowcount
if count > 0:
logger.info(
f"Cleaned up {count} expired sessions for user {user_id} using bulk DELETE"
)
return count
except Exception as e:
await db.rollback()
logger.error(
f"Error cleaning up expired sessions for user {user_id}: {e!s}"
)
raise
async def get_user_session_count(self, db: AsyncSession, *, user_id: str) -> int:
"""Get count of active sessions for a user."""
try:
user_uuid = UUID(user_id) if isinstance(user_id, str) else user_id
result = await db.execute(
select(func.count(UserSession.id)).where(
and_(UserSession.user_id == user_uuid, UserSession.is_active)
)
)
return result.scalar_one()
except Exception as e:
logger.error(f"Error counting sessions for user {user_id}: {e!s}")
raise
async def get_all_sessions(
self,
db: AsyncSession,
*,
skip: int = 0,
limit: int = 100,
active_only: bool = True,
with_user: bool = True,
) -> tuple[list[UserSession], int]:
"""Get all sessions across all users with pagination (admin only)."""
try:
query = select(UserSession)
if with_user:
query = query.options(joinedload(UserSession.user))
if active_only:
query = query.where(UserSession.is_active)
count_query = select(func.count(UserSession.id))
if active_only:
count_query = count_query.where(UserSession.is_active)
count_result = await db.execute(count_query)
total = count_result.scalar_one()
query = (
query.order_by(UserSession.last_used_at.desc())
.offset(skip)
.limit(limit)
)
result = await db.execute(query)
sessions = list(result.scalars().all())
return sessions, total
except Exception as e:
logger.error(f"Error getting all sessions: {e!s}", exc_info=True)
raise
# Singleton instance
session_repo = SessionRepository(UserSession)

View File

@@ -0,0 +1,265 @@
# app/repositories/user.py
"""Repository for User model async CRUD operations using SQLAlchemy 2.0 patterns."""
import logging
from datetime import UTC, datetime
from typing import Any
from uuid import UUID
from sqlalchemy import or_, select, update
from sqlalchemy.exc import IntegrityError
from sqlalchemy.ext.asyncio import AsyncSession
from app.core.auth import get_password_hash_async
from app.core.repository_exceptions import DuplicateEntryError, InvalidInputError
from app.models.user import User
from app.repositories.base import BaseRepository
from app.schemas.users import UserCreate, UserUpdate
logger = logging.getLogger(__name__)
class UserRepository(BaseRepository[User, UserCreate, UserUpdate]):
"""Repository for User model."""
async def get_by_email(self, db: AsyncSession, *, email: str) -> User | None:
"""Get user by email address."""
try:
result = await db.execute(select(User).where(User.email == email))
return result.scalar_one_or_none()
except Exception as e:
logger.error(f"Error getting user by email {email}: {e!s}")
raise
async def create(self, db: AsyncSession, *, obj_in: UserCreate) -> User:
"""Create a new user with async password hashing and error handling."""
try:
password_hash = await get_password_hash_async(obj_in.password)
db_obj = User(
email=obj_in.email,
password_hash=password_hash,
first_name=obj_in.first_name,
last_name=obj_in.last_name,
phone_number=obj_in.phone_number
if hasattr(obj_in, "phone_number")
else None,
is_superuser=obj_in.is_superuser
if hasattr(obj_in, "is_superuser")
else False,
preferences={},
)
db.add(db_obj)
await db.commit()
await db.refresh(db_obj)
return db_obj
except IntegrityError as e:
await db.rollback()
error_msg = str(e.orig) if hasattr(e, "orig") else str(e)
if "email" in error_msg.lower():
logger.warning(f"Duplicate email attempted: {obj_in.email}")
raise DuplicateEntryError(f"User with email {obj_in.email} already exists")
logger.error(f"Integrity error creating user: {error_msg}")
raise DuplicateEntryError(f"Database integrity error: {error_msg}")
except Exception as e:
await db.rollback()
logger.error(f"Unexpected error creating user: {e!s}", exc_info=True)
raise
async def create_oauth_user(
self,
db: AsyncSession,
*,
email: str,
first_name: str = "User",
last_name: str | None = None,
) -> User:
"""Create a new passwordless user for OAuth sign-in."""
try:
db_obj = User(
email=email,
password_hash=None, # OAuth-only user
first_name=first_name,
last_name=last_name,
is_active=True,
is_superuser=False,
)
db.add(db_obj)
await db.flush() # Get user.id without committing
return db_obj
except IntegrityError as e:
await db.rollback()
error_msg = str(e.orig) if hasattr(e, "orig") else str(e)
if "email" in error_msg.lower():
logger.warning(f"Duplicate email attempted: {email}")
raise DuplicateEntryError(f"User with email {email} already exists")
logger.error(f"Integrity error creating OAuth user: {error_msg}")
raise DuplicateEntryError(f"Database integrity error: {error_msg}")
except Exception as e:
await db.rollback()
logger.error(f"Unexpected error creating OAuth user: {e!s}", exc_info=True)
raise
async def update(
self, db: AsyncSession, *, db_obj: User, obj_in: UserUpdate | dict[str, Any]
) -> User:
"""Update user with async password hashing if password is updated."""
if isinstance(obj_in, dict):
update_data = obj_in
else:
update_data = obj_in.model_dump(exclude_unset=True)
if "password" in update_data:
update_data["password_hash"] = await get_password_hash_async(
update_data["password"]
)
del update_data["password"]
return await super().update(db, db_obj=db_obj, obj_in=update_data)
async def update_password(
self, db: AsyncSession, *, user: User, password_hash: str
) -> User:
"""Set a new password hash on a user and commit."""
user.password_hash = password_hash
await db.commit()
await db.refresh(user)
return user
async def get_multi_with_total(
self,
db: AsyncSession,
*,
skip: int = 0,
limit: int = 100,
sort_by: str | None = None,
sort_order: str = "asc",
filters: dict[str, Any] | None = None,
search: str | None = None,
) -> tuple[list[User], int]:
"""Get multiple users with total count, filtering, sorting, and search."""
if skip < 0:
raise InvalidInputError("skip must be non-negative")
if limit < 0:
raise InvalidInputError("limit must be non-negative")
if limit > 1000:
raise InvalidInputError("Maximum limit is 1000")
try:
query = select(User)
query = query.where(User.deleted_at.is_(None))
if filters:
for field, value in filters.items():
if hasattr(User, field) and value is not None:
query = query.where(getattr(User, field) == value)
if search:
search_filter = or_(
User.email.ilike(f"%{search}%"),
User.first_name.ilike(f"%{search}%"),
User.last_name.ilike(f"%{search}%"),
)
query = query.where(search_filter)
from sqlalchemy import func
count_query = select(func.count()).select_from(query.alias())
count_result = await db.execute(count_query)
total = count_result.scalar_one()
if sort_by and hasattr(User, sort_by):
sort_column = getattr(User, sort_by)
if sort_order.lower() == "desc":
query = query.order_by(sort_column.desc())
else:
query = query.order_by(sort_column.asc())
query = query.offset(skip).limit(limit)
result = await db.execute(query)
users = list(result.scalars().all())
return users, total
except Exception as e:
logger.error(f"Error retrieving paginated users: {e!s}")
raise
async def bulk_update_status(
self, db: AsyncSession, *, user_ids: list[UUID], is_active: bool
) -> int:
"""Bulk update is_active status for multiple users."""
try:
if not user_ids:
return 0
stmt = (
update(User)
.where(User.id.in_(user_ids))
.where(User.deleted_at.is_(None))
.values(is_active=is_active, updated_at=datetime.now(UTC))
)
result = await db.execute(stmt)
await db.commit()
updated_count = result.rowcount
logger.info(f"Bulk updated {updated_count} users to is_active={is_active}")
return updated_count
except Exception as e:
await db.rollback()
logger.error(f"Error bulk updating user status: {e!s}", exc_info=True)
raise
async def bulk_soft_delete(
self,
db: AsyncSession,
*,
user_ids: list[UUID],
exclude_user_id: UUID | None = None,
) -> int:
"""Bulk soft delete multiple users."""
try:
if not user_ids:
return 0
filtered_ids = [uid for uid in user_ids if uid != exclude_user_id]
if not filtered_ids:
return 0
stmt = (
update(User)
.where(User.id.in_(filtered_ids))
.where(User.deleted_at.is_(None))
.values(
deleted_at=datetime.now(UTC),
is_active=False,
updated_at=datetime.now(UTC),
)
)
result = await db.execute(stmt)
await db.commit()
deleted_count = result.rowcount
logger.info(f"Bulk soft deleted {deleted_count} users")
return deleted_count
except Exception as e:
await db.rollback()
logger.error(f"Error bulk deleting users: {e!s}", exc_info=True)
raise
def is_active(self, user: User) -> bool:
"""Check if user is active."""
return user.is_active
def is_superuser(self, user: User) -> bool:
"""Check if user is a superuser."""
return user.is_superuser
# Singleton instance
user_repo = UserRepository(User)