Enhance OAuth security, PKCE, and state validation
- Enforced stricter PKCE requirements by rejecting insecure 'plain' method for public clients. - Transitioned client secret hashing to bcrypt for improved security and migration compatibility. - Added constant-time comparison for state parameter validation to prevent timing attacks. - Improved error handling and logging for OAuth workflows, including malformed headers and invalid scopes. - Upgraded Google OIDC token validation to verify both signature and nonce. - Refactored OAuth service methods and schemas for better readability, consistency, and compliance with RFC specifications.
This commit is contained in:
@@ -196,28 +196,27 @@ async def authorize(
|
|||||||
valid_scopes = provider_service.validate_scopes(client, requested_scopes)
|
valid_scopes = provider_service.validate_scopes(client, requested_scopes)
|
||||||
except provider_service.InvalidScopeError as e:
|
except provider_service.InvalidScopeError as e:
|
||||||
# Redirect with error
|
# Redirect with error
|
||||||
error_params = {
|
scope_error_params: dict[str, str] = {"error": e.error}
|
||||||
"error": e.error,
|
if e.error_description:
|
||||||
"error_description": e.error_description,
|
scope_error_params["error_description"] = e.error_description
|
||||||
}
|
|
||||||
if state:
|
if state:
|
||||||
error_params["state"] = state
|
scope_error_params["state"] = state
|
||||||
return RedirectResponse(
|
return RedirectResponse(
|
||||||
url=f"{redirect_uri}?{urlencode(error_params)}",
|
url=f"{redirect_uri}?{urlencode(scope_error_params)}",
|
||||||
status_code=status.HTTP_302_FOUND,
|
status_code=status.HTTP_302_FOUND,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Public clients MUST use PKCE
|
# Public clients MUST use PKCE
|
||||||
if client.client_type == "public":
|
if client.client_type == "public":
|
||||||
if not code_challenge or code_challenge_method != "S256":
|
if not code_challenge or code_challenge_method != "S256":
|
||||||
error_params = {
|
pkce_error_params: dict[str, str] = {
|
||||||
"error": "invalid_request",
|
"error": "invalid_request",
|
||||||
"error_description": "PKCE with S256 is required for public clients",
|
"error_description": "PKCE with S256 is required for public clients",
|
||||||
}
|
}
|
||||||
if state:
|
if state:
|
||||||
error_params["state"] = state
|
pkce_error_params["state"] = state
|
||||||
return RedirectResponse(
|
return RedirectResponse(
|
||||||
url=f"{redirect_uri}?{urlencode(error_params)}",
|
url=f"{redirect_uri}?{urlencode(pkce_error_params)}",
|
||||||
status_code=status.HTTP_302_FOUND,
|
status_code=status.HTTP_302_FOUND,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -226,16 +225,18 @@ async def authorize(
|
|||||||
# Store authorization request in session and redirect to login
|
# Store authorization request in session and redirect to login
|
||||||
# The frontend will handle the return URL
|
# The frontend will handle the return URL
|
||||||
login_url = f"{settings.FRONTEND_URL}/login"
|
login_url = f"{settings.FRONTEND_URL}/login"
|
||||||
return_params = urlencode({
|
return_params = urlencode(
|
||||||
"oauth_authorize": "true",
|
{
|
||||||
"client_id": client_id,
|
"oauth_authorize": "true",
|
||||||
"redirect_uri": redirect_uri,
|
"client_id": client_id,
|
||||||
"scope": " ".join(valid_scopes),
|
"redirect_uri": redirect_uri,
|
||||||
"state": state,
|
"scope": " ".join(valid_scopes),
|
||||||
"code_challenge": code_challenge or "",
|
"state": state,
|
||||||
"code_challenge_method": code_challenge_method or "",
|
"code_challenge": code_challenge or "",
|
||||||
"nonce": nonce or "",
|
"code_challenge_method": code_challenge_method or "",
|
||||||
})
|
"nonce": nonce or "",
|
||||||
|
}
|
||||||
|
)
|
||||||
return RedirectResponse(
|
return RedirectResponse(
|
||||||
url=f"{login_url}?return_to=/auth/consent?{return_params}",
|
url=f"{login_url}?return_to=/auth/consent?{return_params}",
|
||||||
status_code=status.HTTP_302_FOUND,
|
status_code=status.HTTP_302_FOUND,
|
||||||
@@ -248,16 +249,18 @@ async def authorize(
|
|||||||
|
|
||||||
if not has_consent:
|
if not has_consent:
|
||||||
# Redirect to consent page
|
# Redirect to consent page
|
||||||
consent_params = urlencode({
|
consent_params = urlencode(
|
||||||
"client_id": client_id,
|
{
|
||||||
"client_name": client.client_name,
|
"client_id": client_id,
|
||||||
"redirect_uri": redirect_uri,
|
"client_name": client.client_name,
|
||||||
"scope": " ".join(valid_scopes),
|
"redirect_uri": redirect_uri,
|
||||||
"state": state,
|
"scope": " ".join(valid_scopes),
|
||||||
"code_challenge": code_challenge or "",
|
"state": state,
|
||||||
"code_challenge_method": code_challenge_method or "",
|
"code_challenge": code_challenge or "",
|
||||||
"nonce": nonce or "",
|
"code_challenge_method": code_challenge_method or "",
|
||||||
})
|
"nonce": nonce or "",
|
||||||
|
}
|
||||||
|
)
|
||||||
return RedirectResponse(
|
return RedirectResponse(
|
||||||
url=f"{settings.FRONTEND_URL}/auth/consent?{consent_params}",
|
url=f"{settings.FRONTEND_URL}/auth/consent?{consent_params}",
|
||||||
status_code=status.HTTP_302_FOUND,
|
status_code=status.HTTP_302_FOUND,
|
||||||
@@ -277,10 +280,9 @@ async def authorize(
|
|||||||
nonce=nonce,
|
nonce=nonce,
|
||||||
)
|
)
|
||||||
except provider_service.OAuthProviderError as e:
|
except provider_service.OAuthProviderError as e:
|
||||||
error_params = {
|
error_params: dict[str, str] = {"error": e.error}
|
||||||
"error": e.error,
|
if e.error_description:
|
||||||
"error_description": e.error_description,
|
error_params["error_description"] = e.error_description
|
||||||
}
|
|
||||||
if state:
|
if state:
|
||||||
error_params["state"] = state
|
error_params["state"] = state
|
||||||
return RedirectResponse(
|
return RedirectResponse(
|
||||||
@@ -340,14 +342,14 @@ async def submit_consent(
|
|||||||
|
|
||||||
# If user denied, redirect with error
|
# If user denied, redirect with error
|
||||||
if not approved:
|
if not approved:
|
||||||
error_params = {
|
denied_params: dict[str, str] = {
|
||||||
"error": "access_denied",
|
"error": "access_denied",
|
||||||
"error_description": "User denied authorization",
|
"error_description": "User denied authorization",
|
||||||
}
|
}
|
||||||
if state:
|
if state:
|
||||||
error_params["state"] = state
|
denied_params["state"] = state
|
||||||
return RedirectResponse(
|
return RedirectResponse(
|
||||||
url=f"{redirect_uri}?{urlencode(error_params)}",
|
url=f"{redirect_uri}?{urlencode(denied_params)}",
|
||||||
status_code=status.HTTP_302_FOUND,
|
status_code=status.HTTP_302_FOUND,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -356,9 +358,7 @@ async def submit_consent(
|
|||||||
valid_scopes = provider_service.validate_scopes(client, granted_scopes)
|
valid_scopes = provider_service.validate_scopes(client, granted_scopes)
|
||||||
|
|
||||||
# Record consent
|
# Record consent
|
||||||
await provider_service.grant_consent(
|
await provider_service.grant_consent(db, current_user.id, client_id, valid_scopes)
|
||||||
db, current_user.id, client_id, valid_scopes
|
|
||||||
)
|
|
||||||
|
|
||||||
# Generate authorization code
|
# Generate authorization code
|
||||||
try:
|
try:
|
||||||
@@ -374,10 +374,9 @@ async def submit_consent(
|
|||||||
nonce=nonce,
|
nonce=nonce,
|
||||||
)
|
)
|
||||||
except provider_service.OAuthProviderError as e:
|
except provider_service.OAuthProviderError as e:
|
||||||
error_params = {
|
error_params: dict[str, str] = {"error": e.error}
|
||||||
"error": e.error,
|
if e.error_description:
|
||||||
"error_description": e.error_description,
|
error_params["error_description"] = e.error_description
|
||||||
}
|
|
||||||
if state:
|
if state:
|
||||||
error_params["state"] = state
|
error_params["state"] = state
|
||||||
return RedirectResponse(
|
return RedirectResponse(
|
||||||
@@ -439,6 +438,7 @@ async def token(
|
|||||||
auth_header = request.headers.get("Authorization", "")
|
auth_header = request.headers.get("Authorization", "")
|
||||||
if auth_header.startswith("Basic "):
|
if auth_header.startswith("Basic "):
|
||||||
import base64
|
import base64
|
||||||
|
|
||||||
try:
|
try:
|
||||||
decoded = base64.b64decode(auth_header[6:]).decode()
|
decoded = base64.b64decode(auth_header[6:]).decode()
|
||||||
client_id, client_secret = decoded.split(":", 1)
|
client_id, client_secret = decoded.split(":", 1)
|
||||||
@@ -549,6 +549,7 @@ async def revoke(
|
|||||||
auth_header = request.headers.get("Authorization", "")
|
auth_header = request.headers.get("Authorization", "")
|
||||||
if auth_header.startswith("Basic "):
|
if auth_header.startswith("Basic "):
|
||||||
import base64
|
import base64
|
||||||
|
|
||||||
try:
|
try:
|
||||||
decoded = base64.b64decode(auth_header[6:]).decode()
|
decoded = base64.b64decode(auth_header[6:]).decode()
|
||||||
client_id, client_secret = decoded.split(":", 1)
|
client_id, client_secret = decoded.split(":", 1)
|
||||||
@@ -619,6 +620,7 @@ async def introspect(
|
|||||||
auth_header = request.headers.get("Authorization", "")
|
auth_header = request.headers.get("Authorization", "")
|
||||||
if auth_header.startswith("Basic "):
|
if auth_header.startswith("Basic "):
|
||||||
import base64
|
import base64
|
||||||
|
|
||||||
try:
|
try:
|
||||||
decoded = base64.b64decode(auth_header[6:]).decode()
|
decoded = base64.b64decode(auth_header[6:]).decode()
|
||||||
client_id, client_secret = decoded.split(":", 1)
|
client_id, client_secret = decoded.split(":", 1)
|
||||||
@@ -804,7 +806,9 @@ async def list_my_consents(
|
|||||||
"client_id": consent.client_id,
|
"client_id": consent.client_id,
|
||||||
"client_name": client.client_name,
|
"client_name": client.client_name,
|
||||||
"client_description": client.client_description,
|
"client_description": client.client_description,
|
||||||
"granted_scopes": consent.granted_scopes.split() if consent.granted_scopes else [],
|
"granted_scopes": consent.granted_scopes.split()
|
||||||
|
if consent.granted_scopes
|
||||||
|
else [],
|
||||||
"granted_at": consent.created_at.isoformat(),
|
"granted_at": consent.created_at.isoformat(),
|
||||||
}
|
}
|
||||||
for consent, client in rows
|
for consent, client in rows
|
||||||
|
|||||||
@@ -515,11 +515,11 @@ class CRUDOAuthClient(CRUDBase[OAuthClient, OAuthClientCreate, EmptySchema]):
|
|||||||
client_secret_hash = None
|
client_secret_hash = None
|
||||||
if obj_in.client_type == "confidential":
|
if obj_in.client_type == "confidential":
|
||||||
client_secret = secrets.token_urlsafe(48)
|
client_secret = secrets.token_urlsafe(48)
|
||||||
# In production, use proper password hashing (bcrypt)
|
# SECURITY: Use bcrypt for secret storage (not SHA-256)
|
||||||
# For now, we store a hash placeholder
|
# bcrypt is computationally expensive, making brute-force attacks infeasible
|
||||||
import hashlib
|
from app.core.auth import get_password_hash
|
||||||
|
|
||||||
client_secret_hash = hashlib.sha256(client_secret.encode()).hexdigest()
|
client_secret_hash = get_password_hash(client_secret)
|
||||||
|
|
||||||
db_obj = OAuthClient(
|
db_obj = OAuthClient(
|
||||||
client_id=client_id,
|
client_id=client_id,
|
||||||
@@ -632,13 +632,22 @@ class CRUDOAuthClient(CRUDBase[OAuthClient, OAuthClientCreate, EmptySchema]):
|
|||||||
if client is None or client.client_secret_hash is None:
|
if client is None or client.client_secret_hash is None:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
# Verify secret
|
# SECURITY: Verify secret using bcrypt (not SHA-256)
|
||||||
import hashlib
|
# This supports both old SHA-256 hashes (for migration) and new bcrypt hashes
|
||||||
|
from app.core.auth import verify_password
|
||||||
|
|
||||||
secret_hash = hashlib.sha256(client_secret.encode()).hexdigest()
|
|
||||||
# Cast to str for type safety with compare_digest
|
|
||||||
stored_hash: str = str(client.client_secret_hash)
|
stored_hash: str = str(client.client_secret_hash)
|
||||||
return secrets.compare_digest(stored_hash, secret_hash)
|
|
||||||
|
# Check if it's a bcrypt hash (starts with $2b$) or legacy SHA-256
|
||||||
|
if stored_hash.startswith("$2"):
|
||||||
|
# New bcrypt format
|
||||||
|
return verify_password(client_secret, stored_hash)
|
||||||
|
else:
|
||||||
|
# Legacy SHA-256 format - still support for migration
|
||||||
|
import hashlib
|
||||||
|
|
||||||
|
secret_hash = hashlib.sha256(client_secret.encode()).hexdigest()
|
||||||
|
return secrets.compare_digest(stored_hash, secret_hash)
|
||||||
except Exception as e: # pragma: no cover
|
except Exception as e: # pragma: no cover
|
||||||
logger.error(f"Error verifying client secret: {e!s}")
|
logger.error(f"Error verifying client secret: {e!s}")
|
||||||
return False
|
return False
|
||||||
|
|||||||
@@ -339,9 +339,7 @@ class OAuthTokenResponse(BaseModel):
|
|||||||
token_type: str = Field(
|
token_type: str = Field(
|
||||||
default="Bearer", description="The type of token (typically 'Bearer')"
|
default="Bearer", description="The type of token (typically 'Bearer')"
|
||||||
)
|
)
|
||||||
expires_in: int | None = Field(
|
expires_in: int | None = Field(None, description="Token lifetime in seconds")
|
||||||
None, description="Token lifetime in seconds"
|
|
||||||
)
|
|
||||||
refresh_token: str | None = Field(
|
refresh_token: str | None = Field(
|
||||||
None, description="Refresh token for obtaining new access tokens"
|
None, description="Refresh token for obtaining new access tokens"
|
||||||
)
|
)
|
||||||
@@ -365,39 +363,21 @@ class OAuthTokenResponse(BaseModel):
|
|||||||
class OAuthTokenIntrospectionResponse(BaseModel):
|
class OAuthTokenIntrospectionResponse(BaseModel):
|
||||||
"""OAuth 2.0 Token Introspection Response (RFC 7662)."""
|
"""OAuth 2.0 Token Introspection Response (RFC 7662)."""
|
||||||
|
|
||||||
active: bool = Field(
|
active: bool = Field(..., description="Whether the token is currently active")
|
||||||
..., description="Whether the token is currently active"
|
scope: str | None = Field(None, description="Space-separated list of scopes")
|
||||||
)
|
client_id: str | None = Field(None, description="Client identifier for the token")
|
||||||
scope: str | None = Field(
|
|
||||||
None, description="Space-separated list of scopes"
|
|
||||||
)
|
|
||||||
client_id: str | None = Field(
|
|
||||||
None, description="Client identifier for the token"
|
|
||||||
)
|
|
||||||
username: str | None = Field(
|
username: str | None = Field(
|
||||||
None, description="Human-readable identifier for the resource owner"
|
None, description="Human-readable identifier for the resource owner"
|
||||||
)
|
)
|
||||||
token_type: str | None = Field(
|
token_type: str | None = Field(
|
||||||
None, description="Type of the token (e.g., 'Bearer')"
|
None, description="Type of the token (e.g., 'Bearer')"
|
||||||
)
|
)
|
||||||
exp: int | None = Field(
|
exp: int | None = Field(None, description="Token expiration time (Unix timestamp)")
|
||||||
None, description="Token expiration time (Unix timestamp)"
|
iat: int | None = Field(None, description="Token issue time (Unix timestamp)")
|
||||||
)
|
nbf: int | None = Field(None, description="Token not-before time (Unix timestamp)")
|
||||||
iat: int | None = Field(
|
sub: str | None = Field(None, description="Subject of the token (user ID)")
|
||||||
None, description="Token issue time (Unix timestamp)"
|
aud: str | None = Field(None, description="Intended audience of the token")
|
||||||
)
|
iss: str | None = Field(None, description="Issuer of the token")
|
||||||
nbf: int | None = Field(
|
|
||||||
None, description="Token not-before time (Unix timestamp)"
|
|
||||||
)
|
|
||||||
sub: str | None = Field(
|
|
||||||
None, description="Subject of the token (user ID)"
|
|
||||||
)
|
|
||||||
aud: str | None = Field(
|
|
||||||
None, description="Intended audience of the token"
|
|
||||||
)
|
|
||||||
iss: str | None = Field(
|
|
||||||
None, description="Issuer of the token"
|
|
||||||
)
|
|
||||||
|
|
||||||
model_config = ConfigDict(
|
model_config = ConfigDict(
|
||||||
json_schema_extra={
|
json_schema_extra={
|
||||||
|
|||||||
@@ -126,16 +126,22 @@ def hash_token(token: str) -> str:
|
|||||||
|
|
||||||
|
|
||||||
def verify_pkce(code_verifier: str, code_challenge: str, method: str) -> bool:
|
def verify_pkce(code_verifier: str, code_challenge: str, method: str) -> bool:
|
||||||
"""Verify PKCE code_verifier against stored code_challenge."""
|
"""
|
||||||
if method == "S256":
|
Verify PKCE code_verifier against stored code_challenge.
|
||||||
# SHA-256 hash, then base64url encode
|
|
||||||
digest = hashlib.sha256(code_verifier.encode()).digest()
|
SECURITY: Only S256 method is supported. The 'plain' method provides
|
||||||
computed = base64.urlsafe_b64encode(digest).rstrip(b"=").decode()
|
no security benefit and is explicitly rejected per RFC 7636 Section 4.3.
|
||||||
return secrets.compare_digest(computed, code_challenge)
|
"""
|
||||||
elif method == "plain":
|
if method != "S256":
|
||||||
# Direct comparison (not recommended, but supported)
|
# SECURITY: Reject any method other than S256
|
||||||
return secrets.compare_digest(code_verifier, code_challenge)
|
# 'plain' method provides no security against code interception attacks
|
||||||
return False
|
logger.warning(f"PKCE verification rejected for unsupported method: {method}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
# SHA-256 hash, then base64url encode (RFC 7636 Section 4.2)
|
||||||
|
digest = hashlib.sha256(code_verifier.encode()).digest()
|
||||||
|
computed = base64.urlsafe_b64encode(digest).rstrip(b"=").decode()
|
||||||
|
return secrets.compare_digest(computed, code_challenge)
|
||||||
|
|
||||||
|
|
||||||
def parse_scope(scope: str) -> list[str]:
|
def parse_scope(scope: str) -> list[str]:
|
||||||
@@ -198,10 +204,21 @@ async def validate_client(
|
|||||||
if not client.client_secret_hash:
|
if not client.client_secret_hash:
|
||||||
raise InvalidClientError("Client not configured with secret")
|
raise InvalidClientError("Client not configured with secret")
|
||||||
|
|
||||||
# Verify secret using SHA256 hash (consistent with CRUD)
|
# SECURITY: Verify secret using bcrypt (not SHA-256)
|
||||||
computed_hash = hashlib.sha256(client_secret.encode()).hexdigest()
|
# Supports both bcrypt and legacy SHA-256 hashes for migration
|
||||||
if not secrets.compare_digest(computed_hash, client.client_secret_hash):
|
from app.core.auth import verify_password
|
||||||
raise InvalidClientError("Invalid client secret")
|
|
||||||
|
stored_hash = str(client.client_secret_hash)
|
||||||
|
|
||||||
|
if stored_hash.startswith("$2"):
|
||||||
|
# New bcrypt format
|
||||||
|
if not verify_password(client_secret, stored_hash):
|
||||||
|
raise InvalidClientError("Invalid client secret")
|
||||||
|
else:
|
||||||
|
# Legacy SHA-256 format
|
||||||
|
computed_hash = hashlib.sha256(client_secret.encode()).hexdigest()
|
||||||
|
if not secrets.compare_digest(computed_hash, stored_hash):
|
||||||
|
raise InvalidClientError("Invalid client secret")
|
||||||
|
|
||||||
return client
|
return client
|
||||||
|
|
||||||
@@ -246,9 +263,7 @@ def validate_scopes(client: OAuthClient, requested_scopes: list[str]) -> list[st
|
|||||||
# Warn if some scopes were filtered out
|
# Warn if some scopes were filtered out
|
||||||
invalid = requested - allowed
|
invalid = requested - allowed
|
||||||
if invalid:
|
if invalid:
|
||||||
logger.warning(
|
logger.warning(f"Client {client.client_id} requested invalid scopes: {invalid}")
|
||||||
f"Client {client.client_id} requested invalid scopes: {invalid}"
|
|
||||||
)
|
|
||||||
|
|
||||||
return list(valid)
|
return list(valid)
|
||||||
|
|
||||||
@@ -382,17 +397,17 @@ async def exchange_authorization_code(
|
|||||||
f"Authorization code reuse detected for client {existing_code.client_id}"
|
f"Authorization code reuse detected for client {existing_code.client_id}"
|
||||||
)
|
)
|
||||||
await revoke_tokens_for_user_client(
|
await revoke_tokens_for_user_client(
|
||||||
db, existing_code.user_id, existing_code.client_id
|
db, UUID(str(existing_code.user_id)), str(existing_code.client_id)
|
||||||
)
|
)
|
||||||
raise InvalidGrantError("Authorization code has already been used")
|
raise InvalidGrantError("Authorization code has already been used")
|
||||||
else:
|
else:
|
||||||
raise InvalidGrantError("Invalid authorization code")
|
raise InvalidGrantError("Invalid authorization code")
|
||||||
|
|
||||||
# Now fetch the full auth code record
|
# Now fetch the full auth code record
|
||||||
result = await db.execute(
|
auth_code_result = await db.execute(
|
||||||
select(OAuthAuthorizationCode).where(OAuthAuthorizationCode.id == updated_id)
|
select(OAuthAuthorizationCode).where(OAuthAuthorizationCode.id == updated_id)
|
||||||
)
|
)
|
||||||
auth_code = result.scalar_one()
|
auth_code = auth_code_result.scalar_one()
|
||||||
await db.commit()
|
await db.commit()
|
||||||
|
|
||||||
if auth_code.is_expired:
|
if auth_code.is_expired:
|
||||||
@@ -413,10 +428,14 @@ async def exchange_authorization_code(
|
|||||||
if client.client_type == "confidential":
|
if client.client_type == "confidential":
|
||||||
if not client_secret:
|
if not client_secret:
|
||||||
raise InvalidClientError("Client secret required for confidential clients")
|
raise InvalidClientError("Client secret required for confidential clients")
|
||||||
client = await validate_client(db, client_id, client_secret, require_secret=True)
|
client = await validate_client(
|
||||||
|
db, client_id, client_secret, require_secret=True
|
||||||
|
)
|
||||||
elif client_secret:
|
elif client_secret:
|
||||||
# Public client provided secret - validate it if given
|
# Public client provided secret - validate it if given
|
||||||
client = await validate_client(db, client_id, client_secret, require_secret=True)
|
client = await validate_client(
|
||||||
|
db, client_id, client_secret, require_secret=True
|
||||||
|
)
|
||||||
|
|
||||||
# Verify PKCE
|
# Verify PKCE
|
||||||
if auth_code.code_challenge:
|
if auth_code.code_challenge:
|
||||||
@@ -424,8 +443,8 @@ async def exchange_authorization_code(
|
|||||||
raise InvalidGrantError("code_verifier required")
|
raise InvalidGrantError("code_verifier required")
|
||||||
if not verify_pkce(
|
if not verify_pkce(
|
||||||
code_verifier,
|
code_verifier,
|
||||||
auth_code.code_challenge,
|
str(auth_code.code_challenge),
|
||||||
auth_code.code_challenge_method or "S256",
|
str(auth_code.code_challenge_method or "S256"),
|
||||||
):
|
):
|
||||||
raise InvalidGrantError("Invalid code_verifier")
|
raise InvalidGrantError("Invalid code_verifier")
|
||||||
elif client.client_type == "public":
|
elif client.client_type == "public":
|
||||||
@@ -443,8 +462,8 @@ async def exchange_authorization_code(
|
|||||||
db=db,
|
db=db,
|
||||||
client=client,
|
client=client,
|
||||||
user=user,
|
user=user,
|
||||||
scope=auth_code.scope,
|
scope=str(auth_code.scope),
|
||||||
nonce=auth_code.nonce,
|
nonce=str(auth_code.nonce) if auth_code.nonce else None,
|
||||||
device_info=device_info,
|
device_info=device_info,
|
||||||
ip_address=ip_address,
|
ip_address=ip_address,
|
||||||
)
|
)
|
||||||
@@ -487,16 +506,20 @@ async def create_tokens(
|
|||||||
access_expires = now + timedelta(seconds=access_token_lifetime)
|
access_expires = now + timedelta(seconds=access_token_lifetime)
|
||||||
|
|
||||||
# Refresh token expiry
|
# Refresh token expiry
|
||||||
refresh_token_lifetime = int(client.refresh_token_lifetime or str(REFRESH_TOKEN_EXPIRY_DAYS * 86400))
|
refresh_token_lifetime = int(
|
||||||
|
client.refresh_token_lifetime or str(REFRESH_TOKEN_EXPIRY_DAYS * 86400)
|
||||||
|
)
|
||||||
refresh_expires = now + timedelta(seconds=refresh_token_lifetime)
|
refresh_expires = now + timedelta(seconds=refresh_token_lifetime)
|
||||||
|
|
||||||
# Create JWT access token
|
# Create JWT access token
|
||||||
|
# SECURITY: Include all standard JWT claims per RFC 7519
|
||||||
access_token_payload = {
|
access_token_payload = {
|
||||||
"iss": settings.OAUTH_ISSUER,
|
"iss": settings.OAUTH_ISSUER,
|
||||||
"sub": str(user.id),
|
"sub": str(user.id),
|
||||||
"aud": client.client_id,
|
"aud": client.client_id,
|
||||||
"exp": int(access_expires.timestamp()),
|
"exp": int(access_expires.timestamp()),
|
||||||
"iat": int(now.timestamp()),
|
"iat": int(now.timestamp()),
|
||||||
|
"nbf": int(now.timestamp()), # Not Before - token is valid immediately
|
||||||
"jti": jti,
|
"jti": jti,
|
||||||
"scope": scope,
|
"scope": scope,
|
||||||
"client_id": client.client_id,
|
"client_id": client.client_id,
|
||||||
@@ -581,7 +604,7 @@ async def refresh_tokens(
|
|||||||
OAuthProviderRefreshToken.token_hash == token_hash
|
OAuthProviderRefreshToken.token_hash == token_hash
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
token_record = result.scalar_one_or_none()
|
token_record: OAuthProviderRefreshToken | None = result.scalar_one_or_none()
|
||||||
|
|
||||||
if not token_record:
|
if not token_record:
|
||||||
raise InvalidGrantError("Invalid refresh token")
|
raise InvalidGrantError("Invalid refresh token")
|
||||||
@@ -608,36 +631,37 @@ async def refresh_tokens(
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Get user
|
# Get user
|
||||||
user_result = await db.execute(
|
user_result = await db.execute(select(User).where(User.id == token_record.user_id))
|
||||||
select(User).where(User.id == token_record.user_id)
|
|
||||||
)
|
|
||||||
user = user_result.scalar_one_or_none()
|
user = user_result.scalar_one_or_none()
|
||||||
if not user or not user.is_active:
|
if not user or not user.is_active:
|
||||||
raise InvalidGrantError("User not found or inactive")
|
raise InvalidGrantError("User not found or inactive")
|
||||||
|
|
||||||
# Validate scope (can only reduce, not expand)
|
# Validate scope (can only reduce, not expand)
|
||||||
original_scopes = set(parse_scope(token_record.scope))
|
token_scope = str(token_record.scope) if token_record.scope else ""
|
||||||
|
original_scopes = set(parse_scope(token_scope))
|
||||||
if scope:
|
if scope:
|
||||||
requested_scopes = set(parse_scope(scope))
|
requested_scopes = set(parse_scope(scope))
|
||||||
if not requested_scopes.issubset(original_scopes):
|
if not requested_scopes.issubset(original_scopes):
|
||||||
raise InvalidScopeError("Cannot expand scope on refresh")
|
raise InvalidScopeError("Cannot expand scope on refresh")
|
||||||
final_scope = join_scope(list(requested_scopes))
|
final_scope = join_scope(list(requested_scopes))
|
||||||
else:
|
else:
|
||||||
final_scope = token_record.scope
|
final_scope = token_scope
|
||||||
|
|
||||||
# Revoke old refresh token (token rotation)
|
# Revoke old refresh token (token rotation)
|
||||||
token_record.revoked = True
|
token_record.revoked = True # type: ignore[assignment]
|
||||||
token_record.last_used_at = datetime.now(UTC)
|
token_record.last_used_at = datetime.now(UTC) # type: ignore[assignment]
|
||||||
await db.commit()
|
await db.commit()
|
||||||
|
|
||||||
# Issue new tokens
|
# Issue new tokens
|
||||||
|
device = str(token_record.device_info) if token_record.device_info else None
|
||||||
|
ip_addr = str(token_record.ip_address) if token_record.ip_address else None
|
||||||
return await create_tokens(
|
return await create_tokens(
|
||||||
db=db,
|
db=db,
|
||||||
client=client,
|
client=client,
|
||||||
user=user,
|
user=user,
|
||||||
scope=final_scope,
|
scope=final_scope,
|
||||||
device_info=device_info or token_record.device_info,
|
device_info=device_info or device,
|
||||||
ip_address=ip_address or token_record.ip_address,
|
ip_address=ip_address or ip_addr,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -685,7 +709,7 @@ async def revoke_token(
|
|||||||
if client_id and refresh_record.client_id != client_id:
|
if client_id and refresh_record.client_id != client_id:
|
||||||
raise InvalidClientError("Token was not issued to this client")
|
raise InvalidClientError("Token was not issued to this client")
|
||||||
|
|
||||||
refresh_record.revoked = True
|
refresh_record.revoked = True # type: ignore[assignment]
|
||||||
await db.commit()
|
await db.commit()
|
||||||
logger.info(f"Revoked refresh token {refresh_record.jti[:8]}...")
|
logger.info(f"Revoked refresh token {refresh_record.jti[:8]}...")
|
||||||
return True
|
return True
|
||||||
@@ -699,7 +723,10 @@ async def revoke_token(
|
|||||||
token,
|
token,
|
||||||
settings.SECRET_KEY,
|
settings.SECRET_KEY,
|
||||||
algorithms=[settings.ALGORITHM],
|
algorithms=[settings.ALGORITHM],
|
||||||
options={"verify_exp": False, "verify_aud": False}, # Allow expired tokens
|
options={
|
||||||
|
"verify_exp": False,
|
||||||
|
"verify_aud": False,
|
||||||
|
}, # Allow expired tokens
|
||||||
)
|
)
|
||||||
jti = payload.get("jti")
|
jti = payload.get("jti")
|
||||||
if jti:
|
if jti:
|
||||||
@@ -713,7 +740,7 @@ async def revoke_token(
|
|||||||
if refresh_record:
|
if refresh_record:
|
||||||
if client_id and refresh_record.client_id != client_id:
|
if client_id and refresh_record.client_id != client_id:
|
||||||
raise InvalidClientError("Token was not issued to this client")
|
raise InvalidClientError("Token was not issued to this client")
|
||||||
refresh_record.revoked = True
|
refresh_record.revoked = True # type: ignore[assignment]
|
||||||
await db.commit()
|
await db.commit()
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Revoked refresh token via access token JTI {jti[:8]}..."
|
f"Revoked refresh token via access token JTI {jti[:8]}..."
|
||||||
@@ -756,7 +783,7 @@ async def revoke_tokens_for_user_client(
|
|||||||
|
|
||||||
count = 0
|
count = 0
|
||||||
for token in tokens:
|
for token in tokens:
|
||||||
token.revoked = True
|
token.revoked = True # type: ignore[assignment]
|
||||||
count += 1
|
count += 1
|
||||||
|
|
||||||
if count > 0:
|
if count > 0:
|
||||||
@@ -793,7 +820,7 @@ async def revoke_all_user_tokens(db: AsyncSession, user_id: UUID) -> int:
|
|||||||
|
|
||||||
count = 0
|
count = 0
|
||||||
for token in tokens:
|
for token in tokens:
|
||||||
token.revoked = True
|
token.revoked = True # type: ignore[assignment]
|
||||||
count += 1
|
count += 1
|
||||||
|
|
||||||
if count > 0:
|
if count > 0:
|
||||||
@@ -843,7 +870,9 @@ async def introspect_token(
|
|||||||
token,
|
token,
|
||||||
settings.SECRET_KEY,
|
settings.SECRET_KEY,
|
||||||
algorithms=[settings.ALGORITHM],
|
algorithms=[settings.ALGORITHM],
|
||||||
options={"verify_aud": False}, # Don't require audience match for introspection
|
options={
|
||||||
|
"verify_aud": False
|
||||||
|
}, # Don't require audience match for introspection
|
||||||
)
|
)
|
||||||
|
|
||||||
# Check if associated refresh token is revoked
|
# Check if associated refresh token is revoked
|
||||||
@@ -953,9 +982,10 @@ async def grant_consent(
|
|||||||
|
|
||||||
if consent:
|
if consent:
|
||||||
# Merge scopes
|
# Merge scopes
|
||||||
existing = set(parse_scope(consent.granted_scopes))
|
granted = str(consent.granted_scopes) if consent.granted_scopes else ""
|
||||||
|
existing = set(parse_scope(granted))
|
||||||
new_scopes = existing | set(scopes)
|
new_scopes = existing | set(scopes)
|
||||||
consent.granted_scopes = join_scope(list(new_scopes))
|
consent.granted_scopes = join_scope(list(new_scopes)) # type: ignore[assignment]
|
||||||
else:
|
else:
|
||||||
consent = OAuthConsent(
|
consent = OAuthConsent(
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
@@ -993,7 +1023,7 @@ async def revoke_consent(
|
|||||||
await revoke_tokens_for_user_client(db, user_id, client_id)
|
await revoke_tokens_for_user_client(db, user_id, client_id)
|
||||||
|
|
||||||
await db.commit()
|
await db.commit()
|
||||||
return result.rowcount > 0
|
return result.rowcount > 0 # type: ignore[attr-defined]
|
||||||
|
|
||||||
|
|
||||||
# ============================================================================
|
# ============================================================================
|
||||||
@@ -1016,7 +1046,7 @@ async def cleanup_expired_codes(db: AsyncSession) -> int:
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
await db.commit()
|
await db.commit()
|
||||||
return result.rowcount
|
return result.rowcount # type: ignore[attr-defined]
|
||||||
|
|
||||||
|
|
||||||
async def cleanup_expired_tokens(db: AsyncSession) -> int:
|
async def cleanup_expired_tokens(db: AsyncSession) -> int:
|
||||||
@@ -1036,4 +1066,4 @@ async def cleanup_expired_tokens(db: AsyncSession) -> int:
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
await db.commit()
|
await db.commit()
|
||||||
return result.rowcount
|
return result.rowcount # type: ignore[attr-defined]
|
||||||
|
|||||||
@@ -282,35 +282,16 @@ class OAuthService:
|
|||||||
**token_params,
|
**token_params,
|
||||||
)
|
)
|
||||||
|
|
||||||
# SECURITY: Validate nonce in ID token for OpenID Connect (Google)
|
# SECURITY: Validate ID token signature and nonce for OpenID Connect
|
||||||
# This prevents token replay attacks (OpenID Connect Core 3.1.3.7)
|
# This prevents token forgery and replay attacks (OIDC Core 3.1.3.7)
|
||||||
if provider == "google" and state_record.nonce:
|
if provider == "google" and state_record.nonce:
|
||||||
id_token = token.get("id_token")
|
id_token = token.get("id_token")
|
||||||
if id_token:
|
if id_token:
|
||||||
import base64
|
await OAuthService._verify_google_id_token(
|
||||||
import json
|
id_token=str(id_token),
|
||||||
|
expected_nonce=str(state_record.nonce),
|
||||||
# Decode ID token payload (JWT format: header.payload.signature)
|
client_id=client_id,
|
||||||
try:
|
)
|
||||||
payload_b64 = id_token.split(".")[1]
|
|
||||||
# Add padding if needed
|
|
||||||
padding = 4 - len(payload_b64) % 4
|
|
||||||
if padding != 4:
|
|
||||||
payload_b64 += "=" * padding
|
|
||||||
payload_json = base64.urlsafe_b64decode(payload_b64)
|
|
||||||
payload = json.loads(payload_json)
|
|
||||||
|
|
||||||
token_nonce = payload.get("nonce")
|
|
||||||
if token_nonce != state_record.nonce:
|
|
||||||
logger.warning(
|
|
||||||
f"OAuth nonce mismatch: expected {state_record.nonce}, "
|
|
||||||
f"got {token_nonce}"
|
|
||||||
)
|
|
||||||
raise AuthenticationError("Invalid ID token nonce")
|
|
||||||
except (IndexError, ValueError, json.JSONDecodeError) as e:
|
|
||||||
logger.warning(f"Failed to decode ID token for nonce validation: {e}")
|
|
||||||
# Continue without nonce validation if ID token is malformed
|
|
||||||
# The token will still be validated when getting user info
|
|
||||||
except AuthenticationError:
|
except AuthenticationError:
|
||||||
raise
|
raise
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -337,7 +318,9 @@ class OAuthService:
|
|||||||
# Email can be None if user didn't grant email permission
|
# Email can be None if user didn't grant email permission
|
||||||
# SECURITY: Normalize email (lowercase, strip) to prevent case-based account duplication
|
# SECURITY: Normalize email (lowercase, strip) to prevent case-based account duplication
|
||||||
email_raw = user_info.get("email")
|
email_raw = user_info.get("email")
|
||||||
provider_email: str | None = str(email_raw).lower().strip() if email_raw else None
|
provider_email: str | None = (
|
||||||
|
str(email_raw).lower().strip() if email_raw else None
|
||||||
|
)
|
||||||
|
|
||||||
if not provider_user_id:
|
if not provider_user_id:
|
||||||
raise AuthenticationError("Provider did not return user ID")
|
raise AuthenticationError("Provider did not return user ID")
|
||||||
@@ -521,6 +504,106 @@ class OAuthService:
|
|||||||
|
|
||||||
return user_info
|
return user_info
|
||||||
|
|
||||||
|
# Google's OIDC configuration endpoints
|
||||||
|
GOOGLE_JWKS_URL = "https://www.googleapis.com/oauth2/v3/certs"
|
||||||
|
GOOGLE_ISSUERS = ("https://accounts.google.com", "accounts.google.com")
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def _verify_google_id_token(
|
||||||
|
id_token: str,
|
||||||
|
expected_nonce: str,
|
||||||
|
client_id: str,
|
||||||
|
) -> dict[str, object]:
|
||||||
|
"""
|
||||||
|
Verify Google ID token signature and claims.
|
||||||
|
|
||||||
|
SECURITY: This properly verifies the ID token by:
|
||||||
|
1. Fetching Google's public keys (JWKS)
|
||||||
|
2. Verifying the JWT signature against the public key
|
||||||
|
3. Validating issuer, audience, expiry, and nonce claims
|
||||||
|
|
||||||
|
Args:
|
||||||
|
id_token: The ID token JWT string
|
||||||
|
expected_nonce: The nonce we sent in the authorization request
|
||||||
|
client_id: Our OAuth client ID (expected audience)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Decoded ID token payload
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
AuthenticationError: If verification fails
|
||||||
|
"""
|
||||||
|
import httpx
|
||||||
|
from jose import jwt as jose_jwt
|
||||||
|
from jose.exceptions import JWTError
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Fetch Google's public keys (JWKS)
|
||||||
|
# In production, consider caching this with TTL matching Cache-Control header
|
||||||
|
async with httpx.AsyncClient() as client:
|
||||||
|
jwks_response = await client.get(
|
||||||
|
OAuthService.GOOGLE_JWKS_URL,
|
||||||
|
timeout=10.0,
|
||||||
|
)
|
||||||
|
jwks_response.raise_for_status()
|
||||||
|
jwks = jwks_response.json()
|
||||||
|
|
||||||
|
# Get the key ID from the token header
|
||||||
|
unverified_header = jose_jwt.get_unverified_header(id_token)
|
||||||
|
kid = unverified_header.get("kid")
|
||||||
|
if not kid:
|
||||||
|
raise AuthenticationError("ID token missing key ID (kid)")
|
||||||
|
|
||||||
|
# Find the matching public key
|
||||||
|
public_key = None
|
||||||
|
for key in jwks.get("keys", []):
|
||||||
|
if key.get("kid") == kid:
|
||||||
|
public_key = key
|
||||||
|
break
|
||||||
|
|
||||||
|
if not public_key:
|
||||||
|
raise AuthenticationError("ID token signed with unknown key")
|
||||||
|
|
||||||
|
# Verify the token signature and decode claims
|
||||||
|
# jose library will verify signature against the JWK
|
||||||
|
payload = jose_jwt.decode(
|
||||||
|
id_token,
|
||||||
|
public_key,
|
||||||
|
algorithms=["RS256"], # Google uses RS256
|
||||||
|
audience=client_id,
|
||||||
|
issuer=OAuthService.GOOGLE_ISSUERS,
|
||||||
|
options={
|
||||||
|
"verify_signature": True,
|
||||||
|
"verify_aud": True,
|
||||||
|
"verify_iss": True,
|
||||||
|
"verify_exp": True,
|
||||||
|
"verify_iat": True,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify nonce (OIDC replay attack protection)
|
||||||
|
token_nonce = payload.get("nonce")
|
||||||
|
if token_nonce != expected_nonce:
|
||||||
|
logger.warning(
|
||||||
|
f"OAuth ID token nonce mismatch: expected {expected_nonce}, "
|
||||||
|
f"got {token_nonce}"
|
||||||
|
)
|
||||||
|
raise AuthenticationError("Invalid ID token nonce")
|
||||||
|
|
||||||
|
logger.debug("Google ID token verified successfully")
|
||||||
|
return payload
|
||||||
|
|
||||||
|
except JWTError as e:
|
||||||
|
logger.warning(f"Google ID token verification failed: {e}")
|
||||||
|
raise AuthenticationError("Invalid ID token signature")
|
||||||
|
except httpx.HTTPError as e:
|
||||||
|
logger.error(f"Failed to fetch Google JWKS: {e}")
|
||||||
|
# If we can't verify the ID token, fail closed for security
|
||||||
|
raise AuthenticationError("Failed to verify ID token")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Unexpected error verifying Google ID token: {e}")
|
||||||
|
raise AuthenticationError("ID token verification error")
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
async def _create_oauth_user(
|
async def _create_oauth_user(
|
||||||
db: AsyncSession,
|
db: AsyncSession,
|
||||||
|
|||||||
@@ -21,6 +21,24 @@ import { Loader2 } from 'lucide-react';
|
|||||||
import { useOAuthCallback } from '@/lib/api/hooks/useOAuth';
|
import { useOAuthCallback } from '@/lib/api/hooks/useOAuth';
|
||||||
import config from '@/config/app.config';
|
import config from '@/config/app.config';
|
||||||
|
|
||||||
|
/**
|
||||||
|
* SECURITY: Constant-time string comparison to prevent timing attacks.
|
||||||
|
* JavaScript's === operator may short-circuit, potentially leaking information.
|
||||||
|
* While timing attacks on frontend state are less critical (state is in URL),
|
||||||
|
* this provides defense-in-depth.
|
||||||
|
*/
|
||||||
|
function constantTimeCompare(a: string, b: string): boolean {
|
||||||
|
if (a.length !== b.length) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
let result = 0;
|
||||||
|
for (let i = 0; i < a.length; i++) {
|
||||||
|
result |= a.charCodeAt(i) ^ b.charCodeAt(i);
|
||||||
|
}
|
||||||
|
return result === 0;
|
||||||
|
}
|
||||||
|
|
||||||
export default function OAuthCallbackPage() {
|
export default function OAuthCallbackPage() {
|
||||||
const params = useParams();
|
const params = useParams();
|
||||||
const searchParams = useSearchParams();
|
const searchParams = useSearchParams();
|
||||||
@@ -55,8 +73,9 @@ export default function OAuthCallbackPage() {
|
|||||||
|
|
||||||
// SECURITY: Validate state parameter against stored value (CSRF protection)
|
// SECURITY: Validate state parameter against stored value (CSRF protection)
|
||||||
// This prevents cross-site request forgery attacks
|
// This prevents cross-site request forgery attacks
|
||||||
|
// Use constant-time comparison for defense-in-depth
|
||||||
const storedState = sessionStorage.getItem('oauth_state');
|
const storedState = sessionStorage.getItem('oauth_state');
|
||||||
if (!storedState || storedState !== state) {
|
if (!storedState || !constantTimeCompare(storedState, state)) {
|
||||||
// Clean up stored state on mismatch
|
// Clean up stored state on mismatch
|
||||||
sessionStorage.removeItem('oauth_state');
|
sessionStorage.removeItem('oauth_state');
|
||||||
sessionStorage.removeItem('oauth_mode');
|
sessionStorage.removeItem('oauth_mode');
|
||||||
|
|||||||
Reference in New Issue
Block a user