refactor(backend): enforce route→service→repo layered architecture
- introduce custom repository exception hierarchy (DuplicateEntryError, IntegrityConstraintError, InvalidInputError) replacing raw ValueError - eliminate all direct repository imports and raw SQL from route layer - add UserService, SessionService, OrganizationService to service layer - add get_stats/get_org_distribution service methods replacing admin inline SQL - fix timing side-channel in authenticate_user via dummy bcrypt check - replace SHA-256 client secret fallback with explicit InvalidClientError - replace assert with InvalidGrantError in authorization code exchange - replace N+1 token revocation loops with bulk UPDATE statements - rename oauth account token fields (drop misleading 'encrypted' suffix) - add Alembic migration 0003 for token field column rename - add 45 new service/repository tests; 975 passing, 94% coverage
This commit is contained in:
@@ -26,14 +26,17 @@ from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
from jose import jwt
|
||||
from sqlalchemy import and_, delete, select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.core.config import settings
|
||||
from app.models.oauth_authorization_code import OAuthAuthorizationCode
|
||||
from app.models.oauth_client import OAuthClient
|
||||
from app.models.oauth_provider_token import OAuthConsent, OAuthProviderRefreshToken
|
||||
from app.schemas.oauth import OAuthClientCreate
|
||||
from app.models.user import User
|
||||
from app.repositories.oauth_authorization_code import oauth_authorization_code_repo
|
||||
from app.repositories.oauth_client import oauth_client_repo
|
||||
from app.repositories.oauth_consent import oauth_consent_repo
|
||||
from app.repositories.oauth_provider_token import oauth_provider_token_repo
|
||||
from app.repositories.user import user_repo
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -161,15 +164,7 @@ def join_scope(scopes: list[str]) -> str:
|
||||
|
||||
async def get_client(db: AsyncSession, client_id: str) -> OAuthClient | None:
|
||||
"""Get OAuth client by client_id."""
|
||||
result = await db.execute(
|
||||
select(OAuthClient).where(
|
||||
and_(
|
||||
OAuthClient.client_id == client_id,
|
||||
OAuthClient.is_active == True, # noqa: E712
|
||||
)
|
||||
)
|
||||
)
|
||||
return result.scalar_one_or_none()
|
||||
return await oauth_client_repo.get_by_client_id(db, client_id=client_id)
|
||||
|
||||
|
||||
async def validate_client(
|
||||
@@ -204,21 +199,19 @@ async def validate_client(
|
||||
if not client.client_secret_hash:
|
||||
raise InvalidClientError("Client not configured with secret")
|
||||
|
||||
# SECURITY: Verify secret using bcrypt (not SHA-256)
|
||||
# Supports both bcrypt and legacy SHA-256 hashes for migration
|
||||
# SECURITY: Verify secret using bcrypt
|
||||
from app.core.auth import verify_password
|
||||
|
||||
stored_hash = str(client.client_secret_hash)
|
||||
|
||||
if stored_hash.startswith("$2"):
|
||||
# New bcrypt format
|
||||
if not verify_password(client_secret, stored_hash):
|
||||
raise InvalidClientError("Invalid client secret")
|
||||
else:
|
||||
# Legacy SHA-256 format
|
||||
computed_hash = hashlib.sha256(client_secret.encode()).hexdigest()
|
||||
if not secrets.compare_digest(computed_hash, stored_hash):
|
||||
raise InvalidClientError("Invalid client secret")
|
||||
if not stored_hash.startswith("$2"):
|
||||
raise InvalidClientError(
|
||||
"Client secret uses deprecated hash format. "
|
||||
"Please regenerate your client credentials."
|
||||
)
|
||||
|
||||
if not verify_password(client_secret, stored_hash):
|
||||
raise InvalidClientError("Invalid client secret")
|
||||
|
||||
return client
|
||||
|
||||
@@ -311,23 +304,20 @@ async def create_authorization_code(
|
||||
minutes=AUTHORIZATION_CODE_EXPIRY_MINUTES
|
||||
)
|
||||
|
||||
auth_code = OAuthAuthorizationCode(
|
||||
await oauth_authorization_code_repo.create_code(
|
||||
db,
|
||||
code=code,
|
||||
client_id=client.client_id,
|
||||
user_id=user.id,
|
||||
redirect_uri=redirect_uri,
|
||||
scope=scope,
|
||||
expires_at=expires_at,
|
||||
code_challenge=code_challenge,
|
||||
code_challenge_method=code_challenge_method,
|
||||
state=state,
|
||||
nonce=nonce,
|
||||
expires_at=expires_at,
|
||||
used=False,
|
||||
)
|
||||
|
||||
db.add(auth_code)
|
||||
await db.commit()
|
||||
|
||||
logger.info(
|
||||
f"Created authorization code for user {user.id} and client {client.client_id}"
|
||||
)
|
||||
@@ -366,30 +356,14 @@ async def exchange_authorization_code(
|
||||
"""
|
||||
# Atomically mark code as used and fetch it (prevents race condition)
|
||||
# RFC 6749 Section 4.1.2: Authorization codes MUST be single-use
|
||||
from sqlalchemy import update
|
||||
|
||||
# First, atomically mark the code as used and get affected count
|
||||
update_stmt = (
|
||||
update(OAuthAuthorizationCode)
|
||||
.where(
|
||||
and_(
|
||||
OAuthAuthorizationCode.code == code,
|
||||
OAuthAuthorizationCode.used == False, # noqa: E712
|
||||
)
|
||||
)
|
||||
.values(used=True)
|
||||
.returning(OAuthAuthorizationCode.id)
|
||||
updated_id = await oauth_authorization_code_repo.consume_code_atomically(
|
||||
db, code=code
|
||||
)
|
||||
result = await db.execute(update_stmt)
|
||||
updated_id = result.scalar_one_or_none()
|
||||
|
||||
if not updated_id:
|
||||
# Either code doesn't exist or was already used
|
||||
# Check if it exists to provide appropriate error
|
||||
check_result = await db.execute(
|
||||
select(OAuthAuthorizationCode).where(OAuthAuthorizationCode.code == code)
|
||||
)
|
||||
existing_code = check_result.scalar_one_or_none()
|
||||
existing_code = await oauth_authorization_code_repo.get_by_code(db, code=code)
|
||||
|
||||
if existing_code and existing_code.used:
|
||||
# Code reuse is a security incident - revoke all tokens for this grant
|
||||
@@ -404,11 +378,9 @@ async def exchange_authorization_code(
|
||||
raise InvalidGrantError("Invalid authorization code")
|
||||
|
||||
# Now fetch the full auth code record
|
||||
auth_code_result = await db.execute(
|
||||
select(OAuthAuthorizationCode).where(OAuthAuthorizationCode.id == updated_id)
|
||||
)
|
||||
auth_code = auth_code_result.scalar_one()
|
||||
await db.commit()
|
||||
auth_code = await oauth_authorization_code_repo.get_by_id(db, code_id=updated_id)
|
||||
if auth_code is None:
|
||||
raise InvalidGrantError("Authorization code not found after consumption")
|
||||
|
||||
if auth_code.is_expired:
|
||||
raise InvalidGrantError("Authorization code has expired")
|
||||
@@ -452,8 +424,7 @@ async def exchange_authorization_code(
|
||||
raise InvalidGrantError("PKCE required for public clients")
|
||||
|
||||
# Get user
|
||||
user_result = await db.execute(select(User).where(User.id == auth_code.user_id))
|
||||
user = user_result.scalar_one_or_none()
|
||||
user = await user_repo.get(db, id=str(auth_code.user_id))
|
||||
if not user or not user.is_active:
|
||||
raise InvalidGrantError("User not found or inactive")
|
||||
|
||||
@@ -543,7 +514,8 @@ async def create_tokens(
|
||||
refresh_token_hash = hash_token(refresh_token)
|
||||
|
||||
# Store refresh token in database
|
||||
refresh_token_record = OAuthProviderRefreshToken(
|
||||
await oauth_provider_token_repo.create_token(
|
||||
db,
|
||||
token_hash=refresh_token_hash,
|
||||
jti=jti,
|
||||
client_id=client.client_id,
|
||||
@@ -553,8 +525,6 @@ async def create_tokens(
|
||||
device_info=device_info,
|
||||
ip_address=ip_address,
|
||||
)
|
||||
db.add(refresh_token_record)
|
||||
await db.commit()
|
||||
|
||||
logger.info(f"Issued tokens for user {user.id} to client {client.client_id}")
|
||||
|
||||
@@ -599,12 +569,9 @@ async def refresh_tokens(
|
||||
"""
|
||||
# Find refresh token
|
||||
token_hash = hash_token(refresh_token)
|
||||
result = await db.execute(
|
||||
select(OAuthProviderRefreshToken).where(
|
||||
OAuthProviderRefreshToken.token_hash == token_hash
|
||||
)
|
||||
token_record = await oauth_provider_token_repo.get_by_token_hash(
|
||||
db, token_hash=token_hash
|
||||
)
|
||||
token_record: OAuthProviderRefreshToken | None = result.scalar_one_or_none()
|
||||
|
||||
if not token_record:
|
||||
raise InvalidGrantError("Invalid refresh token")
|
||||
@@ -631,8 +598,7 @@ async def refresh_tokens(
|
||||
)
|
||||
|
||||
# Get user
|
||||
user_result = await db.execute(select(User).where(User.id == token_record.user_id))
|
||||
user = user_result.scalar_one_or_none()
|
||||
user = await user_repo.get(db, id=str(token_record.user_id))
|
||||
if not user or not user.is_active:
|
||||
raise InvalidGrantError("User not found or inactive")
|
||||
|
||||
@@ -648,9 +614,7 @@ async def refresh_tokens(
|
||||
final_scope = token_scope
|
||||
|
||||
# Revoke old refresh token (token rotation)
|
||||
token_record.revoked = True # type: ignore[assignment]
|
||||
token_record.last_used_at = datetime.now(UTC) # type: ignore[assignment]
|
||||
await db.commit()
|
||||
await oauth_provider_token_repo.revoke(db, token=token_record)
|
||||
|
||||
# Issue new tokens
|
||||
device = str(token_record.device_info) if token_record.device_info else None
|
||||
@@ -697,20 +661,16 @@ async def revoke_token(
|
||||
# Try as refresh token first (more likely)
|
||||
if token_type_hint != "access_token":
|
||||
token_hash = hash_token(token)
|
||||
result = await db.execute(
|
||||
select(OAuthProviderRefreshToken).where(
|
||||
OAuthProviderRefreshToken.token_hash == token_hash
|
||||
)
|
||||
refresh_record = await oauth_provider_token_repo.get_by_token_hash(
|
||||
db, token_hash=token_hash
|
||||
)
|
||||
refresh_record = result.scalar_one_or_none()
|
||||
|
||||
if refresh_record:
|
||||
# Validate client if provided
|
||||
if client_id and refresh_record.client_id != client_id:
|
||||
raise InvalidClientError("Token was not issued to this client")
|
||||
|
||||
refresh_record.revoked = True # type: ignore[assignment]
|
||||
await db.commit()
|
||||
await oauth_provider_token_repo.revoke(db, token=refresh_record)
|
||||
logger.info(f"Revoked refresh token {refresh_record.jti[:8]}...")
|
||||
return True
|
||||
|
||||
@@ -731,17 +691,13 @@ async def revoke_token(
|
||||
jti = payload.get("jti")
|
||||
if jti:
|
||||
# Find and revoke the associated refresh token
|
||||
result = await db.execute(
|
||||
select(OAuthProviderRefreshToken).where(
|
||||
OAuthProviderRefreshToken.jti == jti
|
||||
)
|
||||
refresh_record = await oauth_provider_token_repo.get_by_jti(
|
||||
db, jti=jti
|
||||
)
|
||||
refresh_record = result.scalar_one_or_none()
|
||||
if refresh_record:
|
||||
if client_id and refresh_record.client_id != client_id:
|
||||
raise InvalidClientError("Token was not issued to this client")
|
||||
refresh_record.revoked = True # type: ignore[assignment]
|
||||
await db.commit()
|
||||
await oauth_provider_token_repo.revoke(db, token=refresh_record)
|
||||
logger.info(
|
||||
f"Revoked refresh token via access token JTI {jti[:8]}..."
|
||||
)
|
||||
@@ -770,24 +726,11 @@ async def revoke_tokens_for_user_client(
|
||||
Returns:
|
||||
Number of tokens revoked
|
||||
"""
|
||||
result = await db.execute(
|
||||
select(OAuthProviderRefreshToken).where(
|
||||
and_(
|
||||
OAuthProviderRefreshToken.user_id == user_id,
|
||||
OAuthProviderRefreshToken.client_id == client_id,
|
||||
OAuthProviderRefreshToken.revoked == False, # noqa: E712
|
||||
)
|
||||
)
|
||||
count = await oauth_provider_token_repo.revoke_all_for_user_client(
|
||||
db, user_id=user_id, client_id=client_id
|
||||
)
|
||||
tokens = result.scalars().all()
|
||||
|
||||
count = 0
|
||||
for token in tokens:
|
||||
token.revoked = True # type: ignore[assignment]
|
||||
count += 1
|
||||
|
||||
if count > 0:
|
||||
await db.commit()
|
||||
logger.warning(
|
||||
f"Revoked {count} tokens for user {user_id} and client {client_id}"
|
||||
)
|
||||
@@ -808,23 +751,9 @@ async def revoke_all_user_tokens(db: AsyncSession, user_id: UUID) -> int:
|
||||
Returns:
|
||||
Number of tokens revoked
|
||||
"""
|
||||
result = await db.execute(
|
||||
select(OAuthProviderRefreshToken).where(
|
||||
and_(
|
||||
OAuthProviderRefreshToken.user_id == user_id,
|
||||
OAuthProviderRefreshToken.revoked == False, # noqa: E712
|
||||
)
|
||||
)
|
||||
)
|
||||
tokens = result.scalars().all()
|
||||
|
||||
count = 0
|
||||
for token in tokens:
|
||||
token.revoked = True # type: ignore[assignment]
|
||||
count += 1
|
||||
count = await oauth_provider_token_repo.revoke_all_for_user(db, user_id=user_id)
|
||||
|
||||
if count > 0:
|
||||
await db.commit()
|
||||
logger.info(f"Revoked {count} OAuth provider tokens for user {user_id}")
|
||||
|
||||
return count
|
||||
@@ -878,12 +807,9 @@ async def introspect_token(
|
||||
# Check if associated refresh token is revoked
|
||||
jti = payload.get("jti")
|
||||
if jti:
|
||||
result = await db.execute(
|
||||
select(OAuthProviderRefreshToken).where(
|
||||
OAuthProviderRefreshToken.jti == jti
|
||||
)
|
||||
refresh_record = await oauth_provider_token_repo.get_by_jti(
|
||||
db, jti=jti
|
||||
)
|
||||
refresh_record = result.scalar_one_or_none()
|
||||
if refresh_record and refresh_record.revoked:
|
||||
return {"active": False}
|
||||
|
||||
@@ -907,12 +833,9 @@ async def introspect_token(
|
||||
# Try as refresh token
|
||||
if token_type_hint != "access_token":
|
||||
token_hash = hash_token(token)
|
||||
result = await db.execute(
|
||||
select(OAuthProviderRefreshToken).where(
|
||||
OAuthProviderRefreshToken.token_hash == token_hash
|
||||
)
|
||||
refresh_record = await oauth_provider_token_repo.get_by_token_hash(
|
||||
db, token_hash=token_hash
|
||||
)
|
||||
refresh_record = result.scalar_one_or_none()
|
||||
|
||||
if refresh_record and refresh_record.is_valid:
|
||||
return {
|
||||
@@ -937,17 +860,9 @@ async def get_consent(
|
||||
db: AsyncSession,
|
||||
user_id: UUID,
|
||||
client_id: str,
|
||||
) -> OAuthConsent | None:
|
||||
):
|
||||
"""Get existing consent record for user-client pair."""
|
||||
result = await db.execute(
|
||||
select(OAuthConsent).where(
|
||||
and_(
|
||||
OAuthConsent.user_id == user_id,
|
||||
OAuthConsent.client_id == client_id,
|
||||
)
|
||||
)
|
||||
)
|
||||
return result.scalar_one_or_none()
|
||||
return await oauth_consent_repo.get_consent(db, user_id=user_id, client_id=client_id)
|
||||
|
||||
|
||||
async def check_consent(
|
||||
@@ -972,31 +887,15 @@ async def grant_consent(
|
||||
user_id: UUID,
|
||||
client_id: str,
|
||||
scopes: list[str],
|
||||
) -> OAuthConsent:
|
||||
):
|
||||
"""
|
||||
Grant or update consent for a user-client pair.
|
||||
|
||||
If consent already exists, updates the granted scopes.
|
||||
"""
|
||||
consent = await get_consent(db, user_id, client_id)
|
||||
|
||||
if consent:
|
||||
# Merge scopes
|
||||
granted = str(consent.granted_scopes) if consent.granted_scopes else ""
|
||||
existing = set(parse_scope(granted))
|
||||
new_scopes = existing | set(scopes)
|
||||
consent.granted_scopes = join_scope(list(new_scopes)) # type: ignore[assignment]
|
||||
else:
|
||||
consent = OAuthConsent(
|
||||
user_id=user_id,
|
||||
client_id=client_id,
|
||||
granted_scopes=join_scope(scopes),
|
||||
)
|
||||
db.add(consent)
|
||||
|
||||
await db.commit()
|
||||
await db.refresh(consent)
|
||||
return consent
|
||||
return await oauth_consent_repo.grant_consent(
|
||||
db, user_id=user_id, client_id=client_id, scopes=scopes
|
||||
)
|
||||
|
||||
|
||||
async def revoke_consent(
|
||||
@@ -1009,21 +908,13 @@ async def revoke_consent(
|
||||
|
||||
Returns True if consent was found and revoked.
|
||||
"""
|
||||
# Delete consent record
|
||||
result = await db.execute(
|
||||
delete(OAuthConsent).where(
|
||||
and_(
|
||||
OAuthConsent.user_id == user_id,
|
||||
OAuthConsent.client_id == client_id,
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
# Revoke all tokens
|
||||
# Revoke all tokens first
|
||||
await revoke_tokens_for_user_client(db, user_id, client_id)
|
||||
|
||||
await db.commit()
|
||||
return result.rowcount > 0 # type: ignore[attr-defined]
|
||||
# Delete consent record
|
||||
return await oauth_consent_repo.revoke_consent(
|
||||
db, user_id=user_id, client_id=client_id
|
||||
)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
@@ -1031,6 +922,26 @@ async def revoke_consent(
|
||||
# ============================================================================
|
||||
|
||||
|
||||
async def register_client(db: AsyncSession, client_data: OAuthClientCreate) -> tuple:
|
||||
"""Create a new OAuth client. Returns (client, secret)."""
|
||||
return await oauth_client_repo.create_client(db, obj_in=client_data)
|
||||
|
||||
|
||||
async def list_clients(db: AsyncSession) -> list:
|
||||
"""List all registered OAuth clients."""
|
||||
return await oauth_client_repo.get_all_clients(db)
|
||||
|
||||
|
||||
async def delete_client_by_id(db: AsyncSession, client_id: str) -> None:
|
||||
"""Delete an OAuth client by client_id."""
|
||||
await oauth_client_repo.delete_client(db, client_id=client_id)
|
||||
|
||||
|
||||
async def list_user_consents(db: AsyncSession, user_id: UUID) -> list[dict]:
|
||||
"""Get all OAuth consents for a user with client details."""
|
||||
return await oauth_consent_repo.get_user_consents_with_clients(db, user_id=user_id)
|
||||
|
||||
|
||||
async def cleanup_expired_codes(db: AsyncSession) -> int:
|
||||
"""
|
||||
Delete expired authorization codes.
|
||||
@@ -1040,13 +951,7 @@ async def cleanup_expired_codes(db: AsyncSession) -> int:
|
||||
Returns:
|
||||
Number of codes deleted
|
||||
"""
|
||||
result = await db.execute(
|
||||
delete(OAuthAuthorizationCode).where(
|
||||
OAuthAuthorizationCode.expires_at < datetime.now(UTC)
|
||||
)
|
||||
)
|
||||
await db.commit()
|
||||
return result.rowcount # type: ignore[attr-defined]
|
||||
return await oauth_authorization_code_repo.cleanup_expired(db)
|
||||
|
||||
|
||||
async def cleanup_expired_tokens(db: AsyncSession) -> int:
|
||||
@@ -1058,12 +963,4 @@ async def cleanup_expired_tokens(db: AsyncSession) -> int:
|
||||
Returns:
|
||||
Number of tokens deleted
|
||||
"""
|
||||
# Delete tokens that are both expired AND revoked (or just very old)
|
||||
cutoff = datetime.now(UTC) - timedelta(days=7)
|
||||
result = await db.execute(
|
||||
delete(OAuthProviderRefreshToken).where(
|
||||
OAuthProviderRefreshToken.expires_at < cutoff
|
||||
)
|
||||
)
|
||||
await db.commit()
|
||||
return result.rowcount # type: ignore[attr-defined]
|
||||
return await oauth_provider_token_repo.cleanup_expired(db, cutoff_days=7)
|
||||
|
||||
Reference in New Issue
Block a user