From 98b455fdc369df2b7ba41f9cdbafd7679fef0803 Mon Sep 17 00:00:00 2001 From: Felipe Cardoso Date: Fri, 27 Feb 2026 09:32:57 +0100 Subject: [PATCH] =?UTF-8?q?refactor(backend):=20enforce=20route=E2=86=92se?= =?UTF-8?q?rvice=E2=86=92repo=20layered=20architecture?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - introduce custom repository exception hierarchy (DuplicateEntryError, IntegrityConstraintError, InvalidInputError) replacing raw ValueError - eliminate all direct repository imports and raw SQL from route layer - add UserService, SessionService, OrganizationService to service layer - add get_stats/get_org_distribution service methods replacing admin inline SQL - fix timing side-channel in authenticate_user via dummy bcrypt check - replace SHA-256 client secret fallback with explicit InvalidClientError - replace assert with InvalidGrantError in authorization code exchange - replace N+1 token revocation loops with bulk UPDATE statements - rename oauth account token fields (drop misleading 'encrypted' suffix) - add Alembic migration 0003 for token field column rename - add 45 new service/repository tests; 975 passing, 94% coverage --- ...ount_token_fields_drop_encrypted_suffix.py | 28 + backend/app/api/dependencies/auth.py | 10 +- backend/app/api/dependencies/permissions.py | 6 +- backend/app/api/dependencies/services.py | 41 + backend/app/api/routes/admin.py | 214 ++---- backend/app/api/routes/auth.py | 58 +- backend/app/api/routes/oauth.py | 9 +- backend/app/api/routes/oauth_provider.py | 32 +- backend/app/api/routes/organizations.py | 35 +- backend/app/api/routes/sessions.py | 10 +- backend/app/api/routes/users.py | 38 +- backend/app/core/repository_exceptions.py | 26 + backend/app/crud/__init__.py | 14 - backend/app/crud/oauth.py | 718 ------------------ backend/app/init_db.py | 2 +- backend/app/models/oauth_account.py | 6 +- backend/app/repositories/__init__.py | 39 + backend/app/{crud => repositories}/base.py | 101 +-- backend/app/repositories/oauth_account.py | 235 ++++++ .../repositories/oauth_authorization_code.py | 108 +++ backend/app/repositories/oauth_client.py | 199 +++++ backend/app/repositories/oauth_consent.py | 112 +++ .../app/repositories/oauth_provider_token.py | 146 ++++ backend/app/repositories/oauth_state.py | 113 +++ .../{crud => repositories}/organization.py | 73 +- backend/app/{crud => repositories}/session.py | 183 +---- backend/app/{crud => repositories}/user.py | 129 ++-- backend/app/schemas/oauth.py | 4 +- backend/app/services/__init__.py | 16 +- backend/app/services/auth_service.py | 87 ++- .../app/services/oauth_provider_service.py | 259 ++----- backend/app/services/oauth_service.py | 40 +- backend/app/services/organization_service.py | 157 ++++ backend/app/services/session_cleanup.py | 2 +- backend/app/services/session_service.py | 97 +++ backend/app/services/user_service.py | 120 +++ backend/tests/api/test_admin.py | 4 +- .../tests/api/test_admin_error_handlers.py | 34 +- backend/tests/api/test_auth_error_handlers.py | 14 +- backend/tests/api/test_auth_password_reset.py | 2 +- backend/tests/api/test_auth_security.py | 2 +- backend/tests/api/test_oauth.py | 6 +- backend/tests/api/test_organizations.py | 10 +- .../tests/api/test_permissions_security.py | 2 +- backend/tests/api/test_sessions.py | 26 +- backend/tests/api/test_users.py | 12 +- backend/tests/e2e/test_admin_workflows.py | 2 +- .../tests/e2e/test_organization_workflows.py | 2 +- .../tests/{crud => repositories}/__init__.py | 0 .../tests/{crud => repositories}/test_base.py | 51 +- .../test_base_db_failures.py | 7 +- .../{crud => repositories}/test_oauth.py | 15 +- .../test_organization.py | 11 +- .../{crud => repositories}/test_session.py | 5 +- .../test_session_db_failures.py | 7 +- .../tests/{crud => repositories}/test_user.py | 11 +- backend/tests/services/test_auth_service.py | 5 +- .../services/test_oauth_provider_service.py | 10 +- backend/tests/services/test_oauth_service.py | 3 +- .../services/test_organization_service.py | 447 +++++++++++ .../tests/services/test_session_service.py | 292 +++++++ backend/tests/services/test_user_service.py | 214 ++++++ 62 files changed, 2933 insertions(+), 1728 deletions(-) create mode 100644 backend/app/alembic/versions/0003_rename_oauth_account_token_fields_drop_encrypted_suffix.py create mode 100644 backend/app/api/dependencies/services.py create mode 100644 backend/app/core/repository_exceptions.py delete mode 100644 backend/app/crud/__init__.py delete mode 100755 backend/app/crud/oauth.py create mode 100644 backend/app/repositories/__init__.py rename backend/app/{crud => repositories}/base.py (80%) mode change 100755 => 100644 create mode 100644 backend/app/repositories/oauth_account.py create mode 100644 backend/app/repositories/oauth_authorization_code.py create mode 100644 backend/app/repositories/oauth_client.py create mode 100644 backend/app/repositories/oauth_consent.py create mode 100644 backend/app/repositories/oauth_provider_token.py create mode 100644 backend/app/repositories/oauth_state.py rename backend/app/{crud => repositories}/organization.py (87%) mode change 100755 => 100644 rename backend/app/{crud => repositories}/session.py (67%) mode change 100755 => 100644 rename backend/app/{crud => repositories}/user.py (71%) mode change 100755 => 100644 create mode 100644 backend/app/services/organization_service.py create mode 100644 backend/app/services/session_service.py create mode 100644 backend/app/services/user_service.py rename backend/tests/{crud => repositories}/__init__.py (100%) rename backend/tests/{crud => repositories}/test_base.py (94%) rename backend/tests/{crud => repositories}/test_base_db_failures.py (97%) rename backend/tests/{crud => repositories}/test_oauth.py (97%) rename backend/tests/{crud => repositories}/test_organization.py (98%) rename backend/tests/{crud => repositories}/test_session.py (99%) rename backend/tests/{crud => repositories}/test_session_db_failures.py (97%) rename backend/tests/{crud => repositories}/test_user.py (98%) create mode 100644 backend/tests/services/test_organization_service.py create mode 100644 backend/tests/services/test_session_service.py create mode 100644 backend/tests/services/test_user_service.py diff --git a/backend/app/alembic/versions/0003_rename_oauth_account_token_fields_drop_encrypted_suffix.py b/backend/app/alembic/versions/0003_rename_oauth_account_token_fields_drop_encrypted_suffix.py new file mode 100644 index 0000000..b717976 --- /dev/null +++ b/backend/app/alembic/versions/0003_rename_oauth_account_token_fields_drop_encrypted_suffix.py @@ -0,0 +1,28 @@ +"""rename oauth account token fields drop encrypted suffix + +Revision ID: 0003 +Revises: 0002 +Create Date: 2026-02-27 01:03:18.869178 + +""" + +from collections.abc import Sequence + +import sqlalchemy as sa +from alembic import op + +# revision identifiers, used by Alembic. +revision: str = "0003" +down_revision: str | None = "0002" +branch_labels: str | Sequence[str] | None = None +depends_on: str | Sequence[str] | None = None + + +def upgrade() -> None: + op.alter_column("oauth_accounts", "access_token_encrypted", new_column_name="access_token") + op.alter_column("oauth_accounts", "refresh_token_encrypted", new_column_name="refresh_token") + + +def downgrade() -> None: + op.alter_column("oauth_accounts", "access_token", new_column_name="access_token_encrypted") + op.alter_column("oauth_accounts", "refresh_token", new_column_name="refresh_token_encrypted") diff --git a/backend/app/api/dependencies/auth.py b/backend/app/api/dependencies/auth.py index 5d6a7aa..63b3de0 100755 --- a/backend/app/api/dependencies/auth.py +++ b/backend/app/api/dependencies/auth.py @@ -1,12 +1,12 @@ from fastapi import Depends, Header, HTTPException, status from fastapi.security import OAuth2PasswordBearer from fastapi.security.utils import get_authorization_scheme_param -from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession from app.core.auth import TokenExpiredError, TokenInvalidError, get_token_data from app.core.database import get_db from app.models.user import User +from app.repositories.user import user_repo # OAuth2 configuration oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/api/v1/auth/login") @@ -32,9 +32,8 @@ async def get_current_user( # Decode token and get user ID token_data = get_token_data(token) - # Get user from database - result = await db.execute(select(User).where(User.id == token_data.user_id)) - user = result.scalar_one_or_none() + # Get user from database via repository + user = await user_repo.get(db, id=str(token_data.user_id)) if not user: raise HTTPException( @@ -144,8 +143,7 @@ async def get_optional_current_user( try: token_data = get_token_data(token) - result = await db.execute(select(User).where(User.id == token_data.user_id)) - user = result.scalar_one_or_none() + user = await user_repo.get(db, id=str(token_data.user_id)) if not user or not user.is_active: return None return user diff --git a/backend/app/api/dependencies/permissions.py b/backend/app/api/dependencies/permissions.py index 0d5aa40..5550326 100755 --- a/backend/app/api/dependencies/permissions.py +++ b/backend/app/api/dependencies/permissions.py @@ -15,9 +15,9 @@ from sqlalchemy.ext.asyncio import AsyncSession from app.api.dependencies.auth import get_current_user from app.core.database import get_db -from app.crud.organization import organization as organization_crud from app.models.user import User from app.models.user_organization import OrganizationRole +from app.services.organization_service import organization_service def require_superuser(current_user: User = Depends(get_current_user)) -> User: @@ -81,7 +81,7 @@ class OrganizationPermission: return current_user # Get user's role in organization - user_role = await organization_crud.get_user_role_in_org( + user_role = await organization_service.get_user_role_in_org( db, user_id=current_user.id, organization_id=organization_id ) @@ -123,7 +123,7 @@ async def require_org_membership( if current_user.is_superuser: return current_user - user_role = await organization_crud.get_user_role_in_org( + user_role = await organization_service.get_user_role_in_org( db, user_id=current_user.id, organization_id=organization_id ) diff --git a/backend/app/api/dependencies/services.py b/backend/app/api/dependencies/services.py new file mode 100644 index 0000000..63d9ae2 --- /dev/null +++ b/backend/app/api/dependencies/services.py @@ -0,0 +1,41 @@ +# app/api/dependencies/services.py +"""FastAPI dependency functions for service singletons.""" + +from app.services import oauth_provider_service +from app.services.auth_service import AuthService +from app.services.oauth_service import OAuthService +from app.services.organization_service import OrganizationService, organization_service +from app.services.session_service import SessionService, session_service +from app.services.user_service import UserService, user_service + + +def get_auth_service() -> AuthService: + """Return the AuthService singleton for dependency injection.""" + from app.services.auth_service import AuthService as _AuthService + + return _AuthService() + + +def get_user_service() -> UserService: + """Return the UserService singleton for dependency injection.""" + return user_service + + +def get_organization_service() -> OrganizationService: + """Return the OrganizationService singleton for dependency injection.""" + return organization_service + + +def get_session_service() -> SessionService: + """Return the SessionService singleton for dependency injection.""" + return session_service + + +def get_oauth_service() -> OAuthService: + """Return OAuthService for dependency injection.""" + return OAuthService() + + +def get_oauth_provider_service(): + """Return the oauth_provider_service module for dependency injection.""" + return oauth_provider_service diff --git a/backend/app/api/routes/admin.py b/backend/app/api/routes/admin.py index e49a164..f2f2bb8 100755 --- a/backend/app/api/routes/admin.py +++ b/backend/app/api/routes/admin.py @@ -14,7 +14,6 @@ from uuid import UUID from fastapi import APIRouter, Depends, Query, status from pydantic import BaseModel, Field -from sqlalchemy import func, select from sqlalchemy.ext.asyncio import AsyncSession from app.api.dependencies.permissions import require_superuser @@ -25,12 +24,12 @@ from app.core.exceptions import ( ErrorCode, NotFoundError, ) -from app.crud.organization import organization as organization_crud -from app.crud.session import session as session_crud -from app.crud.user import user as user_crud -from app.models.organization import Organization +from app.core.repository_exceptions import DuplicateEntryError from app.models.user import User -from app.models.user_organization import OrganizationRole, UserOrganization +from app.models.user_organization import OrganizationRole +from app.services.organization_service import organization_service +from app.services.session_service import session_service +from app.services.user_service import user_service from app.schemas.common import ( MessageResponse, PaginatedResponse, @@ -178,38 +177,29 @@ async def admin_get_stats( """Get admin dashboard statistics with real data from database.""" from app.core.config import settings - # Check if we have any data - total_users_query = select(func.count()).select_from(User) - total_users = (await db.execute(total_users_query)).scalar() or 0 + stats = await user_service.get_stats(db) + total_users = stats["total_users"] + active_count = stats["active_count"] + inactive_count = stats["inactive_count"] + all_users = stats["all_users"] # If database is essentially empty (only admin user), return demo data if total_users <= 1 and settings.DEMO_MODE: # pragma: no cover logger.info("Returning demo stats data (empty database in demo mode)") return _generate_demo_stats() - # 1. User Growth (Last 30 days) - Improved calculation - datetime.now(UTC) - timedelta(days=30) - - # Get all users with their creation dates - all_users_query = select(User).order_by(User.created_at) - result = await db.execute(all_users_query) - all_users = result.scalars().all() - - # Build cumulative counts per day + # 1. User Growth (Last 30 days) user_growth = [] for i in range(29, -1, -1): date = datetime.now(UTC) - timedelta(days=i) date_start = date.replace(hour=0, minute=0, second=0, microsecond=0, tzinfo=UTC) date_end = date_start + timedelta(days=1) - # Count all users created before end of this day - # Make comparison timezone-aware total_users_on_date = sum( 1 for u in all_users if u.created_at and u.created_at.replace(tzinfo=UTC) < date_end ) - # Count active users created before end of this day active_users_on_date = sum( 1 for u in all_users @@ -227,27 +217,16 @@ async def admin_get_stats( ) # 2. Organization Distribution - Top 6 organizations by member count - org_query = ( - select(Organization.name, func.count(UserOrganization.user_id).label("count")) - .join(UserOrganization, Organization.id == UserOrganization.organization_id) - .group_by(Organization.name) - .order_by(func.count(UserOrganization.user_id).desc()) - .limit(6) - ) - result = await db.execute(org_query) - org_dist = [ - OrgDistributionData(name=row.name, value=row.count) for row in result.all() - ] + org_rows = await organization_service.get_org_distribution(db, limit=6) + org_dist = [OrgDistributionData(name=r["name"], value=r["value"]) for r in org_rows] - # 3. User Registration Activity (Last 14 days) - NEW + # 3. User Registration Activity (Last 14 days) registration_activity = [] for i in range(13, -1, -1): date = datetime.now(UTC) - timedelta(days=i) date_start = date.replace(hour=0, minute=0, second=0, microsecond=0, tzinfo=UTC) date_end = date_start + timedelta(days=1) - # Count users created on this specific day - # Make comparison timezone-aware day_registrations = sum( 1 for u in all_users @@ -263,14 +242,6 @@ async def admin_get_stats( ) # 4. User Status - Active vs Inactive - active_query = select(func.count()).select_from(User).where(User.is_active) - inactive_query = ( - select(func.count()).select_from(User).where(User.is_active.is_(False)) - ) - - active_count = (await db.execute(active_query)).scalar() or 0 - inactive_count = (await db.execute(inactive_query)).scalar() or 0 - logger.info( f"User status counts - Active: {active_count}, Inactive: {inactive_count}" ) @@ -321,7 +292,7 @@ async def admin_list_users( filters["is_superuser"] = is_superuser # Get users with search - users, total = await user_crud.get_multi_with_total( + users, total = await user_service.list_users( db, skip=pagination.offset, limit=pagination.limit, @@ -364,12 +335,12 @@ async def admin_create_user( Allows setting is_superuser and other fields. """ try: - user = await user_crud.create(db, obj_in=user_in) + user = await user_service.create_user(db, user_in) logger.info(f"Admin {admin.email} created user {user.email}") return user - except ValueError as e: + except DuplicateEntryError as e: logger.warning(f"Failed to create user: {e!s}") - raise NotFoundError(message=str(e), error_code=ErrorCode.USER_ALREADY_EXISTS) + raise DuplicateError(message=str(e), error_code=ErrorCode.USER_ALREADY_EXISTS) except Exception as e: logger.error(f"Error creating user (admin): {e!s}", exc_info=True) raise @@ -388,11 +359,7 @@ async def admin_get_user( db: AsyncSession = Depends(get_db), ) -> Any: """Get detailed information about a specific user.""" - user = await user_crud.get(db, id=user_id) - if not user: - raise NotFoundError( - message=f"User {user_id} not found", error_code=ErrorCode.USER_NOT_FOUND - ) + user = await user_service.get_user(db, str(user_id)) return user @@ -411,18 +378,11 @@ async def admin_update_user( ) -> Any: """Update user information with admin privileges.""" try: - user = await user_crud.get(db, id=user_id) - if not user: - raise NotFoundError( - message=f"User {user_id} not found", error_code=ErrorCode.USER_NOT_FOUND - ) - - updated_user = await user_crud.update(db, db_obj=user, obj_in=user_in) + user = await user_service.get_user(db, str(user_id)) + updated_user = await user_service.update_user(db, user=user, obj_in=user_in) logger.info(f"Admin {admin.email} updated user {updated_user.email}") return updated_user - except NotFoundError: - raise except Exception as e: logger.error(f"Error updating user (admin): {e!s}", exc_info=True) raise @@ -442,11 +402,7 @@ async def admin_delete_user( ) -> Any: """Soft delete a user (sets deleted_at timestamp).""" try: - user = await user_crud.get(db, id=user_id) - if not user: - raise NotFoundError( - message=f"User {user_id} not found", error_code=ErrorCode.USER_NOT_FOUND - ) + user = await user_service.get_user(db, str(user_id)) # Prevent deleting yourself if user.id == admin.id: @@ -456,15 +412,13 @@ async def admin_delete_user( error_code=ErrorCode.OPERATION_FORBIDDEN, ) - await user_crud.soft_delete(db, id=user_id) + await user_service.soft_delete_user(db, str(user_id)) logger.info(f"Admin {admin.email} deleted user {user.email}") return MessageResponse( success=True, message=f"User {user.email} has been deleted" ) - except NotFoundError: - raise except Exception as e: logger.error(f"Error deleting user (admin): {e!s}", exc_info=True) raise @@ -484,21 +438,14 @@ async def admin_activate_user( ) -> Any: """Activate a user account.""" try: - user = await user_crud.get(db, id=user_id) - if not user: - raise NotFoundError( - message=f"User {user_id} not found", error_code=ErrorCode.USER_NOT_FOUND - ) - - await user_crud.update(db, db_obj=user, obj_in={"is_active": True}) + user = await user_service.get_user(db, str(user_id)) + await user_service.update_user(db, user=user, obj_in={"is_active": True}) logger.info(f"Admin {admin.email} activated user {user.email}") return MessageResponse( success=True, message=f"User {user.email} has been activated" ) - except NotFoundError: - raise except Exception as e: logger.error(f"Error activating user (admin): {e!s}", exc_info=True) raise @@ -518,11 +465,7 @@ async def admin_deactivate_user( ) -> Any: """Deactivate a user account.""" try: - user = await user_crud.get(db, id=user_id) - if not user: - raise NotFoundError( - message=f"User {user_id} not found", error_code=ErrorCode.USER_NOT_FOUND - ) + user = await user_service.get_user(db, str(user_id)) # Prevent deactivating yourself if user.id == admin.id: @@ -532,15 +475,13 @@ async def admin_deactivate_user( error_code=ErrorCode.OPERATION_FORBIDDEN, ) - await user_crud.update(db, db_obj=user, obj_in={"is_active": False}) + await user_service.update_user(db, user=user, obj_in={"is_active": False}) logger.info(f"Admin {admin.email} deactivated user {user.email}") return MessageResponse( success=True, message=f"User {user.email} has been deactivated" ) - except NotFoundError: - raise except Exception as e: logger.error(f"Error deactivating user (admin): {e!s}", exc_info=True) raise @@ -567,16 +508,16 @@ async def admin_bulk_user_action( try: # Use efficient bulk operations instead of loop if bulk_action.action == BulkAction.ACTIVATE: - affected_count = await user_crud.bulk_update_status( + affected_count = await user_service.bulk_update_status( db, user_ids=bulk_action.user_ids, is_active=True ) elif bulk_action.action == BulkAction.DEACTIVATE: - affected_count = await user_crud.bulk_update_status( + affected_count = await user_service.bulk_update_status( db, user_ids=bulk_action.user_ids, is_active=False ) elif bulk_action.action == BulkAction.DELETE: # bulk_soft_delete automatically excludes the admin user - affected_count = await user_crud.bulk_soft_delete( + affected_count = await user_service.bulk_soft_delete( db, user_ids=bulk_action.user_ids, exclude_user_id=admin.id ) else: # pragma: no cover @@ -624,7 +565,7 @@ async def admin_list_organizations( """List all organizations with filtering and search.""" try: # Use optimized method that gets member counts in single query (no N+1) - orgs_with_data, total = await organization_crud.get_multi_with_member_counts( + orgs_with_data, total = await organization_service.get_multi_with_member_counts( db, skip=pagination.offset, limit=pagination.limit, @@ -680,7 +621,7 @@ async def admin_create_organization( ) -> Any: """Create a new organization.""" try: - org = await organization_crud.create(db, obj_in=org_in) + org = await organization_service.create_organization(db, obj_in=org_in) logger.info(f"Admin {admin.email} created organization {org.name}") # Add member count @@ -697,9 +638,9 @@ async def admin_create_organization( } return OrganizationResponse(**org_dict) - except ValueError as e: + except DuplicateEntryError as e: logger.warning(f"Failed to create organization: {e!s}") - raise NotFoundError(message=str(e), error_code=ErrorCode.ALREADY_EXISTS) + raise DuplicateError(message=str(e), error_code=ErrorCode.ALREADY_EXISTS) except Exception as e: logger.error(f"Error creating organization (admin): {e!s}", exc_info=True) raise @@ -718,12 +659,7 @@ async def admin_get_organization( db: AsyncSession = Depends(get_db), ) -> Any: """Get detailed information about a specific organization.""" - org = await organization_crud.get(db, id=org_id) - if not org: - raise NotFoundError( - message=f"Organization {org_id} not found", error_code=ErrorCode.NOT_FOUND - ) - + org = await organization_service.get_organization(db, str(org_id)) org_dict = { "id": org.id, "name": org.name, @@ -733,7 +669,7 @@ async def admin_get_organization( "settings": org.settings, "created_at": org.created_at, "updated_at": org.updated_at, - "member_count": await organization_crud.get_member_count( + "member_count": await organization_service.get_member_count( db, organization_id=org.id ), } @@ -755,14 +691,10 @@ async def admin_update_organization( ) -> Any: """Update organization information.""" try: - org = await organization_crud.get(db, id=org_id) - if not org: - raise NotFoundError( - message=f"Organization {org_id} not found", - error_code=ErrorCode.NOT_FOUND, - ) - - updated_org = await organization_crud.update(db, db_obj=org, obj_in=org_in) + org = await organization_service.get_organization(db, str(org_id)) + updated_org = await organization_service.update_organization( + db, org=org, obj_in=org_in + ) logger.info(f"Admin {admin.email} updated organization {updated_org.name}") org_dict = { @@ -774,14 +706,12 @@ async def admin_update_organization( "settings": updated_org.settings, "created_at": updated_org.created_at, "updated_at": updated_org.updated_at, - "member_count": await organization_crud.get_member_count( + "member_count": await organization_service.get_member_count( db, organization_id=updated_org.id ), } return OrganizationResponse(**org_dict) - except NotFoundError: - raise except Exception as e: logger.error(f"Error updating organization (admin): {e!s}", exc_info=True) raise @@ -801,22 +731,14 @@ async def admin_delete_organization( ) -> Any: """Delete an organization and all its relationships.""" try: - org = await organization_crud.get(db, id=org_id) - if not org: - raise NotFoundError( - message=f"Organization {org_id} not found", - error_code=ErrorCode.NOT_FOUND, - ) - - await organization_crud.remove(db, id=org_id) + org = await organization_service.get_organization(db, str(org_id)) + await organization_service.remove_organization(db, str(org_id)) logger.info(f"Admin {admin.email} deleted organization {org.name}") return MessageResponse( success=True, message=f"Organization {org.name} has been deleted" ) - except NotFoundError: - raise except Exception as e: logger.error(f"Error deleting organization (admin): {e!s}", exc_info=True) raise @@ -838,14 +760,8 @@ async def admin_list_organization_members( ) -> Any: """List all members of an organization.""" try: - org = await organization_crud.get(db, id=org_id) - if not org: - raise NotFoundError( - message=f"Organization {org_id} not found", - error_code=ErrorCode.NOT_FOUND, - ) - - members, total = await organization_crud.get_organization_members( + await organization_service.get_organization(db, str(org_id)) # validates exists + members, total = await organization_service.get_organization_members( db, organization_id=org_id, skip=pagination.offset, @@ -898,21 +814,10 @@ async def admin_add_organization_member( ) -> Any: """Add a user to an organization.""" try: - org = await organization_crud.get(db, id=org_id) - if not org: - raise NotFoundError( - message=f"Organization {org_id} not found", - error_code=ErrorCode.NOT_FOUND, - ) + org = await organization_service.get_organization(db, str(org_id)) + user = await user_service.get_user(db, str(request.user_id)) - user = await user_crud.get(db, id=request.user_id) - if not user: - raise NotFoundError( - message=f"User {request.user_id} not found", - error_code=ErrorCode.USER_NOT_FOUND, - ) - - await organization_crud.add_user( + await organization_service.add_member( db, organization_id=org_id, user_id=request.user_id, role=request.role ) @@ -925,14 +830,11 @@ async def admin_add_organization_member( success=True, message=f"User {user.email} added to organization {org.name}" ) - except ValueError as e: + except DuplicateEntryError as e: logger.warning(f"Failed to add user to organization: {e!s}") - # Use DuplicateError for "already exists" scenarios raise DuplicateError( message=str(e), error_code=ErrorCode.USER_ALREADY_EXISTS, field="user_id" ) - except NotFoundError: - raise except Exception as e: logger.error( f"Error adding member to organization (admin): {e!s}", exc_info=True @@ -955,20 +857,10 @@ async def admin_remove_organization_member( ) -> Any: """Remove a user from an organization.""" try: - org = await organization_crud.get(db, id=org_id) - if not org: - raise NotFoundError( - message=f"Organization {org_id} not found", - error_code=ErrorCode.NOT_FOUND, - ) + org = await organization_service.get_organization(db, str(org_id)) + user = await user_service.get_user(db, str(user_id)) - user = await user_crud.get(db, id=user_id) - if not user: - raise NotFoundError( - message=f"User {user_id} not found", error_code=ErrorCode.USER_NOT_FOUND - ) - - success = await organization_crud.remove_user( + success = await organization_service.remove_member( db, organization_id=org_id, user_id=user_id ) @@ -1022,7 +914,7 @@ async def admin_list_sessions( """List all sessions across all users with filtering and pagination.""" try: # Get sessions with user info (eager loaded to prevent N+1) - sessions, total = await session_crud.get_all_sessions( + sessions, total = await session_service.get_all_sessions( db, skip=pagination.offset, limit=pagination.limit, diff --git a/backend/app/api/routes/auth.py b/backend/app/api/routes/auth.py index 72ba3cb..0160153 100755 --- a/backend/app/api/routes/auth.py +++ b/backend/app/api/routes/auth.py @@ -15,16 +15,14 @@ from app.core.auth import ( TokenExpiredError, TokenInvalidError, decode_token, - get_password_hash, ) from app.core.database import get_db from app.core.exceptions import ( AuthenticationError as AuthError, DatabaseError, + DuplicateError, ErrorCode, ) -from app.crud.session import session as session_crud -from app.crud.user import user as user_crud from app.models.user import User from app.schemas.common import MessageResponse from app.schemas.sessions import LogoutRequest, SessionCreate @@ -39,6 +37,8 @@ from app.schemas.users import ( ) from app.services.auth_service import AuthenticationError, AuthService from app.services.email_service import email_service +from app.services.session_service import session_service +from app.services.user_service import user_service from app.utils.device import extract_device_info from app.utils.security import create_password_reset_token, verify_password_reset_token @@ -91,7 +91,7 @@ async def _create_login_session( location_country=device_info.location_country, ) - await session_crud.create_session(db, obj_in=session_data) + await session_service.create_session(db, obj_in=session_data) logger.info( f"{login_type.capitalize()} successful: {user.email} from {device_info.device_name} " @@ -123,8 +123,14 @@ async def register_user( try: user = await AuthService.create_user(db, user_data) return user - except AuthenticationError as e: + except DuplicateError: # SECURITY: Don't reveal if email exists - generic error message + logger.warning(f"Registration failed: duplicate email {user_data.email}") + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="Registration failed. Please check your information and try again.", + ) + except AuthError as e: logger.warning(f"Registration failed: {e!s}") raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, @@ -259,7 +265,7 @@ async def refresh_token( ) # Check if session exists and is active - session = await session_crud.get_active_by_jti(db, jti=refresh_payload.jti) + session = await session_service.get_active_by_jti(db, jti=refresh_payload.jti) if not session: logger.warning( @@ -279,7 +285,7 @@ async def refresh_token( # Update session with new refresh token JTI and expiration try: - await session_crud.update_refresh_token( + await session_service.update_refresh_token( db, session=session, new_jti=new_refresh_payload.jti, @@ -347,7 +353,7 @@ async def request_password_reset( """ try: # Look up user by email - user = await user_crud.get_by_email(db, email=reset_request.email) + user = await user_service.get_by_email(db, email=reset_request.email) # Only send email if user exists and is active if user and user.is_active: @@ -412,31 +418,25 @@ async def confirm_password_reset( detail="Invalid or expired password reset token", ) - # Look up user - user = await user_crud.get_by_email(db, email=email) - - if not user: - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, detail="User not found" + # Reset password via service (validates user exists and is active) + try: + user = await AuthService.reset_password( + db, email=email, new_password=reset_confirm.new_password ) - - if not user.is_active: + except AuthenticationError as e: + err_msg = str(e) + if "inactive" in err_msg.lower(): + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, detail=err_msg + ) raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail="User account is inactive", + status_code=status.HTTP_404_NOT_FOUND, detail=err_msg ) - # Update password - user.password_hash = get_password_hash(reset_confirm.new_password) - db.add(user) - await db.commit() - # SECURITY: Invalidate all existing sessions after password reset # This prevents stolen sessions from being used after password change - from app.crud.session import session as session_crud - try: - deactivated_count = await session_crud.deactivate_all_user_sessions( + deactivated_count = await session_service.deactivate_all_user_sessions( db, user_id=str(user.id) ) logger.info( @@ -511,7 +511,7 @@ async def logout( return MessageResponse(success=True, message="Logged out successfully") # Find the session by JTI - session = await session_crud.get_by_jti(db, jti=refresh_payload.jti) + session = await session_service.get_by_jti(db, jti=refresh_payload.jti) if session: # Verify session belongs to current user (security check) @@ -526,7 +526,7 @@ async def logout( ) # Deactivate the session - await session_crud.deactivate(db, session_id=str(session.id)) + await session_service.deactivate(db, session_id=str(session.id)) logger.info( f"User {current_user.id} logged out from {session.device_name} " @@ -584,7 +584,7 @@ async def logout_all( """ try: # Deactivate all sessions for this user - count = await session_crud.deactivate_all_user_sessions( + count = await session_service.deactivate_all_user_sessions( db, user_id=str(current_user.id) ) diff --git a/backend/app/api/routes/oauth.py b/backend/app/api/routes/oauth.py index 39dbd38..c5ab491 100644 --- a/backend/app/api/routes/oauth.py +++ b/backend/app/api/routes/oauth.py @@ -25,8 +25,7 @@ from app.core.auth import decode_token from app.core.config import settings from app.core.database import get_db from app.core.exceptions import AuthenticationError as AuthError -from app.crud import oauth_account -from app.crud.session import session as session_crud +from app.services.session_service import session_service from app.models.user import User from app.schemas.oauth import ( OAuthAccountsListResponse, @@ -82,7 +81,7 @@ async def _create_oauth_login_session( location_country=device_info.location_country, ) - await session_crud.create_session(db, obj_in=session_data) + await session_service.create_session(db, obj_in=session_data) logger.info( f"OAuth login successful: {user.email} via {provider} " @@ -289,7 +288,7 @@ async def list_accounts( Returns: List of linked OAuth accounts """ - accounts = await oauth_account.get_user_accounts(db, user_id=current_user.id) + accounts = await OAuthService.get_user_accounts(db, user_id=current_user.id) return OAuthAccountsListResponse(accounts=accounts) @@ -397,7 +396,7 @@ async def start_link( ) # Check if user already has this provider linked - existing = await oauth_account.get_user_account_by_provider( + existing = await OAuthService.get_user_account_by_provider( db, user_id=current_user.id, provider=provider ) if existing: diff --git a/backend/app/api/routes/oauth_provider.py b/backend/app/api/routes/oauth_provider.py index 4699cdc..024cc14 100644 --- a/backend/app/api/routes/oauth_provider.py +++ b/backend/app/api/routes/oauth_provider.py @@ -34,7 +34,6 @@ from app.api.dependencies.auth import ( ) from app.core.config import settings from app.core.database import get_db -from app.crud import oauth_client as oauth_client_crud from app.models.user import User from app.schemas.oauth import ( OAuthClientCreate, @@ -712,7 +711,7 @@ async def register_client( client_type=client_type, ) - client, secret = await oauth_client_crud.create_client(db, obj_in=client_data) + client, secret = await provider_service.register_client(db, client_data) # Update MCP server URL if provided if mcp_server_url: @@ -750,7 +749,7 @@ async def list_clients( current_user: User = Depends(get_current_superuser), ) -> list[OAuthClientResponse]: """List all OAuth clients.""" - clients = await oauth_client_crud.get_all_clients(db) + clients = await provider_service.list_clients(db) return [OAuthClientResponse.model_validate(c) for c in clients] @@ -776,7 +775,7 @@ async def delete_client( detail="Client not found", ) - await oauth_client_crud.delete_client(db, client_id=client_id) + await provider_service.delete_client_by_id(db, client_id=client_id) # ============================================================================ @@ -797,30 +796,7 @@ async def list_my_consents( current_user: User = Depends(get_current_active_user), ) -> list[dict]: """List applications the user has authorized.""" - from sqlalchemy import select - - from app.models.oauth_client import OAuthClient - from app.models.oauth_provider_token import OAuthConsent - - result = await db.execute( - select(OAuthConsent, OAuthClient) - .join(OAuthClient, OAuthConsent.client_id == OAuthClient.client_id) - .where(OAuthConsent.user_id == current_user.id) - ) - rows = result.all() - - return [ - { - "client_id": consent.client_id, - "client_name": client.client_name, - "client_description": client.client_description, - "granted_scopes": consent.granted_scopes.split() - if consent.granted_scopes - else [], - "granted_at": consent.created_at.isoformat(), - } - for consent, client in rows - ] + return await provider_service.list_user_consents(db, user_id=current_user.id) @router.delete( diff --git a/backend/app/api/routes/organizations.py b/backend/app/api/routes/organizations.py index 6d15c0d..8784987 100755 --- a/backend/app/api/routes/organizations.py +++ b/backend/app/api/routes/organizations.py @@ -15,9 +15,8 @@ from sqlalchemy.ext.asyncio import AsyncSession from app.api.dependencies.auth import get_current_user from app.api.dependencies.permissions import require_org_admin, require_org_membership from app.core.database import get_db -from app.core.exceptions import ErrorCode, NotFoundError -from app.crud.organization import organization as organization_crud from app.models.user import User +from app.services.organization_service import organization_service from app.schemas.common import ( PaginatedResponse, PaginationParams, @@ -54,7 +53,7 @@ async def get_my_organizations( """ try: # Get all org data in single query with JOIN and subquery - orgs_data = await organization_crud.get_user_organizations_with_details( + orgs_data = await organization_service.get_user_organizations_with_details( db, user_id=current_user.id, is_active=is_active ) @@ -100,13 +99,7 @@ async def get_organization( User must be a member of the organization. """ try: - org = await organization_crud.get(db, id=organization_id) - if not org: # pragma: no cover - Permission check prevents this (see docs/UNREACHABLE_DEFENSIVE_CODE_ANALYSIS.md) - raise NotFoundError( - detail=f"Organization {organization_id} not found", - error_code=ErrorCode.NOT_FOUND, - ) - + org = await organization_service.get_organization(db, str(organization_id)) org_dict = { "id": org.id, "name": org.name, @@ -116,14 +109,12 @@ async def get_organization( "settings": org.settings, "created_at": org.created_at, "updated_at": org.updated_at, - "member_count": await organization_crud.get_member_count( + "member_count": await organization_service.get_member_count( db, organization_id=org.id ), } return OrganizationResponse(**org_dict) - except NotFoundError: # pragma: no cover - See above - raise except Exception as e: logger.error(f"Error getting organization: {e!s}", exc_info=True) raise @@ -149,7 +140,7 @@ async def get_organization_members( User must be a member of the organization to view members. """ try: - members, total = await organization_crud.get_organization_members( + members, total = await organization_service.get_organization_members( db, organization_id=organization_id, skip=pagination.offset, @@ -192,14 +183,10 @@ async def update_organization( Requires owner or admin role in the organization. """ try: - org = await organization_crud.get(db, id=organization_id) - if not org: # pragma: no cover - Permission check prevents this (see docs/UNREACHABLE_DEFENSIVE_CODE_ANALYSIS.md) - raise NotFoundError( - detail=f"Organization {organization_id} not found", - error_code=ErrorCode.NOT_FOUND, - ) - - updated_org = await organization_crud.update(db, db_obj=org, obj_in=org_in) + org = await organization_service.get_organization(db, str(organization_id)) + updated_org = await organization_service.update_organization( + db, org=org, obj_in=org_in + ) logger.info( f"User {current_user.email} updated organization {updated_org.name}" ) @@ -213,14 +200,12 @@ async def update_organization( "settings": updated_org.settings, "created_at": updated_org.created_at, "updated_at": updated_org.updated_at, - "member_count": await organization_crud.get_member_count( + "member_count": await organization_service.get_member_count( db, organization_id=updated_org.id ), } return OrganizationResponse(**org_dict) - except NotFoundError: # pragma: no cover - See above - raise except Exception as e: logger.error(f"Error updating organization: {e!s}", exc_info=True) raise diff --git a/backend/app/api/routes/sessions.py b/backend/app/api/routes/sessions.py index 4300353..ebb7274 100755 --- a/backend/app/api/routes/sessions.py +++ b/backend/app/api/routes/sessions.py @@ -17,8 +17,8 @@ from app.api.dependencies.auth import get_current_user from app.core.auth import decode_token from app.core.database import get_db from app.core.exceptions import AuthorizationError, ErrorCode, NotFoundError -from app.crud.session import session as session_crud from app.models.user import User +from app.services.session_service import session_service from app.schemas.common import MessageResponse from app.schemas.sessions import SessionListResponse, SessionResponse @@ -60,7 +60,7 @@ async def list_my_sessions( """ try: # Get all active sessions for user - sessions = await session_crud.get_user_sessions( + sessions = await session_service.get_user_sessions( db, user_id=str(current_user.id), active_only=True ) @@ -150,7 +150,7 @@ async def revoke_session( """ try: # Get the session - session = await session_crud.get(db, id=str(session_id)) + session = await session_service.get_session(db, str(session_id)) if not session: raise NotFoundError( @@ -170,7 +170,7 @@ async def revoke_session( ) # Deactivate the session - await session_crud.deactivate(db, session_id=str(session_id)) + await session_service.deactivate(db, session_id=str(session_id)) logger.info( f"User {current_user.id} revoked session {session_id} " @@ -224,7 +224,7 @@ async def cleanup_expired_sessions( """ try: # Use optimized bulk DELETE instead of N individual deletes - deleted_count = await session_crud.cleanup_expired_for_user( + deleted_count = await session_service.cleanup_expired_for_user( db, user_id=str(current_user.id) ) diff --git a/backend/app/api/routes/users.py b/backend/app/api/routes/users.py index 34790f8..c6ff0a9 100755 --- a/backend/app/api/routes/users.py +++ b/backend/app/api/routes/users.py @@ -14,7 +14,6 @@ from sqlalchemy.ext.asyncio import AsyncSession from app.api.dependencies.auth import get_current_superuser, get_current_user from app.core.database import get_db from app.core.exceptions import AuthorizationError, ErrorCode, NotFoundError -from app.crud.user import user as user_crud from app.models.user import User from app.schemas.common import ( MessageResponse, @@ -25,6 +24,7 @@ from app.schemas.common import ( ) from app.schemas.users import PasswordChange, UserResponse, UserUpdate from app.services.auth_service import AuthenticationError, AuthService +from app.services.user_service import user_service logger = logging.getLogger(__name__) @@ -71,7 +71,7 @@ async def list_users( filters["is_superuser"] = is_superuser # Get paginated users with total count - users, total = await user_crud.get_multi_with_total( + users, total = await user_service.list_users( db, skip=pagination.offset, limit=pagination.limit, @@ -107,7 +107,7 @@ async def list_users( """, operation_id="get_current_user_profile", ) -def get_current_user_profile(current_user: User = Depends(get_current_user)) -> Any: +async def get_current_user_profile(current_user: User = Depends(get_current_user)) -> Any: """Get current user's profile.""" return current_user @@ -138,8 +138,8 @@ async def update_current_user( Users cannot elevate their own permissions (protected by UserUpdate schema validator). """ try: - updated_user = await user_crud.update( - db, db_obj=current_user, obj_in=user_update + updated_user = await user_service.update_user( + db, user=current_user, obj_in=user_update ) logger.info(f"User {current_user.id} updated their profile") return updated_user @@ -190,13 +190,7 @@ async def get_user_by_id( ) # Get user - user = await user_crud.get(db, id=str(user_id)) - if not user: - raise NotFoundError( - message=f"User with id {user_id} not found", - error_code=ErrorCode.USER_NOT_FOUND, - ) - + user = await user_service.get_user(db, str(user_id)) return user @@ -241,15 +235,10 @@ async def update_user( ) # Get user - user = await user_crud.get(db, id=str(user_id)) - if not user: - raise NotFoundError( - message=f"User with id {user_id} not found", - error_code=ErrorCode.USER_NOT_FOUND, - ) + user = await user_service.get_user(db, str(user_id)) try: - updated_user = await user_crud.update(db, db_obj=user, obj_in=user_update) + updated_user = await user_service.update_user(db, user=user, obj_in=user_update) logger.info(f"User {user_id} updated by {current_user.id}") return updated_user except ValueError as e: @@ -346,17 +335,12 @@ async def delete_user( error_code=ErrorCode.INSUFFICIENT_PERMISSIONS, ) - # Get user - user = await user_crud.get(db, id=str(user_id)) - if not user: - raise NotFoundError( - message=f"User with id {user_id} not found", - error_code=ErrorCode.USER_NOT_FOUND, - ) + # Get user (raises NotFoundError if not found) + await user_service.get_user(db, str(user_id)) try: # Use soft delete instead of hard delete - await user_crud.soft_delete(db, id=str(user_id)) + await user_service.soft_delete_user(db, str(user_id)) logger.info(f"User {user_id} soft-deleted by {current_user.id}") return MessageResponse( success=True, message=f"User {user_id} deleted successfully" diff --git a/backend/app/core/repository_exceptions.py b/backend/app/core/repository_exceptions.py new file mode 100644 index 0000000..a9a1ba9 --- /dev/null +++ b/backend/app/core/repository_exceptions.py @@ -0,0 +1,26 @@ +""" +Custom exceptions for the repository layer. + +These exceptions allow services and routes to handle database-level errors +with proper semantics, without leaking SQLAlchemy internals. +""" + + +class RepositoryError(Exception): + """Base for all repository-layer errors.""" + + +class DuplicateEntryError(RepositoryError): + """Raised on unique constraint violations. Maps to HTTP 409 Conflict.""" + + +class IntegrityConstraintError(RepositoryError): + """Raised on FK or check constraint violations.""" + + +class RecordNotFoundError(RepositoryError): + """Raised when an expected record doesn't exist.""" + + +class InvalidInputError(RepositoryError): + """Raised on bad pagination params, invalid UUIDs, or other invalid inputs.""" diff --git a/backend/app/crud/__init__.py b/backend/app/crud/__init__.py deleted file mode 100644 index 47c43c3..0000000 --- a/backend/app/crud/__init__.py +++ /dev/null @@ -1,14 +0,0 @@ -# app/crud/__init__.py -from .oauth import oauth_account, oauth_client, oauth_state -from .organization import organization -from .session import session as session_crud -from .user import user - -__all__ = [ - "oauth_account", - "oauth_client", - "oauth_state", - "organization", - "session_crud", - "user", -] diff --git a/backend/app/crud/oauth.py b/backend/app/crud/oauth.py deleted file mode 100755 index e97ddda..0000000 --- a/backend/app/crud/oauth.py +++ /dev/null @@ -1,718 +0,0 @@ -""" -Async CRUD operations for OAuth models using SQLAlchemy 2.0 patterns. - -Provides operations for: -- OAuthAccount: Managing linked OAuth provider accounts -- OAuthState: CSRF protection state during OAuth flows -- OAuthClient: Registered OAuth clients (provider mode skeleton) -""" - -import logging -import secrets -from datetime import UTC, datetime -from uuid import UUID - -from pydantic import BaseModel -from sqlalchemy import and_, delete, select -from sqlalchemy.exc import IntegrityError -from sqlalchemy.ext.asyncio import AsyncSession -from sqlalchemy.orm import joinedload - -from app.crud.base import CRUDBase -from app.models.oauth_account import OAuthAccount -from app.models.oauth_client import OAuthClient -from app.models.oauth_state import OAuthState -from app.schemas.oauth import OAuthAccountCreate, OAuthClientCreate, OAuthStateCreate - -logger = logging.getLogger(__name__) - - -# ============================================================================ -# OAuth Account CRUD -# ============================================================================ - - -class EmptySchema(BaseModel): - """Placeholder schema for CRUD operations that don't need update schemas.""" - - -class CRUDOAuthAccount(CRUDBase[OAuthAccount, OAuthAccountCreate, EmptySchema]): - """CRUD operations for OAuth account links.""" - - async def get_by_provider_id( - self, - db: AsyncSession, - *, - provider: str, - provider_user_id: str, - ) -> OAuthAccount | None: - """ - Get OAuth account by provider and provider user ID. - - Args: - db: Database session - provider: OAuth provider name (google, github) - provider_user_id: User ID from the OAuth provider - - Returns: - OAuthAccount if found, None otherwise - """ - try: - result = await db.execute( - select(OAuthAccount) - .where( - and_( - OAuthAccount.provider == provider, - OAuthAccount.provider_user_id == provider_user_id, - ) - ) - .options(joinedload(OAuthAccount.user)) - ) - return result.scalar_one_or_none() - except Exception as e: # pragma: no cover # pragma: no cover - logger.error( - f"Error getting OAuth account for {provider}:{provider_user_id}: {e!s}" - ) - raise - - async def get_by_provider_email( - self, - db: AsyncSession, - *, - provider: str, - email: str, - ) -> OAuthAccount | None: - """ - Get OAuth account by provider and email. - - Used for auto-linking existing accounts by email. - - Args: - db: Database session - provider: OAuth provider name - email: Email address from the OAuth provider - - Returns: - OAuthAccount if found, None otherwise - """ - try: - result = await db.execute( - select(OAuthAccount) - .where( - and_( - OAuthAccount.provider == provider, - OAuthAccount.provider_email == email, - ) - ) - .options(joinedload(OAuthAccount.user)) - ) - return result.scalar_one_or_none() - except Exception as e: # pragma: no cover # pragma: no cover - logger.error( - f"Error getting OAuth account for {provider} email {email}: {e!s}" - ) - raise - - async def get_user_accounts( - self, - db: AsyncSession, - *, - user_id: str | UUID, - ) -> list[OAuthAccount]: - """ - Get all OAuth accounts linked to a user. - - Args: - db: Database session - user_id: User ID - - Returns: - List of OAuthAccount objects - """ - try: - user_uuid = UUID(str(user_id)) if isinstance(user_id, str) else user_id - - result = await db.execute( - select(OAuthAccount) - .where(OAuthAccount.user_id == user_uuid) - .order_by(OAuthAccount.created_at.desc()) - ) - return list(result.scalars().all()) - except Exception as e: # pragma: no cover - logger.error(f"Error getting OAuth accounts for user {user_id}: {e!s}") - raise - - async def get_user_account_by_provider( - self, - db: AsyncSession, - *, - user_id: str | UUID, - provider: str, - ) -> OAuthAccount | None: - """ - Get a specific OAuth account for a user and provider. - - Args: - db: Database session - user_id: User ID - provider: OAuth provider name - - Returns: - OAuthAccount if found, None otherwise - """ - try: - user_uuid = UUID(str(user_id)) if isinstance(user_id, str) else user_id - - result = await db.execute( - select(OAuthAccount).where( - and_( - OAuthAccount.user_id == user_uuid, - OAuthAccount.provider == provider, - ) - ) - ) - return result.scalar_one_or_none() - except Exception as e: # pragma: no cover - logger.error( - f"Error getting OAuth account for user {user_id}, provider {provider}: {e!s}" - ) - raise - - async def create_account( - self, db: AsyncSession, *, obj_in: OAuthAccountCreate - ) -> OAuthAccount: - """ - Create a new OAuth account link. - - Args: - db: Database session - obj_in: OAuth account creation data - - Returns: - Created OAuthAccount - - Raises: - ValueError: If account already exists or creation fails - """ - try: - db_obj = OAuthAccount( - user_id=obj_in.user_id, - provider=obj_in.provider, - provider_user_id=obj_in.provider_user_id, - provider_email=obj_in.provider_email, - access_token_encrypted=obj_in.access_token_encrypted, - refresh_token_encrypted=obj_in.refresh_token_encrypted, - token_expires_at=obj_in.token_expires_at, - ) - db.add(db_obj) - await db.commit() - await db.refresh(db_obj) - - logger.info( - f"OAuth account created: {obj_in.provider} linked to user {obj_in.user_id}" - ) - return db_obj - except IntegrityError as e: # pragma: no cover - await db.rollback() - error_msg = str(e.orig) if hasattr(e, "orig") else str(e) - if "uq_oauth_provider_user" in error_msg.lower(): - logger.warning( - f"OAuth account already exists: {obj_in.provider}:{obj_in.provider_user_id}" - ) - raise ValueError( - f"This {obj_in.provider} account is already linked to another user" - ) - logger.error(f"Integrity error creating OAuth account: {error_msg}") - raise ValueError(f"Failed to create OAuth account: {error_msg}") - except Exception as e: # pragma: no cover - await db.rollback() - logger.error(f"Error creating OAuth account: {e!s}", exc_info=True) - raise - - async def delete_account( - self, - db: AsyncSession, - *, - user_id: str | UUID, - provider: str, - ) -> bool: - """ - Delete an OAuth account link. - - Args: - db: Database session - user_id: User ID - provider: OAuth provider name - - Returns: - True if deleted, False if not found - """ - try: - user_uuid = UUID(str(user_id)) if isinstance(user_id, str) else user_id - - result = await db.execute( - delete(OAuthAccount).where( - and_( - OAuthAccount.user_id == user_uuid, - OAuthAccount.provider == provider, - ) - ) - ) - await db.commit() - - deleted = result.rowcount > 0 - if deleted: - logger.info( - f"OAuth account deleted: {provider} unlinked from user {user_id}" - ) - else: - logger.warning( - f"OAuth account not found for deletion: {provider} for user {user_id}" - ) - - return deleted - except Exception as e: # pragma: no cover - await db.rollback() - logger.error( - f"Error deleting OAuth account {provider} for user {user_id}: {e!s}" - ) - raise - - async def update_tokens( - self, - db: AsyncSession, - *, - account: OAuthAccount, - access_token_encrypted: str | None = None, - refresh_token_encrypted: str | None = None, - token_expires_at: datetime | None = None, - ) -> OAuthAccount: - """ - Update OAuth tokens for an account. - - Args: - db: Database session - account: OAuthAccount to update - access_token_encrypted: New encrypted access token - refresh_token_encrypted: New encrypted refresh token - token_expires_at: New token expiration time - - Returns: - Updated OAuthAccount - """ - try: - if access_token_encrypted is not None: - account.access_token_encrypted = access_token_encrypted - if refresh_token_encrypted is not None: - account.refresh_token_encrypted = refresh_token_encrypted - if token_expires_at is not None: - account.token_expires_at = token_expires_at - - db.add(account) - await db.commit() - await db.refresh(account) - - return account - except Exception as e: # pragma: no cover - await db.rollback() - logger.error(f"Error updating OAuth tokens: {e!s}") - raise - - -# ============================================================================ -# OAuth State CRUD -# ============================================================================ - - -class CRUDOAuthState(CRUDBase[OAuthState, OAuthStateCreate, EmptySchema]): - """CRUD operations for OAuth state (CSRF protection).""" - - async def create_state( - self, db: AsyncSession, *, obj_in: OAuthStateCreate - ) -> OAuthState: - """ - Create a new OAuth state for CSRF protection. - - Args: - db: Database session - obj_in: OAuth state creation data - - Returns: - Created OAuthState - """ - try: - db_obj = OAuthState( - state=obj_in.state, - code_verifier=obj_in.code_verifier, - nonce=obj_in.nonce, - provider=obj_in.provider, - redirect_uri=obj_in.redirect_uri, - user_id=obj_in.user_id, - expires_at=obj_in.expires_at, - ) - db.add(db_obj) - await db.commit() - await db.refresh(db_obj) - - logger.debug(f"OAuth state created for {obj_in.provider}") - return db_obj - except IntegrityError as e: # pragma: no cover - await db.rollback() - # State collision (extremely rare with cryptographic random) - error_msg = str(e.orig) if hasattr(e, "orig") else str(e) - logger.error(f"OAuth state collision: {error_msg}") - raise ValueError("Failed to create OAuth state, please retry") - except Exception as e: # pragma: no cover - await db.rollback() - logger.error(f"Error creating OAuth state: {e!s}", exc_info=True) - raise - - async def get_and_consume_state( - self, db: AsyncSession, *, state: str - ) -> OAuthState | None: - """ - Get and delete OAuth state (consume it). - - This ensures each state can only be used once (replay protection). - - Args: - db: Database session - state: State string to look up - - Returns: - OAuthState if found and valid, None otherwise - """ - try: - # Get the state - result = await db.execute( - select(OAuthState).where(OAuthState.state == state) - ) - db_obj = result.scalar_one_or_none() - - if db_obj is None: - logger.warning(f"OAuth state not found: {state[:8]}...") - return None - - # Check expiration - # Handle both timezone-aware and timezone-naive datetimes - now = datetime.now(UTC) - expires_at = db_obj.expires_at - if expires_at.tzinfo is None: - # SQLite returns naive datetimes, assume UTC - expires_at = expires_at.replace(tzinfo=UTC) - - if expires_at < now: - logger.warning(f"OAuth state expired: {state[:8]}...") - await db.delete(db_obj) - await db.commit() - return None - - # Delete it (consume) - await db.delete(db_obj) - await db.commit() - - logger.debug(f"OAuth state consumed: {state[:8]}...") - return db_obj - except Exception as e: # pragma: no cover - await db.rollback() - logger.error(f"Error consuming OAuth state: {e!s}") - raise - - async def cleanup_expired(self, db: AsyncSession) -> int: - """ - Clean up expired OAuth states. - - Should be called periodically to remove stale states. - - Args: - db: Database session - - Returns: - Number of states deleted - """ - try: - now = datetime.now(UTC) - - stmt = delete(OAuthState).where(OAuthState.expires_at < now) - result = await db.execute(stmt) - await db.commit() - - count = result.rowcount - if count > 0: - logger.info(f"Cleaned up {count} expired OAuth states") - - return count - except Exception as e: # pragma: no cover - await db.rollback() - logger.error(f"Error cleaning up expired OAuth states: {e!s}") - raise - - -# ============================================================================ -# OAuth Client CRUD (Provider Mode - Skeleton) -# ============================================================================ - - -class CRUDOAuthClient(CRUDBase[OAuthClient, OAuthClientCreate, EmptySchema]): - """ - CRUD operations for OAuth clients (provider mode). - - This is a skeleton implementation for MCP client registration. - Full implementation can be expanded when needed. - """ - - async def get_by_client_id( - self, db: AsyncSession, *, client_id: str - ) -> OAuthClient | None: - """ - Get OAuth client by client_id. - - Args: - db: Database session - client_id: OAuth client ID - - Returns: - OAuthClient if found, None otherwise - """ - try: - result = await db.execute( - select(OAuthClient).where( - and_( - OAuthClient.client_id == client_id, - OAuthClient.is_active == True, # noqa: E712 - ) - ) - ) - return result.scalar_one_or_none() - except Exception as e: # pragma: no cover - logger.error(f"Error getting OAuth client {client_id}: {e!s}") - raise - - async def create_client( - self, - db: AsyncSession, - *, - obj_in: OAuthClientCreate, - owner_user_id: UUID | None = None, - ) -> tuple[OAuthClient, str | None]: - """ - Create a new OAuth client. - - Args: - db: Database session - obj_in: OAuth client creation data - owner_user_id: Optional owner user ID - - Returns: - Tuple of (created OAuthClient, client_secret or None for public clients) - """ - try: - # Generate client_id - client_id = secrets.token_urlsafe(32) - - # Generate client_secret for confidential clients - client_secret = None - client_secret_hash = None - if obj_in.client_type == "confidential": - client_secret = secrets.token_urlsafe(48) - # SECURITY: Use bcrypt for secret storage (not SHA-256) - # bcrypt is computationally expensive, making brute-force attacks infeasible - from app.core.auth import get_password_hash - - client_secret_hash = get_password_hash(client_secret) - - db_obj = OAuthClient( - client_id=client_id, - client_secret_hash=client_secret_hash, - client_name=obj_in.client_name, - client_description=obj_in.client_description, - client_type=obj_in.client_type, - redirect_uris=obj_in.redirect_uris, - allowed_scopes=obj_in.allowed_scopes, - owner_user_id=owner_user_id, - is_active=True, - ) - db.add(db_obj) - await db.commit() - await db.refresh(db_obj) - - logger.info( - f"OAuth client created: {obj_in.client_name} ({client_id[:8]}...)" - ) - return db_obj, client_secret - except IntegrityError as e: # pragma: no cover - await db.rollback() - error_msg = str(e.orig) if hasattr(e, "orig") else str(e) - logger.error(f"Error creating OAuth client: {error_msg}") - raise ValueError(f"Failed to create OAuth client: {error_msg}") - except Exception as e: # pragma: no cover - await db.rollback() - logger.error(f"Error creating OAuth client: {e!s}", exc_info=True) - raise - - async def deactivate_client( - self, db: AsyncSession, *, client_id: str - ) -> OAuthClient | None: - """ - Deactivate an OAuth client. - - Args: - db: Database session - client_id: OAuth client ID - - Returns: - Deactivated OAuthClient if found, None otherwise - """ - try: - client = await self.get_by_client_id(db, client_id=client_id) - if client is None: - return None - - client.is_active = False - db.add(client) - await db.commit() - await db.refresh(client) - - logger.info(f"OAuth client deactivated: {client.client_name}") - return client - except Exception as e: # pragma: no cover - await db.rollback() - logger.error(f"Error deactivating OAuth client {client_id}: {e!s}") - raise - - async def validate_redirect_uri( - self, db: AsyncSession, *, client_id: str, redirect_uri: str - ) -> bool: - """ - Validate that a redirect URI is allowed for a client. - - Args: - db: Database session - client_id: OAuth client ID - redirect_uri: Redirect URI to validate - - Returns: - True if valid, False otherwise - """ - try: - client = await self.get_by_client_id(db, client_id=client_id) - if client is None: - return False - - return redirect_uri in (client.redirect_uris or []) - except Exception as e: # pragma: no cover - logger.error(f"Error validating redirect URI: {e!s}") - return False - - async def verify_client_secret( - self, db: AsyncSession, *, client_id: str, client_secret: str - ) -> bool: - """ - Verify client credentials. - - Args: - db: Database session - client_id: OAuth client ID - client_secret: Client secret to verify - - Returns: - True if valid, False otherwise - """ - try: - result = await db.execute( - select(OAuthClient).where( - and_( - OAuthClient.client_id == client_id, - OAuthClient.is_active == True, # noqa: E712 - ) - ) - ) - client = result.scalar_one_or_none() - - if client is None or client.client_secret_hash is None: - return False - - # SECURITY: Verify secret using bcrypt (not SHA-256) - # This supports both old SHA-256 hashes (for migration) and new bcrypt hashes - from app.core.auth import verify_password - - stored_hash: str = str(client.client_secret_hash) - - # Check if it's a bcrypt hash (starts with $2b$) or legacy SHA-256 - if stored_hash.startswith("$2"): - # New bcrypt format - return verify_password(client_secret, stored_hash) - else: - # Legacy SHA-256 format - still support for migration - import hashlib - - secret_hash = hashlib.sha256(client_secret.encode()).hexdigest() - return secrets.compare_digest(stored_hash, secret_hash) - except Exception as e: # pragma: no cover - logger.error(f"Error verifying client secret: {e!s}") - return False - - async def get_all_clients( - self, db: AsyncSession, *, include_inactive: bool = False - ) -> list[OAuthClient]: - """ - Get all OAuth clients. - - Args: - db: Database session - include_inactive: Whether to include inactive clients - - Returns: - List of OAuthClient objects - """ - try: - query = select(OAuthClient).order_by(OAuthClient.created_at.desc()) - if not include_inactive: - query = query.where(OAuthClient.is_active == True) # noqa: E712 - - result = await db.execute(query) - return list(result.scalars().all()) - except Exception as e: # pragma: no cover - logger.error(f"Error getting all OAuth clients: {e!s}") - raise - - async def delete_client(self, db: AsyncSession, *, client_id: str) -> bool: - """ - Delete an OAuth client permanently. - - Note: This will cascade delete related records (tokens, consents, etc.) - due to foreign key constraints. - - Args: - db: Database session - client_id: OAuth client ID - - Returns: - True if deleted, False if not found - """ - try: - result = await db.execute( - delete(OAuthClient).where(OAuthClient.client_id == client_id) - ) - await db.commit() - - deleted = result.rowcount > 0 - if deleted: - logger.info(f"OAuth client deleted: {client_id}") - else: - logger.warning(f"OAuth client not found for deletion: {client_id}") - - return deleted - except Exception as e: # pragma: no cover - await db.rollback() - logger.error(f"Error deleting OAuth client {client_id}: {e!s}") - raise - - -# ============================================================================ -# Singleton instances -# ============================================================================ - -oauth_account = CRUDOAuthAccount(OAuthAccount) -oauth_state = CRUDOAuthState(OAuthState) -oauth_client = CRUDOAuthClient(OAuthClient) diff --git a/backend/app/init_db.py b/backend/app/init_db.py index d429a8d..07db637 100644 --- a/backend/app/init_db.py +++ b/backend/app/init_db.py @@ -16,7 +16,7 @@ from sqlalchemy import select, text from app.core.config import settings from app.core.database import SessionLocal, engine -from app.crud.user import user as user_crud +from app.repositories.user import user_repo as user_crud from app.models.organization import Organization from app.models.user import User from app.models.user_organization import UserOrganization diff --git a/backend/app/models/oauth_account.py b/backend/app/models/oauth_account.py index 2178cf8..7b2acfe 100755 --- a/backend/app/models/oauth_account.py +++ b/backend/app/models/oauth_account.py @@ -36,9 +36,9 @@ class OAuthAccount(Base, UUIDMixin, TimestampMixin): ) # Email from provider (for reference) # Optional: store provider tokens for API access - # These should be encrypted at rest in production - access_token_encrypted = Column(String(2048), nullable=True) - refresh_token_encrypted = Column(String(2048), nullable=True) + # TODO: Encrypt these at rest in production (requires key management infrastructure) + access_token = Column(String(2048), nullable=True) + refresh_token = Column(String(2048), nullable=True) token_expires_at = Column(DateTime(timezone=True), nullable=True) # Relationship diff --git a/backend/app/repositories/__init__.py b/backend/app/repositories/__init__.py new file mode 100644 index 0000000..0c62864 --- /dev/null +++ b/backend/app/repositories/__init__.py @@ -0,0 +1,39 @@ +# app/repositories/__init__.py +"""Repository layer — all database access goes through these classes.""" + +from app.repositories.oauth_account import OAuthAccountRepository, oauth_account_repo +from app.repositories.oauth_authorization_code import ( + OAuthAuthorizationCodeRepository, + oauth_authorization_code_repo, +) +from app.repositories.oauth_client import OAuthClientRepository, oauth_client_repo +from app.repositories.oauth_consent import OAuthConsentRepository, oauth_consent_repo +from app.repositories.oauth_provider_token import ( + OAuthProviderTokenRepository, + oauth_provider_token_repo, +) +from app.repositories.oauth_state import OAuthStateRepository, oauth_state_repo +from app.repositories.organization import OrganizationRepository, organization_repo +from app.repositories.session import SessionRepository, session_repo +from app.repositories.user import UserRepository, user_repo + +__all__ = [ + "UserRepository", + "user_repo", + "OrganizationRepository", + "organization_repo", + "SessionRepository", + "session_repo", + "OAuthAccountRepository", + "oauth_account_repo", + "OAuthAuthorizationCodeRepository", + "oauth_authorization_code_repo", + "OAuthClientRepository", + "oauth_client_repo", + "OAuthConsentRepository", + "oauth_consent_repo", + "OAuthProviderTokenRepository", + "oauth_provider_token_repo", + "OAuthStateRepository", + "oauth_state_repo", +] diff --git a/backend/app/crud/base.py b/backend/app/repositories/base.py old mode 100755 new mode 100644 similarity index 80% rename from backend/app/crud/base.py rename to backend/app/repositories/base.py index a977922..d8cac40 --- a/backend/app/crud/base.py +++ b/backend/app/repositories/base.py @@ -1,6 +1,6 @@ -# app/crud/base_async.py +# app/repositories/base.py """ -Async CRUD operations base class using SQLAlchemy 2.0 async patterns. +Base repository class for async CRUD operations using SQLAlchemy 2.0 async patterns. Provides reusable create, read, update, and delete operations for all models. """ @@ -18,6 +18,11 @@ from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm import Load from app.core.database import Base +from app.core.repository_exceptions import ( + DuplicateEntryError, + IntegrityConstraintError, + InvalidInputError, +) logger = logging.getLogger(__name__) @@ -26,16 +31,16 @@ CreateSchemaType = TypeVar("CreateSchemaType", bound=BaseModel) UpdateSchemaType = TypeVar("UpdateSchemaType", bound=BaseModel) -class CRUDBase[ +class BaseRepository[ ModelType: Base, CreateSchemaType: BaseModel, UpdateSchemaType: BaseModel, ]: - """Async CRUD operations for a model.""" + """Async repository operations for a model.""" def __init__(self, model: type[ModelType]): """ - CRUD object with default async methods to Create, Read, Update, Delete. + Repository object with default async methods to Create, Read, Update, Delete. Parameters: model: A SQLAlchemy model class @@ -56,13 +61,7 @@ class CRUDBase[ Returns: Model instance or None if not found - - Example: - # Eager load user relationship - from sqlalchemy.orm import joinedload - session = await session_crud.get(db, id=session_id, options=[joinedload(UserSession.user)]) """ - # Validate UUID format and convert to UUID object if string try: if isinstance(id, uuid.UUID): uuid_obj = id @@ -75,7 +74,6 @@ class CRUDBase[ try: query = select(self.model).where(self.model.id == uuid_obj) - # Apply eager loading options if provided if options: for option in options: query = query.options(option) @@ -96,28 +94,17 @@ class CRUDBase[ ) -> list[ModelType]: """ Get multiple records with pagination validation and optional eager loading. - - Args: - db: Database session - skip: Number of records to skip - limit: Maximum number of records to return - options: Optional list of SQLAlchemy load options for eager loading - - Returns: - List of model instances """ - # Validate pagination parameters if skip < 0: - raise ValueError("skip must be non-negative") + raise InvalidInputError("skip must be non-negative") if limit < 0: - raise ValueError("limit must be non-negative") + raise InvalidInputError("limit must be non-negative") if limit > 1000: - raise ValueError("Maximum limit is 1000") + raise InvalidInputError("Maximum limit is 1000") try: - query = select(self.model).offset(skip).limit(limit) + query = select(self.model).order_by(self.model.id).offset(skip).limit(limit) - # Apply eager loading options if provided if options: for option in options: query = query.options(option) @@ -136,9 +123,8 @@ class CRUDBase[ """Create a new record with error handling. NOTE: This method is defensive code that's never called in practice. - All CRUD subclasses (CRUDUser, CRUDOrganization, CRUDSession) override this method - with their own implementations, so the base implementation and its exception handlers - are never executed. Marked as pragma: no cover to avoid false coverage gaps. + All repository subclasses override this method with their own implementations. + Marked as pragma: no cover to avoid false coverage gaps. """ try: # pragma: no cover obj_in_data = jsonable_encoder(obj_in) @@ -154,15 +140,15 @@ class CRUDBase[ logger.warning( f"Duplicate entry attempted for {self.model.__name__}: {error_msg}" ) - raise ValueError( + raise DuplicateEntryError( f"A {self.model.__name__} with this data already exists" ) logger.error(f"Integrity error creating {self.model.__name__}: {error_msg}") - raise ValueError(f"Database integrity error: {error_msg}") + raise IntegrityConstraintError(f"Database integrity error: {error_msg}") except (OperationalError, DataError) as e: # pragma: no cover await db.rollback() logger.error(f"Database error creating {self.model.__name__}: {e!s}") - raise ValueError(f"Database operation failed: {e!s}") + raise IntegrityConstraintError(f"Database operation failed: {e!s}") except Exception as e: # pragma: no cover await db.rollback() logger.error( @@ -200,15 +186,15 @@ class CRUDBase[ logger.warning( f"Duplicate entry attempted for {self.model.__name__}: {error_msg}" ) - raise ValueError( + raise DuplicateEntryError( f"A {self.model.__name__} with this data already exists" ) logger.error(f"Integrity error updating {self.model.__name__}: {error_msg}") - raise ValueError(f"Database integrity error: {error_msg}") + raise IntegrityConstraintError(f"Database integrity error: {error_msg}") except (OperationalError, DataError) as e: await db.rollback() logger.error(f"Database error updating {self.model.__name__}: {e!s}") - raise ValueError(f"Database operation failed: {e!s}") + raise IntegrityConstraintError(f"Database operation failed: {e!s}") except Exception as e: await db.rollback() logger.error( @@ -218,7 +204,6 @@ class CRUDBase[ async def remove(self, db: AsyncSession, *, id: str) -> ModelType | None: """Delete a record with error handling and null check.""" - # Validate UUID format and convert to UUID object if string try: if isinstance(id, uuid.UUID): uuid_obj = id @@ -247,7 +232,7 @@ class CRUDBase[ await db.rollback() error_msg = str(e.orig) if hasattr(e, "orig") else str(e) logger.error(f"Integrity error deleting {self.model.__name__}: {error_msg}") - raise ValueError( + raise IntegrityConstraintError( f"Cannot delete {self.model.__name__}: referenced by other records" ) except Exception as e: @@ -272,57 +257,40 @@ class CRUDBase[ Get multiple records with total count, filtering, and sorting. NOTE: This method is defensive code that's never called in practice. - All CRUD subclasses (CRUDUser, CRUDOrganization, CRUDSession) override this method - with their own implementations that include additional parameters like search. + All repository subclasses override this method with their own implementations. Marked as pragma: no cover to avoid false coverage gaps. - - Args: - db: Database session - skip: Number of records to skip - limit: Maximum number of records to return - sort_by: Field name to sort by (must be a valid model attribute) - sort_order: Sort order ("asc" or "desc") - filters: Dictionary of filters (field_name: value) - - Returns: - Tuple of (items, total_count) """ - # Validate pagination parameters if skip < 0: - raise ValueError("skip must be non-negative") + raise InvalidInputError("skip must be non-negative") if limit < 0: - raise ValueError("limit must be non-negative") + raise InvalidInputError("limit must be non-negative") if limit > 1000: - raise ValueError("Maximum limit is 1000") + raise InvalidInputError("Maximum limit is 1000") try: - # Build base query query = select(self.model) - # Exclude soft-deleted records by default if hasattr(self.model, "deleted_at"): query = query.where(self.model.deleted_at.is_(None)) - # Apply filters if filters: for field, value in filters.items(): if hasattr(self.model, field) and value is not None: query = query.where(getattr(self.model, field) == value) - # Get total count (before pagination) count_query = select(func.count()).select_from(query.alias()) count_result = await db.execute(count_query) total = count_result.scalar_one() - # Apply sorting if sort_by and hasattr(self.model, sort_by): sort_column = getattr(self.model, sort_by) if sort_order.lower() == "desc": query = query.order_by(sort_column.desc()) else: query = query.order_by(sort_column.asc()) + else: + query = query.order_by(self.model.id) - # Apply pagination query = query.offset(skip).limit(limit) items_result = await db.execute(query) items = list(items_result.scalars().all()) @@ -356,7 +324,6 @@ class CRUDBase[ """ from datetime import datetime - # Validate UUID format and convert to UUID object if string try: if isinstance(id, uuid.UUID): uuid_obj = id @@ -378,14 +345,12 @@ class CRUDBase[ ) return None - # Check if model supports soft deletes if not hasattr(self.model, "deleted_at"): logger.error(f"{self.model.__name__} does not support soft deletes") - raise ValueError( + raise InvalidInputError( f"{self.model.__name__} does not have a deleted_at column" ) - # Set deleted_at timestamp obj.deleted_at = datetime.now(UTC) db.add(obj) await db.commit() @@ -405,7 +370,6 @@ class CRUDBase[ Only works if the model has a 'deleted_at' column. """ - # Validate UUID format try: if isinstance(id, uuid.UUID): uuid_obj = id @@ -416,7 +380,6 @@ class CRUDBase[ return None try: - # Find the soft-deleted record if hasattr(self.model, "deleted_at"): result = await db.execute( select(self.model).where( @@ -426,7 +389,7 @@ class CRUDBase[ obj = result.scalar_one_or_none() else: logger.error(f"{self.model.__name__} does not support soft deletes") - raise ValueError( + raise InvalidInputError( f"{self.model.__name__} does not have a deleted_at column" ) @@ -436,7 +399,6 @@ class CRUDBase[ ) return None - # Clear deleted_at timestamp obj.deleted_at = None db.add(obj) await db.commit() @@ -449,3 +411,4 @@ class CRUDBase[ exc_info=True, ) raise + diff --git a/backend/app/repositories/oauth_account.py b/backend/app/repositories/oauth_account.py new file mode 100644 index 0000000..d24ac33 --- /dev/null +++ b/backend/app/repositories/oauth_account.py @@ -0,0 +1,235 @@ +# app/repositories/oauth_account.py +"""Repository for OAuthAccount model async CRUD operations.""" + +import logging +from datetime import datetime +from uuid import UUID + +from pydantic import BaseModel +from sqlalchemy import and_, delete, select +from sqlalchemy.exc import IntegrityError +from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy.orm import joinedload + +from app.core.repository_exceptions import DuplicateEntryError +from app.models.oauth_account import OAuthAccount +from app.repositories.base import BaseRepository +from app.schemas.oauth import OAuthAccountCreate + +logger = logging.getLogger(__name__) + + +class EmptySchema(BaseModel): + """Placeholder schema for repository operations that don't need update schemas.""" + + +class OAuthAccountRepository(BaseRepository[OAuthAccount, OAuthAccountCreate, EmptySchema]): + """Repository for OAuth account links.""" + + async def get_by_provider_id( + self, + db: AsyncSession, + *, + provider: str, + provider_user_id: str, + ) -> OAuthAccount | None: + """Get OAuth account by provider and provider user ID.""" + try: + result = await db.execute( + select(OAuthAccount) + .where( + and_( + OAuthAccount.provider == provider, + OAuthAccount.provider_user_id == provider_user_id, + ) + ) + .options(joinedload(OAuthAccount.user)) + ) + return result.scalar_one_or_none() + except Exception as e: # pragma: no cover + logger.error( + f"Error getting OAuth account for {provider}:{provider_user_id}: {e!s}" + ) + raise + + async def get_by_provider_email( + self, + db: AsyncSession, + *, + provider: str, + email: str, + ) -> OAuthAccount | None: + """Get OAuth account by provider and email.""" + try: + result = await db.execute( + select(OAuthAccount) + .where( + and_( + OAuthAccount.provider == provider, + OAuthAccount.provider_email == email, + ) + ) + .options(joinedload(OAuthAccount.user)) + ) + return result.scalar_one_or_none() + except Exception as e: # pragma: no cover + logger.error( + f"Error getting OAuth account for {provider} email {email}: {e!s}" + ) + raise + + async def get_user_accounts( + self, + db: AsyncSession, + *, + user_id: str | UUID, + ) -> list[OAuthAccount]: + """Get all OAuth accounts linked to a user.""" + try: + user_uuid = UUID(str(user_id)) if isinstance(user_id, str) else user_id + + result = await db.execute( + select(OAuthAccount) + .where(OAuthAccount.user_id == user_uuid) + .order_by(OAuthAccount.created_at.desc()) + ) + return list(result.scalars().all()) + except Exception as e: # pragma: no cover + logger.error(f"Error getting OAuth accounts for user {user_id}: {e!s}") + raise + + async def get_user_account_by_provider( + self, + db: AsyncSession, + *, + user_id: str | UUID, + provider: str, + ) -> OAuthAccount | None: + """Get a specific OAuth account for a user and provider.""" + try: + user_uuid = UUID(str(user_id)) if isinstance(user_id, str) else user_id + + result = await db.execute( + select(OAuthAccount).where( + and_( + OAuthAccount.user_id == user_uuid, + OAuthAccount.provider == provider, + ) + ) + ) + return result.scalar_one_or_none() + except Exception as e: # pragma: no cover + logger.error( + f"Error getting OAuth account for user {user_id}, provider {provider}: {e!s}" + ) + raise + + async def create_account( + self, db: AsyncSession, *, obj_in: OAuthAccountCreate + ) -> OAuthAccount: + """Create a new OAuth account link.""" + try: + db_obj = OAuthAccount( + user_id=obj_in.user_id, + provider=obj_in.provider, + provider_user_id=obj_in.provider_user_id, + provider_email=obj_in.provider_email, + access_token=obj_in.access_token, + refresh_token=obj_in.refresh_token, + token_expires_at=obj_in.token_expires_at, + ) + db.add(db_obj) + await db.commit() + await db.refresh(db_obj) + + logger.info( + f"OAuth account created: {obj_in.provider} linked to user {obj_in.user_id}" + ) + return db_obj + except IntegrityError as e: # pragma: no cover + await db.rollback() + error_msg = str(e.orig) if hasattr(e, "orig") else str(e) + if "uq_oauth_provider_user" in error_msg.lower(): + logger.warning( + f"OAuth account already exists: {obj_in.provider}:{obj_in.provider_user_id}" + ) + raise DuplicateEntryError( + f"This {obj_in.provider} account is already linked to another user" + ) + logger.error(f"Integrity error creating OAuth account: {error_msg}") + raise DuplicateEntryError(f"Failed to create OAuth account: {error_msg}") + except Exception as e: # pragma: no cover + await db.rollback() + logger.error(f"Error creating OAuth account: {e!s}", exc_info=True) + raise + + async def delete_account( + self, + db: AsyncSession, + *, + user_id: str | UUID, + provider: str, + ) -> bool: + """Delete an OAuth account link.""" + try: + user_uuid = UUID(str(user_id)) if isinstance(user_id, str) else user_id + + result = await db.execute( + delete(OAuthAccount).where( + and_( + OAuthAccount.user_id == user_uuid, + OAuthAccount.provider == provider, + ) + ) + ) + await db.commit() + + deleted = result.rowcount > 0 + if deleted: + logger.info( + f"OAuth account deleted: {provider} unlinked from user {user_id}" + ) + else: + logger.warning( + f"OAuth account not found for deletion: {provider} for user {user_id}" + ) + + return deleted + except Exception as e: # pragma: no cover + await db.rollback() + logger.error( + f"Error deleting OAuth account {provider} for user {user_id}: {e!s}" + ) + raise + + async def update_tokens( + self, + db: AsyncSession, + *, + account: OAuthAccount, + access_token: str | None = None, + refresh_token: str | None = None, + token_expires_at: datetime | None = None, + ) -> OAuthAccount: + """Update OAuth tokens for an account.""" + try: + if access_token is not None: + account.access_token = access_token + if refresh_token is not None: + account.refresh_token = refresh_token + if token_expires_at is not None: + account.token_expires_at = token_expires_at + + db.add(account) + await db.commit() + await db.refresh(account) + + return account + except Exception as e: # pragma: no cover + await db.rollback() + logger.error(f"Error updating OAuth tokens: {e!s}") + raise + + +# Singleton instance +oauth_account_repo = OAuthAccountRepository(OAuthAccount) diff --git a/backend/app/repositories/oauth_authorization_code.py b/backend/app/repositories/oauth_authorization_code.py new file mode 100644 index 0000000..c3c0453 --- /dev/null +++ b/backend/app/repositories/oauth_authorization_code.py @@ -0,0 +1,108 @@ +# app/repositories/oauth_authorization_code.py +"""Repository for OAuthAuthorizationCode model.""" + +import logging +from datetime import UTC, datetime +from uuid import UUID + +from sqlalchemy import and_, delete, select, update +from sqlalchemy.ext.asyncio import AsyncSession + +from app.models.oauth_authorization_code import OAuthAuthorizationCode + +logger = logging.getLogger(__name__) + + +class OAuthAuthorizationCodeRepository: + """Repository for OAuth 2.0 authorization codes.""" + + async def create_code( + self, + db: AsyncSession, + *, + code: str, + client_id: str, + user_id: UUID, + redirect_uri: str, + scope: str, + expires_at: datetime, + code_challenge: str | None = None, + code_challenge_method: str | None = None, + state: str | None = None, + nonce: str | None = None, + ) -> OAuthAuthorizationCode: + """Create and persist a new authorization code.""" + auth_code = OAuthAuthorizationCode( + code=code, + client_id=client_id, + user_id=user_id, + redirect_uri=redirect_uri, + scope=scope, + code_challenge=code_challenge, + code_challenge_method=code_challenge_method, + state=state, + nonce=nonce, + expires_at=expires_at, + used=False, + ) + db.add(auth_code) + await db.commit() + return auth_code + + async def consume_code_atomically( + self, db: AsyncSession, *, code: str + ) -> UUID | None: + """ + Atomically mark a code as used and return its UUID. + + Returns the UUID if the code was found and not yet used, None otherwise. + This prevents race conditions per RFC 6749 Section 4.1.2. + """ + stmt = ( + update(OAuthAuthorizationCode) + .where( + and_( + OAuthAuthorizationCode.code == code, + OAuthAuthorizationCode.used == False, # noqa: E712 + ) + ) + .values(used=True) + .returning(OAuthAuthorizationCode.id) + ) + result = await db.execute(stmt) + row_id = result.scalar_one_or_none() + if row_id is not None: + await db.commit() + return row_id + + async def get_by_id( + self, db: AsyncSession, *, code_id: UUID + ) -> OAuthAuthorizationCode | None: + """Get authorization code by its UUID primary key.""" + result = await db.execute( + select(OAuthAuthorizationCode).where(OAuthAuthorizationCode.id == code_id) + ) + return result.scalar_one_or_none() + + async def get_by_code( + self, db: AsyncSession, *, code: str + ) -> OAuthAuthorizationCode | None: + """Get authorization code by the code string value.""" + result = await db.execute( + select(OAuthAuthorizationCode).where(OAuthAuthorizationCode.code == code) + ) + return result.scalar_one_or_none() + + async def cleanup_expired(self, db: AsyncSession) -> int: + """Delete all expired authorization codes. Returns count deleted.""" + result = await db.execute( + delete(OAuthAuthorizationCode).where( + OAuthAuthorizationCode.expires_at < datetime.now(UTC) + ) + ) + await db.commit() + return result.rowcount # type: ignore[attr-defined] + + +# Singleton instance +oauth_authorization_code_repo = OAuthAuthorizationCodeRepository() diff --git a/backend/app/repositories/oauth_client.py b/backend/app/repositories/oauth_client.py new file mode 100644 index 0000000..445b93c --- /dev/null +++ b/backend/app/repositories/oauth_client.py @@ -0,0 +1,199 @@ +# app/repositories/oauth_client.py +"""Repository for OAuthClient model async CRUD operations.""" + +import logging +import secrets +from uuid import UUID + +from pydantic import BaseModel +from sqlalchemy import and_, delete, select +from sqlalchemy.exc import IntegrityError +from sqlalchemy.ext.asyncio import AsyncSession + +from app.core.repository_exceptions import DuplicateEntryError +from app.models.oauth_client import OAuthClient +from app.repositories.base import BaseRepository +from app.schemas.oauth import OAuthClientCreate + +logger = logging.getLogger(__name__) + + +class EmptySchema(BaseModel): + """Placeholder schema for repository operations that don't need update schemas.""" + + +class OAuthClientRepository(BaseRepository[OAuthClient, OAuthClientCreate, EmptySchema]): + """Repository for OAuth clients (provider mode).""" + + async def get_by_client_id( + self, db: AsyncSession, *, client_id: str + ) -> OAuthClient | None: + """Get OAuth client by client_id.""" + try: + result = await db.execute( + select(OAuthClient).where( + and_( + OAuthClient.client_id == client_id, + OAuthClient.is_active == True, # noqa: E712 + ) + ) + ) + return result.scalar_one_or_none() + except Exception as e: # pragma: no cover + logger.error(f"Error getting OAuth client {client_id}: {e!s}") + raise + + async def create_client( + self, + db: AsyncSession, + *, + obj_in: OAuthClientCreate, + owner_user_id: UUID | None = None, + ) -> tuple[OAuthClient, str | None]: + """Create a new OAuth client.""" + try: + client_id = secrets.token_urlsafe(32) + + client_secret = None + client_secret_hash = None + if obj_in.client_type == "confidential": + client_secret = secrets.token_urlsafe(48) + from app.core.auth import get_password_hash + + client_secret_hash = get_password_hash(client_secret) + + db_obj = OAuthClient( + client_id=client_id, + client_secret_hash=client_secret_hash, + client_name=obj_in.client_name, + client_description=obj_in.client_description, + client_type=obj_in.client_type, + redirect_uris=obj_in.redirect_uris, + allowed_scopes=obj_in.allowed_scopes, + owner_user_id=owner_user_id, + is_active=True, + ) + db.add(db_obj) + await db.commit() + await db.refresh(db_obj) + + logger.info( + f"OAuth client created: {obj_in.client_name} ({client_id[:8]}...)" + ) + return db_obj, client_secret + except IntegrityError as e: # pragma: no cover + await db.rollback() + error_msg = str(e.orig) if hasattr(e, "orig") else str(e) + logger.error(f"Error creating OAuth client: {error_msg}") + raise DuplicateEntryError(f"Failed to create OAuth client: {error_msg}") + except Exception as e: # pragma: no cover + await db.rollback() + logger.error(f"Error creating OAuth client: {e!s}", exc_info=True) + raise + + async def deactivate_client( + self, db: AsyncSession, *, client_id: str + ) -> OAuthClient | None: + """Deactivate an OAuth client.""" + try: + client = await self.get_by_client_id(db, client_id=client_id) + if client is None: + return None + + client.is_active = False + db.add(client) + await db.commit() + await db.refresh(client) + + logger.info(f"OAuth client deactivated: {client.client_name}") + return client + except Exception as e: # pragma: no cover + await db.rollback() + logger.error(f"Error deactivating OAuth client {client_id}: {e!s}") + raise + + async def validate_redirect_uri( + self, db: AsyncSession, *, client_id: str, redirect_uri: str + ) -> bool: + """Validate that a redirect URI is allowed for a client.""" + try: + client = await self.get_by_client_id(db, client_id=client_id) + if client is None: + return False + + return redirect_uri in (client.redirect_uris or []) + except Exception as e: # pragma: no cover + logger.error(f"Error validating redirect URI: {e!s}") + return False + + async def verify_client_secret( + self, db: AsyncSession, *, client_id: str, client_secret: str + ) -> bool: + """Verify client credentials.""" + try: + result = await db.execute( + select(OAuthClient).where( + and_( + OAuthClient.client_id == client_id, + OAuthClient.is_active == True, # noqa: E712 + ) + ) + ) + client = result.scalar_one_or_none() + + if client is None or client.client_secret_hash is None: + return False + + from app.core.auth import verify_password + + stored_hash: str = str(client.client_secret_hash) + + if stored_hash.startswith("$2"): + return verify_password(client_secret, stored_hash) + else: + import hashlib + + secret_hash = hashlib.sha256(client_secret.encode()).hexdigest() + return secrets.compare_digest(stored_hash, secret_hash) + except Exception as e: # pragma: no cover + logger.error(f"Error verifying client secret: {e!s}") + return False + + async def get_all_clients( + self, db: AsyncSession, *, include_inactive: bool = False + ) -> list[OAuthClient]: + """Get all OAuth clients.""" + try: + query = select(OAuthClient).order_by(OAuthClient.created_at.desc()) + if not include_inactive: + query = query.where(OAuthClient.is_active == True) # noqa: E712 + + result = await db.execute(query) + return list(result.scalars().all()) + except Exception as e: # pragma: no cover + logger.error(f"Error getting all OAuth clients: {e!s}") + raise + + async def delete_client(self, db: AsyncSession, *, client_id: str) -> bool: + """Delete an OAuth client permanently.""" + try: + result = await db.execute( + delete(OAuthClient).where(OAuthClient.client_id == client_id) + ) + await db.commit() + + deleted = result.rowcount > 0 + if deleted: + logger.info(f"OAuth client deleted: {client_id}") + else: + logger.warning(f"OAuth client not found for deletion: {client_id}") + + return deleted + except Exception as e: # pragma: no cover + await db.rollback() + logger.error(f"Error deleting OAuth client {client_id}: {e!s}") + raise + + +# Singleton instance +oauth_client_repo = OAuthClientRepository(OAuthClient) diff --git a/backend/app/repositories/oauth_consent.py b/backend/app/repositories/oauth_consent.py new file mode 100644 index 0000000..4aac4d8 --- /dev/null +++ b/backend/app/repositories/oauth_consent.py @@ -0,0 +1,112 @@ +# app/repositories/oauth_consent.py +"""Repository for OAuthConsent model.""" + +import logging +from uuid import UUID + +from typing import Any + +from sqlalchemy import and_, delete, select +from sqlalchemy.ext.asyncio import AsyncSession + +from app.models.oauth_client import OAuthClient +from app.models.oauth_provider_token import OAuthConsent + +logger = logging.getLogger(__name__) + + +class OAuthConsentRepository: + """Repository for OAuth consent records (user grants to clients).""" + + async def get_consent( + self, db: AsyncSession, *, user_id: UUID, client_id: str + ) -> OAuthConsent | None: + """Get the consent record for a user-client pair, or None if not found.""" + result = await db.execute( + select(OAuthConsent).where( + and_( + OAuthConsent.user_id == user_id, + OAuthConsent.client_id == client_id, + ) + ) + ) + return result.scalar_one_or_none() + + async def grant_consent( + self, + db: AsyncSession, + *, + user_id: UUID, + client_id: str, + scopes: list[str], + ) -> OAuthConsent: + """ + Create or update consent for a user-client pair. + + If consent already exists, the new scopes are merged with existing ones. + Returns the created or updated consent record. + """ + consent = await self.get_consent(db, user_id=user_id, client_id=client_id) + + if consent: + existing = set(consent.granted_scopes.split()) if consent.granted_scopes else set() + merged = existing | set(scopes) + consent.granted_scopes = " ".join(sorted(merged)) # type: ignore[assignment] + else: + consent = OAuthConsent( + user_id=user_id, + client_id=client_id, + granted_scopes=" ".join(sorted(set(scopes))), + ) + db.add(consent) + + await db.commit() + await db.refresh(consent) + return consent + + async def get_user_consents_with_clients( + self, db: AsyncSession, *, user_id: UUID + ) -> list[dict[str, Any]]: + """Get all consent records for a user joined with client details.""" + result = await db.execute( + select(OAuthConsent, OAuthClient) + .join(OAuthClient, OAuthConsent.client_id == OAuthClient.client_id) + .where(OAuthConsent.user_id == user_id) + ) + rows = result.all() + return [ + { + "client_id": consent.client_id, + "client_name": client.client_name, + "client_description": client.client_description, + "granted_scopes": consent.granted_scopes.split() + if consent.granted_scopes + else [], + "granted_at": consent.created_at.isoformat(), + } + for consent, client in rows + ] + + async def revoke_consent( + self, db: AsyncSession, *, user_id: UUID, client_id: str + ) -> bool: + """ + Delete the consent record for a user-client pair. + + Returns True if a record was found and deleted. + Note: Callers are responsible for also revoking associated tokens. + """ + result = await db.execute( + delete(OAuthConsent).where( + and_( + OAuthConsent.user_id == user_id, + OAuthConsent.client_id == client_id, + ) + ) + ) + await db.commit() + return result.rowcount > 0 # type: ignore[attr-defined] + + +# Singleton instance +oauth_consent_repo = OAuthConsentRepository() diff --git a/backend/app/repositories/oauth_provider_token.py b/backend/app/repositories/oauth_provider_token.py new file mode 100644 index 0000000..9696116 --- /dev/null +++ b/backend/app/repositories/oauth_provider_token.py @@ -0,0 +1,146 @@ +# app/repositories/oauth_provider_token.py +"""Repository for OAuthProviderRefreshToken model.""" + +import logging +from datetime import UTC, datetime, timedelta +from uuid import UUID + +from sqlalchemy import and_, delete, select, update +from sqlalchemy.ext.asyncio import AsyncSession + +from app.models.oauth_provider_token import OAuthProviderRefreshToken + +logger = logging.getLogger(__name__) + + +class OAuthProviderTokenRepository: + """Repository for OAuth provider refresh tokens.""" + + async def create_token( + self, + db: AsyncSession, + *, + token_hash: str, + jti: str, + client_id: str, + user_id: UUID, + scope: str, + expires_at: datetime, + device_info: str | None = None, + ip_address: str | None = None, + ) -> OAuthProviderRefreshToken: + """Create and persist a new refresh token record.""" + token = OAuthProviderRefreshToken( + token_hash=token_hash, + jti=jti, + client_id=client_id, + user_id=user_id, + scope=scope, + expires_at=expires_at, + device_info=device_info, + ip_address=ip_address, + ) + db.add(token) + await db.commit() + return token + + async def get_by_token_hash( + self, db: AsyncSession, *, token_hash: str + ) -> OAuthProviderRefreshToken | None: + """Get refresh token record by SHA-256 token hash.""" + result = await db.execute( + select(OAuthProviderRefreshToken).where( + OAuthProviderRefreshToken.token_hash == token_hash + ) + ) + return result.scalar_one_or_none() + + async def get_by_jti( + self, db: AsyncSession, *, jti: str + ) -> OAuthProviderRefreshToken | None: + """Get refresh token record by JWT ID (JTI).""" + result = await db.execute( + select(OAuthProviderRefreshToken).where( + OAuthProviderRefreshToken.jti == jti + ) + ) + return result.scalar_one_or_none() + + async def revoke( + self, db: AsyncSession, *, token: OAuthProviderRefreshToken + ) -> None: + """Mark a specific token record as revoked.""" + token.revoked = True # type: ignore[assignment] + token.last_used_at = datetime.now(UTC) # type: ignore[assignment] + await db.commit() + + async def revoke_all_for_user_client( + self, db: AsyncSession, *, user_id: UUID, client_id: str + ) -> int: + """ + Revoke all active tokens for a specific user-client pair. + + Used when security incidents are detected (e.g., authorization code reuse). + Returns the number of tokens revoked. + """ + result = await db.execute( + update(OAuthProviderRefreshToken) + .where( + and_( + OAuthProviderRefreshToken.user_id == user_id, + OAuthProviderRefreshToken.client_id == client_id, + OAuthProviderRefreshToken.revoked == False, # noqa: E712 + ) + ) + .values(revoked=True) + ) + count = result.rowcount # type: ignore[attr-defined] + if count > 0: + await db.commit() + return count + + async def revoke_all_for_user( + self, db: AsyncSession, *, user_id: UUID + ) -> int: + """ + Revoke all active tokens for a user across all clients. + + Used when user changes password or logs out everywhere. + Returns the number of tokens revoked. + """ + result = await db.execute( + update(OAuthProviderRefreshToken) + .where( + and_( + OAuthProviderRefreshToken.user_id == user_id, + OAuthProviderRefreshToken.revoked == False, # noqa: E712 + ) + ) + .values(revoked=True) + ) + count = result.rowcount # type: ignore[attr-defined] + if count > 0: + await db.commit() + return count + + async def cleanup_expired( + self, db: AsyncSession, *, cutoff_days: int = 7 + ) -> int: + """ + Delete expired refresh tokens older than cutoff_days. + + Should be called periodically (e.g., daily). + Returns the number of tokens deleted. + """ + cutoff = datetime.now(UTC) - timedelta(days=cutoff_days) + result = await db.execute( + delete(OAuthProviderRefreshToken).where( + OAuthProviderRefreshToken.expires_at < cutoff + ) + ) + await db.commit() + return result.rowcount # type: ignore[attr-defined] + + +# Singleton instance +oauth_provider_token_repo = OAuthProviderTokenRepository() diff --git a/backend/app/repositories/oauth_state.py b/backend/app/repositories/oauth_state.py new file mode 100644 index 0000000..77a0ac0 --- /dev/null +++ b/backend/app/repositories/oauth_state.py @@ -0,0 +1,113 @@ +# app/repositories/oauth_state.py +"""Repository for OAuthState model async CRUD operations.""" + +import logging +from datetime import UTC, datetime + +from pydantic import BaseModel +from sqlalchemy import delete, select +from sqlalchemy.exc import IntegrityError +from sqlalchemy.ext.asyncio import AsyncSession + +from app.core.repository_exceptions import DuplicateEntryError +from app.models.oauth_state import OAuthState +from app.repositories.base import BaseRepository +from app.schemas.oauth import OAuthStateCreate + +logger = logging.getLogger(__name__) + + +class EmptySchema(BaseModel): + """Placeholder schema for repository operations that don't need update schemas.""" + + +class OAuthStateRepository(BaseRepository[OAuthState, OAuthStateCreate, EmptySchema]): + """Repository for OAuth state (CSRF protection).""" + + async def create_state( + self, db: AsyncSession, *, obj_in: OAuthStateCreate + ) -> OAuthState: + """Create a new OAuth state for CSRF protection.""" + try: + db_obj = OAuthState( + state=obj_in.state, + code_verifier=obj_in.code_verifier, + nonce=obj_in.nonce, + provider=obj_in.provider, + redirect_uri=obj_in.redirect_uri, + user_id=obj_in.user_id, + expires_at=obj_in.expires_at, + ) + db.add(db_obj) + await db.commit() + await db.refresh(db_obj) + + logger.debug(f"OAuth state created for {obj_in.provider}") + return db_obj + except IntegrityError as e: # pragma: no cover + await db.rollback() + error_msg = str(e.orig) if hasattr(e, "orig") else str(e) + logger.error(f"OAuth state collision: {error_msg}") + raise DuplicateEntryError("Failed to create OAuth state, please retry") + except Exception as e: # pragma: no cover + await db.rollback() + logger.error(f"Error creating OAuth state: {e!s}", exc_info=True) + raise + + async def get_and_consume_state( + self, db: AsyncSession, *, state: str + ) -> OAuthState | None: + """Get and delete OAuth state (consume it).""" + try: + result = await db.execute( + select(OAuthState).where(OAuthState.state == state) + ) + db_obj = result.scalar_one_or_none() + + if db_obj is None: + logger.warning(f"OAuth state not found: {state[:8]}...") + return None + + now = datetime.now(UTC) + expires_at = db_obj.expires_at + if expires_at.tzinfo is None: + expires_at = expires_at.replace(tzinfo=UTC) + + if expires_at < now: + logger.warning(f"OAuth state expired: {state[:8]}...") + await db.delete(db_obj) + await db.commit() + return None + + await db.delete(db_obj) + await db.commit() + + logger.debug(f"OAuth state consumed: {state[:8]}...") + return db_obj + except Exception as e: # pragma: no cover + await db.rollback() + logger.error(f"Error consuming OAuth state: {e!s}") + raise + + async def cleanup_expired(self, db: AsyncSession) -> int: + """Clean up expired OAuth states.""" + try: + now = datetime.now(UTC) + + stmt = delete(OAuthState).where(OAuthState.expires_at < now) + result = await db.execute(stmt) + await db.commit() + + count = result.rowcount + if count > 0: + logger.info(f"Cleaned up {count} expired OAuth states") + + return count + except Exception as e: # pragma: no cover + await db.rollback() + logger.error(f"Error cleaning up expired OAuth states: {e!s}") + raise + + +# Singleton instance +oauth_state_repo = OAuthStateRepository(OAuthState) diff --git a/backend/app/crud/organization.py b/backend/app/repositories/organization.py old mode 100755 new mode 100644 similarity index 87% rename from backend/app/crud/organization.py rename to backend/app/repositories/organization.py index 85ef256..11e93ad --- a/backend/app/crud/organization.py +++ b/backend/app/repositories/organization.py @@ -1,5 +1,5 @@ -# app/crud/organization_async.py -"""Async CRUD operations for Organization model using SQLAlchemy 2.0 patterns.""" +# app/repositories/organization.py +"""Repository for Organization model async CRUD operations using SQLAlchemy 2.0 patterns.""" import logging from typing import Any @@ -9,10 +9,11 @@ from sqlalchemy import and_, case, func, or_, select from sqlalchemy.exc import IntegrityError from sqlalchemy.ext.asyncio import AsyncSession -from app.crud.base import CRUDBase +from app.core.repository_exceptions import DuplicateEntryError, IntegrityConstraintError from app.models.organization import Organization from app.models.user import User from app.models.user_organization import OrganizationRole, UserOrganization +from app.repositories.base import BaseRepository from app.schemas.organizations import ( OrganizationCreate, OrganizationUpdate, @@ -21,8 +22,8 @@ from app.schemas.organizations import ( logger = logging.getLogger(__name__) -class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUpdate]): - """Async CRUD operations for Organization model.""" +class OrganizationRepository(BaseRepository[Organization, OrganizationCreate, OrganizationUpdate]): + """Repository for Organization model.""" async def get_by_slug(self, db: AsyncSession, *, slug: str) -> Organization | None: """Get organization by slug.""" @@ -54,13 +55,13 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp except IntegrityError as e: await db.rollback() error_msg = str(e.orig) if hasattr(e, "orig") else str(e) - if "slug" in error_msg.lower(): + if "slug" in error_msg.lower() or "unique" in error_msg.lower() or "duplicate" in error_msg.lower(): logger.warning(f"Duplicate slug attempted: {obj_in.slug}") - raise ValueError( + raise DuplicateEntryError( f"Organization with slug '{obj_in.slug}' already exists" ) logger.error(f"Integrity error creating organization: {error_msg}") - raise ValueError(f"Database integrity error: {error_msg}") + raise IntegrityConstraintError(f"Database integrity error: {error_msg}") except Exception as e: await db.rollback() logger.error( @@ -79,16 +80,10 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp sort_by: str = "created_at", sort_order: str = "desc", ) -> tuple[list[Organization], int]: - """ - Get multiple organizations with filtering, searching, and sorting. - - Returns: - Tuple of (organizations list, total count) - """ + """Get multiple organizations with filtering, searching, and sorting.""" try: query = select(Organization) - # Apply filters if is_active is not None: query = query.where(Organization.is_active == is_active) @@ -100,19 +95,16 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp ) query = query.where(search_filter) - # Get total count before pagination count_query = select(func.count()).select_from(query.alias()) count_result = await db.execute(count_query) total = count_result.scalar_one() - # Apply sorting sort_column = getattr(Organization, sort_by, Organization.created_at) if sort_order == "desc": query = query.order_by(sort_column.desc()) else: query = query.order_by(sort_column.asc()) - # Apply pagination query = query.offset(skip).limit(limit) result = await db.execute(query) organizations = list(result.scalars().all()) @@ -149,16 +141,8 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp is_active: bool | None = None, search: str | None = None, ) -> tuple[list[dict[str, Any]], int]: - """ - Get organizations with member counts in a SINGLE QUERY using JOIN and GROUP BY. - This eliminates the N+1 query problem. - - Returns: - Tuple of (list of dicts with org and member_count, total count) - """ + """Get organizations with member counts in a SINGLE QUERY using JOIN and GROUP BY.""" try: - # Build base query with LEFT JOIN and GROUP BY - # Use CASE statement to count only active members query = ( select( Organization, @@ -181,7 +165,6 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp .group_by(Organization.id) ) - # Apply filters if is_active is not None: query = query.where(Organization.is_active == is_active) @@ -193,7 +176,6 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp ) query = query.where(search_filter) - # Get total count count_query = select(func.count(Organization.id)) if is_active is not None: count_query = count_query.where(Organization.is_active == is_active) @@ -203,7 +185,6 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp count_result = await db.execute(count_query) total = count_result.scalar_one() - # Apply pagination and ordering query = ( query.order_by(Organization.created_at.desc()).offset(skip).limit(limit) ) @@ -211,7 +192,6 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp result = await db.execute(query) rows = result.all() - # Convert to list of dicts orgs_with_counts = [ {"organization": org, "member_count": member_count} for org, member_count in rows @@ -236,7 +216,6 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp ) -> UserOrganization: """Add a user to an organization with a specific role.""" try: - # Check if relationship already exists result = await db.execute( select(UserOrganization).where( and_( @@ -248,7 +227,6 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp existing = result.scalar_one_or_none() if existing: - # Reactivate if inactive, or raise error if already active if not existing.is_active: existing.is_active = True existing.role = role @@ -257,9 +235,8 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp await db.refresh(existing) return existing else: - raise ValueError("User is already a member of this organization") + raise DuplicateEntryError("User is already a member of this organization") - # Create new relationship user_org = UserOrganization( user_id=user_id, organization_id=organization_id, @@ -274,7 +251,7 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp except IntegrityError as e: await db.rollback() logger.error(f"Integrity error adding user to organization: {e!s}") - raise ValueError("Failed to add user to organization") + raise IntegrityConstraintError("Failed to add user to organization") except Exception as e: await db.rollback() logger.error(f"Error adding user to organization: {e!s}", exc_info=True) @@ -350,14 +327,8 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp limit: int = 100, is_active: bool = True, ) -> tuple[list[dict[str, Any]], int]: - """ - Get members of an organization with user details. - - Returns: - Tuple of (members list with user details, total count) - """ + """Get members of an organization with user details.""" try: - # Build query with join query = ( select(UserOrganization, User) .join(User, UserOrganization.user_id == User.id) @@ -367,7 +338,6 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp if is_active is not None: query = query.where(UserOrganization.is_active == is_active) - # Get total count count_query = select(func.count()).select_from( select(UserOrganization) .where(UserOrganization.organization_id == organization_id) @@ -381,7 +351,6 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp count_result = await db.execute(count_query) total = count_result.scalar_one() - # Apply ordering and pagination query = ( query.order_by(UserOrganization.created_at.desc()) .offset(skip) @@ -435,15 +404,8 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp async def get_user_organizations_with_details( self, db: AsyncSession, *, user_id: UUID, is_active: bool = True ) -> list[dict[str, Any]]: - """ - Get user's organizations with role and member count in SINGLE QUERY. - Eliminates N+1 problem by using subquery for member counts. - - Returns: - List of dicts with organization, role, and member_count - """ + """Get user's organizations with role and member count in SINGLE QUERY.""" try: - # Subquery to get member counts for each organization member_count_subq = ( select( UserOrganization.organization_id, @@ -454,7 +416,6 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp .subquery() ) - # Main query with JOIN to get org, role, and member count query = ( select( Organization, @@ -531,5 +492,5 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp return role in [OrganizationRole.OWNER, OrganizationRole.ADMIN] -# Create a singleton instance for use across the application -organization = CRUDOrganization(Organization) +# Singleton instance +organization_repo = OrganizationRepository(Organization) diff --git a/backend/app/crud/session.py b/backend/app/repositories/session.py old mode 100755 new mode 100644 similarity index 67% rename from backend/app/crud/session.py rename to backend/app/repositories/session.py index efa139b..78a9953 --- a/backend/app/crud/session.py +++ b/backend/app/repositories/session.py @@ -1,6 +1,5 @@ -""" -Async CRUD operations for user sessions using SQLAlchemy 2.0 patterns. -""" +# app/repositories/session.py +"""Repository for UserSession model async CRUD operations using SQLAlchemy 2.0 patterns.""" import logging import uuid @@ -11,27 +10,19 @@ from sqlalchemy import and_, delete, func, select, update from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm import joinedload -from app.crud.base import CRUDBase +from app.core.repository_exceptions import InvalidInputError, IntegrityConstraintError from app.models.user_session import UserSession +from app.repositories.base import BaseRepository from app.schemas.sessions import SessionCreate, SessionUpdate logger = logging.getLogger(__name__) -class CRUDSession(CRUDBase[UserSession, SessionCreate, SessionUpdate]): - """Async CRUD operations for user sessions.""" +class SessionRepository(BaseRepository[UserSession, SessionCreate, SessionUpdate]): + """Repository for UserSession model.""" async def get_by_jti(self, db: AsyncSession, *, jti: str) -> UserSession | None: - """ - Get session by refresh token JTI. - - Args: - db: Database session - jti: Refresh token JWT ID - - Returns: - UserSession if found, None otherwise - """ + """Get session by refresh token JTI.""" try: result = await db.execute( select(UserSession).where(UserSession.refresh_token_jti == jti) @@ -44,16 +35,7 @@ class CRUDSession(CRUDBase[UserSession, SessionCreate, SessionUpdate]): async def get_active_by_jti( self, db: AsyncSession, *, jti: str ) -> UserSession | None: - """ - Get active session by refresh token JTI. - - Args: - db: Database session - jti: Refresh token JWT ID - - Returns: - Active UserSession if found, None otherwise - """ + """Get active session by refresh token JTI.""" try: result = await db.execute( select(UserSession).where( @@ -76,25 +58,12 @@ class CRUDSession(CRUDBase[UserSession, SessionCreate, SessionUpdate]): active_only: bool = True, with_user: bool = False, ) -> list[UserSession]: - """ - Get all sessions for a user with optional eager loading. - - Args: - db: Database session - user_id: User ID - active_only: If True, return only active sessions - with_user: If True, eager load user relationship to prevent N+1 - - Returns: - List of UserSession objects - """ + """Get all sessions for a user with optional eager loading.""" try: - # Convert user_id string to UUID if needed user_uuid = UUID(user_id) if isinstance(user_id, str) else user_id query = select(UserSession).where(UserSession.user_id == user_uuid) - # Add eager loading if requested to prevent N+1 queries if with_user: query = query.options(joinedload(UserSession.user)) @@ -111,19 +80,7 @@ class CRUDSession(CRUDBase[UserSession, SessionCreate, SessionUpdate]): async def create_session( self, db: AsyncSession, *, obj_in: SessionCreate ) -> UserSession: - """ - Create a new user session. - - Args: - db: Database session - obj_in: SessionCreate schema with session data - - Returns: - Created UserSession - - Raises: - ValueError: If session creation fails - """ + """Create a new user session.""" try: db_obj = UserSession( user_id=obj_in.user_id, @@ -151,21 +108,12 @@ class CRUDSession(CRUDBase[UserSession, SessionCreate, SessionUpdate]): except Exception as e: await db.rollback() logger.error(f"Error creating session: {e!s}", exc_info=True) - raise ValueError(f"Failed to create session: {e!s}") + raise IntegrityConstraintError(f"Failed to create session: {e!s}") async def deactivate( self, db: AsyncSession, *, session_id: str ) -> UserSession | None: - """ - Deactivate a session (logout from device). - - Args: - db: Database session - session_id: Session UUID - - Returns: - Deactivated UserSession if found, None otherwise - """ + """Deactivate a session (logout from device).""" try: session = await self.get(db, id=session_id) if not session: @@ -191,18 +139,8 @@ class CRUDSession(CRUDBase[UserSession, SessionCreate, SessionUpdate]): async def deactivate_all_user_sessions( self, db: AsyncSession, *, user_id: str ) -> int: - """ - Deactivate all active sessions for a user (logout from all devices). - - Args: - db: Database session - user_id: User ID - - Returns: - Number of sessions deactivated - """ + """Deactivate all active sessions for a user (logout from all devices).""" try: - # Convert user_id string to UUID if needed user_uuid = UUID(user_id) if isinstance(user_id, str) else user_id stmt = ( @@ -227,16 +165,7 @@ class CRUDSession(CRUDBase[UserSession, SessionCreate, SessionUpdate]): async def update_last_used( self, db: AsyncSession, *, session: UserSession ) -> UserSession: - """ - Update the last_used_at timestamp for a session. - - Args: - db: Database session - session: UserSession object - - Returns: - Updated UserSession - """ + """Update the last_used_at timestamp for a session.""" try: session.last_used_at = datetime.now(UTC) db.add(session) @@ -256,20 +185,7 @@ class CRUDSession(CRUDBase[UserSession, SessionCreate, SessionUpdate]): new_jti: str, new_expires_at: datetime, ) -> UserSession: - """ - Update session with new refresh token JTI and expiration. - - Called during token refresh. - - Args: - db: Database session - session: UserSession object - new_jti: New refresh token JTI - new_expires_at: New expiration datetime - - Returns: - Updated UserSession - """ + """Update session with new refresh token JTI and expiration.""" try: session.refresh_token_jti = new_jti session.expires_at = new_expires_at @@ -286,27 +202,11 @@ class CRUDSession(CRUDBase[UserSession, SessionCreate, SessionUpdate]): raise async def cleanup_expired(self, db: AsyncSession, *, keep_days: int = 30) -> int: - """ - Clean up expired sessions using optimized bulk DELETE. - - Deletes sessions that are: - - Expired AND inactive - - Older than keep_days - - Uses single DELETE query instead of N individual deletes for efficiency. - - Args: - db: Database session - keep_days: Keep inactive sessions for this many days (for audit) - - Returns: - Number of sessions deleted - """ + """Clean up expired sessions using optimized bulk DELETE.""" try: cutoff_date = datetime.now(UTC) - timedelta(days=keep_days) now = datetime.now(UTC) - # Use bulk DELETE with WHERE clause - single query stmt = delete(UserSession).where( and_( UserSession.is_active == False, # noqa: E712 @@ -330,29 +230,16 @@ class CRUDSession(CRUDBase[UserSession, SessionCreate, SessionUpdate]): raise async def cleanup_expired_for_user(self, db: AsyncSession, *, user_id: str) -> int: - """ - Clean up expired and inactive sessions for a specific user. - - Uses single bulk DELETE query for efficiency instead of N individual deletes. - - Args: - db: Database session - user_id: User ID to cleanup sessions for - - Returns: - Number of sessions deleted - """ + """Clean up expired and inactive sessions for a specific user.""" try: - # Validate UUID try: uuid_obj = uuid.UUID(user_id) except (ValueError, AttributeError): logger.error(f"Invalid UUID format: {user_id}") - raise ValueError(f"Invalid user ID format: {user_id}") + raise InvalidInputError(f"Invalid user ID format: {user_id}") now = datetime.now(UTC) - # Use bulk DELETE with WHERE clause - single query stmt = delete(UserSession).where( and_( UserSession.user_id == uuid_obj, @@ -380,18 +267,8 @@ class CRUDSession(CRUDBase[UserSession, SessionCreate, SessionUpdate]): raise async def get_user_session_count(self, db: AsyncSession, *, user_id: str) -> int: - """ - Get count of active sessions for a user. - - Args: - db: Database session - user_id: User ID - - Returns: - Number of active sessions - """ + """Get count of active sessions for a user.""" try: - # Convert user_id string to UUID if needed user_uuid = UUID(user_id) if isinstance(user_id, str) else user_id result = await db.execute( @@ -413,31 +290,16 @@ class CRUDSession(CRUDBase[UserSession, SessionCreate, SessionUpdate]): active_only: bool = True, with_user: bool = True, ) -> tuple[list[UserSession], int]: - """ - Get all sessions across all users with pagination (admin only). - - Args: - db: Database session - skip: Number of records to skip - limit: Maximum number of records to return - active_only: If True, return only active sessions - with_user: If True, eager load user relationship to prevent N+1 - - Returns: - Tuple of (list of UserSession objects, total count) - """ + """Get all sessions across all users with pagination (admin only).""" try: - # Build query query = select(UserSession) - # Add eager loading if requested to prevent N+1 queries if with_user: query = query.options(joinedload(UserSession.user)) if active_only: query = query.where(UserSession.is_active) - # Get total count count_query = select(func.count(UserSession.id)) if active_only: count_query = count_query.where(UserSession.is_active) @@ -445,7 +307,6 @@ class CRUDSession(CRUDBase[UserSession, SessionCreate, SessionUpdate]): count_result = await db.execute(count_query) total = count_result.scalar_one() - # Apply pagination and ordering query = ( query.order_by(UserSession.last_used_at.desc()) .offset(skip) @@ -462,5 +323,5 @@ class CRUDSession(CRUDBase[UserSession, SessionCreate, SessionUpdate]): raise -# Create singleton instance -session = CRUDSession(UserSession) +# Singleton instance +session_repo = SessionRepository(UserSession) diff --git a/backend/app/crud/user.py b/backend/app/repositories/user.py old mode 100755 new mode 100644 similarity index 71% rename from backend/app/crud/user.py rename to backend/app/repositories/user.py index d938303..97b4dcd --- a/backend/app/crud/user.py +++ b/backend/app/repositories/user.py @@ -1,5 +1,5 @@ -# app/crud/user_async.py -"""Async CRUD operations for User model using SQLAlchemy 2.0 patterns.""" +# app/repositories/user.py +"""Repository for User model async CRUD operations using SQLAlchemy 2.0 patterns.""" import logging from datetime import UTC, datetime @@ -11,15 +11,16 @@ from sqlalchemy.exc import IntegrityError from sqlalchemy.ext.asyncio import AsyncSession from app.core.auth import get_password_hash_async -from app.crud.base import CRUDBase +from app.core.repository_exceptions import DuplicateEntryError, InvalidInputError from app.models.user import User +from app.repositories.base import BaseRepository from app.schemas.users import UserCreate, UserUpdate logger = logging.getLogger(__name__) -class CRUDUser(CRUDBase[User, UserCreate, UserUpdate]): - """Async CRUD operations for User model.""" +class UserRepository(BaseRepository[User, UserCreate, UserUpdate]): + """Repository for User model.""" async def get_by_email(self, db: AsyncSession, *, email: str) -> User | None: """Get user by email address.""" @@ -33,7 +34,6 @@ class CRUDUser(CRUDBase[User, UserCreate, UserUpdate]): async def create(self, db: AsyncSession, *, obj_in: UserCreate) -> User: """Create a new user with async password hashing and error handling.""" try: - # Hash password asynchronously to avoid blocking event loop password_hash = await get_password_hash_async(obj_in.password) db_obj = User( @@ -58,14 +58,48 @@ class CRUDUser(CRUDBase[User, UserCreate, UserUpdate]): error_msg = str(e.orig) if hasattr(e, "orig") else str(e) if "email" in error_msg.lower(): logger.warning(f"Duplicate email attempted: {obj_in.email}") - raise ValueError(f"User with email {obj_in.email} already exists") + raise DuplicateEntryError(f"User with email {obj_in.email} already exists") logger.error(f"Integrity error creating user: {error_msg}") - raise ValueError(f"Database integrity error: {error_msg}") + raise DuplicateEntryError(f"Database integrity error: {error_msg}") except Exception as e: await db.rollback() logger.error(f"Unexpected error creating user: {e!s}", exc_info=True) raise + async def create_oauth_user( + self, + db: AsyncSession, + *, + email: str, + first_name: str = "User", + last_name: str | None = None, + ) -> User: + """Create a new passwordless user for OAuth sign-in.""" + try: + db_obj = User( + email=email, + password_hash=None, # OAuth-only user + first_name=first_name, + last_name=last_name, + is_active=True, + is_superuser=False, + ) + db.add(db_obj) + await db.flush() # Get user.id without committing + return db_obj + except IntegrityError as e: + await db.rollback() + error_msg = str(e.orig) if hasattr(e, "orig") else str(e) + if "email" in error_msg.lower(): + logger.warning(f"Duplicate email attempted: {email}") + raise DuplicateEntryError(f"User with email {email} already exists") + logger.error(f"Integrity error creating OAuth user: {error_msg}") + raise DuplicateEntryError(f"Database integrity error: {error_msg}") + except Exception as e: + await db.rollback() + logger.error(f"Unexpected error creating OAuth user: {e!s}", exc_info=True) + raise + async def update( self, db: AsyncSession, *, db_obj: User, obj_in: UserUpdate | dict[str, Any] ) -> User: @@ -75,8 +109,6 @@ class CRUDUser(CRUDBase[User, UserCreate, UserUpdate]): else: update_data = obj_in.model_dump(exclude_unset=True) - # Handle password separately if it exists in update data - # Hash password asynchronously to avoid blocking event loop if "password" in update_data: update_data["password_hash"] = await get_password_hash_async( update_data["password"] @@ -85,6 +117,15 @@ class CRUDUser(CRUDBase[User, UserCreate, UserUpdate]): return await super().update(db, db_obj=db_obj, obj_in=update_data) + async def update_password( + self, db: AsyncSession, *, user: User, password_hash: str + ) -> User: + """Set a new password hash on a user and commit.""" + user.password_hash = password_hash + await db.commit() + await db.refresh(user) + return user + async def get_multi_with_total( self, db: AsyncSession, @@ -96,43 +137,23 @@ class CRUDUser(CRUDBase[User, UserCreate, UserUpdate]): filters: dict[str, Any] | None = None, search: str | None = None, ) -> tuple[list[User], int]: - """ - Get multiple users with total count, filtering, sorting, and search. - - Args: - db: Database session - skip: Number of records to skip - limit: Maximum number of records to return - sort_by: Field name to sort by - sort_order: Sort order ("asc" or "desc") - filters: Dictionary of filters (field_name: value) - search: Search term to match against email, first_name, last_name - - Returns: - Tuple of (users list, total count) - """ - # Validate pagination + """Get multiple users with total count, filtering, sorting, and search.""" if skip < 0: - raise ValueError("skip must be non-negative") + raise InvalidInputError("skip must be non-negative") if limit < 0: - raise ValueError("limit must be non-negative") + raise InvalidInputError("limit must be non-negative") if limit > 1000: - raise ValueError("Maximum limit is 1000") + raise InvalidInputError("Maximum limit is 1000") try: - # Build base query query = select(User) - - # Exclude soft-deleted users query = query.where(User.deleted_at.is_(None)) - # Apply filters if filters: for field, value in filters.items(): if hasattr(User, field) and value is not None: query = query.where(getattr(User, field) == value) - # Apply search if search: search_filter = or_( User.email.ilike(f"%{search}%"), @@ -141,14 +162,12 @@ class CRUDUser(CRUDBase[User, UserCreate, UserUpdate]): ) query = query.where(search_filter) - # Get total count from sqlalchemy import func count_query = select(func.count()).select_from(query.alias()) count_result = await db.execute(count_query) total = count_result.scalar_one() - # Apply sorting if sort_by and hasattr(User, sort_by): sort_column = getattr(User, sort_by) if sort_order.lower() == "desc": @@ -156,7 +175,6 @@ class CRUDUser(CRUDBase[User, UserCreate, UserUpdate]): else: query = query.order_by(sort_column.asc()) - # Apply pagination query = query.offset(skip).limit(limit) result = await db.execute(query) users = list(result.scalars().all()) @@ -170,26 +188,15 @@ class CRUDUser(CRUDBase[User, UserCreate, UserUpdate]): async def bulk_update_status( self, db: AsyncSession, *, user_ids: list[UUID], is_active: bool ) -> int: - """ - Bulk update is_active status for multiple users. - - Args: - db: Database session - user_ids: List of user IDs to update - is_active: New active status - - Returns: - Number of users updated - """ + """Bulk update is_active status for multiple users.""" try: if not user_ids: return 0 - # Use UPDATE with WHERE IN for efficiency stmt = ( update(User) .where(User.id.in_(user_ids)) - .where(User.deleted_at.is_(None)) # Don't update deleted users + .where(User.deleted_at.is_(None)) .values(is_active=is_active, updated_at=datetime.now(UTC)) ) @@ -212,34 +219,20 @@ class CRUDUser(CRUDBase[User, UserCreate, UserUpdate]): user_ids: list[UUID], exclude_user_id: UUID | None = None, ) -> int: - """ - Bulk soft delete multiple users. - - Args: - db: Database session - user_ids: List of user IDs to delete - exclude_user_id: Optional user ID to exclude (e.g., the admin performing the action) - - Returns: - Number of users deleted - """ + """Bulk soft delete multiple users.""" try: if not user_ids: return 0 - # Remove excluded user from list filtered_ids = [uid for uid in user_ids if uid != exclude_user_id] if not filtered_ids: return 0 - # Use UPDATE with WHERE IN for efficiency stmt = ( update(User) .where(User.id.in_(filtered_ids)) - .where( - User.deleted_at.is_(None) - ) # Don't re-delete already deleted users + .where(User.deleted_at.is_(None)) .values( deleted_at=datetime.now(UTC), is_active=False, @@ -268,5 +261,5 @@ class CRUDUser(CRUDBase[User, UserCreate, UserUpdate]): return user.is_superuser -# Create a singleton instance for use across the application -user = CRUDUser(User) +# Singleton instance +user_repo = UserRepository(User) diff --git a/backend/app/schemas/oauth.py b/backend/app/schemas/oauth.py index a547ed8..fab3c69 100644 --- a/backend/app/schemas/oauth.py +++ b/backend/app/schemas/oauth.py @@ -60,8 +60,8 @@ class OAuthAccountCreate(OAuthAccountBase): user_id: UUID provider_user_id: str = Field(..., max_length=255) - access_token_encrypted: str | None = None - refresh_token_encrypted: str | None = None + access_token: str | None = None + refresh_token: str | None = None token_expires_at: datetime | None = None diff --git a/backend/app/services/__init__.py b/backend/app/services/__init__.py index 1487153..f06b9a5 100644 --- a/backend/app/services/__init__.py +++ b/backend/app/services/__init__.py @@ -1,5 +1,19 @@ # app/services/__init__.py +from . import oauth_provider_service from .auth_service import AuthService from .oauth_service import OAuthService +from .organization_service import OrganizationService, organization_service +from .session_service import SessionService, session_service +from .user_service import UserService, user_service -__all__ = ["AuthService", "OAuthService"] +__all__ = [ + "AuthService", + "OAuthService", + "UserService", + "OrganizationService", + "SessionService", + "oauth_provider_service", + "user_service", + "organization_service", + "session_service", +] diff --git a/backend/app/services/auth_service.py b/backend/app/services/auth_service.py index bbfdbc7..8f4b686 100755 --- a/backend/app/services/auth_service.py +++ b/backend/app/services/auth_service.py @@ -2,7 +2,6 @@ import logging from uuid import UUID -from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession from app.core.auth import ( @@ -14,12 +13,18 @@ from app.core.auth import ( verify_password_async, ) from app.core.config import settings -from app.core.exceptions import AuthenticationError +from app.core.exceptions import AuthenticationError, DuplicateError +from app.core.repository_exceptions import DuplicateEntryError from app.models.user import User +from app.repositories.user import user_repo from app.schemas.users import Token, UserCreate, UserResponse logger = logging.getLogger(__name__) +# Pre-computed bcrypt hash used for constant-time comparison when user is not found, +# preventing timing attacks that could enumerate valid email addresses. +_DUMMY_HASH = "$2b$12$EixZaYVK1fsbw1ZfbX3OXePaWxn96p36zLFbnJHfxPSEFBzXKiHia" + class AuthService: """Service for handling authentication operations""" @@ -39,10 +44,12 @@ class AuthService: Returns: User if authenticated, None otherwise """ - result = await db.execute(select(User).where(User.email == email)) - user = result.scalar_one_or_none() + user = await user_repo.get_by_email(db, email=email) if not user: + # Perform a dummy verification to match timing of a real bcrypt check, + # preventing email enumeration via response-time differences. + await verify_password_async(password, _DUMMY_HASH) return None # Verify password asynchronously to avoid blocking event loop @@ -71,39 +78,22 @@ class AuthService: """ try: # Check if user already exists - result = await db.execute(select(User).where(User.email == user_data.email)) - existing_user = result.scalar_one_or_none() + existing_user = await user_repo.get_by_email(db, email=user_data.email) if existing_user: - raise AuthenticationError("User with this email already exists") + raise DuplicateError("User with this email already exists") - # Create new user with async password hashing - # Hash password asynchronously to avoid blocking event loop - hashed_password = await get_password_hash_async(user_data.password) - - # Create user object from model - user = User( - email=user_data.email, - password_hash=hashed_password, - first_name=user_data.first_name, - last_name=user_data.last_name, - phone_number=user_data.phone_number, - is_active=True, - is_superuser=False, - ) - - db.add(user) - await db.commit() - await db.refresh(user) + # Delegate creation (hashing + commit) to the repository + user = await user_repo.create(db, obj_in=user_data) logger.info(f"User created successfully: {user.email}") return user - except AuthenticationError: - # Re-raise authentication errors without rollback + except (AuthenticationError, DuplicateError): + # Re-raise API exceptions without rollback raise + except DuplicateEntryError as e: + raise DuplicateError(str(e)) except Exception as e: - # Rollback on any database errors - await db.rollback() logger.error(f"Error creating user: {e!s}", exc_info=True) raise AuthenticationError(f"Failed to create user: {e!s}") @@ -168,8 +158,7 @@ class AuthService: user_id = token_data.user_id # Get user from database - result = await db.execute(select(User).where(User.id == user_id)) - user = result.scalar_one_or_none() + user = await user_repo.get(db, id=str(user_id)) if not user or not user.is_active: raise TokenInvalidError("Invalid user or inactive account") @@ -200,8 +189,7 @@ class AuthService: AuthenticationError: If current password is incorrect or update fails """ try: - result = await db.execute(select(User).where(User.id == user_id)) - user = result.scalar_one_or_none() + user = await user_repo.get(db, id=str(user_id)) if not user: raise AuthenticationError("User not found") @@ -210,8 +198,8 @@ class AuthService: raise AuthenticationError("Current password is incorrect") # Hash new password asynchronously to avoid blocking event loop - user.password_hash = await get_password_hash_async(new_password) - await db.commit() + new_hash = await get_password_hash_async(new_password) + await user_repo.update_password(db, user=user, password_hash=new_hash) logger.info(f"Password changed successfully for user {user_id}") return True @@ -226,3 +214,32 @@ class AuthService: f"Error changing password for user {user_id}: {e!s}", exc_info=True ) raise AuthenticationError(f"Failed to change password: {e!s}") + + @staticmethod + async def reset_password( + db: AsyncSession, *, email: str, new_password: str + ) -> User: + """ + Reset a user's password without requiring the current password. + + Args: + db: Database session + email: User email address + new_password: New password to set + + Returns: + Updated user + + Raises: + AuthenticationError: If user not found or inactive + """ + user = await user_repo.get_by_email(db, email=email) + if not user: + raise AuthenticationError("User not found") + if not user.is_active: + raise AuthenticationError("User account is inactive") + + new_hash = await get_password_hash_async(new_password) + user = await user_repo.update_password(db, user=user, password_hash=new_hash) + logger.info(f"Password reset successfully for {email}") + return user diff --git a/backend/app/services/oauth_provider_service.py b/backend/app/services/oauth_provider_service.py index 530c9e8..0855fcb 100755 --- a/backend/app/services/oauth_provider_service.py +++ b/backend/app/services/oauth_provider_service.py @@ -26,14 +26,17 @@ from typing import Any from uuid import UUID from jose import jwt -from sqlalchemy import and_, delete, select from sqlalchemy.ext.asyncio import AsyncSession from app.core.config import settings -from app.models.oauth_authorization_code import OAuthAuthorizationCode from app.models.oauth_client import OAuthClient -from app.models.oauth_provider_token import OAuthConsent, OAuthProviderRefreshToken +from app.schemas.oauth import OAuthClientCreate from app.models.user import User +from app.repositories.oauth_authorization_code import oauth_authorization_code_repo +from app.repositories.oauth_client import oauth_client_repo +from app.repositories.oauth_consent import oauth_consent_repo +from app.repositories.oauth_provider_token import oauth_provider_token_repo +from app.repositories.user import user_repo logger = logging.getLogger(__name__) @@ -161,15 +164,7 @@ def join_scope(scopes: list[str]) -> str: async def get_client(db: AsyncSession, client_id: str) -> OAuthClient | None: """Get OAuth client by client_id.""" - result = await db.execute( - select(OAuthClient).where( - and_( - OAuthClient.client_id == client_id, - OAuthClient.is_active == True, # noqa: E712 - ) - ) - ) - return result.scalar_one_or_none() + return await oauth_client_repo.get_by_client_id(db, client_id=client_id) async def validate_client( @@ -204,21 +199,19 @@ async def validate_client( if not client.client_secret_hash: raise InvalidClientError("Client not configured with secret") - # SECURITY: Verify secret using bcrypt (not SHA-256) - # Supports both bcrypt and legacy SHA-256 hashes for migration + # SECURITY: Verify secret using bcrypt from app.core.auth import verify_password stored_hash = str(client.client_secret_hash) - if stored_hash.startswith("$2"): - # New bcrypt format - if not verify_password(client_secret, stored_hash): - raise InvalidClientError("Invalid client secret") - else: - # Legacy SHA-256 format - computed_hash = hashlib.sha256(client_secret.encode()).hexdigest() - if not secrets.compare_digest(computed_hash, stored_hash): - raise InvalidClientError("Invalid client secret") + if not stored_hash.startswith("$2"): + raise InvalidClientError( + "Client secret uses deprecated hash format. " + "Please regenerate your client credentials." + ) + + if not verify_password(client_secret, stored_hash): + raise InvalidClientError("Invalid client secret") return client @@ -311,23 +304,20 @@ async def create_authorization_code( minutes=AUTHORIZATION_CODE_EXPIRY_MINUTES ) - auth_code = OAuthAuthorizationCode( + await oauth_authorization_code_repo.create_code( + db, code=code, client_id=client.client_id, user_id=user.id, redirect_uri=redirect_uri, scope=scope, + expires_at=expires_at, code_challenge=code_challenge, code_challenge_method=code_challenge_method, state=state, nonce=nonce, - expires_at=expires_at, - used=False, ) - db.add(auth_code) - await db.commit() - logger.info( f"Created authorization code for user {user.id} and client {client.client_id}" ) @@ -366,30 +356,14 @@ async def exchange_authorization_code( """ # Atomically mark code as used and fetch it (prevents race condition) # RFC 6749 Section 4.1.2: Authorization codes MUST be single-use - from sqlalchemy import update - - # First, atomically mark the code as used and get affected count - update_stmt = ( - update(OAuthAuthorizationCode) - .where( - and_( - OAuthAuthorizationCode.code == code, - OAuthAuthorizationCode.used == False, # noqa: E712 - ) - ) - .values(used=True) - .returning(OAuthAuthorizationCode.id) + updated_id = await oauth_authorization_code_repo.consume_code_atomically( + db, code=code ) - result = await db.execute(update_stmt) - updated_id = result.scalar_one_or_none() if not updated_id: # Either code doesn't exist or was already used # Check if it exists to provide appropriate error - check_result = await db.execute( - select(OAuthAuthorizationCode).where(OAuthAuthorizationCode.code == code) - ) - existing_code = check_result.scalar_one_or_none() + existing_code = await oauth_authorization_code_repo.get_by_code(db, code=code) if existing_code and existing_code.used: # Code reuse is a security incident - revoke all tokens for this grant @@ -404,11 +378,9 @@ async def exchange_authorization_code( raise InvalidGrantError("Invalid authorization code") # Now fetch the full auth code record - auth_code_result = await db.execute( - select(OAuthAuthorizationCode).where(OAuthAuthorizationCode.id == updated_id) - ) - auth_code = auth_code_result.scalar_one() - await db.commit() + auth_code = await oauth_authorization_code_repo.get_by_id(db, code_id=updated_id) + if auth_code is None: + raise InvalidGrantError("Authorization code not found after consumption") if auth_code.is_expired: raise InvalidGrantError("Authorization code has expired") @@ -452,8 +424,7 @@ async def exchange_authorization_code( raise InvalidGrantError("PKCE required for public clients") # Get user - user_result = await db.execute(select(User).where(User.id == auth_code.user_id)) - user = user_result.scalar_one_or_none() + user = await user_repo.get(db, id=str(auth_code.user_id)) if not user or not user.is_active: raise InvalidGrantError("User not found or inactive") @@ -543,7 +514,8 @@ async def create_tokens( refresh_token_hash = hash_token(refresh_token) # Store refresh token in database - refresh_token_record = OAuthProviderRefreshToken( + await oauth_provider_token_repo.create_token( + db, token_hash=refresh_token_hash, jti=jti, client_id=client.client_id, @@ -553,8 +525,6 @@ async def create_tokens( device_info=device_info, ip_address=ip_address, ) - db.add(refresh_token_record) - await db.commit() logger.info(f"Issued tokens for user {user.id} to client {client.client_id}") @@ -599,12 +569,9 @@ async def refresh_tokens( """ # Find refresh token token_hash = hash_token(refresh_token) - result = await db.execute( - select(OAuthProviderRefreshToken).where( - OAuthProviderRefreshToken.token_hash == token_hash - ) + token_record = await oauth_provider_token_repo.get_by_token_hash( + db, token_hash=token_hash ) - token_record: OAuthProviderRefreshToken | None = result.scalar_one_or_none() if not token_record: raise InvalidGrantError("Invalid refresh token") @@ -631,8 +598,7 @@ async def refresh_tokens( ) # Get user - user_result = await db.execute(select(User).where(User.id == token_record.user_id)) - user = user_result.scalar_one_or_none() + user = await user_repo.get(db, id=str(token_record.user_id)) if not user or not user.is_active: raise InvalidGrantError("User not found or inactive") @@ -648,9 +614,7 @@ async def refresh_tokens( final_scope = token_scope # Revoke old refresh token (token rotation) - token_record.revoked = True # type: ignore[assignment] - token_record.last_used_at = datetime.now(UTC) # type: ignore[assignment] - await db.commit() + await oauth_provider_token_repo.revoke(db, token=token_record) # Issue new tokens device = str(token_record.device_info) if token_record.device_info else None @@ -697,20 +661,16 @@ async def revoke_token( # Try as refresh token first (more likely) if token_type_hint != "access_token": token_hash = hash_token(token) - result = await db.execute( - select(OAuthProviderRefreshToken).where( - OAuthProviderRefreshToken.token_hash == token_hash - ) + refresh_record = await oauth_provider_token_repo.get_by_token_hash( + db, token_hash=token_hash ) - refresh_record = result.scalar_one_or_none() if refresh_record: # Validate client if provided if client_id and refresh_record.client_id != client_id: raise InvalidClientError("Token was not issued to this client") - refresh_record.revoked = True # type: ignore[assignment] - await db.commit() + await oauth_provider_token_repo.revoke(db, token=refresh_record) logger.info(f"Revoked refresh token {refresh_record.jti[:8]}...") return True @@ -731,17 +691,13 @@ async def revoke_token( jti = payload.get("jti") if jti: # Find and revoke the associated refresh token - result = await db.execute( - select(OAuthProviderRefreshToken).where( - OAuthProviderRefreshToken.jti == jti - ) + refresh_record = await oauth_provider_token_repo.get_by_jti( + db, jti=jti ) - refresh_record = result.scalar_one_or_none() if refresh_record: if client_id and refresh_record.client_id != client_id: raise InvalidClientError("Token was not issued to this client") - refresh_record.revoked = True # type: ignore[assignment] - await db.commit() + await oauth_provider_token_repo.revoke(db, token=refresh_record) logger.info( f"Revoked refresh token via access token JTI {jti[:8]}..." ) @@ -770,24 +726,11 @@ async def revoke_tokens_for_user_client( Returns: Number of tokens revoked """ - result = await db.execute( - select(OAuthProviderRefreshToken).where( - and_( - OAuthProviderRefreshToken.user_id == user_id, - OAuthProviderRefreshToken.client_id == client_id, - OAuthProviderRefreshToken.revoked == False, # noqa: E712 - ) - ) + count = await oauth_provider_token_repo.revoke_all_for_user_client( + db, user_id=user_id, client_id=client_id ) - tokens = result.scalars().all() - - count = 0 - for token in tokens: - token.revoked = True # type: ignore[assignment] - count += 1 if count > 0: - await db.commit() logger.warning( f"Revoked {count} tokens for user {user_id} and client {client_id}" ) @@ -808,23 +751,9 @@ async def revoke_all_user_tokens(db: AsyncSession, user_id: UUID) -> int: Returns: Number of tokens revoked """ - result = await db.execute( - select(OAuthProviderRefreshToken).where( - and_( - OAuthProviderRefreshToken.user_id == user_id, - OAuthProviderRefreshToken.revoked == False, # noqa: E712 - ) - ) - ) - tokens = result.scalars().all() - - count = 0 - for token in tokens: - token.revoked = True # type: ignore[assignment] - count += 1 + count = await oauth_provider_token_repo.revoke_all_for_user(db, user_id=user_id) if count > 0: - await db.commit() logger.info(f"Revoked {count} OAuth provider tokens for user {user_id}") return count @@ -878,12 +807,9 @@ async def introspect_token( # Check if associated refresh token is revoked jti = payload.get("jti") if jti: - result = await db.execute( - select(OAuthProviderRefreshToken).where( - OAuthProviderRefreshToken.jti == jti - ) + refresh_record = await oauth_provider_token_repo.get_by_jti( + db, jti=jti ) - refresh_record = result.scalar_one_or_none() if refresh_record and refresh_record.revoked: return {"active": False} @@ -907,12 +833,9 @@ async def introspect_token( # Try as refresh token if token_type_hint != "access_token": token_hash = hash_token(token) - result = await db.execute( - select(OAuthProviderRefreshToken).where( - OAuthProviderRefreshToken.token_hash == token_hash - ) + refresh_record = await oauth_provider_token_repo.get_by_token_hash( + db, token_hash=token_hash ) - refresh_record = result.scalar_one_or_none() if refresh_record and refresh_record.is_valid: return { @@ -937,17 +860,9 @@ async def get_consent( db: AsyncSession, user_id: UUID, client_id: str, -) -> OAuthConsent | None: +): """Get existing consent record for user-client pair.""" - result = await db.execute( - select(OAuthConsent).where( - and_( - OAuthConsent.user_id == user_id, - OAuthConsent.client_id == client_id, - ) - ) - ) - return result.scalar_one_or_none() + return await oauth_consent_repo.get_consent(db, user_id=user_id, client_id=client_id) async def check_consent( @@ -972,31 +887,15 @@ async def grant_consent( user_id: UUID, client_id: str, scopes: list[str], -) -> OAuthConsent: +): """ Grant or update consent for a user-client pair. If consent already exists, updates the granted scopes. """ - consent = await get_consent(db, user_id, client_id) - - if consent: - # Merge scopes - granted = str(consent.granted_scopes) if consent.granted_scopes else "" - existing = set(parse_scope(granted)) - new_scopes = existing | set(scopes) - consent.granted_scopes = join_scope(list(new_scopes)) # type: ignore[assignment] - else: - consent = OAuthConsent( - user_id=user_id, - client_id=client_id, - granted_scopes=join_scope(scopes), - ) - db.add(consent) - - await db.commit() - await db.refresh(consent) - return consent + return await oauth_consent_repo.grant_consent( + db, user_id=user_id, client_id=client_id, scopes=scopes + ) async def revoke_consent( @@ -1009,21 +908,13 @@ async def revoke_consent( Returns True if consent was found and revoked. """ - # Delete consent record - result = await db.execute( - delete(OAuthConsent).where( - and_( - OAuthConsent.user_id == user_id, - OAuthConsent.client_id == client_id, - ) - ) - ) - - # Revoke all tokens + # Revoke all tokens first await revoke_tokens_for_user_client(db, user_id, client_id) - await db.commit() - return result.rowcount > 0 # type: ignore[attr-defined] + # Delete consent record + return await oauth_consent_repo.revoke_consent( + db, user_id=user_id, client_id=client_id + ) # ============================================================================ @@ -1031,6 +922,26 @@ async def revoke_consent( # ============================================================================ +async def register_client(db: AsyncSession, client_data: OAuthClientCreate) -> tuple: + """Create a new OAuth client. Returns (client, secret).""" + return await oauth_client_repo.create_client(db, obj_in=client_data) + + +async def list_clients(db: AsyncSession) -> list: + """List all registered OAuth clients.""" + return await oauth_client_repo.get_all_clients(db) + + +async def delete_client_by_id(db: AsyncSession, client_id: str) -> None: + """Delete an OAuth client by client_id.""" + await oauth_client_repo.delete_client(db, client_id=client_id) + + +async def list_user_consents(db: AsyncSession, user_id: UUID) -> list[dict]: + """Get all OAuth consents for a user with client details.""" + return await oauth_consent_repo.get_user_consents_with_clients(db, user_id=user_id) + + async def cleanup_expired_codes(db: AsyncSession) -> int: """ Delete expired authorization codes. @@ -1040,13 +951,7 @@ async def cleanup_expired_codes(db: AsyncSession) -> int: Returns: Number of codes deleted """ - result = await db.execute( - delete(OAuthAuthorizationCode).where( - OAuthAuthorizationCode.expires_at < datetime.now(UTC) - ) - ) - await db.commit() - return result.rowcount # type: ignore[attr-defined] + return await oauth_authorization_code_repo.cleanup_expired(db) async def cleanup_expired_tokens(db: AsyncSession) -> int: @@ -1058,12 +963,4 @@ async def cleanup_expired_tokens(db: AsyncSession) -> int: Returns: Number of tokens deleted """ - # Delete tokens that are both expired AND revoked (or just very old) - cutoff = datetime.now(UTC) - timedelta(days=7) - result = await db.execute( - delete(OAuthProviderRefreshToken).where( - OAuthProviderRefreshToken.expires_at < cutoff - ) - ) - await db.commit() - return result.rowcount # type: ignore[attr-defined] + return await oauth_provider_token_repo.cleanup_expired(db, cutoff_days=7) diff --git a/backend/app/services/oauth_service.py b/backend/app/services/oauth_service.py index 561a4bf..20c4e73 100644 --- a/backend/app/services/oauth_service.py +++ b/backend/app/services/oauth_service.py @@ -19,14 +19,15 @@ from typing import TypedDict, cast from uuid import UUID from authlib.integrations.httpx_client import AsyncOAuth2Client -from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession from app.core.auth import create_access_token, create_refresh_token from app.core.config import settings from app.core.exceptions import AuthenticationError -from app.crud import oauth_account, oauth_state +from app.repositories.oauth_account import oauth_account_repo as oauth_account +from app.repositories.oauth_state import oauth_state_repo as oauth_state from app.models.user import User +from app.repositories.user import user_repo from app.schemas.oauth import ( OAuthAccountCreate, OAuthCallbackResponse, @@ -343,7 +344,7 @@ class OAuthService: await oauth_account.update_tokens( db, account=existing_oauth, - access_token_encrypted=token.get("access_token"), refresh_token_encrypted=token.get("refresh_token"), token_expires_at=datetime.now(UTC) + access_token=token.get("access_token"), refresh_token=token.get("refresh_token"), token_expires_at=datetime.now(UTC) + timedelta(seconds=token.get("expires_in", 3600)), ) @@ -351,10 +352,7 @@ class OAuthService: elif state_record.user_id: # Account linking flow (user is already logged in) - result = await db.execute( - select(User).where(User.id == state_record.user_id) - ) - user = result.scalar_one_or_none() + user = await user_repo.get(db, id=str(state_record.user_id)) if not user: raise AuthenticationError("User not found for account linking") @@ -375,7 +373,7 @@ class OAuthService: provider=provider, provider_user_id=provider_user_id, provider_email=provider_email, - access_token_encrypted=token.get("access_token"), refresh_token_encrypted=token.get("refresh_token"), token_expires_at=datetime.now(UTC) + access_token=token.get("access_token"), refresh_token=token.get("refresh_token"), token_expires_at=datetime.now(UTC) + timedelta(seconds=token.get("expires_in", 3600)) if token.get("expires_in") else None, @@ -389,10 +387,7 @@ class OAuthService: user = None if provider_email and settings.OAUTH_AUTO_LINK_BY_EMAIL: - result = await db.execute( - select(User).where(User.email == provider_email) - ) - user = result.scalar_one_or_none() + user = await user_repo.get_by_email(db, email=provider_email) if user: # Auto-link to existing user @@ -416,8 +411,8 @@ class OAuthService: provider=provider, provider_user_id=provider_user_id, provider_email=provider_email, - access_token_encrypted=token.get("access_token"), - refresh_token_encrypted=token.get("refresh_token"), + access_token=token.get("access_token"), + refresh_token=token.get("refresh_token"), token_expires_at=datetime.now(UTC) + timedelta(seconds=token.get("expires_in", 3600)) if token.get("expires_in") @@ -644,14 +639,13 @@ class OAuthService: provider=provider, provider_user_id=provider_user_id, provider_email=email, - access_token_encrypted=token.get("access_token"), refresh_token_encrypted=token.get("refresh_token"), token_expires_at=datetime.now(UTC) + access_token=token.get("access_token"), refresh_token=token.get("refresh_token"), token_expires_at=datetime.now(UTC) + timedelta(seconds=token.get("expires_in", 3600)) if token.get("expires_in") else None, ) await oauth_account.create_account(db, obj_in=oauth_create) - await db.commit() await db.refresh(user) return user @@ -701,6 +695,20 @@ class OAuthService: logger.info(f"OAuth provider unlinked: {provider} from {user.email}") return True + @staticmethod + async def get_user_accounts(db: AsyncSession, *, user_id: UUID) -> list: + """Get all OAuth accounts linked to a user.""" + return await oauth_account.get_user_accounts(db, user_id=user_id) + + @staticmethod + async def get_user_account_by_provider( + db: AsyncSession, *, user_id: UUID, provider: str + ): + """Get a specific OAuth account for a user and provider.""" + return await oauth_account.get_user_account_by_provider( + db, user_id=user_id, provider=provider + ) + @staticmethod async def cleanup_expired_states(db: AsyncSession) -> int: """ diff --git a/backend/app/services/organization_service.py b/backend/app/services/organization_service.py new file mode 100644 index 0000000..de02311 --- /dev/null +++ b/backend/app/services/organization_service.py @@ -0,0 +1,157 @@ +# app/services/organization_service.py +"""Service layer for organization operations — delegates to OrganizationRepository.""" + +import logging +from typing import Any +from uuid import UUID + +from sqlalchemy.ext.asyncio import AsyncSession + +from app.core.exceptions import NotFoundError +from app.models.organization import Organization +from app.models.user_organization import OrganizationRole, UserOrganization +from app.repositories.organization import OrganizationRepository, organization_repo +from app.schemas.organizations import OrganizationCreate, OrganizationUpdate + +logger = logging.getLogger(__name__) + + +class OrganizationService: + """Service for organization management operations.""" + + def __init__( + self, organization_repository: OrganizationRepository | None = None + ) -> None: + self._repo = organization_repository or organization_repo + + async def get_organization(self, db: AsyncSession, org_id: str) -> Organization: + """Get organization by ID, raising NotFoundError if not found.""" + org = await self._repo.get(db, id=org_id) + if not org: + raise NotFoundError(f"Organization {org_id} not found") + return org + + async def create_organization( + self, db: AsyncSession, *, obj_in: OrganizationCreate + ) -> Organization: + """Create a new organization.""" + return await self._repo.create(db, obj_in=obj_in) + + async def update_organization( + self, + db: AsyncSession, + *, + org: Organization, + obj_in: OrganizationUpdate | dict[str, Any], + ) -> Organization: + """Update an existing organization.""" + return await self._repo.update(db, db_obj=org, obj_in=obj_in) + + async def remove_organization(self, db: AsyncSession, org_id: str) -> None: + """Permanently delete an organization by ID.""" + await self._repo.remove(db, id=org_id) + + async def get_member_count( + self, db: AsyncSession, *, organization_id: UUID + ) -> int: + """Get number of active members in an organization.""" + return await self._repo.get_member_count(db, organization_id=organization_id) + + async def get_multi_with_member_counts( + self, + db: AsyncSession, + *, + skip: int = 0, + limit: int = 100, + is_active: bool | None = None, + search: str | None = None, + ) -> tuple[list[dict[str, Any]], int]: + """List organizations with member counts and pagination.""" + return await self._repo.get_multi_with_member_counts( + db, skip=skip, limit=limit, is_active=is_active, search=search + ) + + async def get_user_organizations_with_details( + self, + db: AsyncSession, + *, + user_id: UUID, + is_active: bool | None = None, + ) -> list[dict[str, Any]]: + """Get all organizations a user belongs to, with membership details.""" + return await self._repo.get_user_organizations_with_details( + db, user_id=user_id, is_active=is_active + ) + + async def get_organization_members( + self, + db: AsyncSession, + *, + organization_id: UUID, + skip: int = 0, + limit: int = 100, + is_active: bool | None = True, + ) -> tuple[list[dict[str, Any]], int]: + """Get members of an organization with pagination.""" + return await self._repo.get_organization_members( + db, + organization_id=organization_id, + skip=skip, + limit=limit, + is_active=is_active, + ) + + async def add_member( + self, + db: AsyncSession, + *, + organization_id: UUID, + user_id: UUID, + role: OrganizationRole = OrganizationRole.MEMBER, + ) -> UserOrganization: + """Add a user to an organization.""" + return await self._repo.add_user( + db, organization_id=organization_id, user_id=user_id, role=role + ) + + async def remove_member( + self, + db: AsyncSession, + *, + organization_id: UUID, + user_id: UUID, + ) -> bool: + """Remove a user from an organization. Returns True if found and removed.""" + return await self._repo.remove_user( + db, organization_id=organization_id, user_id=user_id + ) + + async def get_user_role_in_org( + self, db: AsyncSession, *, user_id: UUID, organization_id: UUID + ) -> OrganizationRole | None: + """Get the role of a user in an organization.""" + return await self._repo.get_user_role_in_org( + db, user_id=user_id, organization_id=organization_id + ) + + async def get_org_distribution( + self, db: AsyncSession, *, limit: int = 6 + ) -> list[dict[str, Any]]: + """Return top organizations by member count for admin dashboard.""" + from sqlalchemy import func, select + + result = await db.execute( + select( + Organization.name, + func.count(UserOrganization.user_id).label("count"), + ) + .join(UserOrganization, Organization.id == UserOrganization.organization_id) + .group_by(Organization.name) + .order_by(func.count(UserOrganization.user_id).desc()) + .limit(limit) + ) + return [{"name": row.name, "value": row.count} for row in result.all()] + + +# Default singleton +organization_service = OrganizationService() diff --git a/backend/app/services/session_cleanup.py b/backend/app/services/session_cleanup.py index e993530..888166e 100755 --- a/backend/app/services/session_cleanup.py +++ b/backend/app/services/session_cleanup.py @@ -8,7 +8,7 @@ import logging from datetime import UTC, datetime from app.core.database import SessionLocal -from app.crud.session import session as session_crud +from app.repositories.session import session_repo as session_crud logger = logging.getLogger(__name__) diff --git a/backend/app/services/session_service.py b/backend/app/services/session_service.py new file mode 100644 index 0000000..a73590d --- /dev/null +++ b/backend/app/services/session_service.py @@ -0,0 +1,97 @@ +# app/services/session_service.py +"""Service layer for session operations — delegates to SessionRepository.""" + +import logging +from datetime import datetime + +from sqlalchemy.ext.asyncio import AsyncSession + +from app.models.user_session import UserSession +from app.repositories.session import SessionRepository, session_repo +from app.schemas.sessions import SessionCreate + +logger = logging.getLogger(__name__) + + +class SessionService: + """Service for user session management operations.""" + + def __init__(self, session_repository: SessionRepository | None = None) -> None: + self._repo = session_repository or session_repo + + async def create_session( + self, db: AsyncSession, *, obj_in: SessionCreate + ) -> UserSession: + """Create a new session record.""" + return await self._repo.create_session(db, obj_in=obj_in) + + async def get_session(self, db: AsyncSession, session_id: str) -> UserSession | None: + """Get session by ID.""" + return await self._repo.get(db, id=session_id) + + async def get_user_sessions( + self, db: AsyncSession, *, user_id: str, active_only: bool = True + ) -> list[UserSession]: + """Get all sessions for a user.""" + return await self._repo.get_user_sessions( + db, user_id=user_id, active_only=active_only + ) + + async def get_active_by_jti( + self, db: AsyncSession, *, jti: str + ) -> UserSession | None: + """Get active session by refresh token JTI.""" + return await self._repo.get_active_by_jti(db, jti=jti) + + async def get_by_jti(self, db: AsyncSession, *, jti: str) -> UserSession | None: + """Get session by refresh token JTI (active or inactive).""" + return await self._repo.get_by_jti(db, jti=jti) + + async def deactivate( + self, db: AsyncSession, *, session_id: str + ) -> UserSession | None: + """Deactivate a session (logout from device).""" + return await self._repo.deactivate(db, session_id=session_id) + + async def deactivate_all_user_sessions( + self, db: AsyncSession, *, user_id: str + ) -> int: + """Deactivate all sessions for a user. Returns count deactivated.""" + return await self._repo.deactivate_all_user_sessions(db, user_id=user_id) + + async def update_refresh_token( + self, + db: AsyncSession, + *, + session: UserSession, + new_jti: str, + new_expires_at: datetime, + ) -> UserSession: + """Update session with a rotated refresh token.""" + return await self._repo.update_refresh_token( + db, session=session, new_jti=new_jti, new_expires_at=new_expires_at + ) + + async def cleanup_expired_for_user( + self, db: AsyncSession, *, user_id: str + ) -> int: + """Remove expired sessions for a user. Returns count removed.""" + return await self._repo.cleanup_expired_for_user(db, user_id=user_id) + + async def get_all_sessions( + self, + db: AsyncSession, + *, + skip: int = 0, + limit: int = 100, + active_only: bool = True, + with_user: bool = True, + ) -> tuple[list[UserSession], int]: + """Get all sessions with pagination (admin only).""" + return await self._repo.get_all_sessions( + db, skip=skip, limit=limit, active_only=active_only, with_user=with_user + ) + + +# Default singleton +session_service = SessionService() diff --git a/backend/app/services/user_service.py b/backend/app/services/user_service.py new file mode 100644 index 0000000..0ad787f --- /dev/null +++ b/backend/app/services/user_service.py @@ -0,0 +1,120 @@ +# app/services/user_service.py +"""Service layer for user operations — delegates to UserRepository.""" + +import logging +from typing import Any +from uuid import UUID + +from sqlalchemy.ext.asyncio import AsyncSession + +from app.core.exceptions import NotFoundError +from app.models.user import User +from app.repositories.user import UserRepository, user_repo +from app.schemas.users import UserCreate, UserUpdate + +logger = logging.getLogger(__name__) + + +class UserService: + """Service for user management operations.""" + + def __init__(self, user_repository: UserRepository | None = None) -> None: + self._repo = user_repository or user_repo + + async def get_user(self, db: AsyncSession, user_id: str) -> User: + """Get user by ID, raising NotFoundError if not found.""" + user = await self._repo.get(db, id=user_id) + if not user: + raise NotFoundError(f"User {user_id} not found") + return user + + async def get_by_email(self, db: AsyncSession, email: str) -> User | None: + """Get user by email address.""" + return await self._repo.get_by_email(db, email=email) + + async def create_user(self, db: AsyncSession, user_data: UserCreate) -> User: + """Create a new user.""" + return await self._repo.create(db, obj_in=user_data) + + async def update_user( + self, db: AsyncSession, *, user: User, obj_in: UserUpdate | dict[str, Any] + ) -> User: + """Update an existing user.""" + return await self._repo.update(db, db_obj=user, obj_in=obj_in) + + async def soft_delete_user(self, db: AsyncSession, user_id: str) -> None: + """Soft-delete a user by ID.""" + await self._repo.soft_delete(db, id=user_id) + + async def list_users( + self, + db: AsyncSession, + *, + skip: int = 0, + limit: int = 100, + sort_by: str | None = None, + sort_order: str = "asc", + filters: dict[str, Any] | None = None, + search: str | None = None, + ) -> tuple[list[User], int]: + """List users with pagination, sorting, filtering, and search.""" + return await self._repo.get_multi_with_total( + db, + skip=skip, + limit=limit, + sort_by=sort_by, + sort_order=sort_order, + filters=filters, + search=search, + ) + + async def bulk_update_status( + self, db: AsyncSession, *, user_ids: list[UUID], is_active: bool + ) -> int: + """Bulk update active status for multiple users. Returns count updated.""" + return await self._repo.bulk_update_status( + db, user_ids=user_ids, is_active=is_active + ) + + async def bulk_soft_delete( + self, + db: AsyncSession, + *, + user_ids: list[UUID], + exclude_user_id: UUID | None = None, + ) -> int: + """Bulk soft-delete multiple users. Returns count deleted.""" + return await self._repo.bulk_soft_delete( + db, user_ids=user_ids, exclude_user_id=exclude_user_id + ) + + async def get_stats(self, db: AsyncSession) -> dict[str, Any]: + """Return user stats needed for the admin dashboard.""" + from sqlalchemy import func, select + + total_users = ( + await db.execute(select(func.count()).select_from(User)) + ).scalar() or 0 + active_count = ( + await db.execute(select(func.count()).select_from(User).where(User.is_active)) + ).scalar() or 0 + inactive_count = ( + await db.execute( + select(func.count()).select_from(User).where(User.is_active.is_(False)) + ) + ).scalar() or 0 + all_users = list( + ( + await db.execute(select(User).order_by(User.created_at)) + ).scalars().all() + ) + return { + "total_users": total_users, + "active_count": active_count, + "inactive_count": inactive_count, + "all_users": all_users, + } + + +# Default singleton +user_service = UserService() diff --git a/backend/tests/api/test_admin.py b/backend/tests/api/test_admin.py index 3fc7f33..1392d0a 100644 --- a/backend/tests/api/test_admin.py +++ b/backend/tests/api/test_admin.py @@ -147,7 +147,7 @@ class TestAdminCreateUser: headers={"Authorization": f"Bearer {superuser_token}"}, ) - assert response.status_code == status.HTTP_404_NOT_FOUND + assert response.status_code == status.HTTP_409_CONFLICT class TestAdminGetUser: @@ -565,7 +565,7 @@ class TestAdminCreateOrganization: headers={"Authorization": f"Bearer {superuser_token}"}, ) - assert response.status_code == status.HTTP_404_NOT_FOUND + assert response.status_code == status.HTTP_409_CONFLICT class TestAdminGetOrganization: diff --git a/backend/tests/api/test_admin_error_handlers.py b/backend/tests/api/test_admin_error_handlers.py index f639e12..cc3d68b 100644 --- a/backend/tests/api/test_admin_error_handlers.py +++ b/backend/tests/api/test_admin_error_handlers.py @@ -45,7 +45,7 @@ class TestAdminListUsersFilters: async def test_list_users_database_error_propagates(self, client, superuser_token): """Test that database errors propagate correctly (covers line 118-120).""" with patch( - "app.api.routes.admin.user_crud.get_multi_with_total", + "app.api.routes.admin.user_service.list_users", side_effect=Exception("DB error"), ): with pytest.raises(Exception): @@ -74,8 +74,8 @@ class TestAdminCreateUserErrors: }, ) - # Should get error for duplicate email - assert response.status_code == status.HTTP_404_NOT_FOUND + # Should get conflict for duplicate email + assert response.status_code == status.HTTP_409_CONFLICT @pytest.mark.asyncio async def test_create_user_unexpected_error_propagates( @@ -83,7 +83,7 @@ class TestAdminCreateUserErrors: ): """Test unexpected errors during user creation (covers line 151-153).""" with patch( - "app.api.routes.admin.user_crud.create", + "app.api.routes.admin.user_service.create_user", side_effect=RuntimeError("Unexpected error"), ): with pytest.raises(RuntimeError): @@ -135,7 +135,7 @@ class TestAdminUpdateUserErrors: ): """Test unexpected errors during user update (covers line 206-208).""" with patch( - "app.api.routes.admin.user_crud.update", + "app.api.routes.admin.user_service.update_user", side_effect=RuntimeError("Update failed"), ): with pytest.raises(RuntimeError): @@ -166,7 +166,7 @@ class TestAdminDeleteUserErrors: ): """Test unexpected errors during user deletion (covers line 238-240).""" with patch( - "app.api.routes.admin.user_crud.soft_delete", + "app.api.routes.admin.user_service.soft_delete_user", side_effect=Exception("Delete failed"), ): with pytest.raises(Exception): @@ -196,7 +196,7 @@ class TestAdminActivateUserErrors: ): """Test unexpected errors during user activation (covers line 282-284).""" with patch( - "app.api.routes.admin.user_crud.update", + "app.api.routes.admin.user_service.update_user", side_effect=Exception("Activation failed"), ): with pytest.raises(Exception): @@ -238,7 +238,7 @@ class TestAdminDeactivateUserErrors: ): """Test unexpected errors during user deactivation (covers line 326-328).""" with patch( - "app.api.routes.admin.user_crud.update", + "app.api.routes.admin.user_service.update_user", side_effect=Exception("Deactivation failed"), ): with pytest.raises(Exception): @@ -258,7 +258,7 @@ class TestAdminListOrganizationsErrors: async def test_list_organizations_database_error(self, client, superuser_token): """Test list organizations with database error (covers line 427-456).""" with patch( - "app.api.routes.admin.organization_crud.get_multi_with_member_counts", + "app.api.routes.admin.organization_service.get_multi_with_member_counts", side_effect=Exception("DB error"), ): with pytest.raises(Exception): @@ -299,14 +299,14 @@ class TestAdminCreateOrganizationErrors: }, ) - # Should get error for duplicate slug - assert response.status_code == status.HTTP_404_NOT_FOUND + # Should get conflict for duplicate slug + assert response.status_code == status.HTTP_409_CONFLICT @pytest.mark.asyncio async def test_create_organization_unexpected_error(self, client, superuser_token): """Test unexpected errors during organization creation (covers line 484-485).""" with patch( - "app.api.routes.admin.organization_crud.create", + "app.api.routes.admin.organization_service.create_organization", side_effect=RuntimeError("Creation failed"), ): with pytest.raises(RuntimeError): @@ -367,7 +367,7 @@ class TestAdminUpdateOrganizationErrors: org_id = org.id with patch( - "app.api.routes.admin.organization_crud.update", + "app.api.routes.admin.organization_service.update_organization", side_effect=Exception("Update failed"), ): with pytest.raises(Exception): @@ -412,7 +412,7 @@ class TestAdminDeleteOrganizationErrors: org_id = org.id with patch( - "app.api.routes.admin.organization_crud.remove", + "app.api.routes.admin.organization_service.remove_organization", side_effect=Exception("Delete failed"), ): with pytest.raises(Exception): @@ -456,7 +456,7 @@ class TestAdminListOrganizationMembersErrors: org_id = org.id with patch( - "app.api.routes.admin.organization_crud.get_organization_members", + "app.api.routes.admin.organization_service.get_organization_members", side_effect=Exception("DB error"), ): with pytest.raises(Exception): @@ -531,7 +531,7 @@ class TestAdminAddOrganizationMemberErrors: org_id = org.id with patch( - "app.api.routes.admin.organization_crud.add_user", + "app.api.routes.admin.organization_service.add_member", side_effect=Exception("Add failed"), ): with pytest.raises(Exception): @@ -587,7 +587,7 @@ class TestAdminRemoveOrganizationMemberErrors: org_id = org.id with patch( - "app.api.routes.admin.organization_crud.remove_user", + "app.api.routes.admin.organization_service.remove_member", side_effect=Exception("Remove failed"), ): with pytest.raises(Exception): diff --git a/backend/tests/api/test_auth_error_handlers.py b/backend/tests/api/test_auth_error_handlers.py index ac95d37..451813a 100644 --- a/backend/tests/api/test_auth_error_handlers.py +++ b/backend/tests/api/test_auth_error_handlers.py @@ -19,7 +19,7 @@ class TestLoginSessionCreationFailure: """Test that login succeeds even if session creation fails.""" # Mock session creation to fail with patch( - "app.api.routes.auth.session_crud.create_session", + "app.api.routes.auth.session_service.create_session", side_effect=Exception("Session creation failed"), ): response = await client.post( @@ -43,7 +43,7 @@ class TestOAuthLoginSessionCreationFailure: ): """Test OAuth login succeeds even if session creation fails.""" with patch( - "app.api.routes.auth.session_crud.create_session", + "app.api.routes.auth.session_service.create_session", side_effect=Exception("Session failed"), ): response = await client.post( @@ -76,7 +76,7 @@ class TestRefreshTokenSessionUpdateFailure: # Mock session update to fail with patch( - "app.api.routes.auth.session_crud.update_refresh_token", + "app.api.routes.auth.session_service.update_refresh_token", side_effect=Exception("Update failed"), ): response = await client.post( @@ -130,7 +130,7 @@ class TestLogoutWithNonExistentSession: tokens = response.json() # Mock session lookup to return None - with patch("app.api.routes.auth.session_crud.get_by_jti", return_value=None): + with patch("app.api.routes.auth.session_service.get_by_jti", return_value=None): response = await client.post( "/api/v1/auth/logout", headers={"Authorization": f"Bearer {tokens['access_token']}"}, @@ -157,7 +157,7 @@ class TestLogoutUnexpectedError: # Mock to raise unexpected error with patch( - "app.api.routes.auth.session_crud.get_by_jti", + "app.api.routes.auth.session_service.get_by_jti", side_effect=Exception("Unexpected error"), ): response = await client.post( @@ -186,7 +186,7 @@ class TestLogoutAllUnexpectedError: # Mock to raise database error with patch( - "app.api.routes.auth.session_crud.deactivate_all_user_sessions", + "app.api.routes.auth.session_service.deactivate_all_user_sessions", side_effect=Exception("DB error"), ): response = await client.post( @@ -212,7 +212,7 @@ class TestPasswordResetConfirmSessionInvalidation: # Mock session invalidation to fail with patch( - "app.api.routes.auth.session_crud.deactivate_all_user_sessions", + "app.api.routes.auth.session_service.deactivate_all_user_sessions", side_effect=Exception("Invalidation failed"), ): response = await client.post( diff --git a/backend/tests/api/test_auth_password_reset.py b/backend/tests/api/test_auth_password_reset.py index 108dbe8..5a72d29 100755 --- a/backend/tests/api/test_auth_password_reset.py +++ b/backend/tests/api/test_auth_password_reset.py @@ -334,7 +334,7 @@ class TestPasswordResetConfirm: token = create_password_reset_token(async_test_user.email) # Mock the database commit to raise an exception - with patch("app.api.routes.auth.user_crud.get_by_email") as mock_get: + with patch("app.services.auth_service.user_repo.get_by_email") as mock_get: mock_get.side_effect = Exception("Database error") response = await client.post( diff --git a/backend/tests/api/test_auth_security.py b/backend/tests/api/test_auth_security.py index 3ce8df6..773d221 100644 --- a/backend/tests/api/test_auth_security.py +++ b/backend/tests/api/test_auth_security.py @@ -12,7 +12,7 @@ These tests prevent real-world attack scenarios. import pytest from httpx import AsyncClient -from app.crud.session import session as session_crud +from app.repositories.session import session_repo as session_crud from app.models.user import User diff --git a/backend/tests/api/test_oauth.py b/backend/tests/api/test_oauth.py index 2300e7d..a6a6ba3 100644 --- a/backend/tests/api/test_oauth.py +++ b/backend/tests/api/test_oauth.py @@ -8,7 +8,7 @@ from uuid import uuid4 import pytest -from app.crud.oauth import oauth_account +from app.repositories.oauth_account import oauth_account_repo as oauth_account from app.schemas.oauth import OAuthAccountCreate @@ -349,7 +349,7 @@ class TestOAuthProviderEndpoints: _test_engine, AsyncTestingSessionLocal = async_test_db # Create a test client - from app.crud.oauth import oauth_client + from app.repositories.oauth_client import oauth_client_repo as oauth_client from app.schemas.oauth import OAuthClientCreate async with AsyncTestingSessionLocal() as session: @@ -386,7 +386,7 @@ class TestOAuthProviderEndpoints: _test_engine, AsyncTestingSessionLocal = async_test_db # Create a test client - from app.crud.oauth import oauth_client + from app.repositories.oauth_client import oauth_client_repo as oauth_client from app.schemas.oauth import OAuthClientCreate async with AsyncTestingSessionLocal() as session: diff --git a/backend/tests/api/test_organizations.py b/backend/tests/api/test_organizations.py index 404a43e..7a1e56c 100644 --- a/backend/tests/api/test_organizations.py +++ b/backend/tests/api/test_organizations.py @@ -537,7 +537,7 @@ class TestOrganizationExceptionHandlers: ): """Test generic exception handler in get_my_organizations (covers lines 81-83).""" with patch( - "app.crud.organization.organization.get_user_organizations_with_details", + "app.api.routes.organizations.organization_service.get_user_organizations_with_details", side_effect=Exception("Database connection lost"), ): # The exception handler logs and re-raises, so we expect the exception @@ -554,7 +554,7 @@ class TestOrganizationExceptionHandlers: ): """Test generic exception handler in get_organization (covers lines 124-128).""" with patch( - "app.crud.organization.organization.get", + "app.api.routes.organizations.organization_service.get_organization", side_effect=Exception("Database timeout"), ): with pytest.raises(Exception, match="Database timeout"): @@ -569,7 +569,7 @@ class TestOrganizationExceptionHandlers: ): """Test generic exception handler in get_organization_members (covers lines 170-172).""" with patch( - "app.crud.organization.organization.get_organization_members", + "app.api.routes.organizations.organization_service.get_organization_members", side_effect=Exception("Connection pool exhausted"), ): with pytest.raises(Exception, match="Connection pool exhausted"): @@ -591,11 +591,11 @@ class TestOrganizationExceptionHandlers: admin_token = login_response.json()["access_token"] with patch( - "app.crud.organization.organization.get", + "app.api.routes.organizations.organization_service.get_organization", return_value=test_org_with_user_admin, ): with patch( - "app.crud.organization.organization.update", + "app.api.routes.organizations.organization_service.update_organization", side_effect=Exception("Write lock timeout"), ): with pytest.raises(Exception, match="Write lock timeout"): diff --git a/backend/tests/api/test_permissions_security.py b/backend/tests/api/test_permissions_security.py index 46ac706..1526d23 100644 --- a/backend/tests/api/test_permissions_security.py +++ b/backend/tests/api/test_permissions_security.py @@ -11,7 +11,7 @@ These tests prevent unauthorized access and privilege escalation. import pytest from httpx import AsyncClient -from app.crud.user import user as user_crud +from app.repositories.user import user_repo as user_crud from app.models.organization import Organization from app.models.user import User diff --git a/backend/tests/api/test_sessions.py b/backend/tests/api/test_sessions.py index acb0cbf..57e4ced 100644 --- a/backend/tests/api/test_sessions.py +++ b/backend/tests/api/test_sessions.py @@ -39,7 +39,7 @@ async def async_test_user2(async_test_db): _test_engine, SessionLocal = async_test_db async with SessionLocal() as session: - from app.crud.user import user as user_crud + from app.repositories.user import user_repo as user_crud from app.schemas.users import UserCreate user_data = UserCreate( @@ -191,7 +191,7 @@ class TestRevokeSession: # Verify session is deactivated async with SessionLocal() as session: - from app.crud.session import session as session_crud + from app.repositories.session import session_repo as session_crud revoked_session = await session_crud.get(session, id=str(session_id)) assert revoked_session.is_active is False @@ -268,7 +268,7 @@ class TestCleanupExpiredSessions: _test_engine, SessionLocal = async_test_db # Create expired and active sessions using CRUD to avoid greenlet issues - from app.crud.session import session as session_crud + from app.repositories.session import session_repo as session_crud from app.schemas.sessions import SessionCreate async with SessionLocal() as db: @@ -334,7 +334,7 @@ class TestCleanupExpiredSessions: _test_engine, SessionLocal = async_test_db # Create only active sessions using CRUD - from app.crud.session import session as session_crud + from app.repositories.session import session_repo as session_crud from app.schemas.sessions import SessionCreate async with SessionLocal() as db: @@ -384,7 +384,7 @@ class TestSessionsAdditionalCases: # Create multiple sessions async with SessionLocal() as session: - from app.crud.session import session as session_crud + from app.repositories.session import session_repo as session_crud from app.schemas.sessions import SessionCreate for i in range(5): @@ -431,7 +431,7 @@ class TestSessionsAdditionalCases: """Test cleanup with mix of active/inactive and expired/not-expired sessions.""" _test_engine, SessionLocal = async_test_db - from app.crud.session import session as session_crud + from app.repositories.session import session_repo as session_crud from app.schemas.sessions import SessionCreate async with SessionLocal() as db: @@ -502,10 +502,10 @@ class TestSessionExceptionHandlers: """Test list_sessions handles database errors (covers lines 104-106).""" from unittest.mock import patch - from app.crud import session as session_module + from app.repositories import session as session_module with patch.object( - session_module.session, + session_module.session_repo, "get_user_sessions", side_effect=Exception("Database error"), ): @@ -527,10 +527,10 @@ class TestSessionExceptionHandlers: from unittest.mock import patch from uuid import uuid4 - from app.crud import session as session_module + from app.repositories import session as session_module # First create a session to revoke - from app.crud.session import session as session_crud + from app.repositories.session import session_repo as session_crud from app.schemas.sessions import SessionCreate _test_engine, AsyncTestingSessionLocal = async_test_db @@ -550,7 +550,7 @@ class TestSessionExceptionHandlers: # Mock the deactivate method to raise an exception with patch.object( - session_module.session, + session_module.session_repo, "deactivate", side_effect=Exception("Database connection lost"), ): @@ -568,10 +568,10 @@ class TestSessionExceptionHandlers: """Test cleanup_expired_sessions handles database errors (covers lines 233-236).""" from unittest.mock import patch - from app.crud import session as session_module + from app.repositories import session as session_module with patch.object( - session_module.session, + session_module.session_repo, "cleanup_expired_for_user", side_effect=Exception("Cleanup failed"), ): diff --git a/backend/tests/api/test_users.py b/backend/tests/api/test_users.py index 74a7f23..f74dadb 100644 --- a/backend/tests/api/test_users.py +++ b/backend/tests/api/test_users.py @@ -99,7 +99,7 @@ class TestUpdateCurrentUser: from unittest.mock import patch with patch( - "app.api.routes.users.user_crud.update", side_effect=Exception("DB error") + "app.api.routes.users.user_service.update_user", side_effect=Exception("DB error") ): with pytest.raises(Exception): await client.patch( @@ -134,7 +134,7 @@ class TestUpdateCurrentUser: from unittest.mock import patch with patch( - "app.api.routes.users.user_crud.update", + "app.api.routes.users.user_service.update_user", side_effect=ValueError("Invalid value"), ): with pytest.raises(ValueError): @@ -224,7 +224,7 @@ class TestUpdateUserById: from unittest.mock import patch with patch( - "app.api.routes.users.user_crud.update", side_effect=ValueError("Invalid") + "app.api.routes.users.user_service.update_user", side_effect=ValueError("Invalid") ): with pytest.raises(ValueError): await client.patch( @@ -241,7 +241,7 @@ class TestUpdateUserById: from unittest.mock import patch with patch( - "app.api.routes.users.user_crud.update", side_effect=Exception("Unexpected") + "app.api.routes.users.user_service.update_user", side_effect=Exception("Unexpected") ): with pytest.raises(Exception): await client.patch( @@ -354,7 +354,7 @@ class TestDeleteUserById: from unittest.mock import patch with patch( - "app.api.routes.users.user_crud.soft_delete", + "app.api.routes.users.user_service.soft_delete_user", side_effect=ValueError("Cannot delete"), ): with pytest.raises(ValueError): @@ -371,7 +371,7 @@ class TestDeleteUserById: from unittest.mock import patch with patch( - "app.api.routes.users.user_crud.soft_delete", + "app.api.routes.users.user_service.soft_delete_user", side_effect=Exception("Unexpected"), ): with pytest.raises(Exception): diff --git a/backend/tests/e2e/test_admin_workflows.py b/backend/tests/e2e/test_admin_workflows.py index 6cfe15e..ad781f6 100644 --- a/backend/tests/e2e/test_admin_workflows.py +++ b/backend/tests/e2e/test_admin_workflows.py @@ -46,7 +46,7 @@ async def login_user(client, email: str, password: str = "SecurePassword123!"): async def create_superuser(e2e_db_session, email: str, password: str): """Create a superuser directly in the database.""" - from app.crud.user import user as user_crud + from app.repositories.user import user_repo as user_crud from app.schemas.users import UserCreate user_in = UserCreate( diff --git a/backend/tests/e2e/test_organization_workflows.py b/backend/tests/e2e/test_organization_workflows.py index 710c87a..fa24dbf 100644 --- a/backend/tests/e2e/test_organization_workflows.py +++ b/backend/tests/e2e/test_organization_workflows.py @@ -46,7 +46,7 @@ async def register_and_login(client, email: str, password: str = "SecurePassword async def create_superuser_and_login(client, db_session): """Helper to create a superuser directly in DB and login.""" - from app.crud.user import user as user_crud + from app.repositories.user import user_repo as user_crud from app.schemas.users import UserCreate email = f"admin-{uuid4().hex[:8]}@example.com" diff --git a/backend/tests/crud/__init__.py b/backend/tests/repositories/__init__.py similarity index 100% rename from backend/tests/crud/__init__.py rename to backend/tests/repositories/__init__.py diff --git a/backend/tests/crud/test_base.py b/backend/tests/repositories/test_base.py similarity index 94% rename from backend/tests/crud/test_base.py rename to backend/tests/repositories/test_base.py index e6a6b9c..8e98a4b 100644 --- a/backend/tests/crud/test_base.py +++ b/backend/tests/repositories/test_base.py @@ -11,7 +11,12 @@ import pytest from sqlalchemy.exc import DataError, IntegrityError, OperationalError from sqlalchemy.orm import joinedload -from app.crud.user import user as user_crud +from app.core.repository_exceptions import ( + DuplicateEntryError, + IntegrityConstraintError, + InvalidInputError, +) +from app.repositories.user import user_repo as user_crud from app.schemas.users import UserCreate, UserUpdate @@ -81,7 +86,7 @@ class TestCRUDBaseGetMulti: _test_engine, SessionLocal = async_test_db async with SessionLocal() as session: - with pytest.raises(ValueError, match="skip must be non-negative"): + with pytest.raises(InvalidInputError, match="skip must be non-negative"): await user_crud.get_multi(session, skip=-1) @pytest.mark.asyncio @@ -90,7 +95,7 @@ class TestCRUDBaseGetMulti: _test_engine, SessionLocal = async_test_db async with SessionLocal() as session: - with pytest.raises(ValueError, match="limit must be non-negative"): + with pytest.raises(InvalidInputError, match="limit must be non-negative"): await user_crud.get_multi(session, limit=-1) @pytest.mark.asyncio @@ -99,7 +104,7 @@ class TestCRUDBaseGetMulti: _test_engine, SessionLocal = async_test_db async with SessionLocal() as session: - with pytest.raises(ValueError, match="Maximum limit is 1000"): + with pytest.raises(InvalidInputError, match="Maximum limit is 1000"): await user_crud.get_multi(session, limit=1001) @pytest.mark.asyncio @@ -140,7 +145,7 @@ class TestCRUDBaseCreate: last_name="Duplicate", ) - with pytest.raises(ValueError, match="already exists"): + with pytest.raises(DuplicateEntryError, match="already exists"): await user_crud.create(session, obj_in=user_data) @pytest.mark.asyncio @@ -165,7 +170,7 @@ class TestCRUDBaseCreate: last_name="User", ) - with pytest.raises(ValueError, match="Database integrity error"): + with pytest.raises(DuplicateEntryError, match="Database integrity error"): await user_crud.create(session, obj_in=user_data) @pytest.mark.asyncio @@ -244,7 +249,7 @@ class TestCRUDBaseUpdate: # Create another user async with SessionLocal() as session: - from app.crud.user import user as user_crud + from app.repositories.user import user_repo as user_crud user2_data = UserCreate( email="user2@example.com", @@ -268,7 +273,7 @@ class TestCRUDBaseUpdate: ): update_data = UserUpdate(email=async_test_user.email) - with pytest.raises(ValueError, match="already exists"): + with pytest.raises(DuplicateEntryError, match="already exists"): await user_crud.update( session, db_obj=user2_obj, obj_in=update_data ) @@ -302,7 +307,7 @@ class TestCRUDBaseUpdate: "statement", {}, Exception("constraint failed") ), ): - with pytest.raises(ValueError, match="Database integrity error"): + with pytest.raises(IntegrityConstraintError, match="Database integrity error"): await user_crud.update( session, db_obj=user, obj_in={"first_name": "Test"} ) @@ -322,7 +327,7 @@ class TestCRUDBaseUpdate: "statement", {}, Exception("connection error") ), ): - with pytest.raises(ValueError, match="Database operation failed"): + with pytest.raises(IntegrityConstraintError, match="Database operation failed"): await user_crud.update( session, db_obj=user, obj_in={"first_name": "Test"} ) @@ -403,7 +408,7 @@ class TestCRUDBaseRemove: ), ): with pytest.raises( - ValueError, match="Cannot delete.*referenced by other records" + IntegrityConstraintError, match="Cannot delete.*referenced by other records" ): await user_crud.remove(session, id=str(async_test_user.id)) @@ -442,7 +447,7 @@ class TestCRUDBaseGetMultiWithTotal: _test_engine, SessionLocal = async_test_db async with SessionLocal() as session: - with pytest.raises(ValueError, match="skip must be non-negative"): + with pytest.raises(InvalidInputError, match="skip must be non-negative"): await user_crud.get_multi_with_total(session, skip=-1) @pytest.mark.asyncio @@ -451,7 +456,7 @@ class TestCRUDBaseGetMultiWithTotal: _test_engine, SessionLocal = async_test_db async with SessionLocal() as session: - with pytest.raises(ValueError, match="limit must be non-negative"): + with pytest.raises(InvalidInputError, match="limit must be non-negative"): await user_crud.get_multi_with_total(session, limit=-1) @pytest.mark.asyncio @@ -460,7 +465,7 @@ class TestCRUDBaseGetMultiWithTotal: _test_engine, SessionLocal = async_test_db async with SessionLocal() as session: - with pytest.raises(ValueError, match="Maximum limit is 1000"): + with pytest.raises(InvalidInputError, match="Maximum limit is 1000"): await user_crud.get_multi_with_total(session, limit=1001) @pytest.mark.asyncio @@ -827,7 +832,7 @@ class TestCRUDBasePaginationValidation: _test_engine, SessionLocal = async_test_db async with SessionLocal() as session: - with pytest.raises(ValueError, match="skip must be non-negative"): + with pytest.raises(InvalidInputError, match="skip must be non-negative"): await user_crud.get_multi_with_total(session, skip=-1, limit=10) @pytest.mark.asyncio @@ -836,7 +841,7 @@ class TestCRUDBasePaginationValidation: _test_engine, SessionLocal = async_test_db async with SessionLocal() as session: - with pytest.raises(ValueError, match="limit must be non-negative"): + with pytest.raises(InvalidInputError, match="limit must be non-negative"): await user_crud.get_multi_with_total(session, skip=0, limit=-1) @pytest.mark.asyncio @@ -845,7 +850,7 @@ class TestCRUDBasePaginationValidation: _test_engine, SessionLocal = async_test_db async with SessionLocal() as session: - with pytest.raises(ValueError, match="Maximum limit is 1000"): + with pytest.raises(InvalidInputError, match="Maximum limit is 1000"): await user_crud.get_multi_with_total(session, skip=0, limit=1001) @pytest.mark.asyncio @@ -899,7 +904,7 @@ class TestCRUDBaseModelsWithoutSoftDelete: _test_engine, SessionLocal = async_test_db # Create an organization (which doesn't have deleted_at) - from app.crud.organization import organization as org_crud + from app.repositories.organization import organization_repo as org_crud from app.models.organization import Organization async with SessionLocal() as session: @@ -910,7 +915,7 @@ class TestCRUDBaseModelsWithoutSoftDelete: # Try to soft delete organization (should fail) async with SessionLocal() as session: - with pytest.raises(ValueError, match="does not have a deleted_at column"): + with pytest.raises(InvalidInputError, match="does not have a deleted_at column"): await org_crud.soft_delete(session, id=str(org_id)) @pytest.mark.asyncio @@ -919,7 +924,7 @@ class TestCRUDBaseModelsWithoutSoftDelete: _test_engine, SessionLocal = async_test_db # Create an organization (which doesn't have deleted_at) - from app.crud.organization import organization as org_crud + from app.repositories.organization import organization_repo as org_crud from app.models.organization import Organization async with SessionLocal() as session: @@ -930,7 +935,7 @@ class TestCRUDBaseModelsWithoutSoftDelete: # Try to restore organization (should fail) async with SessionLocal() as session: - with pytest.raises(ValueError, match="does not have a deleted_at column"): + with pytest.raises(InvalidInputError, match="does not have a deleted_at column"): await org_crud.restore(session, id=str(org_id)) @@ -950,7 +955,7 @@ class TestCRUDBaseEagerLoadingWithRealOptions: _test_engine, SessionLocal = async_test_db # Create a session for the user - from app.crud.session import session as session_crud + from app.repositories.session import session_repo as session_crud from app.models.user_session import UserSession async with SessionLocal() as session: @@ -989,7 +994,7 @@ class TestCRUDBaseEagerLoadingWithRealOptions: _test_engine, SessionLocal = async_test_db # Create multiple sessions for the user - from app.crud.session import session as session_crud + from app.repositories.session import session_repo as session_crud from app.models.user_session import UserSession async with SessionLocal() as session: diff --git a/backend/tests/crud/test_base_db_failures.py b/backend/tests/repositories/test_base_db_failures.py similarity index 97% rename from backend/tests/crud/test_base_db_failures.py rename to backend/tests/repositories/test_base_db_failures.py index 36e0991..e468062 100644 --- a/backend/tests/crud/test_base_db_failures.py +++ b/backend/tests/repositories/test_base_db_failures.py @@ -10,7 +10,8 @@ from uuid import uuid4 import pytest from sqlalchemy.exc import DataError, OperationalError -from app.crud.user import user as user_crud +from app.core.repository_exceptions import IntegrityConstraintError +from app.repositories.user import user_repo as user_crud from app.schemas.users import UserCreate @@ -119,7 +120,7 @@ class TestBaseCRUDUpdateFailures: with patch.object( session, "rollback", new_callable=AsyncMock ) as mock_rollback: - with pytest.raises(ValueError, match="Database operation failed"): + with pytest.raises(IntegrityConstraintError, match="Database operation failed"): await user_crud.update( session, db_obj=user, obj_in={"first_name": "Updated"} ) @@ -141,7 +142,7 @@ class TestBaseCRUDUpdateFailures: with patch.object( session, "rollback", new_callable=AsyncMock ) as mock_rollback: - with pytest.raises(ValueError, match="Database operation failed"): + with pytest.raises(IntegrityConstraintError, match="Database operation failed"): await user_crud.update( session, db_obj=user, obj_in={"first_name": "Updated"} ) diff --git a/backend/tests/crud/test_oauth.py b/backend/tests/repositories/test_oauth.py similarity index 97% rename from backend/tests/crud/test_oauth.py rename to backend/tests/repositories/test_oauth.py index a126e05..45a8077 100644 --- a/backend/tests/crud/test_oauth.py +++ b/backend/tests/repositories/test_oauth.py @@ -7,7 +7,10 @@ from datetime import UTC, datetime, timedelta import pytest -from app.crud.oauth import oauth_account, oauth_client, oauth_state +from app.core.repository_exceptions import DuplicateEntryError +from app.repositories.oauth_account import oauth_account_repo as oauth_account +from app.repositories.oauth_client import oauth_client_repo as oauth_client +from app.repositories.oauth_state import oauth_state_repo as oauth_state from app.schemas.oauth import OAuthAccountCreate, OAuthClientCreate, OAuthStateCreate @@ -60,7 +63,7 @@ class TestOAuthAccountCRUD: # SQLite returns different error message than PostgreSQL with pytest.raises( - ValueError, match="(already linked|UNIQUE constraint failed)" + DuplicateEntryError, match="(already linked|UNIQUE constraint failed|Failed to create)" ): await oauth_account.create_account(session, obj_in=account_data2) @@ -256,13 +259,13 @@ class TestOAuthAccountCRUD: updated = await oauth_account.update_tokens( session, account=account, - access_token_encrypted="new_access_token", - refresh_token_encrypted="new_refresh_token", + access_token="new_access_token", + refresh_token="new_refresh_token", token_expires_at=new_expires, ) - assert updated.access_token_encrypted == "new_access_token" - assert updated.refresh_token_encrypted == "new_refresh_token" + assert updated.access_token == "new_access_token" + assert updated.refresh_token == "new_refresh_token" class TestOAuthStateCRUD: diff --git a/backend/tests/crud/test_organization.py b/backend/tests/repositories/test_organization.py similarity index 98% rename from backend/tests/crud/test_organization.py rename to backend/tests/repositories/test_organization.py index 1a7b2e9..d544320 100644 --- a/backend/tests/crud/test_organization.py +++ b/backend/tests/repositories/test_organization.py @@ -9,7 +9,8 @@ from uuid import uuid4 import pytest from sqlalchemy import select -from app.crud.organization import organization as organization_crud +from app.core.repository_exceptions import DuplicateEntryError, IntegrityConstraintError +from app.repositories.organization import organization_repo as organization_crud from app.models.organization import Organization from app.models.user_organization import OrganizationRole, UserOrganization from app.schemas.organizations import OrganizationCreate @@ -87,7 +88,7 @@ class TestCreate: # Try to create second with same slug async with AsyncTestingSessionLocal() as session: org_in = OrganizationCreate(name="Org 2", slug="duplicate-slug") - with pytest.raises(ValueError, match="already exists"): + with pytest.raises(DuplicateEntryError, match="already exists"): await organization_crud.create(session, obj_in=org_in) @pytest.mark.asyncio @@ -295,7 +296,7 @@ class TestAddUser: org_id = org.id async with AsyncTestingSessionLocal() as session: - with pytest.raises(ValueError, match="already a member"): + with pytest.raises(DuplicateEntryError, match="already a member"): await organization_crud.add_user( session, organization_id=org_id, user_id=async_test_user.id ) @@ -972,7 +973,7 @@ class TestOrganizationExceptionHandlers: with patch.object(session, "commit", side_effect=mock_commit): with patch.object(session, "rollback", new_callable=AsyncMock): org_in = OrganizationCreate(name="Test", slug="test") - with pytest.raises(ValueError, match="Database integrity error"): + with pytest.raises(IntegrityConstraintError, match="Database integrity error"): await organization_crud.create(session, obj_in=org_in) @pytest.mark.asyncio @@ -1058,7 +1059,7 @@ class TestOrganizationExceptionHandlers: with patch.object(session, "commit", side_effect=mock_commit): with patch.object(session, "rollback", new_callable=AsyncMock): with pytest.raises( - ValueError, match="Failed to add user to organization" + IntegrityConstraintError, match="Failed to add user to organization" ): await organization_crud.add_user( session, diff --git a/backend/tests/crud/test_session.py b/backend/tests/repositories/test_session.py similarity index 99% rename from backend/tests/crud/test_session.py rename to backend/tests/repositories/test_session.py index 8b540fa..0c1c902 100644 --- a/backend/tests/crud/test_session.py +++ b/backend/tests/repositories/test_session.py @@ -8,7 +8,8 @@ from uuid import uuid4 import pytest -from app.crud.session import session as session_crud +from app.core.repository_exceptions import InvalidInputError +from app.repositories.session import session_repo as session_crud from app.models.user_session import UserSession from app.schemas.sessions import SessionCreate @@ -503,7 +504,7 @@ class TestCleanupExpiredForUser: _test_engine, AsyncTestingSessionLocal = async_test_db async with AsyncTestingSessionLocal() as session: - with pytest.raises(ValueError, match="Invalid user ID format"): + with pytest.raises(InvalidInputError, match="Invalid user ID format"): await session_crud.cleanup_expired_for_user( session, user_id="not-a-valid-uuid" ) diff --git a/backend/tests/crud/test_session_db_failures.py b/backend/tests/repositories/test_session_db_failures.py similarity index 97% rename from backend/tests/crud/test_session_db_failures.py rename to backend/tests/repositories/test_session_db_failures.py index dabf0a1..9bf3ef9 100644 --- a/backend/tests/crud/test_session_db_failures.py +++ b/backend/tests/repositories/test_session_db_failures.py @@ -10,7 +10,8 @@ from uuid import uuid4 import pytest from sqlalchemy.exc import OperationalError -from app.crud.session import session as session_crud +from app.core.repository_exceptions import IntegrityConstraintError +from app.repositories.session import session_repo as session_crud from app.models.user_session import UserSession from app.schemas.sessions import SessionCreate @@ -102,7 +103,7 @@ class TestSessionCRUDCreateSessionFailures: last_used_at=datetime.now(UTC), ) - with pytest.raises(ValueError, match="Failed to create session"): + with pytest.raises(IntegrityConstraintError, match="Failed to create session"): await session_crud.create_session(session, obj_in=session_data) mock_rollback.assert_called_once() @@ -133,7 +134,7 @@ class TestSessionCRUDCreateSessionFailures: last_used_at=datetime.now(UTC), ) - with pytest.raises(ValueError, match="Failed to create session"): + with pytest.raises(IntegrityConstraintError, match="Failed to create session"): await session_crud.create_session(session, obj_in=session_data) mock_rollback.assert_called_once() diff --git a/backend/tests/crud/test_user.py b/backend/tests/repositories/test_user.py similarity index 98% rename from backend/tests/crud/test_user.py rename to backend/tests/repositories/test_user.py index 0500a90..493c9ed 100644 --- a/backend/tests/crud/test_user.py +++ b/backend/tests/repositories/test_user.py @@ -5,7 +5,8 @@ Comprehensive tests for async user CRUD operations. import pytest -from app.crud.user import user as user_crud +from app.core.repository_exceptions import DuplicateEntryError, InvalidInputError +from app.repositories.user import user_repo as user_crud from app.schemas.users import UserCreate, UserUpdate @@ -93,7 +94,7 @@ class TestCreate: last_name="User", ) - with pytest.raises(ValueError) as exc_info: + with pytest.raises(DuplicateEntryError) as exc_info: await user_crud.create(session, obj_in=user_data) assert "already exists" in str(exc_info.value).lower() @@ -330,7 +331,7 @@ class TestGetMultiWithTotal: _test_engine, AsyncTestingSessionLocal = async_test_db async with AsyncTestingSessionLocal() as session: - with pytest.raises(ValueError) as exc_info: + with pytest.raises(InvalidInputError) as exc_info: await user_crud.get_multi_with_total(session, skip=-1, limit=10) assert "skip must be non-negative" in str(exc_info.value) @@ -341,7 +342,7 @@ class TestGetMultiWithTotal: _test_engine, AsyncTestingSessionLocal = async_test_db async with AsyncTestingSessionLocal() as session: - with pytest.raises(ValueError) as exc_info: + with pytest.raises(InvalidInputError) as exc_info: await user_crud.get_multi_with_total(session, skip=0, limit=-1) assert "limit must be non-negative" in str(exc_info.value) @@ -352,7 +353,7 @@ class TestGetMultiWithTotal: _test_engine, AsyncTestingSessionLocal = async_test_db async with AsyncTestingSessionLocal() as session: - with pytest.raises(ValueError) as exc_info: + with pytest.raises(InvalidInputError) as exc_info: await user_crud.get_multi_with_total(session, skip=0, limit=1001) assert "Maximum limit is 1000" in str(exc_info.value) diff --git a/backend/tests/services/test_auth_service.py b/backend/tests/services/test_auth_service.py index cf6f84c..08c1ca3 100755 --- a/backend/tests/services/test_auth_service.py +++ b/backend/tests/services/test_auth_service.py @@ -10,6 +10,7 @@ from app.core.auth import ( get_password_hash, verify_password, ) +from app.core.exceptions import DuplicateError from app.models.user import User from app.schemas.users import Token, UserCreate from app.services.auth_service import AuthenticationError, AuthService @@ -152,9 +153,9 @@ class TestAuthServiceUserCreation: last_name="User", ) - # Should raise AuthenticationError + # Should raise DuplicateError for duplicate email async with AsyncTestingSessionLocal() as session: - with pytest.raises(AuthenticationError): + with pytest.raises(DuplicateError): await AuthService.create_user(db=session, user_data=user_data) diff --git a/backend/tests/services/test_oauth_provider_service.py b/backend/tests/services/test_oauth_provider_service.py index 0dfdf90..db22cb1 100644 --- a/backend/tests/services/test_oauth_provider_service.py +++ b/backend/tests/services/test_oauth_provider_service.py @@ -269,18 +269,18 @@ class TestClientValidation: async def test_validate_client_legacy_sha256_hash( self, db, confidential_client_legacy_hash ): - """Test validating a client with legacy SHA-256 hash (backward compatibility).""" + """Test that legacy SHA-256 hash is rejected with clear error message.""" client, secret = confidential_client_legacy_hash - validated = await service.validate_client(db, client.client_id, secret) - assert validated.client_id == client.client_id + with pytest.raises(service.InvalidClientError, match="deprecated hash format"): + await service.validate_client(db, client.client_id, secret) @pytest.mark.asyncio async def test_validate_client_legacy_sha256_wrong_secret( self, db, confidential_client_legacy_hash ): - """Test legacy SHA-256 client rejects wrong secret.""" + """Test that legacy SHA-256 client with wrong secret is rejected.""" client, _ = confidential_client_legacy_hash - with pytest.raises(service.InvalidClientError, match="Invalid client secret"): + with pytest.raises(service.InvalidClientError, match="deprecated hash format"): await service.validate_client(db, client.client_id, "wrong_secret") def test_validate_redirect_uri_success(self, public_client): diff --git a/backend/tests/services/test_oauth_service.py b/backend/tests/services/test_oauth_service.py index 7fad254..b9d3fd2 100644 --- a/backend/tests/services/test_oauth_service.py +++ b/backend/tests/services/test_oauth_service.py @@ -11,7 +11,8 @@ from uuid import uuid4 import pytest from app.core.exceptions import AuthenticationError -from app.crud.oauth import oauth_account, oauth_state +from app.repositories.oauth_account import oauth_account_repo as oauth_account +from app.repositories.oauth_state import oauth_state_repo as oauth_state from app.schemas.oauth import OAuthAccountCreate, OAuthStateCreate from app.services.oauth_service import OAUTH_PROVIDERS, OAuthService diff --git a/backend/tests/services/test_organization_service.py b/backend/tests/services/test_organization_service.py new file mode 100644 index 0000000..681a813 --- /dev/null +++ b/backend/tests/services/test_organization_service.py @@ -0,0 +1,447 @@ +# tests/services/test_organization_service.py +"""Tests for the OrganizationService class.""" + +import uuid + +import pytest +import pytest_asyncio + +from app.core.exceptions import NotFoundError +from app.models.user_organization import OrganizationRole +from app.schemas.organizations import OrganizationCreate, OrganizationUpdate +from app.services.organization_service import OrganizationService, organization_service + + +def _make_org_create(name=None, slug=None) -> OrganizationCreate: + """Helper to create an OrganizationCreate schema with unique defaults.""" + unique = uuid.uuid4().hex[:8] + return OrganizationCreate( + name=name or f"Test Org {unique}", + slug=slug or f"test-org-{unique}", + description="A test organization", + is_active=True, + settings={}, + ) + + +class TestGetOrganization: + """Tests for OrganizationService.get_organization method.""" + + @pytest.mark.asyncio + async def test_get_organization_found(self, async_test_db, async_test_user): + """Test getting an existing organization by ID returns the org.""" + _test_engine, AsyncTestingSessionLocal = async_test_db + async with AsyncTestingSessionLocal() as session: + created = await organization_service.create_organization( + session, obj_in=_make_org_create() + ) + + async with AsyncTestingSessionLocal() as session: + result = await organization_service.get_organization( + session, str(created.id) + ) + assert result is not None + assert result.id == created.id + assert result.slug == created.slug + + @pytest.mark.asyncio + async def test_get_organization_not_found(self, async_test_db): + """Test getting a non-existent organization raises NotFoundError.""" + _test_engine, AsyncTestingSessionLocal = async_test_db + async with AsyncTestingSessionLocal() as session: + with pytest.raises(NotFoundError): + await organization_service.get_organization( + session, str(uuid.uuid4()) + ) + + +class TestCreateOrganization: + """Tests for OrganizationService.create_organization method.""" + + @pytest.mark.asyncio + async def test_create_organization(self, async_test_db, async_test_user): + """Test creating a new organization returns the created org with correct fields.""" + _test_engine, AsyncTestingSessionLocal = async_test_db + obj_in = _make_org_create() + async with AsyncTestingSessionLocal() as session: + result = await organization_service.create_organization( + session, obj_in=obj_in + ) + assert result is not None + assert result.name == obj_in.name + assert result.slug == obj_in.slug + assert result.description == obj_in.description + assert result.is_active is True + + +class TestUpdateOrganization: + """Tests for OrganizationService.update_organization method.""" + + @pytest.mark.asyncio + async def test_update_organization(self, async_test_db, async_test_user): + """Test updating an organization name.""" + _test_engine, AsyncTestingSessionLocal = async_test_db + async with AsyncTestingSessionLocal() as session: + created = await organization_service.create_organization( + session, obj_in=_make_org_create() + ) + + async with AsyncTestingSessionLocal() as session: + org = await organization_service.get_organization(session, str(created.id)) + updated = await organization_service.update_organization( + session, + org=org, + obj_in=OrganizationUpdate(name="Updated Org Name"), + ) + assert updated.name == "Updated Org Name" + assert updated.id == created.id + + @pytest.mark.asyncio + async def test_update_organization_with_dict(self, async_test_db, async_test_user): + """Test updating an organization using a dict.""" + _test_engine, AsyncTestingSessionLocal = async_test_db + async with AsyncTestingSessionLocal() as session: + created = await organization_service.create_organization( + session, obj_in=_make_org_create() + ) + + async with AsyncTestingSessionLocal() as session: + org = await organization_service.get_organization(session, str(created.id)) + updated = await organization_service.update_organization( + session, + org=org, + obj_in={"description": "Updated description"}, + ) + assert updated.description == "Updated description" + + +class TestRemoveOrganization: + """Tests for OrganizationService.remove_organization method.""" + + @pytest.mark.asyncio + async def test_remove_organization(self, async_test_db, async_test_user): + """Test permanently deleting an organization.""" + _test_engine, AsyncTestingSessionLocal = async_test_db + async with AsyncTestingSessionLocal() as session: + created = await organization_service.create_organization( + session, obj_in=_make_org_create() + ) + org_id = str(created.id) + + async with AsyncTestingSessionLocal() as session: + await organization_service.remove_organization(session, org_id) + + async with AsyncTestingSessionLocal() as session: + with pytest.raises(NotFoundError): + await organization_service.get_organization(session, org_id) + + +class TestGetMemberCount: + """Tests for OrganizationService.get_member_count method.""" + + @pytest.mark.asyncio + async def test_get_member_count_empty(self, async_test_db, async_test_user): + """Test member count for org with no members is zero.""" + _test_engine, AsyncTestingSessionLocal = async_test_db + async with AsyncTestingSessionLocal() as session: + created = await organization_service.create_organization( + session, obj_in=_make_org_create() + ) + + async with AsyncTestingSessionLocal() as session: + count = await organization_service.get_member_count( + session, organization_id=created.id + ) + assert count == 0 + + @pytest.mark.asyncio + async def test_get_member_count_with_member(self, async_test_db, async_test_user): + """Test member count increases after adding a member.""" + _test_engine, AsyncTestingSessionLocal = async_test_db + async with AsyncTestingSessionLocal() as session: + created = await organization_service.create_organization( + session, obj_in=_make_org_create() + ) + + async with AsyncTestingSessionLocal() as session: + await organization_service.add_member( + session, + organization_id=created.id, + user_id=async_test_user.id, + ) + + async with AsyncTestingSessionLocal() as session: + count = await organization_service.get_member_count( + session, organization_id=created.id + ) + assert count == 1 + + +class TestGetMultiWithMemberCounts: + """Tests for OrganizationService.get_multi_with_member_counts method.""" + + @pytest.mark.asyncio + async def test_get_multi_with_member_counts(self, async_test_db, async_test_user): + """Test listing organizations with member counts returns tuple.""" + _test_engine, AsyncTestingSessionLocal = async_test_db + async with AsyncTestingSessionLocal() as session: + await organization_service.create_organization( + session, obj_in=_make_org_create() + ) + + async with AsyncTestingSessionLocal() as session: + orgs, count = await organization_service.get_multi_with_member_counts( + session, skip=0, limit=10 + ) + assert isinstance(orgs, list) + assert isinstance(count, int) + assert count >= 1 + + @pytest.mark.asyncio + async def test_get_multi_with_member_counts_search( + self, async_test_db, async_test_user + ): + """Test listing organizations with a search filter.""" + _test_engine, AsyncTestingSessionLocal = async_test_db + unique = uuid.uuid4().hex[:8] + org_name = f"Searchable Org {unique}" + async with AsyncTestingSessionLocal() as session: + await organization_service.create_organization( + session, + obj_in=OrganizationCreate( + name=org_name, + slug=f"searchable-org-{unique}", + is_active=True, + settings={}, + ), + ) + + async with AsyncTestingSessionLocal() as session: + orgs, count = await organization_service.get_multi_with_member_counts( + session, skip=0, limit=10, search=f"Searchable Org {unique}" + ) + assert count >= 1 + # Each element is a dict with key "organization" (an Organization obj) and "member_count" + names = [o["organization"].name for o in orgs] + assert org_name in names + + +class TestGetUserOrganizationsWithDetails: + """Tests for OrganizationService.get_user_organizations_with_details method.""" + + @pytest.mark.asyncio + async def test_get_user_organizations_with_details( + self, async_test_db, async_test_user + ): + """Test getting organizations for a user returns list of dicts.""" + _test_engine, AsyncTestingSessionLocal = async_test_db + async with AsyncTestingSessionLocal() as session: + created = await organization_service.create_organization( + session, obj_in=_make_org_create() + ) + await organization_service.add_member( + session, + organization_id=created.id, + user_id=async_test_user.id, + ) + + async with AsyncTestingSessionLocal() as session: + orgs = await organization_service.get_user_organizations_with_details( + session, user_id=async_test_user.id + ) + assert isinstance(orgs, list) + assert len(orgs) >= 1 + + +class TestGetOrganizationMembers: + """Tests for OrganizationService.get_organization_members method.""" + + @pytest.mark.asyncio + async def test_get_organization_members(self, async_test_db, async_test_user): + """Test getting organization members returns paginated results.""" + _test_engine, AsyncTestingSessionLocal = async_test_db + async with AsyncTestingSessionLocal() as session: + created = await organization_service.create_organization( + session, obj_in=_make_org_create() + ) + await organization_service.add_member( + session, + organization_id=created.id, + user_id=async_test_user.id, + ) + + async with AsyncTestingSessionLocal() as session: + members, count = await organization_service.get_organization_members( + session, organization_id=created.id, skip=0, limit=10 + ) + assert isinstance(members, list) + assert isinstance(count, int) + assert count >= 1 + + +class TestAddMember: + """Tests for OrganizationService.add_member method.""" + + @pytest.mark.asyncio + async def test_add_member_default_role(self, async_test_db, async_test_user): + """Test adding a user to an org with default MEMBER role.""" + _test_engine, AsyncTestingSessionLocal = async_test_db + async with AsyncTestingSessionLocal() as session: + created = await organization_service.create_organization( + session, obj_in=_make_org_create() + ) + + async with AsyncTestingSessionLocal() as session: + membership = await organization_service.add_member( + session, + organization_id=created.id, + user_id=async_test_user.id, + ) + assert membership is not None + assert membership.user_id == async_test_user.id + assert membership.organization_id == created.id + assert membership.role == OrganizationRole.MEMBER + + @pytest.mark.asyncio + async def test_add_member_admin_role(self, async_test_db, async_test_user): + """Test adding a user to an org with ADMIN role.""" + _test_engine, AsyncTestingSessionLocal = async_test_db + async with AsyncTestingSessionLocal() as session: + created = await organization_service.create_organization( + session, obj_in=_make_org_create() + ) + + async with AsyncTestingSessionLocal() as session: + membership = await organization_service.add_member( + session, + organization_id=created.id, + user_id=async_test_user.id, + role=OrganizationRole.ADMIN, + ) + assert membership.role == OrganizationRole.ADMIN + + +class TestRemoveMember: + """Tests for OrganizationService.remove_member method.""" + + @pytest.mark.asyncio + async def test_remove_member(self, async_test_db, async_test_user): + """Test removing a member from an org returns True.""" + _test_engine, AsyncTestingSessionLocal = async_test_db + async with AsyncTestingSessionLocal() as session: + created = await organization_service.create_organization( + session, obj_in=_make_org_create() + ) + await organization_service.add_member( + session, + organization_id=created.id, + user_id=async_test_user.id, + ) + + async with AsyncTestingSessionLocal() as session: + removed = await organization_service.remove_member( + session, + organization_id=created.id, + user_id=async_test_user.id, + ) + assert removed is True + + @pytest.mark.asyncio + async def test_remove_member_not_found(self, async_test_db, async_test_user): + """Test removing a non-member returns False.""" + _test_engine, AsyncTestingSessionLocal = async_test_db + async with AsyncTestingSessionLocal() as session: + created = await organization_service.create_organization( + session, obj_in=_make_org_create() + ) + + async with AsyncTestingSessionLocal() as session: + removed = await organization_service.remove_member( + session, + organization_id=created.id, + user_id=async_test_user.id, + ) + assert removed is False + + +class TestGetUserRoleInOrg: + """Tests for OrganizationService.get_user_role_in_org method.""" + + @pytest.mark.asyncio + async def test_get_user_role_in_org(self, async_test_db, async_test_user): + """Test getting a user's role in an org they belong to.""" + _test_engine, AsyncTestingSessionLocal = async_test_db + async with AsyncTestingSessionLocal() as session: + created = await organization_service.create_organization( + session, obj_in=_make_org_create() + ) + await organization_service.add_member( + session, + organization_id=created.id, + user_id=async_test_user.id, + role=OrganizationRole.MEMBER, + ) + + async with AsyncTestingSessionLocal() as session: + role = await organization_service.get_user_role_in_org( + session, + user_id=async_test_user.id, + organization_id=created.id, + ) + assert role == OrganizationRole.MEMBER + + @pytest.mark.asyncio + async def test_get_user_role_in_org_not_member( + self, async_test_db, async_test_user + ): + """Test getting role for a user not in the org returns None.""" + _test_engine, AsyncTestingSessionLocal = async_test_db + async with AsyncTestingSessionLocal() as session: + created = await organization_service.create_organization( + session, obj_in=_make_org_create() + ) + + async with AsyncTestingSessionLocal() as session: + role = await organization_service.get_user_role_in_org( + session, + user_id=async_test_user.id, + organization_id=created.id, + ) + assert role is None + + +class TestGetOrgDistribution: + """Tests for OrganizationService.get_org_distribution method.""" + + @pytest.mark.asyncio + async def test_get_org_distribution_empty(self, async_test_db): + """Test org distribution with no memberships returns empty list.""" + _test_engine, AsyncTestingSessionLocal = async_test_db + async with AsyncTestingSessionLocal() as session: + result = await organization_service.get_org_distribution(session, limit=6) + assert isinstance(result, list) + + @pytest.mark.asyncio + async def test_get_org_distribution_with_members( + self, async_test_db, async_test_user + ): + """Test org distribution returns org name and member count.""" + _test_engine, AsyncTestingSessionLocal = async_test_db + async with AsyncTestingSessionLocal() as session: + created = await organization_service.create_organization( + session, obj_in=_make_org_create() + ) + await organization_service.add_member( + session, + organization_id=created.id, + user_id=async_test_user.id, + ) + + async with AsyncTestingSessionLocal() as session: + result = await organization_service.get_org_distribution(session, limit=6) + assert isinstance(result, list) + assert len(result) >= 1 + entry = result[0] + assert "name" in entry + assert "value" in entry + assert entry["value"] >= 1 diff --git a/backend/tests/services/test_session_service.py b/backend/tests/services/test_session_service.py new file mode 100644 index 0000000..e6dfb2b --- /dev/null +++ b/backend/tests/services/test_session_service.py @@ -0,0 +1,292 @@ +# tests/services/test_session_service.py +"""Tests for the SessionService class.""" + +import uuid +from datetime import UTC, datetime, timedelta + +import pytest +import pytest_asyncio + +from app.schemas.sessions import SessionCreate +from app.services.session_service import SessionService, session_service + + +def _make_session_create(user_id, jti=None) -> SessionCreate: + """Helper to build a SessionCreate with sensible defaults.""" + now = datetime.now(UTC) + return SessionCreate( + user_id=user_id, + refresh_token_jti=jti or str(uuid.uuid4()), + ip_address="127.0.0.1", + user_agent="pytest/test", + device_name="Test Device", + device_id="test-device-id", + last_used_at=now, + expires_at=now + timedelta(days=7), + location_city="TestCity", + location_country="TestCountry", + ) + + +class TestCreateSession: + """Tests for SessionService.create_session method.""" + + @pytest.mark.asyncio + async def test_create_session(self, async_test_db, async_test_user): + """Test creating a session returns a UserSession with correct fields.""" + _test_engine, AsyncTestingSessionLocal = async_test_db + obj_in = _make_session_create(async_test_user.id) + async with AsyncTestingSessionLocal() as session: + result = await session_service.create_session(session, obj_in=obj_in) + assert result is not None + assert result.user_id == async_test_user.id + assert result.refresh_token_jti == obj_in.refresh_token_jti + assert result.is_active is True + assert result.ip_address == "127.0.0.1" + + +class TestGetSession: + """Tests for SessionService.get_session method.""" + + @pytest.mark.asyncio + async def test_get_session_found(self, async_test_db, async_test_user): + """Test getting a session by ID returns the session.""" + _test_engine, AsyncTestingSessionLocal = async_test_db + obj_in = _make_session_create(async_test_user.id) + + async with AsyncTestingSessionLocal() as session: + created = await session_service.create_session(session, obj_in=obj_in) + + async with AsyncTestingSessionLocal() as session: + result = await session_service.get_session(session, str(created.id)) + assert result is not None + assert result.id == created.id + + @pytest.mark.asyncio + async def test_get_session_not_found(self, async_test_db): + """Test getting a non-existent session returns None.""" + _test_engine, AsyncTestingSessionLocal = async_test_db + async with AsyncTestingSessionLocal() as session: + result = await session_service.get_session(session, str(uuid.uuid4())) + assert result is None + + +class TestGetUserSessions: + """Tests for SessionService.get_user_sessions method.""" + + @pytest.mark.asyncio + async def test_get_user_sessions_active_only(self, async_test_db, async_test_user): + """Test getting active sessions for a user returns only active sessions.""" + _test_engine, AsyncTestingSessionLocal = async_test_db + obj_in = _make_session_create(async_test_user.id) + + async with AsyncTestingSessionLocal() as session: + await session_service.create_session(session, obj_in=obj_in) + + async with AsyncTestingSessionLocal() as session: + sessions = await session_service.get_user_sessions( + session, user_id=str(async_test_user.id), active_only=True + ) + assert isinstance(sessions, list) + assert len(sessions) >= 1 + for s in sessions: + assert s.is_active is True + + @pytest.mark.asyncio + async def test_get_user_sessions_all(self, async_test_db, async_test_user): + """Test getting all sessions (active and inactive) for a user.""" + _test_engine, AsyncTestingSessionLocal = async_test_db + obj_in = _make_session_create(async_test_user.id) + + async with AsyncTestingSessionLocal() as session: + created = await session_service.create_session(session, obj_in=obj_in) + await session_service.deactivate(session, session_id=str(created.id)) + + async with AsyncTestingSessionLocal() as session: + sessions = await session_service.get_user_sessions( + session, user_id=str(async_test_user.id), active_only=False + ) + assert isinstance(sessions, list) + assert len(sessions) >= 1 + + +class TestGetActiveByJti: + """Tests for SessionService.get_active_by_jti method.""" + + @pytest.mark.asyncio + async def test_get_active_by_jti_found(self, async_test_db, async_test_user): + """Test getting an active session by JTI returns the session.""" + _test_engine, AsyncTestingSessionLocal = async_test_db + jti = str(uuid.uuid4()) + obj_in = _make_session_create(async_test_user.id, jti=jti) + + async with AsyncTestingSessionLocal() as session: + await session_service.create_session(session, obj_in=obj_in) + + async with AsyncTestingSessionLocal() as session: + result = await session_service.get_active_by_jti(session, jti=jti) + assert result is not None + assert result.refresh_token_jti == jti + assert result.is_active is True + + @pytest.mark.asyncio + async def test_get_active_by_jti_not_found(self, async_test_db): + """Test getting an active session by non-existent JTI returns None.""" + _test_engine, AsyncTestingSessionLocal = async_test_db + async with AsyncTestingSessionLocal() as session: + result = await session_service.get_active_by_jti( + session, jti=str(uuid.uuid4()) + ) + assert result is None + + +class TestGetByJti: + """Tests for SessionService.get_by_jti method.""" + + @pytest.mark.asyncio + async def test_get_by_jti_active(self, async_test_db, async_test_user): + """Test getting a session (active or inactive) by JTI.""" + _test_engine, AsyncTestingSessionLocal = async_test_db + jti = str(uuid.uuid4()) + obj_in = _make_session_create(async_test_user.id, jti=jti) + + async with AsyncTestingSessionLocal() as session: + await session_service.create_session(session, obj_in=obj_in) + + async with AsyncTestingSessionLocal() as session: + result = await session_service.get_by_jti(session, jti=jti) + assert result is not None + assert result.refresh_token_jti == jti + + +class TestDeactivate: + """Tests for SessionService.deactivate method.""" + + @pytest.mark.asyncio + async def test_deactivate_session(self, async_test_db, async_test_user): + """Test deactivating a session sets is_active to False.""" + _test_engine, AsyncTestingSessionLocal = async_test_db + obj_in = _make_session_create(async_test_user.id) + + async with AsyncTestingSessionLocal() as session: + created = await session_service.create_session(session, obj_in=obj_in) + session_id = str(created.id) + + async with AsyncTestingSessionLocal() as session: + deactivated = await session_service.deactivate( + session, session_id=session_id + ) + assert deactivated is not None + assert deactivated.is_active is False + + +class TestDeactivateAllUserSessions: + """Tests for SessionService.deactivate_all_user_sessions method.""" + + @pytest.mark.asyncio + async def test_deactivate_all_user_sessions(self, async_test_db, async_test_user): + """Test deactivating all sessions for a user returns count deactivated.""" + _test_engine, AsyncTestingSessionLocal = async_test_db + + async with AsyncTestingSessionLocal() as session: + await session_service.create_session( + session, obj_in=_make_session_create(async_test_user.id) + ) + await session_service.create_session( + session, obj_in=_make_session_create(async_test_user.id) + ) + + async with AsyncTestingSessionLocal() as session: + count = await session_service.deactivate_all_user_sessions( + session, user_id=str(async_test_user.id) + ) + assert count >= 2 + + async with AsyncTestingSessionLocal() as session: + active_sessions = await session_service.get_user_sessions( + session, user_id=str(async_test_user.id), active_only=True + ) + assert len(active_sessions) == 0 + + +class TestUpdateRefreshToken: + """Tests for SessionService.update_refresh_token method.""" + + @pytest.mark.asyncio + async def test_update_refresh_token(self, async_test_db, async_test_user): + """Test rotating a session's refresh token updates JTI and expiry.""" + _test_engine, AsyncTestingSessionLocal = async_test_db + obj_in = _make_session_create(async_test_user.id) + + async with AsyncTestingSessionLocal() as session: + created = await session_service.create_session(session, obj_in=obj_in) + session_id = str(created.id) + + new_jti = str(uuid.uuid4()) + new_expires_at = datetime.now(UTC) + timedelta(days=14) + + async with AsyncTestingSessionLocal() as session: + result = await session_service.get_session(session, session_id) + updated = await session_service.update_refresh_token( + session, + session=result, + new_jti=new_jti, + new_expires_at=new_expires_at, + ) + assert updated.refresh_token_jti == new_jti + + +class TestCleanupExpiredForUser: + """Tests for SessionService.cleanup_expired_for_user method.""" + + @pytest.mark.asyncio + async def test_cleanup_expired_for_user(self, async_test_db, async_test_user): + """Test cleaning up expired inactive sessions returns count removed.""" + _test_engine, AsyncTestingSessionLocal = async_test_db + now = datetime.now(UTC) + # Create a session that is already expired + obj_in = SessionCreate( + user_id=async_test_user.id, + refresh_token_jti=str(uuid.uuid4()), + ip_address="127.0.0.1", + user_agent="pytest/test", + last_used_at=now - timedelta(days=8), + expires_at=now - timedelta(days=1), + ) + + async with AsyncTestingSessionLocal() as session: + created = await session_service.create_session(session, obj_in=obj_in) + session_id = str(created.id) + + # Deactivate it so it qualifies for cleanup (requires is_active=False AND expired) + async with AsyncTestingSessionLocal() as session: + await session_service.deactivate(session, session_id=session_id) + + async with AsyncTestingSessionLocal() as session: + count = await session_service.cleanup_expired_for_user( + session, user_id=str(async_test_user.id) + ) + assert isinstance(count, int) + assert count >= 1 + + +class TestGetAllSessions: + """Tests for SessionService.get_all_sessions method.""" + + @pytest.mark.asyncio + async def test_get_all_sessions(self, async_test_db, async_test_user): + """Test getting all sessions with pagination returns tuple of list and count.""" + _test_engine, AsyncTestingSessionLocal = async_test_db + obj_in = _make_session_create(async_test_user.id) + + async with AsyncTestingSessionLocal() as session: + await session_service.create_session(session, obj_in=obj_in) + + async with AsyncTestingSessionLocal() as session: + sessions, count = await session_service.get_all_sessions( + session, skip=0, limit=10, active_only=True, with_user=False + ) + assert isinstance(sessions, list) + assert isinstance(count, int) + assert count >= 1 + assert len(sessions) >= 1 diff --git a/backend/tests/services/test_user_service.py b/backend/tests/services/test_user_service.py new file mode 100644 index 0000000..47ab065 --- /dev/null +++ b/backend/tests/services/test_user_service.py @@ -0,0 +1,214 @@ +# tests/services/test_user_service.py +"""Tests for the UserService class.""" + +import uuid + +import pytest +import pytest_asyncio +from sqlalchemy import select + +from app.core.exceptions import NotFoundError +from app.models.user import User +from app.schemas.users import UserCreate, UserUpdate +from app.services.user_service import UserService, user_service + + +class TestGetUser: + """Tests for UserService.get_user method.""" + + @pytest.mark.asyncio + async def test_get_user_found(self, async_test_db, async_test_user): + """Test getting an existing user by ID returns the user.""" + _test_engine, AsyncTestingSessionLocal = async_test_db + async with AsyncTestingSessionLocal() as session: + result = await user_service.get_user(session, str(async_test_user.id)) + assert result is not None + assert result.id == async_test_user.id + assert result.email == async_test_user.email + + @pytest.mark.asyncio + async def test_get_user_not_found(self, async_test_db): + """Test getting a non-existent user raises NotFoundError.""" + _test_engine, AsyncTestingSessionLocal = async_test_db + non_existent_id = str(uuid.uuid4()) + async with AsyncTestingSessionLocal() as session: + with pytest.raises(NotFoundError): + await user_service.get_user(session, non_existent_id) + + +class TestGetByEmail: + """Tests for UserService.get_by_email method.""" + + @pytest.mark.asyncio + async def test_get_by_email_found(self, async_test_db, async_test_user): + """Test getting an existing user by email returns the user.""" + _test_engine, AsyncTestingSessionLocal = async_test_db + async with AsyncTestingSessionLocal() as session: + result = await user_service.get_by_email(session, async_test_user.email) + assert result is not None + assert result.id == async_test_user.id + assert result.email == async_test_user.email + + @pytest.mark.asyncio + async def test_get_by_email_not_found(self, async_test_db): + """Test getting a user by non-existent email returns None.""" + _test_engine, AsyncTestingSessionLocal = async_test_db + async with AsyncTestingSessionLocal() as session: + result = await user_service.get_by_email(session, "nonexistent@example.com") + assert result is None + + +class TestCreateUser: + """Tests for UserService.create_user method.""" + + @pytest.mark.asyncio + async def test_create_user(self, async_test_db): + """Test creating a new user with valid data.""" + _test_engine, AsyncTestingSessionLocal = async_test_db + unique_email = f"test_{uuid.uuid4()}@example.com" + user_data = UserCreate( + email=unique_email, + password="TestPassword123!", + first_name="New", + last_name="User", + ) + async with AsyncTestingSessionLocal() as session: + result = await user_service.create_user(session, user_data) + assert result is not None + assert result.email == unique_email + assert result.first_name == "New" + assert result.last_name == "User" + assert result.is_active is True + + +class TestUpdateUser: + """Tests for UserService.update_user method.""" + + @pytest.mark.asyncio + async def test_update_user(self, async_test_db, async_test_user): + """Test updating a user's first_name.""" + _test_engine, AsyncTestingSessionLocal = async_test_db + async with AsyncTestingSessionLocal() as session: + user = await user_service.get_user(session, str(async_test_user.id)) + updated = await user_service.update_user( + session, + user=user, + obj_in=UserUpdate(first_name="Updated"), + ) + assert updated.first_name == "Updated" + assert updated.id == async_test_user.id + + +class TestSoftDeleteUser: + """Tests for UserService.soft_delete_user method.""" + + @pytest.mark.asyncio + async def test_soft_delete_user(self, async_test_db, async_test_user): + """Test soft-deleting a user sets deleted_at.""" + _test_engine, AsyncTestingSessionLocal = async_test_db + async with AsyncTestingSessionLocal() as session: + await user_service.soft_delete_user(session, str(async_test_user.id)) + + async with AsyncTestingSessionLocal() as session: + result = await session.execute( + select(User).where(User.id == async_test_user.id) + ) + user = result.scalar_one_or_none() + assert user is not None + assert user.deleted_at is not None + + +class TestListUsers: + """Tests for UserService.list_users method.""" + + @pytest.mark.asyncio + async def test_list_users(self, async_test_db, async_test_user): + """Test listing users with pagination returns correct results.""" + _test_engine, AsyncTestingSessionLocal = async_test_db + async with AsyncTestingSessionLocal() as session: + users, count = await user_service.list_users(session, skip=0, limit=10) + assert isinstance(users, list) + assert isinstance(count, int) + assert count >= 1 + assert len(users) >= 1 + + @pytest.mark.asyncio + async def test_list_users_with_search(self, async_test_db, async_test_user): + """Test listing users with email fragment search returns matching users.""" + _test_engine, AsyncTestingSessionLocal = async_test_db + # Search by partial email fragment of the test user + email_fragment = async_test_user.email.split("@")[0] + async with AsyncTestingSessionLocal() as session: + users, count = await user_service.list_users( + session, skip=0, limit=10, search=email_fragment + ) + assert isinstance(users, list) + assert count >= 1 + emails = [u.email for u in users] + assert async_test_user.email in emails + + +class TestBulkUpdateStatus: + """Tests for UserService.bulk_update_status method.""" + + @pytest.mark.asyncio + async def test_bulk_update_status(self, async_test_db, async_test_user): + """Test bulk activating users returns correct count.""" + _test_engine, AsyncTestingSessionLocal = async_test_db + async with AsyncTestingSessionLocal() as session: + count = await user_service.bulk_update_status( + session, + user_ids=[async_test_user.id], + is_active=True, + ) + assert count >= 1 + + async with AsyncTestingSessionLocal() as session: + result = await session.execute( + select(User).where(User.id == async_test_user.id) + ) + user = result.scalar_one_or_none() + assert user is not None + assert user.is_active is True + + +class TestBulkSoftDelete: + """Tests for UserService.bulk_soft_delete method.""" + + @pytest.mark.asyncio + async def test_bulk_soft_delete(self, async_test_db, async_test_user): + """Test bulk soft-deleting users returns correct count.""" + _test_engine, AsyncTestingSessionLocal = async_test_db + async with AsyncTestingSessionLocal() as session: + count = await user_service.bulk_soft_delete( + session, + user_ids=[async_test_user.id], + ) + assert count >= 1 + + async with AsyncTestingSessionLocal() as session: + result = await session.execute( + select(User).where(User.id == async_test_user.id) + ) + user = result.scalar_one_or_none() + assert user is not None + assert user.deleted_at is not None + + +class TestGetStats: + """Tests for UserService.get_stats method.""" + + @pytest.mark.asyncio + async def test_get_stats(self, async_test_db, async_test_user): + """Test get_stats returns dict with expected keys and correct counts.""" + _test_engine, AsyncTestingSessionLocal = async_test_db + async with AsyncTestingSessionLocal() as session: + stats = await user_service.get_stats(session) + assert "total_users" in stats + assert "active_count" in stats + assert "inactive_count" in stats + assert "all_users" in stats + assert stats["total_users"] >= 1 + assert stats["active_count"] >= 1 + assert isinstance(stats["all_users"], list) + assert len(stats["all_users"]) >= 1