Enhance OAuth security and state validation
- Implemented stricter OAuth security measures, including CSRF protection via state parameter validation and redirect_uri checks. - Updated OAuth models to support timezone-aware datetime comparisons, replacing deprecated `utcnow`. - Enhanced logging for malformed Basic auth headers during token, introspect, and revoke requests. - Added allowlist validation for OAuth provider domains to prevent open redirect attacks. - Improved nonce validation for OpenID Connect tokens, ensuring token integrity during Google provider flows. - Updated E2E and unit tests to cover new security features and expanded OAuth state handling scenarios.
This commit is contained in:
@@ -349,22 +349,51 @@ async def exchange_authorization_code(
|
||||
InvalidGrantError: If code is invalid, expired, or already used
|
||||
InvalidClientError: If client validation fails
|
||||
"""
|
||||
# Get and validate authorization code
|
||||
result = await db.execute(
|
||||
select(OAuthAuthorizationCode).where(OAuthAuthorizationCode.code == code)
|
||||
)
|
||||
auth_code = result.scalar_one_or_none()
|
||||
# Atomically mark code as used and fetch it (prevents race condition)
|
||||
# RFC 6749 Section 4.1.2: Authorization codes MUST be single-use
|
||||
from sqlalchemy import update
|
||||
|
||||
if not auth_code:
|
||||
raise InvalidGrantError("Invalid authorization code")
|
||||
|
||||
if auth_code.used:
|
||||
# Code reuse is a security incident - revoke all tokens for this grant
|
||||
logger.warning(
|
||||
f"Authorization code reuse detected for client {auth_code.client_id}"
|
||||
# First, atomically mark the code as used and get affected count
|
||||
update_stmt = (
|
||||
update(OAuthAuthorizationCode)
|
||||
.where(
|
||||
and_(
|
||||
OAuthAuthorizationCode.code == code,
|
||||
OAuthAuthorizationCode.used == False, # noqa: E712
|
||||
)
|
||||
)
|
||||
await revoke_tokens_for_user_client(db, auth_code.user_id, auth_code.client_id)
|
||||
raise InvalidGrantError("Authorization code has already been used")
|
||||
.values(used=True)
|
||||
.returning(OAuthAuthorizationCode.id)
|
||||
)
|
||||
result = await db.execute(update_stmt)
|
||||
updated_id = result.scalar_one_or_none()
|
||||
|
||||
if not updated_id:
|
||||
# Either code doesn't exist or was already used
|
||||
# Check if it exists to provide appropriate error
|
||||
check_result = await db.execute(
|
||||
select(OAuthAuthorizationCode).where(OAuthAuthorizationCode.code == code)
|
||||
)
|
||||
existing_code = check_result.scalar_one_or_none()
|
||||
|
||||
if existing_code and existing_code.used:
|
||||
# Code reuse is a security incident - revoke all tokens for this grant
|
||||
logger.warning(
|
||||
f"Authorization code reuse detected for client {existing_code.client_id}"
|
||||
)
|
||||
await revoke_tokens_for_user_client(
|
||||
db, existing_code.user_id, existing_code.client_id
|
||||
)
|
||||
raise InvalidGrantError("Authorization code has already been used")
|
||||
else:
|
||||
raise InvalidGrantError("Invalid authorization code")
|
||||
|
||||
# Now fetch the full auth code record
|
||||
result = await db.execute(
|
||||
select(OAuthAuthorizationCode).where(OAuthAuthorizationCode.id == updated_id)
|
||||
)
|
||||
auth_code = result.scalar_one()
|
||||
await db.commit()
|
||||
|
||||
if auth_code.is_expired:
|
||||
raise InvalidGrantError("Authorization code has expired")
|
||||
@@ -375,13 +404,19 @@ async def exchange_authorization_code(
|
||||
if auth_code.redirect_uri != redirect_uri:
|
||||
raise InvalidGrantError("redirect_uri mismatch")
|
||||
|
||||
# Validate client
|
||||
client = await validate_client(
|
||||
db,
|
||||
client_id,
|
||||
client_secret,
|
||||
require_secret=(client_secret is not None),
|
||||
)
|
||||
# Validate client - ALWAYS require secret for confidential clients
|
||||
client = await get_client(db, client_id)
|
||||
if not client:
|
||||
raise InvalidClientError("Unknown client_id")
|
||||
|
||||
# Confidential clients MUST authenticate (RFC 6749 Section 3.2.1)
|
||||
if client.client_type == "confidential":
|
||||
if not client_secret:
|
||||
raise InvalidClientError("Client secret required for confidential clients")
|
||||
client = await validate_client(db, client_id, client_secret, require_secret=True)
|
||||
elif client_secret:
|
||||
# Public client provided secret - validate it if given
|
||||
client = await validate_client(db, client_id, client_secret, require_secret=True)
|
||||
|
||||
# Verify PKCE
|
||||
if auth_code.code_challenge:
|
||||
@@ -397,10 +432,6 @@ async def exchange_authorization_code(
|
||||
# Public clients without PKCE - this shouldn't happen if we validated on authorize
|
||||
raise InvalidGrantError("PKCE required for public clients")
|
||||
|
||||
# Mark code as used (single-use)
|
||||
auth_code.used = True
|
||||
await db.commit()
|
||||
|
||||
# Get user
|
||||
user_result = await db.execute(select(User).where(User.id == auth_code.user_id))
|
||||
user = user_result.scalar_one_or_none()
|
||||
|
||||
@@ -246,6 +246,15 @@ class OAuthService:
|
||||
if not state_record:
|
||||
raise AuthenticationError("Invalid or expired OAuth state")
|
||||
|
||||
# SECURITY: Validate redirect_uri matches the one from authorization request
|
||||
# This prevents authorization code injection attacks (RFC 6749 Section 10.6)
|
||||
if state_record.redirect_uri != redirect_uri:
|
||||
logger.warning(
|
||||
f"OAuth redirect_uri mismatch: expected {state_record.redirect_uri}, "
|
||||
f"got {redirect_uri}"
|
||||
)
|
||||
raise AuthenticationError("Redirect URI mismatch")
|
||||
|
||||
# Extract provider from state record (str for type safety)
|
||||
provider: str = str(state_record.provider)
|
||||
|
||||
@@ -272,6 +281,38 @@ class OAuthService:
|
||||
config["token_url"],
|
||||
**token_params,
|
||||
)
|
||||
|
||||
# SECURITY: Validate nonce in ID token for OpenID Connect (Google)
|
||||
# This prevents token replay attacks (OpenID Connect Core 3.1.3.7)
|
||||
if provider == "google" and state_record.nonce:
|
||||
id_token = token.get("id_token")
|
||||
if id_token:
|
||||
import base64
|
||||
import json
|
||||
|
||||
# Decode ID token payload (JWT format: header.payload.signature)
|
||||
try:
|
||||
payload_b64 = id_token.split(".")[1]
|
||||
# Add padding if needed
|
||||
padding = 4 - len(payload_b64) % 4
|
||||
if padding != 4:
|
||||
payload_b64 += "=" * padding
|
||||
payload_json = base64.urlsafe_b64decode(payload_b64)
|
||||
payload = json.loads(payload_json)
|
||||
|
||||
token_nonce = payload.get("nonce")
|
||||
if token_nonce != state_record.nonce:
|
||||
logger.warning(
|
||||
f"OAuth nonce mismatch: expected {state_record.nonce}, "
|
||||
f"got {token_nonce}"
|
||||
)
|
||||
raise AuthenticationError("Invalid ID token nonce")
|
||||
except (IndexError, ValueError, json.JSONDecodeError) as e:
|
||||
logger.warning(f"Failed to decode ID token for nonce validation: {e}")
|
||||
# Continue without nonce validation if ID token is malformed
|
||||
# The token will still be validated when getting user info
|
||||
except AuthenticationError:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"OAuth token exchange failed: {e!s}")
|
||||
raise AuthenticationError("Failed to exchange authorization code")
|
||||
@@ -294,8 +335,9 @@ class OAuthService:
|
||||
# Process user info and create/link account
|
||||
provider_user_id = str(user_info.get("id") or user_info.get("sub"))
|
||||
# Email can be None if user didn't grant email permission
|
||||
# SECURITY: Normalize email (lowercase, strip) to prevent case-based account duplication
|
||||
email_raw = user_info.get("email")
|
||||
provider_email: str | None = str(email_raw) if email_raw else None
|
||||
provider_email: str | None = str(email_raw).lower().strip() if email_raw else None
|
||||
|
||||
if not provider_user_id:
|
||||
raise AuthenticationError("Provider did not return user ID")
|
||||
|
||||
Reference in New Issue
Block a user