refactor(backend): enforce route→service→repo layered architecture

- introduce custom repository exception hierarchy (DuplicateEntryError,
  IntegrityConstraintError, InvalidInputError) replacing raw ValueError
- eliminate all direct repository imports and raw SQL from route layer
- add UserService, SessionService, OrganizationService to service layer
- add get_stats/get_org_distribution service methods replacing admin inline SQL
- fix timing side-channel in authenticate_user via dummy bcrypt check
- replace SHA-256 client secret fallback with explicit InvalidClientError
- replace assert with InvalidGrantError in authorization code exchange
- replace N+1 token revocation loops with bulk UPDATE statements
- rename oauth account token fields (drop misleading 'encrypted' suffix)
- add Alembic migration 0003 for token field column rename
- add 45 new service/repository tests; 975 passing, 94% coverage
This commit is contained in:
2026-02-27 09:32:57 +01:00
parent 0646c96b19
commit 98b455fdc3
62 changed files with 2933 additions and 1728 deletions

View File

@@ -26,14 +26,17 @@ from typing import Any
from uuid import UUID
from jose import jwt
from sqlalchemy import and_, delete, select
from sqlalchemy.ext.asyncio import AsyncSession
from app.core.config import settings
from app.models.oauth_authorization_code import OAuthAuthorizationCode
from app.models.oauth_client import OAuthClient
from app.models.oauth_provider_token import OAuthConsent, OAuthProviderRefreshToken
from app.schemas.oauth import OAuthClientCreate
from app.models.user import User
from app.repositories.oauth_authorization_code import oauth_authorization_code_repo
from app.repositories.oauth_client import oauth_client_repo
from app.repositories.oauth_consent import oauth_consent_repo
from app.repositories.oauth_provider_token import oauth_provider_token_repo
from app.repositories.user import user_repo
logger = logging.getLogger(__name__)
@@ -161,15 +164,7 @@ def join_scope(scopes: list[str]) -> str:
async def get_client(db: AsyncSession, client_id: str) -> OAuthClient | None:
"""Get OAuth client by client_id."""
result = await db.execute(
select(OAuthClient).where(
and_(
OAuthClient.client_id == client_id,
OAuthClient.is_active == True, # noqa: E712
)
)
)
return result.scalar_one_or_none()
return await oauth_client_repo.get_by_client_id(db, client_id=client_id)
async def validate_client(
@@ -204,21 +199,19 @@ async def validate_client(
if not client.client_secret_hash:
raise InvalidClientError("Client not configured with secret")
# SECURITY: Verify secret using bcrypt (not SHA-256)
# Supports both bcrypt and legacy SHA-256 hashes for migration
# SECURITY: Verify secret using bcrypt
from app.core.auth import verify_password
stored_hash = str(client.client_secret_hash)
if stored_hash.startswith("$2"):
# New bcrypt format
if not verify_password(client_secret, stored_hash):
raise InvalidClientError("Invalid client secret")
else:
# Legacy SHA-256 format
computed_hash = hashlib.sha256(client_secret.encode()).hexdigest()
if not secrets.compare_digest(computed_hash, stored_hash):
raise InvalidClientError("Invalid client secret")
if not stored_hash.startswith("$2"):
raise InvalidClientError(
"Client secret uses deprecated hash format. "
"Please regenerate your client credentials."
)
if not verify_password(client_secret, stored_hash):
raise InvalidClientError("Invalid client secret")
return client
@@ -311,23 +304,20 @@ async def create_authorization_code(
minutes=AUTHORIZATION_CODE_EXPIRY_MINUTES
)
auth_code = OAuthAuthorizationCode(
await oauth_authorization_code_repo.create_code(
db,
code=code,
client_id=client.client_id,
user_id=user.id,
redirect_uri=redirect_uri,
scope=scope,
expires_at=expires_at,
code_challenge=code_challenge,
code_challenge_method=code_challenge_method,
state=state,
nonce=nonce,
expires_at=expires_at,
used=False,
)
db.add(auth_code)
await db.commit()
logger.info(
f"Created authorization code for user {user.id} and client {client.client_id}"
)
@@ -366,30 +356,14 @@ async def exchange_authorization_code(
"""
# Atomically mark code as used and fetch it (prevents race condition)
# RFC 6749 Section 4.1.2: Authorization codes MUST be single-use
from sqlalchemy import update
# First, atomically mark the code as used and get affected count
update_stmt = (
update(OAuthAuthorizationCode)
.where(
and_(
OAuthAuthorizationCode.code == code,
OAuthAuthorizationCode.used == False, # noqa: E712
)
)
.values(used=True)
.returning(OAuthAuthorizationCode.id)
updated_id = await oauth_authorization_code_repo.consume_code_atomically(
db, code=code
)
result = await db.execute(update_stmt)
updated_id = result.scalar_one_or_none()
if not updated_id:
# Either code doesn't exist or was already used
# Check if it exists to provide appropriate error
check_result = await db.execute(
select(OAuthAuthorizationCode).where(OAuthAuthorizationCode.code == code)
)
existing_code = check_result.scalar_one_or_none()
existing_code = await oauth_authorization_code_repo.get_by_code(db, code=code)
if existing_code and existing_code.used:
# Code reuse is a security incident - revoke all tokens for this grant
@@ -404,11 +378,9 @@ async def exchange_authorization_code(
raise InvalidGrantError("Invalid authorization code")
# Now fetch the full auth code record
auth_code_result = await db.execute(
select(OAuthAuthorizationCode).where(OAuthAuthorizationCode.id == updated_id)
)
auth_code = auth_code_result.scalar_one()
await db.commit()
auth_code = await oauth_authorization_code_repo.get_by_id(db, code_id=updated_id)
if auth_code is None:
raise InvalidGrantError("Authorization code not found after consumption")
if auth_code.is_expired:
raise InvalidGrantError("Authorization code has expired")
@@ -452,8 +424,7 @@ async def exchange_authorization_code(
raise InvalidGrantError("PKCE required for public clients")
# Get user
user_result = await db.execute(select(User).where(User.id == auth_code.user_id))
user = user_result.scalar_one_or_none()
user = await user_repo.get(db, id=str(auth_code.user_id))
if not user or not user.is_active:
raise InvalidGrantError("User not found or inactive")
@@ -543,7 +514,8 @@ async def create_tokens(
refresh_token_hash = hash_token(refresh_token)
# Store refresh token in database
refresh_token_record = OAuthProviderRefreshToken(
await oauth_provider_token_repo.create_token(
db,
token_hash=refresh_token_hash,
jti=jti,
client_id=client.client_id,
@@ -553,8 +525,6 @@ async def create_tokens(
device_info=device_info,
ip_address=ip_address,
)
db.add(refresh_token_record)
await db.commit()
logger.info(f"Issued tokens for user {user.id} to client {client.client_id}")
@@ -599,12 +569,9 @@ async def refresh_tokens(
"""
# Find refresh token
token_hash = hash_token(refresh_token)
result = await db.execute(
select(OAuthProviderRefreshToken).where(
OAuthProviderRefreshToken.token_hash == token_hash
)
token_record = await oauth_provider_token_repo.get_by_token_hash(
db, token_hash=token_hash
)
token_record: OAuthProviderRefreshToken | None = result.scalar_one_or_none()
if not token_record:
raise InvalidGrantError("Invalid refresh token")
@@ -631,8 +598,7 @@ async def refresh_tokens(
)
# Get user
user_result = await db.execute(select(User).where(User.id == token_record.user_id))
user = user_result.scalar_one_or_none()
user = await user_repo.get(db, id=str(token_record.user_id))
if not user or not user.is_active:
raise InvalidGrantError("User not found or inactive")
@@ -648,9 +614,7 @@ async def refresh_tokens(
final_scope = token_scope
# Revoke old refresh token (token rotation)
token_record.revoked = True # type: ignore[assignment]
token_record.last_used_at = datetime.now(UTC) # type: ignore[assignment]
await db.commit()
await oauth_provider_token_repo.revoke(db, token=token_record)
# Issue new tokens
device = str(token_record.device_info) if token_record.device_info else None
@@ -697,20 +661,16 @@ async def revoke_token(
# Try as refresh token first (more likely)
if token_type_hint != "access_token":
token_hash = hash_token(token)
result = await db.execute(
select(OAuthProviderRefreshToken).where(
OAuthProviderRefreshToken.token_hash == token_hash
)
refresh_record = await oauth_provider_token_repo.get_by_token_hash(
db, token_hash=token_hash
)
refresh_record = result.scalar_one_or_none()
if refresh_record:
# Validate client if provided
if client_id and refresh_record.client_id != client_id:
raise InvalidClientError("Token was not issued to this client")
refresh_record.revoked = True # type: ignore[assignment]
await db.commit()
await oauth_provider_token_repo.revoke(db, token=refresh_record)
logger.info(f"Revoked refresh token {refresh_record.jti[:8]}...")
return True
@@ -731,17 +691,13 @@ async def revoke_token(
jti = payload.get("jti")
if jti:
# Find and revoke the associated refresh token
result = await db.execute(
select(OAuthProviderRefreshToken).where(
OAuthProviderRefreshToken.jti == jti
)
refresh_record = await oauth_provider_token_repo.get_by_jti(
db, jti=jti
)
refresh_record = result.scalar_one_or_none()
if refresh_record:
if client_id and refresh_record.client_id != client_id:
raise InvalidClientError("Token was not issued to this client")
refresh_record.revoked = True # type: ignore[assignment]
await db.commit()
await oauth_provider_token_repo.revoke(db, token=refresh_record)
logger.info(
f"Revoked refresh token via access token JTI {jti[:8]}..."
)
@@ -770,24 +726,11 @@ async def revoke_tokens_for_user_client(
Returns:
Number of tokens revoked
"""
result = await db.execute(
select(OAuthProviderRefreshToken).where(
and_(
OAuthProviderRefreshToken.user_id == user_id,
OAuthProviderRefreshToken.client_id == client_id,
OAuthProviderRefreshToken.revoked == False, # noqa: E712
)
)
count = await oauth_provider_token_repo.revoke_all_for_user_client(
db, user_id=user_id, client_id=client_id
)
tokens = result.scalars().all()
count = 0
for token in tokens:
token.revoked = True # type: ignore[assignment]
count += 1
if count > 0:
await db.commit()
logger.warning(
f"Revoked {count} tokens for user {user_id} and client {client_id}"
)
@@ -808,23 +751,9 @@ async def revoke_all_user_tokens(db: AsyncSession, user_id: UUID) -> int:
Returns:
Number of tokens revoked
"""
result = await db.execute(
select(OAuthProviderRefreshToken).where(
and_(
OAuthProviderRefreshToken.user_id == user_id,
OAuthProviderRefreshToken.revoked == False, # noqa: E712
)
)
)
tokens = result.scalars().all()
count = 0
for token in tokens:
token.revoked = True # type: ignore[assignment]
count += 1
count = await oauth_provider_token_repo.revoke_all_for_user(db, user_id=user_id)
if count > 0:
await db.commit()
logger.info(f"Revoked {count} OAuth provider tokens for user {user_id}")
return count
@@ -878,12 +807,9 @@ async def introspect_token(
# Check if associated refresh token is revoked
jti = payload.get("jti")
if jti:
result = await db.execute(
select(OAuthProviderRefreshToken).where(
OAuthProviderRefreshToken.jti == jti
)
refresh_record = await oauth_provider_token_repo.get_by_jti(
db, jti=jti
)
refresh_record = result.scalar_one_or_none()
if refresh_record and refresh_record.revoked:
return {"active": False}
@@ -907,12 +833,9 @@ async def introspect_token(
# Try as refresh token
if token_type_hint != "access_token":
token_hash = hash_token(token)
result = await db.execute(
select(OAuthProviderRefreshToken).where(
OAuthProviderRefreshToken.token_hash == token_hash
)
refresh_record = await oauth_provider_token_repo.get_by_token_hash(
db, token_hash=token_hash
)
refresh_record = result.scalar_one_or_none()
if refresh_record and refresh_record.is_valid:
return {
@@ -937,17 +860,9 @@ async def get_consent(
db: AsyncSession,
user_id: UUID,
client_id: str,
) -> OAuthConsent | None:
):
"""Get existing consent record for user-client pair."""
result = await db.execute(
select(OAuthConsent).where(
and_(
OAuthConsent.user_id == user_id,
OAuthConsent.client_id == client_id,
)
)
)
return result.scalar_one_or_none()
return await oauth_consent_repo.get_consent(db, user_id=user_id, client_id=client_id)
async def check_consent(
@@ -972,31 +887,15 @@ async def grant_consent(
user_id: UUID,
client_id: str,
scopes: list[str],
) -> OAuthConsent:
):
"""
Grant or update consent for a user-client pair.
If consent already exists, updates the granted scopes.
"""
consent = await get_consent(db, user_id, client_id)
if consent:
# Merge scopes
granted = str(consent.granted_scopes) if consent.granted_scopes else ""
existing = set(parse_scope(granted))
new_scopes = existing | set(scopes)
consent.granted_scopes = join_scope(list(new_scopes)) # type: ignore[assignment]
else:
consent = OAuthConsent(
user_id=user_id,
client_id=client_id,
granted_scopes=join_scope(scopes),
)
db.add(consent)
await db.commit()
await db.refresh(consent)
return consent
return await oauth_consent_repo.grant_consent(
db, user_id=user_id, client_id=client_id, scopes=scopes
)
async def revoke_consent(
@@ -1009,21 +908,13 @@ async def revoke_consent(
Returns True if consent was found and revoked.
"""
# Delete consent record
result = await db.execute(
delete(OAuthConsent).where(
and_(
OAuthConsent.user_id == user_id,
OAuthConsent.client_id == client_id,
)
)
)
# Revoke all tokens
# Revoke all tokens first
await revoke_tokens_for_user_client(db, user_id, client_id)
await db.commit()
return result.rowcount > 0 # type: ignore[attr-defined]
# Delete consent record
return await oauth_consent_repo.revoke_consent(
db, user_id=user_id, client_id=client_id
)
# ============================================================================
@@ -1031,6 +922,26 @@ async def revoke_consent(
# ============================================================================
async def register_client(db: AsyncSession, client_data: OAuthClientCreate) -> tuple:
"""Create a new OAuth client. Returns (client, secret)."""
return await oauth_client_repo.create_client(db, obj_in=client_data)
async def list_clients(db: AsyncSession) -> list:
"""List all registered OAuth clients."""
return await oauth_client_repo.get_all_clients(db)
async def delete_client_by_id(db: AsyncSession, client_id: str) -> None:
"""Delete an OAuth client by client_id."""
await oauth_client_repo.delete_client(db, client_id=client_id)
async def list_user_consents(db: AsyncSession, user_id: UUID) -> list[dict]:
"""Get all OAuth consents for a user with client details."""
return await oauth_consent_repo.get_user_consents_with_clients(db, user_id=user_id)
async def cleanup_expired_codes(db: AsyncSession) -> int:
"""
Delete expired authorization codes.
@@ -1040,13 +951,7 @@ async def cleanup_expired_codes(db: AsyncSession) -> int:
Returns:
Number of codes deleted
"""
result = await db.execute(
delete(OAuthAuthorizationCode).where(
OAuthAuthorizationCode.expires_at < datetime.now(UTC)
)
)
await db.commit()
return result.rowcount # type: ignore[attr-defined]
return await oauth_authorization_code_repo.cleanup_expired(db)
async def cleanup_expired_tokens(db: AsyncSession) -> int:
@@ -1058,12 +963,4 @@ async def cleanup_expired_tokens(db: AsyncSession) -> int:
Returns:
Number of tokens deleted
"""
# Delete tokens that are both expired AND revoked (or just very old)
cutoff = datetime.now(UTC) - timedelta(days=7)
result = await db.execute(
delete(OAuthProviderRefreshToken).where(
OAuthProviderRefreshToken.expires_at < cutoff
)
)
await db.commit()
return result.rowcount # type: ignore[attr-defined]
return await oauth_provider_token_repo.cleanup_expired(db, cutoff_days=7)