Reformatted multiline function calls, object definitions, and queries for improved code readability and consistency. Adjusted imports and constraints where necessary.
724 lines
25 KiB
Python
724 lines
25 KiB
Python
"""
|
|
OAuth Service for handling social authentication flows.
|
|
|
|
Supports:
|
|
- Google OAuth (OpenID Connect)
|
|
- GitHub OAuth
|
|
|
|
Features:
|
|
- PKCE support for public clients
|
|
- State parameter for CSRF protection
|
|
- Auto-linking by email (configurable)
|
|
- Account linking for existing users
|
|
"""
|
|
|
|
import logging
|
|
import secrets
|
|
from datetime import UTC, datetime, timedelta
|
|
from typing import TypedDict, cast
|
|
from uuid import UUID
|
|
|
|
from authlib.integrations.httpx_client import AsyncOAuth2Client
|
|
from sqlalchemy import select
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
|
|
from app.core.auth import create_access_token, create_refresh_token
|
|
from app.core.config import settings
|
|
from app.core.exceptions import AuthenticationError
|
|
from app.crud import oauth_account, oauth_state
|
|
from app.models.user import User
|
|
from app.schemas.oauth import (
|
|
OAuthAccountCreate,
|
|
OAuthCallbackResponse,
|
|
OAuthProviderInfo,
|
|
OAuthProvidersResponse,
|
|
OAuthStateCreate,
|
|
)
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class OAuthProviderConfig(TypedDict, total=False):
|
|
"""Type definition for OAuth provider configuration."""
|
|
|
|
name: str
|
|
icon: str
|
|
authorize_url: str
|
|
token_url: str
|
|
userinfo_url: str
|
|
email_url: str # Optional, GitHub-only
|
|
scopes: list[str]
|
|
supports_pkce: bool
|
|
|
|
|
|
# Provider configurations
|
|
OAUTH_PROVIDERS: dict[str, OAuthProviderConfig] = {
|
|
"google": {
|
|
"name": "Google",
|
|
"icon": "google",
|
|
"authorize_url": "https://accounts.google.com/o/oauth2/v2/auth",
|
|
"token_url": "https://oauth2.googleapis.com/token",
|
|
"userinfo_url": "https://www.googleapis.com/oauth2/v3/userinfo",
|
|
"scopes": ["openid", "email", "profile"],
|
|
"supports_pkce": True,
|
|
},
|
|
"github": {
|
|
"name": "GitHub",
|
|
"icon": "github",
|
|
"authorize_url": "https://github.com/login/oauth/authorize",
|
|
"token_url": "https://github.com/login/oauth/access_token",
|
|
"userinfo_url": "https://api.github.com/user",
|
|
"email_url": "https://api.github.com/user/emails",
|
|
"scopes": ["read:user", "user:email"],
|
|
"supports_pkce": False, # GitHub doesn't support PKCE
|
|
},
|
|
}
|
|
|
|
|
|
class OAuthService:
|
|
"""Service for handling OAuth authentication flows."""
|
|
|
|
@staticmethod
|
|
def get_enabled_providers() -> OAuthProvidersResponse:
|
|
"""
|
|
Get list of enabled OAuth providers.
|
|
|
|
Returns:
|
|
OAuthProvidersResponse with enabled providers
|
|
"""
|
|
providers = []
|
|
|
|
for provider_id in settings.enabled_oauth_providers:
|
|
if provider_id in OAUTH_PROVIDERS:
|
|
config = OAUTH_PROVIDERS[provider_id]
|
|
providers.append(
|
|
OAuthProviderInfo(
|
|
provider=provider_id,
|
|
name=config["name"],
|
|
icon=config["icon"],
|
|
)
|
|
)
|
|
|
|
return OAuthProvidersResponse(
|
|
enabled=settings.OAUTH_ENABLED and len(providers) > 0,
|
|
providers=providers,
|
|
)
|
|
|
|
@staticmethod
|
|
def _get_provider_credentials(provider: str) -> tuple[str, str]:
|
|
"""Get client ID and secret for a provider."""
|
|
if provider == "google":
|
|
client_id = settings.OAUTH_GOOGLE_CLIENT_ID
|
|
client_secret = settings.OAUTH_GOOGLE_CLIENT_SECRET
|
|
elif provider == "github":
|
|
client_id = settings.OAUTH_GITHUB_CLIENT_ID
|
|
client_secret = settings.OAUTH_GITHUB_CLIENT_SECRET
|
|
else:
|
|
raise AuthenticationError(f"Unknown OAuth provider: {provider}")
|
|
|
|
if not client_id or not client_secret:
|
|
raise AuthenticationError(f"OAuth provider {provider} is not configured")
|
|
|
|
return client_id, client_secret
|
|
|
|
@staticmethod
|
|
async def create_authorization_url(
|
|
db: AsyncSession,
|
|
*,
|
|
provider: str,
|
|
redirect_uri: str,
|
|
user_id: str | None = None,
|
|
) -> tuple[str, str]:
|
|
"""
|
|
Create OAuth authorization URL with state and optional PKCE.
|
|
|
|
Args:
|
|
db: Database session
|
|
provider: OAuth provider (google, github)
|
|
redirect_uri: Callback URL after OAuth
|
|
user_id: User ID if linking account (user is logged in)
|
|
|
|
Returns:
|
|
Tuple of (authorization_url, state)
|
|
|
|
Raises:
|
|
AuthenticationError: If provider is not configured
|
|
"""
|
|
if not settings.OAUTH_ENABLED:
|
|
raise AuthenticationError("OAuth is not enabled")
|
|
|
|
if provider not in OAUTH_PROVIDERS:
|
|
raise AuthenticationError(f"Unknown OAuth provider: {provider}")
|
|
|
|
if provider not in settings.enabled_oauth_providers:
|
|
raise AuthenticationError(f"OAuth provider {provider} is not enabled")
|
|
|
|
config = OAUTH_PROVIDERS[provider]
|
|
client_id, client_secret = OAuthService._get_provider_credentials(provider)
|
|
|
|
# Generate state for CSRF protection
|
|
state = secrets.token_urlsafe(32)
|
|
|
|
# Generate PKCE code verifier and challenge if supported
|
|
code_verifier = None
|
|
code_challenge = None
|
|
if config.get("supports_pkce"):
|
|
code_verifier = secrets.token_urlsafe(64)
|
|
# Create code_challenge using S256 method
|
|
import base64
|
|
import hashlib
|
|
|
|
code_challenge_bytes = hashlib.sha256(code_verifier.encode()).digest()
|
|
code_challenge = (
|
|
base64.urlsafe_b64encode(code_challenge_bytes).decode().rstrip("=")
|
|
)
|
|
|
|
# Generate nonce for OIDC (Google)
|
|
nonce = secrets.token_urlsafe(32) if provider == "google" else None
|
|
|
|
# Store state in database
|
|
from uuid import UUID
|
|
|
|
state_data = OAuthStateCreate(
|
|
state=state,
|
|
code_verifier=code_verifier,
|
|
nonce=nonce,
|
|
provider=provider,
|
|
redirect_uri=redirect_uri,
|
|
user_id=UUID(user_id) if user_id else None,
|
|
expires_at=datetime.now(UTC)
|
|
+ timedelta(minutes=settings.OAUTH_STATE_EXPIRE_MINUTES),
|
|
)
|
|
await oauth_state.create_state(db, obj_in=state_data)
|
|
|
|
# Build authorization URL
|
|
async with AsyncOAuth2Client(
|
|
client_id=client_id,
|
|
client_secret=client_secret,
|
|
redirect_uri=redirect_uri,
|
|
) as client:
|
|
# Prepare authorization params
|
|
auth_params = {
|
|
"state": state,
|
|
"scope": " ".join(config["scopes"]),
|
|
}
|
|
|
|
if code_challenge:
|
|
auth_params["code_challenge"] = code_challenge
|
|
auth_params["code_challenge_method"] = "S256"
|
|
|
|
if nonce:
|
|
auth_params["nonce"] = nonce
|
|
|
|
url, _ = client.create_authorization_url(
|
|
config["authorize_url"],
|
|
**auth_params,
|
|
)
|
|
|
|
logger.info(f"OAuth authorization URL created for {provider}")
|
|
return url, state
|
|
|
|
@staticmethod
|
|
async def handle_callback(
|
|
db: AsyncSession,
|
|
*,
|
|
code: str,
|
|
state: str,
|
|
redirect_uri: str,
|
|
) -> OAuthCallbackResponse:
|
|
"""
|
|
Handle OAuth callback and authenticate/create user.
|
|
|
|
Args:
|
|
db: Database session
|
|
code: Authorization code from provider
|
|
state: State parameter for CSRF verification
|
|
redirect_uri: Callback URL (must match authorization request)
|
|
|
|
Returns:
|
|
OAuthCallbackResponse with tokens
|
|
|
|
Raises:
|
|
AuthenticationError: If authentication fails
|
|
"""
|
|
# Validate and consume state
|
|
state_record = await oauth_state.get_and_consume_state(db, state=state)
|
|
if not state_record:
|
|
raise AuthenticationError("Invalid or expired OAuth state")
|
|
|
|
# SECURITY: Validate redirect_uri matches the one from authorization request
|
|
# This prevents authorization code injection attacks (RFC 6749 Section 10.6)
|
|
if state_record.redirect_uri != redirect_uri:
|
|
logger.warning(
|
|
f"OAuth redirect_uri mismatch: expected {state_record.redirect_uri}, "
|
|
f"got {redirect_uri}"
|
|
)
|
|
raise AuthenticationError("Redirect URI mismatch")
|
|
|
|
# Extract provider from state record (str for type safety)
|
|
provider: str = str(state_record.provider)
|
|
|
|
if provider not in OAUTH_PROVIDERS:
|
|
raise AuthenticationError(f"Unknown OAuth provider: {provider}")
|
|
|
|
config = OAUTH_PROVIDERS[provider]
|
|
client_id, client_secret = OAuthService._get_provider_credentials(provider)
|
|
|
|
# Exchange code for tokens
|
|
async with AsyncOAuth2Client(
|
|
client_id=client_id,
|
|
client_secret=client_secret,
|
|
redirect_uri=redirect_uri,
|
|
) as client:
|
|
try:
|
|
# Prepare token request params
|
|
token_params: dict[str, str] = {"code": code}
|
|
|
|
if state_record.code_verifier:
|
|
token_params["code_verifier"] = str(state_record.code_verifier)
|
|
|
|
token = await client.fetch_token(
|
|
config["token_url"],
|
|
**token_params,
|
|
)
|
|
|
|
# SECURITY: Validate ID token signature and nonce for OpenID Connect
|
|
# This prevents token forgery and replay attacks (OIDC Core 3.1.3.7)
|
|
if provider == "google" and state_record.nonce:
|
|
id_token = token.get("id_token")
|
|
if id_token:
|
|
await OAuthService._verify_google_id_token(
|
|
id_token=str(id_token),
|
|
expected_nonce=str(state_record.nonce),
|
|
client_id=client_id,
|
|
)
|
|
except AuthenticationError:
|
|
raise
|
|
except Exception as e:
|
|
logger.error(f"OAuth token exchange failed: {e!s}")
|
|
raise AuthenticationError("Failed to exchange authorization code")
|
|
|
|
# Get user info from provider
|
|
try:
|
|
access_token = token.get("access_token")
|
|
if not access_token:
|
|
raise AuthenticationError("No access token received")
|
|
|
|
user_info = await OAuthService._get_user_info(
|
|
client, provider, config, access_token
|
|
)
|
|
except Exception as e:
|
|
logger.error(f"Failed to get user info: {e!s}")
|
|
raise AuthenticationError(
|
|
"Failed to get user information from provider"
|
|
)
|
|
|
|
# Process user info and create/link account
|
|
provider_user_id = str(user_info.get("id") or user_info.get("sub"))
|
|
# Email can be None if user didn't grant email permission
|
|
# SECURITY: Normalize email (lowercase, strip) to prevent case-based account duplication
|
|
email_raw = user_info.get("email")
|
|
provider_email: str | None = (
|
|
str(email_raw).lower().strip() if email_raw else None
|
|
)
|
|
|
|
if not provider_user_id:
|
|
raise AuthenticationError("Provider did not return user ID")
|
|
|
|
# Check if this OAuth account already exists
|
|
existing_oauth = await oauth_account.get_by_provider_id(
|
|
db, provider=provider, provider_user_id=provider_user_id
|
|
)
|
|
|
|
is_new_user = False
|
|
|
|
if existing_oauth:
|
|
# Existing OAuth account - login
|
|
user = existing_oauth.user
|
|
if not user.is_active:
|
|
raise AuthenticationError("User account is inactive")
|
|
|
|
# Update tokens if stored
|
|
if token.get("access_token"):
|
|
await oauth_account.update_tokens(
|
|
db,
|
|
account=existing_oauth,
|
|
access_token_encrypted=token.get("access_token"),
|
|
refresh_token_encrypted=token.get("refresh_token"),
|
|
token_expires_at=datetime.now(UTC)
|
|
+ timedelta(seconds=token.get("expires_in", 3600)),
|
|
)
|
|
|
|
logger.info(f"OAuth login successful for {user.email} via {provider}")
|
|
|
|
elif state_record.user_id:
|
|
# Account linking flow (user is already logged in)
|
|
result = await db.execute(
|
|
select(User).where(User.id == state_record.user_id)
|
|
)
|
|
user = result.scalar_one_or_none()
|
|
|
|
if not user:
|
|
raise AuthenticationError("User not found for account linking")
|
|
|
|
# Check if user already has this provider linked
|
|
user_id = cast(UUID, user.id)
|
|
existing_provider = await oauth_account.get_user_account_by_provider(
|
|
db, user_id=user_id, provider=provider
|
|
)
|
|
if existing_provider:
|
|
raise AuthenticationError(
|
|
f"You already have a {provider} account linked"
|
|
)
|
|
|
|
# Create OAuth account link
|
|
oauth_create = OAuthAccountCreate(
|
|
user_id=user_id,
|
|
provider=provider,
|
|
provider_user_id=provider_user_id,
|
|
provider_email=provider_email,
|
|
access_token_encrypted=token.get("access_token"),
|
|
refresh_token_encrypted=token.get("refresh_token"),
|
|
token_expires_at=datetime.now(UTC)
|
|
+ timedelta(seconds=token.get("expires_in", 3600))
|
|
if token.get("expires_in")
|
|
else None,
|
|
)
|
|
await oauth_account.create_account(db, obj_in=oauth_create)
|
|
|
|
logger.info(f"OAuth account linked: {provider} -> {user.email}")
|
|
|
|
else:
|
|
# New OAuth login - check for existing user by email
|
|
user = None
|
|
|
|
if provider_email and settings.OAUTH_AUTO_LINK_BY_EMAIL:
|
|
result = await db.execute(
|
|
select(User).where(User.email == provider_email)
|
|
)
|
|
user = result.scalar_one_or_none()
|
|
|
|
if user:
|
|
# Auto-link to existing user
|
|
if not user.is_active:
|
|
raise AuthenticationError("User account is inactive")
|
|
|
|
# Check if user already has this provider linked
|
|
user_id = cast(UUID, user.id)
|
|
existing_provider = await oauth_account.get_user_account_by_provider(
|
|
db, user_id=user_id, provider=provider
|
|
)
|
|
if existing_provider:
|
|
# This shouldn't happen if we got here, but safety check
|
|
logger.warning(
|
|
f"OAuth account already linked (race condition?): {provider} -> {user.email}"
|
|
)
|
|
else:
|
|
# Create OAuth account link
|
|
oauth_create = OAuthAccountCreate(
|
|
user_id=user_id,
|
|
provider=provider,
|
|
provider_user_id=provider_user_id,
|
|
provider_email=provider_email,
|
|
access_token_encrypted=token.get("access_token"),
|
|
refresh_token_encrypted=token.get("refresh_token"),
|
|
token_expires_at=datetime.now(UTC)
|
|
+ timedelta(seconds=token.get("expires_in", 3600))
|
|
if token.get("expires_in")
|
|
else None,
|
|
)
|
|
await oauth_account.create_account(db, obj_in=oauth_create)
|
|
|
|
logger.info(f"OAuth auto-linked by email: {provider} -> {user.email}")
|
|
|
|
else:
|
|
# Create new user
|
|
if not provider_email:
|
|
raise AuthenticationError(
|
|
f"Email is required for registration. "
|
|
f"Please grant email permission to {provider}."
|
|
)
|
|
|
|
user = await OAuthService._create_oauth_user(
|
|
db,
|
|
email=provider_email,
|
|
provider=provider,
|
|
provider_user_id=provider_user_id,
|
|
user_info=user_info,
|
|
token=token,
|
|
)
|
|
is_new_user = True
|
|
|
|
logger.info(f"New user created via OAuth: {user.email} ({provider})")
|
|
|
|
# Generate JWT tokens
|
|
claims = {
|
|
"is_superuser": user.is_superuser,
|
|
"email": user.email,
|
|
"first_name": user.first_name,
|
|
}
|
|
|
|
access_token_jwt = create_access_token(subject=str(user.id), claims=claims)
|
|
refresh_token_jwt = create_refresh_token(subject=str(user.id))
|
|
|
|
return OAuthCallbackResponse(
|
|
access_token=access_token_jwt,
|
|
refresh_token=refresh_token_jwt,
|
|
token_type="bearer",
|
|
expires_in=settings.ACCESS_TOKEN_EXPIRE_MINUTES * 60,
|
|
is_new_user=is_new_user,
|
|
)
|
|
|
|
@staticmethod
|
|
async def _get_user_info(
|
|
client: AsyncOAuth2Client,
|
|
provider: str,
|
|
config: OAuthProviderConfig,
|
|
access_token: str,
|
|
) -> dict[str, object]:
|
|
"""Get user info from OAuth provider."""
|
|
headers = {"Authorization": f"Bearer {access_token}"}
|
|
|
|
if provider == "github":
|
|
# GitHub returns JSON with Accept header
|
|
headers["Accept"] = "application/vnd.github+json"
|
|
|
|
resp = await client.get(config["userinfo_url"], headers=headers)
|
|
resp.raise_for_status()
|
|
user_info = resp.json()
|
|
|
|
# GitHub requires separate request for email
|
|
if provider == "github" and not user_info.get("email"):
|
|
email_resp = await client.get(
|
|
config["email_url"],
|
|
headers=headers,
|
|
)
|
|
email_resp.raise_for_status()
|
|
emails = email_resp.json()
|
|
|
|
# Find primary verified email
|
|
for email_data in emails:
|
|
if email_data.get("primary") and email_data.get("verified"):
|
|
user_info["email"] = email_data["email"]
|
|
break
|
|
|
|
return user_info
|
|
|
|
# Google's OIDC configuration endpoints
|
|
GOOGLE_JWKS_URL = "https://www.googleapis.com/oauth2/v3/certs"
|
|
GOOGLE_ISSUERS = ("https://accounts.google.com", "accounts.google.com")
|
|
|
|
@staticmethod
|
|
async def _verify_google_id_token(
|
|
id_token: str,
|
|
expected_nonce: str,
|
|
client_id: str,
|
|
) -> dict[str, object]:
|
|
"""
|
|
Verify Google ID token signature and claims.
|
|
|
|
SECURITY: This properly verifies the ID token by:
|
|
1. Fetching Google's public keys (JWKS)
|
|
2. Verifying the JWT signature against the public key
|
|
3. Validating issuer, audience, expiry, and nonce claims
|
|
|
|
Args:
|
|
id_token: The ID token JWT string
|
|
expected_nonce: The nonce we sent in the authorization request
|
|
client_id: Our OAuth client ID (expected audience)
|
|
|
|
Returns:
|
|
Decoded ID token payload
|
|
|
|
Raises:
|
|
AuthenticationError: If verification fails
|
|
"""
|
|
import httpx
|
|
from jose import jwt as jose_jwt
|
|
from jose.exceptions import JWTError
|
|
|
|
try:
|
|
# Fetch Google's public keys (JWKS)
|
|
# In production, consider caching this with TTL matching Cache-Control header
|
|
async with httpx.AsyncClient() as client:
|
|
jwks_response = await client.get(
|
|
OAuthService.GOOGLE_JWKS_URL,
|
|
timeout=10.0,
|
|
)
|
|
jwks_response.raise_for_status()
|
|
jwks = jwks_response.json()
|
|
|
|
# Get the key ID from the token header
|
|
unverified_header = jose_jwt.get_unverified_header(id_token)
|
|
kid = unverified_header.get("kid")
|
|
if not kid:
|
|
raise AuthenticationError("ID token missing key ID (kid)")
|
|
|
|
# Find the matching public key
|
|
public_key = None
|
|
for key in jwks.get("keys", []):
|
|
if key.get("kid") == kid:
|
|
public_key = key
|
|
break
|
|
|
|
if not public_key:
|
|
raise AuthenticationError("ID token signed with unknown key")
|
|
|
|
# Verify the token signature and decode claims
|
|
# jose library will verify signature against the JWK
|
|
payload = jose_jwt.decode(
|
|
id_token,
|
|
public_key,
|
|
algorithms=["RS256"], # Google uses RS256
|
|
audience=client_id,
|
|
issuer=OAuthService.GOOGLE_ISSUERS,
|
|
options={
|
|
"verify_signature": True,
|
|
"verify_aud": True,
|
|
"verify_iss": True,
|
|
"verify_exp": True,
|
|
"verify_iat": True,
|
|
},
|
|
)
|
|
|
|
# Verify nonce (OIDC replay attack protection)
|
|
token_nonce = payload.get("nonce")
|
|
if token_nonce != expected_nonce:
|
|
logger.warning(
|
|
f"OAuth ID token nonce mismatch: expected {expected_nonce}, "
|
|
f"got {token_nonce}"
|
|
)
|
|
raise AuthenticationError("Invalid ID token nonce")
|
|
|
|
logger.debug("Google ID token verified successfully")
|
|
return payload
|
|
|
|
except JWTError as e:
|
|
logger.warning(f"Google ID token verification failed: {e}")
|
|
raise AuthenticationError("Invalid ID token signature")
|
|
except httpx.HTTPError as e:
|
|
logger.error(f"Failed to fetch Google JWKS: {e}")
|
|
# If we can't verify the ID token, fail closed for security
|
|
raise AuthenticationError("Failed to verify ID token")
|
|
except Exception as e:
|
|
logger.error(f"Unexpected error verifying Google ID token: {e}")
|
|
raise AuthenticationError("ID token verification error")
|
|
|
|
@staticmethod
|
|
async def _create_oauth_user(
|
|
db: AsyncSession,
|
|
*,
|
|
email: str,
|
|
provider: str,
|
|
provider_user_id: str,
|
|
user_info: dict,
|
|
token: dict,
|
|
) -> User:
|
|
"""Create a new user from OAuth provider data."""
|
|
# Extract name from user_info
|
|
first_name = "User"
|
|
last_name = None
|
|
|
|
if provider == "google":
|
|
first_name = user_info.get("given_name") or user_info.get("name", "User")
|
|
last_name = user_info.get("family_name")
|
|
elif provider == "github":
|
|
# GitHub has full name, try to split
|
|
name = user_info.get("name") or user_info.get("login", "User")
|
|
parts = name.split(" ", 1)
|
|
first_name = parts[0]
|
|
last_name = parts[1] if len(parts) > 1 else None
|
|
|
|
# Create user (no password for OAuth-only users)
|
|
user = User(
|
|
email=email,
|
|
password_hash=None, # OAuth-only user
|
|
first_name=first_name,
|
|
last_name=last_name,
|
|
is_active=True,
|
|
is_superuser=False,
|
|
)
|
|
db.add(user)
|
|
await db.flush() # Get user.id
|
|
|
|
# Create OAuth account link
|
|
user_id = cast(UUID, user.id)
|
|
oauth_create = OAuthAccountCreate(
|
|
user_id=user_id,
|
|
provider=provider,
|
|
provider_user_id=provider_user_id,
|
|
provider_email=email,
|
|
access_token_encrypted=token.get("access_token"),
|
|
refresh_token_encrypted=token.get("refresh_token"),
|
|
token_expires_at=datetime.now(UTC)
|
|
+ timedelta(seconds=token.get("expires_in", 3600))
|
|
if token.get("expires_in")
|
|
else None,
|
|
)
|
|
await oauth_account.create_account(db, obj_in=oauth_create)
|
|
|
|
await db.commit()
|
|
await db.refresh(user)
|
|
|
|
return user
|
|
|
|
@staticmethod
|
|
async def unlink_provider(
|
|
db: AsyncSession,
|
|
*,
|
|
user: User,
|
|
provider: str,
|
|
) -> bool:
|
|
"""
|
|
Unlink an OAuth provider from a user account.
|
|
|
|
Args:
|
|
db: Database session
|
|
user: User to unlink from
|
|
provider: Provider to unlink
|
|
|
|
Returns:
|
|
True if unlinked successfully
|
|
|
|
Raises:
|
|
AuthenticationError: If unlinking would leave user without login method
|
|
"""
|
|
# Check if user can safely remove this OAuth account
|
|
# Note: We query directly instead of using user.can_remove_oauth property
|
|
# because the property uses lazy loading which doesn't work in async context
|
|
user_id = cast(UUID, user.id)
|
|
has_password = user.password_hash is not None
|
|
oauth_accounts = await oauth_account.get_user_accounts(db, user_id=user_id)
|
|
can_remove = has_password or len(oauth_accounts) > 1
|
|
|
|
if not can_remove:
|
|
raise AuthenticationError(
|
|
"Cannot unlink OAuth account. You must have either a password set "
|
|
"or at least one other OAuth provider linked."
|
|
)
|
|
|
|
deleted = await oauth_account.delete_account(
|
|
db, user_id=user_id, provider=provider
|
|
)
|
|
|
|
if not deleted:
|
|
raise AuthenticationError(f"No {provider} account found to unlink")
|
|
|
|
logger.info(f"OAuth provider unlinked: {provider} from {user.email}")
|
|
return True
|
|
|
|
@staticmethod
|
|
async def cleanup_expired_states(db: AsyncSession) -> int:
|
|
"""
|
|
Clean up expired OAuth states.
|
|
|
|
Should be called periodically (e.g., by a background task).
|
|
|
|
Args:
|
|
db: Database session
|
|
|
|
Returns:
|
|
Number of states cleaned up
|
|
"""
|
|
return await oauth_state.cleanup_expired(db)
|