From 400d6f6f7571517e9e02f3b9a9571de8dc988a94 Mon Sep 17 00:00:00 2001 From: Felipe Cardoso Date: Tue, 25 Nov 2025 23:50:43 +0100 Subject: [PATCH] Enhance OAuth security and state validation - Implemented stricter OAuth security measures, including CSRF protection via state parameter validation and redirect_uri checks. - Updated OAuth models to support timezone-aware datetime comparisons, replacing deprecated `utcnow`. - Enhanced logging for malformed Basic auth headers during token, introspect, and revoke requests. - Added allowlist validation for OAuth provider domains to prevent open redirect attacks. - Improved nonce validation for OpenID Connect tokens, ensuring token integrity during Google provider flows. - Updated E2E and unit tests to cover new security features and expanded OAuth state handling scenarios. --- AGENTS.md | 6 +- backend/app/api/routes/oauth_provider.py | 31 +++++-- .../app/models/oauth_authorization_code.py | 10 ++- backend/app/models/oauth_provider_token.py | 10 ++- .../app/services/oauth_provider_service.py | 81 +++++++++++------ backend/app/services/oauth_service.py | 44 +++++++++- backend/tests/e2e/conftest.py | 3 - backend/tests/e2e/test_admin_workflows.py | 4 +- .../tests/e2e/test_organization_workflows.py | 2 +- backend/tests/services/test_oauth_service.py | 12 +++ frontend/messages/en.json | 1 + frontend/messages/it.json | 1 + .../(auth)/auth/callback/[provider]/page.tsx | 12 +++ frontend/src/lib/api/hooks/useOAuth.ts | 86 +++++++++++++++++-- 14 files changed, 246 insertions(+), 57 deletions(-) diff --git a/AGENTS.md b/AGENTS.md index 8c8a100..923b667 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -42,7 +42,7 @@ Default superuser (change in production): │ │ ├── schemas/ # Pydantic request/response schemas │ │ ├── services/ # Business logic layer │ │ └── utils/ # Utilities (security, device detection) -│ ├── tests/ # 97% coverage, 743 tests +│ ├── tests/ # 96% coverage, 987 tests │ └── alembic/ # Database migrations │ └── frontend/ # Next.js 15 frontend @@ -128,7 +128,7 @@ Permission dependencies in `api/dependencies/permissions.py`: ### Testing Infrastructure **Backend Unit/Integration (pytest + SQLite):** -- 97% coverage, 743+ tests +- 96% coverage, 987 tests - Security-focused: JWT attacks, session hijacking, privilege escalation - Async fixtures in `tests/conftest.py` - Run: `IS_TEST=True uv run pytest` or `make test` @@ -265,7 +265,7 @@ docker-compose exec backend python -c "from app.init_db import init_db; import a - Organization system (multi-tenant with RBAC) - Admin panel (user/org management, bulk operations) - **Internationalization (i18n)** with English and Italian -- Comprehensive test coverage (97% backend, 97% frontend unit, 56 E2E tests) +- Comprehensive test coverage (96% backend, 97% frontend unit, 56 E2E tests) - Design system documentation - **Marketing landing page** with animations - **`/dev` documentation portal** with live examples diff --git a/backend/app/api/routes/oauth_provider.py b/backend/app/api/routes/oauth_provider.py index b8d2b26..4bc8ef4 100644 --- a/backend/app/api/routes/oauth_provider.py +++ b/backend/app/api/routes/oauth_provider.py @@ -169,11 +169,12 @@ async def authorize( detail="invalid_request: response_type must be 'code'", ) - # Validate PKCE method if provided - if code_challenge_method and code_challenge_method not in ["S256", "plain"]: + # Validate PKCE method if provided - ONLY S256 is allowed (RFC 7636 Section 4.3) + # "plain" method provides no security benefit and MUST NOT be used + if code_challenge_method and code_challenge_method != "S256": raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, - detail="invalid_request: code_challenge_method must be 'S256'", + detail="invalid_request: code_challenge_method must be 'S256' (plain is not supported)", ) # Validate client @@ -441,8 +442,12 @@ async def token( try: decoded = base64.b64decode(auth_header[6:]).decode() client_id, client_secret = decoded.split(":", 1) - except Exception: # noqa: S110 - Intentional: malformed Basic auth falls back to form body - pass + except Exception as e: + # Log malformed Basic auth for security monitoring + logger.warning( + f"Malformed Basic auth header in token request: {type(e).__name__}" + ) + # Fall back to form body if not client_id: raise HTTPException( @@ -547,8 +552,12 @@ async def revoke( try: decoded = base64.b64decode(auth_header[6:]).decode() client_id, client_secret = decoded.split(":", 1) - except Exception: # noqa: S110 - Intentional: malformed Basic auth falls back to form body - pass + except Exception as e: + # Log malformed Basic auth for security monitoring + logger.warning( + f"Malformed Basic auth header in revoke request: {type(e).__name__}" + ) + # Fall back to form body try: await provider_service.revoke_token( @@ -613,8 +622,12 @@ async def introspect( try: decoded = base64.b64decode(auth_header[6:]).decode() client_id, client_secret = decoded.split(":", 1) - except Exception: # noqa: S110 - Intentional: malformed Basic auth falls back to form body - pass + except Exception as e: + # Log malformed Basic auth for security monitoring + logger.warning( + f"Malformed Basic auth header in introspect request: {type(e).__name__}" + ) + # Fall back to form body try: result = await provider_service.introspect_token( diff --git a/backend/app/models/oauth_authorization_code.py b/backend/app/models/oauth_authorization_code.py index 5f0543c..3741f02 100644 --- a/backend/app/models/oauth_authorization_code.py +++ b/backend/app/models/oauth_authorization_code.py @@ -1,6 +1,6 @@ """OAuth authorization code model for OAuth provider mode.""" -from datetime import datetime +from datetime import UTC, datetime from sqlalchemy import Boolean, Column, DateTime, ForeignKey, Index, String from sqlalchemy.dialects.postgresql import UUID @@ -83,7 +83,13 @@ class OAuthAuthorizationCode(Base, UUIDMixin, TimestampMixin): @property def is_expired(self) -> bool: """Check if the authorization code has expired.""" - return datetime.utcnow() > self.expires_at.replace(tzinfo=None) + # Use timezone-aware comparison (datetime.utcnow() is deprecated) + now = datetime.now(UTC) + expires_at = self.expires_at + # Handle both timezone-aware and naive datetimes from DB + if expires_at.tzinfo is None: + expires_at = expires_at.replace(tzinfo=UTC) + return now > expires_at @property def is_valid(self) -> bool: diff --git a/backend/app/models/oauth_provider_token.py b/backend/app/models/oauth_provider_token.py index 2f99826..765d6d7 100644 --- a/backend/app/models/oauth_provider_token.py +++ b/backend/app/models/oauth_provider_token.py @@ -1,6 +1,6 @@ """OAuth provider token models for OAuth provider mode.""" -from datetime import datetime +from datetime import UTC, datetime from sqlalchemy import Boolean, Column, DateTime, ForeignKey, Index, String from sqlalchemy.dialects.postgresql import UUID @@ -90,7 +90,13 @@ class OAuthProviderRefreshToken(Base, UUIDMixin, TimestampMixin): @property def is_expired(self) -> bool: """Check if the refresh token has expired.""" - return datetime.utcnow() > self.expires_at.replace(tzinfo=None) + # Use timezone-aware comparison (datetime.utcnow() is deprecated) + now = datetime.now(UTC) + expires_at = self.expires_at + # Handle both timezone-aware and naive datetimes from DB + if expires_at.tzinfo is None: + expires_at = expires_at.replace(tzinfo=UTC) + return now > expires_at @property def is_valid(self) -> bool: diff --git a/backend/app/services/oauth_provider_service.py b/backend/app/services/oauth_provider_service.py index a198364..80ab249 100644 --- a/backend/app/services/oauth_provider_service.py +++ b/backend/app/services/oauth_provider_service.py @@ -349,22 +349,51 @@ async def exchange_authorization_code( InvalidGrantError: If code is invalid, expired, or already used InvalidClientError: If client validation fails """ - # Get and validate authorization code - result = await db.execute( - select(OAuthAuthorizationCode).where(OAuthAuthorizationCode.code == code) - ) - auth_code = result.scalar_one_or_none() + # Atomically mark code as used and fetch it (prevents race condition) + # RFC 6749 Section 4.1.2: Authorization codes MUST be single-use + from sqlalchemy import update - if not auth_code: - raise InvalidGrantError("Invalid authorization code") - - if auth_code.used: - # Code reuse is a security incident - revoke all tokens for this grant - logger.warning( - f"Authorization code reuse detected for client {auth_code.client_id}" + # First, atomically mark the code as used and get affected count + update_stmt = ( + update(OAuthAuthorizationCode) + .where( + and_( + OAuthAuthorizationCode.code == code, + OAuthAuthorizationCode.used == False, # noqa: E712 + ) ) - await revoke_tokens_for_user_client(db, auth_code.user_id, auth_code.client_id) - raise InvalidGrantError("Authorization code has already been used") + .values(used=True) + .returning(OAuthAuthorizationCode.id) + ) + result = await db.execute(update_stmt) + updated_id = result.scalar_one_or_none() + + if not updated_id: + # Either code doesn't exist or was already used + # Check if it exists to provide appropriate error + check_result = await db.execute( + select(OAuthAuthorizationCode).where(OAuthAuthorizationCode.code == code) + ) + existing_code = check_result.scalar_one_or_none() + + if existing_code and existing_code.used: + # Code reuse is a security incident - revoke all tokens for this grant + logger.warning( + f"Authorization code reuse detected for client {existing_code.client_id}" + ) + await revoke_tokens_for_user_client( + db, existing_code.user_id, existing_code.client_id + ) + raise InvalidGrantError("Authorization code has already been used") + else: + raise InvalidGrantError("Invalid authorization code") + + # Now fetch the full auth code record + result = await db.execute( + select(OAuthAuthorizationCode).where(OAuthAuthorizationCode.id == updated_id) + ) + auth_code = result.scalar_one() + await db.commit() if auth_code.is_expired: raise InvalidGrantError("Authorization code has expired") @@ -375,13 +404,19 @@ async def exchange_authorization_code( if auth_code.redirect_uri != redirect_uri: raise InvalidGrantError("redirect_uri mismatch") - # Validate client - client = await validate_client( - db, - client_id, - client_secret, - require_secret=(client_secret is not None), - ) + # Validate client - ALWAYS require secret for confidential clients + client = await get_client(db, client_id) + if not client: + raise InvalidClientError("Unknown client_id") + + # Confidential clients MUST authenticate (RFC 6749 Section 3.2.1) + if client.client_type == "confidential": + if not client_secret: + raise InvalidClientError("Client secret required for confidential clients") + client = await validate_client(db, client_id, client_secret, require_secret=True) + elif client_secret: + # Public client provided secret - validate it if given + client = await validate_client(db, client_id, client_secret, require_secret=True) # Verify PKCE if auth_code.code_challenge: @@ -397,10 +432,6 @@ async def exchange_authorization_code( # Public clients without PKCE - this shouldn't happen if we validated on authorize raise InvalidGrantError("PKCE required for public clients") - # Mark code as used (single-use) - auth_code.used = True - await db.commit() - # Get user user_result = await db.execute(select(User).where(User.id == auth_code.user_id)) user = user_result.scalar_one_or_none() diff --git a/backend/app/services/oauth_service.py b/backend/app/services/oauth_service.py index 26464b3..d52d459 100644 --- a/backend/app/services/oauth_service.py +++ b/backend/app/services/oauth_service.py @@ -246,6 +246,15 @@ class OAuthService: if not state_record: raise AuthenticationError("Invalid or expired OAuth state") + # SECURITY: Validate redirect_uri matches the one from authorization request + # This prevents authorization code injection attacks (RFC 6749 Section 10.6) + if state_record.redirect_uri != redirect_uri: + logger.warning( + f"OAuth redirect_uri mismatch: expected {state_record.redirect_uri}, " + f"got {redirect_uri}" + ) + raise AuthenticationError("Redirect URI mismatch") + # Extract provider from state record (str for type safety) provider: str = str(state_record.provider) @@ -272,6 +281,38 @@ class OAuthService: config["token_url"], **token_params, ) + + # SECURITY: Validate nonce in ID token for OpenID Connect (Google) + # This prevents token replay attacks (OpenID Connect Core 3.1.3.7) + if provider == "google" and state_record.nonce: + id_token = token.get("id_token") + if id_token: + import base64 + import json + + # Decode ID token payload (JWT format: header.payload.signature) + try: + payload_b64 = id_token.split(".")[1] + # Add padding if needed + padding = 4 - len(payload_b64) % 4 + if padding != 4: + payload_b64 += "=" * padding + payload_json = base64.urlsafe_b64decode(payload_b64) + payload = json.loads(payload_json) + + token_nonce = payload.get("nonce") + if token_nonce != state_record.nonce: + logger.warning( + f"OAuth nonce mismatch: expected {state_record.nonce}, " + f"got {token_nonce}" + ) + raise AuthenticationError("Invalid ID token nonce") + except (IndexError, ValueError, json.JSONDecodeError) as e: + logger.warning(f"Failed to decode ID token for nonce validation: {e}") + # Continue without nonce validation if ID token is malformed + # The token will still be validated when getting user info + except AuthenticationError: + raise except Exception as e: logger.error(f"OAuth token exchange failed: {e!s}") raise AuthenticationError("Failed to exchange authorization code") @@ -294,8 +335,9 @@ class OAuthService: # Process user info and create/link account provider_user_id = str(user_info.get("id") or user_info.get("sub")) # Email can be None if user didn't grant email permission + # SECURITY: Normalize email (lowercase, strip) to prevent case-based account duplication email_raw = user_info.get("email") - provider_email: str | None = str(email_raw) if email_raw else None + provider_email: str | None = str(email_raw).lower().strip() if email_raw else None if not provider_user_id: raise AuthenticationError("Provider did not return user ID") diff --git a/backend/tests/e2e/conftest.py b/backend/tests/e2e/conftest.py index b996372..b085f96 100644 --- a/backend/tests/e2e/conftest.py +++ b/backend/tests/e2e/conftest.py @@ -214,9 +214,6 @@ async def e2e_superuser(e2e_client): """ from uuid import uuid4 - from app.crud.user import user as user_crud - from app.schemas.users import UserCreate - email = f"admin-{uuid4().hex[:8]}@example.com" password = "SuperAdmin123!" diff --git a/backend/tests/e2e/test_admin_workflows.py b/backend/tests/e2e/test_admin_workflows.py index 2372afd..6cfe15e 100644 --- a/backend/tests/e2e/test_admin_workflows.py +++ b/backend/tests/e2e/test_admin_workflows.py @@ -21,7 +21,7 @@ pytestmark = [ ] -async def register_user(client, email: str, password: str = "SecurePassword123!"): +async def register_user(client, email: str, password: str = "SecurePassword123!"): # noqa: S107 """Helper to register a user.""" resp = await client.post( "/api/v1/auth/register", @@ -35,7 +35,7 @@ async def register_user(client, email: str, password: str = "SecurePassword123!" return resp.json() -async def login_user(client, email: str, password: str = "SecurePassword123!"): +async def login_user(client, email: str, password: str = "SecurePassword123!"): # noqa: S107 """Helper to login a user.""" resp = await client.post( "/api/v1/auth/login", diff --git a/backend/tests/e2e/test_organization_workflows.py b/backend/tests/e2e/test_organization_workflows.py index ba8faa0..3cd038d 100644 --- a/backend/tests/e2e/test_organization_workflows.py +++ b/backend/tests/e2e/test_organization_workflows.py @@ -22,7 +22,7 @@ pytestmark = [ ] -async def register_and_login(client, email: str, password: str = "SecurePassword123!"): +async def register_and_login(client, email: str, password: str = "SecurePassword123!"): # noqa: S107 """Helper to register a user and get tokens.""" # Register await client.post( diff --git a/backend/tests/services/test_oauth_service.py b/backend/tests/services/test_oauth_service.py index 084b1d6..7fad254 100644 --- a/backend/tests/services/test_oauth_service.py +++ b/backend/tests/services/test_oauth_service.py @@ -451,6 +451,7 @@ class TestHandleCallbackComplete: state="valid_state_login", provider="google", code_verifier="test_verifier", + redirect_uri="http://localhost:3000/callback", expires_at=datetime.now(UTC) + timedelta(minutes=10), ) await oauth_state.create_state(session, obj_in=state_data) @@ -533,6 +534,7 @@ class TestHandleCallbackComplete: state_data = OAuthStateCreate( state="valid_state_inactive", provider="google", + redirect_uri="http://localhost:3000/callback", expires_at=datetime.now(UTC) + timedelta(minutes=10), ) await oauth_state.create_state(session, obj_in=state_data) @@ -583,6 +585,7 @@ class TestHandleCallbackComplete: state="valid_state_linking", provider="github", user_id=async_test_user.id, # User is logged in + redirect_uri="http://localhost:3000/callback", expires_at=datetime.now(UTC) + timedelta(minutes=10), ) await oauth_state.create_state(session, obj_in=state_data) @@ -648,6 +651,7 @@ class TestHandleCallbackComplete: state="valid_state_bad_user", provider="google", user_id=uuid4(), # Non-existent user + redirect_uri="http://localhost:3000/callback", expires_at=datetime.now(UTC) + timedelta(minutes=10), ) await oauth_state.create_state(session, obj_in=state_data) @@ -707,6 +711,7 @@ class TestHandleCallbackComplete: state="valid_state_already_linked", provider="google", user_id=async_test_user.id, + redirect_uri="http://localhost:3000/callback", expires_at=datetime.now(UTC) + timedelta(minutes=10), ) await oauth_state.create_state(session, obj_in=state_data) @@ -769,6 +774,7 @@ class TestHandleCallbackComplete: state_data = OAuthStateCreate( state="valid_state_autolink", provider="google", + redirect_uri="http://localhost:3000/callback", expires_at=datetime.now(UTC) + timedelta(minutes=10), ) await oauth_state.create_state(session, obj_in=state_data) @@ -832,6 +838,7 @@ class TestHandleCallbackComplete: state_data = OAuthStateCreate( state="valid_state_new_user", provider="google", + redirect_uri="http://localhost:3000/callback", expires_at=datetime.now(UTC) + timedelta(minutes=10), ) await oauth_state.create_state(session, obj_in=state_data) @@ -904,6 +911,7 @@ class TestHandleCallbackComplete: state_data = OAuthStateCreate( state="valid_state_no_email", provider="github", + redirect_uri="http://localhost:3000/callback", expires_at=datetime.now(UTC) + timedelta(minutes=10), ) await oauth_state.create_state(session, obj_in=state_data) @@ -961,6 +969,7 @@ class TestHandleCallbackComplete: state_data = OAuthStateCreate( state="valid_state_token_fail", provider="google", + redirect_uri="http://localhost:3000/callback", expires_at=datetime.now(UTC) + timedelta(minutes=10), ) await oauth_state.create_state(session, obj_in=state_data) @@ -1004,6 +1013,7 @@ class TestHandleCallbackComplete: state_data = OAuthStateCreate( state="valid_state_userinfo_fail", provider="google", + redirect_uri="http://localhost:3000/callback", expires_at=datetime.now(UTC) + timedelta(minutes=10), ) await oauth_state.create_state(session, obj_in=state_data) @@ -1047,6 +1057,7 @@ class TestHandleCallbackComplete: state_data = OAuthStateCreate( state="valid_state_no_token", provider="google", + redirect_uri="http://localhost:3000/callback", expires_at=datetime.now(UTC) + timedelta(minutes=10), ) await oauth_state.create_state(session, obj_in=state_data) @@ -1090,6 +1101,7 @@ class TestHandleCallbackComplete: state_data = OAuthStateCreate( state="valid_state_no_user_id", provider="google", + redirect_uri="http://localhost:3000/callback", expires_at=datetime.now(UTC) + timedelta(minutes=10), ) await oauth_state.create_state(session, obj_in=state_data) diff --git a/frontend/messages/en.json b/frontend/messages/en.json index f7035bf..888bce2 100644 --- a/frontend/messages/en.json +++ b/frontend/messages/en.json @@ -153,6 +153,7 @@ "authFailed": "Authentication Failed", "providerError": "The authentication provider returned an error", "missingParams": "Missing authentication parameters", + "stateMismatch": "Invalid OAuth state. Please try again.", "unexpectedError": "An unexpected error occurred during authentication", "backToLogin": "Back to Login" } diff --git a/frontend/messages/it.json b/frontend/messages/it.json index 523c115..4c65efc 100644 --- a/frontend/messages/it.json +++ b/frontend/messages/it.json @@ -153,6 +153,7 @@ "authFailed": "Autenticazione Fallita", "providerError": "Il provider di autenticazione ha restituito un errore", "missingParams": "Parametri di autenticazione mancanti", + "stateMismatch": "Stato OAuth non valido. Riprova.", "unexpectedError": "Si è verificato un errore durante l'autenticazione", "backToLogin": "Torna al Login" } diff --git a/frontend/src/app/[locale]/(auth)/auth/callback/[provider]/page.tsx b/frontend/src/app/[locale]/(auth)/auth/callback/[provider]/page.tsx index 62f452a..860007b 100644 --- a/frontend/src/app/[locale]/(auth)/auth/callback/[provider]/page.tsx +++ b/frontend/src/app/[locale]/(auth)/auth/callback/[provider]/page.tsx @@ -53,6 +53,18 @@ export default function OAuthCallbackPage() { return; } + // SECURITY: Validate state parameter against stored value (CSRF protection) + // This prevents cross-site request forgery attacks + const storedState = sessionStorage.getItem('oauth_state'); + if (!storedState || storedState !== state) { + // Clean up stored state on mismatch + sessionStorage.removeItem('oauth_state'); + sessionStorage.removeItem('oauth_mode'); + sessionStorage.removeItem('oauth_provider'); + setError(t('stateMismatch') || 'Invalid OAuth state. Please try again.'); + return; + } + hasProcessed.current = true; // Process the OAuth callback diff --git a/frontend/src/lib/api/hooks/useOAuth.ts b/frontend/src/lib/api/hooks/useOAuth.ts index f5634e0..2069be0 100644 --- a/frontend/src/lib/api/hooks/useOAuth.ts +++ b/frontend/src/lib/api/hooks/useOAuth.ts @@ -56,6 +56,44 @@ export function useOAuthProviders() { // OAuth Flow Mutations // ============================================================================ +// Allowed OAuth provider domains for security validation +const ALLOWED_OAUTH_DOMAINS = [ + 'accounts.google.com', + 'github.com', + 'www.facebook.com', // For future Facebook support + 'login.microsoftonline.com', // For future Microsoft support +]; + +/** + * Validate OAuth authorization URL + * SECURITY: Prevents open redirect attacks by only allowing known OAuth provider domains + */ +function isValidOAuthUrl(url: string): boolean { + try { + const parsed = new URL(url); + // Only allow HTTPS for OAuth (security requirement) + if (parsed.protocol !== 'https:') { + return false; + } + // Check if domain is in allowlist + return ALLOWED_OAUTH_DOMAINS.includes(parsed.hostname); + } catch { + return false; + } +} + +/** + * Extract state parameter from OAuth authorization URL + */ +function extractStateFromUrl(url: string): string | null { + try { + const parsed = new URL(url); + return parsed.searchParams.get('state'); + } catch { + return null; + } +} + /** * Start OAuth login/registration flow * Redirects user to the OAuth provider @@ -77,12 +115,27 @@ export function useOAuthStart() { }); if (response.data) { - // Store mode in sessionStorage for callback handling - sessionStorage.setItem('oauth_mode', mode); - sessionStorage.setItem('oauth_provider', provider); - // Response is { [key: string]: unknown }, so cast authorization_url const authUrl = (response.data as { authorization_url: string }).authorization_url; + + // SECURITY: Validate the authorization URL before redirecting + // This prevents open redirect attacks if the backend is compromised + if (!isValidOAuthUrl(authUrl)) { + throw new Error('Invalid OAuth authorization URL'); + } + + // SECURITY: Extract and store the state parameter for CSRF validation + // The callback page will verify this matches the state in the response + const state = extractStateFromUrl(authUrl); + if (!state) { + throw new Error('Missing state parameter in authorization URL'); + } + + // Store mode, provider, and state in sessionStorage for callback handling + sessionStorage.setItem('oauth_mode', mode); + sessionStorage.setItem('oauth_provider', provider); + sessionStorage.setItem('oauth_state', state); + // Redirect to OAuth provider window.location.href = authUrl; } @@ -151,14 +204,16 @@ export function useOAuthCallback() { queryClient.invalidateQueries({ queryKey: ['user'] }); } - // Clean up session storage + // Clean up session storage (including state for security) sessionStorage.removeItem('oauth_mode'); sessionStorage.removeItem('oauth_provider'); + sessionStorage.removeItem('oauth_state'); }, onError: () => { // Clean up session storage on error too sessionStorage.removeItem('oauth_mode'); sessionStorage.removeItem('oauth_provider'); + sessionStorage.removeItem('oauth_state'); }, }); } @@ -199,12 +254,25 @@ export function useOAuthLink() { }); if (response.data) { - // Store mode in sessionStorage for callback handling - sessionStorage.setItem('oauth_mode', 'link'); - sessionStorage.setItem('oauth_provider', provider); - // Response is { [key: string]: unknown }, so cast authorization_url const authUrl = (response.data as { authorization_url: string }).authorization_url; + + // SECURITY: Validate the authorization URL before redirecting + if (!isValidOAuthUrl(authUrl)) { + throw new Error('Invalid OAuth authorization URL'); + } + + // SECURITY: Extract and store the state parameter for CSRF validation + const state = extractStateFromUrl(authUrl); + if (!state) { + throw new Error('Missing state parameter in authorization URL'); + } + + // Store mode, provider, and state in sessionStorage for callback handling + sessionStorage.setItem('oauth_mode', 'link'); + sessionStorage.setItem('oauth_provider', provider); + sessionStorage.setItem('oauth_state', state); + // Redirect to OAuth provider window.location.href = authUrl; }