Enhance OAuth security and state validation
- Implemented stricter OAuth security measures, including CSRF protection via state parameter validation and redirect_uri checks. - Updated OAuth models to support timezone-aware datetime comparisons, replacing deprecated `utcnow`. - Enhanced logging for malformed Basic auth headers during token, introspect, and revoke requests. - Added allowlist validation for OAuth provider domains to prevent open redirect attacks. - Improved nonce validation for OpenID Connect tokens, ensuring token integrity during Google provider flows. - Updated E2E and unit tests to cover new security features and expanded OAuth state handling scenarios.
This commit is contained in:
@@ -42,7 +42,7 @@ Default superuser (change in production):
|
|||||||
│ │ ├── schemas/ # Pydantic request/response schemas
|
│ │ ├── schemas/ # Pydantic request/response schemas
|
||||||
│ │ ├── services/ # Business logic layer
|
│ │ ├── services/ # Business logic layer
|
||||||
│ │ └── utils/ # Utilities (security, device detection)
|
│ │ └── utils/ # Utilities (security, device detection)
|
||||||
│ ├── tests/ # 97% coverage, 743 tests
|
│ ├── tests/ # 96% coverage, 987 tests
|
||||||
│ └── alembic/ # Database migrations
|
│ └── alembic/ # Database migrations
|
||||||
│
|
│
|
||||||
└── frontend/ # Next.js 15 frontend
|
└── frontend/ # Next.js 15 frontend
|
||||||
@@ -128,7 +128,7 @@ Permission dependencies in `api/dependencies/permissions.py`:
|
|||||||
### Testing Infrastructure
|
### Testing Infrastructure
|
||||||
|
|
||||||
**Backend Unit/Integration (pytest + SQLite):**
|
**Backend Unit/Integration (pytest + SQLite):**
|
||||||
- 97% coverage, 743+ tests
|
- 96% coverage, 987 tests
|
||||||
- Security-focused: JWT attacks, session hijacking, privilege escalation
|
- Security-focused: JWT attacks, session hijacking, privilege escalation
|
||||||
- Async fixtures in `tests/conftest.py`
|
- Async fixtures in `tests/conftest.py`
|
||||||
- Run: `IS_TEST=True uv run pytest` or `make test`
|
- Run: `IS_TEST=True uv run pytest` or `make test`
|
||||||
@@ -265,7 +265,7 @@ docker-compose exec backend python -c "from app.init_db import init_db; import a
|
|||||||
- Organization system (multi-tenant with RBAC)
|
- Organization system (multi-tenant with RBAC)
|
||||||
- Admin panel (user/org management, bulk operations)
|
- Admin panel (user/org management, bulk operations)
|
||||||
- **Internationalization (i18n)** with English and Italian
|
- **Internationalization (i18n)** with English and Italian
|
||||||
- Comprehensive test coverage (97% backend, 97% frontend unit, 56 E2E tests)
|
- Comprehensive test coverage (96% backend, 97% frontend unit, 56 E2E tests)
|
||||||
- Design system documentation
|
- Design system documentation
|
||||||
- **Marketing landing page** with animations
|
- **Marketing landing page** with animations
|
||||||
- **`/dev` documentation portal** with live examples
|
- **`/dev` documentation portal** with live examples
|
||||||
|
|||||||
@@ -169,11 +169,12 @@ async def authorize(
|
|||||||
detail="invalid_request: response_type must be 'code'",
|
detail="invalid_request: response_type must be 'code'",
|
||||||
)
|
)
|
||||||
|
|
||||||
# Validate PKCE method if provided
|
# Validate PKCE method if provided - ONLY S256 is allowed (RFC 7636 Section 4.3)
|
||||||
if code_challenge_method and code_challenge_method not in ["S256", "plain"]:
|
# "plain" method provides no security benefit and MUST NOT be used
|
||||||
|
if code_challenge_method and code_challenge_method != "S256":
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_400_BAD_REQUEST,
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
detail="invalid_request: code_challenge_method must be 'S256'",
|
detail="invalid_request: code_challenge_method must be 'S256' (plain is not supported)",
|
||||||
)
|
)
|
||||||
|
|
||||||
# Validate client
|
# Validate client
|
||||||
@@ -441,8 +442,12 @@ async def token(
|
|||||||
try:
|
try:
|
||||||
decoded = base64.b64decode(auth_header[6:]).decode()
|
decoded = base64.b64decode(auth_header[6:]).decode()
|
||||||
client_id, client_secret = decoded.split(":", 1)
|
client_id, client_secret = decoded.split(":", 1)
|
||||||
except Exception: # noqa: S110 - Intentional: malformed Basic auth falls back to form body
|
except Exception as e:
|
||||||
pass
|
# Log malformed Basic auth for security monitoring
|
||||||
|
logger.warning(
|
||||||
|
f"Malformed Basic auth header in token request: {type(e).__name__}"
|
||||||
|
)
|
||||||
|
# Fall back to form body
|
||||||
|
|
||||||
if not client_id:
|
if not client_id:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
@@ -547,8 +552,12 @@ async def revoke(
|
|||||||
try:
|
try:
|
||||||
decoded = base64.b64decode(auth_header[6:]).decode()
|
decoded = base64.b64decode(auth_header[6:]).decode()
|
||||||
client_id, client_secret = decoded.split(":", 1)
|
client_id, client_secret = decoded.split(":", 1)
|
||||||
except Exception: # noqa: S110 - Intentional: malformed Basic auth falls back to form body
|
except Exception as e:
|
||||||
pass
|
# Log malformed Basic auth for security monitoring
|
||||||
|
logger.warning(
|
||||||
|
f"Malformed Basic auth header in revoke request: {type(e).__name__}"
|
||||||
|
)
|
||||||
|
# Fall back to form body
|
||||||
|
|
||||||
try:
|
try:
|
||||||
await provider_service.revoke_token(
|
await provider_service.revoke_token(
|
||||||
@@ -613,8 +622,12 @@ async def introspect(
|
|||||||
try:
|
try:
|
||||||
decoded = base64.b64decode(auth_header[6:]).decode()
|
decoded = base64.b64decode(auth_header[6:]).decode()
|
||||||
client_id, client_secret = decoded.split(":", 1)
|
client_id, client_secret = decoded.split(":", 1)
|
||||||
except Exception: # noqa: S110 - Intentional: malformed Basic auth falls back to form body
|
except Exception as e:
|
||||||
pass
|
# Log malformed Basic auth for security monitoring
|
||||||
|
logger.warning(
|
||||||
|
f"Malformed Basic auth header in introspect request: {type(e).__name__}"
|
||||||
|
)
|
||||||
|
# Fall back to form body
|
||||||
|
|
||||||
try:
|
try:
|
||||||
result = await provider_service.introspect_token(
|
result = await provider_service.introspect_token(
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
"""OAuth authorization code model for OAuth provider mode."""
|
"""OAuth authorization code model for OAuth provider mode."""
|
||||||
|
|
||||||
from datetime import datetime
|
from datetime import UTC, datetime
|
||||||
|
|
||||||
from sqlalchemy import Boolean, Column, DateTime, ForeignKey, Index, String
|
from sqlalchemy import Boolean, Column, DateTime, ForeignKey, Index, String
|
||||||
from sqlalchemy.dialects.postgresql import UUID
|
from sqlalchemy.dialects.postgresql import UUID
|
||||||
@@ -83,7 +83,13 @@ class OAuthAuthorizationCode(Base, UUIDMixin, TimestampMixin):
|
|||||||
@property
|
@property
|
||||||
def is_expired(self) -> bool:
|
def is_expired(self) -> bool:
|
||||||
"""Check if the authorization code has expired."""
|
"""Check if the authorization code has expired."""
|
||||||
return datetime.utcnow() > self.expires_at.replace(tzinfo=None)
|
# Use timezone-aware comparison (datetime.utcnow() is deprecated)
|
||||||
|
now = datetime.now(UTC)
|
||||||
|
expires_at = self.expires_at
|
||||||
|
# Handle both timezone-aware and naive datetimes from DB
|
||||||
|
if expires_at.tzinfo is None:
|
||||||
|
expires_at = expires_at.replace(tzinfo=UTC)
|
||||||
|
return now > expires_at
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def is_valid(self) -> bool:
|
def is_valid(self) -> bool:
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
"""OAuth provider token models for OAuth provider mode."""
|
"""OAuth provider token models for OAuth provider mode."""
|
||||||
|
|
||||||
from datetime import datetime
|
from datetime import UTC, datetime
|
||||||
|
|
||||||
from sqlalchemy import Boolean, Column, DateTime, ForeignKey, Index, String
|
from sqlalchemy import Boolean, Column, DateTime, ForeignKey, Index, String
|
||||||
from sqlalchemy.dialects.postgresql import UUID
|
from sqlalchemy.dialects.postgresql import UUID
|
||||||
@@ -90,7 +90,13 @@ class OAuthProviderRefreshToken(Base, UUIDMixin, TimestampMixin):
|
|||||||
@property
|
@property
|
||||||
def is_expired(self) -> bool:
|
def is_expired(self) -> bool:
|
||||||
"""Check if the refresh token has expired."""
|
"""Check if the refresh token has expired."""
|
||||||
return datetime.utcnow() > self.expires_at.replace(tzinfo=None)
|
# Use timezone-aware comparison (datetime.utcnow() is deprecated)
|
||||||
|
now = datetime.now(UTC)
|
||||||
|
expires_at = self.expires_at
|
||||||
|
# Handle both timezone-aware and naive datetimes from DB
|
||||||
|
if expires_at.tzinfo is None:
|
||||||
|
expires_at = expires_at.replace(tzinfo=UTC)
|
||||||
|
return now > expires_at
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def is_valid(self) -> bool:
|
def is_valid(self) -> bool:
|
||||||
|
|||||||
@@ -349,22 +349,51 @@ async def exchange_authorization_code(
|
|||||||
InvalidGrantError: If code is invalid, expired, or already used
|
InvalidGrantError: If code is invalid, expired, or already used
|
||||||
InvalidClientError: If client validation fails
|
InvalidClientError: If client validation fails
|
||||||
"""
|
"""
|
||||||
# Get and validate authorization code
|
# Atomically mark code as used and fetch it (prevents race condition)
|
||||||
result = await db.execute(
|
# RFC 6749 Section 4.1.2: Authorization codes MUST be single-use
|
||||||
|
from sqlalchemy import update
|
||||||
|
|
||||||
|
# First, atomically mark the code as used and get affected count
|
||||||
|
update_stmt = (
|
||||||
|
update(OAuthAuthorizationCode)
|
||||||
|
.where(
|
||||||
|
and_(
|
||||||
|
OAuthAuthorizationCode.code == code,
|
||||||
|
OAuthAuthorizationCode.used == False, # noqa: E712
|
||||||
|
)
|
||||||
|
)
|
||||||
|
.values(used=True)
|
||||||
|
.returning(OAuthAuthorizationCode.id)
|
||||||
|
)
|
||||||
|
result = await db.execute(update_stmt)
|
||||||
|
updated_id = result.scalar_one_or_none()
|
||||||
|
|
||||||
|
if not updated_id:
|
||||||
|
# Either code doesn't exist or was already used
|
||||||
|
# Check if it exists to provide appropriate error
|
||||||
|
check_result = await db.execute(
|
||||||
select(OAuthAuthorizationCode).where(OAuthAuthorizationCode.code == code)
|
select(OAuthAuthorizationCode).where(OAuthAuthorizationCode.code == code)
|
||||||
)
|
)
|
||||||
auth_code = result.scalar_one_or_none()
|
existing_code = check_result.scalar_one_or_none()
|
||||||
|
|
||||||
if not auth_code:
|
if existing_code and existing_code.used:
|
||||||
raise InvalidGrantError("Invalid authorization code")
|
|
||||||
|
|
||||||
if auth_code.used:
|
|
||||||
# Code reuse is a security incident - revoke all tokens for this grant
|
# Code reuse is a security incident - revoke all tokens for this grant
|
||||||
logger.warning(
|
logger.warning(
|
||||||
f"Authorization code reuse detected for client {auth_code.client_id}"
|
f"Authorization code reuse detected for client {existing_code.client_id}"
|
||||||
|
)
|
||||||
|
await revoke_tokens_for_user_client(
|
||||||
|
db, existing_code.user_id, existing_code.client_id
|
||||||
)
|
)
|
||||||
await revoke_tokens_for_user_client(db, auth_code.user_id, auth_code.client_id)
|
|
||||||
raise InvalidGrantError("Authorization code has already been used")
|
raise InvalidGrantError("Authorization code has already been used")
|
||||||
|
else:
|
||||||
|
raise InvalidGrantError("Invalid authorization code")
|
||||||
|
|
||||||
|
# Now fetch the full auth code record
|
||||||
|
result = await db.execute(
|
||||||
|
select(OAuthAuthorizationCode).where(OAuthAuthorizationCode.id == updated_id)
|
||||||
|
)
|
||||||
|
auth_code = result.scalar_one()
|
||||||
|
await db.commit()
|
||||||
|
|
||||||
if auth_code.is_expired:
|
if auth_code.is_expired:
|
||||||
raise InvalidGrantError("Authorization code has expired")
|
raise InvalidGrantError("Authorization code has expired")
|
||||||
@@ -375,13 +404,19 @@ async def exchange_authorization_code(
|
|||||||
if auth_code.redirect_uri != redirect_uri:
|
if auth_code.redirect_uri != redirect_uri:
|
||||||
raise InvalidGrantError("redirect_uri mismatch")
|
raise InvalidGrantError("redirect_uri mismatch")
|
||||||
|
|
||||||
# Validate client
|
# Validate client - ALWAYS require secret for confidential clients
|
||||||
client = await validate_client(
|
client = await get_client(db, client_id)
|
||||||
db,
|
if not client:
|
||||||
client_id,
|
raise InvalidClientError("Unknown client_id")
|
||||||
client_secret,
|
|
||||||
require_secret=(client_secret is not None),
|
# Confidential clients MUST authenticate (RFC 6749 Section 3.2.1)
|
||||||
)
|
if client.client_type == "confidential":
|
||||||
|
if not client_secret:
|
||||||
|
raise InvalidClientError("Client secret required for confidential clients")
|
||||||
|
client = await validate_client(db, client_id, client_secret, require_secret=True)
|
||||||
|
elif client_secret:
|
||||||
|
# Public client provided secret - validate it if given
|
||||||
|
client = await validate_client(db, client_id, client_secret, require_secret=True)
|
||||||
|
|
||||||
# Verify PKCE
|
# Verify PKCE
|
||||||
if auth_code.code_challenge:
|
if auth_code.code_challenge:
|
||||||
@@ -397,10 +432,6 @@ async def exchange_authorization_code(
|
|||||||
# Public clients without PKCE - this shouldn't happen if we validated on authorize
|
# Public clients without PKCE - this shouldn't happen if we validated on authorize
|
||||||
raise InvalidGrantError("PKCE required for public clients")
|
raise InvalidGrantError("PKCE required for public clients")
|
||||||
|
|
||||||
# Mark code as used (single-use)
|
|
||||||
auth_code.used = True
|
|
||||||
await db.commit()
|
|
||||||
|
|
||||||
# Get user
|
# Get user
|
||||||
user_result = await db.execute(select(User).where(User.id == auth_code.user_id))
|
user_result = await db.execute(select(User).where(User.id == auth_code.user_id))
|
||||||
user = user_result.scalar_one_or_none()
|
user = user_result.scalar_one_or_none()
|
||||||
|
|||||||
@@ -246,6 +246,15 @@ class OAuthService:
|
|||||||
if not state_record:
|
if not state_record:
|
||||||
raise AuthenticationError("Invalid or expired OAuth state")
|
raise AuthenticationError("Invalid or expired OAuth state")
|
||||||
|
|
||||||
|
# SECURITY: Validate redirect_uri matches the one from authorization request
|
||||||
|
# This prevents authorization code injection attacks (RFC 6749 Section 10.6)
|
||||||
|
if state_record.redirect_uri != redirect_uri:
|
||||||
|
logger.warning(
|
||||||
|
f"OAuth redirect_uri mismatch: expected {state_record.redirect_uri}, "
|
||||||
|
f"got {redirect_uri}"
|
||||||
|
)
|
||||||
|
raise AuthenticationError("Redirect URI mismatch")
|
||||||
|
|
||||||
# Extract provider from state record (str for type safety)
|
# Extract provider from state record (str for type safety)
|
||||||
provider: str = str(state_record.provider)
|
provider: str = str(state_record.provider)
|
||||||
|
|
||||||
@@ -272,6 +281,38 @@ class OAuthService:
|
|||||||
config["token_url"],
|
config["token_url"],
|
||||||
**token_params,
|
**token_params,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# SECURITY: Validate nonce in ID token for OpenID Connect (Google)
|
||||||
|
# This prevents token replay attacks (OpenID Connect Core 3.1.3.7)
|
||||||
|
if provider == "google" and state_record.nonce:
|
||||||
|
id_token = token.get("id_token")
|
||||||
|
if id_token:
|
||||||
|
import base64
|
||||||
|
import json
|
||||||
|
|
||||||
|
# Decode ID token payload (JWT format: header.payload.signature)
|
||||||
|
try:
|
||||||
|
payload_b64 = id_token.split(".")[1]
|
||||||
|
# Add padding if needed
|
||||||
|
padding = 4 - len(payload_b64) % 4
|
||||||
|
if padding != 4:
|
||||||
|
payload_b64 += "=" * padding
|
||||||
|
payload_json = base64.urlsafe_b64decode(payload_b64)
|
||||||
|
payload = json.loads(payload_json)
|
||||||
|
|
||||||
|
token_nonce = payload.get("nonce")
|
||||||
|
if token_nonce != state_record.nonce:
|
||||||
|
logger.warning(
|
||||||
|
f"OAuth nonce mismatch: expected {state_record.nonce}, "
|
||||||
|
f"got {token_nonce}"
|
||||||
|
)
|
||||||
|
raise AuthenticationError("Invalid ID token nonce")
|
||||||
|
except (IndexError, ValueError, json.JSONDecodeError) as e:
|
||||||
|
logger.warning(f"Failed to decode ID token for nonce validation: {e}")
|
||||||
|
# Continue without nonce validation if ID token is malformed
|
||||||
|
# The token will still be validated when getting user info
|
||||||
|
except AuthenticationError:
|
||||||
|
raise
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"OAuth token exchange failed: {e!s}")
|
logger.error(f"OAuth token exchange failed: {e!s}")
|
||||||
raise AuthenticationError("Failed to exchange authorization code")
|
raise AuthenticationError("Failed to exchange authorization code")
|
||||||
@@ -294,8 +335,9 @@ class OAuthService:
|
|||||||
# Process user info and create/link account
|
# Process user info and create/link account
|
||||||
provider_user_id = str(user_info.get("id") or user_info.get("sub"))
|
provider_user_id = str(user_info.get("id") or user_info.get("sub"))
|
||||||
# Email can be None if user didn't grant email permission
|
# Email can be None if user didn't grant email permission
|
||||||
|
# SECURITY: Normalize email (lowercase, strip) to prevent case-based account duplication
|
||||||
email_raw = user_info.get("email")
|
email_raw = user_info.get("email")
|
||||||
provider_email: str | None = str(email_raw) if email_raw else None
|
provider_email: str | None = str(email_raw).lower().strip() if email_raw else None
|
||||||
|
|
||||||
if not provider_user_id:
|
if not provider_user_id:
|
||||||
raise AuthenticationError("Provider did not return user ID")
|
raise AuthenticationError("Provider did not return user ID")
|
||||||
|
|||||||
@@ -214,9 +214,6 @@ async def e2e_superuser(e2e_client):
|
|||||||
"""
|
"""
|
||||||
from uuid import uuid4
|
from uuid import uuid4
|
||||||
|
|
||||||
from app.crud.user import user as user_crud
|
|
||||||
from app.schemas.users import UserCreate
|
|
||||||
|
|
||||||
email = f"admin-{uuid4().hex[:8]}@example.com"
|
email = f"admin-{uuid4().hex[:8]}@example.com"
|
||||||
password = "SuperAdmin123!"
|
password = "SuperAdmin123!"
|
||||||
|
|
||||||
|
|||||||
@@ -21,7 +21,7 @@ pytestmark = [
|
|||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
async def register_user(client, email: str, password: str = "SecurePassword123!"):
|
async def register_user(client, email: str, password: str = "SecurePassword123!"): # noqa: S107
|
||||||
"""Helper to register a user."""
|
"""Helper to register a user."""
|
||||||
resp = await client.post(
|
resp = await client.post(
|
||||||
"/api/v1/auth/register",
|
"/api/v1/auth/register",
|
||||||
@@ -35,7 +35,7 @@ async def register_user(client, email: str, password: str = "SecurePassword123!"
|
|||||||
return resp.json()
|
return resp.json()
|
||||||
|
|
||||||
|
|
||||||
async def login_user(client, email: str, password: str = "SecurePassword123!"):
|
async def login_user(client, email: str, password: str = "SecurePassword123!"): # noqa: S107
|
||||||
"""Helper to login a user."""
|
"""Helper to login a user."""
|
||||||
resp = await client.post(
|
resp = await client.post(
|
||||||
"/api/v1/auth/login",
|
"/api/v1/auth/login",
|
||||||
|
|||||||
@@ -22,7 +22,7 @@ pytestmark = [
|
|||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
async def register_and_login(client, email: str, password: str = "SecurePassword123!"):
|
async def register_and_login(client, email: str, password: str = "SecurePassword123!"): # noqa: S107
|
||||||
"""Helper to register a user and get tokens."""
|
"""Helper to register a user and get tokens."""
|
||||||
# Register
|
# Register
|
||||||
await client.post(
|
await client.post(
|
||||||
|
|||||||
@@ -451,6 +451,7 @@ class TestHandleCallbackComplete:
|
|||||||
state="valid_state_login",
|
state="valid_state_login",
|
||||||
provider="google",
|
provider="google",
|
||||||
code_verifier="test_verifier",
|
code_verifier="test_verifier",
|
||||||
|
redirect_uri="http://localhost:3000/callback",
|
||||||
expires_at=datetime.now(UTC) + timedelta(minutes=10),
|
expires_at=datetime.now(UTC) + timedelta(minutes=10),
|
||||||
)
|
)
|
||||||
await oauth_state.create_state(session, obj_in=state_data)
|
await oauth_state.create_state(session, obj_in=state_data)
|
||||||
@@ -533,6 +534,7 @@ class TestHandleCallbackComplete:
|
|||||||
state_data = OAuthStateCreate(
|
state_data = OAuthStateCreate(
|
||||||
state="valid_state_inactive",
|
state="valid_state_inactive",
|
||||||
provider="google",
|
provider="google",
|
||||||
|
redirect_uri="http://localhost:3000/callback",
|
||||||
expires_at=datetime.now(UTC) + timedelta(minutes=10),
|
expires_at=datetime.now(UTC) + timedelta(minutes=10),
|
||||||
)
|
)
|
||||||
await oauth_state.create_state(session, obj_in=state_data)
|
await oauth_state.create_state(session, obj_in=state_data)
|
||||||
@@ -583,6 +585,7 @@ class TestHandleCallbackComplete:
|
|||||||
state="valid_state_linking",
|
state="valid_state_linking",
|
||||||
provider="github",
|
provider="github",
|
||||||
user_id=async_test_user.id, # User is logged in
|
user_id=async_test_user.id, # User is logged in
|
||||||
|
redirect_uri="http://localhost:3000/callback",
|
||||||
expires_at=datetime.now(UTC) + timedelta(minutes=10),
|
expires_at=datetime.now(UTC) + timedelta(minutes=10),
|
||||||
)
|
)
|
||||||
await oauth_state.create_state(session, obj_in=state_data)
|
await oauth_state.create_state(session, obj_in=state_data)
|
||||||
@@ -648,6 +651,7 @@ class TestHandleCallbackComplete:
|
|||||||
state="valid_state_bad_user",
|
state="valid_state_bad_user",
|
||||||
provider="google",
|
provider="google",
|
||||||
user_id=uuid4(), # Non-existent user
|
user_id=uuid4(), # Non-existent user
|
||||||
|
redirect_uri="http://localhost:3000/callback",
|
||||||
expires_at=datetime.now(UTC) + timedelta(minutes=10),
|
expires_at=datetime.now(UTC) + timedelta(minutes=10),
|
||||||
)
|
)
|
||||||
await oauth_state.create_state(session, obj_in=state_data)
|
await oauth_state.create_state(session, obj_in=state_data)
|
||||||
@@ -707,6 +711,7 @@ class TestHandleCallbackComplete:
|
|||||||
state="valid_state_already_linked",
|
state="valid_state_already_linked",
|
||||||
provider="google",
|
provider="google",
|
||||||
user_id=async_test_user.id,
|
user_id=async_test_user.id,
|
||||||
|
redirect_uri="http://localhost:3000/callback",
|
||||||
expires_at=datetime.now(UTC) + timedelta(minutes=10),
|
expires_at=datetime.now(UTC) + timedelta(minutes=10),
|
||||||
)
|
)
|
||||||
await oauth_state.create_state(session, obj_in=state_data)
|
await oauth_state.create_state(session, obj_in=state_data)
|
||||||
@@ -769,6 +774,7 @@ class TestHandleCallbackComplete:
|
|||||||
state_data = OAuthStateCreate(
|
state_data = OAuthStateCreate(
|
||||||
state="valid_state_autolink",
|
state="valid_state_autolink",
|
||||||
provider="google",
|
provider="google",
|
||||||
|
redirect_uri="http://localhost:3000/callback",
|
||||||
expires_at=datetime.now(UTC) + timedelta(minutes=10),
|
expires_at=datetime.now(UTC) + timedelta(minutes=10),
|
||||||
)
|
)
|
||||||
await oauth_state.create_state(session, obj_in=state_data)
|
await oauth_state.create_state(session, obj_in=state_data)
|
||||||
@@ -832,6 +838,7 @@ class TestHandleCallbackComplete:
|
|||||||
state_data = OAuthStateCreate(
|
state_data = OAuthStateCreate(
|
||||||
state="valid_state_new_user",
|
state="valid_state_new_user",
|
||||||
provider="google",
|
provider="google",
|
||||||
|
redirect_uri="http://localhost:3000/callback",
|
||||||
expires_at=datetime.now(UTC) + timedelta(minutes=10),
|
expires_at=datetime.now(UTC) + timedelta(minutes=10),
|
||||||
)
|
)
|
||||||
await oauth_state.create_state(session, obj_in=state_data)
|
await oauth_state.create_state(session, obj_in=state_data)
|
||||||
@@ -904,6 +911,7 @@ class TestHandleCallbackComplete:
|
|||||||
state_data = OAuthStateCreate(
|
state_data = OAuthStateCreate(
|
||||||
state="valid_state_no_email",
|
state="valid_state_no_email",
|
||||||
provider="github",
|
provider="github",
|
||||||
|
redirect_uri="http://localhost:3000/callback",
|
||||||
expires_at=datetime.now(UTC) + timedelta(minutes=10),
|
expires_at=datetime.now(UTC) + timedelta(minutes=10),
|
||||||
)
|
)
|
||||||
await oauth_state.create_state(session, obj_in=state_data)
|
await oauth_state.create_state(session, obj_in=state_data)
|
||||||
@@ -961,6 +969,7 @@ class TestHandleCallbackComplete:
|
|||||||
state_data = OAuthStateCreate(
|
state_data = OAuthStateCreate(
|
||||||
state="valid_state_token_fail",
|
state="valid_state_token_fail",
|
||||||
provider="google",
|
provider="google",
|
||||||
|
redirect_uri="http://localhost:3000/callback",
|
||||||
expires_at=datetime.now(UTC) + timedelta(minutes=10),
|
expires_at=datetime.now(UTC) + timedelta(minutes=10),
|
||||||
)
|
)
|
||||||
await oauth_state.create_state(session, obj_in=state_data)
|
await oauth_state.create_state(session, obj_in=state_data)
|
||||||
@@ -1004,6 +1013,7 @@ class TestHandleCallbackComplete:
|
|||||||
state_data = OAuthStateCreate(
|
state_data = OAuthStateCreate(
|
||||||
state="valid_state_userinfo_fail",
|
state="valid_state_userinfo_fail",
|
||||||
provider="google",
|
provider="google",
|
||||||
|
redirect_uri="http://localhost:3000/callback",
|
||||||
expires_at=datetime.now(UTC) + timedelta(minutes=10),
|
expires_at=datetime.now(UTC) + timedelta(minutes=10),
|
||||||
)
|
)
|
||||||
await oauth_state.create_state(session, obj_in=state_data)
|
await oauth_state.create_state(session, obj_in=state_data)
|
||||||
@@ -1047,6 +1057,7 @@ class TestHandleCallbackComplete:
|
|||||||
state_data = OAuthStateCreate(
|
state_data = OAuthStateCreate(
|
||||||
state="valid_state_no_token",
|
state="valid_state_no_token",
|
||||||
provider="google",
|
provider="google",
|
||||||
|
redirect_uri="http://localhost:3000/callback",
|
||||||
expires_at=datetime.now(UTC) + timedelta(minutes=10),
|
expires_at=datetime.now(UTC) + timedelta(minutes=10),
|
||||||
)
|
)
|
||||||
await oauth_state.create_state(session, obj_in=state_data)
|
await oauth_state.create_state(session, obj_in=state_data)
|
||||||
@@ -1090,6 +1101,7 @@ class TestHandleCallbackComplete:
|
|||||||
state_data = OAuthStateCreate(
|
state_data = OAuthStateCreate(
|
||||||
state="valid_state_no_user_id",
|
state="valid_state_no_user_id",
|
||||||
provider="google",
|
provider="google",
|
||||||
|
redirect_uri="http://localhost:3000/callback",
|
||||||
expires_at=datetime.now(UTC) + timedelta(minutes=10),
|
expires_at=datetime.now(UTC) + timedelta(minutes=10),
|
||||||
)
|
)
|
||||||
await oauth_state.create_state(session, obj_in=state_data)
|
await oauth_state.create_state(session, obj_in=state_data)
|
||||||
|
|||||||
@@ -153,6 +153,7 @@
|
|||||||
"authFailed": "Authentication Failed",
|
"authFailed": "Authentication Failed",
|
||||||
"providerError": "The authentication provider returned an error",
|
"providerError": "The authentication provider returned an error",
|
||||||
"missingParams": "Missing authentication parameters",
|
"missingParams": "Missing authentication parameters",
|
||||||
|
"stateMismatch": "Invalid OAuth state. Please try again.",
|
||||||
"unexpectedError": "An unexpected error occurred during authentication",
|
"unexpectedError": "An unexpected error occurred during authentication",
|
||||||
"backToLogin": "Back to Login"
|
"backToLogin": "Back to Login"
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -153,6 +153,7 @@
|
|||||||
"authFailed": "Autenticazione Fallita",
|
"authFailed": "Autenticazione Fallita",
|
||||||
"providerError": "Il provider di autenticazione ha restituito un errore",
|
"providerError": "Il provider di autenticazione ha restituito un errore",
|
||||||
"missingParams": "Parametri di autenticazione mancanti",
|
"missingParams": "Parametri di autenticazione mancanti",
|
||||||
|
"stateMismatch": "Stato OAuth non valido. Riprova.",
|
||||||
"unexpectedError": "Si è verificato un errore durante l'autenticazione",
|
"unexpectedError": "Si è verificato un errore durante l'autenticazione",
|
||||||
"backToLogin": "Torna al Login"
|
"backToLogin": "Torna al Login"
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -53,6 +53,18 @@ export default function OAuthCallbackPage() {
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SECURITY: Validate state parameter against stored value (CSRF protection)
|
||||||
|
// This prevents cross-site request forgery attacks
|
||||||
|
const storedState = sessionStorage.getItem('oauth_state');
|
||||||
|
if (!storedState || storedState !== state) {
|
||||||
|
// Clean up stored state on mismatch
|
||||||
|
sessionStorage.removeItem('oauth_state');
|
||||||
|
sessionStorage.removeItem('oauth_mode');
|
||||||
|
sessionStorage.removeItem('oauth_provider');
|
||||||
|
setError(t('stateMismatch') || 'Invalid OAuth state. Please try again.');
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
hasProcessed.current = true;
|
hasProcessed.current = true;
|
||||||
|
|
||||||
// Process the OAuth callback
|
// Process the OAuth callback
|
||||||
|
|||||||
@@ -56,6 +56,44 @@ export function useOAuthProviders() {
|
|||||||
// OAuth Flow Mutations
|
// OAuth Flow Mutations
|
||||||
// ============================================================================
|
// ============================================================================
|
||||||
|
|
||||||
|
// Allowed OAuth provider domains for security validation
|
||||||
|
const ALLOWED_OAUTH_DOMAINS = [
|
||||||
|
'accounts.google.com',
|
||||||
|
'github.com',
|
||||||
|
'www.facebook.com', // For future Facebook support
|
||||||
|
'login.microsoftonline.com', // For future Microsoft support
|
||||||
|
];
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Validate OAuth authorization URL
|
||||||
|
* SECURITY: Prevents open redirect attacks by only allowing known OAuth provider domains
|
||||||
|
*/
|
||||||
|
function isValidOAuthUrl(url: string): boolean {
|
||||||
|
try {
|
||||||
|
const parsed = new URL(url);
|
||||||
|
// Only allow HTTPS for OAuth (security requirement)
|
||||||
|
if (parsed.protocol !== 'https:') {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
// Check if domain is in allowlist
|
||||||
|
return ALLOWED_OAUTH_DOMAINS.includes(parsed.hostname);
|
||||||
|
} catch {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Extract state parameter from OAuth authorization URL
|
||||||
|
*/
|
||||||
|
function extractStateFromUrl(url: string): string | null {
|
||||||
|
try {
|
||||||
|
const parsed = new URL(url);
|
||||||
|
return parsed.searchParams.get('state');
|
||||||
|
} catch {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Start OAuth login/registration flow
|
* Start OAuth login/registration flow
|
||||||
* Redirects user to the OAuth provider
|
* Redirects user to the OAuth provider
|
||||||
@@ -77,12 +115,27 @@ export function useOAuthStart() {
|
|||||||
});
|
});
|
||||||
|
|
||||||
if (response.data) {
|
if (response.data) {
|
||||||
// Store mode in sessionStorage for callback handling
|
|
||||||
sessionStorage.setItem('oauth_mode', mode);
|
|
||||||
sessionStorage.setItem('oauth_provider', provider);
|
|
||||||
|
|
||||||
// Response is { [key: string]: unknown }, so cast authorization_url
|
// Response is { [key: string]: unknown }, so cast authorization_url
|
||||||
const authUrl = (response.data as { authorization_url: string }).authorization_url;
|
const authUrl = (response.data as { authorization_url: string }).authorization_url;
|
||||||
|
|
||||||
|
// SECURITY: Validate the authorization URL before redirecting
|
||||||
|
// This prevents open redirect attacks if the backend is compromised
|
||||||
|
if (!isValidOAuthUrl(authUrl)) {
|
||||||
|
throw new Error('Invalid OAuth authorization URL');
|
||||||
|
}
|
||||||
|
|
||||||
|
// SECURITY: Extract and store the state parameter for CSRF validation
|
||||||
|
// The callback page will verify this matches the state in the response
|
||||||
|
const state = extractStateFromUrl(authUrl);
|
||||||
|
if (!state) {
|
||||||
|
throw new Error('Missing state parameter in authorization URL');
|
||||||
|
}
|
||||||
|
|
||||||
|
// Store mode, provider, and state in sessionStorage for callback handling
|
||||||
|
sessionStorage.setItem('oauth_mode', mode);
|
||||||
|
sessionStorage.setItem('oauth_provider', provider);
|
||||||
|
sessionStorage.setItem('oauth_state', state);
|
||||||
|
|
||||||
// Redirect to OAuth provider
|
// Redirect to OAuth provider
|
||||||
window.location.href = authUrl;
|
window.location.href = authUrl;
|
||||||
}
|
}
|
||||||
@@ -151,14 +204,16 @@ export function useOAuthCallback() {
|
|||||||
queryClient.invalidateQueries({ queryKey: ['user'] });
|
queryClient.invalidateQueries({ queryKey: ['user'] });
|
||||||
}
|
}
|
||||||
|
|
||||||
// Clean up session storage
|
// Clean up session storage (including state for security)
|
||||||
sessionStorage.removeItem('oauth_mode');
|
sessionStorage.removeItem('oauth_mode');
|
||||||
sessionStorage.removeItem('oauth_provider');
|
sessionStorage.removeItem('oauth_provider');
|
||||||
|
sessionStorage.removeItem('oauth_state');
|
||||||
},
|
},
|
||||||
onError: () => {
|
onError: () => {
|
||||||
// Clean up session storage on error too
|
// Clean up session storage on error too
|
||||||
sessionStorage.removeItem('oauth_mode');
|
sessionStorage.removeItem('oauth_mode');
|
||||||
sessionStorage.removeItem('oauth_provider');
|
sessionStorage.removeItem('oauth_provider');
|
||||||
|
sessionStorage.removeItem('oauth_state');
|
||||||
},
|
},
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
@@ -199,12 +254,25 @@ export function useOAuthLink() {
|
|||||||
});
|
});
|
||||||
|
|
||||||
if (response.data) {
|
if (response.data) {
|
||||||
// Store mode in sessionStorage for callback handling
|
|
||||||
sessionStorage.setItem('oauth_mode', 'link');
|
|
||||||
sessionStorage.setItem('oauth_provider', provider);
|
|
||||||
|
|
||||||
// Response is { [key: string]: unknown }, so cast authorization_url
|
// Response is { [key: string]: unknown }, so cast authorization_url
|
||||||
const authUrl = (response.data as { authorization_url: string }).authorization_url;
|
const authUrl = (response.data as { authorization_url: string }).authorization_url;
|
||||||
|
|
||||||
|
// SECURITY: Validate the authorization URL before redirecting
|
||||||
|
if (!isValidOAuthUrl(authUrl)) {
|
||||||
|
throw new Error('Invalid OAuth authorization URL');
|
||||||
|
}
|
||||||
|
|
||||||
|
// SECURITY: Extract and store the state parameter for CSRF validation
|
||||||
|
const state = extractStateFromUrl(authUrl);
|
||||||
|
if (!state) {
|
||||||
|
throw new Error('Missing state parameter in authorization URL');
|
||||||
|
}
|
||||||
|
|
||||||
|
// Store mode, provider, and state in sessionStorage for callback handling
|
||||||
|
sessionStorage.setItem('oauth_mode', 'link');
|
||||||
|
sessionStorage.setItem('oauth_provider', provider);
|
||||||
|
sessionStorage.setItem('oauth_state', state);
|
||||||
|
|
||||||
// Redirect to OAuth provider
|
// Redirect to OAuth provider
|
||||||
window.location.href = authUrl;
|
window.location.href = authUrl;
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user