forked from cardosofelipe/pragma-stack
refactor(backend): enforce route→service→repo layered architecture
- introduce custom repository exception hierarchy (DuplicateEntryError, IntegrityConstraintError, InvalidInputError) replacing raw ValueError - eliminate all direct repository imports and raw SQL from route layer - add UserService, SessionService, OrganizationService to service layer - add get_stats/get_org_distribution service methods replacing admin inline SQL - fix timing side-channel in authenticate_user via dummy bcrypt check - replace SHA-256 client secret fallback with explicit InvalidClientError - replace assert with InvalidGrantError in authorization code exchange - replace N+1 token revocation loops with bulk UPDATE statements - rename oauth account token fields (drop misleading 'encrypted' suffix) - add Alembic migration 0003 for token field column rename - add 45 new service/repository tests; 975 passing, 94% coverage
This commit is contained in:
@@ -19,14 +19,15 @@ from typing import TypedDict, cast
|
||||
from uuid import UUID
|
||||
|
||||
from authlib.integrations.httpx_client import AsyncOAuth2Client
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.core.auth import create_access_token, create_refresh_token
|
||||
from app.core.config import settings
|
||||
from app.core.exceptions import AuthenticationError
|
||||
from app.crud import oauth_account, oauth_state
|
||||
from app.repositories.oauth_account import oauth_account_repo as oauth_account
|
||||
from app.repositories.oauth_state import oauth_state_repo as oauth_state
|
||||
from app.models.user import User
|
||||
from app.repositories.user import user_repo
|
||||
from app.schemas.oauth import (
|
||||
OAuthAccountCreate,
|
||||
OAuthCallbackResponse,
|
||||
@@ -343,7 +344,7 @@ class OAuthService:
|
||||
await oauth_account.update_tokens(
|
||||
db,
|
||||
account=existing_oauth,
|
||||
access_token_encrypted=token.get("access_token"), refresh_token_encrypted=token.get("refresh_token"), token_expires_at=datetime.now(UTC)
|
||||
access_token=token.get("access_token"), refresh_token=token.get("refresh_token"), token_expires_at=datetime.now(UTC)
|
||||
+ timedelta(seconds=token.get("expires_in", 3600)),
|
||||
)
|
||||
|
||||
@@ -351,10 +352,7 @@ class OAuthService:
|
||||
|
||||
elif state_record.user_id:
|
||||
# Account linking flow (user is already logged in)
|
||||
result = await db.execute(
|
||||
select(User).where(User.id == state_record.user_id)
|
||||
)
|
||||
user = result.scalar_one_or_none()
|
||||
user = await user_repo.get(db, id=str(state_record.user_id))
|
||||
|
||||
if not user:
|
||||
raise AuthenticationError("User not found for account linking")
|
||||
@@ -375,7 +373,7 @@ class OAuthService:
|
||||
provider=provider,
|
||||
provider_user_id=provider_user_id,
|
||||
provider_email=provider_email,
|
||||
access_token_encrypted=token.get("access_token"), refresh_token_encrypted=token.get("refresh_token"), token_expires_at=datetime.now(UTC)
|
||||
access_token=token.get("access_token"), refresh_token=token.get("refresh_token"), token_expires_at=datetime.now(UTC)
|
||||
+ timedelta(seconds=token.get("expires_in", 3600))
|
||||
if token.get("expires_in")
|
||||
else None,
|
||||
@@ -389,10 +387,7 @@ class OAuthService:
|
||||
user = None
|
||||
|
||||
if provider_email and settings.OAUTH_AUTO_LINK_BY_EMAIL:
|
||||
result = await db.execute(
|
||||
select(User).where(User.email == provider_email)
|
||||
)
|
||||
user = result.scalar_one_or_none()
|
||||
user = await user_repo.get_by_email(db, email=provider_email)
|
||||
|
||||
if user:
|
||||
# Auto-link to existing user
|
||||
@@ -416,8 +411,8 @@ class OAuthService:
|
||||
provider=provider,
|
||||
provider_user_id=provider_user_id,
|
||||
provider_email=provider_email,
|
||||
access_token_encrypted=token.get("access_token"),
|
||||
refresh_token_encrypted=token.get("refresh_token"),
|
||||
access_token=token.get("access_token"),
|
||||
refresh_token=token.get("refresh_token"),
|
||||
token_expires_at=datetime.now(UTC)
|
||||
+ timedelta(seconds=token.get("expires_in", 3600))
|
||||
if token.get("expires_in")
|
||||
@@ -644,14 +639,13 @@ class OAuthService:
|
||||
provider=provider,
|
||||
provider_user_id=provider_user_id,
|
||||
provider_email=email,
|
||||
access_token_encrypted=token.get("access_token"), refresh_token_encrypted=token.get("refresh_token"), token_expires_at=datetime.now(UTC)
|
||||
access_token=token.get("access_token"), refresh_token=token.get("refresh_token"), token_expires_at=datetime.now(UTC)
|
||||
+ timedelta(seconds=token.get("expires_in", 3600))
|
||||
if token.get("expires_in")
|
||||
else None,
|
||||
)
|
||||
await oauth_account.create_account(db, obj_in=oauth_create)
|
||||
|
||||
await db.commit()
|
||||
await db.refresh(user)
|
||||
|
||||
return user
|
||||
@@ -701,6 +695,20 @@ class OAuthService:
|
||||
logger.info(f"OAuth provider unlinked: {provider} from {user.email}")
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
async def get_user_accounts(db: AsyncSession, *, user_id: UUID) -> list:
|
||||
"""Get all OAuth accounts linked to a user."""
|
||||
return await oauth_account.get_user_accounts(db, user_id=user_id)
|
||||
|
||||
@staticmethod
|
||||
async def get_user_account_by_provider(
|
||||
db: AsyncSession, *, user_id: UUID, provider: str
|
||||
):
|
||||
"""Get a specific OAuth account for a user and provider."""
|
||||
return await oauth_account.get_user_account_by_provider(
|
||||
db, user_id=user_id, provider=provider
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def cleanup_expired_states(db: AsyncSession) -> int:
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user