Files
fast-next-template/backend/app/services/oauth_service.py
Felipe Cardoso 29074f26a6 Remove outdated documentation files
- Deleted `I18N_IMPLEMENTATION_PLAN.md` and `PROJECT_PROGRESS.md` to declutter the repository.
- These documents were finalized, no longer relevant, and superseded by implemented features and external references.
2025-11-27 18:55:29 +01:00

718 lines
25 KiB
Python

"""
OAuth Service for handling social authentication flows.
Supports:
- Google OAuth (OpenID Connect)
- GitHub OAuth
Features:
- PKCE support for public clients
- State parameter for CSRF protection
- Auto-linking by email (configurable)
- Account linking for existing users
"""
import logging
import secrets
from datetime import UTC, datetime, timedelta
from typing import TypedDict, cast
from uuid import UUID
from authlib.integrations.httpx_client import AsyncOAuth2Client
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.core.auth import create_access_token, create_refresh_token
from app.core.config import settings
from app.core.exceptions import AuthenticationError
from app.crud import oauth_account, oauth_state
from app.models.user import User
from app.schemas.oauth import (
OAuthAccountCreate,
OAuthCallbackResponse,
OAuthProviderInfo,
OAuthProvidersResponse,
OAuthStateCreate,
)
logger = logging.getLogger(__name__)
class OAuthProviderConfig(TypedDict, total=False):
"""Type definition for OAuth provider configuration."""
name: str
icon: str
authorize_url: str
token_url: str
userinfo_url: str
email_url: str # Optional, GitHub-only
scopes: list[str]
supports_pkce: bool
# Provider configurations
OAUTH_PROVIDERS: dict[str, OAuthProviderConfig] = {
"google": {
"name": "Google",
"icon": "google",
"authorize_url": "https://accounts.google.com/o/oauth2/v2/auth",
"token_url": "https://oauth2.googleapis.com/token",
"userinfo_url": "https://www.googleapis.com/oauth2/v3/userinfo",
"scopes": ["openid", "email", "profile"],
"supports_pkce": True,
},
"github": {
"name": "GitHub",
"icon": "github",
"authorize_url": "https://github.com/login/oauth/authorize",
"token_url": "https://github.com/login/oauth/access_token",
"userinfo_url": "https://api.github.com/user",
"email_url": "https://api.github.com/user/emails",
"scopes": ["read:user", "user:email"],
"supports_pkce": False, # GitHub doesn't support PKCE
},
}
class OAuthService:
"""Service for handling OAuth authentication flows."""
@staticmethod
def get_enabled_providers() -> OAuthProvidersResponse:
"""
Get list of enabled OAuth providers.
Returns:
OAuthProvidersResponse with enabled providers
"""
providers = []
for provider_id in settings.enabled_oauth_providers:
if provider_id in OAUTH_PROVIDERS:
config = OAUTH_PROVIDERS[provider_id]
providers.append(
OAuthProviderInfo(
provider=provider_id,
name=config["name"],
icon=config["icon"],
)
)
return OAuthProvidersResponse(
enabled=settings.OAUTH_ENABLED and len(providers) > 0,
providers=providers,
)
@staticmethod
def _get_provider_credentials(provider: str) -> tuple[str, str]:
"""Get client ID and secret for a provider."""
if provider == "google":
client_id = settings.OAUTH_GOOGLE_CLIENT_ID
client_secret = settings.OAUTH_GOOGLE_CLIENT_SECRET
elif provider == "github":
client_id = settings.OAUTH_GITHUB_CLIENT_ID
client_secret = settings.OAUTH_GITHUB_CLIENT_SECRET
else:
raise AuthenticationError(f"Unknown OAuth provider: {provider}")
if not client_id or not client_secret:
raise AuthenticationError(f"OAuth provider {provider} is not configured")
return client_id, client_secret
@staticmethod
async def create_authorization_url(
db: AsyncSession,
*,
provider: str,
redirect_uri: str,
user_id: str | None = None,
) -> tuple[str, str]:
"""
Create OAuth authorization URL with state and optional PKCE.
Args:
db: Database session
provider: OAuth provider (google, github)
redirect_uri: Callback URL after OAuth
user_id: User ID if linking account (user is logged in)
Returns:
Tuple of (authorization_url, state)
Raises:
AuthenticationError: If provider is not configured
"""
if not settings.OAUTH_ENABLED:
raise AuthenticationError("OAuth is not enabled")
if provider not in OAUTH_PROVIDERS:
raise AuthenticationError(f"Unknown OAuth provider: {provider}")
if provider not in settings.enabled_oauth_providers:
raise AuthenticationError(f"OAuth provider {provider} is not enabled")
config = OAUTH_PROVIDERS[provider]
client_id, client_secret = OAuthService._get_provider_credentials(provider)
# Generate state for CSRF protection
state = secrets.token_urlsafe(32)
# Generate PKCE code verifier and challenge if supported
code_verifier = None
code_challenge = None
if config.get("supports_pkce"):
code_verifier = secrets.token_urlsafe(64)
# Create code_challenge using S256 method
import base64
import hashlib
code_challenge_bytes = hashlib.sha256(code_verifier.encode()).digest()
code_challenge = (
base64.urlsafe_b64encode(code_challenge_bytes).decode().rstrip("=")
)
# Generate nonce for OIDC (Google)
nonce = secrets.token_urlsafe(32) if provider == "google" else None
# Store state in database
from uuid import UUID
state_data = OAuthStateCreate(
state=state,
code_verifier=code_verifier,
nonce=nonce,
provider=provider,
redirect_uri=redirect_uri,
user_id=UUID(user_id) if user_id else None,
expires_at=datetime.now(UTC)
+ timedelta(minutes=settings.OAUTH_STATE_EXPIRE_MINUTES),
)
await oauth_state.create_state(db, obj_in=state_data)
# Build authorization URL
async with AsyncOAuth2Client(
client_id=client_id,
client_secret=client_secret,
redirect_uri=redirect_uri,
) as client:
# Prepare authorization params
auth_params = {
"state": state,
"scope": " ".join(config["scopes"]),
}
if code_challenge:
auth_params["code_challenge"] = code_challenge
auth_params["code_challenge_method"] = "S256"
if nonce:
auth_params["nonce"] = nonce
url, _ = client.create_authorization_url(
config["authorize_url"],
**auth_params,
)
logger.info(f"OAuth authorization URL created for {provider}")
return url, state
@staticmethod
async def handle_callback(
db: AsyncSession,
*,
code: str,
state: str,
redirect_uri: str,
) -> OAuthCallbackResponse:
"""
Handle OAuth callback and authenticate/create user.
Args:
db: Database session
code: Authorization code from provider
state: State parameter for CSRF verification
redirect_uri: Callback URL (must match authorization request)
Returns:
OAuthCallbackResponse with tokens
Raises:
AuthenticationError: If authentication fails
"""
# Validate and consume state
state_record = await oauth_state.get_and_consume_state(db, state=state)
if not state_record:
raise AuthenticationError("Invalid or expired OAuth state")
# SECURITY: Validate redirect_uri matches the one from authorization request
# This prevents authorization code injection attacks (RFC 6749 Section 10.6)
if state_record.redirect_uri != redirect_uri:
logger.warning(
f"OAuth redirect_uri mismatch: expected {state_record.redirect_uri}, "
f"got {redirect_uri}"
)
raise AuthenticationError("Redirect URI mismatch")
# Extract provider from state record (str for type safety)
provider: str = str(state_record.provider)
if provider not in OAUTH_PROVIDERS:
raise AuthenticationError(f"Unknown OAuth provider: {provider}")
config = OAUTH_PROVIDERS[provider]
client_id, client_secret = OAuthService._get_provider_credentials(provider)
# Exchange code for tokens
async with AsyncOAuth2Client(
client_id=client_id,
client_secret=client_secret,
redirect_uri=redirect_uri,
) as client:
try:
# Prepare token request params
token_params: dict[str, str] = {"code": code}
if state_record.code_verifier:
token_params["code_verifier"] = str(state_record.code_verifier)
token = await client.fetch_token(
config["token_url"],
**token_params,
)
# SECURITY: Validate ID token signature and nonce for OpenID Connect
# This prevents token forgery and replay attacks (OIDC Core 3.1.3.7)
if provider == "google" and state_record.nonce:
id_token = token.get("id_token")
if id_token:
await OAuthService._verify_google_id_token(
id_token=str(id_token),
expected_nonce=str(state_record.nonce),
client_id=client_id,
)
except AuthenticationError:
raise
except Exception as e:
logger.error(f"OAuth token exchange failed: {e!s}")
raise AuthenticationError("Failed to exchange authorization code")
# Get user info from provider
try:
access_token = token.get("access_token")
if not access_token:
raise AuthenticationError("No access token received")
user_info = await OAuthService._get_user_info(
client, provider, config, access_token
)
except Exception as e:
logger.error(f"Failed to get user info: {e!s}")
raise AuthenticationError(
"Failed to get user information from provider"
)
# Process user info and create/link account
provider_user_id = str(user_info.get("id") or user_info.get("sub"))
# Email can be None if user didn't grant email permission
# SECURITY: Normalize email (lowercase, strip) to prevent case-based account duplication
email_raw = user_info.get("email")
provider_email: str | None = (
str(email_raw).lower().strip() if email_raw else None
)
if not provider_user_id:
raise AuthenticationError("Provider did not return user ID")
# Check if this OAuth account already exists
existing_oauth = await oauth_account.get_by_provider_id(
db, provider=provider, provider_user_id=provider_user_id
)
is_new_user = False
if existing_oauth:
# Existing OAuth account - login
user = existing_oauth.user
if not user.is_active:
raise AuthenticationError("User account is inactive")
# Update tokens if stored
if token.get("access_token"):
await oauth_account.update_tokens(
db,
account=existing_oauth,
access_token_encrypted=token.get("access_token"), refresh_token_encrypted=token.get("refresh_token"), token_expires_at=datetime.now(UTC)
+ timedelta(seconds=token.get("expires_in", 3600)),
)
logger.info(f"OAuth login successful for {user.email} via {provider}")
elif state_record.user_id:
# Account linking flow (user is already logged in)
result = await db.execute(
select(User).where(User.id == state_record.user_id)
)
user = result.scalar_one_or_none()
if not user:
raise AuthenticationError("User not found for account linking")
# Check if user already has this provider linked
user_id = cast(UUID, user.id)
existing_provider = await oauth_account.get_user_account_by_provider(
db, user_id=user_id, provider=provider
)
if existing_provider:
raise AuthenticationError(
f"You already have a {provider} account linked"
)
# Create OAuth account link
oauth_create = OAuthAccountCreate(
user_id=user_id,
provider=provider,
provider_user_id=provider_user_id,
provider_email=provider_email,
access_token_encrypted=token.get("access_token"), refresh_token_encrypted=token.get("refresh_token"), token_expires_at=datetime.now(UTC)
+ timedelta(seconds=token.get("expires_in", 3600))
if token.get("expires_in")
else None,
)
await oauth_account.create_account(db, obj_in=oauth_create)
logger.info(f"OAuth account linked: {provider} -> {user.email}")
else:
# New OAuth login - check for existing user by email
user = None
if provider_email and settings.OAUTH_AUTO_LINK_BY_EMAIL:
result = await db.execute(
select(User).where(User.email == provider_email)
)
user = result.scalar_one_or_none()
if user:
# Auto-link to existing user
if not user.is_active:
raise AuthenticationError("User account is inactive")
# Check if user already has this provider linked
user_id = cast(UUID, user.id)
existing_provider = await oauth_account.get_user_account_by_provider(
db, user_id=user_id, provider=provider
)
if existing_provider:
# This shouldn't happen if we got here, but safety check
logger.warning(
f"OAuth account already linked (race condition?): {provider} -> {user.email}"
)
else:
# Create OAuth account link
oauth_create = OAuthAccountCreate(
user_id=user_id,
provider=provider,
provider_user_id=provider_user_id,
provider_email=provider_email,
access_token_encrypted=token.get("access_token"),
refresh_token_encrypted=token.get("refresh_token"),
token_expires_at=datetime.now(UTC)
+ timedelta(seconds=token.get("expires_in", 3600))
if token.get("expires_in")
else None,
)
await oauth_account.create_account(db, obj_in=oauth_create)
logger.info(f"OAuth auto-linked by email: {provider} -> {user.email}")
else:
# Create new user
if not provider_email:
raise AuthenticationError(
f"Email is required for registration. "
f"Please grant email permission to {provider}."
)
user = await OAuthService._create_oauth_user(
db,
email=provider_email,
provider=provider,
provider_user_id=provider_user_id,
user_info=user_info,
token=token,
)
is_new_user = True
logger.info(f"New user created via OAuth: {user.email} ({provider})")
# Generate JWT tokens
claims = {
"is_superuser": user.is_superuser,
"email": user.email,
"first_name": user.first_name,
}
access_token_jwt = create_access_token(subject=str(user.id), claims=claims)
refresh_token_jwt = create_refresh_token(subject=str(user.id))
return OAuthCallbackResponse(
access_token=access_token_jwt,
refresh_token=refresh_token_jwt,
token_type="bearer",
expires_in=settings.ACCESS_TOKEN_EXPIRE_MINUTES * 60,
is_new_user=is_new_user,
)
@staticmethod
async def _get_user_info(
client: AsyncOAuth2Client,
provider: str,
config: OAuthProviderConfig,
access_token: str,
) -> dict[str, object]:
"""Get user info from OAuth provider."""
headers = {"Authorization": f"Bearer {access_token}"}
if provider == "github":
# GitHub returns JSON with Accept header
headers["Accept"] = "application/vnd.github+json"
resp = await client.get(config["userinfo_url"], headers=headers)
resp.raise_for_status()
user_info = resp.json()
# GitHub requires separate request for email
if provider == "github" and not user_info.get("email"):
email_resp = await client.get(
config["email_url"],
headers=headers,
)
email_resp.raise_for_status()
emails = email_resp.json()
# Find primary verified email
for email_data in emails:
if email_data.get("primary") and email_data.get("verified"):
user_info["email"] = email_data["email"]
break
return user_info
# Google's OIDC configuration endpoints
GOOGLE_JWKS_URL = "https://www.googleapis.com/oauth2/v3/certs"
GOOGLE_ISSUERS = ("https://accounts.google.com", "accounts.google.com")
@staticmethod
async def _verify_google_id_token(
id_token: str,
expected_nonce: str,
client_id: str,
) -> dict[str, object]:
"""
Verify Google ID token signature and claims.
SECURITY: This properly verifies the ID token by:
1. Fetching Google's public keys (JWKS)
2. Verifying the JWT signature against the public key
3. Validating issuer, audience, expiry, and nonce claims
Args:
id_token: The ID token JWT string
expected_nonce: The nonce we sent in the authorization request
client_id: Our OAuth client ID (expected audience)
Returns:
Decoded ID token payload
Raises:
AuthenticationError: If verification fails
"""
import httpx
from jose import jwt as jose_jwt
from jose.exceptions import JWTError
try:
# Fetch Google's public keys (JWKS)
# In production, consider caching this with TTL matching Cache-Control header
async with httpx.AsyncClient() as client:
jwks_response = await client.get(
OAuthService.GOOGLE_JWKS_URL,
timeout=10.0,
)
jwks_response.raise_for_status()
jwks = jwks_response.json()
# Get the key ID from the token header
unverified_header = jose_jwt.get_unverified_header(id_token)
kid = unverified_header.get("kid")
if not kid:
raise AuthenticationError("ID token missing key ID (kid)")
# Find the matching public key
public_key = None
for key in jwks.get("keys", []):
if key.get("kid") == kid:
public_key = key
break
if not public_key:
raise AuthenticationError("ID token signed with unknown key")
# Verify the token signature and decode claims
# jose library will verify signature against the JWK
payload = jose_jwt.decode(
id_token,
public_key,
algorithms=["RS256"], # Google uses RS256
audience=client_id,
issuer=OAuthService.GOOGLE_ISSUERS,
options={
"verify_signature": True,
"verify_aud": True,
"verify_iss": True,
"verify_exp": True,
"verify_iat": True,
},
)
# Verify nonce (OIDC replay attack protection)
token_nonce = payload.get("nonce")
if token_nonce != expected_nonce:
logger.warning(
f"OAuth ID token nonce mismatch: expected {expected_nonce}, "
f"got {token_nonce}"
)
raise AuthenticationError("Invalid ID token nonce")
logger.debug("Google ID token verified successfully")
return payload
except JWTError as e:
logger.warning(f"Google ID token verification failed: {e}")
raise AuthenticationError("Invalid ID token signature")
except httpx.HTTPError as e:
logger.error(f"Failed to fetch Google JWKS: {e}")
# If we can't verify the ID token, fail closed for security
raise AuthenticationError("Failed to verify ID token")
except Exception as e:
logger.error(f"Unexpected error verifying Google ID token: {e}")
raise AuthenticationError("ID token verification error")
@staticmethod
async def _create_oauth_user(
db: AsyncSession,
*,
email: str,
provider: str,
provider_user_id: str,
user_info: dict,
token: dict,
) -> User:
"""Create a new user from OAuth provider data."""
# Extract name from user_info
first_name = "User"
last_name = None
if provider == "google":
first_name = user_info.get("given_name") or user_info.get("name", "User")
last_name = user_info.get("family_name")
elif provider == "github":
# GitHub has full name, try to split
name = user_info.get("name") or user_info.get("login", "User")
parts = name.split(" ", 1)
first_name = parts[0]
last_name = parts[1] if len(parts) > 1 else None
# Create user (no password for OAuth-only users)
user = User(
email=email,
password_hash=None, # OAuth-only user
first_name=first_name,
last_name=last_name,
is_active=True,
is_superuser=False,
)
db.add(user)
await db.flush() # Get user.id
# Create OAuth account link
user_id = cast(UUID, user.id)
oauth_create = OAuthAccountCreate(
user_id=user_id,
provider=provider,
provider_user_id=provider_user_id,
provider_email=email,
access_token_encrypted=token.get("access_token"), refresh_token_encrypted=token.get("refresh_token"), token_expires_at=datetime.now(UTC)
+ timedelta(seconds=token.get("expires_in", 3600))
if token.get("expires_in")
else None,
)
await oauth_account.create_account(db, obj_in=oauth_create)
await db.commit()
await db.refresh(user)
return user
@staticmethod
async def unlink_provider(
db: AsyncSession,
*,
user: User,
provider: str,
) -> bool:
"""
Unlink an OAuth provider from a user account.
Args:
db: Database session
user: User to unlink from
provider: Provider to unlink
Returns:
True if unlinked successfully
Raises:
AuthenticationError: If unlinking would leave user without login method
"""
# Check if user can safely remove this OAuth account
# Note: We query directly instead of using user.can_remove_oauth property
# because the property uses lazy loading which doesn't work in async context
user_id = cast(UUID, user.id)
has_password = user.password_hash is not None
oauth_accounts = await oauth_account.get_user_accounts(db, user_id=user_id)
can_remove = has_password or len(oauth_accounts) > 1
if not can_remove:
raise AuthenticationError(
"Cannot unlink OAuth account. You must have either a password set "
"or at least one other OAuth provider linked."
)
deleted = await oauth_account.delete_account(
db, user_id=user_id, provider=provider
)
if not deleted:
raise AuthenticationError(f"No {provider} account found to unlink")
logger.info(f"OAuth provider unlinked: {provider} from {user.email}")
return True
@staticmethod
async def cleanup_expired_states(db: AsyncSession) -> int:
"""
Clean up expired OAuth states.
Should be called periodically (e.g., by a background task).
Args:
db: Database session
Returns:
Number of states cleaned up
"""
return await oauth_state.cleanup_expired(db)