Enhance OAuth security, PKCE, and state validation

- Enforced stricter PKCE requirements by rejecting insecure 'plain' method for public clients.
- Transitioned client secret hashing to bcrypt for improved security and migration compatibility.
- Added constant-time comparison for state parameter validation to prevent timing attacks.
- Improved error handling and logging for OAuth workflows, including malformed headers and invalid scopes.
- Upgraded Google OIDC token validation to verify both signature and nonce.
- Refactored OAuth service methods and schemas for better readability, consistency, and compliance with RFC specifications.
This commit is contained in:
Felipe Cardoso
2025-11-26 00:14:26 +01:00
parent 0ea428b718
commit dc875c5c95
6 changed files with 284 additions and 159 deletions

View File

@@ -126,16 +126,22 @@ def hash_token(token: str) -> str:
def verify_pkce(code_verifier: str, code_challenge: str, method: str) -> bool:
"""Verify PKCE code_verifier against stored code_challenge."""
if method == "S256":
# SHA-256 hash, then base64url encode
digest = hashlib.sha256(code_verifier.encode()).digest()
computed = base64.urlsafe_b64encode(digest).rstrip(b"=").decode()
return secrets.compare_digest(computed, code_challenge)
elif method == "plain":
# Direct comparison (not recommended, but supported)
return secrets.compare_digest(code_verifier, code_challenge)
return False
"""
Verify PKCE code_verifier against stored code_challenge.
SECURITY: Only S256 method is supported. The 'plain' method provides
no security benefit and is explicitly rejected per RFC 7636 Section 4.3.
"""
if method != "S256":
# SECURITY: Reject any method other than S256
# 'plain' method provides no security against code interception attacks
logger.warning(f"PKCE verification rejected for unsupported method: {method}")
return False
# SHA-256 hash, then base64url encode (RFC 7636 Section 4.2)
digest = hashlib.sha256(code_verifier.encode()).digest()
computed = base64.urlsafe_b64encode(digest).rstrip(b"=").decode()
return secrets.compare_digest(computed, code_challenge)
def parse_scope(scope: str) -> list[str]:
@@ -198,10 +204,21 @@ async def validate_client(
if not client.client_secret_hash:
raise InvalidClientError("Client not configured with secret")
# Verify secret using SHA256 hash (consistent with CRUD)
computed_hash = hashlib.sha256(client_secret.encode()).hexdigest()
if not secrets.compare_digest(computed_hash, client.client_secret_hash):
raise InvalidClientError("Invalid client secret")
# SECURITY: Verify secret using bcrypt (not SHA-256)
# Supports both bcrypt and legacy SHA-256 hashes for migration
from app.core.auth import verify_password
stored_hash = str(client.client_secret_hash)
if stored_hash.startswith("$2"):
# New bcrypt format
if not verify_password(client_secret, stored_hash):
raise InvalidClientError("Invalid client secret")
else:
# Legacy SHA-256 format
computed_hash = hashlib.sha256(client_secret.encode()).hexdigest()
if not secrets.compare_digest(computed_hash, stored_hash):
raise InvalidClientError("Invalid client secret")
return client
@@ -246,9 +263,7 @@ def validate_scopes(client: OAuthClient, requested_scopes: list[str]) -> list[st
# Warn if some scopes were filtered out
invalid = requested - allowed
if invalid:
logger.warning(
f"Client {client.client_id} requested invalid scopes: {invalid}"
)
logger.warning(f"Client {client.client_id} requested invalid scopes: {invalid}")
return list(valid)
@@ -382,17 +397,17 @@ async def exchange_authorization_code(
f"Authorization code reuse detected for client {existing_code.client_id}"
)
await revoke_tokens_for_user_client(
db, existing_code.user_id, existing_code.client_id
db, UUID(str(existing_code.user_id)), str(existing_code.client_id)
)
raise InvalidGrantError("Authorization code has already been used")
else:
raise InvalidGrantError("Invalid authorization code")
# Now fetch the full auth code record
result = await db.execute(
auth_code_result = await db.execute(
select(OAuthAuthorizationCode).where(OAuthAuthorizationCode.id == updated_id)
)
auth_code = result.scalar_one()
auth_code = auth_code_result.scalar_one()
await db.commit()
if auth_code.is_expired:
@@ -413,10 +428,14 @@ async def exchange_authorization_code(
if client.client_type == "confidential":
if not client_secret:
raise InvalidClientError("Client secret required for confidential clients")
client = await validate_client(db, client_id, client_secret, require_secret=True)
client = await validate_client(
db, client_id, client_secret, require_secret=True
)
elif client_secret:
# Public client provided secret - validate it if given
client = await validate_client(db, client_id, client_secret, require_secret=True)
client = await validate_client(
db, client_id, client_secret, require_secret=True
)
# Verify PKCE
if auth_code.code_challenge:
@@ -424,8 +443,8 @@ async def exchange_authorization_code(
raise InvalidGrantError("code_verifier required")
if not verify_pkce(
code_verifier,
auth_code.code_challenge,
auth_code.code_challenge_method or "S256",
str(auth_code.code_challenge),
str(auth_code.code_challenge_method or "S256"),
):
raise InvalidGrantError("Invalid code_verifier")
elif client.client_type == "public":
@@ -443,8 +462,8 @@ async def exchange_authorization_code(
db=db,
client=client,
user=user,
scope=auth_code.scope,
nonce=auth_code.nonce,
scope=str(auth_code.scope),
nonce=str(auth_code.nonce) if auth_code.nonce else None,
device_info=device_info,
ip_address=ip_address,
)
@@ -487,16 +506,20 @@ async def create_tokens(
access_expires = now + timedelta(seconds=access_token_lifetime)
# Refresh token expiry
refresh_token_lifetime = int(client.refresh_token_lifetime or str(REFRESH_TOKEN_EXPIRY_DAYS * 86400))
refresh_token_lifetime = int(
client.refresh_token_lifetime or str(REFRESH_TOKEN_EXPIRY_DAYS * 86400)
)
refresh_expires = now + timedelta(seconds=refresh_token_lifetime)
# Create JWT access token
# SECURITY: Include all standard JWT claims per RFC 7519
access_token_payload = {
"iss": settings.OAUTH_ISSUER,
"sub": str(user.id),
"aud": client.client_id,
"exp": int(access_expires.timestamp()),
"iat": int(now.timestamp()),
"nbf": int(now.timestamp()), # Not Before - token is valid immediately
"jti": jti,
"scope": scope,
"client_id": client.client_id,
@@ -581,7 +604,7 @@ async def refresh_tokens(
OAuthProviderRefreshToken.token_hash == token_hash
)
)
token_record = result.scalar_one_or_none()
token_record: OAuthProviderRefreshToken | None = result.scalar_one_or_none()
if not token_record:
raise InvalidGrantError("Invalid refresh token")
@@ -608,36 +631,37 @@ async def refresh_tokens(
)
# Get user
user_result = await db.execute(
select(User).where(User.id == token_record.user_id)
)
user_result = await db.execute(select(User).where(User.id == token_record.user_id))
user = user_result.scalar_one_or_none()
if not user or not user.is_active:
raise InvalidGrantError("User not found or inactive")
# Validate scope (can only reduce, not expand)
original_scopes = set(parse_scope(token_record.scope))
token_scope = str(token_record.scope) if token_record.scope else ""
original_scopes = set(parse_scope(token_scope))
if scope:
requested_scopes = set(parse_scope(scope))
if not requested_scopes.issubset(original_scopes):
raise InvalidScopeError("Cannot expand scope on refresh")
final_scope = join_scope(list(requested_scopes))
else:
final_scope = token_record.scope
final_scope = token_scope
# Revoke old refresh token (token rotation)
token_record.revoked = True
token_record.last_used_at = datetime.now(UTC)
token_record.revoked = True # type: ignore[assignment]
token_record.last_used_at = datetime.now(UTC) # type: ignore[assignment]
await db.commit()
# Issue new tokens
device = str(token_record.device_info) if token_record.device_info else None
ip_addr = str(token_record.ip_address) if token_record.ip_address else None
return await create_tokens(
db=db,
client=client,
user=user,
scope=final_scope,
device_info=device_info or token_record.device_info,
ip_address=ip_address or token_record.ip_address,
device_info=device_info or device,
ip_address=ip_address or ip_addr,
)
@@ -685,7 +709,7 @@ async def revoke_token(
if client_id and refresh_record.client_id != client_id:
raise InvalidClientError("Token was not issued to this client")
refresh_record.revoked = True
refresh_record.revoked = True # type: ignore[assignment]
await db.commit()
logger.info(f"Revoked refresh token {refresh_record.jti[:8]}...")
return True
@@ -699,7 +723,10 @@ async def revoke_token(
token,
settings.SECRET_KEY,
algorithms=[settings.ALGORITHM],
options={"verify_exp": False, "verify_aud": False}, # Allow expired tokens
options={
"verify_exp": False,
"verify_aud": False,
}, # Allow expired tokens
)
jti = payload.get("jti")
if jti:
@@ -713,7 +740,7 @@ async def revoke_token(
if refresh_record:
if client_id and refresh_record.client_id != client_id:
raise InvalidClientError("Token was not issued to this client")
refresh_record.revoked = True
refresh_record.revoked = True # type: ignore[assignment]
await db.commit()
logger.info(
f"Revoked refresh token via access token JTI {jti[:8]}..."
@@ -756,7 +783,7 @@ async def revoke_tokens_for_user_client(
count = 0
for token in tokens:
token.revoked = True
token.revoked = True # type: ignore[assignment]
count += 1
if count > 0:
@@ -793,7 +820,7 @@ async def revoke_all_user_tokens(db: AsyncSession, user_id: UUID) -> int:
count = 0
for token in tokens:
token.revoked = True
token.revoked = True # type: ignore[assignment]
count += 1
if count > 0:
@@ -843,7 +870,9 @@ async def introspect_token(
token,
settings.SECRET_KEY,
algorithms=[settings.ALGORITHM],
options={"verify_aud": False}, # Don't require audience match for introspection
options={
"verify_aud": False
}, # Don't require audience match for introspection
)
# Check if associated refresh token is revoked
@@ -953,9 +982,10 @@ async def grant_consent(
if consent:
# Merge scopes
existing = set(parse_scope(consent.granted_scopes))
granted = str(consent.granted_scopes) if consent.granted_scopes else ""
existing = set(parse_scope(granted))
new_scopes = existing | set(scopes)
consent.granted_scopes = join_scope(list(new_scopes))
consent.granted_scopes = join_scope(list(new_scopes)) # type: ignore[assignment]
else:
consent = OAuthConsent(
user_id=user_id,
@@ -993,7 +1023,7 @@ async def revoke_consent(
await revoke_tokens_for_user_client(db, user_id, client_id)
await db.commit()
return result.rowcount > 0
return result.rowcount > 0 # type: ignore[attr-defined]
# ============================================================================
@@ -1016,7 +1046,7 @@ async def cleanup_expired_codes(db: AsyncSession) -> int:
)
)
await db.commit()
return result.rowcount
return result.rowcount # type: ignore[attr-defined]
async def cleanup_expired_tokens(db: AsyncSession) -> int:
@@ -1036,4 +1066,4 @@ async def cleanup_expired_tokens(db: AsyncSession) -> int:
)
)
await db.commit()
return result.rowcount
return result.rowcount # type: ignore[attr-defined]

View File

@@ -282,35 +282,16 @@ class OAuthService:
**token_params,
)
# SECURITY: Validate nonce in ID token for OpenID Connect (Google)
# This prevents token replay attacks (OpenID Connect Core 3.1.3.7)
# SECURITY: Validate ID token signature and nonce for OpenID Connect
# This prevents token forgery and replay attacks (OIDC Core 3.1.3.7)
if provider == "google" and state_record.nonce:
id_token = token.get("id_token")
if id_token:
import base64
import json
# Decode ID token payload (JWT format: header.payload.signature)
try:
payload_b64 = id_token.split(".")[1]
# Add padding if needed
padding = 4 - len(payload_b64) % 4
if padding != 4:
payload_b64 += "=" * padding
payload_json = base64.urlsafe_b64decode(payload_b64)
payload = json.loads(payload_json)
token_nonce = payload.get("nonce")
if token_nonce != state_record.nonce:
logger.warning(
f"OAuth nonce mismatch: expected {state_record.nonce}, "
f"got {token_nonce}"
)
raise AuthenticationError("Invalid ID token nonce")
except (IndexError, ValueError, json.JSONDecodeError) as e:
logger.warning(f"Failed to decode ID token for nonce validation: {e}")
# Continue without nonce validation if ID token is malformed
# The token will still be validated when getting user info
await OAuthService._verify_google_id_token(
id_token=str(id_token),
expected_nonce=str(state_record.nonce),
client_id=client_id,
)
except AuthenticationError:
raise
except Exception as e:
@@ -337,7 +318,9 @@ class OAuthService:
# Email can be None if user didn't grant email permission
# SECURITY: Normalize email (lowercase, strip) to prevent case-based account duplication
email_raw = user_info.get("email")
provider_email: str | None = str(email_raw).lower().strip() if email_raw else None
provider_email: str | None = (
str(email_raw).lower().strip() if email_raw else None
)
if not provider_user_id:
raise AuthenticationError("Provider did not return user ID")
@@ -521,6 +504,106 @@ class OAuthService:
return user_info
# Google's OIDC configuration endpoints
GOOGLE_JWKS_URL = "https://www.googleapis.com/oauth2/v3/certs"
GOOGLE_ISSUERS = ("https://accounts.google.com", "accounts.google.com")
@staticmethod
async def _verify_google_id_token(
id_token: str,
expected_nonce: str,
client_id: str,
) -> dict[str, object]:
"""
Verify Google ID token signature and claims.
SECURITY: This properly verifies the ID token by:
1. Fetching Google's public keys (JWKS)
2. Verifying the JWT signature against the public key
3. Validating issuer, audience, expiry, and nonce claims
Args:
id_token: The ID token JWT string
expected_nonce: The nonce we sent in the authorization request
client_id: Our OAuth client ID (expected audience)
Returns:
Decoded ID token payload
Raises:
AuthenticationError: If verification fails
"""
import httpx
from jose import jwt as jose_jwt
from jose.exceptions import JWTError
try:
# Fetch Google's public keys (JWKS)
# In production, consider caching this with TTL matching Cache-Control header
async with httpx.AsyncClient() as client:
jwks_response = await client.get(
OAuthService.GOOGLE_JWKS_URL,
timeout=10.0,
)
jwks_response.raise_for_status()
jwks = jwks_response.json()
# Get the key ID from the token header
unverified_header = jose_jwt.get_unverified_header(id_token)
kid = unverified_header.get("kid")
if not kid:
raise AuthenticationError("ID token missing key ID (kid)")
# Find the matching public key
public_key = None
for key in jwks.get("keys", []):
if key.get("kid") == kid:
public_key = key
break
if not public_key:
raise AuthenticationError("ID token signed with unknown key")
# Verify the token signature and decode claims
# jose library will verify signature against the JWK
payload = jose_jwt.decode(
id_token,
public_key,
algorithms=["RS256"], # Google uses RS256
audience=client_id,
issuer=OAuthService.GOOGLE_ISSUERS,
options={
"verify_signature": True,
"verify_aud": True,
"verify_iss": True,
"verify_exp": True,
"verify_iat": True,
},
)
# Verify nonce (OIDC replay attack protection)
token_nonce = payload.get("nonce")
if token_nonce != expected_nonce:
logger.warning(
f"OAuth ID token nonce mismatch: expected {expected_nonce}, "
f"got {token_nonce}"
)
raise AuthenticationError("Invalid ID token nonce")
logger.debug("Google ID token verified successfully")
return payload
except JWTError as e:
logger.warning(f"Google ID token verification failed: {e}")
raise AuthenticationError("Invalid ID token signature")
except httpx.HTTPError as e:
logger.error(f"Failed to fetch Google JWKS: {e}")
# If we can't verify the ID token, fail closed for security
raise AuthenticationError("Failed to verify ID token")
except Exception as e:
logger.error(f"Unexpected error verifying Google ID token: {e}")
raise AuthenticationError("ID token verification error")
@staticmethod
async def _create_oauth_user(
db: AsyncSession,