Enhance OAuth security and state validation

- Implemented stricter OAuth security measures, including CSRF protection via state parameter validation and redirect_uri checks.
- Updated OAuth models to support timezone-aware datetime comparisons, replacing deprecated `utcnow`.
- Enhanced logging for malformed Basic auth headers during token, introspect, and revoke requests.
- Added allowlist validation for OAuth provider domains to prevent open redirect attacks.
- Improved nonce validation for OpenID Connect tokens, ensuring token integrity during Google provider flows.
- Updated E2E and unit tests to cover new security features and expanded OAuth state handling scenarios.
This commit is contained in:
Felipe Cardoso
2025-11-25 23:50:43 +01:00
parent 7716468d72
commit 400d6f6f75
14 changed files with 246 additions and 57 deletions

View File

@@ -42,7 +42,7 @@ Default superuser (change in production):
│ │ ├── schemas/ # Pydantic request/response schemas │ │ ├── schemas/ # Pydantic request/response schemas
│ │ ├── services/ # Business logic layer │ │ ├── services/ # Business logic layer
│ │ └── utils/ # Utilities (security, device detection) │ │ └── utils/ # Utilities (security, device detection)
│ ├── tests/ # 97% coverage, 743 tests │ ├── tests/ # 96% coverage, 987 tests
│ └── alembic/ # Database migrations │ └── alembic/ # Database migrations
└── frontend/ # Next.js 15 frontend └── frontend/ # Next.js 15 frontend
@@ -128,7 +128,7 @@ Permission dependencies in `api/dependencies/permissions.py`:
### Testing Infrastructure ### Testing Infrastructure
**Backend Unit/Integration (pytest + SQLite):** **Backend Unit/Integration (pytest + SQLite):**
- 97% coverage, 743+ tests - 96% coverage, 987 tests
- Security-focused: JWT attacks, session hijacking, privilege escalation - Security-focused: JWT attacks, session hijacking, privilege escalation
- Async fixtures in `tests/conftest.py` - Async fixtures in `tests/conftest.py`
- Run: `IS_TEST=True uv run pytest` or `make test` - Run: `IS_TEST=True uv run pytest` or `make test`
@@ -265,7 +265,7 @@ docker-compose exec backend python -c "from app.init_db import init_db; import a
- Organization system (multi-tenant with RBAC) - Organization system (multi-tenant with RBAC)
- Admin panel (user/org management, bulk operations) - Admin panel (user/org management, bulk operations)
- **Internationalization (i18n)** with English and Italian - **Internationalization (i18n)** with English and Italian
- Comprehensive test coverage (97% backend, 97% frontend unit, 56 E2E tests) - Comprehensive test coverage (96% backend, 97% frontend unit, 56 E2E tests)
- Design system documentation - Design system documentation
- **Marketing landing page** with animations - **Marketing landing page** with animations
- **`/dev` documentation portal** with live examples - **`/dev` documentation portal** with live examples

View File

@@ -169,11 +169,12 @@ async def authorize(
detail="invalid_request: response_type must be 'code'", detail="invalid_request: response_type must be 'code'",
) )
# Validate PKCE method if provided # Validate PKCE method if provided - ONLY S256 is allowed (RFC 7636 Section 4.3)
if code_challenge_method and code_challenge_method not in ["S256", "plain"]: # "plain" method provides no security benefit and MUST NOT be used
if code_challenge_method and code_challenge_method != "S256":
raise HTTPException( raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST, status_code=status.HTTP_400_BAD_REQUEST,
detail="invalid_request: code_challenge_method must be 'S256'", detail="invalid_request: code_challenge_method must be 'S256' (plain is not supported)",
) )
# Validate client # Validate client
@@ -441,8 +442,12 @@ async def token(
try: try:
decoded = base64.b64decode(auth_header[6:]).decode() decoded = base64.b64decode(auth_header[6:]).decode()
client_id, client_secret = decoded.split(":", 1) client_id, client_secret = decoded.split(":", 1)
except Exception: # noqa: S110 - Intentional: malformed Basic auth falls back to form body except Exception as e:
pass # Log malformed Basic auth for security monitoring
logger.warning(
f"Malformed Basic auth header in token request: {type(e).__name__}"
)
# Fall back to form body
if not client_id: if not client_id:
raise HTTPException( raise HTTPException(
@@ -547,8 +552,12 @@ async def revoke(
try: try:
decoded = base64.b64decode(auth_header[6:]).decode() decoded = base64.b64decode(auth_header[6:]).decode()
client_id, client_secret = decoded.split(":", 1) client_id, client_secret = decoded.split(":", 1)
except Exception: # noqa: S110 - Intentional: malformed Basic auth falls back to form body except Exception as e:
pass # Log malformed Basic auth for security monitoring
logger.warning(
f"Malformed Basic auth header in revoke request: {type(e).__name__}"
)
# Fall back to form body
try: try:
await provider_service.revoke_token( await provider_service.revoke_token(
@@ -613,8 +622,12 @@ async def introspect(
try: try:
decoded = base64.b64decode(auth_header[6:]).decode() decoded = base64.b64decode(auth_header[6:]).decode()
client_id, client_secret = decoded.split(":", 1) client_id, client_secret = decoded.split(":", 1)
except Exception: # noqa: S110 - Intentional: malformed Basic auth falls back to form body except Exception as e:
pass # Log malformed Basic auth for security monitoring
logger.warning(
f"Malformed Basic auth header in introspect request: {type(e).__name__}"
)
# Fall back to form body
try: try:
result = await provider_service.introspect_token( result = await provider_service.introspect_token(

View File

@@ -1,6 +1,6 @@
"""OAuth authorization code model for OAuth provider mode.""" """OAuth authorization code model for OAuth provider mode."""
from datetime import datetime from datetime import UTC, datetime
from sqlalchemy import Boolean, Column, DateTime, ForeignKey, Index, String from sqlalchemy import Boolean, Column, DateTime, ForeignKey, Index, String
from sqlalchemy.dialects.postgresql import UUID from sqlalchemy.dialects.postgresql import UUID
@@ -83,7 +83,13 @@ class OAuthAuthorizationCode(Base, UUIDMixin, TimestampMixin):
@property @property
def is_expired(self) -> bool: def is_expired(self) -> bool:
"""Check if the authorization code has expired.""" """Check if the authorization code has expired."""
return datetime.utcnow() > self.expires_at.replace(tzinfo=None) # Use timezone-aware comparison (datetime.utcnow() is deprecated)
now = datetime.now(UTC)
expires_at = self.expires_at
# Handle both timezone-aware and naive datetimes from DB
if expires_at.tzinfo is None:
expires_at = expires_at.replace(tzinfo=UTC)
return now > expires_at
@property @property
def is_valid(self) -> bool: def is_valid(self) -> bool:

View File

@@ -1,6 +1,6 @@
"""OAuth provider token models for OAuth provider mode.""" """OAuth provider token models for OAuth provider mode."""
from datetime import datetime from datetime import UTC, datetime
from sqlalchemy import Boolean, Column, DateTime, ForeignKey, Index, String from sqlalchemy import Boolean, Column, DateTime, ForeignKey, Index, String
from sqlalchemy.dialects.postgresql import UUID from sqlalchemy.dialects.postgresql import UUID
@@ -90,7 +90,13 @@ class OAuthProviderRefreshToken(Base, UUIDMixin, TimestampMixin):
@property @property
def is_expired(self) -> bool: def is_expired(self) -> bool:
"""Check if the refresh token has expired.""" """Check if the refresh token has expired."""
return datetime.utcnow() > self.expires_at.replace(tzinfo=None) # Use timezone-aware comparison (datetime.utcnow() is deprecated)
now = datetime.now(UTC)
expires_at = self.expires_at
# Handle both timezone-aware and naive datetimes from DB
if expires_at.tzinfo is None:
expires_at = expires_at.replace(tzinfo=UTC)
return now > expires_at
@property @property
def is_valid(self) -> bool: def is_valid(self) -> bool:

View File

@@ -349,22 +349,51 @@ async def exchange_authorization_code(
InvalidGrantError: If code is invalid, expired, or already used InvalidGrantError: If code is invalid, expired, or already used
InvalidClientError: If client validation fails InvalidClientError: If client validation fails
""" """
# Get and validate authorization code # Atomically mark code as used and fetch it (prevents race condition)
result = await db.execute( # RFC 6749 Section 4.1.2: Authorization codes MUST be single-use
from sqlalchemy import update
# First, atomically mark the code as used and get affected count
update_stmt = (
update(OAuthAuthorizationCode)
.where(
and_(
OAuthAuthorizationCode.code == code,
OAuthAuthorizationCode.used == False, # noqa: E712
)
)
.values(used=True)
.returning(OAuthAuthorizationCode.id)
)
result = await db.execute(update_stmt)
updated_id = result.scalar_one_or_none()
if not updated_id:
# Either code doesn't exist or was already used
# Check if it exists to provide appropriate error
check_result = await db.execute(
select(OAuthAuthorizationCode).where(OAuthAuthorizationCode.code == code) select(OAuthAuthorizationCode).where(OAuthAuthorizationCode.code == code)
) )
auth_code = result.scalar_one_or_none() existing_code = check_result.scalar_one_or_none()
if not auth_code: if existing_code and existing_code.used:
raise InvalidGrantError("Invalid authorization code")
if auth_code.used:
# Code reuse is a security incident - revoke all tokens for this grant # Code reuse is a security incident - revoke all tokens for this grant
logger.warning( logger.warning(
f"Authorization code reuse detected for client {auth_code.client_id}" f"Authorization code reuse detected for client {existing_code.client_id}"
)
await revoke_tokens_for_user_client(
db, existing_code.user_id, existing_code.client_id
) )
await revoke_tokens_for_user_client(db, auth_code.user_id, auth_code.client_id)
raise InvalidGrantError("Authorization code has already been used") raise InvalidGrantError("Authorization code has already been used")
else:
raise InvalidGrantError("Invalid authorization code")
# Now fetch the full auth code record
result = await db.execute(
select(OAuthAuthorizationCode).where(OAuthAuthorizationCode.id == updated_id)
)
auth_code = result.scalar_one()
await db.commit()
if auth_code.is_expired: if auth_code.is_expired:
raise InvalidGrantError("Authorization code has expired") raise InvalidGrantError("Authorization code has expired")
@@ -375,13 +404,19 @@ async def exchange_authorization_code(
if auth_code.redirect_uri != redirect_uri: if auth_code.redirect_uri != redirect_uri:
raise InvalidGrantError("redirect_uri mismatch") raise InvalidGrantError("redirect_uri mismatch")
# Validate client # Validate client - ALWAYS require secret for confidential clients
client = await validate_client( client = await get_client(db, client_id)
db, if not client:
client_id, raise InvalidClientError("Unknown client_id")
client_secret,
require_secret=(client_secret is not None), # Confidential clients MUST authenticate (RFC 6749 Section 3.2.1)
) if client.client_type == "confidential":
if not client_secret:
raise InvalidClientError("Client secret required for confidential clients")
client = await validate_client(db, client_id, client_secret, require_secret=True)
elif client_secret:
# Public client provided secret - validate it if given
client = await validate_client(db, client_id, client_secret, require_secret=True)
# Verify PKCE # Verify PKCE
if auth_code.code_challenge: if auth_code.code_challenge:
@@ -397,10 +432,6 @@ async def exchange_authorization_code(
# Public clients without PKCE - this shouldn't happen if we validated on authorize # Public clients without PKCE - this shouldn't happen if we validated on authorize
raise InvalidGrantError("PKCE required for public clients") raise InvalidGrantError("PKCE required for public clients")
# Mark code as used (single-use)
auth_code.used = True
await db.commit()
# Get user # Get user
user_result = await db.execute(select(User).where(User.id == auth_code.user_id)) user_result = await db.execute(select(User).where(User.id == auth_code.user_id))
user = user_result.scalar_one_or_none() user = user_result.scalar_one_or_none()

View File

@@ -246,6 +246,15 @@ class OAuthService:
if not state_record: if not state_record:
raise AuthenticationError("Invalid or expired OAuth state") raise AuthenticationError("Invalid or expired OAuth state")
# SECURITY: Validate redirect_uri matches the one from authorization request
# This prevents authorization code injection attacks (RFC 6749 Section 10.6)
if state_record.redirect_uri != redirect_uri:
logger.warning(
f"OAuth redirect_uri mismatch: expected {state_record.redirect_uri}, "
f"got {redirect_uri}"
)
raise AuthenticationError("Redirect URI mismatch")
# Extract provider from state record (str for type safety) # Extract provider from state record (str for type safety)
provider: str = str(state_record.provider) provider: str = str(state_record.provider)
@@ -272,6 +281,38 @@ class OAuthService:
config["token_url"], config["token_url"],
**token_params, **token_params,
) )
# SECURITY: Validate nonce in ID token for OpenID Connect (Google)
# This prevents token replay attacks (OpenID Connect Core 3.1.3.7)
if provider == "google" and state_record.nonce:
id_token = token.get("id_token")
if id_token:
import base64
import json
# Decode ID token payload (JWT format: header.payload.signature)
try:
payload_b64 = id_token.split(".")[1]
# Add padding if needed
padding = 4 - len(payload_b64) % 4
if padding != 4:
payload_b64 += "=" * padding
payload_json = base64.urlsafe_b64decode(payload_b64)
payload = json.loads(payload_json)
token_nonce = payload.get("nonce")
if token_nonce != state_record.nonce:
logger.warning(
f"OAuth nonce mismatch: expected {state_record.nonce}, "
f"got {token_nonce}"
)
raise AuthenticationError("Invalid ID token nonce")
except (IndexError, ValueError, json.JSONDecodeError) as e:
logger.warning(f"Failed to decode ID token for nonce validation: {e}")
# Continue without nonce validation if ID token is malformed
# The token will still be validated when getting user info
except AuthenticationError:
raise
except Exception as e: except Exception as e:
logger.error(f"OAuth token exchange failed: {e!s}") logger.error(f"OAuth token exchange failed: {e!s}")
raise AuthenticationError("Failed to exchange authorization code") raise AuthenticationError("Failed to exchange authorization code")
@@ -294,8 +335,9 @@ class OAuthService:
# Process user info and create/link account # Process user info and create/link account
provider_user_id = str(user_info.get("id") or user_info.get("sub")) provider_user_id = str(user_info.get("id") or user_info.get("sub"))
# Email can be None if user didn't grant email permission # Email can be None if user didn't grant email permission
# SECURITY: Normalize email (lowercase, strip) to prevent case-based account duplication
email_raw = user_info.get("email") email_raw = user_info.get("email")
provider_email: str | None = str(email_raw) if email_raw else None provider_email: str | None = str(email_raw).lower().strip() if email_raw else None
if not provider_user_id: if not provider_user_id:
raise AuthenticationError("Provider did not return user ID") raise AuthenticationError("Provider did not return user ID")

View File

@@ -214,9 +214,6 @@ async def e2e_superuser(e2e_client):
""" """
from uuid import uuid4 from uuid import uuid4
from app.crud.user import user as user_crud
from app.schemas.users import UserCreate
email = f"admin-{uuid4().hex[:8]}@example.com" email = f"admin-{uuid4().hex[:8]}@example.com"
password = "SuperAdmin123!" password = "SuperAdmin123!"

View File

@@ -21,7 +21,7 @@ pytestmark = [
] ]
async def register_user(client, email: str, password: str = "SecurePassword123!"): async def register_user(client, email: str, password: str = "SecurePassword123!"): # noqa: S107
"""Helper to register a user.""" """Helper to register a user."""
resp = await client.post( resp = await client.post(
"/api/v1/auth/register", "/api/v1/auth/register",
@@ -35,7 +35,7 @@ async def register_user(client, email: str, password: str = "SecurePassword123!"
return resp.json() return resp.json()
async def login_user(client, email: str, password: str = "SecurePassword123!"): async def login_user(client, email: str, password: str = "SecurePassword123!"): # noqa: S107
"""Helper to login a user.""" """Helper to login a user."""
resp = await client.post( resp = await client.post(
"/api/v1/auth/login", "/api/v1/auth/login",

View File

@@ -22,7 +22,7 @@ pytestmark = [
] ]
async def register_and_login(client, email: str, password: str = "SecurePassword123!"): async def register_and_login(client, email: str, password: str = "SecurePassword123!"): # noqa: S107
"""Helper to register a user and get tokens.""" """Helper to register a user and get tokens."""
# Register # Register
await client.post( await client.post(

View File

@@ -451,6 +451,7 @@ class TestHandleCallbackComplete:
state="valid_state_login", state="valid_state_login",
provider="google", provider="google",
code_verifier="test_verifier", code_verifier="test_verifier",
redirect_uri="http://localhost:3000/callback",
expires_at=datetime.now(UTC) + timedelta(minutes=10), expires_at=datetime.now(UTC) + timedelta(minutes=10),
) )
await oauth_state.create_state(session, obj_in=state_data) await oauth_state.create_state(session, obj_in=state_data)
@@ -533,6 +534,7 @@ class TestHandleCallbackComplete:
state_data = OAuthStateCreate( state_data = OAuthStateCreate(
state="valid_state_inactive", state="valid_state_inactive",
provider="google", provider="google",
redirect_uri="http://localhost:3000/callback",
expires_at=datetime.now(UTC) + timedelta(minutes=10), expires_at=datetime.now(UTC) + timedelta(minutes=10),
) )
await oauth_state.create_state(session, obj_in=state_data) await oauth_state.create_state(session, obj_in=state_data)
@@ -583,6 +585,7 @@ class TestHandleCallbackComplete:
state="valid_state_linking", state="valid_state_linking",
provider="github", provider="github",
user_id=async_test_user.id, # User is logged in user_id=async_test_user.id, # User is logged in
redirect_uri="http://localhost:3000/callback",
expires_at=datetime.now(UTC) + timedelta(minutes=10), expires_at=datetime.now(UTC) + timedelta(minutes=10),
) )
await oauth_state.create_state(session, obj_in=state_data) await oauth_state.create_state(session, obj_in=state_data)
@@ -648,6 +651,7 @@ class TestHandleCallbackComplete:
state="valid_state_bad_user", state="valid_state_bad_user",
provider="google", provider="google",
user_id=uuid4(), # Non-existent user user_id=uuid4(), # Non-existent user
redirect_uri="http://localhost:3000/callback",
expires_at=datetime.now(UTC) + timedelta(minutes=10), expires_at=datetime.now(UTC) + timedelta(minutes=10),
) )
await oauth_state.create_state(session, obj_in=state_data) await oauth_state.create_state(session, obj_in=state_data)
@@ -707,6 +711,7 @@ class TestHandleCallbackComplete:
state="valid_state_already_linked", state="valid_state_already_linked",
provider="google", provider="google",
user_id=async_test_user.id, user_id=async_test_user.id,
redirect_uri="http://localhost:3000/callback",
expires_at=datetime.now(UTC) + timedelta(minutes=10), expires_at=datetime.now(UTC) + timedelta(minutes=10),
) )
await oauth_state.create_state(session, obj_in=state_data) await oauth_state.create_state(session, obj_in=state_data)
@@ -769,6 +774,7 @@ class TestHandleCallbackComplete:
state_data = OAuthStateCreate( state_data = OAuthStateCreate(
state="valid_state_autolink", state="valid_state_autolink",
provider="google", provider="google",
redirect_uri="http://localhost:3000/callback",
expires_at=datetime.now(UTC) + timedelta(minutes=10), expires_at=datetime.now(UTC) + timedelta(minutes=10),
) )
await oauth_state.create_state(session, obj_in=state_data) await oauth_state.create_state(session, obj_in=state_data)
@@ -832,6 +838,7 @@ class TestHandleCallbackComplete:
state_data = OAuthStateCreate( state_data = OAuthStateCreate(
state="valid_state_new_user", state="valid_state_new_user",
provider="google", provider="google",
redirect_uri="http://localhost:3000/callback",
expires_at=datetime.now(UTC) + timedelta(minutes=10), expires_at=datetime.now(UTC) + timedelta(minutes=10),
) )
await oauth_state.create_state(session, obj_in=state_data) await oauth_state.create_state(session, obj_in=state_data)
@@ -904,6 +911,7 @@ class TestHandleCallbackComplete:
state_data = OAuthStateCreate( state_data = OAuthStateCreate(
state="valid_state_no_email", state="valid_state_no_email",
provider="github", provider="github",
redirect_uri="http://localhost:3000/callback",
expires_at=datetime.now(UTC) + timedelta(minutes=10), expires_at=datetime.now(UTC) + timedelta(minutes=10),
) )
await oauth_state.create_state(session, obj_in=state_data) await oauth_state.create_state(session, obj_in=state_data)
@@ -961,6 +969,7 @@ class TestHandleCallbackComplete:
state_data = OAuthStateCreate( state_data = OAuthStateCreate(
state="valid_state_token_fail", state="valid_state_token_fail",
provider="google", provider="google",
redirect_uri="http://localhost:3000/callback",
expires_at=datetime.now(UTC) + timedelta(minutes=10), expires_at=datetime.now(UTC) + timedelta(minutes=10),
) )
await oauth_state.create_state(session, obj_in=state_data) await oauth_state.create_state(session, obj_in=state_data)
@@ -1004,6 +1013,7 @@ class TestHandleCallbackComplete:
state_data = OAuthStateCreate( state_data = OAuthStateCreate(
state="valid_state_userinfo_fail", state="valid_state_userinfo_fail",
provider="google", provider="google",
redirect_uri="http://localhost:3000/callback",
expires_at=datetime.now(UTC) + timedelta(minutes=10), expires_at=datetime.now(UTC) + timedelta(minutes=10),
) )
await oauth_state.create_state(session, obj_in=state_data) await oauth_state.create_state(session, obj_in=state_data)
@@ -1047,6 +1057,7 @@ class TestHandleCallbackComplete:
state_data = OAuthStateCreate( state_data = OAuthStateCreate(
state="valid_state_no_token", state="valid_state_no_token",
provider="google", provider="google",
redirect_uri="http://localhost:3000/callback",
expires_at=datetime.now(UTC) + timedelta(minutes=10), expires_at=datetime.now(UTC) + timedelta(minutes=10),
) )
await oauth_state.create_state(session, obj_in=state_data) await oauth_state.create_state(session, obj_in=state_data)
@@ -1090,6 +1101,7 @@ class TestHandleCallbackComplete:
state_data = OAuthStateCreate( state_data = OAuthStateCreate(
state="valid_state_no_user_id", state="valid_state_no_user_id",
provider="google", provider="google",
redirect_uri="http://localhost:3000/callback",
expires_at=datetime.now(UTC) + timedelta(minutes=10), expires_at=datetime.now(UTC) + timedelta(minutes=10),
) )
await oauth_state.create_state(session, obj_in=state_data) await oauth_state.create_state(session, obj_in=state_data)

View File

@@ -153,6 +153,7 @@
"authFailed": "Authentication Failed", "authFailed": "Authentication Failed",
"providerError": "The authentication provider returned an error", "providerError": "The authentication provider returned an error",
"missingParams": "Missing authentication parameters", "missingParams": "Missing authentication parameters",
"stateMismatch": "Invalid OAuth state. Please try again.",
"unexpectedError": "An unexpected error occurred during authentication", "unexpectedError": "An unexpected error occurred during authentication",
"backToLogin": "Back to Login" "backToLogin": "Back to Login"
} }

View File

@@ -153,6 +153,7 @@
"authFailed": "Autenticazione Fallita", "authFailed": "Autenticazione Fallita",
"providerError": "Il provider di autenticazione ha restituito un errore", "providerError": "Il provider di autenticazione ha restituito un errore",
"missingParams": "Parametri di autenticazione mancanti", "missingParams": "Parametri di autenticazione mancanti",
"stateMismatch": "Stato OAuth non valido. Riprova.",
"unexpectedError": "Si è verificato un errore durante l'autenticazione", "unexpectedError": "Si è verificato un errore durante l'autenticazione",
"backToLogin": "Torna al Login" "backToLogin": "Torna al Login"
} }

View File

@@ -53,6 +53,18 @@ export default function OAuthCallbackPage() {
return; return;
} }
// SECURITY: Validate state parameter against stored value (CSRF protection)
// This prevents cross-site request forgery attacks
const storedState = sessionStorage.getItem('oauth_state');
if (!storedState || storedState !== state) {
// Clean up stored state on mismatch
sessionStorage.removeItem('oauth_state');
sessionStorage.removeItem('oauth_mode');
sessionStorage.removeItem('oauth_provider');
setError(t('stateMismatch') || 'Invalid OAuth state. Please try again.');
return;
}
hasProcessed.current = true; hasProcessed.current = true;
// Process the OAuth callback // Process the OAuth callback

View File

@@ -56,6 +56,44 @@ export function useOAuthProviders() {
// OAuth Flow Mutations // OAuth Flow Mutations
// ============================================================================ // ============================================================================
// Allowed OAuth provider domains for security validation
const ALLOWED_OAUTH_DOMAINS = [
'accounts.google.com',
'github.com',
'www.facebook.com', // For future Facebook support
'login.microsoftonline.com', // For future Microsoft support
];
/**
* Validate OAuth authorization URL
* SECURITY: Prevents open redirect attacks by only allowing known OAuth provider domains
*/
function isValidOAuthUrl(url: string): boolean {
try {
const parsed = new URL(url);
// Only allow HTTPS for OAuth (security requirement)
if (parsed.protocol !== 'https:') {
return false;
}
// Check if domain is in allowlist
return ALLOWED_OAUTH_DOMAINS.includes(parsed.hostname);
} catch {
return false;
}
}
/**
* Extract state parameter from OAuth authorization URL
*/
function extractStateFromUrl(url: string): string | null {
try {
const parsed = new URL(url);
return parsed.searchParams.get('state');
} catch {
return null;
}
}
/** /**
* Start OAuth login/registration flow * Start OAuth login/registration flow
* Redirects user to the OAuth provider * Redirects user to the OAuth provider
@@ -77,12 +115,27 @@ export function useOAuthStart() {
}); });
if (response.data) { if (response.data) {
// Store mode in sessionStorage for callback handling
sessionStorage.setItem('oauth_mode', mode);
sessionStorage.setItem('oauth_provider', provider);
// Response is { [key: string]: unknown }, so cast authorization_url // Response is { [key: string]: unknown }, so cast authorization_url
const authUrl = (response.data as { authorization_url: string }).authorization_url; const authUrl = (response.data as { authorization_url: string }).authorization_url;
// SECURITY: Validate the authorization URL before redirecting
// This prevents open redirect attacks if the backend is compromised
if (!isValidOAuthUrl(authUrl)) {
throw new Error('Invalid OAuth authorization URL');
}
// SECURITY: Extract and store the state parameter for CSRF validation
// The callback page will verify this matches the state in the response
const state = extractStateFromUrl(authUrl);
if (!state) {
throw new Error('Missing state parameter in authorization URL');
}
// Store mode, provider, and state in sessionStorage for callback handling
sessionStorage.setItem('oauth_mode', mode);
sessionStorage.setItem('oauth_provider', provider);
sessionStorage.setItem('oauth_state', state);
// Redirect to OAuth provider // Redirect to OAuth provider
window.location.href = authUrl; window.location.href = authUrl;
} }
@@ -151,14 +204,16 @@ export function useOAuthCallback() {
queryClient.invalidateQueries({ queryKey: ['user'] }); queryClient.invalidateQueries({ queryKey: ['user'] });
} }
// Clean up session storage // Clean up session storage (including state for security)
sessionStorage.removeItem('oauth_mode'); sessionStorage.removeItem('oauth_mode');
sessionStorage.removeItem('oauth_provider'); sessionStorage.removeItem('oauth_provider');
sessionStorage.removeItem('oauth_state');
}, },
onError: () => { onError: () => {
// Clean up session storage on error too // Clean up session storage on error too
sessionStorage.removeItem('oauth_mode'); sessionStorage.removeItem('oauth_mode');
sessionStorage.removeItem('oauth_provider'); sessionStorage.removeItem('oauth_provider');
sessionStorage.removeItem('oauth_state');
}, },
}); });
} }
@@ -199,12 +254,25 @@ export function useOAuthLink() {
}); });
if (response.data) { if (response.data) {
// Store mode in sessionStorage for callback handling
sessionStorage.setItem('oauth_mode', 'link');
sessionStorage.setItem('oauth_provider', provider);
// Response is { [key: string]: unknown }, so cast authorization_url // Response is { [key: string]: unknown }, so cast authorization_url
const authUrl = (response.data as { authorization_url: string }).authorization_url; const authUrl = (response.data as { authorization_url: string }).authorization_url;
// SECURITY: Validate the authorization URL before redirecting
if (!isValidOAuthUrl(authUrl)) {
throw new Error('Invalid OAuth authorization URL');
}
// SECURITY: Extract and store the state parameter for CSRF validation
const state = extractStateFromUrl(authUrl);
if (!state) {
throw new Error('Missing state parameter in authorization URL');
}
// Store mode, provider, and state in sessionStorage for callback handling
sessionStorage.setItem('oauth_mode', 'link');
sessionStorage.setItem('oauth_provider', provider);
sessionStorage.setItem('oauth_state', state);
// Redirect to OAuth provider // Redirect to OAuth provider
window.location.href = authUrl; window.location.href = authUrl;
} }