forked from cardosofelipe/fast-next-template
Enhance OAuth security, PKCE, and state validation
- Enforced stricter PKCE requirements by rejecting insecure 'plain' method for public clients. - Transitioned client secret hashing to bcrypt for improved security and migration compatibility. - Added constant-time comparison for state parameter validation to prevent timing attacks. - Improved error handling and logging for OAuth workflows, including malformed headers and invalid scopes. - Upgraded Google OIDC token validation to verify both signature and nonce. - Refactored OAuth service methods and schemas for better readability, consistency, and compliance with RFC specifications.
This commit is contained in:
@@ -126,16 +126,22 @@ def hash_token(token: str) -> str:
|
||||
|
||||
|
||||
def verify_pkce(code_verifier: str, code_challenge: str, method: str) -> bool:
|
||||
"""Verify PKCE code_verifier against stored code_challenge."""
|
||||
if method == "S256":
|
||||
# SHA-256 hash, then base64url encode
|
||||
digest = hashlib.sha256(code_verifier.encode()).digest()
|
||||
computed = base64.urlsafe_b64encode(digest).rstrip(b"=").decode()
|
||||
return secrets.compare_digest(computed, code_challenge)
|
||||
elif method == "plain":
|
||||
# Direct comparison (not recommended, but supported)
|
||||
return secrets.compare_digest(code_verifier, code_challenge)
|
||||
return False
|
||||
"""
|
||||
Verify PKCE code_verifier against stored code_challenge.
|
||||
|
||||
SECURITY: Only S256 method is supported. The 'plain' method provides
|
||||
no security benefit and is explicitly rejected per RFC 7636 Section 4.3.
|
||||
"""
|
||||
if method != "S256":
|
||||
# SECURITY: Reject any method other than S256
|
||||
# 'plain' method provides no security against code interception attacks
|
||||
logger.warning(f"PKCE verification rejected for unsupported method: {method}")
|
||||
return False
|
||||
|
||||
# SHA-256 hash, then base64url encode (RFC 7636 Section 4.2)
|
||||
digest = hashlib.sha256(code_verifier.encode()).digest()
|
||||
computed = base64.urlsafe_b64encode(digest).rstrip(b"=").decode()
|
||||
return secrets.compare_digest(computed, code_challenge)
|
||||
|
||||
|
||||
def parse_scope(scope: str) -> list[str]:
|
||||
@@ -198,10 +204,21 @@ async def validate_client(
|
||||
if not client.client_secret_hash:
|
||||
raise InvalidClientError("Client not configured with secret")
|
||||
|
||||
# Verify secret using SHA256 hash (consistent with CRUD)
|
||||
computed_hash = hashlib.sha256(client_secret.encode()).hexdigest()
|
||||
if not secrets.compare_digest(computed_hash, client.client_secret_hash):
|
||||
raise InvalidClientError("Invalid client secret")
|
||||
# SECURITY: Verify secret using bcrypt (not SHA-256)
|
||||
# Supports both bcrypt and legacy SHA-256 hashes for migration
|
||||
from app.core.auth import verify_password
|
||||
|
||||
stored_hash = str(client.client_secret_hash)
|
||||
|
||||
if stored_hash.startswith("$2"):
|
||||
# New bcrypt format
|
||||
if not verify_password(client_secret, stored_hash):
|
||||
raise InvalidClientError("Invalid client secret")
|
||||
else:
|
||||
# Legacy SHA-256 format
|
||||
computed_hash = hashlib.sha256(client_secret.encode()).hexdigest()
|
||||
if not secrets.compare_digest(computed_hash, stored_hash):
|
||||
raise InvalidClientError("Invalid client secret")
|
||||
|
||||
return client
|
||||
|
||||
@@ -246,9 +263,7 @@ def validate_scopes(client: OAuthClient, requested_scopes: list[str]) -> list[st
|
||||
# Warn if some scopes were filtered out
|
||||
invalid = requested - allowed
|
||||
if invalid:
|
||||
logger.warning(
|
||||
f"Client {client.client_id} requested invalid scopes: {invalid}"
|
||||
)
|
||||
logger.warning(f"Client {client.client_id} requested invalid scopes: {invalid}")
|
||||
|
||||
return list(valid)
|
||||
|
||||
@@ -382,17 +397,17 @@ async def exchange_authorization_code(
|
||||
f"Authorization code reuse detected for client {existing_code.client_id}"
|
||||
)
|
||||
await revoke_tokens_for_user_client(
|
||||
db, existing_code.user_id, existing_code.client_id
|
||||
db, UUID(str(existing_code.user_id)), str(existing_code.client_id)
|
||||
)
|
||||
raise InvalidGrantError("Authorization code has already been used")
|
||||
else:
|
||||
raise InvalidGrantError("Invalid authorization code")
|
||||
|
||||
# Now fetch the full auth code record
|
||||
result = await db.execute(
|
||||
auth_code_result = await db.execute(
|
||||
select(OAuthAuthorizationCode).where(OAuthAuthorizationCode.id == updated_id)
|
||||
)
|
||||
auth_code = result.scalar_one()
|
||||
auth_code = auth_code_result.scalar_one()
|
||||
await db.commit()
|
||||
|
||||
if auth_code.is_expired:
|
||||
@@ -413,10 +428,14 @@ async def exchange_authorization_code(
|
||||
if client.client_type == "confidential":
|
||||
if not client_secret:
|
||||
raise InvalidClientError("Client secret required for confidential clients")
|
||||
client = await validate_client(db, client_id, client_secret, require_secret=True)
|
||||
client = await validate_client(
|
||||
db, client_id, client_secret, require_secret=True
|
||||
)
|
||||
elif client_secret:
|
||||
# Public client provided secret - validate it if given
|
||||
client = await validate_client(db, client_id, client_secret, require_secret=True)
|
||||
client = await validate_client(
|
||||
db, client_id, client_secret, require_secret=True
|
||||
)
|
||||
|
||||
# Verify PKCE
|
||||
if auth_code.code_challenge:
|
||||
@@ -424,8 +443,8 @@ async def exchange_authorization_code(
|
||||
raise InvalidGrantError("code_verifier required")
|
||||
if not verify_pkce(
|
||||
code_verifier,
|
||||
auth_code.code_challenge,
|
||||
auth_code.code_challenge_method or "S256",
|
||||
str(auth_code.code_challenge),
|
||||
str(auth_code.code_challenge_method or "S256"),
|
||||
):
|
||||
raise InvalidGrantError("Invalid code_verifier")
|
||||
elif client.client_type == "public":
|
||||
@@ -443,8 +462,8 @@ async def exchange_authorization_code(
|
||||
db=db,
|
||||
client=client,
|
||||
user=user,
|
||||
scope=auth_code.scope,
|
||||
nonce=auth_code.nonce,
|
||||
scope=str(auth_code.scope),
|
||||
nonce=str(auth_code.nonce) if auth_code.nonce else None,
|
||||
device_info=device_info,
|
||||
ip_address=ip_address,
|
||||
)
|
||||
@@ -487,16 +506,20 @@ async def create_tokens(
|
||||
access_expires = now + timedelta(seconds=access_token_lifetime)
|
||||
|
||||
# Refresh token expiry
|
||||
refresh_token_lifetime = int(client.refresh_token_lifetime or str(REFRESH_TOKEN_EXPIRY_DAYS * 86400))
|
||||
refresh_token_lifetime = int(
|
||||
client.refresh_token_lifetime or str(REFRESH_TOKEN_EXPIRY_DAYS * 86400)
|
||||
)
|
||||
refresh_expires = now + timedelta(seconds=refresh_token_lifetime)
|
||||
|
||||
# Create JWT access token
|
||||
# SECURITY: Include all standard JWT claims per RFC 7519
|
||||
access_token_payload = {
|
||||
"iss": settings.OAUTH_ISSUER,
|
||||
"sub": str(user.id),
|
||||
"aud": client.client_id,
|
||||
"exp": int(access_expires.timestamp()),
|
||||
"iat": int(now.timestamp()),
|
||||
"nbf": int(now.timestamp()), # Not Before - token is valid immediately
|
||||
"jti": jti,
|
||||
"scope": scope,
|
||||
"client_id": client.client_id,
|
||||
@@ -581,7 +604,7 @@ async def refresh_tokens(
|
||||
OAuthProviderRefreshToken.token_hash == token_hash
|
||||
)
|
||||
)
|
||||
token_record = result.scalar_one_or_none()
|
||||
token_record: OAuthProviderRefreshToken | None = result.scalar_one_or_none()
|
||||
|
||||
if not token_record:
|
||||
raise InvalidGrantError("Invalid refresh token")
|
||||
@@ -608,36 +631,37 @@ async def refresh_tokens(
|
||||
)
|
||||
|
||||
# Get user
|
||||
user_result = await db.execute(
|
||||
select(User).where(User.id == token_record.user_id)
|
||||
)
|
||||
user_result = await db.execute(select(User).where(User.id == token_record.user_id))
|
||||
user = user_result.scalar_one_or_none()
|
||||
if not user or not user.is_active:
|
||||
raise InvalidGrantError("User not found or inactive")
|
||||
|
||||
# Validate scope (can only reduce, not expand)
|
||||
original_scopes = set(parse_scope(token_record.scope))
|
||||
token_scope = str(token_record.scope) if token_record.scope else ""
|
||||
original_scopes = set(parse_scope(token_scope))
|
||||
if scope:
|
||||
requested_scopes = set(parse_scope(scope))
|
||||
if not requested_scopes.issubset(original_scopes):
|
||||
raise InvalidScopeError("Cannot expand scope on refresh")
|
||||
final_scope = join_scope(list(requested_scopes))
|
||||
else:
|
||||
final_scope = token_record.scope
|
||||
final_scope = token_scope
|
||||
|
||||
# Revoke old refresh token (token rotation)
|
||||
token_record.revoked = True
|
||||
token_record.last_used_at = datetime.now(UTC)
|
||||
token_record.revoked = True # type: ignore[assignment]
|
||||
token_record.last_used_at = datetime.now(UTC) # type: ignore[assignment]
|
||||
await db.commit()
|
||||
|
||||
# Issue new tokens
|
||||
device = str(token_record.device_info) if token_record.device_info else None
|
||||
ip_addr = str(token_record.ip_address) if token_record.ip_address else None
|
||||
return await create_tokens(
|
||||
db=db,
|
||||
client=client,
|
||||
user=user,
|
||||
scope=final_scope,
|
||||
device_info=device_info or token_record.device_info,
|
||||
ip_address=ip_address or token_record.ip_address,
|
||||
device_info=device_info or device,
|
||||
ip_address=ip_address or ip_addr,
|
||||
)
|
||||
|
||||
|
||||
@@ -685,7 +709,7 @@ async def revoke_token(
|
||||
if client_id and refresh_record.client_id != client_id:
|
||||
raise InvalidClientError("Token was not issued to this client")
|
||||
|
||||
refresh_record.revoked = True
|
||||
refresh_record.revoked = True # type: ignore[assignment]
|
||||
await db.commit()
|
||||
logger.info(f"Revoked refresh token {refresh_record.jti[:8]}...")
|
||||
return True
|
||||
@@ -699,7 +723,10 @@ async def revoke_token(
|
||||
token,
|
||||
settings.SECRET_KEY,
|
||||
algorithms=[settings.ALGORITHM],
|
||||
options={"verify_exp": False, "verify_aud": False}, # Allow expired tokens
|
||||
options={
|
||||
"verify_exp": False,
|
||||
"verify_aud": False,
|
||||
}, # Allow expired tokens
|
||||
)
|
||||
jti = payload.get("jti")
|
||||
if jti:
|
||||
@@ -713,7 +740,7 @@ async def revoke_token(
|
||||
if refresh_record:
|
||||
if client_id and refresh_record.client_id != client_id:
|
||||
raise InvalidClientError("Token was not issued to this client")
|
||||
refresh_record.revoked = True
|
||||
refresh_record.revoked = True # type: ignore[assignment]
|
||||
await db.commit()
|
||||
logger.info(
|
||||
f"Revoked refresh token via access token JTI {jti[:8]}..."
|
||||
@@ -756,7 +783,7 @@ async def revoke_tokens_for_user_client(
|
||||
|
||||
count = 0
|
||||
for token in tokens:
|
||||
token.revoked = True
|
||||
token.revoked = True # type: ignore[assignment]
|
||||
count += 1
|
||||
|
||||
if count > 0:
|
||||
@@ -793,7 +820,7 @@ async def revoke_all_user_tokens(db: AsyncSession, user_id: UUID) -> int:
|
||||
|
||||
count = 0
|
||||
for token in tokens:
|
||||
token.revoked = True
|
||||
token.revoked = True # type: ignore[assignment]
|
||||
count += 1
|
||||
|
||||
if count > 0:
|
||||
@@ -843,7 +870,9 @@ async def introspect_token(
|
||||
token,
|
||||
settings.SECRET_KEY,
|
||||
algorithms=[settings.ALGORITHM],
|
||||
options={"verify_aud": False}, # Don't require audience match for introspection
|
||||
options={
|
||||
"verify_aud": False
|
||||
}, # Don't require audience match for introspection
|
||||
)
|
||||
|
||||
# Check if associated refresh token is revoked
|
||||
@@ -953,9 +982,10 @@ async def grant_consent(
|
||||
|
||||
if consent:
|
||||
# Merge scopes
|
||||
existing = set(parse_scope(consent.granted_scopes))
|
||||
granted = str(consent.granted_scopes) if consent.granted_scopes else ""
|
||||
existing = set(parse_scope(granted))
|
||||
new_scopes = existing | set(scopes)
|
||||
consent.granted_scopes = join_scope(list(new_scopes))
|
||||
consent.granted_scopes = join_scope(list(new_scopes)) # type: ignore[assignment]
|
||||
else:
|
||||
consent = OAuthConsent(
|
||||
user_id=user_id,
|
||||
@@ -993,7 +1023,7 @@ async def revoke_consent(
|
||||
await revoke_tokens_for_user_client(db, user_id, client_id)
|
||||
|
||||
await db.commit()
|
||||
return result.rowcount > 0
|
||||
return result.rowcount > 0 # type: ignore[attr-defined]
|
||||
|
||||
|
||||
# ============================================================================
|
||||
@@ -1016,7 +1046,7 @@ async def cleanup_expired_codes(db: AsyncSession) -> int:
|
||||
)
|
||||
)
|
||||
await db.commit()
|
||||
return result.rowcount
|
||||
return result.rowcount # type: ignore[attr-defined]
|
||||
|
||||
|
||||
async def cleanup_expired_tokens(db: AsyncSession) -> int:
|
||||
@@ -1036,4 +1066,4 @@ async def cleanup_expired_tokens(db: AsyncSession) -> int:
|
||||
)
|
||||
)
|
||||
await db.commit()
|
||||
return result.rowcount
|
||||
return result.rowcount # type: ignore[attr-defined]
|
||||
|
||||
Reference in New Issue
Block a user