forked from cardosofelipe/pragma-stack
refactor(backend): enforce route→service→repo layered architecture
- introduce custom repository exception hierarchy (DuplicateEntryError, IntegrityConstraintError, InvalidInputError) replacing raw ValueError - eliminate all direct repository imports and raw SQL from route layer - add UserService, SessionService, OrganizationService to service layer - add get_stats/get_org_distribution service methods replacing admin inline SQL - fix timing side-channel in authenticate_user via dummy bcrypt check - replace SHA-256 client secret fallback with explicit InvalidClientError - replace assert with InvalidGrantError in authorization code exchange - replace N+1 token revocation loops with bulk UPDATE statements - rename oauth account token fields (drop misleading 'encrypted' suffix) - add Alembic migration 0003 for token field column rename - add 45 new service/repository tests; 975 passing, 94% coverage
This commit is contained in:
@@ -0,0 +1,28 @@
|
|||||||
|
"""rename oauth account token fields drop encrypted suffix
|
||||||
|
|
||||||
|
Revision ID: 0003
|
||||||
|
Revises: 0002
|
||||||
|
Create Date: 2026-02-27 01:03:18.869178
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
from collections.abc import Sequence
|
||||||
|
|
||||||
|
import sqlalchemy as sa
|
||||||
|
from alembic import op
|
||||||
|
|
||||||
|
# revision identifiers, used by Alembic.
|
||||||
|
revision: str = "0003"
|
||||||
|
down_revision: str | None = "0002"
|
||||||
|
branch_labels: str | Sequence[str] | None = None
|
||||||
|
depends_on: str | Sequence[str] | None = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
op.alter_column("oauth_accounts", "access_token_encrypted", new_column_name="access_token")
|
||||||
|
op.alter_column("oauth_accounts", "refresh_token_encrypted", new_column_name="refresh_token")
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
op.alter_column("oauth_accounts", "access_token", new_column_name="access_token_encrypted")
|
||||||
|
op.alter_column("oauth_accounts", "refresh_token", new_column_name="refresh_token_encrypted")
|
||||||
@@ -1,12 +1,12 @@
|
|||||||
from fastapi import Depends, Header, HTTPException, status
|
from fastapi import Depends, Header, HTTPException, status
|
||||||
from fastapi.security import OAuth2PasswordBearer
|
from fastapi.security import OAuth2PasswordBearer
|
||||||
from fastapi.security.utils import get_authorization_scheme_param
|
from fastapi.security.utils import get_authorization_scheme_param
|
||||||
from sqlalchemy import select
|
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
from app.core.auth import TokenExpiredError, TokenInvalidError, get_token_data
|
from app.core.auth import TokenExpiredError, TokenInvalidError, get_token_data
|
||||||
from app.core.database import get_db
|
from app.core.database import get_db
|
||||||
from app.models.user import User
|
from app.models.user import User
|
||||||
|
from app.repositories.user import user_repo
|
||||||
|
|
||||||
# OAuth2 configuration
|
# OAuth2 configuration
|
||||||
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/api/v1/auth/login")
|
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/api/v1/auth/login")
|
||||||
@@ -32,9 +32,8 @@ async def get_current_user(
|
|||||||
# Decode token and get user ID
|
# Decode token and get user ID
|
||||||
token_data = get_token_data(token)
|
token_data = get_token_data(token)
|
||||||
|
|
||||||
# Get user from database
|
# Get user from database via repository
|
||||||
result = await db.execute(select(User).where(User.id == token_data.user_id))
|
user = await user_repo.get(db, id=str(token_data.user_id))
|
||||||
user = result.scalar_one_or_none()
|
|
||||||
|
|
||||||
if not user:
|
if not user:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
@@ -144,8 +143,7 @@ async def get_optional_current_user(
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
token_data = get_token_data(token)
|
token_data = get_token_data(token)
|
||||||
result = await db.execute(select(User).where(User.id == token_data.user_id))
|
user = await user_repo.get(db, id=str(token_data.user_id))
|
||||||
user = result.scalar_one_or_none()
|
|
||||||
if not user or not user.is_active:
|
if not user or not user.is_active:
|
||||||
return None
|
return None
|
||||||
return user
|
return user
|
||||||
|
|||||||
@@ -15,9 +15,9 @@ from sqlalchemy.ext.asyncio import AsyncSession
|
|||||||
|
|
||||||
from app.api.dependencies.auth import get_current_user
|
from app.api.dependencies.auth import get_current_user
|
||||||
from app.core.database import get_db
|
from app.core.database import get_db
|
||||||
from app.crud.organization import organization as organization_crud
|
|
||||||
from app.models.user import User
|
from app.models.user import User
|
||||||
from app.models.user_organization import OrganizationRole
|
from app.models.user_organization import OrganizationRole
|
||||||
|
from app.services.organization_service import organization_service
|
||||||
|
|
||||||
|
|
||||||
def require_superuser(current_user: User = Depends(get_current_user)) -> User:
|
def require_superuser(current_user: User = Depends(get_current_user)) -> User:
|
||||||
@@ -81,7 +81,7 @@ class OrganizationPermission:
|
|||||||
return current_user
|
return current_user
|
||||||
|
|
||||||
# Get user's role in organization
|
# Get user's role in organization
|
||||||
user_role = await organization_crud.get_user_role_in_org(
|
user_role = await organization_service.get_user_role_in_org(
|
||||||
db, user_id=current_user.id, organization_id=organization_id
|
db, user_id=current_user.id, organization_id=organization_id
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -123,7 +123,7 @@ async def require_org_membership(
|
|||||||
if current_user.is_superuser:
|
if current_user.is_superuser:
|
||||||
return current_user
|
return current_user
|
||||||
|
|
||||||
user_role = await organization_crud.get_user_role_in_org(
|
user_role = await organization_service.get_user_role_in_org(
|
||||||
db, user_id=current_user.id, organization_id=organization_id
|
db, user_id=current_user.id, organization_id=organization_id
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
41
backend/app/api/dependencies/services.py
Normal file
41
backend/app/api/dependencies/services.py
Normal file
@@ -0,0 +1,41 @@
|
|||||||
|
# app/api/dependencies/services.py
|
||||||
|
"""FastAPI dependency functions for service singletons."""
|
||||||
|
|
||||||
|
from app.services import oauth_provider_service
|
||||||
|
from app.services.auth_service import AuthService
|
||||||
|
from app.services.oauth_service import OAuthService
|
||||||
|
from app.services.organization_service import OrganizationService, organization_service
|
||||||
|
from app.services.session_service import SessionService, session_service
|
||||||
|
from app.services.user_service import UserService, user_service
|
||||||
|
|
||||||
|
|
||||||
|
def get_auth_service() -> AuthService:
|
||||||
|
"""Return the AuthService singleton for dependency injection."""
|
||||||
|
from app.services.auth_service import AuthService as _AuthService
|
||||||
|
|
||||||
|
return _AuthService()
|
||||||
|
|
||||||
|
|
||||||
|
def get_user_service() -> UserService:
|
||||||
|
"""Return the UserService singleton for dependency injection."""
|
||||||
|
return user_service
|
||||||
|
|
||||||
|
|
||||||
|
def get_organization_service() -> OrganizationService:
|
||||||
|
"""Return the OrganizationService singleton for dependency injection."""
|
||||||
|
return organization_service
|
||||||
|
|
||||||
|
|
||||||
|
def get_session_service() -> SessionService:
|
||||||
|
"""Return the SessionService singleton for dependency injection."""
|
||||||
|
return session_service
|
||||||
|
|
||||||
|
|
||||||
|
def get_oauth_service() -> OAuthService:
|
||||||
|
"""Return OAuthService for dependency injection."""
|
||||||
|
return OAuthService()
|
||||||
|
|
||||||
|
|
||||||
|
def get_oauth_provider_service():
|
||||||
|
"""Return the oauth_provider_service module for dependency injection."""
|
||||||
|
return oauth_provider_service
|
||||||
@@ -14,7 +14,6 @@ from uuid import UUID
|
|||||||
|
|
||||||
from fastapi import APIRouter, Depends, Query, status
|
from fastapi import APIRouter, Depends, Query, status
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
from sqlalchemy import func, select
|
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
from app.api.dependencies.permissions import require_superuser
|
from app.api.dependencies.permissions import require_superuser
|
||||||
@@ -25,12 +24,12 @@ from app.core.exceptions import (
|
|||||||
ErrorCode,
|
ErrorCode,
|
||||||
NotFoundError,
|
NotFoundError,
|
||||||
)
|
)
|
||||||
from app.crud.organization import organization as organization_crud
|
from app.core.repository_exceptions import DuplicateEntryError
|
||||||
from app.crud.session import session as session_crud
|
|
||||||
from app.crud.user import user as user_crud
|
|
||||||
from app.models.organization import Organization
|
|
||||||
from app.models.user import User
|
from app.models.user import User
|
||||||
from app.models.user_organization import OrganizationRole, UserOrganization
|
from app.models.user_organization import OrganizationRole
|
||||||
|
from app.services.organization_service import organization_service
|
||||||
|
from app.services.session_service import session_service
|
||||||
|
from app.services.user_service import user_service
|
||||||
from app.schemas.common import (
|
from app.schemas.common import (
|
||||||
MessageResponse,
|
MessageResponse,
|
||||||
PaginatedResponse,
|
PaginatedResponse,
|
||||||
@@ -178,38 +177,29 @@ async def admin_get_stats(
|
|||||||
"""Get admin dashboard statistics with real data from database."""
|
"""Get admin dashboard statistics with real data from database."""
|
||||||
from app.core.config import settings
|
from app.core.config import settings
|
||||||
|
|
||||||
# Check if we have any data
|
stats = await user_service.get_stats(db)
|
||||||
total_users_query = select(func.count()).select_from(User)
|
total_users = stats["total_users"]
|
||||||
total_users = (await db.execute(total_users_query)).scalar() or 0
|
active_count = stats["active_count"]
|
||||||
|
inactive_count = stats["inactive_count"]
|
||||||
|
all_users = stats["all_users"]
|
||||||
|
|
||||||
# If database is essentially empty (only admin user), return demo data
|
# If database is essentially empty (only admin user), return demo data
|
||||||
if total_users <= 1 and settings.DEMO_MODE: # pragma: no cover
|
if total_users <= 1 and settings.DEMO_MODE: # pragma: no cover
|
||||||
logger.info("Returning demo stats data (empty database in demo mode)")
|
logger.info("Returning demo stats data (empty database in demo mode)")
|
||||||
return _generate_demo_stats()
|
return _generate_demo_stats()
|
||||||
|
|
||||||
# 1. User Growth (Last 30 days) - Improved calculation
|
# 1. User Growth (Last 30 days)
|
||||||
datetime.now(UTC) - timedelta(days=30)
|
|
||||||
|
|
||||||
# Get all users with their creation dates
|
|
||||||
all_users_query = select(User).order_by(User.created_at)
|
|
||||||
result = await db.execute(all_users_query)
|
|
||||||
all_users = result.scalars().all()
|
|
||||||
|
|
||||||
# Build cumulative counts per day
|
|
||||||
user_growth = []
|
user_growth = []
|
||||||
for i in range(29, -1, -1):
|
for i in range(29, -1, -1):
|
||||||
date = datetime.now(UTC) - timedelta(days=i)
|
date = datetime.now(UTC) - timedelta(days=i)
|
||||||
date_start = date.replace(hour=0, minute=0, second=0, microsecond=0, tzinfo=UTC)
|
date_start = date.replace(hour=0, minute=0, second=0, microsecond=0, tzinfo=UTC)
|
||||||
date_end = date_start + timedelta(days=1)
|
date_end = date_start + timedelta(days=1)
|
||||||
|
|
||||||
# Count all users created before end of this day
|
|
||||||
# Make comparison timezone-aware
|
|
||||||
total_users_on_date = sum(
|
total_users_on_date = sum(
|
||||||
1
|
1
|
||||||
for u in all_users
|
for u in all_users
|
||||||
if u.created_at and u.created_at.replace(tzinfo=UTC) < date_end
|
if u.created_at and u.created_at.replace(tzinfo=UTC) < date_end
|
||||||
)
|
)
|
||||||
# Count active users created before end of this day
|
|
||||||
active_users_on_date = sum(
|
active_users_on_date = sum(
|
||||||
1
|
1
|
||||||
for u in all_users
|
for u in all_users
|
||||||
@@ -227,27 +217,16 @@ async def admin_get_stats(
|
|||||||
)
|
)
|
||||||
|
|
||||||
# 2. Organization Distribution - Top 6 organizations by member count
|
# 2. Organization Distribution - Top 6 organizations by member count
|
||||||
org_query = (
|
org_rows = await organization_service.get_org_distribution(db, limit=6)
|
||||||
select(Organization.name, func.count(UserOrganization.user_id).label("count"))
|
org_dist = [OrgDistributionData(name=r["name"], value=r["value"]) for r in org_rows]
|
||||||
.join(UserOrganization, Organization.id == UserOrganization.organization_id)
|
|
||||||
.group_by(Organization.name)
|
|
||||||
.order_by(func.count(UserOrganization.user_id).desc())
|
|
||||||
.limit(6)
|
|
||||||
)
|
|
||||||
result = await db.execute(org_query)
|
|
||||||
org_dist = [
|
|
||||||
OrgDistributionData(name=row.name, value=row.count) for row in result.all()
|
|
||||||
]
|
|
||||||
|
|
||||||
# 3. User Registration Activity (Last 14 days) - NEW
|
# 3. User Registration Activity (Last 14 days)
|
||||||
registration_activity = []
|
registration_activity = []
|
||||||
for i in range(13, -1, -1):
|
for i in range(13, -1, -1):
|
||||||
date = datetime.now(UTC) - timedelta(days=i)
|
date = datetime.now(UTC) - timedelta(days=i)
|
||||||
date_start = date.replace(hour=0, minute=0, second=0, microsecond=0, tzinfo=UTC)
|
date_start = date.replace(hour=0, minute=0, second=0, microsecond=0, tzinfo=UTC)
|
||||||
date_end = date_start + timedelta(days=1)
|
date_end = date_start + timedelta(days=1)
|
||||||
|
|
||||||
# Count users created on this specific day
|
|
||||||
# Make comparison timezone-aware
|
|
||||||
day_registrations = sum(
|
day_registrations = sum(
|
||||||
1
|
1
|
||||||
for u in all_users
|
for u in all_users
|
||||||
@@ -263,14 +242,6 @@ async def admin_get_stats(
|
|||||||
)
|
)
|
||||||
|
|
||||||
# 4. User Status - Active vs Inactive
|
# 4. User Status - Active vs Inactive
|
||||||
active_query = select(func.count()).select_from(User).where(User.is_active)
|
|
||||||
inactive_query = (
|
|
||||||
select(func.count()).select_from(User).where(User.is_active.is_(False))
|
|
||||||
)
|
|
||||||
|
|
||||||
active_count = (await db.execute(active_query)).scalar() or 0
|
|
||||||
inactive_count = (await db.execute(inactive_query)).scalar() or 0
|
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
f"User status counts - Active: {active_count}, Inactive: {inactive_count}"
|
f"User status counts - Active: {active_count}, Inactive: {inactive_count}"
|
||||||
)
|
)
|
||||||
@@ -321,7 +292,7 @@ async def admin_list_users(
|
|||||||
filters["is_superuser"] = is_superuser
|
filters["is_superuser"] = is_superuser
|
||||||
|
|
||||||
# Get users with search
|
# Get users with search
|
||||||
users, total = await user_crud.get_multi_with_total(
|
users, total = await user_service.list_users(
|
||||||
db,
|
db,
|
||||||
skip=pagination.offset,
|
skip=pagination.offset,
|
||||||
limit=pagination.limit,
|
limit=pagination.limit,
|
||||||
@@ -364,12 +335,12 @@ async def admin_create_user(
|
|||||||
Allows setting is_superuser and other fields.
|
Allows setting is_superuser and other fields.
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
user = await user_crud.create(db, obj_in=user_in)
|
user = await user_service.create_user(db, user_in)
|
||||||
logger.info(f"Admin {admin.email} created user {user.email}")
|
logger.info(f"Admin {admin.email} created user {user.email}")
|
||||||
return user
|
return user
|
||||||
except ValueError as e:
|
except DuplicateEntryError as e:
|
||||||
logger.warning(f"Failed to create user: {e!s}")
|
logger.warning(f"Failed to create user: {e!s}")
|
||||||
raise NotFoundError(message=str(e), error_code=ErrorCode.USER_ALREADY_EXISTS)
|
raise DuplicateError(message=str(e), error_code=ErrorCode.USER_ALREADY_EXISTS)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error creating user (admin): {e!s}", exc_info=True)
|
logger.error(f"Error creating user (admin): {e!s}", exc_info=True)
|
||||||
raise
|
raise
|
||||||
@@ -388,11 +359,7 @@ async def admin_get_user(
|
|||||||
db: AsyncSession = Depends(get_db),
|
db: AsyncSession = Depends(get_db),
|
||||||
) -> Any:
|
) -> Any:
|
||||||
"""Get detailed information about a specific user."""
|
"""Get detailed information about a specific user."""
|
||||||
user = await user_crud.get(db, id=user_id)
|
user = await user_service.get_user(db, str(user_id))
|
||||||
if not user:
|
|
||||||
raise NotFoundError(
|
|
||||||
message=f"User {user_id} not found", error_code=ErrorCode.USER_NOT_FOUND
|
|
||||||
)
|
|
||||||
return user
|
return user
|
||||||
|
|
||||||
|
|
||||||
@@ -411,18 +378,11 @@ async def admin_update_user(
|
|||||||
) -> Any:
|
) -> Any:
|
||||||
"""Update user information with admin privileges."""
|
"""Update user information with admin privileges."""
|
||||||
try:
|
try:
|
||||||
user = await user_crud.get(db, id=user_id)
|
user = await user_service.get_user(db, str(user_id))
|
||||||
if not user:
|
updated_user = await user_service.update_user(db, user=user, obj_in=user_in)
|
||||||
raise NotFoundError(
|
|
||||||
message=f"User {user_id} not found", error_code=ErrorCode.USER_NOT_FOUND
|
|
||||||
)
|
|
||||||
|
|
||||||
updated_user = await user_crud.update(db, db_obj=user, obj_in=user_in)
|
|
||||||
logger.info(f"Admin {admin.email} updated user {updated_user.email}")
|
logger.info(f"Admin {admin.email} updated user {updated_user.email}")
|
||||||
return updated_user
|
return updated_user
|
||||||
|
|
||||||
except NotFoundError:
|
|
||||||
raise
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error updating user (admin): {e!s}", exc_info=True)
|
logger.error(f"Error updating user (admin): {e!s}", exc_info=True)
|
||||||
raise
|
raise
|
||||||
@@ -442,11 +402,7 @@ async def admin_delete_user(
|
|||||||
) -> Any:
|
) -> Any:
|
||||||
"""Soft delete a user (sets deleted_at timestamp)."""
|
"""Soft delete a user (sets deleted_at timestamp)."""
|
||||||
try:
|
try:
|
||||||
user = await user_crud.get(db, id=user_id)
|
user = await user_service.get_user(db, str(user_id))
|
||||||
if not user:
|
|
||||||
raise NotFoundError(
|
|
||||||
message=f"User {user_id} not found", error_code=ErrorCode.USER_NOT_FOUND
|
|
||||||
)
|
|
||||||
|
|
||||||
# Prevent deleting yourself
|
# Prevent deleting yourself
|
||||||
if user.id == admin.id:
|
if user.id == admin.id:
|
||||||
@@ -456,15 +412,13 @@ async def admin_delete_user(
|
|||||||
error_code=ErrorCode.OPERATION_FORBIDDEN,
|
error_code=ErrorCode.OPERATION_FORBIDDEN,
|
||||||
)
|
)
|
||||||
|
|
||||||
await user_crud.soft_delete(db, id=user_id)
|
await user_service.soft_delete_user(db, str(user_id))
|
||||||
logger.info(f"Admin {admin.email} deleted user {user.email}")
|
logger.info(f"Admin {admin.email} deleted user {user.email}")
|
||||||
|
|
||||||
return MessageResponse(
|
return MessageResponse(
|
||||||
success=True, message=f"User {user.email} has been deleted"
|
success=True, message=f"User {user.email} has been deleted"
|
||||||
)
|
)
|
||||||
|
|
||||||
except NotFoundError:
|
|
||||||
raise
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error deleting user (admin): {e!s}", exc_info=True)
|
logger.error(f"Error deleting user (admin): {e!s}", exc_info=True)
|
||||||
raise
|
raise
|
||||||
@@ -484,21 +438,14 @@ async def admin_activate_user(
|
|||||||
) -> Any:
|
) -> Any:
|
||||||
"""Activate a user account."""
|
"""Activate a user account."""
|
||||||
try:
|
try:
|
||||||
user = await user_crud.get(db, id=user_id)
|
user = await user_service.get_user(db, str(user_id))
|
||||||
if not user:
|
await user_service.update_user(db, user=user, obj_in={"is_active": True})
|
||||||
raise NotFoundError(
|
|
||||||
message=f"User {user_id} not found", error_code=ErrorCode.USER_NOT_FOUND
|
|
||||||
)
|
|
||||||
|
|
||||||
await user_crud.update(db, db_obj=user, obj_in={"is_active": True})
|
|
||||||
logger.info(f"Admin {admin.email} activated user {user.email}")
|
logger.info(f"Admin {admin.email} activated user {user.email}")
|
||||||
|
|
||||||
return MessageResponse(
|
return MessageResponse(
|
||||||
success=True, message=f"User {user.email} has been activated"
|
success=True, message=f"User {user.email} has been activated"
|
||||||
)
|
)
|
||||||
|
|
||||||
except NotFoundError:
|
|
||||||
raise
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error activating user (admin): {e!s}", exc_info=True)
|
logger.error(f"Error activating user (admin): {e!s}", exc_info=True)
|
||||||
raise
|
raise
|
||||||
@@ -518,11 +465,7 @@ async def admin_deactivate_user(
|
|||||||
) -> Any:
|
) -> Any:
|
||||||
"""Deactivate a user account."""
|
"""Deactivate a user account."""
|
||||||
try:
|
try:
|
||||||
user = await user_crud.get(db, id=user_id)
|
user = await user_service.get_user(db, str(user_id))
|
||||||
if not user:
|
|
||||||
raise NotFoundError(
|
|
||||||
message=f"User {user_id} not found", error_code=ErrorCode.USER_NOT_FOUND
|
|
||||||
)
|
|
||||||
|
|
||||||
# Prevent deactivating yourself
|
# Prevent deactivating yourself
|
||||||
if user.id == admin.id:
|
if user.id == admin.id:
|
||||||
@@ -532,15 +475,13 @@ async def admin_deactivate_user(
|
|||||||
error_code=ErrorCode.OPERATION_FORBIDDEN,
|
error_code=ErrorCode.OPERATION_FORBIDDEN,
|
||||||
)
|
)
|
||||||
|
|
||||||
await user_crud.update(db, db_obj=user, obj_in={"is_active": False})
|
await user_service.update_user(db, user=user, obj_in={"is_active": False})
|
||||||
logger.info(f"Admin {admin.email} deactivated user {user.email}")
|
logger.info(f"Admin {admin.email} deactivated user {user.email}")
|
||||||
|
|
||||||
return MessageResponse(
|
return MessageResponse(
|
||||||
success=True, message=f"User {user.email} has been deactivated"
|
success=True, message=f"User {user.email} has been deactivated"
|
||||||
)
|
)
|
||||||
|
|
||||||
except NotFoundError:
|
|
||||||
raise
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error deactivating user (admin): {e!s}", exc_info=True)
|
logger.error(f"Error deactivating user (admin): {e!s}", exc_info=True)
|
||||||
raise
|
raise
|
||||||
@@ -567,16 +508,16 @@ async def admin_bulk_user_action(
|
|||||||
try:
|
try:
|
||||||
# Use efficient bulk operations instead of loop
|
# Use efficient bulk operations instead of loop
|
||||||
if bulk_action.action == BulkAction.ACTIVATE:
|
if bulk_action.action == BulkAction.ACTIVATE:
|
||||||
affected_count = await user_crud.bulk_update_status(
|
affected_count = await user_service.bulk_update_status(
|
||||||
db, user_ids=bulk_action.user_ids, is_active=True
|
db, user_ids=bulk_action.user_ids, is_active=True
|
||||||
)
|
)
|
||||||
elif bulk_action.action == BulkAction.DEACTIVATE:
|
elif bulk_action.action == BulkAction.DEACTIVATE:
|
||||||
affected_count = await user_crud.bulk_update_status(
|
affected_count = await user_service.bulk_update_status(
|
||||||
db, user_ids=bulk_action.user_ids, is_active=False
|
db, user_ids=bulk_action.user_ids, is_active=False
|
||||||
)
|
)
|
||||||
elif bulk_action.action == BulkAction.DELETE:
|
elif bulk_action.action == BulkAction.DELETE:
|
||||||
# bulk_soft_delete automatically excludes the admin user
|
# bulk_soft_delete automatically excludes the admin user
|
||||||
affected_count = await user_crud.bulk_soft_delete(
|
affected_count = await user_service.bulk_soft_delete(
|
||||||
db, user_ids=bulk_action.user_ids, exclude_user_id=admin.id
|
db, user_ids=bulk_action.user_ids, exclude_user_id=admin.id
|
||||||
)
|
)
|
||||||
else: # pragma: no cover
|
else: # pragma: no cover
|
||||||
@@ -624,7 +565,7 @@ async def admin_list_organizations(
|
|||||||
"""List all organizations with filtering and search."""
|
"""List all organizations with filtering and search."""
|
||||||
try:
|
try:
|
||||||
# Use optimized method that gets member counts in single query (no N+1)
|
# Use optimized method that gets member counts in single query (no N+1)
|
||||||
orgs_with_data, total = await organization_crud.get_multi_with_member_counts(
|
orgs_with_data, total = await organization_service.get_multi_with_member_counts(
|
||||||
db,
|
db,
|
||||||
skip=pagination.offset,
|
skip=pagination.offset,
|
||||||
limit=pagination.limit,
|
limit=pagination.limit,
|
||||||
@@ -680,7 +621,7 @@ async def admin_create_organization(
|
|||||||
) -> Any:
|
) -> Any:
|
||||||
"""Create a new organization."""
|
"""Create a new organization."""
|
||||||
try:
|
try:
|
||||||
org = await organization_crud.create(db, obj_in=org_in)
|
org = await organization_service.create_organization(db, obj_in=org_in)
|
||||||
logger.info(f"Admin {admin.email} created organization {org.name}")
|
logger.info(f"Admin {admin.email} created organization {org.name}")
|
||||||
|
|
||||||
# Add member count
|
# Add member count
|
||||||
@@ -697,9 +638,9 @@ async def admin_create_organization(
|
|||||||
}
|
}
|
||||||
return OrganizationResponse(**org_dict)
|
return OrganizationResponse(**org_dict)
|
||||||
|
|
||||||
except ValueError as e:
|
except DuplicateEntryError as e:
|
||||||
logger.warning(f"Failed to create organization: {e!s}")
|
logger.warning(f"Failed to create organization: {e!s}")
|
||||||
raise NotFoundError(message=str(e), error_code=ErrorCode.ALREADY_EXISTS)
|
raise DuplicateError(message=str(e), error_code=ErrorCode.ALREADY_EXISTS)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error creating organization (admin): {e!s}", exc_info=True)
|
logger.error(f"Error creating organization (admin): {e!s}", exc_info=True)
|
||||||
raise
|
raise
|
||||||
@@ -718,12 +659,7 @@ async def admin_get_organization(
|
|||||||
db: AsyncSession = Depends(get_db),
|
db: AsyncSession = Depends(get_db),
|
||||||
) -> Any:
|
) -> Any:
|
||||||
"""Get detailed information about a specific organization."""
|
"""Get detailed information about a specific organization."""
|
||||||
org = await organization_crud.get(db, id=org_id)
|
org = await organization_service.get_organization(db, str(org_id))
|
||||||
if not org:
|
|
||||||
raise NotFoundError(
|
|
||||||
message=f"Organization {org_id} not found", error_code=ErrorCode.NOT_FOUND
|
|
||||||
)
|
|
||||||
|
|
||||||
org_dict = {
|
org_dict = {
|
||||||
"id": org.id,
|
"id": org.id,
|
||||||
"name": org.name,
|
"name": org.name,
|
||||||
@@ -733,7 +669,7 @@ async def admin_get_organization(
|
|||||||
"settings": org.settings,
|
"settings": org.settings,
|
||||||
"created_at": org.created_at,
|
"created_at": org.created_at,
|
||||||
"updated_at": org.updated_at,
|
"updated_at": org.updated_at,
|
||||||
"member_count": await organization_crud.get_member_count(
|
"member_count": await organization_service.get_member_count(
|
||||||
db, organization_id=org.id
|
db, organization_id=org.id
|
||||||
),
|
),
|
||||||
}
|
}
|
||||||
@@ -755,14 +691,10 @@ async def admin_update_organization(
|
|||||||
) -> Any:
|
) -> Any:
|
||||||
"""Update organization information."""
|
"""Update organization information."""
|
||||||
try:
|
try:
|
||||||
org = await organization_crud.get(db, id=org_id)
|
org = await organization_service.get_organization(db, str(org_id))
|
||||||
if not org:
|
updated_org = await organization_service.update_organization(
|
||||||
raise NotFoundError(
|
db, org=org, obj_in=org_in
|
||||||
message=f"Organization {org_id} not found",
|
)
|
||||||
error_code=ErrorCode.NOT_FOUND,
|
|
||||||
)
|
|
||||||
|
|
||||||
updated_org = await organization_crud.update(db, db_obj=org, obj_in=org_in)
|
|
||||||
logger.info(f"Admin {admin.email} updated organization {updated_org.name}")
|
logger.info(f"Admin {admin.email} updated organization {updated_org.name}")
|
||||||
|
|
||||||
org_dict = {
|
org_dict = {
|
||||||
@@ -774,14 +706,12 @@ async def admin_update_organization(
|
|||||||
"settings": updated_org.settings,
|
"settings": updated_org.settings,
|
||||||
"created_at": updated_org.created_at,
|
"created_at": updated_org.created_at,
|
||||||
"updated_at": updated_org.updated_at,
|
"updated_at": updated_org.updated_at,
|
||||||
"member_count": await organization_crud.get_member_count(
|
"member_count": await organization_service.get_member_count(
|
||||||
db, organization_id=updated_org.id
|
db, organization_id=updated_org.id
|
||||||
),
|
),
|
||||||
}
|
}
|
||||||
return OrganizationResponse(**org_dict)
|
return OrganizationResponse(**org_dict)
|
||||||
|
|
||||||
except NotFoundError:
|
|
||||||
raise
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error updating organization (admin): {e!s}", exc_info=True)
|
logger.error(f"Error updating organization (admin): {e!s}", exc_info=True)
|
||||||
raise
|
raise
|
||||||
@@ -801,22 +731,14 @@ async def admin_delete_organization(
|
|||||||
) -> Any:
|
) -> Any:
|
||||||
"""Delete an organization and all its relationships."""
|
"""Delete an organization and all its relationships."""
|
||||||
try:
|
try:
|
||||||
org = await organization_crud.get(db, id=org_id)
|
org = await organization_service.get_organization(db, str(org_id))
|
||||||
if not org:
|
await organization_service.remove_organization(db, str(org_id))
|
||||||
raise NotFoundError(
|
|
||||||
message=f"Organization {org_id} not found",
|
|
||||||
error_code=ErrorCode.NOT_FOUND,
|
|
||||||
)
|
|
||||||
|
|
||||||
await organization_crud.remove(db, id=org_id)
|
|
||||||
logger.info(f"Admin {admin.email} deleted organization {org.name}")
|
logger.info(f"Admin {admin.email} deleted organization {org.name}")
|
||||||
|
|
||||||
return MessageResponse(
|
return MessageResponse(
|
||||||
success=True, message=f"Organization {org.name} has been deleted"
|
success=True, message=f"Organization {org.name} has been deleted"
|
||||||
)
|
)
|
||||||
|
|
||||||
except NotFoundError:
|
|
||||||
raise
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error deleting organization (admin): {e!s}", exc_info=True)
|
logger.error(f"Error deleting organization (admin): {e!s}", exc_info=True)
|
||||||
raise
|
raise
|
||||||
@@ -838,14 +760,8 @@ async def admin_list_organization_members(
|
|||||||
) -> Any:
|
) -> Any:
|
||||||
"""List all members of an organization."""
|
"""List all members of an organization."""
|
||||||
try:
|
try:
|
||||||
org = await organization_crud.get(db, id=org_id)
|
await organization_service.get_organization(db, str(org_id)) # validates exists
|
||||||
if not org:
|
members, total = await organization_service.get_organization_members(
|
||||||
raise NotFoundError(
|
|
||||||
message=f"Organization {org_id} not found",
|
|
||||||
error_code=ErrorCode.NOT_FOUND,
|
|
||||||
)
|
|
||||||
|
|
||||||
members, total = await organization_crud.get_organization_members(
|
|
||||||
db,
|
db,
|
||||||
organization_id=org_id,
|
organization_id=org_id,
|
||||||
skip=pagination.offset,
|
skip=pagination.offset,
|
||||||
@@ -898,21 +814,10 @@ async def admin_add_organization_member(
|
|||||||
) -> Any:
|
) -> Any:
|
||||||
"""Add a user to an organization."""
|
"""Add a user to an organization."""
|
||||||
try:
|
try:
|
||||||
org = await organization_crud.get(db, id=org_id)
|
org = await organization_service.get_organization(db, str(org_id))
|
||||||
if not org:
|
user = await user_service.get_user(db, str(request.user_id))
|
||||||
raise NotFoundError(
|
|
||||||
message=f"Organization {org_id} not found",
|
|
||||||
error_code=ErrorCode.NOT_FOUND,
|
|
||||||
)
|
|
||||||
|
|
||||||
user = await user_crud.get(db, id=request.user_id)
|
await organization_service.add_member(
|
||||||
if not user:
|
|
||||||
raise NotFoundError(
|
|
||||||
message=f"User {request.user_id} not found",
|
|
||||||
error_code=ErrorCode.USER_NOT_FOUND,
|
|
||||||
)
|
|
||||||
|
|
||||||
await organization_crud.add_user(
|
|
||||||
db, organization_id=org_id, user_id=request.user_id, role=request.role
|
db, organization_id=org_id, user_id=request.user_id, role=request.role
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -925,14 +830,11 @@ async def admin_add_organization_member(
|
|||||||
success=True, message=f"User {user.email} added to organization {org.name}"
|
success=True, message=f"User {user.email} added to organization {org.name}"
|
||||||
)
|
)
|
||||||
|
|
||||||
except ValueError as e:
|
except DuplicateEntryError as e:
|
||||||
logger.warning(f"Failed to add user to organization: {e!s}")
|
logger.warning(f"Failed to add user to organization: {e!s}")
|
||||||
# Use DuplicateError for "already exists" scenarios
|
|
||||||
raise DuplicateError(
|
raise DuplicateError(
|
||||||
message=str(e), error_code=ErrorCode.USER_ALREADY_EXISTS, field="user_id"
|
message=str(e), error_code=ErrorCode.USER_ALREADY_EXISTS, field="user_id"
|
||||||
)
|
)
|
||||||
except NotFoundError:
|
|
||||||
raise
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(
|
logger.error(
|
||||||
f"Error adding member to organization (admin): {e!s}", exc_info=True
|
f"Error adding member to organization (admin): {e!s}", exc_info=True
|
||||||
@@ -955,20 +857,10 @@ async def admin_remove_organization_member(
|
|||||||
) -> Any:
|
) -> Any:
|
||||||
"""Remove a user from an organization."""
|
"""Remove a user from an organization."""
|
||||||
try:
|
try:
|
||||||
org = await organization_crud.get(db, id=org_id)
|
org = await organization_service.get_organization(db, str(org_id))
|
||||||
if not org:
|
user = await user_service.get_user(db, str(user_id))
|
||||||
raise NotFoundError(
|
|
||||||
message=f"Organization {org_id} not found",
|
|
||||||
error_code=ErrorCode.NOT_FOUND,
|
|
||||||
)
|
|
||||||
|
|
||||||
user = await user_crud.get(db, id=user_id)
|
success = await organization_service.remove_member(
|
||||||
if not user:
|
|
||||||
raise NotFoundError(
|
|
||||||
message=f"User {user_id} not found", error_code=ErrorCode.USER_NOT_FOUND
|
|
||||||
)
|
|
||||||
|
|
||||||
success = await organization_crud.remove_user(
|
|
||||||
db, organization_id=org_id, user_id=user_id
|
db, organization_id=org_id, user_id=user_id
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -1022,7 +914,7 @@ async def admin_list_sessions(
|
|||||||
"""List all sessions across all users with filtering and pagination."""
|
"""List all sessions across all users with filtering and pagination."""
|
||||||
try:
|
try:
|
||||||
# Get sessions with user info (eager loaded to prevent N+1)
|
# Get sessions with user info (eager loaded to prevent N+1)
|
||||||
sessions, total = await session_crud.get_all_sessions(
|
sessions, total = await session_service.get_all_sessions(
|
||||||
db,
|
db,
|
||||||
skip=pagination.offset,
|
skip=pagination.offset,
|
||||||
limit=pagination.limit,
|
limit=pagination.limit,
|
||||||
|
|||||||
@@ -15,16 +15,14 @@ from app.core.auth import (
|
|||||||
TokenExpiredError,
|
TokenExpiredError,
|
||||||
TokenInvalidError,
|
TokenInvalidError,
|
||||||
decode_token,
|
decode_token,
|
||||||
get_password_hash,
|
|
||||||
)
|
)
|
||||||
from app.core.database import get_db
|
from app.core.database import get_db
|
||||||
from app.core.exceptions import (
|
from app.core.exceptions import (
|
||||||
AuthenticationError as AuthError,
|
AuthenticationError as AuthError,
|
||||||
DatabaseError,
|
DatabaseError,
|
||||||
|
DuplicateError,
|
||||||
ErrorCode,
|
ErrorCode,
|
||||||
)
|
)
|
||||||
from app.crud.session import session as session_crud
|
|
||||||
from app.crud.user import user as user_crud
|
|
||||||
from app.models.user import User
|
from app.models.user import User
|
||||||
from app.schemas.common import MessageResponse
|
from app.schemas.common import MessageResponse
|
||||||
from app.schemas.sessions import LogoutRequest, SessionCreate
|
from app.schemas.sessions import LogoutRequest, SessionCreate
|
||||||
@@ -39,6 +37,8 @@ from app.schemas.users import (
|
|||||||
)
|
)
|
||||||
from app.services.auth_service import AuthenticationError, AuthService
|
from app.services.auth_service import AuthenticationError, AuthService
|
||||||
from app.services.email_service import email_service
|
from app.services.email_service import email_service
|
||||||
|
from app.services.session_service import session_service
|
||||||
|
from app.services.user_service import user_service
|
||||||
from app.utils.device import extract_device_info
|
from app.utils.device import extract_device_info
|
||||||
from app.utils.security import create_password_reset_token, verify_password_reset_token
|
from app.utils.security import create_password_reset_token, verify_password_reset_token
|
||||||
|
|
||||||
@@ -91,7 +91,7 @@ async def _create_login_session(
|
|||||||
location_country=device_info.location_country,
|
location_country=device_info.location_country,
|
||||||
)
|
)
|
||||||
|
|
||||||
await session_crud.create_session(db, obj_in=session_data)
|
await session_service.create_session(db, obj_in=session_data)
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
f"{login_type.capitalize()} successful: {user.email} from {device_info.device_name} "
|
f"{login_type.capitalize()} successful: {user.email} from {device_info.device_name} "
|
||||||
@@ -123,8 +123,14 @@ async def register_user(
|
|||||||
try:
|
try:
|
||||||
user = await AuthService.create_user(db, user_data)
|
user = await AuthService.create_user(db, user_data)
|
||||||
return user
|
return user
|
||||||
except AuthenticationError as e:
|
except DuplicateError:
|
||||||
# SECURITY: Don't reveal if email exists - generic error message
|
# SECURITY: Don't reveal if email exists - generic error message
|
||||||
|
logger.warning(f"Registration failed: duplicate email {user_data.email}")
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
detail="Registration failed. Please check your information and try again.",
|
||||||
|
)
|
||||||
|
except AuthError as e:
|
||||||
logger.warning(f"Registration failed: {e!s}")
|
logger.warning(f"Registration failed: {e!s}")
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_400_BAD_REQUEST,
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
@@ -259,7 +265,7 @@ async def refresh_token(
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Check if session exists and is active
|
# Check if session exists and is active
|
||||||
session = await session_crud.get_active_by_jti(db, jti=refresh_payload.jti)
|
session = await session_service.get_active_by_jti(db, jti=refresh_payload.jti)
|
||||||
|
|
||||||
if not session:
|
if not session:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
@@ -279,7 +285,7 @@ async def refresh_token(
|
|||||||
|
|
||||||
# Update session with new refresh token JTI and expiration
|
# Update session with new refresh token JTI and expiration
|
||||||
try:
|
try:
|
||||||
await session_crud.update_refresh_token(
|
await session_service.update_refresh_token(
|
||||||
db,
|
db,
|
||||||
session=session,
|
session=session,
|
||||||
new_jti=new_refresh_payload.jti,
|
new_jti=new_refresh_payload.jti,
|
||||||
@@ -347,7 +353,7 @@ async def request_password_reset(
|
|||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
# Look up user by email
|
# Look up user by email
|
||||||
user = await user_crud.get_by_email(db, email=reset_request.email)
|
user = await user_service.get_by_email(db, email=reset_request.email)
|
||||||
|
|
||||||
# Only send email if user exists and is active
|
# Only send email if user exists and is active
|
||||||
if user and user.is_active:
|
if user and user.is_active:
|
||||||
@@ -412,31 +418,25 @@ async def confirm_password_reset(
|
|||||||
detail="Invalid or expired password reset token",
|
detail="Invalid or expired password reset token",
|
||||||
)
|
)
|
||||||
|
|
||||||
# Look up user
|
# Reset password via service (validates user exists and is active)
|
||||||
user = await user_crud.get_by_email(db, email=email)
|
try:
|
||||||
|
user = await AuthService.reset_password(
|
||||||
if not user:
|
db, email=email, new_password=reset_confirm.new_password
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_404_NOT_FOUND, detail="User not found"
|
|
||||||
)
|
)
|
||||||
|
except AuthenticationError as e:
|
||||||
if not user.is_active:
|
err_msg = str(e)
|
||||||
|
if "inactive" in err_msg.lower():
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_400_BAD_REQUEST, detail=err_msg
|
||||||
|
)
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_400_BAD_REQUEST,
|
status_code=status.HTTP_404_NOT_FOUND, detail=err_msg
|
||||||
detail="User account is inactive",
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Update password
|
|
||||||
user.password_hash = get_password_hash(reset_confirm.new_password)
|
|
||||||
db.add(user)
|
|
||||||
await db.commit()
|
|
||||||
|
|
||||||
# SECURITY: Invalidate all existing sessions after password reset
|
# SECURITY: Invalidate all existing sessions after password reset
|
||||||
# This prevents stolen sessions from being used after password change
|
# This prevents stolen sessions from being used after password change
|
||||||
from app.crud.session import session as session_crud
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
deactivated_count = await session_crud.deactivate_all_user_sessions(
|
deactivated_count = await session_service.deactivate_all_user_sessions(
|
||||||
db, user_id=str(user.id)
|
db, user_id=str(user.id)
|
||||||
)
|
)
|
||||||
logger.info(
|
logger.info(
|
||||||
@@ -511,7 +511,7 @@ async def logout(
|
|||||||
return MessageResponse(success=True, message="Logged out successfully")
|
return MessageResponse(success=True, message="Logged out successfully")
|
||||||
|
|
||||||
# Find the session by JTI
|
# Find the session by JTI
|
||||||
session = await session_crud.get_by_jti(db, jti=refresh_payload.jti)
|
session = await session_service.get_by_jti(db, jti=refresh_payload.jti)
|
||||||
|
|
||||||
if session:
|
if session:
|
||||||
# Verify session belongs to current user (security check)
|
# Verify session belongs to current user (security check)
|
||||||
@@ -526,7 +526,7 @@ async def logout(
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Deactivate the session
|
# Deactivate the session
|
||||||
await session_crud.deactivate(db, session_id=str(session.id))
|
await session_service.deactivate(db, session_id=str(session.id))
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
f"User {current_user.id} logged out from {session.device_name} "
|
f"User {current_user.id} logged out from {session.device_name} "
|
||||||
@@ -584,7 +584,7 @@ async def logout_all(
|
|||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
# Deactivate all sessions for this user
|
# Deactivate all sessions for this user
|
||||||
count = await session_crud.deactivate_all_user_sessions(
|
count = await session_service.deactivate_all_user_sessions(
|
||||||
db, user_id=str(current_user.id)
|
db, user_id=str(current_user.id)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -25,8 +25,7 @@ from app.core.auth import decode_token
|
|||||||
from app.core.config import settings
|
from app.core.config import settings
|
||||||
from app.core.database import get_db
|
from app.core.database import get_db
|
||||||
from app.core.exceptions import AuthenticationError as AuthError
|
from app.core.exceptions import AuthenticationError as AuthError
|
||||||
from app.crud import oauth_account
|
from app.services.session_service import session_service
|
||||||
from app.crud.session import session as session_crud
|
|
||||||
from app.models.user import User
|
from app.models.user import User
|
||||||
from app.schemas.oauth import (
|
from app.schemas.oauth import (
|
||||||
OAuthAccountsListResponse,
|
OAuthAccountsListResponse,
|
||||||
@@ -82,7 +81,7 @@ async def _create_oauth_login_session(
|
|||||||
location_country=device_info.location_country,
|
location_country=device_info.location_country,
|
||||||
)
|
)
|
||||||
|
|
||||||
await session_crud.create_session(db, obj_in=session_data)
|
await session_service.create_session(db, obj_in=session_data)
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
f"OAuth login successful: {user.email} via {provider} "
|
f"OAuth login successful: {user.email} via {provider} "
|
||||||
@@ -289,7 +288,7 @@ async def list_accounts(
|
|||||||
Returns:
|
Returns:
|
||||||
List of linked OAuth accounts
|
List of linked OAuth accounts
|
||||||
"""
|
"""
|
||||||
accounts = await oauth_account.get_user_accounts(db, user_id=current_user.id)
|
accounts = await OAuthService.get_user_accounts(db, user_id=current_user.id)
|
||||||
return OAuthAccountsListResponse(accounts=accounts)
|
return OAuthAccountsListResponse(accounts=accounts)
|
||||||
|
|
||||||
|
|
||||||
@@ -397,7 +396,7 @@ async def start_link(
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Check if user already has this provider linked
|
# Check if user already has this provider linked
|
||||||
existing = await oauth_account.get_user_account_by_provider(
|
existing = await OAuthService.get_user_account_by_provider(
|
||||||
db, user_id=current_user.id, provider=provider
|
db, user_id=current_user.id, provider=provider
|
||||||
)
|
)
|
||||||
if existing:
|
if existing:
|
||||||
|
|||||||
@@ -34,7 +34,6 @@ from app.api.dependencies.auth import (
|
|||||||
)
|
)
|
||||||
from app.core.config import settings
|
from app.core.config import settings
|
||||||
from app.core.database import get_db
|
from app.core.database import get_db
|
||||||
from app.crud import oauth_client as oauth_client_crud
|
|
||||||
from app.models.user import User
|
from app.models.user import User
|
||||||
from app.schemas.oauth import (
|
from app.schemas.oauth import (
|
||||||
OAuthClientCreate,
|
OAuthClientCreate,
|
||||||
@@ -712,7 +711,7 @@ async def register_client(
|
|||||||
client_type=client_type,
|
client_type=client_type,
|
||||||
)
|
)
|
||||||
|
|
||||||
client, secret = await oauth_client_crud.create_client(db, obj_in=client_data)
|
client, secret = await provider_service.register_client(db, client_data)
|
||||||
|
|
||||||
# Update MCP server URL if provided
|
# Update MCP server URL if provided
|
||||||
if mcp_server_url:
|
if mcp_server_url:
|
||||||
@@ -750,7 +749,7 @@ async def list_clients(
|
|||||||
current_user: User = Depends(get_current_superuser),
|
current_user: User = Depends(get_current_superuser),
|
||||||
) -> list[OAuthClientResponse]:
|
) -> list[OAuthClientResponse]:
|
||||||
"""List all OAuth clients."""
|
"""List all OAuth clients."""
|
||||||
clients = await oauth_client_crud.get_all_clients(db)
|
clients = await provider_service.list_clients(db)
|
||||||
return [OAuthClientResponse.model_validate(c) for c in clients]
|
return [OAuthClientResponse.model_validate(c) for c in clients]
|
||||||
|
|
||||||
|
|
||||||
@@ -776,7 +775,7 @@ async def delete_client(
|
|||||||
detail="Client not found",
|
detail="Client not found",
|
||||||
)
|
)
|
||||||
|
|
||||||
await oauth_client_crud.delete_client(db, client_id=client_id)
|
await provider_service.delete_client_by_id(db, client_id=client_id)
|
||||||
|
|
||||||
|
|
||||||
# ============================================================================
|
# ============================================================================
|
||||||
@@ -797,30 +796,7 @@ async def list_my_consents(
|
|||||||
current_user: User = Depends(get_current_active_user),
|
current_user: User = Depends(get_current_active_user),
|
||||||
) -> list[dict]:
|
) -> list[dict]:
|
||||||
"""List applications the user has authorized."""
|
"""List applications the user has authorized."""
|
||||||
from sqlalchemy import select
|
return await provider_service.list_user_consents(db, user_id=current_user.id)
|
||||||
|
|
||||||
from app.models.oauth_client import OAuthClient
|
|
||||||
from app.models.oauth_provider_token import OAuthConsent
|
|
||||||
|
|
||||||
result = await db.execute(
|
|
||||||
select(OAuthConsent, OAuthClient)
|
|
||||||
.join(OAuthClient, OAuthConsent.client_id == OAuthClient.client_id)
|
|
||||||
.where(OAuthConsent.user_id == current_user.id)
|
|
||||||
)
|
|
||||||
rows = result.all()
|
|
||||||
|
|
||||||
return [
|
|
||||||
{
|
|
||||||
"client_id": consent.client_id,
|
|
||||||
"client_name": client.client_name,
|
|
||||||
"client_description": client.client_description,
|
|
||||||
"granted_scopes": consent.granted_scopes.split()
|
|
||||||
if consent.granted_scopes
|
|
||||||
else [],
|
|
||||||
"granted_at": consent.created_at.isoformat(),
|
|
||||||
}
|
|
||||||
for consent, client in rows
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
@router.delete(
|
@router.delete(
|
||||||
|
|||||||
@@ -15,9 +15,8 @@ from sqlalchemy.ext.asyncio import AsyncSession
|
|||||||
from app.api.dependencies.auth import get_current_user
|
from app.api.dependencies.auth import get_current_user
|
||||||
from app.api.dependencies.permissions import require_org_admin, require_org_membership
|
from app.api.dependencies.permissions import require_org_admin, require_org_membership
|
||||||
from app.core.database import get_db
|
from app.core.database import get_db
|
||||||
from app.core.exceptions import ErrorCode, NotFoundError
|
|
||||||
from app.crud.organization import organization as organization_crud
|
|
||||||
from app.models.user import User
|
from app.models.user import User
|
||||||
|
from app.services.organization_service import organization_service
|
||||||
from app.schemas.common import (
|
from app.schemas.common import (
|
||||||
PaginatedResponse,
|
PaginatedResponse,
|
||||||
PaginationParams,
|
PaginationParams,
|
||||||
@@ -54,7 +53,7 @@ async def get_my_organizations(
|
|||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
# Get all org data in single query with JOIN and subquery
|
# Get all org data in single query with JOIN and subquery
|
||||||
orgs_data = await organization_crud.get_user_organizations_with_details(
|
orgs_data = await organization_service.get_user_organizations_with_details(
|
||||||
db, user_id=current_user.id, is_active=is_active
|
db, user_id=current_user.id, is_active=is_active
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -100,13 +99,7 @@ async def get_organization(
|
|||||||
User must be a member of the organization.
|
User must be a member of the organization.
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
org = await organization_crud.get(db, id=organization_id)
|
org = await organization_service.get_organization(db, str(organization_id))
|
||||||
if not org: # pragma: no cover - Permission check prevents this (see docs/UNREACHABLE_DEFENSIVE_CODE_ANALYSIS.md)
|
|
||||||
raise NotFoundError(
|
|
||||||
detail=f"Organization {organization_id} not found",
|
|
||||||
error_code=ErrorCode.NOT_FOUND,
|
|
||||||
)
|
|
||||||
|
|
||||||
org_dict = {
|
org_dict = {
|
||||||
"id": org.id,
|
"id": org.id,
|
||||||
"name": org.name,
|
"name": org.name,
|
||||||
@@ -116,14 +109,12 @@ async def get_organization(
|
|||||||
"settings": org.settings,
|
"settings": org.settings,
|
||||||
"created_at": org.created_at,
|
"created_at": org.created_at,
|
||||||
"updated_at": org.updated_at,
|
"updated_at": org.updated_at,
|
||||||
"member_count": await organization_crud.get_member_count(
|
"member_count": await organization_service.get_member_count(
|
||||||
db, organization_id=org.id
|
db, organization_id=org.id
|
||||||
),
|
),
|
||||||
}
|
}
|
||||||
return OrganizationResponse(**org_dict)
|
return OrganizationResponse(**org_dict)
|
||||||
|
|
||||||
except NotFoundError: # pragma: no cover - See above
|
|
||||||
raise
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error getting organization: {e!s}", exc_info=True)
|
logger.error(f"Error getting organization: {e!s}", exc_info=True)
|
||||||
raise
|
raise
|
||||||
@@ -149,7 +140,7 @@ async def get_organization_members(
|
|||||||
User must be a member of the organization to view members.
|
User must be a member of the organization to view members.
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
members, total = await organization_crud.get_organization_members(
|
members, total = await organization_service.get_organization_members(
|
||||||
db,
|
db,
|
||||||
organization_id=organization_id,
|
organization_id=organization_id,
|
||||||
skip=pagination.offset,
|
skip=pagination.offset,
|
||||||
@@ -192,14 +183,10 @@ async def update_organization(
|
|||||||
Requires owner or admin role in the organization.
|
Requires owner or admin role in the organization.
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
org = await organization_crud.get(db, id=organization_id)
|
org = await organization_service.get_organization(db, str(organization_id))
|
||||||
if not org: # pragma: no cover - Permission check prevents this (see docs/UNREACHABLE_DEFENSIVE_CODE_ANALYSIS.md)
|
updated_org = await organization_service.update_organization(
|
||||||
raise NotFoundError(
|
db, org=org, obj_in=org_in
|
||||||
detail=f"Organization {organization_id} not found",
|
)
|
||||||
error_code=ErrorCode.NOT_FOUND,
|
|
||||||
)
|
|
||||||
|
|
||||||
updated_org = await organization_crud.update(db, db_obj=org, obj_in=org_in)
|
|
||||||
logger.info(
|
logger.info(
|
||||||
f"User {current_user.email} updated organization {updated_org.name}"
|
f"User {current_user.email} updated organization {updated_org.name}"
|
||||||
)
|
)
|
||||||
@@ -213,14 +200,12 @@ async def update_organization(
|
|||||||
"settings": updated_org.settings,
|
"settings": updated_org.settings,
|
||||||
"created_at": updated_org.created_at,
|
"created_at": updated_org.created_at,
|
||||||
"updated_at": updated_org.updated_at,
|
"updated_at": updated_org.updated_at,
|
||||||
"member_count": await organization_crud.get_member_count(
|
"member_count": await organization_service.get_member_count(
|
||||||
db, organization_id=updated_org.id
|
db, organization_id=updated_org.id
|
||||||
),
|
),
|
||||||
}
|
}
|
||||||
return OrganizationResponse(**org_dict)
|
return OrganizationResponse(**org_dict)
|
||||||
|
|
||||||
except NotFoundError: # pragma: no cover - See above
|
|
||||||
raise
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error updating organization: {e!s}", exc_info=True)
|
logger.error(f"Error updating organization: {e!s}", exc_info=True)
|
||||||
raise
|
raise
|
||||||
|
|||||||
@@ -17,8 +17,8 @@ from app.api.dependencies.auth import get_current_user
|
|||||||
from app.core.auth import decode_token
|
from app.core.auth import decode_token
|
||||||
from app.core.database import get_db
|
from app.core.database import get_db
|
||||||
from app.core.exceptions import AuthorizationError, ErrorCode, NotFoundError
|
from app.core.exceptions import AuthorizationError, ErrorCode, NotFoundError
|
||||||
from app.crud.session import session as session_crud
|
|
||||||
from app.models.user import User
|
from app.models.user import User
|
||||||
|
from app.services.session_service import session_service
|
||||||
from app.schemas.common import MessageResponse
|
from app.schemas.common import MessageResponse
|
||||||
from app.schemas.sessions import SessionListResponse, SessionResponse
|
from app.schemas.sessions import SessionListResponse, SessionResponse
|
||||||
|
|
||||||
@@ -60,7 +60,7 @@ async def list_my_sessions(
|
|||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
# Get all active sessions for user
|
# Get all active sessions for user
|
||||||
sessions = await session_crud.get_user_sessions(
|
sessions = await session_service.get_user_sessions(
|
||||||
db, user_id=str(current_user.id), active_only=True
|
db, user_id=str(current_user.id), active_only=True
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -150,7 +150,7 @@ async def revoke_session(
|
|||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
# Get the session
|
# Get the session
|
||||||
session = await session_crud.get(db, id=str(session_id))
|
session = await session_service.get_session(db, str(session_id))
|
||||||
|
|
||||||
if not session:
|
if not session:
|
||||||
raise NotFoundError(
|
raise NotFoundError(
|
||||||
@@ -170,7 +170,7 @@ async def revoke_session(
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Deactivate the session
|
# Deactivate the session
|
||||||
await session_crud.deactivate(db, session_id=str(session_id))
|
await session_service.deactivate(db, session_id=str(session_id))
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
f"User {current_user.id} revoked session {session_id} "
|
f"User {current_user.id} revoked session {session_id} "
|
||||||
@@ -224,7 +224,7 @@ async def cleanup_expired_sessions(
|
|||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
# Use optimized bulk DELETE instead of N individual deletes
|
# Use optimized bulk DELETE instead of N individual deletes
|
||||||
deleted_count = await session_crud.cleanup_expired_for_user(
|
deleted_count = await session_service.cleanup_expired_for_user(
|
||||||
db, user_id=str(current_user.id)
|
db, user_id=str(current_user.id)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -14,7 +14,6 @@ from sqlalchemy.ext.asyncio import AsyncSession
|
|||||||
from app.api.dependencies.auth import get_current_superuser, get_current_user
|
from app.api.dependencies.auth import get_current_superuser, get_current_user
|
||||||
from app.core.database import get_db
|
from app.core.database import get_db
|
||||||
from app.core.exceptions import AuthorizationError, ErrorCode, NotFoundError
|
from app.core.exceptions import AuthorizationError, ErrorCode, NotFoundError
|
||||||
from app.crud.user import user as user_crud
|
|
||||||
from app.models.user import User
|
from app.models.user import User
|
||||||
from app.schemas.common import (
|
from app.schemas.common import (
|
||||||
MessageResponse,
|
MessageResponse,
|
||||||
@@ -25,6 +24,7 @@ from app.schemas.common import (
|
|||||||
)
|
)
|
||||||
from app.schemas.users import PasswordChange, UserResponse, UserUpdate
|
from app.schemas.users import PasswordChange, UserResponse, UserUpdate
|
||||||
from app.services.auth_service import AuthenticationError, AuthService
|
from app.services.auth_service import AuthenticationError, AuthService
|
||||||
|
from app.services.user_service import user_service
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -71,7 +71,7 @@ async def list_users(
|
|||||||
filters["is_superuser"] = is_superuser
|
filters["is_superuser"] = is_superuser
|
||||||
|
|
||||||
# Get paginated users with total count
|
# Get paginated users with total count
|
||||||
users, total = await user_crud.get_multi_with_total(
|
users, total = await user_service.list_users(
|
||||||
db,
|
db,
|
||||||
skip=pagination.offset,
|
skip=pagination.offset,
|
||||||
limit=pagination.limit,
|
limit=pagination.limit,
|
||||||
@@ -107,7 +107,7 @@ async def list_users(
|
|||||||
""",
|
""",
|
||||||
operation_id="get_current_user_profile",
|
operation_id="get_current_user_profile",
|
||||||
)
|
)
|
||||||
def get_current_user_profile(current_user: User = Depends(get_current_user)) -> Any:
|
async def get_current_user_profile(current_user: User = Depends(get_current_user)) -> Any:
|
||||||
"""Get current user's profile."""
|
"""Get current user's profile."""
|
||||||
return current_user
|
return current_user
|
||||||
|
|
||||||
@@ -138,8 +138,8 @@ async def update_current_user(
|
|||||||
Users cannot elevate their own permissions (protected by UserUpdate schema validator).
|
Users cannot elevate their own permissions (protected by UserUpdate schema validator).
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
updated_user = await user_crud.update(
|
updated_user = await user_service.update_user(
|
||||||
db, db_obj=current_user, obj_in=user_update
|
db, user=current_user, obj_in=user_update
|
||||||
)
|
)
|
||||||
logger.info(f"User {current_user.id} updated their profile")
|
logger.info(f"User {current_user.id} updated their profile")
|
||||||
return updated_user
|
return updated_user
|
||||||
@@ -190,13 +190,7 @@ async def get_user_by_id(
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Get user
|
# Get user
|
||||||
user = await user_crud.get(db, id=str(user_id))
|
user = await user_service.get_user(db, str(user_id))
|
||||||
if not user:
|
|
||||||
raise NotFoundError(
|
|
||||||
message=f"User with id {user_id} not found",
|
|
||||||
error_code=ErrorCode.USER_NOT_FOUND,
|
|
||||||
)
|
|
||||||
|
|
||||||
return user
|
return user
|
||||||
|
|
||||||
|
|
||||||
@@ -241,15 +235,10 @@ async def update_user(
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Get user
|
# Get user
|
||||||
user = await user_crud.get(db, id=str(user_id))
|
user = await user_service.get_user(db, str(user_id))
|
||||||
if not user:
|
|
||||||
raise NotFoundError(
|
|
||||||
message=f"User with id {user_id} not found",
|
|
||||||
error_code=ErrorCode.USER_NOT_FOUND,
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
updated_user = await user_crud.update(db, db_obj=user, obj_in=user_update)
|
updated_user = await user_service.update_user(db, user=user, obj_in=user_update)
|
||||||
logger.info(f"User {user_id} updated by {current_user.id}")
|
logger.info(f"User {user_id} updated by {current_user.id}")
|
||||||
return updated_user
|
return updated_user
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
@@ -346,17 +335,12 @@ async def delete_user(
|
|||||||
error_code=ErrorCode.INSUFFICIENT_PERMISSIONS,
|
error_code=ErrorCode.INSUFFICIENT_PERMISSIONS,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Get user
|
# Get user (raises NotFoundError if not found)
|
||||||
user = await user_crud.get(db, id=str(user_id))
|
await user_service.get_user(db, str(user_id))
|
||||||
if not user:
|
|
||||||
raise NotFoundError(
|
|
||||||
message=f"User with id {user_id} not found",
|
|
||||||
error_code=ErrorCode.USER_NOT_FOUND,
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Use soft delete instead of hard delete
|
# Use soft delete instead of hard delete
|
||||||
await user_crud.soft_delete(db, id=str(user_id))
|
await user_service.soft_delete_user(db, str(user_id))
|
||||||
logger.info(f"User {user_id} soft-deleted by {current_user.id}")
|
logger.info(f"User {user_id} soft-deleted by {current_user.id}")
|
||||||
return MessageResponse(
|
return MessageResponse(
|
||||||
success=True, message=f"User {user_id} deleted successfully"
|
success=True, message=f"User {user_id} deleted successfully"
|
||||||
|
|||||||
26
backend/app/core/repository_exceptions.py
Normal file
26
backend/app/core/repository_exceptions.py
Normal file
@@ -0,0 +1,26 @@
|
|||||||
|
"""
|
||||||
|
Custom exceptions for the repository layer.
|
||||||
|
|
||||||
|
These exceptions allow services and routes to handle database-level errors
|
||||||
|
with proper semantics, without leaking SQLAlchemy internals.
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
class RepositoryError(Exception):
|
||||||
|
"""Base for all repository-layer errors."""
|
||||||
|
|
||||||
|
|
||||||
|
class DuplicateEntryError(RepositoryError):
|
||||||
|
"""Raised on unique constraint violations. Maps to HTTP 409 Conflict."""
|
||||||
|
|
||||||
|
|
||||||
|
class IntegrityConstraintError(RepositoryError):
|
||||||
|
"""Raised on FK or check constraint violations."""
|
||||||
|
|
||||||
|
|
||||||
|
class RecordNotFoundError(RepositoryError):
|
||||||
|
"""Raised when an expected record doesn't exist."""
|
||||||
|
|
||||||
|
|
||||||
|
class InvalidInputError(RepositoryError):
|
||||||
|
"""Raised on bad pagination params, invalid UUIDs, or other invalid inputs."""
|
||||||
@@ -1,14 +0,0 @@
|
|||||||
# app/crud/__init__.py
|
|
||||||
from .oauth import oauth_account, oauth_client, oauth_state
|
|
||||||
from .organization import organization
|
|
||||||
from .session import session as session_crud
|
|
||||||
from .user import user
|
|
||||||
|
|
||||||
__all__ = [
|
|
||||||
"oauth_account",
|
|
||||||
"oauth_client",
|
|
||||||
"oauth_state",
|
|
||||||
"organization",
|
|
||||||
"session_crud",
|
|
||||||
"user",
|
|
||||||
]
|
|
||||||
@@ -1,718 +0,0 @@
|
|||||||
"""
|
|
||||||
Async CRUD operations for OAuth models using SQLAlchemy 2.0 patterns.
|
|
||||||
|
|
||||||
Provides operations for:
|
|
||||||
- OAuthAccount: Managing linked OAuth provider accounts
|
|
||||||
- OAuthState: CSRF protection state during OAuth flows
|
|
||||||
- OAuthClient: Registered OAuth clients (provider mode skeleton)
|
|
||||||
"""
|
|
||||||
|
|
||||||
import logging
|
|
||||||
import secrets
|
|
||||||
from datetime import UTC, datetime
|
|
||||||
from uuid import UUID
|
|
||||||
|
|
||||||
from pydantic import BaseModel
|
|
||||||
from sqlalchemy import and_, delete, select
|
|
||||||
from sqlalchemy.exc import IntegrityError
|
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
|
||||||
from sqlalchemy.orm import joinedload
|
|
||||||
|
|
||||||
from app.crud.base import CRUDBase
|
|
||||||
from app.models.oauth_account import OAuthAccount
|
|
||||||
from app.models.oauth_client import OAuthClient
|
|
||||||
from app.models.oauth_state import OAuthState
|
|
||||||
from app.schemas.oauth import OAuthAccountCreate, OAuthClientCreate, OAuthStateCreate
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
# ============================================================================
|
|
||||||
# OAuth Account CRUD
|
|
||||||
# ============================================================================
|
|
||||||
|
|
||||||
|
|
||||||
class EmptySchema(BaseModel):
|
|
||||||
"""Placeholder schema for CRUD operations that don't need update schemas."""
|
|
||||||
|
|
||||||
|
|
||||||
class CRUDOAuthAccount(CRUDBase[OAuthAccount, OAuthAccountCreate, EmptySchema]):
|
|
||||||
"""CRUD operations for OAuth account links."""
|
|
||||||
|
|
||||||
async def get_by_provider_id(
|
|
||||||
self,
|
|
||||||
db: AsyncSession,
|
|
||||||
*,
|
|
||||||
provider: str,
|
|
||||||
provider_user_id: str,
|
|
||||||
) -> OAuthAccount | None:
|
|
||||||
"""
|
|
||||||
Get OAuth account by provider and provider user ID.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
db: Database session
|
|
||||||
provider: OAuth provider name (google, github)
|
|
||||||
provider_user_id: User ID from the OAuth provider
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
OAuthAccount if found, None otherwise
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
result = await db.execute(
|
|
||||||
select(OAuthAccount)
|
|
||||||
.where(
|
|
||||||
and_(
|
|
||||||
OAuthAccount.provider == provider,
|
|
||||||
OAuthAccount.provider_user_id == provider_user_id,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
.options(joinedload(OAuthAccount.user))
|
|
||||||
)
|
|
||||||
return result.scalar_one_or_none()
|
|
||||||
except Exception as e: # pragma: no cover # pragma: no cover
|
|
||||||
logger.error(
|
|
||||||
f"Error getting OAuth account for {provider}:{provider_user_id}: {e!s}"
|
|
||||||
)
|
|
||||||
raise
|
|
||||||
|
|
||||||
async def get_by_provider_email(
|
|
||||||
self,
|
|
||||||
db: AsyncSession,
|
|
||||||
*,
|
|
||||||
provider: str,
|
|
||||||
email: str,
|
|
||||||
) -> OAuthAccount | None:
|
|
||||||
"""
|
|
||||||
Get OAuth account by provider and email.
|
|
||||||
|
|
||||||
Used for auto-linking existing accounts by email.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
db: Database session
|
|
||||||
provider: OAuth provider name
|
|
||||||
email: Email address from the OAuth provider
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
OAuthAccount if found, None otherwise
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
result = await db.execute(
|
|
||||||
select(OAuthAccount)
|
|
||||||
.where(
|
|
||||||
and_(
|
|
||||||
OAuthAccount.provider == provider,
|
|
||||||
OAuthAccount.provider_email == email,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
.options(joinedload(OAuthAccount.user))
|
|
||||||
)
|
|
||||||
return result.scalar_one_or_none()
|
|
||||||
except Exception as e: # pragma: no cover # pragma: no cover
|
|
||||||
logger.error(
|
|
||||||
f"Error getting OAuth account for {provider} email {email}: {e!s}"
|
|
||||||
)
|
|
||||||
raise
|
|
||||||
|
|
||||||
async def get_user_accounts(
|
|
||||||
self,
|
|
||||||
db: AsyncSession,
|
|
||||||
*,
|
|
||||||
user_id: str | UUID,
|
|
||||||
) -> list[OAuthAccount]:
|
|
||||||
"""
|
|
||||||
Get all OAuth accounts linked to a user.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
db: Database session
|
|
||||||
user_id: User ID
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List of OAuthAccount objects
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
user_uuid = UUID(str(user_id)) if isinstance(user_id, str) else user_id
|
|
||||||
|
|
||||||
result = await db.execute(
|
|
||||||
select(OAuthAccount)
|
|
||||||
.where(OAuthAccount.user_id == user_uuid)
|
|
||||||
.order_by(OAuthAccount.created_at.desc())
|
|
||||||
)
|
|
||||||
return list(result.scalars().all())
|
|
||||||
except Exception as e: # pragma: no cover
|
|
||||||
logger.error(f"Error getting OAuth accounts for user {user_id}: {e!s}")
|
|
||||||
raise
|
|
||||||
|
|
||||||
async def get_user_account_by_provider(
|
|
||||||
self,
|
|
||||||
db: AsyncSession,
|
|
||||||
*,
|
|
||||||
user_id: str | UUID,
|
|
||||||
provider: str,
|
|
||||||
) -> OAuthAccount | None:
|
|
||||||
"""
|
|
||||||
Get a specific OAuth account for a user and provider.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
db: Database session
|
|
||||||
user_id: User ID
|
|
||||||
provider: OAuth provider name
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
OAuthAccount if found, None otherwise
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
user_uuid = UUID(str(user_id)) if isinstance(user_id, str) else user_id
|
|
||||||
|
|
||||||
result = await db.execute(
|
|
||||||
select(OAuthAccount).where(
|
|
||||||
and_(
|
|
||||||
OAuthAccount.user_id == user_uuid,
|
|
||||||
OAuthAccount.provider == provider,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
)
|
|
||||||
return result.scalar_one_or_none()
|
|
||||||
except Exception as e: # pragma: no cover
|
|
||||||
logger.error(
|
|
||||||
f"Error getting OAuth account for user {user_id}, provider {provider}: {e!s}"
|
|
||||||
)
|
|
||||||
raise
|
|
||||||
|
|
||||||
async def create_account(
|
|
||||||
self, db: AsyncSession, *, obj_in: OAuthAccountCreate
|
|
||||||
) -> OAuthAccount:
|
|
||||||
"""
|
|
||||||
Create a new OAuth account link.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
db: Database session
|
|
||||||
obj_in: OAuth account creation data
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Created OAuthAccount
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
ValueError: If account already exists or creation fails
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
db_obj = OAuthAccount(
|
|
||||||
user_id=obj_in.user_id,
|
|
||||||
provider=obj_in.provider,
|
|
||||||
provider_user_id=obj_in.provider_user_id,
|
|
||||||
provider_email=obj_in.provider_email,
|
|
||||||
access_token_encrypted=obj_in.access_token_encrypted,
|
|
||||||
refresh_token_encrypted=obj_in.refresh_token_encrypted,
|
|
||||||
token_expires_at=obj_in.token_expires_at,
|
|
||||||
)
|
|
||||||
db.add(db_obj)
|
|
||||||
await db.commit()
|
|
||||||
await db.refresh(db_obj)
|
|
||||||
|
|
||||||
logger.info(
|
|
||||||
f"OAuth account created: {obj_in.provider} linked to user {obj_in.user_id}"
|
|
||||||
)
|
|
||||||
return db_obj
|
|
||||||
except IntegrityError as e: # pragma: no cover
|
|
||||||
await db.rollback()
|
|
||||||
error_msg = str(e.orig) if hasattr(e, "orig") else str(e)
|
|
||||||
if "uq_oauth_provider_user" in error_msg.lower():
|
|
||||||
logger.warning(
|
|
||||||
f"OAuth account already exists: {obj_in.provider}:{obj_in.provider_user_id}"
|
|
||||||
)
|
|
||||||
raise ValueError(
|
|
||||||
f"This {obj_in.provider} account is already linked to another user"
|
|
||||||
)
|
|
||||||
logger.error(f"Integrity error creating OAuth account: {error_msg}")
|
|
||||||
raise ValueError(f"Failed to create OAuth account: {error_msg}")
|
|
||||||
except Exception as e: # pragma: no cover
|
|
||||||
await db.rollback()
|
|
||||||
logger.error(f"Error creating OAuth account: {e!s}", exc_info=True)
|
|
||||||
raise
|
|
||||||
|
|
||||||
async def delete_account(
|
|
||||||
self,
|
|
||||||
db: AsyncSession,
|
|
||||||
*,
|
|
||||||
user_id: str | UUID,
|
|
||||||
provider: str,
|
|
||||||
) -> bool:
|
|
||||||
"""
|
|
||||||
Delete an OAuth account link.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
db: Database session
|
|
||||||
user_id: User ID
|
|
||||||
provider: OAuth provider name
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
True if deleted, False if not found
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
user_uuid = UUID(str(user_id)) if isinstance(user_id, str) else user_id
|
|
||||||
|
|
||||||
result = await db.execute(
|
|
||||||
delete(OAuthAccount).where(
|
|
||||||
and_(
|
|
||||||
OAuthAccount.user_id == user_uuid,
|
|
||||||
OAuthAccount.provider == provider,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
)
|
|
||||||
await db.commit()
|
|
||||||
|
|
||||||
deleted = result.rowcount > 0
|
|
||||||
if deleted:
|
|
||||||
logger.info(
|
|
||||||
f"OAuth account deleted: {provider} unlinked from user {user_id}"
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
logger.warning(
|
|
||||||
f"OAuth account not found for deletion: {provider} for user {user_id}"
|
|
||||||
)
|
|
||||||
|
|
||||||
return deleted
|
|
||||||
except Exception as e: # pragma: no cover
|
|
||||||
await db.rollback()
|
|
||||||
logger.error(
|
|
||||||
f"Error deleting OAuth account {provider} for user {user_id}: {e!s}"
|
|
||||||
)
|
|
||||||
raise
|
|
||||||
|
|
||||||
async def update_tokens(
|
|
||||||
self,
|
|
||||||
db: AsyncSession,
|
|
||||||
*,
|
|
||||||
account: OAuthAccount,
|
|
||||||
access_token_encrypted: str | None = None,
|
|
||||||
refresh_token_encrypted: str | None = None,
|
|
||||||
token_expires_at: datetime | None = None,
|
|
||||||
) -> OAuthAccount:
|
|
||||||
"""
|
|
||||||
Update OAuth tokens for an account.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
db: Database session
|
|
||||||
account: OAuthAccount to update
|
|
||||||
access_token_encrypted: New encrypted access token
|
|
||||||
refresh_token_encrypted: New encrypted refresh token
|
|
||||||
token_expires_at: New token expiration time
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Updated OAuthAccount
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
if access_token_encrypted is not None:
|
|
||||||
account.access_token_encrypted = access_token_encrypted
|
|
||||||
if refresh_token_encrypted is not None:
|
|
||||||
account.refresh_token_encrypted = refresh_token_encrypted
|
|
||||||
if token_expires_at is not None:
|
|
||||||
account.token_expires_at = token_expires_at
|
|
||||||
|
|
||||||
db.add(account)
|
|
||||||
await db.commit()
|
|
||||||
await db.refresh(account)
|
|
||||||
|
|
||||||
return account
|
|
||||||
except Exception as e: # pragma: no cover
|
|
||||||
await db.rollback()
|
|
||||||
logger.error(f"Error updating OAuth tokens: {e!s}")
|
|
||||||
raise
|
|
||||||
|
|
||||||
|
|
||||||
# ============================================================================
|
|
||||||
# OAuth State CRUD
|
|
||||||
# ============================================================================
|
|
||||||
|
|
||||||
|
|
||||||
class CRUDOAuthState(CRUDBase[OAuthState, OAuthStateCreate, EmptySchema]):
|
|
||||||
"""CRUD operations for OAuth state (CSRF protection)."""
|
|
||||||
|
|
||||||
async def create_state(
|
|
||||||
self, db: AsyncSession, *, obj_in: OAuthStateCreate
|
|
||||||
) -> OAuthState:
|
|
||||||
"""
|
|
||||||
Create a new OAuth state for CSRF protection.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
db: Database session
|
|
||||||
obj_in: OAuth state creation data
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Created OAuthState
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
db_obj = OAuthState(
|
|
||||||
state=obj_in.state,
|
|
||||||
code_verifier=obj_in.code_verifier,
|
|
||||||
nonce=obj_in.nonce,
|
|
||||||
provider=obj_in.provider,
|
|
||||||
redirect_uri=obj_in.redirect_uri,
|
|
||||||
user_id=obj_in.user_id,
|
|
||||||
expires_at=obj_in.expires_at,
|
|
||||||
)
|
|
||||||
db.add(db_obj)
|
|
||||||
await db.commit()
|
|
||||||
await db.refresh(db_obj)
|
|
||||||
|
|
||||||
logger.debug(f"OAuth state created for {obj_in.provider}")
|
|
||||||
return db_obj
|
|
||||||
except IntegrityError as e: # pragma: no cover
|
|
||||||
await db.rollback()
|
|
||||||
# State collision (extremely rare with cryptographic random)
|
|
||||||
error_msg = str(e.orig) if hasattr(e, "orig") else str(e)
|
|
||||||
logger.error(f"OAuth state collision: {error_msg}")
|
|
||||||
raise ValueError("Failed to create OAuth state, please retry")
|
|
||||||
except Exception as e: # pragma: no cover
|
|
||||||
await db.rollback()
|
|
||||||
logger.error(f"Error creating OAuth state: {e!s}", exc_info=True)
|
|
||||||
raise
|
|
||||||
|
|
||||||
async def get_and_consume_state(
|
|
||||||
self, db: AsyncSession, *, state: str
|
|
||||||
) -> OAuthState | None:
|
|
||||||
"""
|
|
||||||
Get and delete OAuth state (consume it).
|
|
||||||
|
|
||||||
This ensures each state can only be used once (replay protection).
|
|
||||||
|
|
||||||
Args:
|
|
||||||
db: Database session
|
|
||||||
state: State string to look up
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
OAuthState if found and valid, None otherwise
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
# Get the state
|
|
||||||
result = await db.execute(
|
|
||||||
select(OAuthState).where(OAuthState.state == state)
|
|
||||||
)
|
|
||||||
db_obj = result.scalar_one_or_none()
|
|
||||||
|
|
||||||
if db_obj is None:
|
|
||||||
logger.warning(f"OAuth state not found: {state[:8]}...")
|
|
||||||
return None
|
|
||||||
|
|
||||||
# Check expiration
|
|
||||||
# Handle both timezone-aware and timezone-naive datetimes
|
|
||||||
now = datetime.now(UTC)
|
|
||||||
expires_at = db_obj.expires_at
|
|
||||||
if expires_at.tzinfo is None:
|
|
||||||
# SQLite returns naive datetimes, assume UTC
|
|
||||||
expires_at = expires_at.replace(tzinfo=UTC)
|
|
||||||
|
|
||||||
if expires_at < now:
|
|
||||||
logger.warning(f"OAuth state expired: {state[:8]}...")
|
|
||||||
await db.delete(db_obj)
|
|
||||||
await db.commit()
|
|
||||||
return None
|
|
||||||
|
|
||||||
# Delete it (consume)
|
|
||||||
await db.delete(db_obj)
|
|
||||||
await db.commit()
|
|
||||||
|
|
||||||
logger.debug(f"OAuth state consumed: {state[:8]}...")
|
|
||||||
return db_obj
|
|
||||||
except Exception as e: # pragma: no cover
|
|
||||||
await db.rollback()
|
|
||||||
logger.error(f"Error consuming OAuth state: {e!s}")
|
|
||||||
raise
|
|
||||||
|
|
||||||
async def cleanup_expired(self, db: AsyncSession) -> int:
|
|
||||||
"""
|
|
||||||
Clean up expired OAuth states.
|
|
||||||
|
|
||||||
Should be called periodically to remove stale states.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
db: Database session
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Number of states deleted
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
now = datetime.now(UTC)
|
|
||||||
|
|
||||||
stmt = delete(OAuthState).where(OAuthState.expires_at < now)
|
|
||||||
result = await db.execute(stmt)
|
|
||||||
await db.commit()
|
|
||||||
|
|
||||||
count = result.rowcount
|
|
||||||
if count > 0:
|
|
||||||
logger.info(f"Cleaned up {count} expired OAuth states")
|
|
||||||
|
|
||||||
return count
|
|
||||||
except Exception as e: # pragma: no cover
|
|
||||||
await db.rollback()
|
|
||||||
logger.error(f"Error cleaning up expired OAuth states: {e!s}")
|
|
||||||
raise
|
|
||||||
|
|
||||||
|
|
||||||
# ============================================================================
|
|
||||||
# OAuth Client CRUD (Provider Mode - Skeleton)
|
|
||||||
# ============================================================================
|
|
||||||
|
|
||||||
|
|
||||||
class CRUDOAuthClient(CRUDBase[OAuthClient, OAuthClientCreate, EmptySchema]):
|
|
||||||
"""
|
|
||||||
CRUD operations for OAuth clients (provider mode).
|
|
||||||
|
|
||||||
This is a skeleton implementation for MCP client registration.
|
|
||||||
Full implementation can be expanded when needed.
|
|
||||||
"""
|
|
||||||
|
|
||||||
async def get_by_client_id(
|
|
||||||
self, db: AsyncSession, *, client_id: str
|
|
||||||
) -> OAuthClient | None:
|
|
||||||
"""
|
|
||||||
Get OAuth client by client_id.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
db: Database session
|
|
||||||
client_id: OAuth client ID
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
OAuthClient if found, None otherwise
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
result = await db.execute(
|
|
||||||
select(OAuthClient).where(
|
|
||||||
and_(
|
|
||||||
OAuthClient.client_id == client_id,
|
|
||||||
OAuthClient.is_active == True, # noqa: E712
|
|
||||||
)
|
|
||||||
)
|
|
||||||
)
|
|
||||||
return result.scalar_one_or_none()
|
|
||||||
except Exception as e: # pragma: no cover
|
|
||||||
logger.error(f"Error getting OAuth client {client_id}: {e!s}")
|
|
||||||
raise
|
|
||||||
|
|
||||||
async def create_client(
|
|
||||||
self,
|
|
||||||
db: AsyncSession,
|
|
||||||
*,
|
|
||||||
obj_in: OAuthClientCreate,
|
|
||||||
owner_user_id: UUID | None = None,
|
|
||||||
) -> tuple[OAuthClient, str | None]:
|
|
||||||
"""
|
|
||||||
Create a new OAuth client.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
db: Database session
|
|
||||||
obj_in: OAuth client creation data
|
|
||||||
owner_user_id: Optional owner user ID
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Tuple of (created OAuthClient, client_secret or None for public clients)
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
# Generate client_id
|
|
||||||
client_id = secrets.token_urlsafe(32)
|
|
||||||
|
|
||||||
# Generate client_secret for confidential clients
|
|
||||||
client_secret = None
|
|
||||||
client_secret_hash = None
|
|
||||||
if obj_in.client_type == "confidential":
|
|
||||||
client_secret = secrets.token_urlsafe(48)
|
|
||||||
# SECURITY: Use bcrypt for secret storage (not SHA-256)
|
|
||||||
# bcrypt is computationally expensive, making brute-force attacks infeasible
|
|
||||||
from app.core.auth import get_password_hash
|
|
||||||
|
|
||||||
client_secret_hash = get_password_hash(client_secret)
|
|
||||||
|
|
||||||
db_obj = OAuthClient(
|
|
||||||
client_id=client_id,
|
|
||||||
client_secret_hash=client_secret_hash,
|
|
||||||
client_name=obj_in.client_name,
|
|
||||||
client_description=obj_in.client_description,
|
|
||||||
client_type=obj_in.client_type,
|
|
||||||
redirect_uris=obj_in.redirect_uris,
|
|
||||||
allowed_scopes=obj_in.allowed_scopes,
|
|
||||||
owner_user_id=owner_user_id,
|
|
||||||
is_active=True,
|
|
||||||
)
|
|
||||||
db.add(db_obj)
|
|
||||||
await db.commit()
|
|
||||||
await db.refresh(db_obj)
|
|
||||||
|
|
||||||
logger.info(
|
|
||||||
f"OAuth client created: {obj_in.client_name} ({client_id[:8]}...)"
|
|
||||||
)
|
|
||||||
return db_obj, client_secret
|
|
||||||
except IntegrityError as e: # pragma: no cover
|
|
||||||
await db.rollback()
|
|
||||||
error_msg = str(e.orig) if hasattr(e, "orig") else str(e)
|
|
||||||
logger.error(f"Error creating OAuth client: {error_msg}")
|
|
||||||
raise ValueError(f"Failed to create OAuth client: {error_msg}")
|
|
||||||
except Exception as e: # pragma: no cover
|
|
||||||
await db.rollback()
|
|
||||||
logger.error(f"Error creating OAuth client: {e!s}", exc_info=True)
|
|
||||||
raise
|
|
||||||
|
|
||||||
async def deactivate_client(
|
|
||||||
self, db: AsyncSession, *, client_id: str
|
|
||||||
) -> OAuthClient | None:
|
|
||||||
"""
|
|
||||||
Deactivate an OAuth client.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
db: Database session
|
|
||||||
client_id: OAuth client ID
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Deactivated OAuthClient if found, None otherwise
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
client = await self.get_by_client_id(db, client_id=client_id)
|
|
||||||
if client is None:
|
|
||||||
return None
|
|
||||||
|
|
||||||
client.is_active = False
|
|
||||||
db.add(client)
|
|
||||||
await db.commit()
|
|
||||||
await db.refresh(client)
|
|
||||||
|
|
||||||
logger.info(f"OAuth client deactivated: {client.client_name}")
|
|
||||||
return client
|
|
||||||
except Exception as e: # pragma: no cover
|
|
||||||
await db.rollback()
|
|
||||||
logger.error(f"Error deactivating OAuth client {client_id}: {e!s}")
|
|
||||||
raise
|
|
||||||
|
|
||||||
async def validate_redirect_uri(
|
|
||||||
self, db: AsyncSession, *, client_id: str, redirect_uri: str
|
|
||||||
) -> bool:
|
|
||||||
"""
|
|
||||||
Validate that a redirect URI is allowed for a client.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
db: Database session
|
|
||||||
client_id: OAuth client ID
|
|
||||||
redirect_uri: Redirect URI to validate
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
True if valid, False otherwise
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
client = await self.get_by_client_id(db, client_id=client_id)
|
|
||||||
if client is None:
|
|
||||||
return False
|
|
||||||
|
|
||||||
return redirect_uri in (client.redirect_uris or [])
|
|
||||||
except Exception as e: # pragma: no cover
|
|
||||||
logger.error(f"Error validating redirect URI: {e!s}")
|
|
||||||
return False
|
|
||||||
|
|
||||||
async def verify_client_secret(
|
|
||||||
self, db: AsyncSession, *, client_id: str, client_secret: str
|
|
||||||
) -> bool:
|
|
||||||
"""
|
|
||||||
Verify client credentials.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
db: Database session
|
|
||||||
client_id: OAuth client ID
|
|
||||||
client_secret: Client secret to verify
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
True if valid, False otherwise
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
result = await db.execute(
|
|
||||||
select(OAuthClient).where(
|
|
||||||
and_(
|
|
||||||
OAuthClient.client_id == client_id,
|
|
||||||
OAuthClient.is_active == True, # noqa: E712
|
|
||||||
)
|
|
||||||
)
|
|
||||||
)
|
|
||||||
client = result.scalar_one_or_none()
|
|
||||||
|
|
||||||
if client is None or client.client_secret_hash is None:
|
|
||||||
return False
|
|
||||||
|
|
||||||
# SECURITY: Verify secret using bcrypt (not SHA-256)
|
|
||||||
# This supports both old SHA-256 hashes (for migration) and new bcrypt hashes
|
|
||||||
from app.core.auth import verify_password
|
|
||||||
|
|
||||||
stored_hash: str = str(client.client_secret_hash)
|
|
||||||
|
|
||||||
# Check if it's a bcrypt hash (starts with $2b$) or legacy SHA-256
|
|
||||||
if stored_hash.startswith("$2"):
|
|
||||||
# New bcrypt format
|
|
||||||
return verify_password(client_secret, stored_hash)
|
|
||||||
else:
|
|
||||||
# Legacy SHA-256 format - still support for migration
|
|
||||||
import hashlib
|
|
||||||
|
|
||||||
secret_hash = hashlib.sha256(client_secret.encode()).hexdigest()
|
|
||||||
return secrets.compare_digest(stored_hash, secret_hash)
|
|
||||||
except Exception as e: # pragma: no cover
|
|
||||||
logger.error(f"Error verifying client secret: {e!s}")
|
|
||||||
return False
|
|
||||||
|
|
||||||
async def get_all_clients(
|
|
||||||
self, db: AsyncSession, *, include_inactive: bool = False
|
|
||||||
) -> list[OAuthClient]:
|
|
||||||
"""
|
|
||||||
Get all OAuth clients.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
db: Database session
|
|
||||||
include_inactive: Whether to include inactive clients
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List of OAuthClient objects
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
query = select(OAuthClient).order_by(OAuthClient.created_at.desc())
|
|
||||||
if not include_inactive:
|
|
||||||
query = query.where(OAuthClient.is_active == True) # noqa: E712
|
|
||||||
|
|
||||||
result = await db.execute(query)
|
|
||||||
return list(result.scalars().all())
|
|
||||||
except Exception as e: # pragma: no cover
|
|
||||||
logger.error(f"Error getting all OAuth clients: {e!s}")
|
|
||||||
raise
|
|
||||||
|
|
||||||
async def delete_client(self, db: AsyncSession, *, client_id: str) -> bool:
|
|
||||||
"""
|
|
||||||
Delete an OAuth client permanently.
|
|
||||||
|
|
||||||
Note: This will cascade delete related records (tokens, consents, etc.)
|
|
||||||
due to foreign key constraints.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
db: Database session
|
|
||||||
client_id: OAuth client ID
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
True if deleted, False if not found
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
result = await db.execute(
|
|
||||||
delete(OAuthClient).where(OAuthClient.client_id == client_id)
|
|
||||||
)
|
|
||||||
await db.commit()
|
|
||||||
|
|
||||||
deleted = result.rowcount > 0
|
|
||||||
if deleted:
|
|
||||||
logger.info(f"OAuth client deleted: {client_id}")
|
|
||||||
else:
|
|
||||||
logger.warning(f"OAuth client not found for deletion: {client_id}")
|
|
||||||
|
|
||||||
return deleted
|
|
||||||
except Exception as e: # pragma: no cover
|
|
||||||
await db.rollback()
|
|
||||||
logger.error(f"Error deleting OAuth client {client_id}: {e!s}")
|
|
||||||
raise
|
|
||||||
|
|
||||||
|
|
||||||
# ============================================================================
|
|
||||||
# Singleton instances
|
|
||||||
# ============================================================================
|
|
||||||
|
|
||||||
oauth_account = CRUDOAuthAccount(OAuthAccount)
|
|
||||||
oauth_state = CRUDOAuthState(OAuthState)
|
|
||||||
oauth_client = CRUDOAuthClient(OAuthClient)
|
|
||||||
@@ -16,7 +16,7 @@ from sqlalchemy import select, text
|
|||||||
|
|
||||||
from app.core.config import settings
|
from app.core.config import settings
|
||||||
from app.core.database import SessionLocal, engine
|
from app.core.database import SessionLocal, engine
|
||||||
from app.crud.user import user as user_crud
|
from app.repositories.user import user_repo as user_crud
|
||||||
from app.models.organization import Organization
|
from app.models.organization import Organization
|
||||||
from app.models.user import User
|
from app.models.user import User
|
||||||
from app.models.user_organization import UserOrganization
|
from app.models.user_organization import UserOrganization
|
||||||
|
|||||||
@@ -36,9 +36,9 @@ class OAuthAccount(Base, UUIDMixin, TimestampMixin):
|
|||||||
) # Email from provider (for reference)
|
) # Email from provider (for reference)
|
||||||
|
|
||||||
# Optional: store provider tokens for API access
|
# Optional: store provider tokens for API access
|
||||||
# These should be encrypted at rest in production
|
# TODO: Encrypt these at rest in production (requires key management infrastructure)
|
||||||
access_token_encrypted = Column(String(2048), nullable=True)
|
access_token = Column(String(2048), nullable=True)
|
||||||
refresh_token_encrypted = Column(String(2048), nullable=True)
|
refresh_token = Column(String(2048), nullable=True)
|
||||||
token_expires_at = Column(DateTime(timezone=True), nullable=True)
|
token_expires_at = Column(DateTime(timezone=True), nullable=True)
|
||||||
|
|
||||||
# Relationship
|
# Relationship
|
||||||
|
|||||||
39
backend/app/repositories/__init__.py
Normal file
39
backend/app/repositories/__init__.py
Normal file
@@ -0,0 +1,39 @@
|
|||||||
|
# app/repositories/__init__.py
|
||||||
|
"""Repository layer — all database access goes through these classes."""
|
||||||
|
|
||||||
|
from app.repositories.oauth_account import OAuthAccountRepository, oauth_account_repo
|
||||||
|
from app.repositories.oauth_authorization_code import (
|
||||||
|
OAuthAuthorizationCodeRepository,
|
||||||
|
oauth_authorization_code_repo,
|
||||||
|
)
|
||||||
|
from app.repositories.oauth_client import OAuthClientRepository, oauth_client_repo
|
||||||
|
from app.repositories.oauth_consent import OAuthConsentRepository, oauth_consent_repo
|
||||||
|
from app.repositories.oauth_provider_token import (
|
||||||
|
OAuthProviderTokenRepository,
|
||||||
|
oauth_provider_token_repo,
|
||||||
|
)
|
||||||
|
from app.repositories.oauth_state import OAuthStateRepository, oauth_state_repo
|
||||||
|
from app.repositories.organization import OrganizationRepository, organization_repo
|
||||||
|
from app.repositories.session import SessionRepository, session_repo
|
||||||
|
from app.repositories.user import UserRepository, user_repo
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"UserRepository",
|
||||||
|
"user_repo",
|
||||||
|
"OrganizationRepository",
|
||||||
|
"organization_repo",
|
||||||
|
"SessionRepository",
|
||||||
|
"session_repo",
|
||||||
|
"OAuthAccountRepository",
|
||||||
|
"oauth_account_repo",
|
||||||
|
"OAuthAuthorizationCodeRepository",
|
||||||
|
"oauth_authorization_code_repo",
|
||||||
|
"OAuthClientRepository",
|
||||||
|
"oauth_client_repo",
|
||||||
|
"OAuthConsentRepository",
|
||||||
|
"oauth_consent_repo",
|
||||||
|
"OAuthProviderTokenRepository",
|
||||||
|
"oauth_provider_token_repo",
|
||||||
|
"OAuthStateRepository",
|
||||||
|
"oauth_state_repo",
|
||||||
|
]
|
||||||
101
backend/app/crud/base.py → backend/app/repositories/base.py
Executable file → Normal file
101
backend/app/crud/base.py → backend/app/repositories/base.py
Executable file → Normal file
@@ -1,6 +1,6 @@
|
|||||||
# app/crud/base_async.py
|
# app/repositories/base.py
|
||||||
"""
|
"""
|
||||||
Async CRUD operations base class using SQLAlchemy 2.0 async patterns.
|
Base repository class for async CRUD operations using SQLAlchemy 2.0 async patterns.
|
||||||
|
|
||||||
Provides reusable create, read, update, and delete operations for all models.
|
Provides reusable create, read, update, and delete operations for all models.
|
||||||
"""
|
"""
|
||||||
@@ -18,6 +18,11 @@ from sqlalchemy.ext.asyncio import AsyncSession
|
|||||||
from sqlalchemy.orm import Load
|
from sqlalchemy.orm import Load
|
||||||
|
|
||||||
from app.core.database import Base
|
from app.core.database import Base
|
||||||
|
from app.core.repository_exceptions import (
|
||||||
|
DuplicateEntryError,
|
||||||
|
IntegrityConstraintError,
|
||||||
|
InvalidInputError,
|
||||||
|
)
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -26,16 +31,16 @@ CreateSchemaType = TypeVar("CreateSchemaType", bound=BaseModel)
|
|||||||
UpdateSchemaType = TypeVar("UpdateSchemaType", bound=BaseModel)
|
UpdateSchemaType = TypeVar("UpdateSchemaType", bound=BaseModel)
|
||||||
|
|
||||||
|
|
||||||
class CRUDBase[
|
class BaseRepository[
|
||||||
ModelType: Base,
|
ModelType: Base,
|
||||||
CreateSchemaType: BaseModel,
|
CreateSchemaType: BaseModel,
|
||||||
UpdateSchemaType: BaseModel,
|
UpdateSchemaType: BaseModel,
|
||||||
]:
|
]:
|
||||||
"""Async CRUD operations for a model."""
|
"""Async repository operations for a model."""
|
||||||
|
|
||||||
def __init__(self, model: type[ModelType]):
|
def __init__(self, model: type[ModelType]):
|
||||||
"""
|
"""
|
||||||
CRUD object with default async methods to Create, Read, Update, Delete.
|
Repository object with default async methods to Create, Read, Update, Delete.
|
||||||
|
|
||||||
Parameters:
|
Parameters:
|
||||||
model: A SQLAlchemy model class
|
model: A SQLAlchemy model class
|
||||||
@@ -56,13 +61,7 @@ class CRUDBase[
|
|||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Model instance or None if not found
|
Model instance or None if not found
|
||||||
|
|
||||||
Example:
|
|
||||||
# Eager load user relationship
|
|
||||||
from sqlalchemy.orm import joinedload
|
|
||||||
session = await session_crud.get(db, id=session_id, options=[joinedload(UserSession.user)])
|
|
||||||
"""
|
"""
|
||||||
# Validate UUID format and convert to UUID object if string
|
|
||||||
try:
|
try:
|
||||||
if isinstance(id, uuid.UUID):
|
if isinstance(id, uuid.UUID):
|
||||||
uuid_obj = id
|
uuid_obj = id
|
||||||
@@ -75,7 +74,6 @@ class CRUDBase[
|
|||||||
try:
|
try:
|
||||||
query = select(self.model).where(self.model.id == uuid_obj)
|
query = select(self.model).where(self.model.id == uuid_obj)
|
||||||
|
|
||||||
# Apply eager loading options if provided
|
|
||||||
if options:
|
if options:
|
||||||
for option in options:
|
for option in options:
|
||||||
query = query.options(option)
|
query = query.options(option)
|
||||||
@@ -96,28 +94,17 @@ class CRUDBase[
|
|||||||
) -> list[ModelType]:
|
) -> list[ModelType]:
|
||||||
"""
|
"""
|
||||||
Get multiple records with pagination validation and optional eager loading.
|
Get multiple records with pagination validation and optional eager loading.
|
||||||
|
|
||||||
Args:
|
|
||||||
db: Database session
|
|
||||||
skip: Number of records to skip
|
|
||||||
limit: Maximum number of records to return
|
|
||||||
options: Optional list of SQLAlchemy load options for eager loading
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List of model instances
|
|
||||||
"""
|
"""
|
||||||
# Validate pagination parameters
|
|
||||||
if skip < 0:
|
if skip < 0:
|
||||||
raise ValueError("skip must be non-negative")
|
raise InvalidInputError("skip must be non-negative")
|
||||||
if limit < 0:
|
if limit < 0:
|
||||||
raise ValueError("limit must be non-negative")
|
raise InvalidInputError("limit must be non-negative")
|
||||||
if limit > 1000:
|
if limit > 1000:
|
||||||
raise ValueError("Maximum limit is 1000")
|
raise InvalidInputError("Maximum limit is 1000")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
query = select(self.model).offset(skip).limit(limit)
|
query = select(self.model).order_by(self.model.id).offset(skip).limit(limit)
|
||||||
|
|
||||||
# Apply eager loading options if provided
|
|
||||||
if options:
|
if options:
|
||||||
for option in options:
|
for option in options:
|
||||||
query = query.options(option)
|
query = query.options(option)
|
||||||
@@ -136,9 +123,8 @@ class CRUDBase[
|
|||||||
"""Create a new record with error handling.
|
"""Create a new record with error handling.
|
||||||
|
|
||||||
NOTE: This method is defensive code that's never called in practice.
|
NOTE: This method is defensive code that's never called in practice.
|
||||||
All CRUD subclasses (CRUDUser, CRUDOrganization, CRUDSession) override this method
|
All repository subclasses override this method with their own implementations.
|
||||||
with their own implementations, so the base implementation and its exception handlers
|
Marked as pragma: no cover to avoid false coverage gaps.
|
||||||
are never executed. Marked as pragma: no cover to avoid false coverage gaps.
|
|
||||||
"""
|
"""
|
||||||
try: # pragma: no cover
|
try: # pragma: no cover
|
||||||
obj_in_data = jsonable_encoder(obj_in)
|
obj_in_data = jsonable_encoder(obj_in)
|
||||||
@@ -154,15 +140,15 @@ class CRUDBase[
|
|||||||
logger.warning(
|
logger.warning(
|
||||||
f"Duplicate entry attempted for {self.model.__name__}: {error_msg}"
|
f"Duplicate entry attempted for {self.model.__name__}: {error_msg}"
|
||||||
)
|
)
|
||||||
raise ValueError(
|
raise DuplicateEntryError(
|
||||||
f"A {self.model.__name__} with this data already exists"
|
f"A {self.model.__name__} with this data already exists"
|
||||||
)
|
)
|
||||||
logger.error(f"Integrity error creating {self.model.__name__}: {error_msg}")
|
logger.error(f"Integrity error creating {self.model.__name__}: {error_msg}")
|
||||||
raise ValueError(f"Database integrity error: {error_msg}")
|
raise IntegrityConstraintError(f"Database integrity error: {error_msg}")
|
||||||
except (OperationalError, DataError) as e: # pragma: no cover
|
except (OperationalError, DataError) as e: # pragma: no cover
|
||||||
await db.rollback()
|
await db.rollback()
|
||||||
logger.error(f"Database error creating {self.model.__name__}: {e!s}")
|
logger.error(f"Database error creating {self.model.__name__}: {e!s}")
|
||||||
raise ValueError(f"Database operation failed: {e!s}")
|
raise IntegrityConstraintError(f"Database operation failed: {e!s}")
|
||||||
except Exception as e: # pragma: no cover
|
except Exception as e: # pragma: no cover
|
||||||
await db.rollback()
|
await db.rollback()
|
||||||
logger.error(
|
logger.error(
|
||||||
@@ -200,15 +186,15 @@ class CRUDBase[
|
|||||||
logger.warning(
|
logger.warning(
|
||||||
f"Duplicate entry attempted for {self.model.__name__}: {error_msg}"
|
f"Duplicate entry attempted for {self.model.__name__}: {error_msg}"
|
||||||
)
|
)
|
||||||
raise ValueError(
|
raise DuplicateEntryError(
|
||||||
f"A {self.model.__name__} with this data already exists"
|
f"A {self.model.__name__} with this data already exists"
|
||||||
)
|
)
|
||||||
logger.error(f"Integrity error updating {self.model.__name__}: {error_msg}")
|
logger.error(f"Integrity error updating {self.model.__name__}: {error_msg}")
|
||||||
raise ValueError(f"Database integrity error: {error_msg}")
|
raise IntegrityConstraintError(f"Database integrity error: {error_msg}")
|
||||||
except (OperationalError, DataError) as e:
|
except (OperationalError, DataError) as e:
|
||||||
await db.rollback()
|
await db.rollback()
|
||||||
logger.error(f"Database error updating {self.model.__name__}: {e!s}")
|
logger.error(f"Database error updating {self.model.__name__}: {e!s}")
|
||||||
raise ValueError(f"Database operation failed: {e!s}")
|
raise IntegrityConstraintError(f"Database operation failed: {e!s}")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
await db.rollback()
|
await db.rollback()
|
||||||
logger.error(
|
logger.error(
|
||||||
@@ -218,7 +204,6 @@ class CRUDBase[
|
|||||||
|
|
||||||
async def remove(self, db: AsyncSession, *, id: str) -> ModelType | None:
|
async def remove(self, db: AsyncSession, *, id: str) -> ModelType | None:
|
||||||
"""Delete a record with error handling and null check."""
|
"""Delete a record with error handling and null check."""
|
||||||
# Validate UUID format and convert to UUID object if string
|
|
||||||
try:
|
try:
|
||||||
if isinstance(id, uuid.UUID):
|
if isinstance(id, uuid.UUID):
|
||||||
uuid_obj = id
|
uuid_obj = id
|
||||||
@@ -247,7 +232,7 @@ class CRUDBase[
|
|||||||
await db.rollback()
|
await db.rollback()
|
||||||
error_msg = str(e.orig) if hasattr(e, "orig") else str(e)
|
error_msg = str(e.orig) if hasattr(e, "orig") else str(e)
|
||||||
logger.error(f"Integrity error deleting {self.model.__name__}: {error_msg}")
|
logger.error(f"Integrity error deleting {self.model.__name__}: {error_msg}")
|
||||||
raise ValueError(
|
raise IntegrityConstraintError(
|
||||||
f"Cannot delete {self.model.__name__}: referenced by other records"
|
f"Cannot delete {self.model.__name__}: referenced by other records"
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -272,57 +257,40 @@ class CRUDBase[
|
|||||||
Get multiple records with total count, filtering, and sorting.
|
Get multiple records with total count, filtering, and sorting.
|
||||||
|
|
||||||
NOTE: This method is defensive code that's never called in practice.
|
NOTE: This method is defensive code that's never called in practice.
|
||||||
All CRUD subclasses (CRUDUser, CRUDOrganization, CRUDSession) override this method
|
All repository subclasses override this method with their own implementations.
|
||||||
with their own implementations that include additional parameters like search.
|
|
||||||
Marked as pragma: no cover to avoid false coverage gaps.
|
Marked as pragma: no cover to avoid false coverage gaps.
|
||||||
|
|
||||||
Args:
|
|
||||||
db: Database session
|
|
||||||
skip: Number of records to skip
|
|
||||||
limit: Maximum number of records to return
|
|
||||||
sort_by: Field name to sort by (must be a valid model attribute)
|
|
||||||
sort_order: Sort order ("asc" or "desc")
|
|
||||||
filters: Dictionary of filters (field_name: value)
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Tuple of (items, total_count)
|
|
||||||
"""
|
"""
|
||||||
# Validate pagination parameters
|
|
||||||
if skip < 0:
|
if skip < 0:
|
||||||
raise ValueError("skip must be non-negative")
|
raise InvalidInputError("skip must be non-negative")
|
||||||
if limit < 0:
|
if limit < 0:
|
||||||
raise ValueError("limit must be non-negative")
|
raise InvalidInputError("limit must be non-negative")
|
||||||
if limit > 1000:
|
if limit > 1000:
|
||||||
raise ValueError("Maximum limit is 1000")
|
raise InvalidInputError("Maximum limit is 1000")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Build base query
|
|
||||||
query = select(self.model)
|
query = select(self.model)
|
||||||
|
|
||||||
# Exclude soft-deleted records by default
|
|
||||||
if hasattr(self.model, "deleted_at"):
|
if hasattr(self.model, "deleted_at"):
|
||||||
query = query.where(self.model.deleted_at.is_(None))
|
query = query.where(self.model.deleted_at.is_(None))
|
||||||
|
|
||||||
# Apply filters
|
|
||||||
if filters:
|
if filters:
|
||||||
for field, value in filters.items():
|
for field, value in filters.items():
|
||||||
if hasattr(self.model, field) and value is not None:
|
if hasattr(self.model, field) and value is not None:
|
||||||
query = query.where(getattr(self.model, field) == value)
|
query = query.where(getattr(self.model, field) == value)
|
||||||
|
|
||||||
# Get total count (before pagination)
|
|
||||||
count_query = select(func.count()).select_from(query.alias())
|
count_query = select(func.count()).select_from(query.alias())
|
||||||
count_result = await db.execute(count_query)
|
count_result = await db.execute(count_query)
|
||||||
total = count_result.scalar_one()
|
total = count_result.scalar_one()
|
||||||
|
|
||||||
# Apply sorting
|
|
||||||
if sort_by and hasattr(self.model, sort_by):
|
if sort_by and hasattr(self.model, sort_by):
|
||||||
sort_column = getattr(self.model, sort_by)
|
sort_column = getattr(self.model, sort_by)
|
||||||
if sort_order.lower() == "desc":
|
if sort_order.lower() == "desc":
|
||||||
query = query.order_by(sort_column.desc())
|
query = query.order_by(sort_column.desc())
|
||||||
else:
|
else:
|
||||||
query = query.order_by(sort_column.asc())
|
query = query.order_by(sort_column.asc())
|
||||||
|
else:
|
||||||
|
query = query.order_by(self.model.id)
|
||||||
|
|
||||||
# Apply pagination
|
|
||||||
query = query.offset(skip).limit(limit)
|
query = query.offset(skip).limit(limit)
|
||||||
items_result = await db.execute(query)
|
items_result = await db.execute(query)
|
||||||
items = list(items_result.scalars().all())
|
items = list(items_result.scalars().all())
|
||||||
@@ -356,7 +324,6 @@ class CRUDBase[
|
|||||||
"""
|
"""
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
|
||||||
# Validate UUID format and convert to UUID object if string
|
|
||||||
try:
|
try:
|
||||||
if isinstance(id, uuid.UUID):
|
if isinstance(id, uuid.UUID):
|
||||||
uuid_obj = id
|
uuid_obj = id
|
||||||
@@ -378,14 +345,12 @@ class CRUDBase[
|
|||||||
)
|
)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
# Check if model supports soft deletes
|
|
||||||
if not hasattr(self.model, "deleted_at"):
|
if not hasattr(self.model, "deleted_at"):
|
||||||
logger.error(f"{self.model.__name__} does not support soft deletes")
|
logger.error(f"{self.model.__name__} does not support soft deletes")
|
||||||
raise ValueError(
|
raise InvalidInputError(
|
||||||
f"{self.model.__name__} does not have a deleted_at column"
|
f"{self.model.__name__} does not have a deleted_at column"
|
||||||
)
|
)
|
||||||
|
|
||||||
# Set deleted_at timestamp
|
|
||||||
obj.deleted_at = datetime.now(UTC)
|
obj.deleted_at = datetime.now(UTC)
|
||||||
db.add(obj)
|
db.add(obj)
|
||||||
await db.commit()
|
await db.commit()
|
||||||
@@ -405,7 +370,6 @@ class CRUDBase[
|
|||||||
|
|
||||||
Only works if the model has a 'deleted_at' column.
|
Only works if the model has a 'deleted_at' column.
|
||||||
"""
|
"""
|
||||||
# Validate UUID format
|
|
||||||
try:
|
try:
|
||||||
if isinstance(id, uuid.UUID):
|
if isinstance(id, uuid.UUID):
|
||||||
uuid_obj = id
|
uuid_obj = id
|
||||||
@@ -416,7 +380,6 @@ class CRUDBase[
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Find the soft-deleted record
|
|
||||||
if hasattr(self.model, "deleted_at"):
|
if hasattr(self.model, "deleted_at"):
|
||||||
result = await db.execute(
|
result = await db.execute(
|
||||||
select(self.model).where(
|
select(self.model).where(
|
||||||
@@ -426,7 +389,7 @@ class CRUDBase[
|
|||||||
obj = result.scalar_one_or_none()
|
obj = result.scalar_one_or_none()
|
||||||
else:
|
else:
|
||||||
logger.error(f"{self.model.__name__} does not support soft deletes")
|
logger.error(f"{self.model.__name__} does not support soft deletes")
|
||||||
raise ValueError(
|
raise InvalidInputError(
|
||||||
f"{self.model.__name__} does not have a deleted_at column"
|
f"{self.model.__name__} does not have a deleted_at column"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -436,7 +399,6 @@ class CRUDBase[
|
|||||||
)
|
)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
# Clear deleted_at timestamp
|
|
||||||
obj.deleted_at = None
|
obj.deleted_at = None
|
||||||
db.add(obj)
|
db.add(obj)
|
||||||
await db.commit()
|
await db.commit()
|
||||||
@@ -449,3 +411,4 @@ class CRUDBase[
|
|||||||
exc_info=True,
|
exc_info=True,
|
||||||
)
|
)
|
||||||
raise
|
raise
|
||||||
|
|
||||||
235
backend/app/repositories/oauth_account.py
Normal file
235
backend/app/repositories/oauth_account.py
Normal file
@@ -0,0 +1,235 @@
|
|||||||
|
# app/repositories/oauth_account.py
|
||||||
|
"""Repository for OAuthAccount model async CRUD operations."""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from datetime import datetime
|
||||||
|
from uuid import UUID
|
||||||
|
|
||||||
|
from pydantic import BaseModel
|
||||||
|
from sqlalchemy import and_, delete, select
|
||||||
|
from sqlalchemy.exc import IntegrityError
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
from sqlalchemy.orm import joinedload
|
||||||
|
|
||||||
|
from app.core.repository_exceptions import DuplicateEntryError
|
||||||
|
from app.models.oauth_account import OAuthAccount
|
||||||
|
from app.repositories.base import BaseRepository
|
||||||
|
from app.schemas.oauth import OAuthAccountCreate
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class EmptySchema(BaseModel):
|
||||||
|
"""Placeholder schema for repository operations that don't need update schemas."""
|
||||||
|
|
||||||
|
|
||||||
|
class OAuthAccountRepository(BaseRepository[OAuthAccount, OAuthAccountCreate, EmptySchema]):
|
||||||
|
"""Repository for OAuth account links."""
|
||||||
|
|
||||||
|
async def get_by_provider_id(
|
||||||
|
self,
|
||||||
|
db: AsyncSession,
|
||||||
|
*,
|
||||||
|
provider: str,
|
||||||
|
provider_user_id: str,
|
||||||
|
) -> OAuthAccount | None:
|
||||||
|
"""Get OAuth account by provider and provider user ID."""
|
||||||
|
try:
|
||||||
|
result = await db.execute(
|
||||||
|
select(OAuthAccount)
|
||||||
|
.where(
|
||||||
|
and_(
|
||||||
|
OAuthAccount.provider == provider,
|
||||||
|
OAuthAccount.provider_user_id == provider_user_id,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
.options(joinedload(OAuthAccount.user))
|
||||||
|
)
|
||||||
|
return result.scalar_one_or_none()
|
||||||
|
except Exception as e: # pragma: no cover
|
||||||
|
logger.error(
|
||||||
|
f"Error getting OAuth account for {provider}:{provider_user_id}: {e!s}"
|
||||||
|
)
|
||||||
|
raise
|
||||||
|
|
||||||
|
async def get_by_provider_email(
|
||||||
|
self,
|
||||||
|
db: AsyncSession,
|
||||||
|
*,
|
||||||
|
provider: str,
|
||||||
|
email: str,
|
||||||
|
) -> OAuthAccount | None:
|
||||||
|
"""Get OAuth account by provider and email."""
|
||||||
|
try:
|
||||||
|
result = await db.execute(
|
||||||
|
select(OAuthAccount)
|
||||||
|
.where(
|
||||||
|
and_(
|
||||||
|
OAuthAccount.provider == provider,
|
||||||
|
OAuthAccount.provider_email == email,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
.options(joinedload(OAuthAccount.user))
|
||||||
|
)
|
||||||
|
return result.scalar_one_or_none()
|
||||||
|
except Exception as e: # pragma: no cover
|
||||||
|
logger.error(
|
||||||
|
f"Error getting OAuth account for {provider} email {email}: {e!s}"
|
||||||
|
)
|
||||||
|
raise
|
||||||
|
|
||||||
|
async def get_user_accounts(
|
||||||
|
self,
|
||||||
|
db: AsyncSession,
|
||||||
|
*,
|
||||||
|
user_id: str | UUID,
|
||||||
|
) -> list[OAuthAccount]:
|
||||||
|
"""Get all OAuth accounts linked to a user."""
|
||||||
|
try:
|
||||||
|
user_uuid = UUID(str(user_id)) if isinstance(user_id, str) else user_id
|
||||||
|
|
||||||
|
result = await db.execute(
|
||||||
|
select(OAuthAccount)
|
||||||
|
.where(OAuthAccount.user_id == user_uuid)
|
||||||
|
.order_by(OAuthAccount.created_at.desc())
|
||||||
|
)
|
||||||
|
return list(result.scalars().all())
|
||||||
|
except Exception as e: # pragma: no cover
|
||||||
|
logger.error(f"Error getting OAuth accounts for user {user_id}: {e!s}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
async def get_user_account_by_provider(
|
||||||
|
self,
|
||||||
|
db: AsyncSession,
|
||||||
|
*,
|
||||||
|
user_id: str | UUID,
|
||||||
|
provider: str,
|
||||||
|
) -> OAuthAccount | None:
|
||||||
|
"""Get a specific OAuth account for a user and provider."""
|
||||||
|
try:
|
||||||
|
user_uuid = UUID(str(user_id)) if isinstance(user_id, str) else user_id
|
||||||
|
|
||||||
|
result = await db.execute(
|
||||||
|
select(OAuthAccount).where(
|
||||||
|
and_(
|
||||||
|
OAuthAccount.user_id == user_uuid,
|
||||||
|
OAuthAccount.provider == provider,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return result.scalar_one_or_none()
|
||||||
|
except Exception as e: # pragma: no cover
|
||||||
|
logger.error(
|
||||||
|
f"Error getting OAuth account for user {user_id}, provider {provider}: {e!s}"
|
||||||
|
)
|
||||||
|
raise
|
||||||
|
|
||||||
|
async def create_account(
|
||||||
|
self, db: AsyncSession, *, obj_in: OAuthAccountCreate
|
||||||
|
) -> OAuthAccount:
|
||||||
|
"""Create a new OAuth account link."""
|
||||||
|
try:
|
||||||
|
db_obj = OAuthAccount(
|
||||||
|
user_id=obj_in.user_id,
|
||||||
|
provider=obj_in.provider,
|
||||||
|
provider_user_id=obj_in.provider_user_id,
|
||||||
|
provider_email=obj_in.provider_email,
|
||||||
|
access_token=obj_in.access_token,
|
||||||
|
refresh_token=obj_in.refresh_token,
|
||||||
|
token_expires_at=obj_in.token_expires_at,
|
||||||
|
)
|
||||||
|
db.add(db_obj)
|
||||||
|
await db.commit()
|
||||||
|
await db.refresh(db_obj)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"OAuth account created: {obj_in.provider} linked to user {obj_in.user_id}"
|
||||||
|
)
|
||||||
|
return db_obj
|
||||||
|
except IntegrityError as e: # pragma: no cover
|
||||||
|
await db.rollback()
|
||||||
|
error_msg = str(e.orig) if hasattr(e, "orig") else str(e)
|
||||||
|
if "uq_oauth_provider_user" in error_msg.lower():
|
||||||
|
logger.warning(
|
||||||
|
f"OAuth account already exists: {obj_in.provider}:{obj_in.provider_user_id}"
|
||||||
|
)
|
||||||
|
raise DuplicateEntryError(
|
||||||
|
f"This {obj_in.provider} account is already linked to another user"
|
||||||
|
)
|
||||||
|
logger.error(f"Integrity error creating OAuth account: {error_msg}")
|
||||||
|
raise DuplicateEntryError(f"Failed to create OAuth account: {error_msg}")
|
||||||
|
except Exception as e: # pragma: no cover
|
||||||
|
await db.rollback()
|
||||||
|
logger.error(f"Error creating OAuth account: {e!s}", exc_info=True)
|
||||||
|
raise
|
||||||
|
|
||||||
|
async def delete_account(
|
||||||
|
self,
|
||||||
|
db: AsyncSession,
|
||||||
|
*,
|
||||||
|
user_id: str | UUID,
|
||||||
|
provider: str,
|
||||||
|
) -> bool:
|
||||||
|
"""Delete an OAuth account link."""
|
||||||
|
try:
|
||||||
|
user_uuid = UUID(str(user_id)) if isinstance(user_id, str) else user_id
|
||||||
|
|
||||||
|
result = await db.execute(
|
||||||
|
delete(OAuthAccount).where(
|
||||||
|
and_(
|
||||||
|
OAuthAccount.user_id == user_uuid,
|
||||||
|
OAuthAccount.provider == provider,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
await db.commit()
|
||||||
|
|
||||||
|
deleted = result.rowcount > 0
|
||||||
|
if deleted:
|
||||||
|
logger.info(
|
||||||
|
f"OAuth account deleted: {provider} unlinked from user {user_id}"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logger.warning(
|
||||||
|
f"OAuth account not found for deletion: {provider} for user {user_id}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return deleted
|
||||||
|
except Exception as e: # pragma: no cover
|
||||||
|
await db.rollback()
|
||||||
|
logger.error(
|
||||||
|
f"Error deleting OAuth account {provider} for user {user_id}: {e!s}"
|
||||||
|
)
|
||||||
|
raise
|
||||||
|
|
||||||
|
async def update_tokens(
|
||||||
|
self,
|
||||||
|
db: AsyncSession,
|
||||||
|
*,
|
||||||
|
account: OAuthAccount,
|
||||||
|
access_token: str | None = None,
|
||||||
|
refresh_token: str | None = None,
|
||||||
|
token_expires_at: datetime | None = None,
|
||||||
|
) -> OAuthAccount:
|
||||||
|
"""Update OAuth tokens for an account."""
|
||||||
|
try:
|
||||||
|
if access_token is not None:
|
||||||
|
account.access_token = access_token
|
||||||
|
if refresh_token is not None:
|
||||||
|
account.refresh_token = refresh_token
|
||||||
|
if token_expires_at is not None:
|
||||||
|
account.token_expires_at = token_expires_at
|
||||||
|
|
||||||
|
db.add(account)
|
||||||
|
await db.commit()
|
||||||
|
await db.refresh(account)
|
||||||
|
|
||||||
|
return account
|
||||||
|
except Exception as e: # pragma: no cover
|
||||||
|
await db.rollback()
|
||||||
|
logger.error(f"Error updating OAuth tokens: {e!s}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
# Singleton instance
|
||||||
|
oauth_account_repo = OAuthAccountRepository(OAuthAccount)
|
||||||
108
backend/app/repositories/oauth_authorization_code.py
Normal file
108
backend/app/repositories/oauth_authorization_code.py
Normal file
@@ -0,0 +1,108 @@
|
|||||||
|
# app/repositories/oauth_authorization_code.py
|
||||||
|
"""Repository for OAuthAuthorizationCode model."""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from datetime import UTC, datetime
|
||||||
|
from uuid import UUID
|
||||||
|
|
||||||
|
from sqlalchemy import and_, delete, select, update
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
|
from app.models.oauth_authorization_code import OAuthAuthorizationCode
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class OAuthAuthorizationCodeRepository:
|
||||||
|
"""Repository for OAuth 2.0 authorization codes."""
|
||||||
|
|
||||||
|
async def create_code(
|
||||||
|
self,
|
||||||
|
db: AsyncSession,
|
||||||
|
*,
|
||||||
|
code: str,
|
||||||
|
client_id: str,
|
||||||
|
user_id: UUID,
|
||||||
|
redirect_uri: str,
|
||||||
|
scope: str,
|
||||||
|
expires_at: datetime,
|
||||||
|
code_challenge: str | None = None,
|
||||||
|
code_challenge_method: str | None = None,
|
||||||
|
state: str | None = None,
|
||||||
|
nonce: str | None = None,
|
||||||
|
) -> OAuthAuthorizationCode:
|
||||||
|
"""Create and persist a new authorization code."""
|
||||||
|
auth_code = OAuthAuthorizationCode(
|
||||||
|
code=code,
|
||||||
|
client_id=client_id,
|
||||||
|
user_id=user_id,
|
||||||
|
redirect_uri=redirect_uri,
|
||||||
|
scope=scope,
|
||||||
|
code_challenge=code_challenge,
|
||||||
|
code_challenge_method=code_challenge_method,
|
||||||
|
state=state,
|
||||||
|
nonce=nonce,
|
||||||
|
expires_at=expires_at,
|
||||||
|
used=False,
|
||||||
|
)
|
||||||
|
db.add(auth_code)
|
||||||
|
await db.commit()
|
||||||
|
return auth_code
|
||||||
|
|
||||||
|
async def consume_code_atomically(
|
||||||
|
self, db: AsyncSession, *, code: str
|
||||||
|
) -> UUID | None:
|
||||||
|
"""
|
||||||
|
Atomically mark a code as used and return its UUID.
|
||||||
|
|
||||||
|
Returns the UUID if the code was found and not yet used, None otherwise.
|
||||||
|
This prevents race conditions per RFC 6749 Section 4.1.2.
|
||||||
|
"""
|
||||||
|
stmt = (
|
||||||
|
update(OAuthAuthorizationCode)
|
||||||
|
.where(
|
||||||
|
and_(
|
||||||
|
OAuthAuthorizationCode.code == code,
|
||||||
|
OAuthAuthorizationCode.used == False, # noqa: E712
|
||||||
|
)
|
||||||
|
)
|
||||||
|
.values(used=True)
|
||||||
|
.returning(OAuthAuthorizationCode.id)
|
||||||
|
)
|
||||||
|
result = await db.execute(stmt)
|
||||||
|
row_id = result.scalar_one_or_none()
|
||||||
|
if row_id is not None:
|
||||||
|
await db.commit()
|
||||||
|
return row_id
|
||||||
|
|
||||||
|
async def get_by_id(
|
||||||
|
self, db: AsyncSession, *, code_id: UUID
|
||||||
|
) -> OAuthAuthorizationCode | None:
|
||||||
|
"""Get authorization code by its UUID primary key."""
|
||||||
|
result = await db.execute(
|
||||||
|
select(OAuthAuthorizationCode).where(OAuthAuthorizationCode.id == code_id)
|
||||||
|
)
|
||||||
|
return result.scalar_one_or_none()
|
||||||
|
|
||||||
|
async def get_by_code(
|
||||||
|
self, db: AsyncSession, *, code: str
|
||||||
|
) -> OAuthAuthorizationCode | None:
|
||||||
|
"""Get authorization code by the code string value."""
|
||||||
|
result = await db.execute(
|
||||||
|
select(OAuthAuthorizationCode).where(OAuthAuthorizationCode.code == code)
|
||||||
|
)
|
||||||
|
return result.scalar_one_or_none()
|
||||||
|
|
||||||
|
async def cleanup_expired(self, db: AsyncSession) -> int:
|
||||||
|
"""Delete all expired authorization codes. Returns count deleted."""
|
||||||
|
result = await db.execute(
|
||||||
|
delete(OAuthAuthorizationCode).where(
|
||||||
|
OAuthAuthorizationCode.expires_at < datetime.now(UTC)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
await db.commit()
|
||||||
|
return result.rowcount # type: ignore[attr-defined]
|
||||||
|
|
||||||
|
|
||||||
|
# Singleton instance
|
||||||
|
oauth_authorization_code_repo = OAuthAuthorizationCodeRepository()
|
||||||
199
backend/app/repositories/oauth_client.py
Normal file
199
backend/app/repositories/oauth_client.py
Normal file
@@ -0,0 +1,199 @@
|
|||||||
|
# app/repositories/oauth_client.py
|
||||||
|
"""Repository for OAuthClient model async CRUD operations."""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import secrets
|
||||||
|
from uuid import UUID
|
||||||
|
|
||||||
|
from pydantic import BaseModel
|
||||||
|
from sqlalchemy import and_, delete, select
|
||||||
|
from sqlalchemy.exc import IntegrityError
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
|
from app.core.repository_exceptions import DuplicateEntryError
|
||||||
|
from app.models.oauth_client import OAuthClient
|
||||||
|
from app.repositories.base import BaseRepository
|
||||||
|
from app.schemas.oauth import OAuthClientCreate
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class EmptySchema(BaseModel):
|
||||||
|
"""Placeholder schema for repository operations that don't need update schemas."""
|
||||||
|
|
||||||
|
|
||||||
|
class OAuthClientRepository(BaseRepository[OAuthClient, OAuthClientCreate, EmptySchema]):
|
||||||
|
"""Repository for OAuth clients (provider mode)."""
|
||||||
|
|
||||||
|
async def get_by_client_id(
|
||||||
|
self, db: AsyncSession, *, client_id: str
|
||||||
|
) -> OAuthClient | None:
|
||||||
|
"""Get OAuth client by client_id."""
|
||||||
|
try:
|
||||||
|
result = await db.execute(
|
||||||
|
select(OAuthClient).where(
|
||||||
|
and_(
|
||||||
|
OAuthClient.client_id == client_id,
|
||||||
|
OAuthClient.is_active == True, # noqa: E712
|
||||||
|
)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return result.scalar_one_or_none()
|
||||||
|
except Exception as e: # pragma: no cover
|
||||||
|
logger.error(f"Error getting OAuth client {client_id}: {e!s}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
async def create_client(
|
||||||
|
self,
|
||||||
|
db: AsyncSession,
|
||||||
|
*,
|
||||||
|
obj_in: OAuthClientCreate,
|
||||||
|
owner_user_id: UUID | None = None,
|
||||||
|
) -> tuple[OAuthClient, str | None]:
|
||||||
|
"""Create a new OAuth client."""
|
||||||
|
try:
|
||||||
|
client_id = secrets.token_urlsafe(32)
|
||||||
|
|
||||||
|
client_secret = None
|
||||||
|
client_secret_hash = None
|
||||||
|
if obj_in.client_type == "confidential":
|
||||||
|
client_secret = secrets.token_urlsafe(48)
|
||||||
|
from app.core.auth import get_password_hash
|
||||||
|
|
||||||
|
client_secret_hash = get_password_hash(client_secret)
|
||||||
|
|
||||||
|
db_obj = OAuthClient(
|
||||||
|
client_id=client_id,
|
||||||
|
client_secret_hash=client_secret_hash,
|
||||||
|
client_name=obj_in.client_name,
|
||||||
|
client_description=obj_in.client_description,
|
||||||
|
client_type=obj_in.client_type,
|
||||||
|
redirect_uris=obj_in.redirect_uris,
|
||||||
|
allowed_scopes=obj_in.allowed_scopes,
|
||||||
|
owner_user_id=owner_user_id,
|
||||||
|
is_active=True,
|
||||||
|
)
|
||||||
|
db.add(db_obj)
|
||||||
|
await db.commit()
|
||||||
|
await db.refresh(db_obj)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"OAuth client created: {obj_in.client_name} ({client_id[:8]}...)"
|
||||||
|
)
|
||||||
|
return db_obj, client_secret
|
||||||
|
except IntegrityError as e: # pragma: no cover
|
||||||
|
await db.rollback()
|
||||||
|
error_msg = str(e.orig) if hasattr(e, "orig") else str(e)
|
||||||
|
logger.error(f"Error creating OAuth client: {error_msg}")
|
||||||
|
raise DuplicateEntryError(f"Failed to create OAuth client: {error_msg}")
|
||||||
|
except Exception as e: # pragma: no cover
|
||||||
|
await db.rollback()
|
||||||
|
logger.error(f"Error creating OAuth client: {e!s}", exc_info=True)
|
||||||
|
raise
|
||||||
|
|
||||||
|
async def deactivate_client(
|
||||||
|
self, db: AsyncSession, *, client_id: str
|
||||||
|
) -> OAuthClient | None:
|
||||||
|
"""Deactivate an OAuth client."""
|
||||||
|
try:
|
||||||
|
client = await self.get_by_client_id(db, client_id=client_id)
|
||||||
|
if client is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
client.is_active = False
|
||||||
|
db.add(client)
|
||||||
|
await db.commit()
|
||||||
|
await db.refresh(client)
|
||||||
|
|
||||||
|
logger.info(f"OAuth client deactivated: {client.client_name}")
|
||||||
|
return client
|
||||||
|
except Exception as e: # pragma: no cover
|
||||||
|
await db.rollback()
|
||||||
|
logger.error(f"Error deactivating OAuth client {client_id}: {e!s}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
async def validate_redirect_uri(
|
||||||
|
self, db: AsyncSession, *, client_id: str, redirect_uri: str
|
||||||
|
) -> bool:
|
||||||
|
"""Validate that a redirect URI is allowed for a client."""
|
||||||
|
try:
|
||||||
|
client = await self.get_by_client_id(db, client_id=client_id)
|
||||||
|
if client is None:
|
||||||
|
return False
|
||||||
|
|
||||||
|
return redirect_uri in (client.redirect_uris or [])
|
||||||
|
except Exception as e: # pragma: no cover
|
||||||
|
logger.error(f"Error validating redirect URI: {e!s}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
async def verify_client_secret(
|
||||||
|
self, db: AsyncSession, *, client_id: str, client_secret: str
|
||||||
|
) -> bool:
|
||||||
|
"""Verify client credentials."""
|
||||||
|
try:
|
||||||
|
result = await db.execute(
|
||||||
|
select(OAuthClient).where(
|
||||||
|
and_(
|
||||||
|
OAuthClient.client_id == client_id,
|
||||||
|
OAuthClient.is_active == True, # noqa: E712
|
||||||
|
)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
client = result.scalar_one_or_none()
|
||||||
|
|
||||||
|
if client is None or client.client_secret_hash is None:
|
||||||
|
return False
|
||||||
|
|
||||||
|
from app.core.auth import verify_password
|
||||||
|
|
||||||
|
stored_hash: str = str(client.client_secret_hash)
|
||||||
|
|
||||||
|
if stored_hash.startswith("$2"):
|
||||||
|
return verify_password(client_secret, stored_hash)
|
||||||
|
else:
|
||||||
|
import hashlib
|
||||||
|
|
||||||
|
secret_hash = hashlib.sha256(client_secret.encode()).hexdigest()
|
||||||
|
return secrets.compare_digest(stored_hash, secret_hash)
|
||||||
|
except Exception as e: # pragma: no cover
|
||||||
|
logger.error(f"Error verifying client secret: {e!s}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
async def get_all_clients(
|
||||||
|
self, db: AsyncSession, *, include_inactive: bool = False
|
||||||
|
) -> list[OAuthClient]:
|
||||||
|
"""Get all OAuth clients."""
|
||||||
|
try:
|
||||||
|
query = select(OAuthClient).order_by(OAuthClient.created_at.desc())
|
||||||
|
if not include_inactive:
|
||||||
|
query = query.where(OAuthClient.is_active == True) # noqa: E712
|
||||||
|
|
||||||
|
result = await db.execute(query)
|
||||||
|
return list(result.scalars().all())
|
||||||
|
except Exception as e: # pragma: no cover
|
||||||
|
logger.error(f"Error getting all OAuth clients: {e!s}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
async def delete_client(self, db: AsyncSession, *, client_id: str) -> bool:
|
||||||
|
"""Delete an OAuth client permanently."""
|
||||||
|
try:
|
||||||
|
result = await db.execute(
|
||||||
|
delete(OAuthClient).where(OAuthClient.client_id == client_id)
|
||||||
|
)
|
||||||
|
await db.commit()
|
||||||
|
|
||||||
|
deleted = result.rowcount > 0
|
||||||
|
if deleted:
|
||||||
|
logger.info(f"OAuth client deleted: {client_id}")
|
||||||
|
else:
|
||||||
|
logger.warning(f"OAuth client not found for deletion: {client_id}")
|
||||||
|
|
||||||
|
return deleted
|
||||||
|
except Exception as e: # pragma: no cover
|
||||||
|
await db.rollback()
|
||||||
|
logger.error(f"Error deleting OAuth client {client_id}: {e!s}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
# Singleton instance
|
||||||
|
oauth_client_repo = OAuthClientRepository(OAuthClient)
|
||||||
112
backend/app/repositories/oauth_consent.py
Normal file
112
backend/app/repositories/oauth_consent.py
Normal file
@@ -0,0 +1,112 @@
|
|||||||
|
# app/repositories/oauth_consent.py
|
||||||
|
"""Repository for OAuthConsent model."""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from uuid import UUID
|
||||||
|
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from sqlalchemy import and_, delete, select
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
|
from app.models.oauth_client import OAuthClient
|
||||||
|
from app.models.oauth_provider_token import OAuthConsent
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class OAuthConsentRepository:
|
||||||
|
"""Repository for OAuth consent records (user grants to clients)."""
|
||||||
|
|
||||||
|
async def get_consent(
|
||||||
|
self, db: AsyncSession, *, user_id: UUID, client_id: str
|
||||||
|
) -> OAuthConsent | None:
|
||||||
|
"""Get the consent record for a user-client pair, or None if not found."""
|
||||||
|
result = await db.execute(
|
||||||
|
select(OAuthConsent).where(
|
||||||
|
and_(
|
||||||
|
OAuthConsent.user_id == user_id,
|
||||||
|
OAuthConsent.client_id == client_id,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return result.scalar_one_or_none()
|
||||||
|
|
||||||
|
async def grant_consent(
|
||||||
|
self,
|
||||||
|
db: AsyncSession,
|
||||||
|
*,
|
||||||
|
user_id: UUID,
|
||||||
|
client_id: str,
|
||||||
|
scopes: list[str],
|
||||||
|
) -> OAuthConsent:
|
||||||
|
"""
|
||||||
|
Create or update consent for a user-client pair.
|
||||||
|
|
||||||
|
If consent already exists, the new scopes are merged with existing ones.
|
||||||
|
Returns the created or updated consent record.
|
||||||
|
"""
|
||||||
|
consent = await self.get_consent(db, user_id=user_id, client_id=client_id)
|
||||||
|
|
||||||
|
if consent:
|
||||||
|
existing = set(consent.granted_scopes.split()) if consent.granted_scopes else set()
|
||||||
|
merged = existing | set(scopes)
|
||||||
|
consent.granted_scopes = " ".join(sorted(merged)) # type: ignore[assignment]
|
||||||
|
else:
|
||||||
|
consent = OAuthConsent(
|
||||||
|
user_id=user_id,
|
||||||
|
client_id=client_id,
|
||||||
|
granted_scopes=" ".join(sorted(set(scopes))),
|
||||||
|
)
|
||||||
|
db.add(consent)
|
||||||
|
|
||||||
|
await db.commit()
|
||||||
|
await db.refresh(consent)
|
||||||
|
return consent
|
||||||
|
|
||||||
|
async def get_user_consents_with_clients(
|
||||||
|
self, db: AsyncSession, *, user_id: UUID
|
||||||
|
) -> list[dict[str, Any]]:
|
||||||
|
"""Get all consent records for a user joined with client details."""
|
||||||
|
result = await db.execute(
|
||||||
|
select(OAuthConsent, OAuthClient)
|
||||||
|
.join(OAuthClient, OAuthConsent.client_id == OAuthClient.client_id)
|
||||||
|
.where(OAuthConsent.user_id == user_id)
|
||||||
|
)
|
||||||
|
rows = result.all()
|
||||||
|
return [
|
||||||
|
{
|
||||||
|
"client_id": consent.client_id,
|
||||||
|
"client_name": client.client_name,
|
||||||
|
"client_description": client.client_description,
|
||||||
|
"granted_scopes": consent.granted_scopes.split()
|
||||||
|
if consent.granted_scopes
|
||||||
|
else [],
|
||||||
|
"granted_at": consent.created_at.isoformat(),
|
||||||
|
}
|
||||||
|
for consent, client in rows
|
||||||
|
]
|
||||||
|
|
||||||
|
async def revoke_consent(
|
||||||
|
self, db: AsyncSession, *, user_id: UUID, client_id: str
|
||||||
|
) -> bool:
|
||||||
|
"""
|
||||||
|
Delete the consent record for a user-client pair.
|
||||||
|
|
||||||
|
Returns True if a record was found and deleted.
|
||||||
|
Note: Callers are responsible for also revoking associated tokens.
|
||||||
|
"""
|
||||||
|
result = await db.execute(
|
||||||
|
delete(OAuthConsent).where(
|
||||||
|
and_(
|
||||||
|
OAuthConsent.user_id == user_id,
|
||||||
|
OAuthConsent.client_id == client_id,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
await db.commit()
|
||||||
|
return result.rowcount > 0 # type: ignore[attr-defined]
|
||||||
|
|
||||||
|
|
||||||
|
# Singleton instance
|
||||||
|
oauth_consent_repo = OAuthConsentRepository()
|
||||||
146
backend/app/repositories/oauth_provider_token.py
Normal file
146
backend/app/repositories/oauth_provider_token.py
Normal file
@@ -0,0 +1,146 @@
|
|||||||
|
# app/repositories/oauth_provider_token.py
|
||||||
|
"""Repository for OAuthProviderRefreshToken model."""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from datetime import UTC, datetime, timedelta
|
||||||
|
from uuid import UUID
|
||||||
|
|
||||||
|
from sqlalchemy import and_, delete, select, update
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
|
from app.models.oauth_provider_token import OAuthProviderRefreshToken
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class OAuthProviderTokenRepository:
|
||||||
|
"""Repository for OAuth provider refresh tokens."""
|
||||||
|
|
||||||
|
async def create_token(
|
||||||
|
self,
|
||||||
|
db: AsyncSession,
|
||||||
|
*,
|
||||||
|
token_hash: str,
|
||||||
|
jti: str,
|
||||||
|
client_id: str,
|
||||||
|
user_id: UUID,
|
||||||
|
scope: str,
|
||||||
|
expires_at: datetime,
|
||||||
|
device_info: str | None = None,
|
||||||
|
ip_address: str | None = None,
|
||||||
|
) -> OAuthProviderRefreshToken:
|
||||||
|
"""Create and persist a new refresh token record."""
|
||||||
|
token = OAuthProviderRefreshToken(
|
||||||
|
token_hash=token_hash,
|
||||||
|
jti=jti,
|
||||||
|
client_id=client_id,
|
||||||
|
user_id=user_id,
|
||||||
|
scope=scope,
|
||||||
|
expires_at=expires_at,
|
||||||
|
device_info=device_info,
|
||||||
|
ip_address=ip_address,
|
||||||
|
)
|
||||||
|
db.add(token)
|
||||||
|
await db.commit()
|
||||||
|
return token
|
||||||
|
|
||||||
|
async def get_by_token_hash(
|
||||||
|
self, db: AsyncSession, *, token_hash: str
|
||||||
|
) -> OAuthProviderRefreshToken | None:
|
||||||
|
"""Get refresh token record by SHA-256 token hash."""
|
||||||
|
result = await db.execute(
|
||||||
|
select(OAuthProviderRefreshToken).where(
|
||||||
|
OAuthProviderRefreshToken.token_hash == token_hash
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return result.scalar_one_or_none()
|
||||||
|
|
||||||
|
async def get_by_jti(
|
||||||
|
self, db: AsyncSession, *, jti: str
|
||||||
|
) -> OAuthProviderRefreshToken | None:
|
||||||
|
"""Get refresh token record by JWT ID (JTI)."""
|
||||||
|
result = await db.execute(
|
||||||
|
select(OAuthProviderRefreshToken).where(
|
||||||
|
OAuthProviderRefreshToken.jti == jti
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return result.scalar_one_or_none()
|
||||||
|
|
||||||
|
async def revoke(
|
||||||
|
self, db: AsyncSession, *, token: OAuthProviderRefreshToken
|
||||||
|
) -> None:
|
||||||
|
"""Mark a specific token record as revoked."""
|
||||||
|
token.revoked = True # type: ignore[assignment]
|
||||||
|
token.last_used_at = datetime.now(UTC) # type: ignore[assignment]
|
||||||
|
await db.commit()
|
||||||
|
|
||||||
|
async def revoke_all_for_user_client(
|
||||||
|
self, db: AsyncSession, *, user_id: UUID, client_id: str
|
||||||
|
) -> int:
|
||||||
|
"""
|
||||||
|
Revoke all active tokens for a specific user-client pair.
|
||||||
|
|
||||||
|
Used when security incidents are detected (e.g., authorization code reuse).
|
||||||
|
Returns the number of tokens revoked.
|
||||||
|
"""
|
||||||
|
result = await db.execute(
|
||||||
|
update(OAuthProviderRefreshToken)
|
||||||
|
.where(
|
||||||
|
and_(
|
||||||
|
OAuthProviderRefreshToken.user_id == user_id,
|
||||||
|
OAuthProviderRefreshToken.client_id == client_id,
|
||||||
|
OAuthProviderRefreshToken.revoked == False, # noqa: E712
|
||||||
|
)
|
||||||
|
)
|
||||||
|
.values(revoked=True)
|
||||||
|
)
|
||||||
|
count = result.rowcount # type: ignore[attr-defined]
|
||||||
|
if count > 0:
|
||||||
|
await db.commit()
|
||||||
|
return count
|
||||||
|
|
||||||
|
async def revoke_all_for_user(
|
||||||
|
self, db: AsyncSession, *, user_id: UUID
|
||||||
|
) -> int:
|
||||||
|
"""
|
||||||
|
Revoke all active tokens for a user across all clients.
|
||||||
|
|
||||||
|
Used when user changes password or logs out everywhere.
|
||||||
|
Returns the number of tokens revoked.
|
||||||
|
"""
|
||||||
|
result = await db.execute(
|
||||||
|
update(OAuthProviderRefreshToken)
|
||||||
|
.where(
|
||||||
|
and_(
|
||||||
|
OAuthProviderRefreshToken.user_id == user_id,
|
||||||
|
OAuthProviderRefreshToken.revoked == False, # noqa: E712
|
||||||
|
)
|
||||||
|
)
|
||||||
|
.values(revoked=True)
|
||||||
|
)
|
||||||
|
count = result.rowcount # type: ignore[attr-defined]
|
||||||
|
if count > 0:
|
||||||
|
await db.commit()
|
||||||
|
return count
|
||||||
|
|
||||||
|
async def cleanup_expired(
|
||||||
|
self, db: AsyncSession, *, cutoff_days: int = 7
|
||||||
|
) -> int:
|
||||||
|
"""
|
||||||
|
Delete expired refresh tokens older than cutoff_days.
|
||||||
|
|
||||||
|
Should be called periodically (e.g., daily).
|
||||||
|
Returns the number of tokens deleted.
|
||||||
|
"""
|
||||||
|
cutoff = datetime.now(UTC) - timedelta(days=cutoff_days)
|
||||||
|
result = await db.execute(
|
||||||
|
delete(OAuthProviderRefreshToken).where(
|
||||||
|
OAuthProviderRefreshToken.expires_at < cutoff
|
||||||
|
)
|
||||||
|
)
|
||||||
|
await db.commit()
|
||||||
|
return result.rowcount # type: ignore[attr-defined]
|
||||||
|
|
||||||
|
|
||||||
|
# Singleton instance
|
||||||
|
oauth_provider_token_repo = OAuthProviderTokenRepository()
|
||||||
113
backend/app/repositories/oauth_state.py
Normal file
113
backend/app/repositories/oauth_state.py
Normal file
@@ -0,0 +1,113 @@
|
|||||||
|
# app/repositories/oauth_state.py
|
||||||
|
"""Repository for OAuthState model async CRUD operations."""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from datetime import UTC, datetime
|
||||||
|
|
||||||
|
from pydantic import BaseModel
|
||||||
|
from sqlalchemy import delete, select
|
||||||
|
from sqlalchemy.exc import IntegrityError
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
|
from app.core.repository_exceptions import DuplicateEntryError
|
||||||
|
from app.models.oauth_state import OAuthState
|
||||||
|
from app.repositories.base import BaseRepository
|
||||||
|
from app.schemas.oauth import OAuthStateCreate
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class EmptySchema(BaseModel):
|
||||||
|
"""Placeholder schema for repository operations that don't need update schemas."""
|
||||||
|
|
||||||
|
|
||||||
|
class OAuthStateRepository(BaseRepository[OAuthState, OAuthStateCreate, EmptySchema]):
|
||||||
|
"""Repository for OAuth state (CSRF protection)."""
|
||||||
|
|
||||||
|
async def create_state(
|
||||||
|
self, db: AsyncSession, *, obj_in: OAuthStateCreate
|
||||||
|
) -> OAuthState:
|
||||||
|
"""Create a new OAuth state for CSRF protection."""
|
||||||
|
try:
|
||||||
|
db_obj = OAuthState(
|
||||||
|
state=obj_in.state,
|
||||||
|
code_verifier=obj_in.code_verifier,
|
||||||
|
nonce=obj_in.nonce,
|
||||||
|
provider=obj_in.provider,
|
||||||
|
redirect_uri=obj_in.redirect_uri,
|
||||||
|
user_id=obj_in.user_id,
|
||||||
|
expires_at=obj_in.expires_at,
|
||||||
|
)
|
||||||
|
db.add(db_obj)
|
||||||
|
await db.commit()
|
||||||
|
await db.refresh(db_obj)
|
||||||
|
|
||||||
|
logger.debug(f"OAuth state created for {obj_in.provider}")
|
||||||
|
return db_obj
|
||||||
|
except IntegrityError as e: # pragma: no cover
|
||||||
|
await db.rollback()
|
||||||
|
error_msg = str(e.orig) if hasattr(e, "orig") else str(e)
|
||||||
|
logger.error(f"OAuth state collision: {error_msg}")
|
||||||
|
raise DuplicateEntryError("Failed to create OAuth state, please retry")
|
||||||
|
except Exception as e: # pragma: no cover
|
||||||
|
await db.rollback()
|
||||||
|
logger.error(f"Error creating OAuth state: {e!s}", exc_info=True)
|
||||||
|
raise
|
||||||
|
|
||||||
|
async def get_and_consume_state(
|
||||||
|
self, db: AsyncSession, *, state: str
|
||||||
|
) -> OAuthState | None:
|
||||||
|
"""Get and delete OAuth state (consume it)."""
|
||||||
|
try:
|
||||||
|
result = await db.execute(
|
||||||
|
select(OAuthState).where(OAuthState.state == state)
|
||||||
|
)
|
||||||
|
db_obj = result.scalar_one_or_none()
|
||||||
|
|
||||||
|
if db_obj is None:
|
||||||
|
logger.warning(f"OAuth state not found: {state[:8]}...")
|
||||||
|
return None
|
||||||
|
|
||||||
|
now = datetime.now(UTC)
|
||||||
|
expires_at = db_obj.expires_at
|
||||||
|
if expires_at.tzinfo is None:
|
||||||
|
expires_at = expires_at.replace(tzinfo=UTC)
|
||||||
|
|
||||||
|
if expires_at < now:
|
||||||
|
logger.warning(f"OAuth state expired: {state[:8]}...")
|
||||||
|
await db.delete(db_obj)
|
||||||
|
await db.commit()
|
||||||
|
return None
|
||||||
|
|
||||||
|
await db.delete(db_obj)
|
||||||
|
await db.commit()
|
||||||
|
|
||||||
|
logger.debug(f"OAuth state consumed: {state[:8]}...")
|
||||||
|
return db_obj
|
||||||
|
except Exception as e: # pragma: no cover
|
||||||
|
await db.rollback()
|
||||||
|
logger.error(f"Error consuming OAuth state: {e!s}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
async def cleanup_expired(self, db: AsyncSession) -> int:
|
||||||
|
"""Clean up expired OAuth states."""
|
||||||
|
try:
|
||||||
|
now = datetime.now(UTC)
|
||||||
|
|
||||||
|
stmt = delete(OAuthState).where(OAuthState.expires_at < now)
|
||||||
|
result = await db.execute(stmt)
|
||||||
|
await db.commit()
|
||||||
|
|
||||||
|
count = result.rowcount
|
||||||
|
if count > 0:
|
||||||
|
logger.info(f"Cleaned up {count} expired OAuth states")
|
||||||
|
|
||||||
|
return count
|
||||||
|
except Exception as e: # pragma: no cover
|
||||||
|
await db.rollback()
|
||||||
|
logger.error(f"Error cleaning up expired OAuth states: {e!s}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
# Singleton instance
|
||||||
|
oauth_state_repo = OAuthStateRepository(OAuthState)
|
||||||
73
backend/app/crud/organization.py → backend/app/repositories/organization.py
Executable file → Normal file
73
backend/app/crud/organization.py → backend/app/repositories/organization.py
Executable file → Normal file
@@ -1,5 +1,5 @@
|
|||||||
# app/crud/organization_async.py
|
# app/repositories/organization.py
|
||||||
"""Async CRUD operations for Organization model using SQLAlchemy 2.0 patterns."""
|
"""Repository for Organization model async CRUD operations using SQLAlchemy 2.0 patterns."""
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
from typing import Any
|
from typing import Any
|
||||||
@@ -9,10 +9,11 @@ from sqlalchemy import and_, case, func, or_, select
|
|||||||
from sqlalchemy.exc import IntegrityError
|
from sqlalchemy.exc import IntegrityError
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
from app.crud.base import CRUDBase
|
from app.core.repository_exceptions import DuplicateEntryError, IntegrityConstraintError
|
||||||
from app.models.organization import Organization
|
from app.models.organization import Organization
|
||||||
from app.models.user import User
|
from app.models.user import User
|
||||||
from app.models.user_organization import OrganizationRole, UserOrganization
|
from app.models.user_organization import OrganizationRole, UserOrganization
|
||||||
|
from app.repositories.base import BaseRepository
|
||||||
from app.schemas.organizations import (
|
from app.schemas.organizations import (
|
||||||
OrganizationCreate,
|
OrganizationCreate,
|
||||||
OrganizationUpdate,
|
OrganizationUpdate,
|
||||||
@@ -21,8 +22,8 @@ from app.schemas.organizations import (
|
|||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUpdate]):
|
class OrganizationRepository(BaseRepository[Organization, OrganizationCreate, OrganizationUpdate]):
|
||||||
"""Async CRUD operations for Organization model."""
|
"""Repository for Organization model."""
|
||||||
|
|
||||||
async def get_by_slug(self, db: AsyncSession, *, slug: str) -> Organization | None:
|
async def get_by_slug(self, db: AsyncSession, *, slug: str) -> Organization | None:
|
||||||
"""Get organization by slug."""
|
"""Get organization by slug."""
|
||||||
@@ -54,13 +55,13 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp
|
|||||||
except IntegrityError as e:
|
except IntegrityError as e:
|
||||||
await db.rollback()
|
await db.rollback()
|
||||||
error_msg = str(e.orig) if hasattr(e, "orig") else str(e)
|
error_msg = str(e.orig) if hasattr(e, "orig") else str(e)
|
||||||
if "slug" in error_msg.lower():
|
if "slug" in error_msg.lower() or "unique" in error_msg.lower() or "duplicate" in error_msg.lower():
|
||||||
logger.warning(f"Duplicate slug attempted: {obj_in.slug}")
|
logger.warning(f"Duplicate slug attempted: {obj_in.slug}")
|
||||||
raise ValueError(
|
raise DuplicateEntryError(
|
||||||
f"Organization with slug '{obj_in.slug}' already exists"
|
f"Organization with slug '{obj_in.slug}' already exists"
|
||||||
)
|
)
|
||||||
logger.error(f"Integrity error creating organization: {error_msg}")
|
logger.error(f"Integrity error creating organization: {error_msg}")
|
||||||
raise ValueError(f"Database integrity error: {error_msg}")
|
raise IntegrityConstraintError(f"Database integrity error: {error_msg}")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
await db.rollback()
|
await db.rollback()
|
||||||
logger.error(
|
logger.error(
|
||||||
@@ -79,16 +80,10 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp
|
|||||||
sort_by: str = "created_at",
|
sort_by: str = "created_at",
|
||||||
sort_order: str = "desc",
|
sort_order: str = "desc",
|
||||||
) -> tuple[list[Organization], int]:
|
) -> tuple[list[Organization], int]:
|
||||||
"""
|
"""Get multiple organizations with filtering, searching, and sorting."""
|
||||||
Get multiple organizations with filtering, searching, and sorting.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Tuple of (organizations list, total count)
|
|
||||||
"""
|
|
||||||
try:
|
try:
|
||||||
query = select(Organization)
|
query = select(Organization)
|
||||||
|
|
||||||
# Apply filters
|
|
||||||
if is_active is not None:
|
if is_active is not None:
|
||||||
query = query.where(Organization.is_active == is_active)
|
query = query.where(Organization.is_active == is_active)
|
||||||
|
|
||||||
@@ -100,19 +95,16 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp
|
|||||||
)
|
)
|
||||||
query = query.where(search_filter)
|
query = query.where(search_filter)
|
||||||
|
|
||||||
# Get total count before pagination
|
|
||||||
count_query = select(func.count()).select_from(query.alias())
|
count_query = select(func.count()).select_from(query.alias())
|
||||||
count_result = await db.execute(count_query)
|
count_result = await db.execute(count_query)
|
||||||
total = count_result.scalar_one()
|
total = count_result.scalar_one()
|
||||||
|
|
||||||
# Apply sorting
|
|
||||||
sort_column = getattr(Organization, sort_by, Organization.created_at)
|
sort_column = getattr(Organization, sort_by, Organization.created_at)
|
||||||
if sort_order == "desc":
|
if sort_order == "desc":
|
||||||
query = query.order_by(sort_column.desc())
|
query = query.order_by(sort_column.desc())
|
||||||
else:
|
else:
|
||||||
query = query.order_by(sort_column.asc())
|
query = query.order_by(sort_column.asc())
|
||||||
|
|
||||||
# Apply pagination
|
|
||||||
query = query.offset(skip).limit(limit)
|
query = query.offset(skip).limit(limit)
|
||||||
result = await db.execute(query)
|
result = await db.execute(query)
|
||||||
organizations = list(result.scalars().all())
|
organizations = list(result.scalars().all())
|
||||||
@@ -149,16 +141,8 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp
|
|||||||
is_active: bool | None = None,
|
is_active: bool | None = None,
|
||||||
search: str | None = None,
|
search: str | None = None,
|
||||||
) -> tuple[list[dict[str, Any]], int]:
|
) -> tuple[list[dict[str, Any]], int]:
|
||||||
"""
|
"""Get organizations with member counts in a SINGLE QUERY using JOIN and GROUP BY."""
|
||||||
Get organizations with member counts in a SINGLE QUERY using JOIN and GROUP BY.
|
|
||||||
This eliminates the N+1 query problem.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Tuple of (list of dicts with org and member_count, total count)
|
|
||||||
"""
|
|
||||||
try:
|
try:
|
||||||
# Build base query with LEFT JOIN and GROUP BY
|
|
||||||
# Use CASE statement to count only active members
|
|
||||||
query = (
|
query = (
|
||||||
select(
|
select(
|
||||||
Organization,
|
Organization,
|
||||||
@@ -181,7 +165,6 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp
|
|||||||
.group_by(Organization.id)
|
.group_by(Organization.id)
|
||||||
)
|
)
|
||||||
|
|
||||||
# Apply filters
|
|
||||||
if is_active is not None:
|
if is_active is not None:
|
||||||
query = query.where(Organization.is_active == is_active)
|
query = query.where(Organization.is_active == is_active)
|
||||||
|
|
||||||
@@ -193,7 +176,6 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp
|
|||||||
)
|
)
|
||||||
query = query.where(search_filter)
|
query = query.where(search_filter)
|
||||||
|
|
||||||
# Get total count
|
|
||||||
count_query = select(func.count(Organization.id))
|
count_query = select(func.count(Organization.id))
|
||||||
if is_active is not None:
|
if is_active is not None:
|
||||||
count_query = count_query.where(Organization.is_active == is_active)
|
count_query = count_query.where(Organization.is_active == is_active)
|
||||||
@@ -203,7 +185,6 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp
|
|||||||
count_result = await db.execute(count_query)
|
count_result = await db.execute(count_query)
|
||||||
total = count_result.scalar_one()
|
total = count_result.scalar_one()
|
||||||
|
|
||||||
# Apply pagination and ordering
|
|
||||||
query = (
|
query = (
|
||||||
query.order_by(Organization.created_at.desc()).offset(skip).limit(limit)
|
query.order_by(Organization.created_at.desc()).offset(skip).limit(limit)
|
||||||
)
|
)
|
||||||
@@ -211,7 +192,6 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp
|
|||||||
result = await db.execute(query)
|
result = await db.execute(query)
|
||||||
rows = result.all()
|
rows = result.all()
|
||||||
|
|
||||||
# Convert to list of dicts
|
|
||||||
orgs_with_counts = [
|
orgs_with_counts = [
|
||||||
{"organization": org, "member_count": member_count}
|
{"organization": org, "member_count": member_count}
|
||||||
for org, member_count in rows
|
for org, member_count in rows
|
||||||
@@ -236,7 +216,6 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp
|
|||||||
) -> UserOrganization:
|
) -> UserOrganization:
|
||||||
"""Add a user to an organization with a specific role."""
|
"""Add a user to an organization with a specific role."""
|
||||||
try:
|
try:
|
||||||
# Check if relationship already exists
|
|
||||||
result = await db.execute(
|
result = await db.execute(
|
||||||
select(UserOrganization).where(
|
select(UserOrganization).where(
|
||||||
and_(
|
and_(
|
||||||
@@ -248,7 +227,6 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp
|
|||||||
existing = result.scalar_one_or_none()
|
existing = result.scalar_one_or_none()
|
||||||
|
|
||||||
if existing:
|
if existing:
|
||||||
# Reactivate if inactive, or raise error if already active
|
|
||||||
if not existing.is_active:
|
if not existing.is_active:
|
||||||
existing.is_active = True
|
existing.is_active = True
|
||||||
existing.role = role
|
existing.role = role
|
||||||
@@ -257,9 +235,8 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp
|
|||||||
await db.refresh(existing)
|
await db.refresh(existing)
|
||||||
return existing
|
return existing
|
||||||
else:
|
else:
|
||||||
raise ValueError("User is already a member of this organization")
|
raise DuplicateEntryError("User is already a member of this organization")
|
||||||
|
|
||||||
# Create new relationship
|
|
||||||
user_org = UserOrganization(
|
user_org = UserOrganization(
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
organization_id=organization_id,
|
organization_id=organization_id,
|
||||||
@@ -274,7 +251,7 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp
|
|||||||
except IntegrityError as e:
|
except IntegrityError as e:
|
||||||
await db.rollback()
|
await db.rollback()
|
||||||
logger.error(f"Integrity error adding user to organization: {e!s}")
|
logger.error(f"Integrity error adding user to organization: {e!s}")
|
||||||
raise ValueError("Failed to add user to organization")
|
raise IntegrityConstraintError("Failed to add user to organization")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
await db.rollback()
|
await db.rollback()
|
||||||
logger.error(f"Error adding user to organization: {e!s}", exc_info=True)
|
logger.error(f"Error adding user to organization: {e!s}", exc_info=True)
|
||||||
@@ -350,14 +327,8 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp
|
|||||||
limit: int = 100,
|
limit: int = 100,
|
||||||
is_active: bool = True,
|
is_active: bool = True,
|
||||||
) -> tuple[list[dict[str, Any]], int]:
|
) -> tuple[list[dict[str, Any]], int]:
|
||||||
"""
|
"""Get members of an organization with user details."""
|
||||||
Get members of an organization with user details.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Tuple of (members list with user details, total count)
|
|
||||||
"""
|
|
||||||
try:
|
try:
|
||||||
# Build query with join
|
|
||||||
query = (
|
query = (
|
||||||
select(UserOrganization, User)
|
select(UserOrganization, User)
|
||||||
.join(User, UserOrganization.user_id == User.id)
|
.join(User, UserOrganization.user_id == User.id)
|
||||||
@@ -367,7 +338,6 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp
|
|||||||
if is_active is not None:
|
if is_active is not None:
|
||||||
query = query.where(UserOrganization.is_active == is_active)
|
query = query.where(UserOrganization.is_active == is_active)
|
||||||
|
|
||||||
# Get total count
|
|
||||||
count_query = select(func.count()).select_from(
|
count_query = select(func.count()).select_from(
|
||||||
select(UserOrganization)
|
select(UserOrganization)
|
||||||
.where(UserOrganization.organization_id == organization_id)
|
.where(UserOrganization.organization_id == organization_id)
|
||||||
@@ -381,7 +351,6 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp
|
|||||||
count_result = await db.execute(count_query)
|
count_result = await db.execute(count_query)
|
||||||
total = count_result.scalar_one()
|
total = count_result.scalar_one()
|
||||||
|
|
||||||
# Apply ordering and pagination
|
|
||||||
query = (
|
query = (
|
||||||
query.order_by(UserOrganization.created_at.desc())
|
query.order_by(UserOrganization.created_at.desc())
|
||||||
.offset(skip)
|
.offset(skip)
|
||||||
@@ -435,15 +404,8 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp
|
|||||||
async def get_user_organizations_with_details(
|
async def get_user_organizations_with_details(
|
||||||
self, db: AsyncSession, *, user_id: UUID, is_active: bool = True
|
self, db: AsyncSession, *, user_id: UUID, is_active: bool = True
|
||||||
) -> list[dict[str, Any]]:
|
) -> list[dict[str, Any]]:
|
||||||
"""
|
"""Get user's organizations with role and member count in SINGLE QUERY."""
|
||||||
Get user's organizations with role and member count in SINGLE QUERY.
|
|
||||||
Eliminates N+1 problem by using subquery for member counts.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List of dicts with organization, role, and member_count
|
|
||||||
"""
|
|
||||||
try:
|
try:
|
||||||
# Subquery to get member counts for each organization
|
|
||||||
member_count_subq = (
|
member_count_subq = (
|
||||||
select(
|
select(
|
||||||
UserOrganization.organization_id,
|
UserOrganization.organization_id,
|
||||||
@@ -454,7 +416,6 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp
|
|||||||
.subquery()
|
.subquery()
|
||||||
)
|
)
|
||||||
|
|
||||||
# Main query with JOIN to get org, role, and member count
|
|
||||||
query = (
|
query = (
|
||||||
select(
|
select(
|
||||||
Organization,
|
Organization,
|
||||||
@@ -531,5 +492,5 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp
|
|||||||
return role in [OrganizationRole.OWNER, OrganizationRole.ADMIN]
|
return role in [OrganizationRole.OWNER, OrganizationRole.ADMIN]
|
||||||
|
|
||||||
|
|
||||||
# Create a singleton instance for use across the application
|
# Singleton instance
|
||||||
organization = CRUDOrganization(Organization)
|
organization_repo = OrganizationRepository(Organization)
|
||||||
183
backend/app/crud/session.py → backend/app/repositories/session.py
Executable file → Normal file
183
backend/app/crud/session.py → backend/app/repositories/session.py
Executable file → Normal file
@@ -1,6 +1,5 @@
|
|||||||
"""
|
# app/repositories/session.py
|
||||||
Async CRUD operations for user sessions using SQLAlchemy 2.0 patterns.
|
"""Repository for UserSession model async CRUD operations using SQLAlchemy 2.0 patterns."""
|
||||||
"""
|
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
import uuid
|
import uuid
|
||||||
@@ -11,27 +10,19 @@ from sqlalchemy import and_, delete, func, select, update
|
|||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
from sqlalchemy.orm import joinedload
|
from sqlalchemy.orm import joinedload
|
||||||
|
|
||||||
from app.crud.base import CRUDBase
|
from app.core.repository_exceptions import InvalidInputError, IntegrityConstraintError
|
||||||
from app.models.user_session import UserSession
|
from app.models.user_session import UserSession
|
||||||
|
from app.repositories.base import BaseRepository
|
||||||
from app.schemas.sessions import SessionCreate, SessionUpdate
|
from app.schemas.sessions import SessionCreate, SessionUpdate
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class CRUDSession(CRUDBase[UserSession, SessionCreate, SessionUpdate]):
|
class SessionRepository(BaseRepository[UserSession, SessionCreate, SessionUpdate]):
|
||||||
"""Async CRUD operations for user sessions."""
|
"""Repository for UserSession model."""
|
||||||
|
|
||||||
async def get_by_jti(self, db: AsyncSession, *, jti: str) -> UserSession | None:
|
async def get_by_jti(self, db: AsyncSession, *, jti: str) -> UserSession | None:
|
||||||
"""
|
"""Get session by refresh token JTI."""
|
||||||
Get session by refresh token JTI.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
db: Database session
|
|
||||||
jti: Refresh token JWT ID
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
UserSession if found, None otherwise
|
|
||||||
"""
|
|
||||||
try:
|
try:
|
||||||
result = await db.execute(
|
result = await db.execute(
|
||||||
select(UserSession).where(UserSession.refresh_token_jti == jti)
|
select(UserSession).where(UserSession.refresh_token_jti == jti)
|
||||||
@@ -44,16 +35,7 @@ class CRUDSession(CRUDBase[UserSession, SessionCreate, SessionUpdate]):
|
|||||||
async def get_active_by_jti(
|
async def get_active_by_jti(
|
||||||
self, db: AsyncSession, *, jti: str
|
self, db: AsyncSession, *, jti: str
|
||||||
) -> UserSession | None:
|
) -> UserSession | None:
|
||||||
"""
|
"""Get active session by refresh token JTI."""
|
||||||
Get active session by refresh token JTI.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
db: Database session
|
|
||||||
jti: Refresh token JWT ID
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Active UserSession if found, None otherwise
|
|
||||||
"""
|
|
||||||
try:
|
try:
|
||||||
result = await db.execute(
|
result = await db.execute(
|
||||||
select(UserSession).where(
|
select(UserSession).where(
|
||||||
@@ -76,25 +58,12 @@ class CRUDSession(CRUDBase[UserSession, SessionCreate, SessionUpdate]):
|
|||||||
active_only: bool = True,
|
active_only: bool = True,
|
||||||
with_user: bool = False,
|
with_user: bool = False,
|
||||||
) -> list[UserSession]:
|
) -> list[UserSession]:
|
||||||
"""
|
"""Get all sessions for a user with optional eager loading."""
|
||||||
Get all sessions for a user with optional eager loading.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
db: Database session
|
|
||||||
user_id: User ID
|
|
||||||
active_only: If True, return only active sessions
|
|
||||||
with_user: If True, eager load user relationship to prevent N+1
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List of UserSession objects
|
|
||||||
"""
|
|
||||||
try:
|
try:
|
||||||
# Convert user_id string to UUID if needed
|
|
||||||
user_uuid = UUID(user_id) if isinstance(user_id, str) else user_id
|
user_uuid = UUID(user_id) if isinstance(user_id, str) else user_id
|
||||||
|
|
||||||
query = select(UserSession).where(UserSession.user_id == user_uuid)
|
query = select(UserSession).where(UserSession.user_id == user_uuid)
|
||||||
|
|
||||||
# Add eager loading if requested to prevent N+1 queries
|
|
||||||
if with_user:
|
if with_user:
|
||||||
query = query.options(joinedload(UserSession.user))
|
query = query.options(joinedload(UserSession.user))
|
||||||
|
|
||||||
@@ -111,19 +80,7 @@ class CRUDSession(CRUDBase[UserSession, SessionCreate, SessionUpdate]):
|
|||||||
async def create_session(
|
async def create_session(
|
||||||
self, db: AsyncSession, *, obj_in: SessionCreate
|
self, db: AsyncSession, *, obj_in: SessionCreate
|
||||||
) -> UserSession:
|
) -> UserSession:
|
||||||
"""
|
"""Create a new user session."""
|
||||||
Create a new user session.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
db: Database session
|
|
||||||
obj_in: SessionCreate schema with session data
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Created UserSession
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
ValueError: If session creation fails
|
|
||||||
"""
|
|
||||||
try:
|
try:
|
||||||
db_obj = UserSession(
|
db_obj = UserSession(
|
||||||
user_id=obj_in.user_id,
|
user_id=obj_in.user_id,
|
||||||
@@ -151,21 +108,12 @@ class CRUDSession(CRUDBase[UserSession, SessionCreate, SessionUpdate]):
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
await db.rollback()
|
await db.rollback()
|
||||||
logger.error(f"Error creating session: {e!s}", exc_info=True)
|
logger.error(f"Error creating session: {e!s}", exc_info=True)
|
||||||
raise ValueError(f"Failed to create session: {e!s}")
|
raise IntegrityConstraintError(f"Failed to create session: {e!s}")
|
||||||
|
|
||||||
async def deactivate(
|
async def deactivate(
|
||||||
self, db: AsyncSession, *, session_id: str
|
self, db: AsyncSession, *, session_id: str
|
||||||
) -> UserSession | None:
|
) -> UserSession | None:
|
||||||
"""
|
"""Deactivate a session (logout from device)."""
|
||||||
Deactivate a session (logout from device).
|
|
||||||
|
|
||||||
Args:
|
|
||||||
db: Database session
|
|
||||||
session_id: Session UUID
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Deactivated UserSession if found, None otherwise
|
|
||||||
"""
|
|
||||||
try:
|
try:
|
||||||
session = await self.get(db, id=session_id)
|
session = await self.get(db, id=session_id)
|
||||||
if not session:
|
if not session:
|
||||||
@@ -191,18 +139,8 @@ class CRUDSession(CRUDBase[UserSession, SessionCreate, SessionUpdate]):
|
|||||||
async def deactivate_all_user_sessions(
|
async def deactivate_all_user_sessions(
|
||||||
self, db: AsyncSession, *, user_id: str
|
self, db: AsyncSession, *, user_id: str
|
||||||
) -> int:
|
) -> int:
|
||||||
"""
|
"""Deactivate all active sessions for a user (logout from all devices)."""
|
||||||
Deactivate all active sessions for a user (logout from all devices).
|
|
||||||
|
|
||||||
Args:
|
|
||||||
db: Database session
|
|
||||||
user_id: User ID
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Number of sessions deactivated
|
|
||||||
"""
|
|
||||||
try:
|
try:
|
||||||
# Convert user_id string to UUID if needed
|
|
||||||
user_uuid = UUID(user_id) if isinstance(user_id, str) else user_id
|
user_uuid = UUID(user_id) if isinstance(user_id, str) else user_id
|
||||||
|
|
||||||
stmt = (
|
stmt = (
|
||||||
@@ -227,16 +165,7 @@ class CRUDSession(CRUDBase[UserSession, SessionCreate, SessionUpdate]):
|
|||||||
async def update_last_used(
|
async def update_last_used(
|
||||||
self, db: AsyncSession, *, session: UserSession
|
self, db: AsyncSession, *, session: UserSession
|
||||||
) -> UserSession:
|
) -> UserSession:
|
||||||
"""
|
"""Update the last_used_at timestamp for a session."""
|
||||||
Update the last_used_at timestamp for a session.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
db: Database session
|
|
||||||
session: UserSession object
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Updated UserSession
|
|
||||||
"""
|
|
||||||
try:
|
try:
|
||||||
session.last_used_at = datetime.now(UTC)
|
session.last_used_at = datetime.now(UTC)
|
||||||
db.add(session)
|
db.add(session)
|
||||||
@@ -256,20 +185,7 @@ class CRUDSession(CRUDBase[UserSession, SessionCreate, SessionUpdate]):
|
|||||||
new_jti: str,
|
new_jti: str,
|
||||||
new_expires_at: datetime,
|
new_expires_at: datetime,
|
||||||
) -> UserSession:
|
) -> UserSession:
|
||||||
"""
|
"""Update session with new refresh token JTI and expiration."""
|
||||||
Update session with new refresh token JTI and expiration.
|
|
||||||
|
|
||||||
Called during token refresh.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
db: Database session
|
|
||||||
session: UserSession object
|
|
||||||
new_jti: New refresh token JTI
|
|
||||||
new_expires_at: New expiration datetime
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Updated UserSession
|
|
||||||
"""
|
|
||||||
try:
|
try:
|
||||||
session.refresh_token_jti = new_jti
|
session.refresh_token_jti = new_jti
|
||||||
session.expires_at = new_expires_at
|
session.expires_at = new_expires_at
|
||||||
@@ -286,27 +202,11 @@ class CRUDSession(CRUDBase[UserSession, SessionCreate, SessionUpdate]):
|
|||||||
raise
|
raise
|
||||||
|
|
||||||
async def cleanup_expired(self, db: AsyncSession, *, keep_days: int = 30) -> int:
|
async def cleanup_expired(self, db: AsyncSession, *, keep_days: int = 30) -> int:
|
||||||
"""
|
"""Clean up expired sessions using optimized bulk DELETE."""
|
||||||
Clean up expired sessions using optimized bulk DELETE.
|
|
||||||
|
|
||||||
Deletes sessions that are:
|
|
||||||
- Expired AND inactive
|
|
||||||
- Older than keep_days
|
|
||||||
|
|
||||||
Uses single DELETE query instead of N individual deletes for efficiency.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
db: Database session
|
|
||||||
keep_days: Keep inactive sessions for this many days (for audit)
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Number of sessions deleted
|
|
||||||
"""
|
|
||||||
try:
|
try:
|
||||||
cutoff_date = datetime.now(UTC) - timedelta(days=keep_days)
|
cutoff_date = datetime.now(UTC) - timedelta(days=keep_days)
|
||||||
now = datetime.now(UTC)
|
now = datetime.now(UTC)
|
||||||
|
|
||||||
# Use bulk DELETE with WHERE clause - single query
|
|
||||||
stmt = delete(UserSession).where(
|
stmt = delete(UserSession).where(
|
||||||
and_(
|
and_(
|
||||||
UserSession.is_active == False, # noqa: E712
|
UserSession.is_active == False, # noqa: E712
|
||||||
@@ -330,29 +230,16 @@ class CRUDSession(CRUDBase[UserSession, SessionCreate, SessionUpdate]):
|
|||||||
raise
|
raise
|
||||||
|
|
||||||
async def cleanup_expired_for_user(self, db: AsyncSession, *, user_id: str) -> int:
|
async def cleanup_expired_for_user(self, db: AsyncSession, *, user_id: str) -> int:
|
||||||
"""
|
"""Clean up expired and inactive sessions for a specific user."""
|
||||||
Clean up expired and inactive sessions for a specific user.
|
|
||||||
|
|
||||||
Uses single bulk DELETE query for efficiency instead of N individual deletes.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
db: Database session
|
|
||||||
user_id: User ID to cleanup sessions for
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Number of sessions deleted
|
|
||||||
"""
|
|
||||||
try:
|
try:
|
||||||
# Validate UUID
|
|
||||||
try:
|
try:
|
||||||
uuid_obj = uuid.UUID(user_id)
|
uuid_obj = uuid.UUID(user_id)
|
||||||
except (ValueError, AttributeError):
|
except (ValueError, AttributeError):
|
||||||
logger.error(f"Invalid UUID format: {user_id}")
|
logger.error(f"Invalid UUID format: {user_id}")
|
||||||
raise ValueError(f"Invalid user ID format: {user_id}")
|
raise InvalidInputError(f"Invalid user ID format: {user_id}")
|
||||||
|
|
||||||
now = datetime.now(UTC)
|
now = datetime.now(UTC)
|
||||||
|
|
||||||
# Use bulk DELETE with WHERE clause - single query
|
|
||||||
stmt = delete(UserSession).where(
|
stmt = delete(UserSession).where(
|
||||||
and_(
|
and_(
|
||||||
UserSession.user_id == uuid_obj,
|
UserSession.user_id == uuid_obj,
|
||||||
@@ -380,18 +267,8 @@ class CRUDSession(CRUDBase[UserSession, SessionCreate, SessionUpdate]):
|
|||||||
raise
|
raise
|
||||||
|
|
||||||
async def get_user_session_count(self, db: AsyncSession, *, user_id: str) -> int:
|
async def get_user_session_count(self, db: AsyncSession, *, user_id: str) -> int:
|
||||||
"""
|
"""Get count of active sessions for a user."""
|
||||||
Get count of active sessions for a user.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
db: Database session
|
|
||||||
user_id: User ID
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Number of active sessions
|
|
||||||
"""
|
|
||||||
try:
|
try:
|
||||||
# Convert user_id string to UUID if needed
|
|
||||||
user_uuid = UUID(user_id) if isinstance(user_id, str) else user_id
|
user_uuid = UUID(user_id) if isinstance(user_id, str) else user_id
|
||||||
|
|
||||||
result = await db.execute(
|
result = await db.execute(
|
||||||
@@ -413,31 +290,16 @@ class CRUDSession(CRUDBase[UserSession, SessionCreate, SessionUpdate]):
|
|||||||
active_only: bool = True,
|
active_only: bool = True,
|
||||||
with_user: bool = True,
|
with_user: bool = True,
|
||||||
) -> tuple[list[UserSession], int]:
|
) -> tuple[list[UserSession], int]:
|
||||||
"""
|
"""Get all sessions across all users with pagination (admin only)."""
|
||||||
Get all sessions across all users with pagination (admin only).
|
|
||||||
|
|
||||||
Args:
|
|
||||||
db: Database session
|
|
||||||
skip: Number of records to skip
|
|
||||||
limit: Maximum number of records to return
|
|
||||||
active_only: If True, return only active sessions
|
|
||||||
with_user: If True, eager load user relationship to prevent N+1
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Tuple of (list of UserSession objects, total count)
|
|
||||||
"""
|
|
||||||
try:
|
try:
|
||||||
# Build query
|
|
||||||
query = select(UserSession)
|
query = select(UserSession)
|
||||||
|
|
||||||
# Add eager loading if requested to prevent N+1 queries
|
|
||||||
if with_user:
|
if with_user:
|
||||||
query = query.options(joinedload(UserSession.user))
|
query = query.options(joinedload(UserSession.user))
|
||||||
|
|
||||||
if active_only:
|
if active_only:
|
||||||
query = query.where(UserSession.is_active)
|
query = query.where(UserSession.is_active)
|
||||||
|
|
||||||
# Get total count
|
|
||||||
count_query = select(func.count(UserSession.id))
|
count_query = select(func.count(UserSession.id))
|
||||||
if active_only:
|
if active_only:
|
||||||
count_query = count_query.where(UserSession.is_active)
|
count_query = count_query.where(UserSession.is_active)
|
||||||
@@ -445,7 +307,6 @@ class CRUDSession(CRUDBase[UserSession, SessionCreate, SessionUpdate]):
|
|||||||
count_result = await db.execute(count_query)
|
count_result = await db.execute(count_query)
|
||||||
total = count_result.scalar_one()
|
total = count_result.scalar_one()
|
||||||
|
|
||||||
# Apply pagination and ordering
|
|
||||||
query = (
|
query = (
|
||||||
query.order_by(UserSession.last_used_at.desc())
|
query.order_by(UserSession.last_used_at.desc())
|
||||||
.offset(skip)
|
.offset(skip)
|
||||||
@@ -462,5 +323,5 @@ class CRUDSession(CRUDBase[UserSession, SessionCreate, SessionUpdate]):
|
|||||||
raise
|
raise
|
||||||
|
|
||||||
|
|
||||||
# Create singleton instance
|
# Singleton instance
|
||||||
session = CRUDSession(UserSession)
|
session_repo = SessionRepository(UserSession)
|
||||||
129
backend/app/crud/user.py → backend/app/repositories/user.py
Executable file → Normal file
129
backend/app/crud/user.py → backend/app/repositories/user.py
Executable file → Normal file
@@ -1,5 +1,5 @@
|
|||||||
# app/crud/user_async.py
|
# app/repositories/user.py
|
||||||
"""Async CRUD operations for User model using SQLAlchemy 2.0 patterns."""
|
"""Repository for User model async CRUD operations using SQLAlchemy 2.0 patterns."""
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
from datetime import UTC, datetime
|
from datetime import UTC, datetime
|
||||||
@@ -11,15 +11,16 @@ from sqlalchemy.exc import IntegrityError
|
|||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
from app.core.auth import get_password_hash_async
|
from app.core.auth import get_password_hash_async
|
||||||
from app.crud.base import CRUDBase
|
from app.core.repository_exceptions import DuplicateEntryError, InvalidInputError
|
||||||
from app.models.user import User
|
from app.models.user import User
|
||||||
|
from app.repositories.base import BaseRepository
|
||||||
from app.schemas.users import UserCreate, UserUpdate
|
from app.schemas.users import UserCreate, UserUpdate
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class CRUDUser(CRUDBase[User, UserCreate, UserUpdate]):
|
class UserRepository(BaseRepository[User, UserCreate, UserUpdate]):
|
||||||
"""Async CRUD operations for User model."""
|
"""Repository for User model."""
|
||||||
|
|
||||||
async def get_by_email(self, db: AsyncSession, *, email: str) -> User | None:
|
async def get_by_email(self, db: AsyncSession, *, email: str) -> User | None:
|
||||||
"""Get user by email address."""
|
"""Get user by email address."""
|
||||||
@@ -33,7 +34,6 @@ class CRUDUser(CRUDBase[User, UserCreate, UserUpdate]):
|
|||||||
async def create(self, db: AsyncSession, *, obj_in: UserCreate) -> User:
|
async def create(self, db: AsyncSession, *, obj_in: UserCreate) -> User:
|
||||||
"""Create a new user with async password hashing and error handling."""
|
"""Create a new user with async password hashing and error handling."""
|
||||||
try:
|
try:
|
||||||
# Hash password asynchronously to avoid blocking event loop
|
|
||||||
password_hash = await get_password_hash_async(obj_in.password)
|
password_hash = await get_password_hash_async(obj_in.password)
|
||||||
|
|
||||||
db_obj = User(
|
db_obj = User(
|
||||||
@@ -58,14 +58,48 @@ class CRUDUser(CRUDBase[User, UserCreate, UserUpdate]):
|
|||||||
error_msg = str(e.orig) if hasattr(e, "orig") else str(e)
|
error_msg = str(e.orig) if hasattr(e, "orig") else str(e)
|
||||||
if "email" in error_msg.lower():
|
if "email" in error_msg.lower():
|
||||||
logger.warning(f"Duplicate email attempted: {obj_in.email}")
|
logger.warning(f"Duplicate email attempted: {obj_in.email}")
|
||||||
raise ValueError(f"User with email {obj_in.email} already exists")
|
raise DuplicateEntryError(f"User with email {obj_in.email} already exists")
|
||||||
logger.error(f"Integrity error creating user: {error_msg}")
|
logger.error(f"Integrity error creating user: {error_msg}")
|
||||||
raise ValueError(f"Database integrity error: {error_msg}")
|
raise DuplicateEntryError(f"Database integrity error: {error_msg}")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
await db.rollback()
|
await db.rollback()
|
||||||
logger.error(f"Unexpected error creating user: {e!s}", exc_info=True)
|
logger.error(f"Unexpected error creating user: {e!s}", exc_info=True)
|
||||||
raise
|
raise
|
||||||
|
|
||||||
|
async def create_oauth_user(
|
||||||
|
self,
|
||||||
|
db: AsyncSession,
|
||||||
|
*,
|
||||||
|
email: str,
|
||||||
|
first_name: str = "User",
|
||||||
|
last_name: str | None = None,
|
||||||
|
) -> User:
|
||||||
|
"""Create a new passwordless user for OAuth sign-in."""
|
||||||
|
try:
|
||||||
|
db_obj = User(
|
||||||
|
email=email,
|
||||||
|
password_hash=None, # OAuth-only user
|
||||||
|
first_name=first_name,
|
||||||
|
last_name=last_name,
|
||||||
|
is_active=True,
|
||||||
|
is_superuser=False,
|
||||||
|
)
|
||||||
|
db.add(db_obj)
|
||||||
|
await db.flush() # Get user.id without committing
|
||||||
|
return db_obj
|
||||||
|
except IntegrityError as e:
|
||||||
|
await db.rollback()
|
||||||
|
error_msg = str(e.orig) if hasattr(e, "orig") else str(e)
|
||||||
|
if "email" in error_msg.lower():
|
||||||
|
logger.warning(f"Duplicate email attempted: {email}")
|
||||||
|
raise DuplicateEntryError(f"User with email {email} already exists")
|
||||||
|
logger.error(f"Integrity error creating OAuth user: {error_msg}")
|
||||||
|
raise DuplicateEntryError(f"Database integrity error: {error_msg}")
|
||||||
|
except Exception as e:
|
||||||
|
await db.rollback()
|
||||||
|
logger.error(f"Unexpected error creating OAuth user: {e!s}", exc_info=True)
|
||||||
|
raise
|
||||||
|
|
||||||
async def update(
|
async def update(
|
||||||
self, db: AsyncSession, *, db_obj: User, obj_in: UserUpdate | dict[str, Any]
|
self, db: AsyncSession, *, db_obj: User, obj_in: UserUpdate | dict[str, Any]
|
||||||
) -> User:
|
) -> User:
|
||||||
@@ -75,8 +109,6 @@ class CRUDUser(CRUDBase[User, UserCreate, UserUpdate]):
|
|||||||
else:
|
else:
|
||||||
update_data = obj_in.model_dump(exclude_unset=True)
|
update_data = obj_in.model_dump(exclude_unset=True)
|
||||||
|
|
||||||
# Handle password separately if it exists in update data
|
|
||||||
# Hash password asynchronously to avoid blocking event loop
|
|
||||||
if "password" in update_data:
|
if "password" in update_data:
|
||||||
update_data["password_hash"] = await get_password_hash_async(
|
update_data["password_hash"] = await get_password_hash_async(
|
||||||
update_data["password"]
|
update_data["password"]
|
||||||
@@ -85,6 +117,15 @@ class CRUDUser(CRUDBase[User, UserCreate, UserUpdate]):
|
|||||||
|
|
||||||
return await super().update(db, db_obj=db_obj, obj_in=update_data)
|
return await super().update(db, db_obj=db_obj, obj_in=update_data)
|
||||||
|
|
||||||
|
async def update_password(
|
||||||
|
self, db: AsyncSession, *, user: User, password_hash: str
|
||||||
|
) -> User:
|
||||||
|
"""Set a new password hash on a user and commit."""
|
||||||
|
user.password_hash = password_hash
|
||||||
|
await db.commit()
|
||||||
|
await db.refresh(user)
|
||||||
|
return user
|
||||||
|
|
||||||
async def get_multi_with_total(
|
async def get_multi_with_total(
|
||||||
self,
|
self,
|
||||||
db: AsyncSession,
|
db: AsyncSession,
|
||||||
@@ -96,43 +137,23 @@ class CRUDUser(CRUDBase[User, UserCreate, UserUpdate]):
|
|||||||
filters: dict[str, Any] | None = None,
|
filters: dict[str, Any] | None = None,
|
||||||
search: str | None = None,
|
search: str | None = None,
|
||||||
) -> tuple[list[User], int]:
|
) -> tuple[list[User], int]:
|
||||||
"""
|
"""Get multiple users with total count, filtering, sorting, and search."""
|
||||||
Get multiple users with total count, filtering, sorting, and search.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
db: Database session
|
|
||||||
skip: Number of records to skip
|
|
||||||
limit: Maximum number of records to return
|
|
||||||
sort_by: Field name to sort by
|
|
||||||
sort_order: Sort order ("asc" or "desc")
|
|
||||||
filters: Dictionary of filters (field_name: value)
|
|
||||||
search: Search term to match against email, first_name, last_name
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Tuple of (users list, total count)
|
|
||||||
"""
|
|
||||||
# Validate pagination
|
|
||||||
if skip < 0:
|
if skip < 0:
|
||||||
raise ValueError("skip must be non-negative")
|
raise InvalidInputError("skip must be non-negative")
|
||||||
if limit < 0:
|
if limit < 0:
|
||||||
raise ValueError("limit must be non-negative")
|
raise InvalidInputError("limit must be non-negative")
|
||||||
if limit > 1000:
|
if limit > 1000:
|
||||||
raise ValueError("Maximum limit is 1000")
|
raise InvalidInputError("Maximum limit is 1000")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Build base query
|
|
||||||
query = select(User)
|
query = select(User)
|
||||||
|
|
||||||
# Exclude soft-deleted users
|
|
||||||
query = query.where(User.deleted_at.is_(None))
|
query = query.where(User.deleted_at.is_(None))
|
||||||
|
|
||||||
# Apply filters
|
|
||||||
if filters:
|
if filters:
|
||||||
for field, value in filters.items():
|
for field, value in filters.items():
|
||||||
if hasattr(User, field) and value is not None:
|
if hasattr(User, field) and value is not None:
|
||||||
query = query.where(getattr(User, field) == value)
|
query = query.where(getattr(User, field) == value)
|
||||||
|
|
||||||
# Apply search
|
|
||||||
if search:
|
if search:
|
||||||
search_filter = or_(
|
search_filter = or_(
|
||||||
User.email.ilike(f"%{search}%"),
|
User.email.ilike(f"%{search}%"),
|
||||||
@@ -141,14 +162,12 @@ class CRUDUser(CRUDBase[User, UserCreate, UserUpdate]):
|
|||||||
)
|
)
|
||||||
query = query.where(search_filter)
|
query = query.where(search_filter)
|
||||||
|
|
||||||
# Get total count
|
|
||||||
from sqlalchemy import func
|
from sqlalchemy import func
|
||||||
|
|
||||||
count_query = select(func.count()).select_from(query.alias())
|
count_query = select(func.count()).select_from(query.alias())
|
||||||
count_result = await db.execute(count_query)
|
count_result = await db.execute(count_query)
|
||||||
total = count_result.scalar_one()
|
total = count_result.scalar_one()
|
||||||
|
|
||||||
# Apply sorting
|
|
||||||
if sort_by and hasattr(User, sort_by):
|
if sort_by and hasattr(User, sort_by):
|
||||||
sort_column = getattr(User, sort_by)
|
sort_column = getattr(User, sort_by)
|
||||||
if sort_order.lower() == "desc":
|
if sort_order.lower() == "desc":
|
||||||
@@ -156,7 +175,6 @@ class CRUDUser(CRUDBase[User, UserCreate, UserUpdate]):
|
|||||||
else:
|
else:
|
||||||
query = query.order_by(sort_column.asc())
|
query = query.order_by(sort_column.asc())
|
||||||
|
|
||||||
# Apply pagination
|
|
||||||
query = query.offset(skip).limit(limit)
|
query = query.offset(skip).limit(limit)
|
||||||
result = await db.execute(query)
|
result = await db.execute(query)
|
||||||
users = list(result.scalars().all())
|
users = list(result.scalars().all())
|
||||||
@@ -170,26 +188,15 @@ class CRUDUser(CRUDBase[User, UserCreate, UserUpdate]):
|
|||||||
async def bulk_update_status(
|
async def bulk_update_status(
|
||||||
self, db: AsyncSession, *, user_ids: list[UUID], is_active: bool
|
self, db: AsyncSession, *, user_ids: list[UUID], is_active: bool
|
||||||
) -> int:
|
) -> int:
|
||||||
"""
|
"""Bulk update is_active status for multiple users."""
|
||||||
Bulk update is_active status for multiple users.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
db: Database session
|
|
||||||
user_ids: List of user IDs to update
|
|
||||||
is_active: New active status
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Number of users updated
|
|
||||||
"""
|
|
||||||
try:
|
try:
|
||||||
if not user_ids:
|
if not user_ids:
|
||||||
return 0
|
return 0
|
||||||
|
|
||||||
# Use UPDATE with WHERE IN for efficiency
|
|
||||||
stmt = (
|
stmt = (
|
||||||
update(User)
|
update(User)
|
||||||
.where(User.id.in_(user_ids))
|
.where(User.id.in_(user_ids))
|
||||||
.where(User.deleted_at.is_(None)) # Don't update deleted users
|
.where(User.deleted_at.is_(None))
|
||||||
.values(is_active=is_active, updated_at=datetime.now(UTC))
|
.values(is_active=is_active, updated_at=datetime.now(UTC))
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -212,34 +219,20 @@ class CRUDUser(CRUDBase[User, UserCreate, UserUpdate]):
|
|||||||
user_ids: list[UUID],
|
user_ids: list[UUID],
|
||||||
exclude_user_id: UUID | None = None,
|
exclude_user_id: UUID | None = None,
|
||||||
) -> int:
|
) -> int:
|
||||||
"""
|
"""Bulk soft delete multiple users."""
|
||||||
Bulk soft delete multiple users.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
db: Database session
|
|
||||||
user_ids: List of user IDs to delete
|
|
||||||
exclude_user_id: Optional user ID to exclude (e.g., the admin performing the action)
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Number of users deleted
|
|
||||||
"""
|
|
||||||
try:
|
try:
|
||||||
if not user_ids:
|
if not user_ids:
|
||||||
return 0
|
return 0
|
||||||
|
|
||||||
# Remove excluded user from list
|
|
||||||
filtered_ids = [uid for uid in user_ids if uid != exclude_user_id]
|
filtered_ids = [uid for uid in user_ids if uid != exclude_user_id]
|
||||||
|
|
||||||
if not filtered_ids:
|
if not filtered_ids:
|
||||||
return 0
|
return 0
|
||||||
|
|
||||||
# Use UPDATE with WHERE IN for efficiency
|
|
||||||
stmt = (
|
stmt = (
|
||||||
update(User)
|
update(User)
|
||||||
.where(User.id.in_(filtered_ids))
|
.where(User.id.in_(filtered_ids))
|
||||||
.where(
|
.where(User.deleted_at.is_(None))
|
||||||
User.deleted_at.is_(None)
|
|
||||||
) # Don't re-delete already deleted users
|
|
||||||
.values(
|
.values(
|
||||||
deleted_at=datetime.now(UTC),
|
deleted_at=datetime.now(UTC),
|
||||||
is_active=False,
|
is_active=False,
|
||||||
@@ -268,5 +261,5 @@ class CRUDUser(CRUDBase[User, UserCreate, UserUpdate]):
|
|||||||
return user.is_superuser
|
return user.is_superuser
|
||||||
|
|
||||||
|
|
||||||
# Create a singleton instance for use across the application
|
# Singleton instance
|
||||||
user = CRUDUser(User)
|
user_repo = UserRepository(User)
|
||||||
@@ -60,8 +60,8 @@ class OAuthAccountCreate(OAuthAccountBase):
|
|||||||
|
|
||||||
user_id: UUID
|
user_id: UUID
|
||||||
provider_user_id: str = Field(..., max_length=255)
|
provider_user_id: str = Field(..., max_length=255)
|
||||||
access_token_encrypted: str | None = None
|
access_token: str | None = None
|
||||||
refresh_token_encrypted: str | None = None
|
refresh_token: str | None = None
|
||||||
token_expires_at: datetime | None = None
|
token_expires_at: datetime | None = None
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -1,5 +1,19 @@
|
|||||||
# app/services/__init__.py
|
# app/services/__init__.py
|
||||||
|
from . import oauth_provider_service
|
||||||
from .auth_service import AuthService
|
from .auth_service import AuthService
|
||||||
from .oauth_service import OAuthService
|
from .oauth_service import OAuthService
|
||||||
|
from .organization_service import OrganizationService, organization_service
|
||||||
|
from .session_service import SessionService, session_service
|
||||||
|
from .user_service import UserService, user_service
|
||||||
|
|
||||||
__all__ = ["AuthService", "OAuthService"]
|
__all__ = [
|
||||||
|
"AuthService",
|
||||||
|
"OAuthService",
|
||||||
|
"UserService",
|
||||||
|
"OrganizationService",
|
||||||
|
"SessionService",
|
||||||
|
"oauth_provider_service",
|
||||||
|
"user_service",
|
||||||
|
"organization_service",
|
||||||
|
"session_service",
|
||||||
|
]
|
||||||
|
|||||||
@@ -2,7 +2,6 @@
|
|||||||
import logging
|
import logging
|
||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
|
|
||||||
from sqlalchemy import select
|
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
from app.core.auth import (
|
from app.core.auth import (
|
||||||
@@ -14,12 +13,18 @@ from app.core.auth import (
|
|||||||
verify_password_async,
|
verify_password_async,
|
||||||
)
|
)
|
||||||
from app.core.config import settings
|
from app.core.config import settings
|
||||||
from app.core.exceptions import AuthenticationError
|
from app.core.exceptions import AuthenticationError, DuplicateError
|
||||||
|
from app.core.repository_exceptions import DuplicateEntryError
|
||||||
from app.models.user import User
|
from app.models.user import User
|
||||||
|
from app.repositories.user import user_repo
|
||||||
from app.schemas.users import Token, UserCreate, UserResponse
|
from app.schemas.users import Token, UserCreate, UserResponse
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# Pre-computed bcrypt hash used for constant-time comparison when user is not found,
|
||||||
|
# preventing timing attacks that could enumerate valid email addresses.
|
||||||
|
_DUMMY_HASH = "$2b$12$EixZaYVK1fsbw1ZfbX3OXePaWxn96p36zLFbnJHfxPSEFBzXKiHia"
|
||||||
|
|
||||||
|
|
||||||
class AuthService:
|
class AuthService:
|
||||||
"""Service for handling authentication operations"""
|
"""Service for handling authentication operations"""
|
||||||
@@ -39,10 +44,12 @@ class AuthService:
|
|||||||
Returns:
|
Returns:
|
||||||
User if authenticated, None otherwise
|
User if authenticated, None otherwise
|
||||||
"""
|
"""
|
||||||
result = await db.execute(select(User).where(User.email == email))
|
user = await user_repo.get_by_email(db, email=email)
|
||||||
user = result.scalar_one_or_none()
|
|
||||||
|
|
||||||
if not user:
|
if not user:
|
||||||
|
# Perform a dummy verification to match timing of a real bcrypt check,
|
||||||
|
# preventing email enumeration via response-time differences.
|
||||||
|
await verify_password_async(password, _DUMMY_HASH)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
# Verify password asynchronously to avoid blocking event loop
|
# Verify password asynchronously to avoid blocking event loop
|
||||||
@@ -71,39 +78,22 @@ class AuthService:
|
|||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
# Check if user already exists
|
# Check if user already exists
|
||||||
result = await db.execute(select(User).where(User.email == user_data.email))
|
existing_user = await user_repo.get_by_email(db, email=user_data.email)
|
||||||
existing_user = result.scalar_one_or_none()
|
|
||||||
if existing_user:
|
if existing_user:
|
||||||
raise AuthenticationError("User with this email already exists")
|
raise DuplicateError("User with this email already exists")
|
||||||
|
|
||||||
# Create new user with async password hashing
|
# Delegate creation (hashing + commit) to the repository
|
||||||
# Hash password asynchronously to avoid blocking event loop
|
user = await user_repo.create(db, obj_in=user_data)
|
||||||
hashed_password = await get_password_hash_async(user_data.password)
|
|
||||||
|
|
||||||
# Create user object from model
|
|
||||||
user = User(
|
|
||||||
email=user_data.email,
|
|
||||||
password_hash=hashed_password,
|
|
||||||
first_name=user_data.first_name,
|
|
||||||
last_name=user_data.last_name,
|
|
||||||
phone_number=user_data.phone_number,
|
|
||||||
is_active=True,
|
|
||||||
is_superuser=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
db.add(user)
|
|
||||||
await db.commit()
|
|
||||||
await db.refresh(user)
|
|
||||||
|
|
||||||
logger.info(f"User created successfully: {user.email}")
|
logger.info(f"User created successfully: {user.email}")
|
||||||
return user
|
return user
|
||||||
|
|
||||||
except AuthenticationError:
|
except (AuthenticationError, DuplicateError):
|
||||||
# Re-raise authentication errors without rollback
|
# Re-raise API exceptions without rollback
|
||||||
raise
|
raise
|
||||||
|
except DuplicateEntryError as e:
|
||||||
|
raise DuplicateError(str(e))
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
# Rollback on any database errors
|
|
||||||
await db.rollback()
|
|
||||||
logger.error(f"Error creating user: {e!s}", exc_info=True)
|
logger.error(f"Error creating user: {e!s}", exc_info=True)
|
||||||
raise AuthenticationError(f"Failed to create user: {e!s}")
|
raise AuthenticationError(f"Failed to create user: {e!s}")
|
||||||
|
|
||||||
@@ -168,8 +158,7 @@ class AuthService:
|
|||||||
user_id = token_data.user_id
|
user_id = token_data.user_id
|
||||||
|
|
||||||
# Get user from database
|
# Get user from database
|
||||||
result = await db.execute(select(User).where(User.id == user_id))
|
user = await user_repo.get(db, id=str(user_id))
|
||||||
user = result.scalar_one_or_none()
|
|
||||||
if not user or not user.is_active:
|
if not user or not user.is_active:
|
||||||
raise TokenInvalidError("Invalid user or inactive account")
|
raise TokenInvalidError("Invalid user or inactive account")
|
||||||
|
|
||||||
@@ -200,8 +189,7 @@ class AuthService:
|
|||||||
AuthenticationError: If current password is incorrect or update fails
|
AuthenticationError: If current password is incorrect or update fails
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
result = await db.execute(select(User).where(User.id == user_id))
|
user = await user_repo.get(db, id=str(user_id))
|
||||||
user = result.scalar_one_or_none()
|
|
||||||
if not user:
|
if not user:
|
||||||
raise AuthenticationError("User not found")
|
raise AuthenticationError("User not found")
|
||||||
|
|
||||||
@@ -210,8 +198,8 @@ class AuthService:
|
|||||||
raise AuthenticationError("Current password is incorrect")
|
raise AuthenticationError("Current password is incorrect")
|
||||||
|
|
||||||
# Hash new password asynchronously to avoid blocking event loop
|
# Hash new password asynchronously to avoid blocking event loop
|
||||||
user.password_hash = await get_password_hash_async(new_password)
|
new_hash = await get_password_hash_async(new_password)
|
||||||
await db.commit()
|
await user_repo.update_password(db, user=user, password_hash=new_hash)
|
||||||
|
|
||||||
logger.info(f"Password changed successfully for user {user_id}")
|
logger.info(f"Password changed successfully for user {user_id}")
|
||||||
return True
|
return True
|
||||||
@@ -226,3 +214,32 @@ class AuthService:
|
|||||||
f"Error changing password for user {user_id}: {e!s}", exc_info=True
|
f"Error changing password for user {user_id}: {e!s}", exc_info=True
|
||||||
)
|
)
|
||||||
raise AuthenticationError(f"Failed to change password: {e!s}")
|
raise AuthenticationError(f"Failed to change password: {e!s}")
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def reset_password(
|
||||||
|
db: AsyncSession, *, email: str, new_password: str
|
||||||
|
) -> User:
|
||||||
|
"""
|
||||||
|
Reset a user's password without requiring the current password.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
db: Database session
|
||||||
|
email: User email address
|
||||||
|
new_password: New password to set
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Updated user
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
AuthenticationError: If user not found or inactive
|
||||||
|
"""
|
||||||
|
user = await user_repo.get_by_email(db, email=email)
|
||||||
|
if not user:
|
||||||
|
raise AuthenticationError("User not found")
|
||||||
|
if not user.is_active:
|
||||||
|
raise AuthenticationError("User account is inactive")
|
||||||
|
|
||||||
|
new_hash = await get_password_hash_async(new_password)
|
||||||
|
user = await user_repo.update_password(db, user=user, password_hash=new_hash)
|
||||||
|
logger.info(f"Password reset successfully for {email}")
|
||||||
|
return user
|
||||||
|
|||||||
@@ -26,14 +26,17 @@ from typing import Any
|
|||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
|
|
||||||
from jose import jwt
|
from jose import jwt
|
||||||
from sqlalchemy import and_, delete, select
|
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
from app.core.config import settings
|
from app.core.config import settings
|
||||||
from app.models.oauth_authorization_code import OAuthAuthorizationCode
|
|
||||||
from app.models.oauth_client import OAuthClient
|
from app.models.oauth_client import OAuthClient
|
||||||
from app.models.oauth_provider_token import OAuthConsent, OAuthProviderRefreshToken
|
from app.schemas.oauth import OAuthClientCreate
|
||||||
from app.models.user import User
|
from app.models.user import User
|
||||||
|
from app.repositories.oauth_authorization_code import oauth_authorization_code_repo
|
||||||
|
from app.repositories.oauth_client import oauth_client_repo
|
||||||
|
from app.repositories.oauth_consent import oauth_consent_repo
|
||||||
|
from app.repositories.oauth_provider_token import oauth_provider_token_repo
|
||||||
|
from app.repositories.user import user_repo
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -161,15 +164,7 @@ def join_scope(scopes: list[str]) -> str:
|
|||||||
|
|
||||||
async def get_client(db: AsyncSession, client_id: str) -> OAuthClient | None:
|
async def get_client(db: AsyncSession, client_id: str) -> OAuthClient | None:
|
||||||
"""Get OAuth client by client_id."""
|
"""Get OAuth client by client_id."""
|
||||||
result = await db.execute(
|
return await oauth_client_repo.get_by_client_id(db, client_id=client_id)
|
||||||
select(OAuthClient).where(
|
|
||||||
and_(
|
|
||||||
OAuthClient.client_id == client_id,
|
|
||||||
OAuthClient.is_active == True, # noqa: E712
|
|
||||||
)
|
|
||||||
)
|
|
||||||
)
|
|
||||||
return result.scalar_one_or_none()
|
|
||||||
|
|
||||||
|
|
||||||
async def validate_client(
|
async def validate_client(
|
||||||
@@ -204,21 +199,19 @@ async def validate_client(
|
|||||||
if not client.client_secret_hash:
|
if not client.client_secret_hash:
|
||||||
raise InvalidClientError("Client not configured with secret")
|
raise InvalidClientError("Client not configured with secret")
|
||||||
|
|
||||||
# SECURITY: Verify secret using bcrypt (not SHA-256)
|
# SECURITY: Verify secret using bcrypt
|
||||||
# Supports both bcrypt and legacy SHA-256 hashes for migration
|
|
||||||
from app.core.auth import verify_password
|
from app.core.auth import verify_password
|
||||||
|
|
||||||
stored_hash = str(client.client_secret_hash)
|
stored_hash = str(client.client_secret_hash)
|
||||||
|
|
||||||
if stored_hash.startswith("$2"):
|
if not stored_hash.startswith("$2"):
|
||||||
# New bcrypt format
|
raise InvalidClientError(
|
||||||
if not verify_password(client_secret, stored_hash):
|
"Client secret uses deprecated hash format. "
|
||||||
raise InvalidClientError("Invalid client secret")
|
"Please regenerate your client credentials."
|
||||||
else:
|
)
|
||||||
# Legacy SHA-256 format
|
|
||||||
computed_hash = hashlib.sha256(client_secret.encode()).hexdigest()
|
if not verify_password(client_secret, stored_hash):
|
||||||
if not secrets.compare_digest(computed_hash, stored_hash):
|
raise InvalidClientError("Invalid client secret")
|
||||||
raise InvalidClientError("Invalid client secret")
|
|
||||||
|
|
||||||
return client
|
return client
|
||||||
|
|
||||||
@@ -311,23 +304,20 @@ async def create_authorization_code(
|
|||||||
minutes=AUTHORIZATION_CODE_EXPIRY_MINUTES
|
minutes=AUTHORIZATION_CODE_EXPIRY_MINUTES
|
||||||
)
|
)
|
||||||
|
|
||||||
auth_code = OAuthAuthorizationCode(
|
await oauth_authorization_code_repo.create_code(
|
||||||
|
db,
|
||||||
code=code,
|
code=code,
|
||||||
client_id=client.client_id,
|
client_id=client.client_id,
|
||||||
user_id=user.id,
|
user_id=user.id,
|
||||||
redirect_uri=redirect_uri,
|
redirect_uri=redirect_uri,
|
||||||
scope=scope,
|
scope=scope,
|
||||||
|
expires_at=expires_at,
|
||||||
code_challenge=code_challenge,
|
code_challenge=code_challenge,
|
||||||
code_challenge_method=code_challenge_method,
|
code_challenge_method=code_challenge_method,
|
||||||
state=state,
|
state=state,
|
||||||
nonce=nonce,
|
nonce=nonce,
|
||||||
expires_at=expires_at,
|
|
||||||
used=False,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
db.add(auth_code)
|
|
||||||
await db.commit()
|
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Created authorization code for user {user.id} and client {client.client_id}"
|
f"Created authorization code for user {user.id} and client {client.client_id}"
|
||||||
)
|
)
|
||||||
@@ -366,30 +356,14 @@ async def exchange_authorization_code(
|
|||||||
"""
|
"""
|
||||||
# Atomically mark code as used and fetch it (prevents race condition)
|
# Atomically mark code as used and fetch it (prevents race condition)
|
||||||
# RFC 6749 Section 4.1.2: Authorization codes MUST be single-use
|
# RFC 6749 Section 4.1.2: Authorization codes MUST be single-use
|
||||||
from sqlalchemy import update
|
updated_id = await oauth_authorization_code_repo.consume_code_atomically(
|
||||||
|
db, code=code
|
||||||
# First, atomically mark the code as used and get affected count
|
|
||||||
update_stmt = (
|
|
||||||
update(OAuthAuthorizationCode)
|
|
||||||
.where(
|
|
||||||
and_(
|
|
||||||
OAuthAuthorizationCode.code == code,
|
|
||||||
OAuthAuthorizationCode.used == False, # noqa: E712
|
|
||||||
)
|
|
||||||
)
|
|
||||||
.values(used=True)
|
|
||||||
.returning(OAuthAuthorizationCode.id)
|
|
||||||
)
|
)
|
||||||
result = await db.execute(update_stmt)
|
|
||||||
updated_id = result.scalar_one_or_none()
|
|
||||||
|
|
||||||
if not updated_id:
|
if not updated_id:
|
||||||
# Either code doesn't exist or was already used
|
# Either code doesn't exist or was already used
|
||||||
# Check if it exists to provide appropriate error
|
# Check if it exists to provide appropriate error
|
||||||
check_result = await db.execute(
|
existing_code = await oauth_authorization_code_repo.get_by_code(db, code=code)
|
||||||
select(OAuthAuthorizationCode).where(OAuthAuthorizationCode.code == code)
|
|
||||||
)
|
|
||||||
existing_code = check_result.scalar_one_or_none()
|
|
||||||
|
|
||||||
if existing_code and existing_code.used:
|
if existing_code and existing_code.used:
|
||||||
# Code reuse is a security incident - revoke all tokens for this grant
|
# Code reuse is a security incident - revoke all tokens for this grant
|
||||||
@@ -404,11 +378,9 @@ async def exchange_authorization_code(
|
|||||||
raise InvalidGrantError("Invalid authorization code")
|
raise InvalidGrantError("Invalid authorization code")
|
||||||
|
|
||||||
# Now fetch the full auth code record
|
# Now fetch the full auth code record
|
||||||
auth_code_result = await db.execute(
|
auth_code = await oauth_authorization_code_repo.get_by_id(db, code_id=updated_id)
|
||||||
select(OAuthAuthorizationCode).where(OAuthAuthorizationCode.id == updated_id)
|
if auth_code is None:
|
||||||
)
|
raise InvalidGrantError("Authorization code not found after consumption")
|
||||||
auth_code = auth_code_result.scalar_one()
|
|
||||||
await db.commit()
|
|
||||||
|
|
||||||
if auth_code.is_expired:
|
if auth_code.is_expired:
|
||||||
raise InvalidGrantError("Authorization code has expired")
|
raise InvalidGrantError("Authorization code has expired")
|
||||||
@@ -452,8 +424,7 @@ async def exchange_authorization_code(
|
|||||||
raise InvalidGrantError("PKCE required for public clients")
|
raise InvalidGrantError("PKCE required for public clients")
|
||||||
|
|
||||||
# Get user
|
# Get user
|
||||||
user_result = await db.execute(select(User).where(User.id == auth_code.user_id))
|
user = await user_repo.get(db, id=str(auth_code.user_id))
|
||||||
user = user_result.scalar_one_or_none()
|
|
||||||
if not user or not user.is_active:
|
if not user or not user.is_active:
|
||||||
raise InvalidGrantError("User not found or inactive")
|
raise InvalidGrantError("User not found or inactive")
|
||||||
|
|
||||||
@@ -543,7 +514,8 @@ async def create_tokens(
|
|||||||
refresh_token_hash = hash_token(refresh_token)
|
refresh_token_hash = hash_token(refresh_token)
|
||||||
|
|
||||||
# Store refresh token in database
|
# Store refresh token in database
|
||||||
refresh_token_record = OAuthProviderRefreshToken(
|
await oauth_provider_token_repo.create_token(
|
||||||
|
db,
|
||||||
token_hash=refresh_token_hash,
|
token_hash=refresh_token_hash,
|
||||||
jti=jti,
|
jti=jti,
|
||||||
client_id=client.client_id,
|
client_id=client.client_id,
|
||||||
@@ -553,8 +525,6 @@ async def create_tokens(
|
|||||||
device_info=device_info,
|
device_info=device_info,
|
||||||
ip_address=ip_address,
|
ip_address=ip_address,
|
||||||
)
|
)
|
||||||
db.add(refresh_token_record)
|
|
||||||
await db.commit()
|
|
||||||
|
|
||||||
logger.info(f"Issued tokens for user {user.id} to client {client.client_id}")
|
logger.info(f"Issued tokens for user {user.id} to client {client.client_id}")
|
||||||
|
|
||||||
@@ -599,12 +569,9 @@ async def refresh_tokens(
|
|||||||
"""
|
"""
|
||||||
# Find refresh token
|
# Find refresh token
|
||||||
token_hash = hash_token(refresh_token)
|
token_hash = hash_token(refresh_token)
|
||||||
result = await db.execute(
|
token_record = await oauth_provider_token_repo.get_by_token_hash(
|
||||||
select(OAuthProviderRefreshToken).where(
|
db, token_hash=token_hash
|
||||||
OAuthProviderRefreshToken.token_hash == token_hash
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
token_record: OAuthProviderRefreshToken | None = result.scalar_one_or_none()
|
|
||||||
|
|
||||||
if not token_record:
|
if not token_record:
|
||||||
raise InvalidGrantError("Invalid refresh token")
|
raise InvalidGrantError("Invalid refresh token")
|
||||||
@@ -631,8 +598,7 @@ async def refresh_tokens(
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Get user
|
# Get user
|
||||||
user_result = await db.execute(select(User).where(User.id == token_record.user_id))
|
user = await user_repo.get(db, id=str(token_record.user_id))
|
||||||
user = user_result.scalar_one_or_none()
|
|
||||||
if not user or not user.is_active:
|
if not user or not user.is_active:
|
||||||
raise InvalidGrantError("User not found or inactive")
|
raise InvalidGrantError("User not found or inactive")
|
||||||
|
|
||||||
@@ -648,9 +614,7 @@ async def refresh_tokens(
|
|||||||
final_scope = token_scope
|
final_scope = token_scope
|
||||||
|
|
||||||
# Revoke old refresh token (token rotation)
|
# Revoke old refresh token (token rotation)
|
||||||
token_record.revoked = True # type: ignore[assignment]
|
await oauth_provider_token_repo.revoke(db, token=token_record)
|
||||||
token_record.last_used_at = datetime.now(UTC) # type: ignore[assignment]
|
|
||||||
await db.commit()
|
|
||||||
|
|
||||||
# Issue new tokens
|
# Issue new tokens
|
||||||
device = str(token_record.device_info) if token_record.device_info else None
|
device = str(token_record.device_info) if token_record.device_info else None
|
||||||
@@ -697,20 +661,16 @@ async def revoke_token(
|
|||||||
# Try as refresh token first (more likely)
|
# Try as refresh token first (more likely)
|
||||||
if token_type_hint != "access_token":
|
if token_type_hint != "access_token":
|
||||||
token_hash = hash_token(token)
|
token_hash = hash_token(token)
|
||||||
result = await db.execute(
|
refresh_record = await oauth_provider_token_repo.get_by_token_hash(
|
||||||
select(OAuthProviderRefreshToken).where(
|
db, token_hash=token_hash
|
||||||
OAuthProviderRefreshToken.token_hash == token_hash
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
refresh_record = result.scalar_one_or_none()
|
|
||||||
|
|
||||||
if refresh_record:
|
if refresh_record:
|
||||||
# Validate client if provided
|
# Validate client if provided
|
||||||
if client_id and refresh_record.client_id != client_id:
|
if client_id and refresh_record.client_id != client_id:
|
||||||
raise InvalidClientError("Token was not issued to this client")
|
raise InvalidClientError("Token was not issued to this client")
|
||||||
|
|
||||||
refresh_record.revoked = True # type: ignore[assignment]
|
await oauth_provider_token_repo.revoke(db, token=refresh_record)
|
||||||
await db.commit()
|
|
||||||
logger.info(f"Revoked refresh token {refresh_record.jti[:8]}...")
|
logger.info(f"Revoked refresh token {refresh_record.jti[:8]}...")
|
||||||
return True
|
return True
|
||||||
|
|
||||||
@@ -731,17 +691,13 @@ async def revoke_token(
|
|||||||
jti = payload.get("jti")
|
jti = payload.get("jti")
|
||||||
if jti:
|
if jti:
|
||||||
# Find and revoke the associated refresh token
|
# Find and revoke the associated refresh token
|
||||||
result = await db.execute(
|
refresh_record = await oauth_provider_token_repo.get_by_jti(
|
||||||
select(OAuthProviderRefreshToken).where(
|
db, jti=jti
|
||||||
OAuthProviderRefreshToken.jti == jti
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
refresh_record = result.scalar_one_or_none()
|
|
||||||
if refresh_record:
|
if refresh_record:
|
||||||
if client_id and refresh_record.client_id != client_id:
|
if client_id and refresh_record.client_id != client_id:
|
||||||
raise InvalidClientError("Token was not issued to this client")
|
raise InvalidClientError("Token was not issued to this client")
|
||||||
refresh_record.revoked = True # type: ignore[assignment]
|
await oauth_provider_token_repo.revoke(db, token=refresh_record)
|
||||||
await db.commit()
|
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Revoked refresh token via access token JTI {jti[:8]}..."
|
f"Revoked refresh token via access token JTI {jti[:8]}..."
|
||||||
)
|
)
|
||||||
@@ -770,24 +726,11 @@ async def revoke_tokens_for_user_client(
|
|||||||
Returns:
|
Returns:
|
||||||
Number of tokens revoked
|
Number of tokens revoked
|
||||||
"""
|
"""
|
||||||
result = await db.execute(
|
count = await oauth_provider_token_repo.revoke_all_for_user_client(
|
||||||
select(OAuthProviderRefreshToken).where(
|
db, user_id=user_id, client_id=client_id
|
||||||
and_(
|
|
||||||
OAuthProviderRefreshToken.user_id == user_id,
|
|
||||||
OAuthProviderRefreshToken.client_id == client_id,
|
|
||||||
OAuthProviderRefreshToken.revoked == False, # noqa: E712
|
|
||||||
)
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
tokens = result.scalars().all()
|
|
||||||
|
|
||||||
count = 0
|
|
||||||
for token in tokens:
|
|
||||||
token.revoked = True # type: ignore[assignment]
|
|
||||||
count += 1
|
|
||||||
|
|
||||||
if count > 0:
|
if count > 0:
|
||||||
await db.commit()
|
|
||||||
logger.warning(
|
logger.warning(
|
||||||
f"Revoked {count} tokens for user {user_id} and client {client_id}"
|
f"Revoked {count} tokens for user {user_id} and client {client_id}"
|
||||||
)
|
)
|
||||||
@@ -808,23 +751,9 @@ async def revoke_all_user_tokens(db: AsyncSession, user_id: UUID) -> int:
|
|||||||
Returns:
|
Returns:
|
||||||
Number of tokens revoked
|
Number of tokens revoked
|
||||||
"""
|
"""
|
||||||
result = await db.execute(
|
count = await oauth_provider_token_repo.revoke_all_for_user(db, user_id=user_id)
|
||||||
select(OAuthProviderRefreshToken).where(
|
|
||||||
and_(
|
|
||||||
OAuthProviderRefreshToken.user_id == user_id,
|
|
||||||
OAuthProviderRefreshToken.revoked == False, # noqa: E712
|
|
||||||
)
|
|
||||||
)
|
|
||||||
)
|
|
||||||
tokens = result.scalars().all()
|
|
||||||
|
|
||||||
count = 0
|
|
||||||
for token in tokens:
|
|
||||||
token.revoked = True # type: ignore[assignment]
|
|
||||||
count += 1
|
|
||||||
|
|
||||||
if count > 0:
|
if count > 0:
|
||||||
await db.commit()
|
|
||||||
logger.info(f"Revoked {count} OAuth provider tokens for user {user_id}")
|
logger.info(f"Revoked {count} OAuth provider tokens for user {user_id}")
|
||||||
|
|
||||||
return count
|
return count
|
||||||
@@ -878,12 +807,9 @@ async def introspect_token(
|
|||||||
# Check if associated refresh token is revoked
|
# Check if associated refresh token is revoked
|
||||||
jti = payload.get("jti")
|
jti = payload.get("jti")
|
||||||
if jti:
|
if jti:
|
||||||
result = await db.execute(
|
refresh_record = await oauth_provider_token_repo.get_by_jti(
|
||||||
select(OAuthProviderRefreshToken).where(
|
db, jti=jti
|
||||||
OAuthProviderRefreshToken.jti == jti
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
refresh_record = result.scalar_one_or_none()
|
|
||||||
if refresh_record and refresh_record.revoked:
|
if refresh_record and refresh_record.revoked:
|
||||||
return {"active": False}
|
return {"active": False}
|
||||||
|
|
||||||
@@ -907,12 +833,9 @@ async def introspect_token(
|
|||||||
# Try as refresh token
|
# Try as refresh token
|
||||||
if token_type_hint != "access_token":
|
if token_type_hint != "access_token":
|
||||||
token_hash = hash_token(token)
|
token_hash = hash_token(token)
|
||||||
result = await db.execute(
|
refresh_record = await oauth_provider_token_repo.get_by_token_hash(
|
||||||
select(OAuthProviderRefreshToken).where(
|
db, token_hash=token_hash
|
||||||
OAuthProviderRefreshToken.token_hash == token_hash
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
refresh_record = result.scalar_one_or_none()
|
|
||||||
|
|
||||||
if refresh_record and refresh_record.is_valid:
|
if refresh_record and refresh_record.is_valid:
|
||||||
return {
|
return {
|
||||||
@@ -937,17 +860,9 @@ async def get_consent(
|
|||||||
db: AsyncSession,
|
db: AsyncSession,
|
||||||
user_id: UUID,
|
user_id: UUID,
|
||||||
client_id: str,
|
client_id: str,
|
||||||
) -> OAuthConsent | None:
|
):
|
||||||
"""Get existing consent record for user-client pair."""
|
"""Get existing consent record for user-client pair."""
|
||||||
result = await db.execute(
|
return await oauth_consent_repo.get_consent(db, user_id=user_id, client_id=client_id)
|
||||||
select(OAuthConsent).where(
|
|
||||||
and_(
|
|
||||||
OAuthConsent.user_id == user_id,
|
|
||||||
OAuthConsent.client_id == client_id,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
)
|
|
||||||
return result.scalar_one_or_none()
|
|
||||||
|
|
||||||
|
|
||||||
async def check_consent(
|
async def check_consent(
|
||||||
@@ -972,31 +887,15 @@ async def grant_consent(
|
|||||||
user_id: UUID,
|
user_id: UUID,
|
||||||
client_id: str,
|
client_id: str,
|
||||||
scopes: list[str],
|
scopes: list[str],
|
||||||
) -> OAuthConsent:
|
):
|
||||||
"""
|
"""
|
||||||
Grant or update consent for a user-client pair.
|
Grant or update consent for a user-client pair.
|
||||||
|
|
||||||
If consent already exists, updates the granted scopes.
|
If consent already exists, updates the granted scopes.
|
||||||
"""
|
"""
|
||||||
consent = await get_consent(db, user_id, client_id)
|
return await oauth_consent_repo.grant_consent(
|
||||||
|
db, user_id=user_id, client_id=client_id, scopes=scopes
|
||||||
if consent:
|
)
|
||||||
# Merge scopes
|
|
||||||
granted = str(consent.granted_scopes) if consent.granted_scopes else ""
|
|
||||||
existing = set(parse_scope(granted))
|
|
||||||
new_scopes = existing | set(scopes)
|
|
||||||
consent.granted_scopes = join_scope(list(new_scopes)) # type: ignore[assignment]
|
|
||||||
else:
|
|
||||||
consent = OAuthConsent(
|
|
||||||
user_id=user_id,
|
|
||||||
client_id=client_id,
|
|
||||||
granted_scopes=join_scope(scopes),
|
|
||||||
)
|
|
||||||
db.add(consent)
|
|
||||||
|
|
||||||
await db.commit()
|
|
||||||
await db.refresh(consent)
|
|
||||||
return consent
|
|
||||||
|
|
||||||
|
|
||||||
async def revoke_consent(
|
async def revoke_consent(
|
||||||
@@ -1009,21 +908,13 @@ async def revoke_consent(
|
|||||||
|
|
||||||
Returns True if consent was found and revoked.
|
Returns True if consent was found and revoked.
|
||||||
"""
|
"""
|
||||||
# Delete consent record
|
# Revoke all tokens first
|
||||||
result = await db.execute(
|
|
||||||
delete(OAuthConsent).where(
|
|
||||||
and_(
|
|
||||||
OAuthConsent.user_id == user_id,
|
|
||||||
OAuthConsent.client_id == client_id,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
# Revoke all tokens
|
|
||||||
await revoke_tokens_for_user_client(db, user_id, client_id)
|
await revoke_tokens_for_user_client(db, user_id, client_id)
|
||||||
|
|
||||||
await db.commit()
|
# Delete consent record
|
||||||
return result.rowcount > 0 # type: ignore[attr-defined]
|
return await oauth_consent_repo.revoke_consent(
|
||||||
|
db, user_id=user_id, client_id=client_id
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
# ============================================================================
|
# ============================================================================
|
||||||
@@ -1031,6 +922,26 @@ async def revoke_consent(
|
|||||||
# ============================================================================
|
# ============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
async def register_client(db: AsyncSession, client_data: OAuthClientCreate) -> tuple:
|
||||||
|
"""Create a new OAuth client. Returns (client, secret)."""
|
||||||
|
return await oauth_client_repo.create_client(db, obj_in=client_data)
|
||||||
|
|
||||||
|
|
||||||
|
async def list_clients(db: AsyncSession) -> list:
|
||||||
|
"""List all registered OAuth clients."""
|
||||||
|
return await oauth_client_repo.get_all_clients(db)
|
||||||
|
|
||||||
|
|
||||||
|
async def delete_client_by_id(db: AsyncSession, client_id: str) -> None:
|
||||||
|
"""Delete an OAuth client by client_id."""
|
||||||
|
await oauth_client_repo.delete_client(db, client_id=client_id)
|
||||||
|
|
||||||
|
|
||||||
|
async def list_user_consents(db: AsyncSession, user_id: UUID) -> list[dict]:
|
||||||
|
"""Get all OAuth consents for a user with client details."""
|
||||||
|
return await oauth_consent_repo.get_user_consents_with_clients(db, user_id=user_id)
|
||||||
|
|
||||||
|
|
||||||
async def cleanup_expired_codes(db: AsyncSession) -> int:
|
async def cleanup_expired_codes(db: AsyncSession) -> int:
|
||||||
"""
|
"""
|
||||||
Delete expired authorization codes.
|
Delete expired authorization codes.
|
||||||
@@ -1040,13 +951,7 @@ async def cleanup_expired_codes(db: AsyncSession) -> int:
|
|||||||
Returns:
|
Returns:
|
||||||
Number of codes deleted
|
Number of codes deleted
|
||||||
"""
|
"""
|
||||||
result = await db.execute(
|
return await oauth_authorization_code_repo.cleanup_expired(db)
|
||||||
delete(OAuthAuthorizationCode).where(
|
|
||||||
OAuthAuthorizationCode.expires_at < datetime.now(UTC)
|
|
||||||
)
|
|
||||||
)
|
|
||||||
await db.commit()
|
|
||||||
return result.rowcount # type: ignore[attr-defined]
|
|
||||||
|
|
||||||
|
|
||||||
async def cleanup_expired_tokens(db: AsyncSession) -> int:
|
async def cleanup_expired_tokens(db: AsyncSession) -> int:
|
||||||
@@ -1058,12 +963,4 @@ async def cleanup_expired_tokens(db: AsyncSession) -> int:
|
|||||||
Returns:
|
Returns:
|
||||||
Number of tokens deleted
|
Number of tokens deleted
|
||||||
"""
|
"""
|
||||||
# Delete tokens that are both expired AND revoked (or just very old)
|
return await oauth_provider_token_repo.cleanup_expired(db, cutoff_days=7)
|
||||||
cutoff = datetime.now(UTC) - timedelta(days=7)
|
|
||||||
result = await db.execute(
|
|
||||||
delete(OAuthProviderRefreshToken).where(
|
|
||||||
OAuthProviderRefreshToken.expires_at < cutoff
|
|
||||||
)
|
|
||||||
)
|
|
||||||
await db.commit()
|
|
||||||
return result.rowcount # type: ignore[attr-defined]
|
|
||||||
|
|||||||
@@ -19,14 +19,15 @@ from typing import TypedDict, cast
|
|||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
|
|
||||||
from authlib.integrations.httpx_client import AsyncOAuth2Client
|
from authlib.integrations.httpx_client import AsyncOAuth2Client
|
||||||
from sqlalchemy import select
|
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
from app.core.auth import create_access_token, create_refresh_token
|
from app.core.auth import create_access_token, create_refresh_token
|
||||||
from app.core.config import settings
|
from app.core.config import settings
|
||||||
from app.core.exceptions import AuthenticationError
|
from app.core.exceptions import AuthenticationError
|
||||||
from app.crud import oauth_account, oauth_state
|
from app.repositories.oauth_account import oauth_account_repo as oauth_account
|
||||||
|
from app.repositories.oauth_state import oauth_state_repo as oauth_state
|
||||||
from app.models.user import User
|
from app.models.user import User
|
||||||
|
from app.repositories.user import user_repo
|
||||||
from app.schemas.oauth import (
|
from app.schemas.oauth import (
|
||||||
OAuthAccountCreate,
|
OAuthAccountCreate,
|
||||||
OAuthCallbackResponse,
|
OAuthCallbackResponse,
|
||||||
@@ -343,7 +344,7 @@ class OAuthService:
|
|||||||
await oauth_account.update_tokens(
|
await oauth_account.update_tokens(
|
||||||
db,
|
db,
|
||||||
account=existing_oauth,
|
account=existing_oauth,
|
||||||
access_token_encrypted=token.get("access_token"), refresh_token_encrypted=token.get("refresh_token"), token_expires_at=datetime.now(UTC)
|
access_token=token.get("access_token"), refresh_token=token.get("refresh_token"), token_expires_at=datetime.now(UTC)
|
||||||
+ timedelta(seconds=token.get("expires_in", 3600)),
|
+ timedelta(seconds=token.get("expires_in", 3600)),
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -351,10 +352,7 @@ class OAuthService:
|
|||||||
|
|
||||||
elif state_record.user_id:
|
elif state_record.user_id:
|
||||||
# Account linking flow (user is already logged in)
|
# Account linking flow (user is already logged in)
|
||||||
result = await db.execute(
|
user = await user_repo.get(db, id=str(state_record.user_id))
|
||||||
select(User).where(User.id == state_record.user_id)
|
|
||||||
)
|
|
||||||
user = result.scalar_one_or_none()
|
|
||||||
|
|
||||||
if not user:
|
if not user:
|
||||||
raise AuthenticationError("User not found for account linking")
|
raise AuthenticationError("User not found for account linking")
|
||||||
@@ -375,7 +373,7 @@ class OAuthService:
|
|||||||
provider=provider,
|
provider=provider,
|
||||||
provider_user_id=provider_user_id,
|
provider_user_id=provider_user_id,
|
||||||
provider_email=provider_email,
|
provider_email=provider_email,
|
||||||
access_token_encrypted=token.get("access_token"), refresh_token_encrypted=token.get("refresh_token"), token_expires_at=datetime.now(UTC)
|
access_token=token.get("access_token"), refresh_token=token.get("refresh_token"), token_expires_at=datetime.now(UTC)
|
||||||
+ timedelta(seconds=token.get("expires_in", 3600))
|
+ timedelta(seconds=token.get("expires_in", 3600))
|
||||||
if token.get("expires_in")
|
if token.get("expires_in")
|
||||||
else None,
|
else None,
|
||||||
@@ -389,10 +387,7 @@ class OAuthService:
|
|||||||
user = None
|
user = None
|
||||||
|
|
||||||
if provider_email and settings.OAUTH_AUTO_LINK_BY_EMAIL:
|
if provider_email and settings.OAUTH_AUTO_LINK_BY_EMAIL:
|
||||||
result = await db.execute(
|
user = await user_repo.get_by_email(db, email=provider_email)
|
||||||
select(User).where(User.email == provider_email)
|
|
||||||
)
|
|
||||||
user = result.scalar_one_or_none()
|
|
||||||
|
|
||||||
if user:
|
if user:
|
||||||
# Auto-link to existing user
|
# Auto-link to existing user
|
||||||
@@ -416,8 +411,8 @@ class OAuthService:
|
|||||||
provider=provider,
|
provider=provider,
|
||||||
provider_user_id=provider_user_id,
|
provider_user_id=provider_user_id,
|
||||||
provider_email=provider_email,
|
provider_email=provider_email,
|
||||||
access_token_encrypted=token.get("access_token"),
|
access_token=token.get("access_token"),
|
||||||
refresh_token_encrypted=token.get("refresh_token"),
|
refresh_token=token.get("refresh_token"),
|
||||||
token_expires_at=datetime.now(UTC)
|
token_expires_at=datetime.now(UTC)
|
||||||
+ timedelta(seconds=token.get("expires_in", 3600))
|
+ timedelta(seconds=token.get("expires_in", 3600))
|
||||||
if token.get("expires_in")
|
if token.get("expires_in")
|
||||||
@@ -644,14 +639,13 @@ class OAuthService:
|
|||||||
provider=provider,
|
provider=provider,
|
||||||
provider_user_id=provider_user_id,
|
provider_user_id=provider_user_id,
|
||||||
provider_email=email,
|
provider_email=email,
|
||||||
access_token_encrypted=token.get("access_token"), refresh_token_encrypted=token.get("refresh_token"), token_expires_at=datetime.now(UTC)
|
access_token=token.get("access_token"), refresh_token=token.get("refresh_token"), token_expires_at=datetime.now(UTC)
|
||||||
+ timedelta(seconds=token.get("expires_in", 3600))
|
+ timedelta(seconds=token.get("expires_in", 3600))
|
||||||
if token.get("expires_in")
|
if token.get("expires_in")
|
||||||
else None,
|
else None,
|
||||||
)
|
)
|
||||||
await oauth_account.create_account(db, obj_in=oauth_create)
|
await oauth_account.create_account(db, obj_in=oauth_create)
|
||||||
|
|
||||||
await db.commit()
|
|
||||||
await db.refresh(user)
|
await db.refresh(user)
|
||||||
|
|
||||||
return user
|
return user
|
||||||
@@ -701,6 +695,20 @@ class OAuthService:
|
|||||||
logger.info(f"OAuth provider unlinked: {provider} from {user.email}")
|
logger.info(f"OAuth provider unlinked: {provider} from {user.email}")
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def get_user_accounts(db: AsyncSession, *, user_id: UUID) -> list:
|
||||||
|
"""Get all OAuth accounts linked to a user."""
|
||||||
|
return await oauth_account.get_user_accounts(db, user_id=user_id)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def get_user_account_by_provider(
|
||||||
|
db: AsyncSession, *, user_id: UUID, provider: str
|
||||||
|
):
|
||||||
|
"""Get a specific OAuth account for a user and provider."""
|
||||||
|
return await oauth_account.get_user_account_by_provider(
|
||||||
|
db, user_id=user_id, provider=provider
|
||||||
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
async def cleanup_expired_states(db: AsyncSession) -> int:
|
async def cleanup_expired_states(db: AsyncSession) -> int:
|
||||||
"""
|
"""
|
||||||
|
|||||||
157
backend/app/services/organization_service.py
Normal file
157
backend/app/services/organization_service.py
Normal file
@@ -0,0 +1,157 @@
|
|||||||
|
# app/services/organization_service.py
|
||||||
|
"""Service layer for organization operations — delegates to OrganizationRepository."""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from typing import Any
|
||||||
|
from uuid import UUID
|
||||||
|
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
|
from app.core.exceptions import NotFoundError
|
||||||
|
from app.models.organization import Organization
|
||||||
|
from app.models.user_organization import OrganizationRole, UserOrganization
|
||||||
|
from app.repositories.organization import OrganizationRepository, organization_repo
|
||||||
|
from app.schemas.organizations import OrganizationCreate, OrganizationUpdate
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class OrganizationService:
|
||||||
|
"""Service for organization management operations."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self, organization_repository: OrganizationRepository | None = None
|
||||||
|
) -> None:
|
||||||
|
self._repo = organization_repository or organization_repo
|
||||||
|
|
||||||
|
async def get_organization(self, db: AsyncSession, org_id: str) -> Organization:
|
||||||
|
"""Get organization by ID, raising NotFoundError if not found."""
|
||||||
|
org = await self._repo.get(db, id=org_id)
|
||||||
|
if not org:
|
||||||
|
raise NotFoundError(f"Organization {org_id} not found")
|
||||||
|
return org
|
||||||
|
|
||||||
|
async def create_organization(
|
||||||
|
self, db: AsyncSession, *, obj_in: OrganizationCreate
|
||||||
|
) -> Organization:
|
||||||
|
"""Create a new organization."""
|
||||||
|
return await self._repo.create(db, obj_in=obj_in)
|
||||||
|
|
||||||
|
async def update_organization(
|
||||||
|
self,
|
||||||
|
db: AsyncSession,
|
||||||
|
*,
|
||||||
|
org: Organization,
|
||||||
|
obj_in: OrganizationUpdate | dict[str, Any],
|
||||||
|
) -> Organization:
|
||||||
|
"""Update an existing organization."""
|
||||||
|
return await self._repo.update(db, db_obj=org, obj_in=obj_in)
|
||||||
|
|
||||||
|
async def remove_organization(self, db: AsyncSession, org_id: str) -> None:
|
||||||
|
"""Permanently delete an organization by ID."""
|
||||||
|
await self._repo.remove(db, id=org_id)
|
||||||
|
|
||||||
|
async def get_member_count(
|
||||||
|
self, db: AsyncSession, *, organization_id: UUID
|
||||||
|
) -> int:
|
||||||
|
"""Get number of active members in an organization."""
|
||||||
|
return await self._repo.get_member_count(db, organization_id=organization_id)
|
||||||
|
|
||||||
|
async def get_multi_with_member_counts(
|
||||||
|
self,
|
||||||
|
db: AsyncSession,
|
||||||
|
*,
|
||||||
|
skip: int = 0,
|
||||||
|
limit: int = 100,
|
||||||
|
is_active: bool | None = None,
|
||||||
|
search: str | None = None,
|
||||||
|
) -> tuple[list[dict[str, Any]], int]:
|
||||||
|
"""List organizations with member counts and pagination."""
|
||||||
|
return await self._repo.get_multi_with_member_counts(
|
||||||
|
db, skip=skip, limit=limit, is_active=is_active, search=search
|
||||||
|
)
|
||||||
|
|
||||||
|
async def get_user_organizations_with_details(
|
||||||
|
self,
|
||||||
|
db: AsyncSession,
|
||||||
|
*,
|
||||||
|
user_id: UUID,
|
||||||
|
is_active: bool | None = None,
|
||||||
|
) -> list[dict[str, Any]]:
|
||||||
|
"""Get all organizations a user belongs to, with membership details."""
|
||||||
|
return await self._repo.get_user_organizations_with_details(
|
||||||
|
db, user_id=user_id, is_active=is_active
|
||||||
|
)
|
||||||
|
|
||||||
|
async def get_organization_members(
|
||||||
|
self,
|
||||||
|
db: AsyncSession,
|
||||||
|
*,
|
||||||
|
organization_id: UUID,
|
||||||
|
skip: int = 0,
|
||||||
|
limit: int = 100,
|
||||||
|
is_active: bool | None = True,
|
||||||
|
) -> tuple[list[dict[str, Any]], int]:
|
||||||
|
"""Get members of an organization with pagination."""
|
||||||
|
return await self._repo.get_organization_members(
|
||||||
|
db,
|
||||||
|
organization_id=organization_id,
|
||||||
|
skip=skip,
|
||||||
|
limit=limit,
|
||||||
|
is_active=is_active,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def add_member(
|
||||||
|
self,
|
||||||
|
db: AsyncSession,
|
||||||
|
*,
|
||||||
|
organization_id: UUID,
|
||||||
|
user_id: UUID,
|
||||||
|
role: OrganizationRole = OrganizationRole.MEMBER,
|
||||||
|
) -> UserOrganization:
|
||||||
|
"""Add a user to an organization."""
|
||||||
|
return await self._repo.add_user(
|
||||||
|
db, organization_id=organization_id, user_id=user_id, role=role
|
||||||
|
)
|
||||||
|
|
||||||
|
async def remove_member(
|
||||||
|
self,
|
||||||
|
db: AsyncSession,
|
||||||
|
*,
|
||||||
|
organization_id: UUID,
|
||||||
|
user_id: UUID,
|
||||||
|
) -> bool:
|
||||||
|
"""Remove a user from an organization. Returns True if found and removed."""
|
||||||
|
return await self._repo.remove_user(
|
||||||
|
db, organization_id=organization_id, user_id=user_id
|
||||||
|
)
|
||||||
|
|
||||||
|
async def get_user_role_in_org(
|
||||||
|
self, db: AsyncSession, *, user_id: UUID, organization_id: UUID
|
||||||
|
) -> OrganizationRole | None:
|
||||||
|
"""Get the role of a user in an organization."""
|
||||||
|
return await self._repo.get_user_role_in_org(
|
||||||
|
db, user_id=user_id, organization_id=organization_id
|
||||||
|
)
|
||||||
|
|
||||||
|
async def get_org_distribution(
|
||||||
|
self, db: AsyncSession, *, limit: int = 6
|
||||||
|
) -> list[dict[str, Any]]:
|
||||||
|
"""Return top organizations by member count for admin dashboard."""
|
||||||
|
from sqlalchemy import func, select
|
||||||
|
|
||||||
|
result = await db.execute(
|
||||||
|
select(
|
||||||
|
Organization.name,
|
||||||
|
func.count(UserOrganization.user_id).label("count"),
|
||||||
|
)
|
||||||
|
.join(UserOrganization, Organization.id == UserOrganization.organization_id)
|
||||||
|
.group_by(Organization.name)
|
||||||
|
.order_by(func.count(UserOrganization.user_id).desc())
|
||||||
|
.limit(limit)
|
||||||
|
)
|
||||||
|
return [{"name": row.name, "value": row.count} for row in result.all()]
|
||||||
|
|
||||||
|
|
||||||
|
# Default singleton
|
||||||
|
organization_service = OrganizationService()
|
||||||
@@ -8,7 +8,7 @@ import logging
|
|||||||
from datetime import UTC, datetime
|
from datetime import UTC, datetime
|
||||||
|
|
||||||
from app.core.database import SessionLocal
|
from app.core.database import SessionLocal
|
||||||
from app.crud.session import session as session_crud
|
from app.repositories.session import session_repo as session_crud
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|||||||
97
backend/app/services/session_service.py
Normal file
97
backend/app/services/session_service.py
Normal file
@@ -0,0 +1,97 @@
|
|||||||
|
# app/services/session_service.py
|
||||||
|
"""Service layer for session operations — delegates to SessionRepository."""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
|
from app.models.user_session import UserSession
|
||||||
|
from app.repositories.session import SessionRepository, session_repo
|
||||||
|
from app.schemas.sessions import SessionCreate
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class SessionService:
|
||||||
|
"""Service for user session management operations."""
|
||||||
|
|
||||||
|
def __init__(self, session_repository: SessionRepository | None = None) -> None:
|
||||||
|
self._repo = session_repository or session_repo
|
||||||
|
|
||||||
|
async def create_session(
|
||||||
|
self, db: AsyncSession, *, obj_in: SessionCreate
|
||||||
|
) -> UserSession:
|
||||||
|
"""Create a new session record."""
|
||||||
|
return await self._repo.create_session(db, obj_in=obj_in)
|
||||||
|
|
||||||
|
async def get_session(self, db: AsyncSession, session_id: str) -> UserSession | None:
|
||||||
|
"""Get session by ID."""
|
||||||
|
return await self._repo.get(db, id=session_id)
|
||||||
|
|
||||||
|
async def get_user_sessions(
|
||||||
|
self, db: AsyncSession, *, user_id: str, active_only: bool = True
|
||||||
|
) -> list[UserSession]:
|
||||||
|
"""Get all sessions for a user."""
|
||||||
|
return await self._repo.get_user_sessions(
|
||||||
|
db, user_id=user_id, active_only=active_only
|
||||||
|
)
|
||||||
|
|
||||||
|
async def get_active_by_jti(
|
||||||
|
self, db: AsyncSession, *, jti: str
|
||||||
|
) -> UserSession | None:
|
||||||
|
"""Get active session by refresh token JTI."""
|
||||||
|
return await self._repo.get_active_by_jti(db, jti=jti)
|
||||||
|
|
||||||
|
async def get_by_jti(self, db: AsyncSession, *, jti: str) -> UserSession | None:
|
||||||
|
"""Get session by refresh token JTI (active or inactive)."""
|
||||||
|
return await self._repo.get_by_jti(db, jti=jti)
|
||||||
|
|
||||||
|
async def deactivate(
|
||||||
|
self, db: AsyncSession, *, session_id: str
|
||||||
|
) -> UserSession | None:
|
||||||
|
"""Deactivate a session (logout from device)."""
|
||||||
|
return await self._repo.deactivate(db, session_id=session_id)
|
||||||
|
|
||||||
|
async def deactivate_all_user_sessions(
|
||||||
|
self, db: AsyncSession, *, user_id: str
|
||||||
|
) -> int:
|
||||||
|
"""Deactivate all sessions for a user. Returns count deactivated."""
|
||||||
|
return await self._repo.deactivate_all_user_sessions(db, user_id=user_id)
|
||||||
|
|
||||||
|
async def update_refresh_token(
|
||||||
|
self,
|
||||||
|
db: AsyncSession,
|
||||||
|
*,
|
||||||
|
session: UserSession,
|
||||||
|
new_jti: str,
|
||||||
|
new_expires_at: datetime,
|
||||||
|
) -> UserSession:
|
||||||
|
"""Update session with a rotated refresh token."""
|
||||||
|
return await self._repo.update_refresh_token(
|
||||||
|
db, session=session, new_jti=new_jti, new_expires_at=new_expires_at
|
||||||
|
)
|
||||||
|
|
||||||
|
async def cleanup_expired_for_user(
|
||||||
|
self, db: AsyncSession, *, user_id: str
|
||||||
|
) -> int:
|
||||||
|
"""Remove expired sessions for a user. Returns count removed."""
|
||||||
|
return await self._repo.cleanup_expired_for_user(db, user_id=user_id)
|
||||||
|
|
||||||
|
async def get_all_sessions(
|
||||||
|
self,
|
||||||
|
db: AsyncSession,
|
||||||
|
*,
|
||||||
|
skip: int = 0,
|
||||||
|
limit: int = 100,
|
||||||
|
active_only: bool = True,
|
||||||
|
with_user: bool = True,
|
||||||
|
) -> tuple[list[UserSession], int]:
|
||||||
|
"""Get all sessions with pagination (admin only)."""
|
||||||
|
return await self._repo.get_all_sessions(
|
||||||
|
db, skip=skip, limit=limit, active_only=active_only, with_user=with_user
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# Default singleton
|
||||||
|
session_service = SessionService()
|
||||||
120
backend/app/services/user_service.py
Normal file
120
backend/app/services/user_service.py
Normal file
@@ -0,0 +1,120 @@
|
|||||||
|
# app/services/user_service.py
|
||||||
|
"""Service layer for user operations — delegates to UserRepository."""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from typing import Any
|
||||||
|
from uuid import UUID
|
||||||
|
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
|
from app.core.exceptions import NotFoundError
|
||||||
|
from app.models.user import User
|
||||||
|
from app.repositories.user import UserRepository, user_repo
|
||||||
|
from app.schemas.users import UserCreate, UserUpdate
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class UserService:
|
||||||
|
"""Service for user management operations."""
|
||||||
|
|
||||||
|
def __init__(self, user_repository: UserRepository | None = None) -> None:
|
||||||
|
self._repo = user_repository or user_repo
|
||||||
|
|
||||||
|
async def get_user(self, db: AsyncSession, user_id: str) -> User:
|
||||||
|
"""Get user by ID, raising NotFoundError if not found."""
|
||||||
|
user = await self._repo.get(db, id=user_id)
|
||||||
|
if not user:
|
||||||
|
raise NotFoundError(f"User {user_id} not found")
|
||||||
|
return user
|
||||||
|
|
||||||
|
async def get_by_email(self, db: AsyncSession, email: str) -> User | None:
|
||||||
|
"""Get user by email address."""
|
||||||
|
return await self._repo.get_by_email(db, email=email)
|
||||||
|
|
||||||
|
async def create_user(self, db: AsyncSession, user_data: UserCreate) -> User:
|
||||||
|
"""Create a new user."""
|
||||||
|
return await self._repo.create(db, obj_in=user_data)
|
||||||
|
|
||||||
|
async def update_user(
|
||||||
|
self, db: AsyncSession, *, user: User, obj_in: UserUpdate | dict[str, Any]
|
||||||
|
) -> User:
|
||||||
|
"""Update an existing user."""
|
||||||
|
return await self._repo.update(db, db_obj=user, obj_in=obj_in)
|
||||||
|
|
||||||
|
async def soft_delete_user(self, db: AsyncSession, user_id: str) -> None:
|
||||||
|
"""Soft-delete a user by ID."""
|
||||||
|
await self._repo.soft_delete(db, id=user_id)
|
||||||
|
|
||||||
|
async def list_users(
|
||||||
|
self,
|
||||||
|
db: AsyncSession,
|
||||||
|
*,
|
||||||
|
skip: int = 0,
|
||||||
|
limit: int = 100,
|
||||||
|
sort_by: str | None = None,
|
||||||
|
sort_order: str = "asc",
|
||||||
|
filters: dict[str, Any] | None = None,
|
||||||
|
search: str | None = None,
|
||||||
|
) -> tuple[list[User], int]:
|
||||||
|
"""List users with pagination, sorting, filtering, and search."""
|
||||||
|
return await self._repo.get_multi_with_total(
|
||||||
|
db,
|
||||||
|
skip=skip,
|
||||||
|
limit=limit,
|
||||||
|
sort_by=sort_by,
|
||||||
|
sort_order=sort_order,
|
||||||
|
filters=filters,
|
||||||
|
search=search,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def bulk_update_status(
|
||||||
|
self, db: AsyncSession, *, user_ids: list[UUID], is_active: bool
|
||||||
|
) -> int:
|
||||||
|
"""Bulk update active status for multiple users. Returns count updated."""
|
||||||
|
return await self._repo.bulk_update_status(
|
||||||
|
db, user_ids=user_ids, is_active=is_active
|
||||||
|
)
|
||||||
|
|
||||||
|
async def bulk_soft_delete(
|
||||||
|
self,
|
||||||
|
db: AsyncSession,
|
||||||
|
*,
|
||||||
|
user_ids: list[UUID],
|
||||||
|
exclude_user_id: UUID | None = None,
|
||||||
|
) -> int:
|
||||||
|
"""Bulk soft-delete multiple users. Returns count deleted."""
|
||||||
|
return await self._repo.bulk_soft_delete(
|
||||||
|
db, user_ids=user_ids, exclude_user_id=exclude_user_id
|
||||||
|
)
|
||||||
|
|
||||||
|
async def get_stats(self, db: AsyncSession) -> dict[str, Any]:
|
||||||
|
"""Return user stats needed for the admin dashboard."""
|
||||||
|
from sqlalchemy import func, select
|
||||||
|
|
||||||
|
total_users = (
|
||||||
|
await db.execute(select(func.count()).select_from(User))
|
||||||
|
).scalar() or 0
|
||||||
|
active_count = (
|
||||||
|
await db.execute(select(func.count()).select_from(User).where(User.is_active))
|
||||||
|
).scalar() or 0
|
||||||
|
inactive_count = (
|
||||||
|
await db.execute(
|
||||||
|
select(func.count()).select_from(User).where(User.is_active.is_(False))
|
||||||
|
)
|
||||||
|
).scalar() or 0
|
||||||
|
all_users = list(
|
||||||
|
(
|
||||||
|
await db.execute(select(User).order_by(User.created_at))
|
||||||
|
).scalars().all()
|
||||||
|
)
|
||||||
|
return {
|
||||||
|
"total_users": total_users,
|
||||||
|
"active_count": active_count,
|
||||||
|
"inactive_count": inactive_count,
|
||||||
|
"all_users": all_users,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# Default singleton
|
||||||
|
user_service = UserService()
|
||||||
@@ -147,7 +147,7 @@ class TestAdminCreateUser:
|
|||||||
headers={"Authorization": f"Bearer {superuser_token}"},
|
headers={"Authorization": f"Bearer {superuser_token}"},
|
||||||
)
|
)
|
||||||
|
|
||||||
assert response.status_code == status.HTTP_404_NOT_FOUND
|
assert response.status_code == status.HTTP_409_CONFLICT
|
||||||
|
|
||||||
|
|
||||||
class TestAdminGetUser:
|
class TestAdminGetUser:
|
||||||
@@ -565,7 +565,7 @@ class TestAdminCreateOrganization:
|
|||||||
headers={"Authorization": f"Bearer {superuser_token}"},
|
headers={"Authorization": f"Bearer {superuser_token}"},
|
||||||
)
|
)
|
||||||
|
|
||||||
assert response.status_code == status.HTTP_404_NOT_FOUND
|
assert response.status_code == status.HTTP_409_CONFLICT
|
||||||
|
|
||||||
|
|
||||||
class TestAdminGetOrganization:
|
class TestAdminGetOrganization:
|
||||||
|
|||||||
@@ -45,7 +45,7 @@ class TestAdminListUsersFilters:
|
|||||||
async def test_list_users_database_error_propagates(self, client, superuser_token):
|
async def test_list_users_database_error_propagates(self, client, superuser_token):
|
||||||
"""Test that database errors propagate correctly (covers line 118-120)."""
|
"""Test that database errors propagate correctly (covers line 118-120)."""
|
||||||
with patch(
|
with patch(
|
||||||
"app.api.routes.admin.user_crud.get_multi_with_total",
|
"app.api.routes.admin.user_service.list_users",
|
||||||
side_effect=Exception("DB error"),
|
side_effect=Exception("DB error"),
|
||||||
):
|
):
|
||||||
with pytest.raises(Exception):
|
with pytest.raises(Exception):
|
||||||
@@ -74,8 +74,8 @@ class TestAdminCreateUserErrors:
|
|||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
# Should get error for duplicate email
|
# Should get conflict for duplicate email
|
||||||
assert response.status_code == status.HTTP_404_NOT_FOUND
|
assert response.status_code == status.HTTP_409_CONFLICT
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_create_user_unexpected_error_propagates(
|
async def test_create_user_unexpected_error_propagates(
|
||||||
@@ -83,7 +83,7 @@ class TestAdminCreateUserErrors:
|
|||||||
):
|
):
|
||||||
"""Test unexpected errors during user creation (covers line 151-153)."""
|
"""Test unexpected errors during user creation (covers line 151-153)."""
|
||||||
with patch(
|
with patch(
|
||||||
"app.api.routes.admin.user_crud.create",
|
"app.api.routes.admin.user_service.create_user",
|
||||||
side_effect=RuntimeError("Unexpected error"),
|
side_effect=RuntimeError("Unexpected error"),
|
||||||
):
|
):
|
||||||
with pytest.raises(RuntimeError):
|
with pytest.raises(RuntimeError):
|
||||||
@@ -135,7 +135,7 @@ class TestAdminUpdateUserErrors:
|
|||||||
):
|
):
|
||||||
"""Test unexpected errors during user update (covers line 206-208)."""
|
"""Test unexpected errors during user update (covers line 206-208)."""
|
||||||
with patch(
|
with patch(
|
||||||
"app.api.routes.admin.user_crud.update",
|
"app.api.routes.admin.user_service.update_user",
|
||||||
side_effect=RuntimeError("Update failed"),
|
side_effect=RuntimeError("Update failed"),
|
||||||
):
|
):
|
||||||
with pytest.raises(RuntimeError):
|
with pytest.raises(RuntimeError):
|
||||||
@@ -166,7 +166,7 @@ class TestAdminDeleteUserErrors:
|
|||||||
):
|
):
|
||||||
"""Test unexpected errors during user deletion (covers line 238-240)."""
|
"""Test unexpected errors during user deletion (covers line 238-240)."""
|
||||||
with patch(
|
with patch(
|
||||||
"app.api.routes.admin.user_crud.soft_delete",
|
"app.api.routes.admin.user_service.soft_delete_user",
|
||||||
side_effect=Exception("Delete failed"),
|
side_effect=Exception("Delete failed"),
|
||||||
):
|
):
|
||||||
with pytest.raises(Exception):
|
with pytest.raises(Exception):
|
||||||
@@ -196,7 +196,7 @@ class TestAdminActivateUserErrors:
|
|||||||
):
|
):
|
||||||
"""Test unexpected errors during user activation (covers line 282-284)."""
|
"""Test unexpected errors during user activation (covers line 282-284)."""
|
||||||
with patch(
|
with patch(
|
||||||
"app.api.routes.admin.user_crud.update",
|
"app.api.routes.admin.user_service.update_user",
|
||||||
side_effect=Exception("Activation failed"),
|
side_effect=Exception("Activation failed"),
|
||||||
):
|
):
|
||||||
with pytest.raises(Exception):
|
with pytest.raises(Exception):
|
||||||
@@ -238,7 +238,7 @@ class TestAdminDeactivateUserErrors:
|
|||||||
):
|
):
|
||||||
"""Test unexpected errors during user deactivation (covers line 326-328)."""
|
"""Test unexpected errors during user deactivation (covers line 326-328)."""
|
||||||
with patch(
|
with patch(
|
||||||
"app.api.routes.admin.user_crud.update",
|
"app.api.routes.admin.user_service.update_user",
|
||||||
side_effect=Exception("Deactivation failed"),
|
side_effect=Exception("Deactivation failed"),
|
||||||
):
|
):
|
||||||
with pytest.raises(Exception):
|
with pytest.raises(Exception):
|
||||||
@@ -258,7 +258,7 @@ class TestAdminListOrganizationsErrors:
|
|||||||
async def test_list_organizations_database_error(self, client, superuser_token):
|
async def test_list_organizations_database_error(self, client, superuser_token):
|
||||||
"""Test list organizations with database error (covers line 427-456)."""
|
"""Test list organizations with database error (covers line 427-456)."""
|
||||||
with patch(
|
with patch(
|
||||||
"app.api.routes.admin.organization_crud.get_multi_with_member_counts",
|
"app.api.routes.admin.organization_service.get_multi_with_member_counts",
|
||||||
side_effect=Exception("DB error"),
|
side_effect=Exception("DB error"),
|
||||||
):
|
):
|
||||||
with pytest.raises(Exception):
|
with pytest.raises(Exception):
|
||||||
@@ -299,14 +299,14 @@ class TestAdminCreateOrganizationErrors:
|
|||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
# Should get error for duplicate slug
|
# Should get conflict for duplicate slug
|
||||||
assert response.status_code == status.HTTP_404_NOT_FOUND
|
assert response.status_code == status.HTTP_409_CONFLICT
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_create_organization_unexpected_error(self, client, superuser_token):
|
async def test_create_organization_unexpected_error(self, client, superuser_token):
|
||||||
"""Test unexpected errors during organization creation (covers line 484-485)."""
|
"""Test unexpected errors during organization creation (covers line 484-485)."""
|
||||||
with patch(
|
with patch(
|
||||||
"app.api.routes.admin.organization_crud.create",
|
"app.api.routes.admin.organization_service.create_organization",
|
||||||
side_effect=RuntimeError("Creation failed"),
|
side_effect=RuntimeError("Creation failed"),
|
||||||
):
|
):
|
||||||
with pytest.raises(RuntimeError):
|
with pytest.raises(RuntimeError):
|
||||||
@@ -367,7 +367,7 @@ class TestAdminUpdateOrganizationErrors:
|
|||||||
org_id = org.id
|
org_id = org.id
|
||||||
|
|
||||||
with patch(
|
with patch(
|
||||||
"app.api.routes.admin.organization_crud.update",
|
"app.api.routes.admin.organization_service.update_organization",
|
||||||
side_effect=Exception("Update failed"),
|
side_effect=Exception("Update failed"),
|
||||||
):
|
):
|
||||||
with pytest.raises(Exception):
|
with pytest.raises(Exception):
|
||||||
@@ -412,7 +412,7 @@ class TestAdminDeleteOrganizationErrors:
|
|||||||
org_id = org.id
|
org_id = org.id
|
||||||
|
|
||||||
with patch(
|
with patch(
|
||||||
"app.api.routes.admin.organization_crud.remove",
|
"app.api.routes.admin.organization_service.remove_organization",
|
||||||
side_effect=Exception("Delete failed"),
|
side_effect=Exception("Delete failed"),
|
||||||
):
|
):
|
||||||
with pytest.raises(Exception):
|
with pytest.raises(Exception):
|
||||||
@@ -456,7 +456,7 @@ class TestAdminListOrganizationMembersErrors:
|
|||||||
org_id = org.id
|
org_id = org.id
|
||||||
|
|
||||||
with patch(
|
with patch(
|
||||||
"app.api.routes.admin.organization_crud.get_organization_members",
|
"app.api.routes.admin.organization_service.get_organization_members",
|
||||||
side_effect=Exception("DB error"),
|
side_effect=Exception("DB error"),
|
||||||
):
|
):
|
||||||
with pytest.raises(Exception):
|
with pytest.raises(Exception):
|
||||||
@@ -531,7 +531,7 @@ class TestAdminAddOrganizationMemberErrors:
|
|||||||
org_id = org.id
|
org_id = org.id
|
||||||
|
|
||||||
with patch(
|
with patch(
|
||||||
"app.api.routes.admin.organization_crud.add_user",
|
"app.api.routes.admin.organization_service.add_member",
|
||||||
side_effect=Exception("Add failed"),
|
side_effect=Exception("Add failed"),
|
||||||
):
|
):
|
||||||
with pytest.raises(Exception):
|
with pytest.raises(Exception):
|
||||||
@@ -587,7 +587,7 @@ class TestAdminRemoveOrganizationMemberErrors:
|
|||||||
org_id = org.id
|
org_id = org.id
|
||||||
|
|
||||||
with patch(
|
with patch(
|
||||||
"app.api.routes.admin.organization_crud.remove_user",
|
"app.api.routes.admin.organization_service.remove_member",
|
||||||
side_effect=Exception("Remove failed"),
|
side_effect=Exception("Remove failed"),
|
||||||
):
|
):
|
||||||
with pytest.raises(Exception):
|
with pytest.raises(Exception):
|
||||||
|
|||||||
@@ -19,7 +19,7 @@ class TestLoginSessionCreationFailure:
|
|||||||
"""Test that login succeeds even if session creation fails."""
|
"""Test that login succeeds even if session creation fails."""
|
||||||
# Mock session creation to fail
|
# Mock session creation to fail
|
||||||
with patch(
|
with patch(
|
||||||
"app.api.routes.auth.session_crud.create_session",
|
"app.api.routes.auth.session_service.create_session",
|
||||||
side_effect=Exception("Session creation failed"),
|
side_effect=Exception("Session creation failed"),
|
||||||
):
|
):
|
||||||
response = await client.post(
|
response = await client.post(
|
||||||
@@ -43,7 +43,7 @@ class TestOAuthLoginSessionCreationFailure:
|
|||||||
):
|
):
|
||||||
"""Test OAuth login succeeds even if session creation fails."""
|
"""Test OAuth login succeeds even if session creation fails."""
|
||||||
with patch(
|
with patch(
|
||||||
"app.api.routes.auth.session_crud.create_session",
|
"app.api.routes.auth.session_service.create_session",
|
||||||
side_effect=Exception("Session failed"),
|
side_effect=Exception("Session failed"),
|
||||||
):
|
):
|
||||||
response = await client.post(
|
response = await client.post(
|
||||||
@@ -76,7 +76,7 @@ class TestRefreshTokenSessionUpdateFailure:
|
|||||||
|
|
||||||
# Mock session update to fail
|
# Mock session update to fail
|
||||||
with patch(
|
with patch(
|
||||||
"app.api.routes.auth.session_crud.update_refresh_token",
|
"app.api.routes.auth.session_service.update_refresh_token",
|
||||||
side_effect=Exception("Update failed"),
|
side_effect=Exception("Update failed"),
|
||||||
):
|
):
|
||||||
response = await client.post(
|
response = await client.post(
|
||||||
@@ -130,7 +130,7 @@ class TestLogoutWithNonExistentSession:
|
|||||||
tokens = response.json()
|
tokens = response.json()
|
||||||
|
|
||||||
# Mock session lookup to return None
|
# Mock session lookup to return None
|
||||||
with patch("app.api.routes.auth.session_crud.get_by_jti", return_value=None):
|
with patch("app.api.routes.auth.session_service.get_by_jti", return_value=None):
|
||||||
response = await client.post(
|
response = await client.post(
|
||||||
"/api/v1/auth/logout",
|
"/api/v1/auth/logout",
|
||||||
headers={"Authorization": f"Bearer {tokens['access_token']}"},
|
headers={"Authorization": f"Bearer {tokens['access_token']}"},
|
||||||
@@ -157,7 +157,7 @@ class TestLogoutUnexpectedError:
|
|||||||
|
|
||||||
# Mock to raise unexpected error
|
# Mock to raise unexpected error
|
||||||
with patch(
|
with patch(
|
||||||
"app.api.routes.auth.session_crud.get_by_jti",
|
"app.api.routes.auth.session_service.get_by_jti",
|
||||||
side_effect=Exception("Unexpected error"),
|
side_effect=Exception("Unexpected error"),
|
||||||
):
|
):
|
||||||
response = await client.post(
|
response = await client.post(
|
||||||
@@ -186,7 +186,7 @@ class TestLogoutAllUnexpectedError:
|
|||||||
|
|
||||||
# Mock to raise database error
|
# Mock to raise database error
|
||||||
with patch(
|
with patch(
|
||||||
"app.api.routes.auth.session_crud.deactivate_all_user_sessions",
|
"app.api.routes.auth.session_service.deactivate_all_user_sessions",
|
||||||
side_effect=Exception("DB error"),
|
side_effect=Exception("DB error"),
|
||||||
):
|
):
|
||||||
response = await client.post(
|
response = await client.post(
|
||||||
@@ -212,7 +212,7 @@ class TestPasswordResetConfirmSessionInvalidation:
|
|||||||
|
|
||||||
# Mock session invalidation to fail
|
# Mock session invalidation to fail
|
||||||
with patch(
|
with patch(
|
||||||
"app.api.routes.auth.session_crud.deactivate_all_user_sessions",
|
"app.api.routes.auth.session_service.deactivate_all_user_sessions",
|
||||||
side_effect=Exception("Invalidation failed"),
|
side_effect=Exception("Invalidation failed"),
|
||||||
):
|
):
|
||||||
response = await client.post(
|
response = await client.post(
|
||||||
|
|||||||
@@ -334,7 +334,7 @@ class TestPasswordResetConfirm:
|
|||||||
token = create_password_reset_token(async_test_user.email)
|
token = create_password_reset_token(async_test_user.email)
|
||||||
|
|
||||||
# Mock the database commit to raise an exception
|
# Mock the database commit to raise an exception
|
||||||
with patch("app.api.routes.auth.user_crud.get_by_email") as mock_get:
|
with patch("app.services.auth_service.user_repo.get_by_email") as mock_get:
|
||||||
mock_get.side_effect = Exception("Database error")
|
mock_get.side_effect = Exception("Database error")
|
||||||
|
|
||||||
response = await client.post(
|
response = await client.post(
|
||||||
|
|||||||
@@ -12,7 +12,7 @@ These tests prevent real-world attack scenarios.
|
|||||||
import pytest
|
import pytest
|
||||||
from httpx import AsyncClient
|
from httpx import AsyncClient
|
||||||
|
|
||||||
from app.crud.session import session as session_crud
|
from app.repositories.session import session_repo as session_crud
|
||||||
from app.models.user import User
|
from app.models.user import User
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -8,7 +8,7 @@ from uuid import uuid4
|
|||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from app.crud.oauth import oauth_account
|
from app.repositories.oauth_account import oauth_account_repo as oauth_account
|
||||||
from app.schemas.oauth import OAuthAccountCreate
|
from app.schemas.oauth import OAuthAccountCreate
|
||||||
|
|
||||||
|
|
||||||
@@ -349,7 +349,7 @@ class TestOAuthProviderEndpoints:
|
|||||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||||
|
|
||||||
# Create a test client
|
# Create a test client
|
||||||
from app.crud.oauth import oauth_client
|
from app.repositories.oauth_client import oauth_client_repo as oauth_client
|
||||||
from app.schemas.oauth import OAuthClientCreate
|
from app.schemas.oauth import OAuthClientCreate
|
||||||
|
|
||||||
async with AsyncTestingSessionLocal() as session:
|
async with AsyncTestingSessionLocal() as session:
|
||||||
@@ -386,7 +386,7 @@ class TestOAuthProviderEndpoints:
|
|||||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||||
|
|
||||||
# Create a test client
|
# Create a test client
|
||||||
from app.crud.oauth import oauth_client
|
from app.repositories.oauth_client import oauth_client_repo as oauth_client
|
||||||
from app.schemas.oauth import OAuthClientCreate
|
from app.schemas.oauth import OAuthClientCreate
|
||||||
|
|
||||||
async with AsyncTestingSessionLocal() as session:
|
async with AsyncTestingSessionLocal() as session:
|
||||||
|
|||||||
@@ -537,7 +537,7 @@ class TestOrganizationExceptionHandlers:
|
|||||||
):
|
):
|
||||||
"""Test generic exception handler in get_my_organizations (covers lines 81-83)."""
|
"""Test generic exception handler in get_my_organizations (covers lines 81-83)."""
|
||||||
with patch(
|
with patch(
|
||||||
"app.crud.organization.organization.get_user_organizations_with_details",
|
"app.api.routes.organizations.organization_service.get_user_organizations_with_details",
|
||||||
side_effect=Exception("Database connection lost"),
|
side_effect=Exception("Database connection lost"),
|
||||||
):
|
):
|
||||||
# The exception handler logs and re-raises, so we expect the exception
|
# The exception handler logs and re-raises, so we expect the exception
|
||||||
@@ -554,7 +554,7 @@ class TestOrganizationExceptionHandlers:
|
|||||||
):
|
):
|
||||||
"""Test generic exception handler in get_organization (covers lines 124-128)."""
|
"""Test generic exception handler in get_organization (covers lines 124-128)."""
|
||||||
with patch(
|
with patch(
|
||||||
"app.crud.organization.organization.get",
|
"app.api.routes.organizations.organization_service.get_organization",
|
||||||
side_effect=Exception("Database timeout"),
|
side_effect=Exception("Database timeout"),
|
||||||
):
|
):
|
||||||
with pytest.raises(Exception, match="Database timeout"):
|
with pytest.raises(Exception, match="Database timeout"):
|
||||||
@@ -569,7 +569,7 @@ class TestOrganizationExceptionHandlers:
|
|||||||
):
|
):
|
||||||
"""Test generic exception handler in get_organization_members (covers lines 170-172)."""
|
"""Test generic exception handler in get_organization_members (covers lines 170-172)."""
|
||||||
with patch(
|
with patch(
|
||||||
"app.crud.organization.organization.get_organization_members",
|
"app.api.routes.organizations.organization_service.get_organization_members",
|
||||||
side_effect=Exception("Connection pool exhausted"),
|
side_effect=Exception("Connection pool exhausted"),
|
||||||
):
|
):
|
||||||
with pytest.raises(Exception, match="Connection pool exhausted"):
|
with pytest.raises(Exception, match="Connection pool exhausted"):
|
||||||
@@ -591,11 +591,11 @@ class TestOrganizationExceptionHandlers:
|
|||||||
admin_token = login_response.json()["access_token"]
|
admin_token = login_response.json()["access_token"]
|
||||||
|
|
||||||
with patch(
|
with patch(
|
||||||
"app.crud.organization.organization.get",
|
"app.api.routes.organizations.organization_service.get_organization",
|
||||||
return_value=test_org_with_user_admin,
|
return_value=test_org_with_user_admin,
|
||||||
):
|
):
|
||||||
with patch(
|
with patch(
|
||||||
"app.crud.organization.organization.update",
|
"app.api.routes.organizations.organization_service.update_organization",
|
||||||
side_effect=Exception("Write lock timeout"),
|
side_effect=Exception("Write lock timeout"),
|
||||||
):
|
):
|
||||||
with pytest.raises(Exception, match="Write lock timeout"):
|
with pytest.raises(Exception, match="Write lock timeout"):
|
||||||
|
|||||||
@@ -11,7 +11,7 @@ These tests prevent unauthorized access and privilege escalation.
|
|||||||
import pytest
|
import pytest
|
||||||
from httpx import AsyncClient
|
from httpx import AsyncClient
|
||||||
|
|
||||||
from app.crud.user import user as user_crud
|
from app.repositories.user import user_repo as user_crud
|
||||||
from app.models.organization import Organization
|
from app.models.organization import Organization
|
||||||
from app.models.user import User
|
from app.models.user import User
|
||||||
|
|
||||||
|
|||||||
@@ -39,7 +39,7 @@ async def async_test_user2(async_test_db):
|
|||||||
_test_engine, SessionLocal = async_test_db
|
_test_engine, SessionLocal = async_test_db
|
||||||
|
|
||||||
async with SessionLocal() as session:
|
async with SessionLocal() as session:
|
||||||
from app.crud.user import user as user_crud
|
from app.repositories.user import user_repo as user_crud
|
||||||
from app.schemas.users import UserCreate
|
from app.schemas.users import UserCreate
|
||||||
|
|
||||||
user_data = UserCreate(
|
user_data = UserCreate(
|
||||||
@@ -191,7 +191,7 @@ class TestRevokeSession:
|
|||||||
|
|
||||||
# Verify session is deactivated
|
# Verify session is deactivated
|
||||||
async with SessionLocal() as session:
|
async with SessionLocal() as session:
|
||||||
from app.crud.session import session as session_crud
|
from app.repositories.session import session_repo as session_crud
|
||||||
|
|
||||||
revoked_session = await session_crud.get(session, id=str(session_id))
|
revoked_session = await session_crud.get(session, id=str(session_id))
|
||||||
assert revoked_session.is_active is False
|
assert revoked_session.is_active is False
|
||||||
@@ -268,7 +268,7 @@ class TestCleanupExpiredSessions:
|
|||||||
_test_engine, SessionLocal = async_test_db
|
_test_engine, SessionLocal = async_test_db
|
||||||
|
|
||||||
# Create expired and active sessions using CRUD to avoid greenlet issues
|
# Create expired and active sessions using CRUD to avoid greenlet issues
|
||||||
from app.crud.session import session as session_crud
|
from app.repositories.session import session_repo as session_crud
|
||||||
from app.schemas.sessions import SessionCreate
|
from app.schemas.sessions import SessionCreate
|
||||||
|
|
||||||
async with SessionLocal() as db:
|
async with SessionLocal() as db:
|
||||||
@@ -334,7 +334,7 @@ class TestCleanupExpiredSessions:
|
|||||||
_test_engine, SessionLocal = async_test_db
|
_test_engine, SessionLocal = async_test_db
|
||||||
|
|
||||||
# Create only active sessions using CRUD
|
# Create only active sessions using CRUD
|
||||||
from app.crud.session import session as session_crud
|
from app.repositories.session import session_repo as session_crud
|
||||||
from app.schemas.sessions import SessionCreate
|
from app.schemas.sessions import SessionCreate
|
||||||
|
|
||||||
async with SessionLocal() as db:
|
async with SessionLocal() as db:
|
||||||
@@ -384,7 +384,7 @@ class TestSessionsAdditionalCases:
|
|||||||
|
|
||||||
# Create multiple sessions
|
# Create multiple sessions
|
||||||
async with SessionLocal() as session:
|
async with SessionLocal() as session:
|
||||||
from app.crud.session import session as session_crud
|
from app.repositories.session import session_repo as session_crud
|
||||||
from app.schemas.sessions import SessionCreate
|
from app.schemas.sessions import SessionCreate
|
||||||
|
|
||||||
for i in range(5):
|
for i in range(5):
|
||||||
@@ -431,7 +431,7 @@ class TestSessionsAdditionalCases:
|
|||||||
"""Test cleanup with mix of active/inactive and expired/not-expired sessions."""
|
"""Test cleanup with mix of active/inactive and expired/not-expired sessions."""
|
||||||
_test_engine, SessionLocal = async_test_db
|
_test_engine, SessionLocal = async_test_db
|
||||||
|
|
||||||
from app.crud.session import session as session_crud
|
from app.repositories.session import session_repo as session_crud
|
||||||
from app.schemas.sessions import SessionCreate
|
from app.schemas.sessions import SessionCreate
|
||||||
|
|
||||||
async with SessionLocal() as db:
|
async with SessionLocal() as db:
|
||||||
@@ -502,10 +502,10 @@ class TestSessionExceptionHandlers:
|
|||||||
"""Test list_sessions handles database errors (covers lines 104-106)."""
|
"""Test list_sessions handles database errors (covers lines 104-106)."""
|
||||||
from unittest.mock import patch
|
from unittest.mock import patch
|
||||||
|
|
||||||
from app.crud import session as session_module
|
from app.repositories import session as session_module
|
||||||
|
|
||||||
with patch.object(
|
with patch.object(
|
||||||
session_module.session,
|
session_module.session_repo,
|
||||||
"get_user_sessions",
|
"get_user_sessions",
|
||||||
side_effect=Exception("Database error"),
|
side_effect=Exception("Database error"),
|
||||||
):
|
):
|
||||||
@@ -527,10 +527,10 @@ class TestSessionExceptionHandlers:
|
|||||||
from unittest.mock import patch
|
from unittest.mock import patch
|
||||||
from uuid import uuid4
|
from uuid import uuid4
|
||||||
|
|
||||||
from app.crud import session as session_module
|
from app.repositories import session as session_module
|
||||||
|
|
||||||
# First create a session to revoke
|
# First create a session to revoke
|
||||||
from app.crud.session import session as session_crud
|
from app.repositories.session import session_repo as session_crud
|
||||||
from app.schemas.sessions import SessionCreate
|
from app.schemas.sessions import SessionCreate
|
||||||
|
|
||||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||||
@@ -550,7 +550,7 @@ class TestSessionExceptionHandlers:
|
|||||||
|
|
||||||
# Mock the deactivate method to raise an exception
|
# Mock the deactivate method to raise an exception
|
||||||
with patch.object(
|
with patch.object(
|
||||||
session_module.session,
|
session_module.session_repo,
|
||||||
"deactivate",
|
"deactivate",
|
||||||
side_effect=Exception("Database connection lost"),
|
side_effect=Exception("Database connection lost"),
|
||||||
):
|
):
|
||||||
@@ -568,10 +568,10 @@ class TestSessionExceptionHandlers:
|
|||||||
"""Test cleanup_expired_sessions handles database errors (covers lines 233-236)."""
|
"""Test cleanup_expired_sessions handles database errors (covers lines 233-236)."""
|
||||||
from unittest.mock import patch
|
from unittest.mock import patch
|
||||||
|
|
||||||
from app.crud import session as session_module
|
from app.repositories import session as session_module
|
||||||
|
|
||||||
with patch.object(
|
with patch.object(
|
||||||
session_module.session,
|
session_module.session_repo,
|
||||||
"cleanup_expired_for_user",
|
"cleanup_expired_for_user",
|
||||||
side_effect=Exception("Cleanup failed"),
|
side_effect=Exception("Cleanup failed"),
|
||||||
):
|
):
|
||||||
|
|||||||
@@ -99,7 +99,7 @@ class TestUpdateCurrentUser:
|
|||||||
from unittest.mock import patch
|
from unittest.mock import patch
|
||||||
|
|
||||||
with patch(
|
with patch(
|
||||||
"app.api.routes.users.user_crud.update", side_effect=Exception("DB error")
|
"app.api.routes.users.user_service.update_user", side_effect=Exception("DB error")
|
||||||
):
|
):
|
||||||
with pytest.raises(Exception):
|
with pytest.raises(Exception):
|
||||||
await client.patch(
|
await client.patch(
|
||||||
@@ -134,7 +134,7 @@ class TestUpdateCurrentUser:
|
|||||||
from unittest.mock import patch
|
from unittest.mock import patch
|
||||||
|
|
||||||
with patch(
|
with patch(
|
||||||
"app.api.routes.users.user_crud.update",
|
"app.api.routes.users.user_service.update_user",
|
||||||
side_effect=ValueError("Invalid value"),
|
side_effect=ValueError("Invalid value"),
|
||||||
):
|
):
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(ValueError):
|
||||||
@@ -224,7 +224,7 @@ class TestUpdateUserById:
|
|||||||
from unittest.mock import patch
|
from unittest.mock import patch
|
||||||
|
|
||||||
with patch(
|
with patch(
|
||||||
"app.api.routes.users.user_crud.update", side_effect=ValueError("Invalid")
|
"app.api.routes.users.user_service.update_user", side_effect=ValueError("Invalid")
|
||||||
):
|
):
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(ValueError):
|
||||||
await client.patch(
|
await client.patch(
|
||||||
@@ -241,7 +241,7 @@ class TestUpdateUserById:
|
|||||||
from unittest.mock import patch
|
from unittest.mock import patch
|
||||||
|
|
||||||
with patch(
|
with patch(
|
||||||
"app.api.routes.users.user_crud.update", side_effect=Exception("Unexpected")
|
"app.api.routes.users.user_service.update_user", side_effect=Exception("Unexpected")
|
||||||
):
|
):
|
||||||
with pytest.raises(Exception):
|
with pytest.raises(Exception):
|
||||||
await client.patch(
|
await client.patch(
|
||||||
@@ -354,7 +354,7 @@ class TestDeleteUserById:
|
|||||||
from unittest.mock import patch
|
from unittest.mock import patch
|
||||||
|
|
||||||
with patch(
|
with patch(
|
||||||
"app.api.routes.users.user_crud.soft_delete",
|
"app.api.routes.users.user_service.soft_delete_user",
|
||||||
side_effect=ValueError("Cannot delete"),
|
side_effect=ValueError("Cannot delete"),
|
||||||
):
|
):
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(ValueError):
|
||||||
@@ -371,7 +371,7 @@ class TestDeleteUserById:
|
|||||||
from unittest.mock import patch
|
from unittest.mock import patch
|
||||||
|
|
||||||
with patch(
|
with patch(
|
||||||
"app.api.routes.users.user_crud.soft_delete",
|
"app.api.routes.users.user_service.soft_delete_user",
|
||||||
side_effect=Exception("Unexpected"),
|
side_effect=Exception("Unexpected"),
|
||||||
):
|
):
|
||||||
with pytest.raises(Exception):
|
with pytest.raises(Exception):
|
||||||
|
|||||||
@@ -46,7 +46,7 @@ async def login_user(client, email: str, password: str = "SecurePassword123!"):
|
|||||||
|
|
||||||
async def create_superuser(e2e_db_session, email: str, password: str):
|
async def create_superuser(e2e_db_session, email: str, password: str):
|
||||||
"""Create a superuser directly in the database."""
|
"""Create a superuser directly in the database."""
|
||||||
from app.crud.user import user as user_crud
|
from app.repositories.user import user_repo as user_crud
|
||||||
from app.schemas.users import UserCreate
|
from app.schemas.users import UserCreate
|
||||||
|
|
||||||
user_in = UserCreate(
|
user_in = UserCreate(
|
||||||
|
|||||||
@@ -46,7 +46,7 @@ async def register_and_login(client, email: str, password: str = "SecurePassword
|
|||||||
|
|
||||||
async def create_superuser_and_login(client, db_session):
|
async def create_superuser_and_login(client, db_session):
|
||||||
"""Helper to create a superuser directly in DB and login."""
|
"""Helper to create a superuser directly in DB and login."""
|
||||||
from app.crud.user import user as user_crud
|
from app.repositories.user import user_repo as user_crud
|
||||||
from app.schemas.users import UserCreate
|
from app.schemas.users import UserCreate
|
||||||
|
|
||||||
email = f"admin-{uuid4().hex[:8]}@example.com"
|
email = f"admin-{uuid4().hex[:8]}@example.com"
|
||||||
|
|||||||
@@ -11,7 +11,12 @@ import pytest
|
|||||||
from sqlalchemy.exc import DataError, IntegrityError, OperationalError
|
from sqlalchemy.exc import DataError, IntegrityError, OperationalError
|
||||||
from sqlalchemy.orm import joinedload
|
from sqlalchemy.orm import joinedload
|
||||||
|
|
||||||
from app.crud.user import user as user_crud
|
from app.core.repository_exceptions import (
|
||||||
|
DuplicateEntryError,
|
||||||
|
IntegrityConstraintError,
|
||||||
|
InvalidInputError,
|
||||||
|
)
|
||||||
|
from app.repositories.user import user_repo as user_crud
|
||||||
from app.schemas.users import UserCreate, UserUpdate
|
from app.schemas.users import UserCreate, UserUpdate
|
||||||
|
|
||||||
|
|
||||||
@@ -81,7 +86,7 @@ class TestCRUDBaseGetMulti:
|
|||||||
_test_engine, SessionLocal = async_test_db
|
_test_engine, SessionLocal = async_test_db
|
||||||
|
|
||||||
async with SessionLocal() as session:
|
async with SessionLocal() as session:
|
||||||
with pytest.raises(ValueError, match="skip must be non-negative"):
|
with pytest.raises(InvalidInputError, match="skip must be non-negative"):
|
||||||
await user_crud.get_multi(session, skip=-1)
|
await user_crud.get_multi(session, skip=-1)
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@@ -90,7 +95,7 @@ class TestCRUDBaseGetMulti:
|
|||||||
_test_engine, SessionLocal = async_test_db
|
_test_engine, SessionLocal = async_test_db
|
||||||
|
|
||||||
async with SessionLocal() as session:
|
async with SessionLocal() as session:
|
||||||
with pytest.raises(ValueError, match="limit must be non-negative"):
|
with pytest.raises(InvalidInputError, match="limit must be non-negative"):
|
||||||
await user_crud.get_multi(session, limit=-1)
|
await user_crud.get_multi(session, limit=-1)
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@@ -99,7 +104,7 @@ class TestCRUDBaseGetMulti:
|
|||||||
_test_engine, SessionLocal = async_test_db
|
_test_engine, SessionLocal = async_test_db
|
||||||
|
|
||||||
async with SessionLocal() as session:
|
async with SessionLocal() as session:
|
||||||
with pytest.raises(ValueError, match="Maximum limit is 1000"):
|
with pytest.raises(InvalidInputError, match="Maximum limit is 1000"):
|
||||||
await user_crud.get_multi(session, limit=1001)
|
await user_crud.get_multi(session, limit=1001)
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@@ -140,7 +145,7 @@ class TestCRUDBaseCreate:
|
|||||||
last_name="Duplicate",
|
last_name="Duplicate",
|
||||||
)
|
)
|
||||||
|
|
||||||
with pytest.raises(ValueError, match="already exists"):
|
with pytest.raises(DuplicateEntryError, match="already exists"):
|
||||||
await user_crud.create(session, obj_in=user_data)
|
await user_crud.create(session, obj_in=user_data)
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@@ -165,7 +170,7 @@ class TestCRUDBaseCreate:
|
|||||||
last_name="User",
|
last_name="User",
|
||||||
)
|
)
|
||||||
|
|
||||||
with pytest.raises(ValueError, match="Database integrity error"):
|
with pytest.raises(DuplicateEntryError, match="Database integrity error"):
|
||||||
await user_crud.create(session, obj_in=user_data)
|
await user_crud.create(session, obj_in=user_data)
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@@ -244,7 +249,7 @@ class TestCRUDBaseUpdate:
|
|||||||
|
|
||||||
# Create another user
|
# Create another user
|
||||||
async with SessionLocal() as session:
|
async with SessionLocal() as session:
|
||||||
from app.crud.user import user as user_crud
|
from app.repositories.user import user_repo as user_crud
|
||||||
|
|
||||||
user2_data = UserCreate(
|
user2_data = UserCreate(
|
||||||
email="user2@example.com",
|
email="user2@example.com",
|
||||||
@@ -268,7 +273,7 @@ class TestCRUDBaseUpdate:
|
|||||||
):
|
):
|
||||||
update_data = UserUpdate(email=async_test_user.email)
|
update_data = UserUpdate(email=async_test_user.email)
|
||||||
|
|
||||||
with pytest.raises(ValueError, match="already exists"):
|
with pytest.raises(DuplicateEntryError, match="already exists"):
|
||||||
await user_crud.update(
|
await user_crud.update(
|
||||||
session, db_obj=user2_obj, obj_in=update_data
|
session, db_obj=user2_obj, obj_in=update_data
|
||||||
)
|
)
|
||||||
@@ -302,7 +307,7 @@ class TestCRUDBaseUpdate:
|
|||||||
"statement", {}, Exception("constraint failed")
|
"statement", {}, Exception("constraint failed")
|
||||||
),
|
),
|
||||||
):
|
):
|
||||||
with pytest.raises(ValueError, match="Database integrity error"):
|
with pytest.raises(IntegrityConstraintError, match="Database integrity error"):
|
||||||
await user_crud.update(
|
await user_crud.update(
|
||||||
session, db_obj=user, obj_in={"first_name": "Test"}
|
session, db_obj=user, obj_in={"first_name": "Test"}
|
||||||
)
|
)
|
||||||
@@ -322,7 +327,7 @@ class TestCRUDBaseUpdate:
|
|||||||
"statement", {}, Exception("connection error")
|
"statement", {}, Exception("connection error")
|
||||||
),
|
),
|
||||||
):
|
):
|
||||||
with pytest.raises(ValueError, match="Database operation failed"):
|
with pytest.raises(IntegrityConstraintError, match="Database operation failed"):
|
||||||
await user_crud.update(
|
await user_crud.update(
|
||||||
session, db_obj=user, obj_in={"first_name": "Test"}
|
session, db_obj=user, obj_in={"first_name": "Test"}
|
||||||
)
|
)
|
||||||
@@ -403,7 +408,7 @@ class TestCRUDBaseRemove:
|
|||||||
),
|
),
|
||||||
):
|
):
|
||||||
with pytest.raises(
|
with pytest.raises(
|
||||||
ValueError, match="Cannot delete.*referenced by other records"
|
IntegrityConstraintError, match="Cannot delete.*referenced by other records"
|
||||||
):
|
):
|
||||||
await user_crud.remove(session, id=str(async_test_user.id))
|
await user_crud.remove(session, id=str(async_test_user.id))
|
||||||
|
|
||||||
@@ -442,7 +447,7 @@ class TestCRUDBaseGetMultiWithTotal:
|
|||||||
_test_engine, SessionLocal = async_test_db
|
_test_engine, SessionLocal = async_test_db
|
||||||
|
|
||||||
async with SessionLocal() as session:
|
async with SessionLocal() as session:
|
||||||
with pytest.raises(ValueError, match="skip must be non-negative"):
|
with pytest.raises(InvalidInputError, match="skip must be non-negative"):
|
||||||
await user_crud.get_multi_with_total(session, skip=-1)
|
await user_crud.get_multi_with_total(session, skip=-1)
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@@ -451,7 +456,7 @@ class TestCRUDBaseGetMultiWithTotal:
|
|||||||
_test_engine, SessionLocal = async_test_db
|
_test_engine, SessionLocal = async_test_db
|
||||||
|
|
||||||
async with SessionLocal() as session:
|
async with SessionLocal() as session:
|
||||||
with pytest.raises(ValueError, match="limit must be non-negative"):
|
with pytest.raises(InvalidInputError, match="limit must be non-negative"):
|
||||||
await user_crud.get_multi_with_total(session, limit=-1)
|
await user_crud.get_multi_with_total(session, limit=-1)
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@@ -460,7 +465,7 @@ class TestCRUDBaseGetMultiWithTotal:
|
|||||||
_test_engine, SessionLocal = async_test_db
|
_test_engine, SessionLocal = async_test_db
|
||||||
|
|
||||||
async with SessionLocal() as session:
|
async with SessionLocal() as session:
|
||||||
with pytest.raises(ValueError, match="Maximum limit is 1000"):
|
with pytest.raises(InvalidInputError, match="Maximum limit is 1000"):
|
||||||
await user_crud.get_multi_with_total(session, limit=1001)
|
await user_crud.get_multi_with_total(session, limit=1001)
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@@ -827,7 +832,7 @@ class TestCRUDBasePaginationValidation:
|
|||||||
_test_engine, SessionLocal = async_test_db
|
_test_engine, SessionLocal = async_test_db
|
||||||
|
|
||||||
async with SessionLocal() as session:
|
async with SessionLocal() as session:
|
||||||
with pytest.raises(ValueError, match="skip must be non-negative"):
|
with pytest.raises(InvalidInputError, match="skip must be non-negative"):
|
||||||
await user_crud.get_multi_with_total(session, skip=-1, limit=10)
|
await user_crud.get_multi_with_total(session, skip=-1, limit=10)
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@@ -836,7 +841,7 @@ class TestCRUDBasePaginationValidation:
|
|||||||
_test_engine, SessionLocal = async_test_db
|
_test_engine, SessionLocal = async_test_db
|
||||||
|
|
||||||
async with SessionLocal() as session:
|
async with SessionLocal() as session:
|
||||||
with pytest.raises(ValueError, match="limit must be non-negative"):
|
with pytest.raises(InvalidInputError, match="limit must be non-negative"):
|
||||||
await user_crud.get_multi_with_total(session, skip=0, limit=-1)
|
await user_crud.get_multi_with_total(session, skip=0, limit=-1)
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@@ -845,7 +850,7 @@ class TestCRUDBasePaginationValidation:
|
|||||||
_test_engine, SessionLocal = async_test_db
|
_test_engine, SessionLocal = async_test_db
|
||||||
|
|
||||||
async with SessionLocal() as session:
|
async with SessionLocal() as session:
|
||||||
with pytest.raises(ValueError, match="Maximum limit is 1000"):
|
with pytest.raises(InvalidInputError, match="Maximum limit is 1000"):
|
||||||
await user_crud.get_multi_with_total(session, skip=0, limit=1001)
|
await user_crud.get_multi_with_total(session, skip=0, limit=1001)
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@@ -899,7 +904,7 @@ class TestCRUDBaseModelsWithoutSoftDelete:
|
|||||||
_test_engine, SessionLocal = async_test_db
|
_test_engine, SessionLocal = async_test_db
|
||||||
|
|
||||||
# Create an organization (which doesn't have deleted_at)
|
# Create an organization (which doesn't have deleted_at)
|
||||||
from app.crud.organization import organization as org_crud
|
from app.repositories.organization import organization_repo as org_crud
|
||||||
from app.models.organization import Organization
|
from app.models.organization import Organization
|
||||||
|
|
||||||
async with SessionLocal() as session:
|
async with SessionLocal() as session:
|
||||||
@@ -910,7 +915,7 @@ class TestCRUDBaseModelsWithoutSoftDelete:
|
|||||||
|
|
||||||
# Try to soft delete organization (should fail)
|
# Try to soft delete organization (should fail)
|
||||||
async with SessionLocal() as session:
|
async with SessionLocal() as session:
|
||||||
with pytest.raises(ValueError, match="does not have a deleted_at column"):
|
with pytest.raises(InvalidInputError, match="does not have a deleted_at column"):
|
||||||
await org_crud.soft_delete(session, id=str(org_id))
|
await org_crud.soft_delete(session, id=str(org_id))
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@@ -919,7 +924,7 @@ class TestCRUDBaseModelsWithoutSoftDelete:
|
|||||||
_test_engine, SessionLocal = async_test_db
|
_test_engine, SessionLocal = async_test_db
|
||||||
|
|
||||||
# Create an organization (which doesn't have deleted_at)
|
# Create an organization (which doesn't have deleted_at)
|
||||||
from app.crud.organization import organization as org_crud
|
from app.repositories.organization import organization_repo as org_crud
|
||||||
from app.models.organization import Organization
|
from app.models.organization import Organization
|
||||||
|
|
||||||
async with SessionLocal() as session:
|
async with SessionLocal() as session:
|
||||||
@@ -930,7 +935,7 @@ class TestCRUDBaseModelsWithoutSoftDelete:
|
|||||||
|
|
||||||
# Try to restore organization (should fail)
|
# Try to restore organization (should fail)
|
||||||
async with SessionLocal() as session:
|
async with SessionLocal() as session:
|
||||||
with pytest.raises(ValueError, match="does not have a deleted_at column"):
|
with pytest.raises(InvalidInputError, match="does not have a deleted_at column"):
|
||||||
await org_crud.restore(session, id=str(org_id))
|
await org_crud.restore(session, id=str(org_id))
|
||||||
|
|
||||||
|
|
||||||
@@ -950,7 +955,7 @@ class TestCRUDBaseEagerLoadingWithRealOptions:
|
|||||||
_test_engine, SessionLocal = async_test_db
|
_test_engine, SessionLocal = async_test_db
|
||||||
|
|
||||||
# Create a session for the user
|
# Create a session for the user
|
||||||
from app.crud.session import session as session_crud
|
from app.repositories.session import session_repo as session_crud
|
||||||
from app.models.user_session import UserSession
|
from app.models.user_session import UserSession
|
||||||
|
|
||||||
async with SessionLocal() as session:
|
async with SessionLocal() as session:
|
||||||
@@ -989,7 +994,7 @@ class TestCRUDBaseEagerLoadingWithRealOptions:
|
|||||||
_test_engine, SessionLocal = async_test_db
|
_test_engine, SessionLocal = async_test_db
|
||||||
|
|
||||||
# Create multiple sessions for the user
|
# Create multiple sessions for the user
|
||||||
from app.crud.session import session as session_crud
|
from app.repositories.session import session_repo as session_crud
|
||||||
from app.models.user_session import UserSession
|
from app.models.user_session import UserSession
|
||||||
|
|
||||||
async with SessionLocal() as session:
|
async with SessionLocal() as session:
|
||||||
@@ -10,7 +10,8 @@ from uuid import uuid4
|
|||||||
import pytest
|
import pytest
|
||||||
from sqlalchemy.exc import DataError, OperationalError
|
from sqlalchemy.exc import DataError, OperationalError
|
||||||
|
|
||||||
from app.crud.user import user as user_crud
|
from app.core.repository_exceptions import IntegrityConstraintError
|
||||||
|
from app.repositories.user import user_repo as user_crud
|
||||||
from app.schemas.users import UserCreate
|
from app.schemas.users import UserCreate
|
||||||
|
|
||||||
|
|
||||||
@@ -119,7 +120,7 @@ class TestBaseCRUDUpdateFailures:
|
|||||||
with patch.object(
|
with patch.object(
|
||||||
session, "rollback", new_callable=AsyncMock
|
session, "rollback", new_callable=AsyncMock
|
||||||
) as mock_rollback:
|
) as mock_rollback:
|
||||||
with pytest.raises(ValueError, match="Database operation failed"):
|
with pytest.raises(IntegrityConstraintError, match="Database operation failed"):
|
||||||
await user_crud.update(
|
await user_crud.update(
|
||||||
session, db_obj=user, obj_in={"first_name": "Updated"}
|
session, db_obj=user, obj_in={"first_name": "Updated"}
|
||||||
)
|
)
|
||||||
@@ -141,7 +142,7 @@ class TestBaseCRUDUpdateFailures:
|
|||||||
with patch.object(
|
with patch.object(
|
||||||
session, "rollback", new_callable=AsyncMock
|
session, "rollback", new_callable=AsyncMock
|
||||||
) as mock_rollback:
|
) as mock_rollback:
|
||||||
with pytest.raises(ValueError, match="Database operation failed"):
|
with pytest.raises(IntegrityConstraintError, match="Database operation failed"):
|
||||||
await user_crud.update(
|
await user_crud.update(
|
||||||
session, db_obj=user, obj_in={"first_name": "Updated"}
|
session, db_obj=user, obj_in={"first_name": "Updated"}
|
||||||
)
|
)
|
||||||
@@ -7,7 +7,10 @@ from datetime import UTC, datetime, timedelta
|
|||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from app.crud.oauth import oauth_account, oauth_client, oauth_state
|
from app.core.repository_exceptions import DuplicateEntryError
|
||||||
|
from app.repositories.oauth_account import oauth_account_repo as oauth_account
|
||||||
|
from app.repositories.oauth_client import oauth_client_repo as oauth_client
|
||||||
|
from app.repositories.oauth_state import oauth_state_repo as oauth_state
|
||||||
from app.schemas.oauth import OAuthAccountCreate, OAuthClientCreate, OAuthStateCreate
|
from app.schemas.oauth import OAuthAccountCreate, OAuthClientCreate, OAuthStateCreate
|
||||||
|
|
||||||
|
|
||||||
@@ -60,7 +63,7 @@ class TestOAuthAccountCRUD:
|
|||||||
|
|
||||||
# SQLite returns different error message than PostgreSQL
|
# SQLite returns different error message than PostgreSQL
|
||||||
with pytest.raises(
|
with pytest.raises(
|
||||||
ValueError, match="(already linked|UNIQUE constraint failed)"
|
DuplicateEntryError, match="(already linked|UNIQUE constraint failed|Failed to create)"
|
||||||
):
|
):
|
||||||
await oauth_account.create_account(session, obj_in=account_data2)
|
await oauth_account.create_account(session, obj_in=account_data2)
|
||||||
|
|
||||||
@@ -256,13 +259,13 @@ class TestOAuthAccountCRUD:
|
|||||||
updated = await oauth_account.update_tokens(
|
updated = await oauth_account.update_tokens(
|
||||||
session,
|
session,
|
||||||
account=account,
|
account=account,
|
||||||
access_token_encrypted="new_access_token",
|
access_token="new_access_token",
|
||||||
refresh_token_encrypted="new_refresh_token",
|
refresh_token="new_refresh_token",
|
||||||
token_expires_at=new_expires,
|
token_expires_at=new_expires,
|
||||||
)
|
)
|
||||||
|
|
||||||
assert updated.access_token_encrypted == "new_access_token"
|
assert updated.access_token == "new_access_token"
|
||||||
assert updated.refresh_token_encrypted == "new_refresh_token"
|
assert updated.refresh_token == "new_refresh_token"
|
||||||
|
|
||||||
|
|
||||||
class TestOAuthStateCRUD:
|
class TestOAuthStateCRUD:
|
||||||
@@ -9,7 +9,8 @@ from uuid import uuid4
|
|||||||
import pytest
|
import pytest
|
||||||
from sqlalchemy import select
|
from sqlalchemy import select
|
||||||
|
|
||||||
from app.crud.organization import organization as organization_crud
|
from app.core.repository_exceptions import DuplicateEntryError, IntegrityConstraintError
|
||||||
|
from app.repositories.organization import organization_repo as organization_crud
|
||||||
from app.models.organization import Organization
|
from app.models.organization import Organization
|
||||||
from app.models.user_organization import OrganizationRole, UserOrganization
|
from app.models.user_organization import OrganizationRole, UserOrganization
|
||||||
from app.schemas.organizations import OrganizationCreate
|
from app.schemas.organizations import OrganizationCreate
|
||||||
@@ -87,7 +88,7 @@ class TestCreate:
|
|||||||
# Try to create second with same slug
|
# Try to create second with same slug
|
||||||
async with AsyncTestingSessionLocal() as session:
|
async with AsyncTestingSessionLocal() as session:
|
||||||
org_in = OrganizationCreate(name="Org 2", slug="duplicate-slug")
|
org_in = OrganizationCreate(name="Org 2", slug="duplicate-slug")
|
||||||
with pytest.raises(ValueError, match="already exists"):
|
with pytest.raises(DuplicateEntryError, match="already exists"):
|
||||||
await organization_crud.create(session, obj_in=org_in)
|
await organization_crud.create(session, obj_in=org_in)
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@@ -295,7 +296,7 @@ class TestAddUser:
|
|||||||
org_id = org.id
|
org_id = org.id
|
||||||
|
|
||||||
async with AsyncTestingSessionLocal() as session:
|
async with AsyncTestingSessionLocal() as session:
|
||||||
with pytest.raises(ValueError, match="already a member"):
|
with pytest.raises(DuplicateEntryError, match="already a member"):
|
||||||
await organization_crud.add_user(
|
await organization_crud.add_user(
|
||||||
session, organization_id=org_id, user_id=async_test_user.id
|
session, organization_id=org_id, user_id=async_test_user.id
|
||||||
)
|
)
|
||||||
@@ -972,7 +973,7 @@ class TestOrganizationExceptionHandlers:
|
|||||||
with patch.object(session, "commit", side_effect=mock_commit):
|
with patch.object(session, "commit", side_effect=mock_commit):
|
||||||
with patch.object(session, "rollback", new_callable=AsyncMock):
|
with patch.object(session, "rollback", new_callable=AsyncMock):
|
||||||
org_in = OrganizationCreate(name="Test", slug="test")
|
org_in = OrganizationCreate(name="Test", slug="test")
|
||||||
with pytest.raises(ValueError, match="Database integrity error"):
|
with pytest.raises(IntegrityConstraintError, match="Database integrity error"):
|
||||||
await organization_crud.create(session, obj_in=org_in)
|
await organization_crud.create(session, obj_in=org_in)
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@@ -1058,7 +1059,7 @@ class TestOrganizationExceptionHandlers:
|
|||||||
with patch.object(session, "commit", side_effect=mock_commit):
|
with patch.object(session, "commit", side_effect=mock_commit):
|
||||||
with patch.object(session, "rollback", new_callable=AsyncMock):
|
with patch.object(session, "rollback", new_callable=AsyncMock):
|
||||||
with pytest.raises(
|
with pytest.raises(
|
||||||
ValueError, match="Failed to add user to organization"
|
IntegrityConstraintError, match="Failed to add user to organization"
|
||||||
):
|
):
|
||||||
await organization_crud.add_user(
|
await organization_crud.add_user(
|
||||||
session,
|
session,
|
||||||
@@ -8,7 +8,8 @@ from uuid import uuid4
|
|||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from app.crud.session import session as session_crud
|
from app.core.repository_exceptions import InvalidInputError
|
||||||
|
from app.repositories.session import session_repo as session_crud
|
||||||
from app.models.user_session import UserSession
|
from app.models.user_session import UserSession
|
||||||
from app.schemas.sessions import SessionCreate
|
from app.schemas.sessions import SessionCreate
|
||||||
|
|
||||||
@@ -503,7 +504,7 @@ class TestCleanupExpiredForUser:
|
|||||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||||
|
|
||||||
async with AsyncTestingSessionLocal() as session:
|
async with AsyncTestingSessionLocal() as session:
|
||||||
with pytest.raises(ValueError, match="Invalid user ID format"):
|
with pytest.raises(InvalidInputError, match="Invalid user ID format"):
|
||||||
await session_crud.cleanup_expired_for_user(
|
await session_crud.cleanup_expired_for_user(
|
||||||
session, user_id="not-a-valid-uuid"
|
session, user_id="not-a-valid-uuid"
|
||||||
)
|
)
|
||||||
@@ -10,7 +10,8 @@ from uuid import uuid4
|
|||||||
import pytest
|
import pytest
|
||||||
from sqlalchemy.exc import OperationalError
|
from sqlalchemy.exc import OperationalError
|
||||||
|
|
||||||
from app.crud.session import session as session_crud
|
from app.core.repository_exceptions import IntegrityConstraintError
|
||||||
|
from app.repositories.session import session_repo as session_crud
|
||||||
from app.models.user_session import UserSession
|
from app.models.user_session import UserSession
|
||||||
from app.schemas.sessions import SessionCreate
|
from app.schemas.sessions import SessionCreate
|
||||||
|
|
||||||
@@ -102,7 +103,7 @@ class TestSessionCRUDCreateSessionFailures:
|
|||||||
last_used_at=datetime.now(UTC),
|
last_used_at=datetime.now(UTC),
|
||||||
)
|
)
|
||||||
|
|
||||||
with pytest.raises(ValueError, match="Failed to create session"):
|
with pytest.raises(IntegrityConstraintError, match="Failed to create session"):
|
||||||
await session_crud.create_session(session, obj_in=session_data)
|
await session_crud.create_session(session, obj_in=session_data)
|
||||||
|
|
||||||
mock_rollback.assert_called_once()
|
mock_rollback.assert_called_once()
|
||||||
@@ -133,7 +134,7 @@ class TestSessionCRUDCreateSessionFailures:
|
|||||||
last_used_at=datetime.now(UTC),
|
last_used_at=datetime.now(UTC),
|
||||||
)
|
)
|
||||||
|
|
||||||
with pytest.raises(ValueError, match="Failed to create session"):
|
with pytest.raises(IntegrityConstraintError, match="Failed to create session"):
|
||||||
await session_crud.create_session(session, obj_in=session_data)
|
await session_crud.create_session(session, obj_in=session_data)
|
||||||
|
|
||||||
mock_rollback.assert_called_once()
|
mock_rollback.assert_called_once()
|
||||||
@@ -5,7 +5,8 @@ Comprehensive tests for async user CRUD operations.
|
|||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from app.crud.user import user as user_crud
|
from app.core.repository_exceptions import DuplicateEntryError, InvalidInputError
|
||||||
|
from app.repositories.user import user_repo as user_crud
|
||||||
from app.schemas.users import UserCreate, UserUpdate
|
from app.schemas.users import UserCreate, UserUpdate
|
||||||
|
|
||||||
|
|
||||||
@@ -93,7 +94,7 @@ class TestCreate:
|
|||||||
last_name="User",
|
last_name="User",
|
||||||
)
|
)
|
||||||
|
|
||||||
with pytest.raises(ValueError) as exc_info:
|
with pytest.raises(DuplicateEntryError) as exc_info:
|
||||||
await user_crud.create(session, obj_in=user_data)
|
await user_crud.create(session, obj_in=user_data)
|
||||||
|
|
||||||
assert "already exists" in str(exc_info.value).lower()
|
assert "already exists" in str(exc_info.value).lower()
|
||||||
@@ -330,7 +331,7 @@ class TestGetMultiWithTotal:
|
|||||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||||
|
|
||||||
async with AsyncTestingSessionLocal() as session:
|
async with AsyncTestingSessionLocal() as session:
|
||||||
with pytest.raises(ValueError) as exc_info:
|
with pytest.raises(InvalidInputError) as exc_info:
|
||||||
await user_crud.get_multi_with_total(session, skip=-1, limit=10)
|
await user_crud.get_multi_with_total(session, skip=-1, limit=10)
|
||||||
|
|
||||||
assert "skip must be non-negative" in str(exc_info.value)
|
assert "skip must be non-negative" in str(exc_info.value)
|
||||||
@@ -341,7 +342,7 @@ class TestGetMultiWithTotal:
|
|||||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||||
|
|
||||||
async with AsyncTestingSessionLocal() as session:
|
async with AsyncTestingSessionLocal() as session:
|
||||||
with pytest.raises(ValueError) as exc_info:
|
with pytest.raises(InvalidInputError) as exc_info:
|
||||||
await user_crud.get_multi_with_total(session, skip=0, limit=-1)
|
await user_crud.get_multi_with_total(session, skip=0, limit=-1)
|
||||||
|
|
||||||
assert "limit must be non-negative" in str(exc_info.value)
|
assert "limit must be non-negative" in str(exc_info.value)
|
||||||
@@ -352,7 +353,7 @@ class TestGetMultiWithTotal:
|
|||||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||||
|
|
||||||
async with AsyncTestingSessionLocal() as session:
|
async with AsyncTestingSessionLocal() as session:
|
||||||
with pytest.raises(ValueError) as exc_info:
|
with pytest.raises(InvalidInputError) as exc_info:
|
||||||
await user_crud.get_multi_with_total(session, skip=0, limit=1001)
|
await user_crud.get_multi_with_total(session, skip=0, limit=1001)
|
||||||
|
|
||||||
assert "Maximum limit is 1000" in str(exc_info.value)
|
assert "Maximum limit is 1000" in str(exc_info.value)
|
||||||
@@ -10,6 +10,7 @@ from app.core.auth import (
|
|||||||
get_password_hash,
|
get_password_hash,
|
||||||
verify_password,
|
verify_password,
|
||||||
)
|
)
|
||||||
|
from app.core.exceptions import DuplicateError
|
||||||
from app.models.user import User
|
from app.models.user import User
|
||||||
from app.schemas.users import Token, UserCreate
|
from app.schemas.users import Token, UserCreate
|
||||||
from app.services.auth_service import AuthenticationError, AuthService
|
from app.services.auth_service import AuthenticationError, AuthService
|
||||||
@@ -152,9 +153,9 @@ class TestAuthServiceUserCreation:
|
|||||||
last_name="User",
|
last_name="User",
|
||||||
)
|
)
|
||||||
|
|
||||||
# Should raise AuthenticationError
|
# Should raise DuplicateError for duplicate email
|
||||||
async with AsyncTestingSessionLocal() as session:
|
async with AsyncTestingSessionLocal() as session:
|
||||||
with pytest.raises(AuthenticationError):
|
with pytest.raises(DuplicateError):
|
||||||
await AuthService.create_user(db=session, user_data=user_data)
|
await AuthService.create_user(db=session, user_data=user_data)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -269,18 +269,18 @@ class TestClientValidation:
|
|||||||
async def test_validate_client_legacy_sha256_hash(
|
async def test_validate_client_legacy_sha256_hash(
|
||||||
self, db, confidential_client_legacy_hash
|
self, db, confidential_client_legacy_hash
|
||||||
):
|
):
|
||||||
"""Test validating a client with legacy SHA-256 hash (backward compatibility)."""
|
"""Test that legacy SHA-256 hash is rejected with clear error message."""
|
||||||
client, secret = confidential_client_legacy_hash
|
client, secret = confidential_client_legacy_hash
|
||||||
validated = await service.validate_client(db, client.client_id, secret)
|
with pytest.raises(service.InvalidClientError, match="deprecated hash format"):
|
||||||
assert validated.client_id == client.client_id
|
await service.validate_client(db, client.client_id, secret)
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_validate_client_legacy_sha256_wrong_secret(
|
async def test_validate_client_legacy_sha256_wrong_secret(
|
||||||
self, db, confidential_client_legacy_hash
|
self, db, confidential_client_legacy_hash
|
||||||
):
|
):
|
||||||
"""Test legacy SHA-256 client rejects wrong secret."""
|
"""Test that legacy SHA-256 client with wrong secret is rejected."""
|
||||||
client, _ = confidential_client_legacy_hash
|
client, _ = confidential_client_legacy_hash
|
||||||
with pytest.raises(service.InvalidClientError, match="Invalid client secret"):
|
with pytest.raises(service.InvalidClientError, match="deprecated hash format"):
|
||||||
await service.validate_client(db, client.client_id, "wrong_secret")
|
await service.validate_client(db, client.client_id, "wrong_secret")
|
||||||
|
|
||||||
def test_validate_redirect_uri_success(self, public_client):
|
def test_validate_redirect_uri_success(self, public_client):
|
||||||
|
|||||||
@@ -11,7 +11,8 @@ from uuid import uuid4
|
|||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from app.core.exceptions import AuthenticationError
|
from app.core.exceptions import AuthenticationError
|
||||||
from app.crud.oauth import oauth_account, oauth_state
|
from app.repositories.oauth_account import oauth_account_repo as oauth_account
|
||||||
|
from app.repositories.oauth_state import oauth_state_repo as oauth_state
|
||||||
from app.schemas.oauth import OAuthAccountCreate, OAuthStateCreate
|
from app.schemas.oauth import OAuthAccountCreate, OAuthStateCreate
|
||||||
from app.services.oauth_service import OAUTH_PROVIDERS, OAuthService
|
from app.services.oauth_service import OAUTH_PROVIDERS, OAuthService
|
||||||
|
|
||||||
|
|||||||
447
backend/tests/services/test_organization_service.py
Normal file
447
backend/tests/services/test_organization_service.py
Normal file
@@ -0,0 +1,447 @@
|
|||||||
|
# tests/services/test_organization_service.py
|
||||||
|
"""Tests for the OrganizationService class."""
|
||||||
|
|
||||||
|
import uuid
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import pytest_asyncio
|
||||||
|
|
||||||
|
from app.core.exceptions import NotFoundError
|
||||||
|
from app.models.user_organization import OrganizationRole
|
||||||
|
from app.schemas.organizations import OrganizationCreate, OrganizationUpdate
|
||||||
|
from app.services.organization_service import OrganizationService, organization_service
|
||||||
|
|
||||||
|
|
||||||
|
def _make_org_create(name=None, slug=None) -> OrganizationCreate:
|
||||||
|
"""Helper to create an OrganizationCreate schema with unique defaults."""
|
||||||
|
unique = uuid.uuid4().hex[:8]
|
||||||
|
return OrganizationCreate(
|
||||||
|
name=name or f"Test Org {unique}",
|
||||||
|
slug=slug or f"test-org-{unique}",
|
||||||
|
description="A test organization",
|
||||||
|
is_active=True,
|
||||||
|
settings={},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestGetOrganization:
|
||||||
|
"""Tests for OrganizationService.get_organization method."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_organization_found(self, async_test_db, async_test_user):
|
||||||
|
"""Test getting an existing organization by ID returns the org."""
|
||||||
|
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||||
|
async with AsyncTestingSessionLocal() as session:
|
||||||
|
created = await organization_service.create_organization(
|
||||||
|
session, obj_in=_make_org_create()
|
||||||
|
)
|
||||||
|
|
||||||
|
async with AsyncTestingSessionLocal() as session:
|
||||||
|
result = await organization_service.get_organization(
|
||||||
|
session, str(created.id)
|
||||||
|
)
|
||||||
|
assert result is not None
|
||||||
|
assert result.id == created.id
|
||||||
|
assert result.slug == created.slug
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_organization_not_found(self, async_test_db):
|
||||||
|
"""Test getting a non-existent organization raises NotFoundError."""
|
||||||
|
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||||
|
async with AsyncTestingSessionLocal() as session:
|
||||||
|
with pytest.raises(NotFoundError):
|
||||||
|
await organization_service.get_organization(
|
||||||
|
session, str(uuid.uuid4())
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestCreateOrganization:
|
||||||
|
"""Tests for OrganizationService.create_organization method."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_create_organization(self, async_test_db, async_test_user):
|
||||||
|
"""Test creating a new organization returns the created org with correct fields."""
|
||||||
|
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||||
|
obj_in = _make_org_create()
|
||||||
|
async with AsyncTestingSessionLocal() as session:
|
||||||
|
result = await organization_service.create_organization(
|
||||||
|
session, obj_in=obj_in
|
||||||
|
)
|
||||||
|
assert result is not None
|
||||||
|
assert result.name == obj_in.name
|
||||||
|
assert result.slug == obj_in.slug
|
||||||
|
assert result.description == obj_in.description
|
||||||
|
assert result.is_active is True
|
||||||
|
|
||||||
|
|
||||||
|
class TestUpdateOrganization:
|
||||||
|
"""Tests for OrganizationService.update_organization method."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_update_organization(self, async_test_db, async_test_user):
|
||||||
|
"""Test updating an organization name."""
|
||||||
|
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||||
|
async with AsyncTestingSessionLocal() as session:
|
||||||
|
created = await organization_service.create_organization(
|
||||||
|
session, obj_in=_make_org_create()
|
||||||
|
)
|
||||||
|
|
||||||
|
async with AsyncTestingSessionLocal() as session:
|
||||||
|
org = await organization_service.get_organization(session, str(created.id))
|
||||||
|
updated = await organization_service.update_organization(
|
||||||
|
session,
|
||||||
|
org=org,
|
||||||
|
obj_in=OrganizationUpdate(name="Updated Org Name"),
|
||||||
|
)
|
||||||
|
assert updated.name == "Updated Org Name"
|
||||||
|
assert updated.id == created.id
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_update_organization_with_dict(self, async_test_db, async_test_user):
|
||||||
|
"""Test updating an organization using a dict."""
|
||||||
|
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||||
|
async with AsyncTestingSessionLocal() as session:
|
||||||
|
created = await organization_service.create_organization(
|
||||||
|
session, obj_in=_make_org_create()
|
||||||
|
)
|
||||||
|
|
||||||
|
async with AsyncTestingSessionLocal() as session:
|
||||||
|
org = await organization_service.get_organization(session, str(created.id))
|
||||||
|
updated = await organization_service.update_organization(
|
||||||
|
session,
|
||||||
|
org=org,
|
||||||
|
obj_in={"description": "Updated description"},
|
||||||
|
)
|
||||||
|
assert updated.description == "Updated description"
|
||||||
|
|
||||||
|
|
||||||
|
class TestRemoveOrganization:
|
||||||
|
"""Tests for OrganizationService.remove_organization method."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_remove_organization(self, async_test_db, async_test_user):
|
||||||
|
"""Test permanently deleting an organization."""
|
||||||
|
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||||
|
async with AsyncTestingSessionLocal() as session:
|
||||||
|
created = await organization_service.create_organization(
|
||||||
|
session, obj_in=_make_org_create()
|
||||||
|
)
|
||||||
|
org_id = str(created.id)
|
||||||
|
|
||||||
|
async with AsyncTestingSessionLocal() as session:
|
||||||
|
await organization_service.remove_organization(session, org_id)
|
||||||
|
|
||||||
|
async with AsyncTestingSessionLocal() as session:
|
||||||
|
with pytest.raises(NotFoundError):
|
||||||
|
await organization_service.get_organization(session, org_id)
|
||||||
|
|
||||||
|
|
||||||
|
class TestGetMemberCount:
|
||||||
|
"""Tests for OrganizationService.get_member_count method."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_member_count_empty(self, async_test_db, async_test_user):
|
||||||
|
"""Test member count for org with no members is zero."""
|
||||||
|
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||||
|
async with AsyncTestingSessionLocal() as session:
|
||||||
|
created = await organization_service.create_organization(
|
||||||
|
session, obj_in=_make_org_create()
|
||||||
|
)
|
||||||
|
|
||||||
|
async with AsyncTestingSessionLocal() as session:
|
||||||
|
count = await organization_service.get_member_count(
|
||||||
|
session, organization_id=created.id
|
||||||
|
)
|
||||||
|
assert count == 0
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_member_count_with_member(self, async_test_db, async_test_user):
|
||||||
|
"""Test member count increases after adding a member."""
|
||||||
|
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||||
|
async with AsyncTestingSessionLocal() as session:
|
||||||
|
created = await organization_service.create_organization(
|
||||||
|
session, obj_in=_make_org_create()
|
||||||
|
)
|
||||||
|
|
||||||
|
async with AsyncTestingSessionLocal() as session:
|
||||||
|
await organization_service.add_member(
|
||||||
|
session,
|
||||||
|
organization_id=created.id,
|
||||||
|
user_id=async_test_user.id,
|
||||||
|
)
|
||||||
|
|
||||||
|
async with AsyncTestingSessionLocal() as session:
|
||||||
|
count = await organization_service.get_member_count(
|
||||||
|
session, organization_id=created.id
|
||||||
|
)
|
||||||
|
assert count == 1
|
||||||
|
|
||||||
|
|
||||||
|
class TestGetMultiWithMemberCounts:
|
||||||
|
"""Tests for OrganizationService.get_multi_with_member_counts method."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_multi_with_member_counts(self, async_test_db, async_test_user):
|
||||||
|
"""Test listing organizations with member counts returns tuple."""
|
||||||
|
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||||
|
async with AsyncTestingSessionLocal() as session:
|
||||||
|
await organization_service.create_organization(
|
||||||
|
session, obj_in=_make_org_create()
|
||||||
|
)
|
||||||
|
|
||||||
|
async with AsyncTestingSessionLocal() as session:
|
||||||
|
orgs, count = await organization_service.get_multi_with_member_counts(
|
||||||
|
session, skip=0, limit=10
|
||||||
|
)
|
||||||
|
assert isinstance(orgs, list)
|
||||||
|
assert isinstance(count, int)
|
||||||
|
assert count >= 1
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_multi_with_member_counts_search(
|
||||||
|
self, async_test_db, async_test_user
|
||||||
|
):
|
||||||
|
"""Test listing organizations with a search filter."""
|
||||||
|
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||||
|
unique = uuid.uuid4().hex[:8]
|
||||||
|
org_name = f"Searchable Org {unique}"
|
||||||
|
async with AsyncTestingSessionLocal() as session:
|
||||||
|
await organization_service.create_organization(
|
||||||
|
session,
|
||||||
|
obj_in=OrganizationCreate(
|
||||||
|
name=org_name,
|
||||||
|
slug=f"searchable-org-{unique}",
|
||||||
|
is_active=True,
|
||||||
|
settings={},
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
async with AsyncTestingSessionLocal() as session:
|
||||||
|
orgs, count = await organization_service.get_multi_with_member_counts(
|
||||||
|
session, skip=0, limit=10, search=f"Searchable Org {unique}"
|
||||||
|
)
|
||||||
|
assert count >= 1
|
||||||
|
# Each element is a dict with key "organization" (an Organization obj) and "member_count"
|
||||||
|
names = [o["organization"].name for o in orgs]
|
||||||
|
assert org_name in names
|
||||||
|
|
||||||
|
|
||||||
|
class TestGetUserOrganizationsWithDetails:
|
||||||
|
"""Tests for OrganizationService.get_user_organizations_with_details method."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_user_organizations_with_details(
|
||||||
|
self, async_test_db, async_test_user
|
||||||
|
):
|
||||||
|
"""Test getting organizations for a user returns list of dicts."""
|
||||||
|
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||||
|
async with AsyncTestingSessionLocal() as session:
|
||||||
|
created = await organization_service.create_organization(
|
||||||
|
session, obj_in=_make_org_create()
|
||||||
|
)
|
||||||
|
await organization_service.add_member(
|
||||||
|
session,
|
||||||
|
organization_id=created.id,
|
||||||
|
user_id=async_test_user.id,
|
||||||
|
)
|
||||||
|
|
||||||
|
async with AsyncTestingSessionLocal() as session:
|
||||||
|
orgs = await organization_service.get_user_organizations_with_details(
|
||||||
|
session, user_id=async_test_user.id
|
||||||
|
)
|
||||||
|
assert isinstance(orgs, list)
|
||||||
|
assert len(orgs) >= 1
|
||||||
|
|
||||||
|
|
||||||
|
class TestGetOrganizationMembers:
|
||||||
|
"""Tests for OrganizationService.get_organization_members method."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_organization_members(self, async_test_db, async_test_user):
|
||||||
|
"""Test getting organization members returns paginated results."""
|
||||||
|
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||||
|
async with AsyncTestingSessionLocal() as session:
|
||||||
|
created = await organization_service.create_organization(
|
||||||
|
session, obj_in=_make_org_create()
|
||||||
|
)
|
||||||
|
await organization_service.add_member(
|
||||||
|
session,
|
||||||
|
organization_id=created.id,
|
||||||
|
user_id=async_test_user.id,
|
||||||
|
)
|
||||||
|
|
||||||
|
async with AsyncTestingSessionLocal() as session:
|
||||||
|
members, count = await organization_service.get_organization_members(
|
||||||
|
session, organization_id=created.id, skip=0, limit=10
|
||||||
|
)
|
||||||
|
assert isinstance(members, list)
|
||||||
|
assert isinstance(count, int)
|
||||||
|
assert count >= 1
|
||||||
|
|
||||||
|
|
||||||
|
class TestAddMember:
|
||||||
|
"""Tests for OrganizationService.add_member method."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_add_member_default_role(self, async_test_db, async_test_user):
|
||||||
|
"""Test adding a user to an org with default MEMBER role."""
|
||||||
|
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||||
|
async with AsyncTestingSessionLocal() as session:
|
||||||
|
created = await organization_service.create_organization(
|
||||||
|
session, obj_in=_make_org_create()
|
||||||
|
)
|
||||||
|
|
||||||
|
async with AsyncTestingSessionLocal() as session:
|
||||||
|
membership = await organization_service.add_member(
|
||||||
|
session,
|
||||||
|
organization_id=created.id,
|
||||||
|
user_id=async_test_user.id,
|
||||||
|
)
|
||||||
|
assert membership is not None
|
||||||
|
assert membership.user_id == async_test_user.id
|
||||||
|
assert membership.organization_id == created.id
|
||||||
|
assert membership.role == OrganizationRole.MEMBER
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_add_member_admin_role(self, async_test_db, async_test_user):
|
||||||
|
"""Test adding a user to an org with ADMIN role."""
|
||||||
|
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||||
|
async with AsyncTestingSessionLocal() as session:
|
||||||
|
created = await organization_service.create_organization(
|
||||||
|
session, obj_in=_make_org_create()
|
||||||
|
)
|
||||||
|
|
||||||
|
async with AsyncTestingSessionLocal() as session:
|
||||||
|
membership = await organization_service.add_member(
|
||||||
|
session,
|
||||||
|
organization_id=created.id,
|
||||||
|
user_id=async_test_user.id,
|
||||||
|
role=OrganizationRole.ADMIN,
|
||||||
|
)
|
||||||
|
assert membership.role == OrganizationRole.ADMIN
|
||||||
|
|
||||||
|
|
||||||
|
class TestRemoveMember:
|
||||||
|
"""Tests for OrganizationService.remove_member method."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_remove_member(self, async_test_db, async_test_user):
|
||||||
|
"""Test removing a member from an org returns True."""
|
||||||
|
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||||
|
async with AsyncTestingSessionLocal() as session:
|
||||||
|
created = await organization_service.create_organization(
|
||||||
|
session, obj_in=_make_org_create()
|
||||||
|
)
|
||||||
|
await organization_service.add_member(
|
||||||
|
session,
|
||||||
|
organization_id=created.id,
|
||||||
|
user_id=async_test_user.id,
|
||||||
|
)
|
||||||
|
|
||||||
|
async with AsyncTestingSessionLocal() as session:
|
||||||
|
removed = await organization_service.remove_member(
|
||||||
|
session,
|
||||||
|
organization_id=created.id,
|
||||||
|
user_id=async_test_user.id,
|
||||||
|
)
|
||||||
|
assert removed is True
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_remove_member_not_found(self, async_test_db, async_test_user):
|
||||||
|
"""Test removing a non-member returns False."""
|
||||||
|
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||||
|
async with AsyncTestingSessionLocal() as session:
|
||||||
|
created = await organization_service.create_organization(
|
||||||
|
session, obj_in=_make_org_create()
|
||||||
|
)
|
||||||
|
|
||||||
|
async with AsyncTestingSessionLocal() as session:
|
||||||
|
removed = await organization_service.remove_member(
|
||||||
|
session,
|
||||||
|
organization_id=created.id,
|
||||||
|
user_id=async_test_user.id,
|
||||||
|
)
|
||||||
|
assert removed is False
|
||||||
|
|
||||||
|
|
||||||
|
class TestGetUserRoleInOrg:
|
||||||
|
"""Tests for OrganizationService.get_user_role_in_org method."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_user_role_in_org(self, async_test_db, async_test_user):
|
||||||
|
"""Test getting a user's role in an org they belong to."""
|
||||||
|
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||||
|
async with AsyncTestingSessionLocal() as session:
|
||||||
|
created = await organization_service.create_organization(
|
||||||
|
session, obj_in=_make_org_create()
|
||||||
|
)
|
||||||
|
await organization_service.add_member(
|
||||||
|
session,
|
||||||
|
organization_id=created.id,
|
||||||
|
user_id=async_test_user.id,
|
||||||
|
role=OrganizationRole.MEMBER,
|
||||||
|
)
|
||||||
|
|
||||||
|
async with AsyncTestingSessionLocal() as session:
|
||||||
|
role = await organization_service.get_user_role_in_org(
|
||||||
|
session,
|
||||||
|
user_id=async_test_user.id,
|
||||||
|
organization_id=created.id,
|
||||||
|
)
|
||||||
|
assert role == OrganizationRole.MEMBER
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_user_role_in_org_not_member(
|
||||||
|
self, async_test_db, async_test_user
|
||||||
|
):
|
||||||
|
"""Test getting role for a user not in the org returns None."""
|
||||||
|
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||||
|
async with AsyncTestingSessionLocal() as session:
|
||||||
|
created = await organization_service.create_organization(
|
||||||
|
session, obj_in=_make_org_create()
|
||||||
|
)
|
||||||
|
|
||||||
|
async with AsyncTestingSessionLocal() as session:
|
||||||
|
role = await organization_service.get_user_role_in_org(
|
||||||
|
session,
|
||||||
|
user_id=async_test_user.id,
|
||||||
|
organization_id=created.id,
|
||||||
|
)
|
||||||
|
assert role is None
|
||||||
|
|
||||||
|
|
||||||
|
class TestGetOrgDistribution:
|
||||||
|
"""Tests for OrganizationService.get_org_distribution method."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_org_distribution_empty(self, async_test_db):
|
||||||
|
"""Test org distribution with no memberships returns empty list."""
|
||||||
|
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||||
|
async with AsyncTestingSessionLocal() as session:
|
||||||
|
result = await organization_service.get_org_distribution(session, limit=6)
|
||||||
|
assert isinstance(result, list)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_org_distribution_with_members(
|
||||||
|
self, async_test_db, async_test_user
|
||||||
|
):
|
||||||
|
"""Test org distribution returns org name and member count."""
|
||||||
|
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||||
|
async with AsyncTestingSessionLocal() as session:
|
||||||
|
created = await organization_service.create_organization(
|
||||||
|
session, obj_in=_make_org_create()
|
||||||
|
)
|
||||||
|
await organization_service.add_member(
|
||||||
|
session,
|
||||||
|
organization_id=created.id,
|
||||||
|
user_id=async_test_user.id,
|
||||||
|
)
|
||||||
|
|
||||||
|
async with AsyncTestingSessionLocal() as session:
|
||||||
|
result = await organization_service.get_org_distribution(session, limit=6)
|
||||||
|
assert isinstance(result, list)
|
||||||
|
assert len(result) >= 1
|
||||||
|
entry = result[0]
|
||||||
|
assert "name" in entry
|
||||||
|
assert "value" in entry
|
||||||
|
assert entry["value"] >= 1
|
||||||
292
backend/tests/services/test_session_service.py
Normal file
292
backend/tests/services/test_session_service.py
Normal file
@@ -0,0 +1,292 @@
|
|||||||
|
# tests/services/test_session_service.py
|
||||||
|
"""Tests for the SessionService class."""
|
||||||
|
|
||||||
|
import uuid
|
||||||
|
from datetime import UTC, datetime, timedelta
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import pytest_asyncio
|
||||||
|
|
||||||
|
from app.schemas.sessions import SessionCreate
|
||||||
|
from app.services.session_service import SessionService, session_service
|
||||||
|
|
||||||
|
|
||||||
|
def _make_session_create(user_id, jti=None) -> SessionCreate:
|
||||||
|
"""Helper to build a SessionCreate with sensible defaults."""
|
||||||
|
now = datetime.now(UTC)
|
||||||
|
return SessionCreate(
|
||||||
|
user_id=user_id,
|
||||||
|
refresh_token_jti=jti or str(uuid.uuid4()),
|
||||||
|
ip_address="127.0.0.1",
|
||||||
|
user_agent="pytest/test",
|
||||||
|
device_name="Test Device",
|
||||||
|
device_id="test-device-id",
|
||||||
|
last_used_at=now,
|
||||||
|
expires_at=now + timedelta(days=7),
|
||||||
|
location_city="TestCity",
|
||||||
|
location_country="TestCountry",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestCreateSession:
|
||||||
|
"""Tests for SessionService.create_session method."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_create_session(self, async_test_db, async_test_user):
|
||||||
|
"""Test creating a session returns a UserSession with correct fields."""
|
||||||
|
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||||
|
obj_in = _make_session_create(async_test_user.id)
|
||||||
|
async with AsyncTestingSessionLocal() as session:
|
||||||
|
result = await session_service.create_session(session, obj_in=obj_in)
|
||||||
|
assert result is not None
|
||||||
|
assert result.user_id == async_test_user.id
|
||||||
|
assert result.refresh_token_jti == obj_in.refresh_token_jti
|
||||||
|
assert result.is_active is True
|
||||||
|
assert result.ip_address == "127.0.0.1"
|
||||||
|
|
||||||
|
|
||||||
|
class TestGetSession:
|
||||||
|
"""Tests for SessionService.get_session method."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_session_found(self, async_test_db, async_test_user):
|
||||||
|
"""Test getting a session by ID returns the session."""
|
||||||
|
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||||
|
obj_in = _make_session_create(async_test_user.id)
|
||||||
|
|
||||||
|
async with AsyncTestingSessionLocal() as session:
|
||||||
|
created = await session_service.create_session(session, obj_in=obj_in)
|
||||||
|
|
||||||
|
async with AsyncTestingSessionLocal() as session:
|
||||||
|
result = await session_service.get_session(session, str(created.id))
|
||||||
|
assert result is not None
|
||||||
|
assert result.id == created.id
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_session_not_found(self, async_test_db):
|
||||||
|
"""Test getting a non-existent session returns None."""
|
||||||
|
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||||
|
async with AsyncTestingSessionLocal() as session:
|
||||||
|
result = await session_service.get_session(session, str(uuid.uuid4()))
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
|
||||||
|
class TestGetUserSessions:
|
||||||
|
"""Tests for SessionService.get_user_sessions method."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_user_sessions_active_only(self, async_test_db, async_test_user):
|
||||||
|
"""Test getting active sessions for a user returns only active sessions."""
|
||||||
|
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||||
|
obj_in = _make_session_create(async_test_user.id)
|
||||||
|
|
||||||
|
async with AsyncTestingSessionLocal() as session:
|
||||||
|
await session_service.create_session(session, obj_in=obj_in)
|
||||||
|
|
||||||
|
async with AsyncTestingSessionLocal() as session:
|
||||||
|
sessions = await session_service.get_user_sessions(
|
||||||
|
session, user_id=str(async_test_user.id), active_only=True
|
||||||
|
)
|
||||||
|
assert isinstance(sessions, list)
|
||||||
|
assert len(sessions) >= 1
|
||||||
|
for s in sessions:
|
||||||
|
assert s.is_active is True
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_user_sessions_all(self, async_test_db, async_test_user):
|
||||||
|
"""Test getting all sessions (active and inactive) for a user."""
|
||||||
|
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||||
|
obj_in = _make_session_create(async_test_user.id)
|
||||||
|
|
||||||
|
async with AsyncTestingSessionLocal() as session:
|
||||||
|
created = await session_service.create_session(session, obj_in=obj_in)
|
||||||
|
await session_service.deactivate(session, session_id=str(created.id))
|
||||||
|
|
||||||
|
async with AsyncTestingSessionLocal() as session:
|
||||||
|
sessions = await session_service.get_user_sessions(
|
||||||
|
session, user_id=str(async_test_user.id), active_only=False
|
||||||
|
)
|
||||||
|
assert isinstance(sessions, list)
|
||||||
|
assert len(sessions) >= 1
|
||||||
|
|
||||||
|
|
||||||
|
class TestGetActiveByJti:
|
||||||
|
"""Tests for SessionService.get_active_by_jti method."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_active_by_jti_found(self, async_test_db, async_test_user):
|
||||||
|
"""Test getting an active session by JTI returns the session."""
|
||||||
|
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||||
|
jti = str(uuid.uuid4())
|
||||||
|
obj_in = _make_session_create(async_test_user.id, jti=jti)
|
||||||
|
|
||||||
|
async with AsyncTestingSessionLocal() as session:
|
||||||
|
await session_service.create_session(session, obj_in=obj_in)
|
||||||
|
|
||||||
|
async with AsyncTestingSessionLocal() as session:
|
||||||
|
result = await session_service.get_active_by_jti(session, jti=jti)
|
||||||
|
assert result is not None
|
||||||
|
assert result.refresh_token_jti == jti
|
||||||
|
assert result.is_active is True
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_active_by_jti_not_found(self, async_test_db):
|
||||||
|
"""Test getting an active session by non-existent JTI returns None."""
|
||||||
|
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||||
|
async with AsyncTestingSessionLocal() as session:
|
||||||
|
result = await session_service.get_active_by_jti(
|
||||||
|
session, jti=str(uuid.uuid4())
|
||||||
|
)
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
|
||||||
|
class TestGetByJti:
|
||||||
|
"""Tests for SessionService.get_by_jti method."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_by_jti_active(self, async_test_db, async_test_user):
|
||||||
|
"""Test getting a session (active or inactive) by JTI."""
|
||||||
|
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||||
|
jti = str(uuid.uuid4())
|
||||||
|
obj_in = _make_session_create(async_test_user.id, jti=jti)
|
||||||
|
|
||||||
|
async with AsyncTestingSessionLocal() as session:
|
||||||
|
await session_service.create_session(session, obj_in=obj_in)
|
||||||
|
|
||||||
|
async with AsyncTestingSessionLocal() as session:
|
||||||
|
result = await session_service.get_by_jti(session, jti=jti)
|
||||||
|
assert result is not None
|
||||||
|
assert result.refresh_token_jti == jti
|
||||||
|
|
||||||
|
|
||||||
|
class TestDeactivate:
|
||||||
|
"""Tests for SessionService.deactivate method."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_deactivate_session(self, async_test_db, async_test_user):
|
||||||
|
"""Test deactivating a session sets is_active to False."""
|
||||||
|
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||||
|
obj_in = _make_session_create(async_test_user.id)
|
||||||
|
|
||||||
|
async with AsyncTestingSessionLocal() as session:
|
||||||
|
created = await session_service.create_session(session, obj_in=obj_in)
|
||||||
|
session_id = str(created.id)
|
||||||
|
|
||||||
|
async with AsyncTestingSessionLocal() as session:
|
||||||
|
deactivated = await session_service.deactivate(
|
||||||
|
session, session_id=session_id
|
||||||
|
)
|
||||||
|
assert deactivated is not None
|
||||||
|
assert deactivated.is_active is False
|
||||||
|
|
||||||
|
|
||||||
|
class TestDeactivateAllUserSessions:
|
||||||
|
"""Tests for SessionService.deactivate_all_user_sessions method."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_deactivate_all_user_sessions(self, async_test_db, async_test_user):
|
||||||
|
"""Test deactivating all sessions for a user returns count deactivated."""
|
||||||
|
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||||
|
|
||||||
|
async with AsyncTestingSessionLocal() as session:
|
||||||
|
await session_service.create_session(
|
||||||
|
session, obj_in=_make_session_create(async_test_user.id)
|
||||||
|
)
|
||||||
|
await session_service.create_session(
|
||||||
|
session, obj_in=_make_session_create(async_test_user.id)
|
||||||
|
)
|
||||||
|
|
||||||
|
async with AsyncTestingSessionLocal() as session:
|
||||||
|
count = await session_service.deactivate_all_user_sessions(
|
||||||
|
session, user_id=str(async_test_user.id)
|
||||||
|
)
|
||||||
|
assert count >= 2
|
||||||
|
|
||||||
|
async with AsyncTestingSessionLocal() as session:
|
||||||
|
active_sessions = await session_service.get_user_sessions(
|
||||||
|
session, user_id=str(async_test_user.id), active_only=True
|
||||||
|
)
|
||||||
|
assert len(active_sessions) == 0
|
||||||
|
|
||||||
|
|
||||||
|
class TestUpdateRefreshToken:
|
||||||
|
"""Tests for SessionService.update_refresh_token method."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_update_refresh_token(self, async_test_db, async_test_user):
|
||||||
|
"""Test rotating a session's refresh token updates JTI and expiry."""
|
||||||
|
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||||
|
obj_in = _make_session_create(async_test_user.id)
|
||||||
|
|
||||||
|
async with AsyncTestingSessionLocal() as session:
|
||||||
|
created = await session_service.create_session(session, obj_in=obj_in)
|
||||||
|
session_id = str(created.id)
|
||||||
|
|
||||||
|
new_jti = str(uuid.uuid4())
|
||||||
|
new_expires_at = datetime.now(UTC) + timedelta(days=14)
|
||||||
|
|
||||||
|
async with AsyncTestingSessionLocal() as session:
|
||||||
|
result = await session_service.get_session(session, session_id)
|
||||||
|
updated = await session_service.update_refresh_token(
|
||||||
|
session,
|
||||||
|
session=result,
|
||||||
|
new_jti=new_jti,
|
||||||
|
new_expires_at=new_expires_at,
|
||||||
|
)
|
||||||
|
assert updated.refresh_token_jti == new_jti
|
||||||
|
|
||||||
|
|
||||||
|
class TestCleanupExpiredForUser:
|
||||||
|
"""Tests for SessionService.cleanup_expired_for_user method."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_cleanup_expired_for_user(self, async_test_db, async_test_user):
|
||||||
|
"""Test cleaning up expired inactive sessions returns count removed."""
|
||||||
|
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||||
|
now = datetime.now(UTC)
|
||||||
|
# Create a session that is already expired
|
||||||
|
obj_in = SessionCreate(
|
||||||
|
user_id=async_test_user.id,
|
||||||
|
refresh_token_jti=str(uuid.uuid4()),
|
||||||
|
ip_address="127.0.0.1",
|
||||||
|
user_agent="pytest/test",
|
||||||
|
last_used_at=now - timedelta(days=8),
|
||||||
|
expires_at=now - timedelta(days=1),
|
||||||
|
)
|
||||||
|
|
||||||
|
async with AsyncTestingSessionLocal() as session:
|
||||||
|
created = await session_service.create_session(session, obj_in=obj_in)
|
||||||
|
session_id = str(created.id)
|
||||||
|
|
||||||
|
# Deactivate it so it qualifies for cleanup (requires is_active=False AND expired)
|
||||||
|
async with AsyncTestingSessionLocal() as session:
|
||||||
|
await session_service.deactivate(session, session_id=session_id)
|
||||||
|
|
||||||
|
async with AsyncTestingSessionLocal() as session:
|
||||||
|
count = await session_service.cleanup_expired_for_user(
|
||||||
|
session, user_id=str(async_test_user.id)
|
||||||
|
)
|
||||||
|
assert isinstance(count, int)
|
||||||
|
assert count >= 1
|
||||||
|
|
||||||
|
|
||||||
|
class TestGetAllSessions:
|
||||||
|
"""Tests for SessionService.get_all_sessions method."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_all_sessions(self, async_test_db, async_test_user):
|
||||||
|
"""Test getting all sessions with pagination returns tuple of list and count."""
|
||||||
|
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||||
|
obj_in = _make_session_create(async_test_user.id)
|
||||||
|
|
||||||
|
async with AsyncTestingSessionLocal() as session:
|
||||||
|
await session_service.create_session(session, obj_in=obj_in)
|
||||||
|
|
||||||
|
async with AsyncTestingSessionLocal() as session:
|
||||||
|
sessions, count = await session_service.get_all_sessions(
|
||||||
|
session, skip=0, limit=10, active_only=True, with_user=False
|
||||||
|
)
|
||||||
|
assert isinstance(sessions, list)
|
||||||
|
assert isinstance(count, int)
|
||||||
|
assert count >= 1
|
||||||
|
assert len(sessions) >= 1
|
||||||
214
backend/tests/services/test_user_service.py
Normal file
214
backend/tests/services/test_user_service.py
Normal file
@@ -0,0 +1,214 @@
|
|||||||
|
# tests/services/test_user_service.py
|
||||||
|
"""Tests for the UserService class."""
|
||||||
|
|
||||||
|
import uuid
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import pytest_asyncio
|
||||||
|
from sqlalchemy import select
|
||||||
|
|
||||||
|
from app.core.exceptions import NotFoundError
|
||||||
|
from app.models.user import User
|
||||||
|
from app.schemas.users import UserCreate, UserUpdate
|
||||||
|
from app.services.user_service import UserService, user_service
|
||||||
|
|
||||||
|
|
||||||
|
class TestGetUser:
|
||||||
|
"""Tests for UserService.get_user method."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_user_found(self, async_test_db, async_test_user):
|
||||||
|
"""Test getting an existing user by ID returns the user."""
|
||||||
|
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||||
|
async with AsyncTestingSessionLocal() as session:
|
||||||
|
result = await user_service.get_user(session, str(async_test_user.id))
|
||||||
|
assert result is not None
|
||||||
|
assert result.id == async_test_user.id
|
||||||
|
assert result.email == async_test_user.email
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_user_not_found(self, async_test_db):
|
||||||
|
"""Test getting a non-existent user raises NotFoundError."""
|
||||||
|
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||||
|
non_existent_id = str(uuid.uuid4())
|
||||||
|
async with AsyncTestingSessionLocal() as session:
|
||||||
|
with pytest.raises(NotFoundError):
|
||||||
|
await user_service.get_user(session, non_existent_id)
|
||||||
|
|
||||||
|
|
||||||
|
class TestGetByEmail:
|
||||||
|
"""Tests for UserService.get_by_email method."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_by_email_found(self, async_test_db, async_test_user):
|
||||||
|
"""Test getting an existing user by email returns the user."""
|
||||||
|
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||||
|
async with AsyncTestingSessionLocal() as session:
|
||||||
|
result = await user_service.get_by_email(session, async_test_user.email)
|
||||||
|
assert result is not None
|
||||||
|
assert result.id == async_test_user.id
|
||||||
|
assert result.email == async_test_user.email
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_by_email_not_found(self, async_test_db):
|
||||||
|
"""Test getting a user by non-existent email returns None."""
|
||||||
|
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||||
|
async with AsyncTestingSessionLocal() as session:
|
||||||
|
result = await user_service.get_by_email(session, "nonexistent@example.com")
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
|
||||||
|
class TestCreateUser:
|
||||||
|
"""Tests for UserService.create_user method."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_create_user(self, async_test_db):
|
||||||
|
"""Test creating a new user with valid data."""
|
||||||
|
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||||
|
unique_email = f"test_{uuid.uuid4()}@example.com"
|
||||||
|
user_data = UserCreate(
|
||||||
|
email=unique_email,
|
||||||
|
password="TestPassword123!",
|
||||||
|
first_name="New",
|
||||||
|
last_name="User",
|
||||||
|
)
|
||||||
|
async with AsyncTestingSessionLocal() as session:
|
||||||
|
result = await user_service.create_user(session, user_data)
|
||||||
|
assert result is not None
|
||||||
|
assert result.email == unique_email
|
||||||
|
assert result.first_name == "New"
|
||||||
|
assert result.last_name == "User"
|
||||||
|
assert result.is_active is True
|
||||||
|
|
||||||
|
|
||||||
|
class TestUpdateUser:
|
||||||
|
"""Tests for UserService.update_user method."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_update_user(self, async_test_db, async_test_user):
|
||||||
|
"""Test updating a user's first_name."""
|
||||||
|
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||||
|
async with AsyncTestingSessionLocal() as session:
|
||||||
|
user = await user_service.get_user(session, str(async_test_user.id))
|
||||||
|
updated = await user_service.update_user(
|
||||||
|
session,
|
||||||
|
user=user,
|
||||||
|
obj_in=UserUpdate(first_name="Updated"),
|
||||||
|
)
|
||||||
|
assert updated.first_name == "Updated"
|
||||||
|
assert updated.id == async_test_user.id
|
||||||
|
|
||||||
|
|
||||||
|
class TestSoftDeleteUser:
|
||||||
|
"""Tests for UserService.soft_delete_user method."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_soft_delete_user(self, async_test_db, async_test_user):
|
||||||
|
"""Test soft-deleting a user sets deleted_at."""
|
||||||
|
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||||
|
async with AsyncTestingSessionLocal() as session:
|
||||||
|
await user_service.soft_delete_user(session, str(async_test_user.id))
|
||||||
|
|
||||||
|
async with AsyncTestingSessionLocal() as session:
|
||||||
|
result = await session.execute(
|
||||||
|
select(User).where(User.id == async_test_user.id)
|
||||||
|
)
|
||||||
|
user = result.scalar_one_or_none()
|
||||||
|
assert user is not None
|
||||||
|
assert user.deleted_at is not None
|
||||||
|
|
||||||
|
|
||||||
|
class TestListUsers:
|
||||||
|
"""Tests for UserService.list_users method."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_list_users(self, async_test_db, async_test_user):
|
||||||
|
"""Test listing users with pagination returns correct results."""
|
||||||
|
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||||
|
async with AsyncTestingSessionLocal() as session:
|
||||||
|
users, count = await user_service.list_users(session, skip=0, limit=10)
|
||||||
|
assert isinstance(users, list)
|
||||||
|
assert isinstance(count, int)
|
||||||
|
assert count >= 1
|
||||||
|
assert len(users) >= 1
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_list_users_with_search(self, async_test_db, async_test_user):
|
||||||
|
"""Test listing users with email fragment search returns matching users."""
|
||||||
|
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||||
|
# Search by partial email fragment of the test user
|
||||||
|
email_fragment = async_test_user.email.split("@")[0]
|
||||||
|
async with AsyncTestingSessionLocal() as session:
|
||||||
|
users, count = await user_service.list_users(
|
||||||
|
session, skip=0, limit=10, search=email_fragment
|
||||||
|
)
|
||||||
|
assert isinstance(users, list)
|
||||||
|
assert count >= 1
|
||||||
|
emails = [u.email for u in users]
|
||||||
|
assert async_test_user.email in emails
|
||||||
|
|
||||||
|
|
||||||
|
class TestBulkUpdateStatus:
|
||||||
|
"""Tests for UserService.bulk_update_status method."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_bulk_update_status(self, async_test_db, async_test_user):
|
||||||
|
"""Test bulk activating users returns correct count."""
|
||||||
|
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||||
|
async with AsyncTestingSessionLocal() as session:
|
||||||
|
count = await user_service.bulk_update_status(
|
||||||
|
session,
|
||||||
|
user_ids=[async_test_user.id],
|
||||||
|
is_active=True,
|
||||||
|
)
|
||||||
|
assert count >= 1
|
||||||
|
|
||||||
|
async with AsyncTestingSessionLocal() as session:
|
||||||
|
result = await session.execute(
|
||||||
|
select(User).where(User.id == async_test_user.id)
|
||||||
|
)
|
||||||
|
user = result.scalar_one_or_none()
|
||||||
|
assert user is not None
|
||||||
|
assert user.is_active is True
|
||||||
|
|
||||||
|
|
||||||
|
class TestBulkSoftDelete:
|
||||||
|
"""Tests for UserService.bulk_soft_delete method."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_bulk_soft_delete(self, async_test_db, async_test_user):
|
||||||
|
"""Test bulk soft-deleting users returns correct count."""
|
||||||
|
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||||
|
async with AsyncTestingSessionLocal() as session:
|
||||||
|
count = await user_service.bulk_soft_delete(
|
||||||
|
session,
|
||||||
|
user_ids=[async_test_user.id],
|
||||||
|
)
|
||||||
|
assert count >= 1
|
||||||
|
|
||||||
|
async with AsyncTestingSessionLocal() as session:
|
||||||
|
result = await session.execute(
|
||||||
|
select(User).where(User.id == async_test_user.id)
|
||||||
|
)
|
||||||
|
user = result.scalar_one_or_none()
|
||||||
|
assert user is not None
|
||||||
|
assert user.deleted_at is not None
|
||||||
|
|
||||||
|
|
||||||
|
class TestGetStats:
|
||||||
|
"""Tests for UserService.get_stats method."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_stats(self, async_test_db, async_test_user):
|
||||||
|
"""Test get_stats returns dict with expected keys and correct counts."""
|
||||||
|
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||||
|
async with AsyncTestingSessionLocal() as session:
|
||||||
|
stats = await user_service.get_stats(session)
|
||||||
|
assert "total_users" in stats
|
||||||
|
assert "active_count" in stats
|
||||||
|
assert "inactive_count" in stats
|
||||||
|
assert "all_users" in stats
|
||||||
|
assert stats["total_users"] >= 1
|
||||||
|
assert stats["active_count"] >= 1
|
||||||
|
assert isinstance(stats["all_users"], list)
|
||||||
|
assert len(stats["all_users"]) >= 1
|
||||||
Reference in New Issue
Block a user