Initial implementation of OAuth models, endpoints, and migrations
- Added models for `OAuthClient`, `OAuthState`, and `OAuthAccount`. - Created Pydantic schemas to support OAuth flows, client management, and linked accounts. - Implemented skeleton endpoints for OAuth Provider mode: authorization, token, and revocation. - Updated router imports to include new `/oauth` and `/oauth/provider` routes. - Added Alembic migration script to create OAuth-related database tables. - Enhanced `users` table to allow OAuth-only accounts by making `password_hash` nullable.
This commit is contained in:
144
backend/app/alembic/versions/d5a7b2c9e1f3_add_oauth_models.py
Normal file
144
backend/app/alembic/versions/d5a7b2c9e1f3_add_oauth_models.py
Normal file
@@ -0,0 +1,144 @@
|
||||
"""add oauth models
|
||||
|
||||
Revision ID: d5a7b2c9e1f3
|
||||
Revises: c8e9f3a2d1b4
|
||||
Create Date: 2025-11-24 20:00:00.000000
|
||||
|
||||
"""
|
||||
|
||||
from collections.abc import Sequence
|
||||
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "d5a7b2c9e1f3"
|
||||
down_revision: str | None = "c8e9f3a2d1b4"
|
||||
branch_labels: str | Sequence[str] | None = None
|
||||
depends_on: str | Sequence[str] | None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# 1. Make password_hash nullable on users table (for OAuth-only users)
|
||||
op.alter_column(
|
||||
"users",
|
||||
"password_hash",
|
||||
existing_type=sa.String(length=255),
|
||||
nullable=True,
|
||||
)
|
||||
|
||||
# 2. Create oauth_accounts table (links OAuth providers to users)
|
||||
op.create_table(
|
||||
"oauth_accounts",
|
||||
sa.Column("id", sa.UUID(), nullable=False),
|
||||
sa.Column("user_id", sa.UUID(), nullable=False),
|
||||
sa.Column("provider", sa.String(length=50), nullable=False),
|
||||
sa.Column("provider_user_id", sa.String(length=255), nullable=False),
|
||||
sa.Column("provider_email", sa.String(length=255), nullable=True),
|
||||
sa.Column("access_token_encrypted", sa.String(length=2048), nullable=True),
|
||||
sa.Column("refresh_token_encrypted", sa.String(length=2048), nullable=True),
|
||||
sa.Column("token_expires_at", sa.DateTime(timezone=True), nullable=True),
|
||||
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False),
|
||||
sa.Column("updated_at", sa.DateTime(timezone=True), nullable=False),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
sa.ForeignKeyConstraint(
|
||||
["user_id"],
|
||||
["users.id"],
|
||||
name="fk_oauth_accounts_user_id",
|
||||
ondelete="CASCADE",
|
||||
),
|
||||
sa.UniqueConstraint(
|
||||
"provider", "provider_user_id", name="uq_oauth_provider_user"
|
||||
),
|
||||
)
|
||||
|
||||
# Create indexes for oauth_accounts
|
||||
op.create_index("ix_oauth_accounts_user_id", "oauth_accounts", ["user_id"])
|
||||
op.create_index("ix_oauth_accounts_provider", "oauth_accounts", ["provider"])
|
||||
op.create_index(
|
||||
"ix_oauth_accounts_provider_email", "oauth_accounts", ["provider_email"]
|
||||
)
|
||||
op.create_index(
|
||||
"ix_oauth_accounts_user_provider", "oauth_accounts", ["user_id", "provider"]
|
||||
)
|
||||
|
||||
# 3. Create oauth_states table (CSRF protection during OAuth flow)
|
||||
op.create_table(
|
||||
"oauth_states",
|
||||
sa.Column("id", sa.UUID(), nullable=False),
|
||||
sa.Column("state", sa.String(length=255), nullable=False),
|
||||
sa.Column("code_verifier", sa.String(length=128), nullable=True),
|
||||
sa.Column("nonce", sa.String(length=255), nullable=True),
|
||||
sa.Column("provider", sa.String(length=50), nullable=False),
|
||||
sa.Column("redirect_uri", sa.String(length=500), nullable=True),
|
||||
sa.Column("user_id", sa.UUID(), nullable=True),
|
||||
sa.Column("expires_at", sa.DateTime(timezone=True), nullable=False),
|
||||
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False),
|
||||
sa.Column("updated_at", sa.DateTime(timezone=True), nullable=False),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
)
|
||||
|
||||
# Create indexes for oauth_states
|
||||
op.create_index("ix_oauth_states_state", "oauth_states", ["state"], unique=True)
|
||||
op.create_index("ix_oauth_states_expires_at", "oauth_states", ["expires_at"])
|
||||
|
||||
# 4. Create oauth_clients table (OAuth provider mode - skeleton for MCP)
|
||||
op.create_table(
|
||||
"oauth_clients",
|
||||
sa.Column("id", sa.UUID(), nullable=False),
|
||||
sa.Column("client_id", sa.String(length=64), nullable=False),
|
||||
sa.Column("client_secret_hash", sa.String(length=255), nullable=True),
|
||||
sa.Column("client_name", sa.String(length=255), nullable=False),
|
||||
sa.Column("client_description", sa.String(length=1000), nullable=True),
|
||||
sa.Column("client_type", sa.String(length=20), nullable=False),
|
||||
sa.Column("redirect_uris", postgresql.JSONB(), nullable=False),
|
||||
sa.Column("allowed_scopes", postgresql.JSONB(), nullable=False),
|
||||
sa.Column("access_token_lifetime", sa.String(length=10), nullable=False),
|
||||
sa.Column("refresh_token_lifetime", sa.String(length=10), nullable=False),
|
||||
sa.Column("is_active", sa.Boolean(), nullable=False, server_default="true"),
|
||||
sa.Column("owner_user_id", sa.UUID(), nullable=True),
|
||||
sa.Column("mcp_server_url", sa.String(length=2048), nullable=True),
|
||||
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False),
|
||||
sa.Column("updated_at", sa.DateTime(timezone=True), nullable=False),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
sa.ForeignKeyConstraint(
|
||||
["owner_user_id"],
|
||||
["users.id"],
|
||||
name="fk_oauth_clients_owner_user_id",
|
||||
ondelete="SET NULL",
|
||||
),
|
||||
)
|
||||
|
||||
# Create indexes for oauth_clients
|
||||
op.create_index(
|
||||
"ix_oauth_clients_client_id", "oauth_clients", ["client_id"], unique=True
|
||||
)
|
||||
op.create_index("ix_oauth_clients_is_active", "oauth_clients", ["is_active"])
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Drop oauth_clients table and indexes
|
||||
op.drop_index("ix_oauth_clients_is_active", table_name="oauth_clients")
|
||||
op.drop_index("ix_oauth_clients_client_id", table_name="oauth_clients")
|
||||
op.drop_table("oauth_clients")
|
||||
|
||||
# Drop oauth_states table and indexes
|
||||
op.drop_index("ix_oauth_states_expires_at", table_name="oauth_states")
|
||||
op.drop_index("ix_oauth_states_state", table_name="oauth_states")
|
||||
op.drop_table("oauth_states")
|
||||
|
||||
# Drop oauth_accounts table and indexes
|
||||
op.drop_index("ix_oauth_accounts_user_provider", table_name="oauth_accounts")
|
||||
op.drop_index("ix_oauth_accounts_provider_email", table_name="oauth_accounts")
|
||||
op.drop_index("ix_oauth_accounts_provider", table_name="oauth_accounts")
|
||||
op.drop_index("ix_oauth_accounts_user_id", table_name="oauth_accounts")
|
||||
op.drop_table("oauth_accounts")
|
||||
|
||||
# Revert password_hash to non-nullable
|
||||
op.alter_column(
|
||||
"users",
|
||||
"password_hash",
|
||||
existing_type=sa.String(length=255),
|
||||
nullable=False,
|
||||
)
|
||||
@@ -1,9 +1,21 @@
|
||||
from fastapi import APIRouter
|
||||
|
||||
from app.api.routes import admin, auth, organizations, sessions, users
|
||||
from app.api.routes import (
|
||||
admin,
|
||||
auth,
|
||||
oauth,
|
||||
oauth_provider,
|
||||
organizations,
|
||||
sessions,
|
||||
users,
|
||||
)
|
||||
|
||||
api_router = APIRouter()
|
||||
api_router.include_router(auth.router, prefix="/auth", tags=["Authentication"])
|
||||
api_router.include_router(oauth.router, prefix="/oauth", tags=["OAuth"])
|
||||
api_router.include_router(
|
||||
oauth_provider.router, prefix="/oauth", tags=["OAuth Provider"]
|
||||
)
|
||||
api_router.include_router(users.router, prefix="/users", tags=["Users"])
|
||||
api_router.include_router(sessions.router, prefix="/sessions", tags=["Sessions"])
|
||||
api_router.include_router(admin.router, prefix="/admin", tags=["Admin"])
|
||||
|
||||
433
backend/app/api/routes/oauth.py
Normal file
433
backend/app/api/routes/oauth.py
Normal file
@@ -0,0 +1,433 @@
|
||||
# app/api/routes/oauth.py
|
||||
"""
|
||||
OAuth routes for social authentication.
|
||||
|
||||
Endpoints:
|
||||
- GET /oauth/providers - List enabled OAuth providers
|
||||
- GET /oauth/authorize/{provider} - Get authorization URL
|
||||
- POST /oauth/callback/{provider} - Handle OAuth callback
|
||||
- GET /oauth/accounts - List linked OAuth accounts
|
||||
- DELETE /oauth/accounts/{provider} - Unlink an OAuth account
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
from datetime import UTC, datetime
|
||||
from typing import Any
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query, Request, status
|
||||
from slowapi import Limiter
|
||||
from slowapi.util import get_remote_address
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.api.dependencies.auth import get_current_user, get_optional_current_user
|
||||
from app.core.auth import decode_token
|
||||
from app.core.config import settings
|
||||
from app.core.database import get_db
|
||||
from app.core.exceptions import AuthenticationError as AuthError
|
||||
from app.crud import oauth_account
|
||||
from app.crud.session import session as session_crud
|
||||
from app.models.user import User
|
||||
from app.schemas.oauth import (
|
||||
OAuthAccountsListResponse,
|
||||
OAuthCallbackRequest,
|
||||
OAuthCallbackResponse,
|
||||
OAuthProvidersResponse,
|
||||
OAuthUnlinkResponse,
|
||||
)
|
||||
from app.schemas.sessions import SessionCreate
|
||||
from app.schemas.users import Token
|
||||
from app.services.oauth_service import OAuthService
|
||||
from app.utils.device import extract_device_info
|
||||
|
||||
router = APIRouter()
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Initialize limiter for this router
|
||||
limiter = Limiter(key_func=get_remote_address)
|
||||
|
||||
# Use higher rate limits in test environment
|
||||
IS_TEST = os.getenv("IS_TEST", "False") == "True"
|
||||
RATE_MULTIPLIER = 100 if IS_TEST else 1
|
||||
|
||||
|
||||
async def _create_oauth_login_session(
|
||||
db: AsyncSession,
|
||||
request: Request,
|
||||
user: User,
|
||||
tokens: Token,
|
||||
provider: str,
|
||||
) -> None:
|
||||
"""
|
||||
Create a session record for successful OAuth login.
|
||||
|
||||
This is a best-effort operation - login succeeds even if session creation fails.
|
||||
"""
|
||||
try:
|
||||
device_info = extract_device_info(request)
|
||||
|
||||
# Decode refresh token to get JTI and expiration
|
||||
refresh_payload = decode_token(tokens.refresh_token, verify_type="refresh")
|
||||
|
||||
session_data = SessionCreate(
|
||||
user_id=user.id,
|
||||
refresh_token_jti=refresh_payload.jti,
|
||||
device_name=device_info.device_name or f"OAuth ({provider})",
|
||||
device_id=device_info.device_id,
|
||||
ip_address=device_info.ip_address,
|
||||
user_agent=device_info.user_agent,
|
||||
last_used_at=datetime.now(UTC),
|
||||
expires_at=datetime.fromtimestamp(refresh_payload.exp, tz=UTC),
|
||||
location_city=device_info.location_city,
|
||||
location_country=device_info.location_country,
|
||||
)
|
||||
|
||||
await session_crud.create_session(db, obj_in=session_data)
|
||||
|
||||
logger.info(
|
||||
f"OAuth login successful: {user.email} via {provider} "
|
||||
f"from {device_info.device_name} (IP: {device_info.ip_address})"
|
||||
)
|
||||
except Exception as session_err:
|
||||
# Log but don't fail login if session creation fails
|
||||
logger.error(
|
||||
f"Failed to create session for OAuth login {user.email}: {session_err!s}",
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/providers",
|
||||
response_model=OAuthProvidersResponse,
|
||||
summary="List OAuth Providers",
|
||||
description="""
|
||||
Get list of enabled OAuth providers for the login/register UI.
|
||||
|
||||
Returns:
|
||||
List of enabled providers with display info.
|
||||
""",
|
||||
operation_id="list_oauth_providers",
|
||||
)
|
||||
async def list_providers() -> Any:
|
||||
"""
|
||||
Get list of enabled OAuth providers.
|
||||
|
||||
This endpoint is public (no authentication required) as it's needed
|
||||
for the login/register UI to display available social login options.
|
||||
"""
|
||||
return OAuthService.get_enabled_providers()
|
||||
|
||||
|
||||
@router.get(
|
||||
"/authorize/{provider}",
|
||||
response_model=dict,
|
||||
summary="Get OAuth Authorization URL",
|
||||
description="""
|
||||
Get the authorization URL to redirect the user to the OAuth provider.
|
||||
|
||||
The frontend should redirect the user to the returned URL.
|
||||
After authentication, the provider will redirect back to the callback URL.
|
||||
|
||||
**Rate Limit**: 10 requests/minute
|
||||
""",
|
||||
operation_id="get_oauth_authorization_url",
|
||||
)
|
||||
@limiter.limit(f"{10 * RATE_MULTIPLIER}/minute")
|
||||
async def get_authorization_url(
|
||||
request: Request,
|
||||
provider: str,
|
||||
redirect_uri: str = Query(
|
||||
..., description="Frontend callback URL after OAuth completes"
|
||||
),
|
||||
current_user: User | None = Depends(get_optional_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> Any:
|
||||
"""
|
||||
Get OAuth authorization URL.
|
||||
|
||||
Args:
|
||||
provider: OAuth provider (google, github)
|
||||
redirect_uri: Frontend callback URL
|
||||
current_user: Current user (optional, for account linking)
|
||||
db: Database session
|
||||
|
||||
Returns:
|
||||
dict with authorization_url and state
|
||||
"""
|
||||
if not settings.OAUTH_ENABLED:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="OAuth is not enabled",
|
||||
)
|
||||
|
||||
try:
|
||||
# If user is logged in, this is an account linking flow
|
||||
user_id = str(current_user.id) if current_user else None
|
||||
|
||||
url, state = await OAuthService.create_authorization_url(
|
||||
db,
|
||||
provider=provider,
|
||||
redirect_uri=redirect_uri,
|
||||
user_id=user_id,
|
||||
)
|
||||
|
||||
return {
|
||||
"authorization_url": url,
|
||||
"state": state,
|
||||
}
|
||||
|
||||
except AuthError as e:
|
||||
logger.warning(f"OAuth authorization failed: {e!s}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=str(e),
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"OAuth authorization error: {e!s}", exc_info=True)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Failed to create authorization URL",
|
||||
)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/callback/{provider}",
|
||||
response_model=OAuthCallbackResponse,
|
||||
summary="OAuth Callback",
|
||||
description="""
|
||||
Handle OAuth callback from provider.
|
||||
|
||||
The frontend should call this endpoint with the code and state
|
||||
parameters received from the OAuth provider redirect.
|
||||
|
||||
Returns:
|
||||
JWT tokens for the authenticated user.
|
||||
|
||||
**Rate Limit**: 10 requests/minute
|
||||
""",
|
||||
operation_id="handle_oauth_callback",
|
||||
)
|
||||
@limiter.limit(f"{10 * RATE_MULTIPLIER}/minute")
|
||||
async def handle_callback(
|
||||
request: Request,
|
||||
provider: str,
|
||||
callback_data: OAuthCallbackRequest,
|
||||
redirect_uri: str = Query(
|
||||
..., description="Must match the redirect_uri used in authorization"
|
||||
),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> Any:
|
||||
"""
|
||||
Handle OAuth callback.
|
||||
|
||||
Args:
|
||||
provider: OAuth provider (google, github)
|
||||
callback_data: Code and state from provider
|
||||
redirect_uri: Original redirect URI (for validation)
|
||||
db: Database session
|
||||
|
||||
Returns:
|
||||
OAuthCallbackResponse with tokens
|
||||
"""
|
||||
if not settings.OAUTH_ENABLED:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="OAuth is not enabled",
|
||||
)
|
||||
|
||||
try:
|
||||
result = await OAuthService.handle_callback(
|
||||
db,
|
||||
code=callback_data.code,
|
||||
state=callback_data.state,
|
||||
redirect_uri=redirect_uri,
|
||||
)
|
||||
|
||||
# Create session for the login (need to get the user first)
|
||||
# Note: This requires fetching the user from the token
|
||||
# For now, we skip session creation here as the result doesn't include user info
|
||||
# The session will be created on next request if needed
|
||||
|
||||
return result
|
||||
|
||||
except AuthError as e:
|
||||
logger.warning(f"OAuth callback failed: {e!s}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail=str(e),
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"OAuth callback error: {e!s}", exc_info=True)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="OAuth authentication failed",
|
||||
)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/accounts",
|
||||
response_model=OAuthAccountsListResponse,
|
||||
summary="List Linked OAuth Accounts",
|
||||
description="""
|
||||
Get list of OAuth accounts linked to the current user.
|
||||
|
||||
Requires authentication.
|
||||
""",
|
||||
operation_id="list_oauth_accounts",
|
||||
)
|
||||
async def list_accounts(
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> Any:
|
||||
"""
|
||||
List OAuth accounts linked to the current user.
|
||||
|
||||
Args:
|
||||
current_user: Current authenticated user
|
||||
db: Database session
|
||||
|
||||
Returns:
|
||||
List of linked OAuth accounts
|
||||
"""
|
||||
accounts = await oauth_account.get_user_accounts(db, user_id=current_user.id)
|
||||
return OAuthAccountsListResponse(accounts=accounts)
|
||||
|
||||
|
||||
@router.delete(
|
||||
"/accounts/{provider}",
|
||||
response_model=OAuthUnlinkResponse,
|
||||
summary="Unlink OAuth Account",
|
||||
description="""
|
||||
Unlink an OAuth provider from the current user.
|
||||
|
||||
The user must have either a password set or another OAuth provider
|
||||
linked to ensure they can still log in.
|
||||
|
||||
**Rate Limit**: 5 requests/minute
|
||||
""",
|
||||
operation_id="unlink_oauth_account",
|
||||
)
|
||||
@limiter.limit(f"{5 * RATE_MULTIPLIER}/minute")
|
||||
async def unlink_account(
|
||||
request: Request,
|
||||
provider: str,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> Any:
|
||||
"""
|
||||
Unlink an OAuth provider from the current user.
|
||||
|
||||
Args:
|
||||
provider: Provider to unlink (google, github)
|
||||
current_user: Current authenticated user
|
||||
db: Database session
|
||||
|
||||
Returns:
|
||||
Success message
|
||||
"""
|
||||
try:
|
||||
await OAuthService.unlink_provider(
|
||||
db,
|
||||
user=current_user,
|
||||
provider=provider,
|
||||
)
|
||||
|
||||
return OAuthUnlinkResponse(
|
||||
success=True,
|
||||
message=f"{provider.capitalize()} account unlinked successfully",
|
||||
)
|
||||
|
||||
except AuthError as e:
|
||||
logger.warning(f"OAuth unlink failed for {current_user.email}: {e!s}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=str(e),
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"OAuth unlink error: {e!s}", exc_info=True)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Failed to unlink OAuth account",
|
||||
)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/link/{provider}",
|
||||
response_model=dict,
|
||||
summary="Start Account Linking",
|
||||
description="""
|
||||
Start the OAuth flow to link a new provider to the current user.
|
||||
|
||||
This is a convenience endpoint that redirects to /authorize/{provider}
|
||||
with the current user context.
|
||||
|
||||
**Rate Limit**: 10 requests/minute
|
||||
""",
|
||||
operation_id="start_oauth_link",
|
||||
)
|
||||
@limiter.limit(f"{10 * RATE_MULTIPLIER}/minute")
|
||||
async def start_link(
|
||||
request: Request,
|
||||
provider: str,
|
||||
redirect_uri: str = Query(
|
||||
..., description="Frontend callback URL after OAuth completes"
|
||||
),
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> Any:
|
||||
"""
|
||||
Start OAuth account linking flow.
|
||||
|
||||
This endpoint requires authentication and will initiate an OAuth flow
|
||||
to link a new provider to the current user's account.
|
||||
|
||||
Args:
|
||||
provider: OAuth provider to link (google, github)
|
||||
redirect_uri: Frontend callback URL
|
||||
current_user: Current authenticated user
|
||||
db: Database session
|
||||
|
||||
Returns:
|
||||
dict with authorization_url and state
|
||||
"""
|
||||
if not settings.OAUTH_ENABLED:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="OAuth is not enabled",
|
||||
)
|
||||
|
||||
# Check if user already has this provider linked
|
||||
existing = await oauth_account.get_user_account_by_provider(
|
||||
db, user_id=current_user.id, provider=provider
|
||||
)
|
||||
if existing:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"You already have a {provider} account linked",
|
||||
)
|
||||
|
||||
try:
|
||||
url, state = await OAuthService.create_authorization_url(
|
||||
db,
|
||||
provider=provider,
|
||||
redirect_uri=redirect_uri,
|
||||
user_id=str(current_user.id),
|
||||
)
|
||||
|
||||
return {
|
||||
"authorization_url": url,
|
||||
"state": state,
|
||||
}
|
||||
|
||||
except AuthError as e:
|
||||
logger.warning(f"OAuth link authorization failed: {e!s}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=str(e),
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"OAuth link error: {e!s}", exc_info=True)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Failed to create authorization URL",
|
||||
)
|
||||
312
backend/app/api/routes/oauth_provider.py
Normal file
312
backend/app/api/routes/oauth_provider.py
Normal file
@@ -0,0 +1,312 @@
|
||||
# app/api/routes/oauth_provider.py
|
||||
"""
|
||||
OAuth Provider routes (Authorization Server mode).
|
||||
|
||||
This is a skeleton implementation for MCP (Model Context Protocol) client authentication.
|
||||
Provides basic OAuth 2.0 endpoints that can be expanded for full functionality.
|
||||
|
||||
Endpoints:
|
||||
- GET /.well-known/oauth-authorization-server - Server metadata (RFC 8414)
|
||||
- GET /oauth/provider/authorize - Authorization endpoint (skeleton)
|
||||
- POST /oauth/provider/token - Token endpoint (skeleton)
|
||||
- POST /oauth/provider/revoke - Token revocation endpoint (skeleton)
|
||||
|
||||
NOTE: This is intentionally minimal. Full implementation should include:
|
||||
- Complete authorization code flow
|
||||
- Refresh token handling
|
||||
- Scope validation
|
||||
- Client authentication
|
||||
- PKCE support
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from fastapi import APIRouter, Depends, Form, HTTPException, Query, status
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.core.config import settings
|
||||
from app.core.database import get_db
|
||||
from app.crud import oauth_client
|
||||
from app.schemas.oauth import OAuthServerMetadata
|
||||
|
||||
router = APIRouter()
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/.well-known/oauth-authorization-server",
|
||||
response_model=OAuthServerMetadata,
|
||||
summary="OAuth Server Metadata",
|
||||
description="""
|
||||
OAuth 2.0 Authorization Server Metadata (RFC 8414).
|
||||
|
||||
Returns server metadata including supported endpoints, scopes,
|
||||
and capabilities for MCP clients.
|
||||
""",
|
||||
operation_id="get_oauth_server_metadata",
|
||||
tags=["OAuth Provider"],
|
||||
)
|
||||
async def get_server_metadata() -> Any:
|
||||
"""
|
||||
Get OAuth 2.0 server metadata.
|
||||
|
||||
This endpoint is used by MCP clients to discover the authorization
|
||||
server's capabilities.
|
||||
"""
|
||||
if not settings.OAUTH_PROVIDER_ENABLED:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="OAuth provider mode is not enabled",
|
||||
)
|
||||
|
||||
base_url = settings.OAUTH_ISSUER.rstrip("/")
|
||||
|
||||
return OAuthServerMetadata(
|
||||
issuer=base_url,
|
||||
authorization_endpoint=f"{base_url}/api/v1/oauth/provider/authorize",
|
||||
token_endpoint=f"{base_url}/api/v1/oauth/provider/token",
|
||||
revocation_endpoint=f"{base_url}/api/v1/oauth/provider/revoke",
|
||||
registration_endpoint=None, # Dynamic registration not implemented
|
||||
scopes_supported=[
|
||||
"openid",
|
||||
"profile",
|
||||
"email",
|
||||
"read:users",
|
||||
"write:users",
|
||||
"read:organizations",
|
||||
"write:organizations",
|
||||
],
|
||||
response_types_supported=["code"],
|
||||
grant_types_supported=["authorization_code", "refresh_token"],
|
||||
code_challenge_methods_supported=["S256"],
|
||||
)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/provider/authorize",
|
||||
summary="Authorization Endpoint (Skeleton)",
|
||||
description="""
|
||||
OAuth 2.0 Authorization Endpoint.
|
||||
|
||||
**NOTE**: This is a skeleton implementation. In a full implementation,
|
||||
this would:
|
||||
1. Validate client_id and redirect_uri
|
||||
2. Display consent screen to user
|
||||
3. Generate authorization code
|
||||
4. Redirect back to client with code
|
||||
|
||||
Currently returns a 501 Not Implemented response.
|
||||
""",
|
||||
operation_id="oauth_provider_authorize",
|
||||
tags=["OAuth Provider"],
|
||||
)
|
||||
async def authorize(
|
||||
response_type: str = Query(..., description="Must be 'code'"),
|
||||
client_id: str = Query(..., description="OAuth client ID"),
|
||||
redirect_uri: str = Query(..., description="Redirect URI"),
|
||||
scope: str = Query(default="", description="Requested scopes"),
|
||||
state: str = Query(default="", description="CSRF state parameter"),
|
||||
code_challenge: str | None = Query(default=None, description="PKCE code challenge"),
|
||||
code_challenge_method: str | None = Query(
|
||||
default=None, description="PKCE method (S256)"
|
||||
),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> Any:
|
||||
"""
|
||||
Authorization endpoint (skeleton).
|
||||
|
||||
In a full implementation, this would:
|
||||
1. Validate the client and redirect URI
|
||||
2. Authenticate the user (if not already)
|
||||
3. Show consent screen
|
||||
4. Generate authorization code
|
||||
5. Redirect to redirect_uri with code
|
||||
"""
|
||||
if not settings.OAUTH_PROVIDER_ENABLED:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="OAuth provider mode is not enabled",
|
||||
)
|
||||
|
||||
# Validate client exists
|
||||
client = await oauth_client.get_by_client_id(db, client_id=client_id)
|
||||
if not client:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="invalid_client: Unknown client_id",
|
||||
)
|
||||
|
||||
# Validate redirect_uri
|
||||
if redirect_uri not in (client.redirect_uris or []):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="invalid_request: Invalid redirect_uri",
|
||||
)
|
||||
|
||||
# Skeleton: Return not implemented
|
||||
# Full implementation would redirect to consent screen
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_501_NOT_IMPLEMENTED,
|
||||
detail="Authorization endpoint not fully implemented. "
|
||||
"This is a skeleton for MCP integration.",
|
||||
)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/provider/token",
|
||||
summary="Token Endpoint (Skeleton)",
|
||||
description="""
|
||||
OAuth 2.0 Token Endpoint.
|
||||
|
||||
**NOTE**: This is a skeleton implementation. In a full implementation,
|
||||
this would exchange authorization codes for access tokens.
|
||||
|
||||
Currently returns a 501 Not Implemented response.
|
||||
""",
|
||||
operation_id="oauth_provider_token",
|
||||
tags=["OAuth Provider"],
|
||||
)
|
||||
async def token(
|
||||
grant_type: str = Form(..., description="Grant type (authorization_code)"),
|
||||
code: str | None = Form(default=None, description="Authorization code"),
|
||||
redirect_uri: str | None = Form(default=None, description="Redirect URI"),
|
||||
client_id: str | None = Form(default=None, description="Client ID"),
|
||||
client_secret: str | None = Form(default=None, description="Client secret"),
|
||||
code_verifier: str | None = Form(default=None, description="PKCE code verifier"),
|
||||
refresh_token: str | None = Form(default=None, description="Refresh token"),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> Any:
|
||||
"""
|
||||
Token endpoint (skeleton).
|
||||
|
||||
Supported grant types (when fully implemented):
|
||||
- authorization_code: Exchange code for tokens
|
||||
- refresh_token: Refresh access token
|
||||
"""
|
||||
if not settings.OAUTH_PROVIDER_ENABLED:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="OAuth provider mode is not enabled",
|
||||
)
|
||||
|
||||
if grant_type not in ["authorization_code", "refresh_token"]:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="unsupported_grant_type",
|
||||
)
|
||||
|
||||
# Skeleton: Return not implemented
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_501_NOT_IMPLEMENTED,
|
||||
detail="Token endpoint not fully implemented. "
|
||||
"This is a skeleton for MCP integration.",
|
||||
)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/provider/revoke",
|
||||
summary="Token Revocation Endpoint (Skeleton)",
|
||||
description="""
|
||||
OAuth 2.0 Token Revocation Endpoint (RFC 7009).
|
||||
|
||||
**NOTE**: This is a skeleton implementation.
|
||||
|
||||
Currently returns a 501 Not Implemented response.
|
||||
""",
|
||||
operation_id="oauth_provider_revoke",
|
||||
tags=["OAuth Provider"],
|
||||
)
|
||||
async def revoke(
|
||||
token: str = Form(..., description="Token to revoke"),
|
||||
token_type_hint: str | None = Form(
|
||||
default=None, description="Token type hint (access_token, refresh_token)"
|
||||
),
|
||||
client_id: str | None = Form(default=None, description="Client ID"),
|
||||
client_secret: str | None = Form(default=None, description="Client secret"),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> Any:
|
||||
"""
|
||||
Token revocation endpoint (skeleton).
|
||||
|
||||
In a full implementation, this would invalidate the specified token.
|
||||
"""
|
||||
if not settings.OAUTH_PROVIDER_ENABLED:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="OAuth provider mode is not enabled",
|
||||
)
|
||||
|
||||
# Skeleton: Return not implemented
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_501_NOT_IMPLEMENTED,
|
||||
detail="Revocation endpoint not fully implemented. "
|
||||
"This is a skeleton for MCP integration.",
|
||||
)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Client Management (Admin only)
|
||||
# ============================================================================
|
||||
|
||||
|
||||
@router.post(
|
||||
"/provider/clients",
|
||||
summary="Register OAuth Client (Admin)",
|
||||
description="""
|
||||
Register a new OAuth client (admin only).
|
||||
|
||||
This endpoint allows creating MCP clients that can authenticate
|
||||
against this API.
|
||||
|
||||
**NOTE**: This is a minimal implementation.
|
||||
""",
|
||||
operation_id="register_oauth_client",
|
||||
tags=["OAuth Provider"],
|
||||
)
|
||||
async def register_client(
|
||||
client_name: str = Form(..., description="Client application name"),
|
||||
redirect_uris: str = Form(..., description="Comma-separated list of redirect URIs"),
|
||||
client_type: str = Form(default="public", description="public or confidential"),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> Any:
|
||||
"""
|
||||
Register a new OAuth client (skeleton).
|
||||
|
||||
In a full implementation, this would require admin authentication.
|
||||
"""
|
||||
if not settings.OAUTH_PROVIDER_ENABLED:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="OAuth provider mode is not enabled",
|
||||
)
|
||||
|
||||
# NOTE: In production, this should require admin authentication
|
||||
# For now, this is a skeleton that shows the structure
|
||||
|
||||
from app.schemas.oauth import OAuthClientCreate
|
||||
|
||||
client_data = OAuthClientCreate(
|
||||
client_name=client_name,
|
||||
client_description=None,
|
||||
redirect_uris=[uri.strip() for uri in redirect_uris.split(",")],
|
||||
allowed_scopes=["openid", "profile", "email"],
|
||||
client_type=client_type,
|
||||
)
|
||||
|
||||
client, secret = await oauth_client.create_client(db, obj_in=client_data)
|
||||
|
||||
result = {
|
||||
"client_id": client.client_id,
|
||||
"client_name": client.client_name,
|
||||
"client_type": client.client_type,
|
||||
"redirect_uris": client.redirect_uris,
|
||||
}
|
||||
|
||||
if secret:
|
||||
result["client_secret"] = secret
|
||||
result["warning"] = (
|
||||
"Store the client_secret securely. It will not be shown again."
|
||||
)
|
||||
|
||||
return result
|
||||
@@ -76,6 +76,60 @@ class Settings(BaseSettings):
|
||||
description="Frontend application URL for email links",
|
||||
)
|
||||
|
||||
# OAuth Configuration
|
||||
OAUTH_ENABLED: bool = Field(
|
||||
default=False,
|
||||
description="Enable OAuth authentication (social login)",
|
||||
)
|
||||
OAUTH_AUTO_LINK_BY_EMAIL: bool = Field(
|
||||
default=True,
|
||||
description="Automatically link OAuth accounts to existing users with matching email",
|
||||
)
|
||||
OAUTH_STATE_EXPIRE_MINUTES: int = Field(
|
||||
default=10,
|
||||
description="OAuth state parameter expiration time in minutes",
|
||||
)
|
||||
|
||||
# Google OAuth
|
||||
OAUTH_GOOGLE_CLIENT_ID: str | None = Field(
|
||||
default=None,
|
||||
description="Google OAuth client ID from Google Cloud Console",
|
||||
)
|
||||
OAUTH_GOOGLE_CLIENT_SECRET: str | None = Field(
|
||||
default=None,
|
||||
description="Google OAuth client secret from Google Cloud Console",
|
||||
)
|
||||
|
||||
# GitHub OAuth
|
||||
OAUTH_GITHUB_CLIENT_ID: str | None = Field(
|
||||
default=None,
|
||||
description="GitHub OAuth client ID from GitHub Developer Settings",
|
||||
)
|
||||
OAUTH_GITHUB_CLIENT_SECRET: str | None = Field(
|
||||
default=None,
|
||||
description="GitHub OAuth client secret from GitHub Developer Settings",
|
||||
)
|
||||
|
||||
# OAuth Provider Mode (for MCP clients - skeleton)
|
||||
OAUTH_PROVIDER_ENABLED: bool = Field(
|
||||
default=False,
|
||||
description="Enable OAuth provider mode (act as authorization server for MCP clients)",
|
||||
)
|
||||
OAUTH_ISSUER: str = Field(
|
||||
default="http://localhost:8000",
|
||||
description="OAuth issuer URL (your API base URL)",
|
||||
)
|
||||
|
||||
@property
|
||||
def enabled_oauth_providers(self) -> list[str]:
|
||||
"""Get list of enabled OAuth providers based on configured credentials."""
|
||||
providers = []
|
||||
if self.OAUTH_GOOGLE_CLIENT_ID and self.OAUTH_GOOGLE_CLIENT_SECRET:
|
||||
providers.append("google")
|
||||
if self.OAUTH_GITHUB_CLIENT_ID and self.OAUTH_GITHUB_CLIENT_SECRET:
|
||||
providers.append("github")
|
||||
return providers
|
||||
|
||||
# Admin user
|
||||
FIRST_SUPERUSER_EMAIL: str | None = Field(
|
||||
default=None, description="Email for first superuser account"
|
||||
|
||||
@@ -1,6 +1,14 @@
|
||||
# app/crud/__init__.py
|
||||
from .oauth import oauth_account, oauth_client, oauth_state
|
||||
from .organization import organization
|
||||
from .session import session as session_crud
|
||||
from .user import user
|
||||
|
||||
__all__ = ["organization", "session_crud", "user"]
|
||||
__all__ = [
|
||||
"oauth_account",
|
||||
"oauth_client",
|
||||
"oauth_state",
|
||||
"organization",
|
||||
"session_crud",
|
||||
"user",
|
||||
]
|
||||
|
||||
653
backend/app/crud/oauth.py
Normal file
653
backend/app/crud/oauth.py
Normal file
@@ -0,0 +1,653 @@
|
||||
"""
|
||||
Async CRUD operations for OAuth models using SQLAlchemy 2.0 patterns.
|
||||
|
||||
Provides operations for:
|
||||
- OAuthAccount: Managing linked OAuth provider accounts
|
||||
- OAuthState: CSRF protection state during OAuth flows
|
||||
- OAuthClient: Registered OAuth clients (provider mode skeleton)
|
||||
"""
|
||||
|
||||
import logging
|
||||
import secrets
|
||||
from datetime import UTC, datetime
|
||||
from uuid import UUID
|
||||
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy import and_, delete, select
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm import joinedload
|
||||
|
||||
from app.crud.base import CRUDBase
|
||||
from app.models.oauth_account import OAuthAccount
|
||||
from app.models.oauth_client import OAuthClient
|
||||
from app.models.oauth_state import OAuthState
|
||||
from app.schemas.oauth import OAuthAccountCreate, OAuthClientCreate, OAuthStateCreate
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# OAuth Account CRUD
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class EmptySchema(BaseModel):
|
||||
"""Placeholder schema for CRUD operations that don't need update schemas."""
|
||||
|
||||
|
||||
class CRUDOAuthAccount(CRUDBase[OAuthAccount, OAuthAccountCreate, EmptySchema]):
|
||||
"""CRUD operations for OAuth account links."""
|
||||
|
||||
async def get_by_provider_id(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
*,
|
||||
provider: str,
|
||||
provider_user_id: str,
|
||||
) -> OAuthAccount | None:
|
||||
"""
|
||||
Get OAuth account by provider and provider user ID.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
provider: OAuth provider name (google, github)
|
||||
provider_user_id: User ID from the OAuth provider
|
||||
|
||||
Returns:
|
||||
OAuthAccount if found, None otherwise
|
||||
"""
|
||||
try:
|
||||
result = await db.execute(
|
||||
select(OAuthAccount)
|
||||
.where(
|
||||
and_(
|
||||
OAuthAccount.provider == provider,
|
||||
OAuthAccount.provider_user_id == provider_user_id,
|
||||
)
|
||||
)
|
||||
.options(joinedload(OAuthAccount.user))
|
||||
)
|
||||
return result.scalar_one_or_none()
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error getting OAuth account for {provider}:{provider_user_id}: {e!s}"
|
||||
)
|
||||
raise
|
||||
|
||||
async def get_by_provider_email(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
*,
|
||||
provider: str,
|
||||
email: str,
|
||||
) -> OAuthAccount | None:
|
||||
"""
|
||||
Get OAuth account by provider and email.
|
||||
|
||||
Used for auto-linking existing accounts by email.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
provider: OAuth provider name
|
||||
email: Email address from the OAuth provider
|
||||
|
||||
Returns:
|
||||
OAuthAccount if found, None otherwise
|
||||
"""
|
||||
try:
|
||||
result = await db.execute(
|
||||
select(OAuthAccount)
|
||||
.where(
|
||||
and_(
|
||||
OAuthAccount.provider == provider,
|
||||
OAuthAccount.provider_email == email,
|
||||
)
|
||||
)
|
||||
.options(joinedload(OAuthAccount.user))
|
||||
)
|
||||
return result.scalar_one_or_none()
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error getting OAuth account for {provider} email {email}: {e!s}"
|
||||
)
|
||||
raise
|
||||
|
||||
async def get_user_accounts(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
*,
|
||||
user_id: str | UUID,
|
||||
) -> list[OAuthAccount]:
|
||||
"""
|
||||
Get all OAuth accounts linked to a user.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
user_id: User ID
|
||||
|
||||
Returns:
|
||||
List of OAuthAccount objects
|
||||
"""
|
||||
try:
|
||||
user_uuid = UUID(str(user_id)) if isinstance(user_id, str) else user_id
|
||||
|
||||
result = await db.execute(
|
||||
select(OAuthAccount)
|
||||
.where(OAuthAccount.user_id == user_uuid)
|
||||
.order_by(OAuthAccount.created_at.desc())
|
||||
)
|
||||
return list(result.scalars().all())
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting OAuth accounts for user {user_id}: {e!s}")
|
||||
raise
|
||||
|
||||
async def get_user_account_by_provider(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
*,
|
||||
user_id: str | UUID,
|
||||
provider: str,
|
||||
) -> OAuthAccount | None:
|
||||
"""
|
||||
Get a specific OAuth account for a user and provider.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
user_id: User ID
|
||||
provider: OAuth provider name
|
||||
|
||||
Returns:
|
||||
OAuthAccount if found, None otherwise
|
||||
"""
|
||||
try:
|
||||
user_uuid = UUID(str(user_id)) if isinstance(user_id, str) else user_id
|
||||
|
||||
result = await db.execute(
|
||||
select(OAuthAccount).where(
|
||||
and_(
|
||||
OAuthAccount.user_id == user_uuid,
|
||||
OAuthAccount.provider == provider,
|
||||
)
|
||||
)
|
||||
)
|
||||
return result.scalar_one_or_none()
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error getting OAuth account for user {user_id}, provider {provider}: {e!s}"
|
||||
)
|
||||
raise
|
||||
|
||||
async def create_account(
|
||||
self, db: AsyncSession, *, obj_in: OAuthAccountCreate
|
||||
) -> OAuthAccount:
|
||||
"""
|
||||
Create a new OAuth account link.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
obj_in: OAuth account creation data
|
||||
|
||||
Returns:
|
||||
Created OAuthAccount
|
||||
|
||||
Raises:
|
||||
ValueError: If account already exists or creation fails
|
||||
"""
|
||||
try:
|
||||
db_obj = OAuthAccount(
|
||||
user_id=obj_in.user_id,
|
||||
provider=obj_in.provider,
|
||||
provider_user_id=obj_in.provider_user_id,
|
||||
provider_email=obj_in.provider_email,
|
||||
access_token_encrypted=obj_in.access_token_encrypted,
|
||||
refresh_token_encrypted=obj_in.refresh_token_encrypted,
|
||||
token_expires_at=obj_in.token_expires_at,
|
||||
)
|
||||
db.add(db_obj)
|
||||
await db.commit()
|
||||
await db.refresh(db_obj)
|
||||
|
||||
logger.info(
|
||||
f"OAuth account created: {obj_in.provider} linked to user {obj_in.user_id}"
|
||||
)
|
||||
return db_obj
|
||||
except IntegrityError as e:
|
||||
await db.rollback()
|
||||
error_msg = str(e.orig) if hasattr(e, "orig") else str(e)
|
||||
if "uq_oauth_provider_user" in error_msg.lower():
|
||||
logger.warning(
|
||||
f"OAuth account already exists: {obj_in.provider}:{obj_in.provider_user_id}"
|
||||
)
|
||||
raise ValueError(
|
||||
f"This {obj_in.provider} account is already linked to another user"
|
||||
)
|
||||
logger.error(f"Integrity error creating OAuth account: {error_msg}")
|
||||
raise ValueError(f"Failed to create OAuth account: {error_msg}")
|
||||
except Exception as e:
|
||||
await db.rollback()
|
||||
logger.error(f"Error creating OAuth account: {e!s}", exc_info=True)
|
||||
raise
|
||||
|
||||
async def delete_account(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
*,
|
||||
user_id: str | UUID,
|
||||
provider: str,
|
||||
) -> bool:
|
||||
"""
|
||||
Delete an OAuth account link.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
user_id: User ID
|
||||
provider: OAuth provider name
|
||||
|
||||
Returns:
|
||||
True if deleted, False if not found
|
||||
"""
|
||||
try:
|
||||
user_uuid = UUID(str(user_id)) if isinstance(user_id, str) else user_id
|
||||
|
||||
result = await db.execute(
|
||||
delete(OAuthAccount).where(
|
||||
and_(
|
||||
OAuthAccount.user_id == user_uuid,
|
||||
OAuthAccount.provider == provider,
|
||||
)
|
||||
)
|
||||
)
|
||||
await db.commit()
|
||||
|
||||
deleted = result.rowcount > 0
|
||||
if deleted:
|
||||
logger.info(
|
||||
f"OAuth account deleted: {provider} unlinked from user {user_id}"
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
f"OAuth account not found for deletion: {provider} for user {user_id}"
|
||||
)
|
||||
|
||||
return deleted
|
||||
except Exception as e:
|
||||
await db.rollback()
|
||||
logger.error(
|
||||
f"Error deleting OAuth account {provider} for user {user_id}: {e!s}"
|
||||
)
|
||||
raise
|
||||
|
||||
async def update_tokens(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
*,
|
||||
account: OAuthAccount,
|
||||
access_token_encrypted: str | None = None,
|
||||
refresh_token_encrypted: str | None = None,
|
||||
token_expires_at: datetime | None = None,
|
||||
) -> OAuthAccount:
|
||||
"""
|
||||
Update OAuth tokens for an account.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
account: OAuthAccount to update
|
||||
access_token_encrypted: New encrypted access token
|
||||
refresh_token_encrypted: New encrypted refresh token
|
||||
token_expires_at: New token expiration time
|
||||
|
||||
Returns:
|
||||
Updated OAuthAccount
|
||||
"""
|
||||
try:
|
||||
if access_token_encrypted is not None:
|
||||
account.access_token_encrypted = access_token_encrypted
|
||||
if refresh_token_encrypted is not None:
|
||||
account.refresh_token_encrypted = refresh_token_encrypted
|
||||
if token_expires_at is not None:
|
||||
account.token_expires_at = token_expires_at
|
||||
|
||||
db.add(account)
|
||||
await db.commit()
|
||||
await db.refresh(account)
|
||||
|
||||
return account
|
||||
except Exception as e:
|
||||
await db.rollback()
|
||||
logger.error(f"Error updating OAuth tokens: {e!s}")
|
||||
raise
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# OAuth State CRUD
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class CRUDOAuthState(CRUDBase[OAuthState, OAuthStateCreate, EmptySchema]):
|
||||
"""CRUD operations for OAuth state (CSRF protection)."""
|
||||
|
||||
async def create_state(
|
||||
self, db: AsyncSession, *, obj_in: OAuthStateCreate
|
||||
) -> OAuthState:
|
||||
"""
|
||||
Create a new OAuth state for CSRF protection.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
obj_in: OAuth state creation data
|
||||
|
||||
Returns:
|
||||
Created OAuthState
|
||||
"""
|
||||
try:
|
||||
db_obj = OAuthState(
|
||||
state=obj_in.state,
|
||||
code_verifier=obj_in.code_verifier,
|
||||
nonce=obj_in.nonce,
|
||||
provider=obj_in.provider,
|
||||
redirect_uri=obj_in.redirect_uri,
|
||||
user_id=obj_in.user_id,
|
||||
expires_at=obj_in.expires_at,
|
||||
)
|
||||
db.add(db_obj)
|
||||
await db.commit()
|
||||
await db.refresh(db_obj)
|
||||
|
||||
logger.debug(f"OAuth state created for {obj_in.provider}")
|
||||
return db_obj
|
||||
except IntegrityError as e:
|
||||
await db.rollback()
|
||||
# State collision (extremely rare with cryptographic random)
|
||||
error_msg = str(e.orig) if hasattr(e, "orig") else str(e)
|
||||
logger.error(f"OAuth state collision: {error_msg}")
|
||||
raise ValueError("Failed to create OAuth state, please retry")
|
||||
except Exception as e:
|
||||
await db.rollback()
|
||||
logger.error(f"Error creating OAuth state: {e!s}", exc_info=True)
|
||||
raise
|
||||
|
||||
async def get_and_consume_state(
|
||||
self, db: AsyncSession, *, state: str
|
||||
) -> OAuthState | None:
|
||||
"""
|
||||
Get and delete OAuth state (consume it).
|
||||
|
||||
This ensures each state can only be used once (replay protection).
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
state: State string to look up
|
||||
|
||||
Returns:
|
||||
OAuthState if found and valid, None otherwise
|
||||
"""
|
||||
try:
|
||||
# Get the state
|
||||
result = await db.execute(
|
||||
select(OAuthState).where(OAuthState.state == state)
|
||||
)
|
||||
db_obj = result.scalar_one_or_none()
|
||||
|
||||
if db_obj is None:
|
||||
logger.warning(f"OAuth state not found: {state[:8]}...")
|
||||
return None
|
||||
|
||||
# Check expiration
|
||||
# Handle both timezone-aware and timezone-naive datetimes
|
||||
now = datetime.now(UTC)
|
||||
expires_at = db_obj.expires_at
|
||||
if expires_at.tzinfo is None:
|
||||
# SQLite returns naive datetimes, assume UTC
|
||||
expires_at = expires_at.replace(tzinfo=UTC)
|
||||
|
||||
if expires_at < now:
|
||||
logger.warning(f"OAuth state expired: {state[:8]}...")
|
||||
await db.delete(db_obj)
|
||||
await db.commit()
|
||||
return None
|
||||
|
||||
# Delete it (consume)
|
||||
await db.delete(db_obj)
|
||||
await db.commit()
|
||||
|
||||
logger.debug(f"OAuth state consumed: {state[:8]}...")
|
||||
return db_obj
|
||||
except Exception as e:
|
||||
await db.rollback()
|
||||
logger.error(f"Error consuming OAuth state: {e!s}")
|
||||
raise
|
||||
|
||||
async def cleanup_expired(self, db: AsyncSession) -> int:
|
||||
"""
|
||||
Clean up expired OAuth states.
|
||||
|
||||
Should be called periodically to remove stale states.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
|
||||
Returns:
|
||||
Number of states deleted
|
||||
"""
|
||||
try:
|
||||
now = datetime.now(UTC)
|
||||
|
||||
stmt = delete(OAuthState).where(OAuthState.expires_at < now)
|
||||
result = await db.execute(stmt)
|
||||
await db.commit()
|
||||
|
||||
count = result.rowcount
|
||||
if count > 0:
|
||||
logger.info(f"Cleaned up {count} expired OAuth states")
|
||||
|
||||
return count
|
||||
except Exception as e:
|
||||
await db.rollback()
|
||||
logger.error(f"Error cleaning up expired OAuth states: {e!s}")
|
||||
raise
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# OAuth Client CRUD (Provider Mode - Skeleton)
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class CRUDOAuthClient(CRUDBase[OAuthClient, OAuthClientCreate, EmptySchema]):
|
||||
"""
|
||||
CRUD operations for OAuth clients (provider mode).
|
||||
|
||||
This is a skeleton implementation for MCP client registration.
|
||||
Full implementation can be expanded when needed.
|
||||
"""
|
||||
|
||||
async def get_by_client_id(
|
||||
self, db: AsyncSession, *, client_id: str
|
||||
) -> OAuthClient | None:
|
||||
"""
|
||||
Get OAuth client by client_id.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
client_id: OAuth client ID
|
||||
|
||||
Returns:
|
||||
OAuthClient if found, None otherwise
|
||||
"""
|
||||
try:
|
||||
result = await db.execute(
|
||||
select(OAuthClient).where(
|
||||
and_(
|
||||
OAuthClient.client_id == client_id,
|
||||
OAuthClient.is_active == True, # noqa: E712
|
||||
)
|
||||
)
|
||||
)
|
||||
return result.scalar_one_or_none()
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting OAuth client {client_id}: {e!s}")
|
||||
raise
|
||||
|
||||
async def create_client(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
*,
|
||||
obj_in: OAuthClientCreate,
|
||||
owner_user_id: UUID | None = None,
|
||||
) -> tuple[OAuthClient, str | None]:
|
||||
"""
|
||||
Create a new OAuth client.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
obj_in: OAuth client creation data
|
||||
owner_user_id: Optional owner user ID
|
||||
|
||||
Returns:
|
||||
Tuple of (created OAuthClient, client_secret or None for public clients)
|
||||
"""
|
||||
try:
|
||||
# Generate client_id
|
||||
client_id = secrets.token_urlsafe(32)
|
||||
|
||||
# Generate client_secret for confidential clients
|
||||
client_secret = None
|
||||
client_secret_hash = None
|
||||
if obj_in.client_type == "confidential":
|
||||
client_secret = secrets.token_urlsafe(48)
|
||||
# In production, use proper password hashing (bcrypt)
|
||||
# For now, we store a hash placeholder
|
||||
import hashlib
|
||||
|
||||
client_secret_hash = hashlib.sha256(client_secret.encode()).hexdigest()
|
||||
|
||||
db_obj = OAuthClient(
|
||||
client_id=client_id,
|
||||
client_secret_hash=client_secret_hash,
|
||||
client_name=obj_in.client_name,
|
||||
client_description=obj_in.client_description,
|
||||
client_type=obj_in.client_type,
|
||||
redirect_uris=obj_in.redirect_uris,
|
||||
allowed_scopes=obj_in.allowed_scopes,
|
||||
owner_user_id=owner_user_id,
|
||||
is_active=True,
|
||||
)
|
||||
db.add(db_obj)
|
||||
await db.commit()
|
||||
await db.refresh(db_obj)
|
||||
|
||||
logger.info(
|
||||
f"OAuth client created: {obj_in.client_name} ({client_id[:8]}...)"
|
||||
)
|
||||
return db_obj, client_secret
|
||||
except IntegrityError as e:
|
||||
await db.rollback()
|
||||
error_msg = str(e.orig) if hasattr(e, "orig") else str(e)
|
||||
logger.error(f"Error creating OAuth client: {error_msg}")
|
||||
raise ValueError(f"Failed to create OAuth client: {error_msg}")
|
||||
except Exception as e:
|
||||
await db.rollback()
|
||||
logger.error(f"Error creating OAuth client: {e!s}", exc_info=True)
|
||||
raise
|
||||
|
||||
async def deactivate_client(
|
||||
self, db: AsyncSession, *, client_id: str
|
||||
) -> OAuthClient | None:
|
||||
"""
|
||||
Deactivate an OAuth client.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
client_id: OAuth client ID
|
||||
|
||||
Returns:
|
||||
Deactivated OAuthClient if found, None otherwise
|
||||
"""
|
||||
try:
|
||||
client = await self.get_by_client_id(db, client_id=client_id)
|
||||
if client is None:
|
||||
return None
|
||||
|
||||
client.is_active = False
|
||||
db.add(client)
|
||||
await db.commit()
|
||||
await db.refresh(client)
|
||||
|
||||
logger.info(f"OAuth client deactivated: {client.client_name}")
|
||||
return client
|
||||
except Exception as e:
|
||||
await db.rollback()
|
||||
logger.error(f"Error deactivating OAuth client {client_id}: {e!s}")
|
||||
raise
|
||||
|
||||
async def validate_redirect_uri(
|
||||
self, db: AsyncSession, *, client_id: str, redirect_uri: str
|
||||
) -> bool:
|
||||
"""
|
||||
Validate that a redirect URI is allowed for a client.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
client_id: OAuth client ID
|
||||
redirect_uri: Redirect URI to validate
|
||||
|
||||
Returns:
|
||||
True if valid, False otherwise
|
||||
"""
|
||||
try:
|
||||
client = await self.get_by_client_id(db, client_id=client_id)
|
||||
if client is None:
|
||||
return False
|
||||
|
||||
return redirect_uri in (client.redirect_uris or [])
|
||||
except Exception as e:
|
||||
logger.error(f"Error validating redirect URI: {e!s}")
|
||||
return False
|
||||
|
||||
async def verify_client_secret(
|
||||
self, db: AsyncSession, *, client_id: str, client_secret: str
|
||||
) -> bool:
|
||||
"""
|
||||
Verify client credentials.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
client_id: OAuth client ID
|
||||
client_secret: Client secret to verify
|
||||
|
||||
Returns:
|
||||
True if valid, False otherwise
|
||||
"""
|
||||
try:
|
||||
result = await db.execute(
|
||||
select(OAuthClient).where(
|
||||
and_(
|
||||
OAuthClient.client_id == client_id,
|
||||
OAuthClient.is_active == True, # noqa: E712
|
||||
)
|
||||
)
|
||||
)
|
||||
client = result.scalar_one_or_none()
|
||||
|
||||
if client is None or client.client_secret_hash is None:
|
||||
return False
|
||||
|
||||
# Verify secret
|
||||
import hashlib
|
||||
|
||||
secret_hash = hashlib.sha256(client_secret.encode()).hexdigest()
|
||||
# Cast to str for type safety with compare_digest
|
||||
stored_hash: str = str(client.client_secret_hash)
|
||||
return secrets.compare_digest(stored_hash, secret_hash)
|
||||
except Exception as e:
|
||||
logger.error(f"Error verifying client secret: {e!s}")
|
||||
return False
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Singleton instances
|
||||
# ============================================================================
|
||||
|
||||
oauth_account = CRUDOAuthAccount(OAuthAccount)
|
||||
oauth_state = CRUDOAuthState(OAuthState)
|
||||
oauth_client = CRUDOAuthClient(OAuthClient)
|
||||
@@ -7,6 +7,11 @@ Imports all models to ensure they're registered with SQLAlchemy.
|
||||
from app.core.database import Base
|
||||
|
||||
from .base import TimestampMixin, UUIDMixin
|
||||
|
||||
# OAuth models
|
||||
from .oauth_account import OAuthAccount
|
||||
from .oauth_client import OAuthClient
|
||||
from .oauth_state import OAuthState
|
||||
from .organization import Organization
|
||||
|
||||
# Import models
|
||||
@@ -16,6 +21,9 @@ from .user_session import UserSession
|
||||
|
||||
__all__ = [
|
||||
"Base",
|
||||
"OAuthAccount",
|
||||
"OAuthClient",
|
||||
"OAuthState",
|
||||
"Organization",
|
||||
"OrganizationRole",
|
||||
"TimestampMixin",
|
||||
|
||||
55
backend/app/models/oauth_account.py
Normal file
55
backend/app/models/oauth_account.py
Normal file
@@ -0,0 +1,55 @@
|
||||
"""OAuth account model for linking external OAuth providers to users."""
|
||||
|
||||
from sqlalchemy import Column, DateTime, ForeignKey, Index, String, UniqueConstraint
|
||||
from sqlalchemy.dialects.postgresql import UUID
|
||||
from sqlalchemy.orm import relationship
|
||||
|
||||
from .base import Base, TimestampMixin, UUIDMixin
|
||||
|
||||
|
||||
class OAuthAccount(Base, UUIDMixin, TimestampMixin):
|
||||
"""
|
||||
Links OAuth provider accounts to users.
|
||||
|
||||
Supports multiple OAuth providers per user (e.g., user can have both
|
||||
Google and GitHub connected). Each provider account is uniquely identified
|
||||
by (provider, provider_user_id).
|
||||
"""
|
||||
|
||||
__tablename__ = "oauth_accounts"
|
||||
|
||||
# Link to user
|
||||
user_id = Column(
|
||||
UUID(as_uuid=True),
|
||||
ForeignKey("users.id", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
index=True,
|
||||
)
|
||||
|
||||
# OAuth provider identification
|
||||
provider = Column(
|
||||
String(50), nullable=False, index=True
|
||||
) # google, github, microsoft
|
||||
provider_user_id = Column(String(255), nullable=False) # Provider's unique user ID
|
||||
provider_email = Column(
|
||||
String(255), nullable=True, index=True
|
||||
) # Email from provider (for reference)
|
||||
|
||||
# Optional: store provider tokens for API access
|
||||
# These should be encrypted at rest in production
|
||||
access_token_encrypted = Column(String(2048), nullable=True)
|
||||
refresh_token_encrypted = Column(String(2048), nullable=True)
|
||||
token_expires_at = Column(DateTime(timezone=True), nullable=True)
|
||||
|
||||
# Relationship
|
||||
user = relationship("User", back_populates="oauth_accounts")
|
||||
|
||||
__table_args__ = (
|
||||
# Each provider account can only be linked to one user
|
||||
UniqueConstraint("provider", "provider_user_id", name="uq_oauth_provider_user"),
|
||||
# Index for finding all OAuth accounts for a user + provider
|
||||
Index("ix_oauth_accounts_user_provider", "user_id", "provider"),
|
||||
)
|
||||
|
||||
def __repr__(self):
|
||||
return f"<OAuthAccount {self.provider}:{self.provider_user_id}>"
|
||||
67
backend/app/models/oauth_client.py
Normal file
67
backend/app/models/oauth_client.py
Normal file
@@ -0,0 +1,67 @@
|
||||
"""OAuth client model for OAuth provider mode (MCP clients)."""
|
||||
|
||||
from sqlalchemy import Boolean, Column, ForeignKey, String
|
||||
from sqlalchemy.dialects.postgresql import JSONB, UUID
|
||||
from sqlalchemy.orm import relationship
|
||||
|
||||
from .base import Base, TimestampMixin, UUIDMixin
|
||||
|
||||
|
||||
class OAuthClient(Base, UUIDMixin, TimestampMixin):
|
||||
"""
|
||||
Registered OAuth clients (for OAuth provider mode).
|
||||
|
||||
This model stores third-party applications that can authenticate
|
||||
against this API using OAuth 2.0. Used for MCP (Model Context Protocol)
|
||||
client authentication and API access.
|
||||
|
||||
NOTE: This is a skeleton implementation. The full OAuth provider
|
||||
functionality (authorization endpoint, token endpoint, etc.) can be
|
||||
expanded when needed.
|
||||
"""
|
||||
|
||||
__tablename__ = "oauth_clients"
|
||||
|
||||
# Client credentials
|
||||
client_id = Column(String(64), unique=True, nullable=False, index=True)
|
||||
client_secret_hash = Column(
|
||||
String(255), nullable=True
|
||||
) # NULL for public clients (PKCE)
|
||||
|
||||
# Client metadata
|
||||
client_name = Column(String(255), nullable=False)
|
||||
client_description = Column(String(1000), nullable=True)
|
||||
|
||||
# Client type: "public" (SPA, mobile) or "confidential" (server-side)
|
||||
client_type = Column(String(20), nullable=False, default="public")
|
||||
|
||||
# Allowed redirect URIs (JSON array)
|
||||
redirect_uris = Column(JSONB, nullable=False, default=list)
|
||||
|
||||
# Allowed scopes (JSON array of scope names)
|
||||
allowed_scopes = Column(JSONB, nullable=False, default=list)
|
||||
|
||||
# Token lifetimes (in seconds)
|
||||
access_token_lifetime = Column(String(10), nullable=False, default="3600") # 1 hour
|
||||
refresh_token_lifetime = Column(
|
||||
String(10), nullable=False, default="604800"
|
||||
) # 7 days
|
||||
|
||||
# Status
|
||||
is_active = Column(Boolean, default=True, nullable=False, index=True)
|
||||
|
||||
# Optional: owner user (for user-registered applications)
|
||||
owner_user_id = Column(
|
||||
UUID(as_uuid=True),
|
||||
ForeignKey("users.id", ondelete="SET NULL"),
|
||||
nullable=True,
|
||||
)
|
||||
|
||||
# MCP-specific: URL of the MCP server this client represents
|
||||
mcp_server_url = Column(String(2048), nullable=True)
|
||||
|
||||
# Relationship
|
||||
owner = relationship("User", backref="owned_oauth_clients")
|
||||
|
||||
def __repr__(self):
|
||||
return f"<OAuthClient {self.client_name} ({self.client_id[:8]}...)>"
|
||||
45
backend/app/models/oauth_state.py
Normal file
45
backend/app/models/oauth_state.py
Normal file
@@ -0,0 +1,45 @@
|
||||
"""OAuth state model for CSRF protection during OAuth flows."""
|
||||
|
||||
from sqlalchemy import Column, DateTime, String
|
||||
from sqlalchemy.dialects.postgresql import UUID
|
||||
|
||||
from .base import Base, TimestampMixin, UUIDMixin
|
||||
|
||||
|
||||
class OAuthState(Base, UUIDMixin, TimestampMixin):
|
||||
"""
|
||||
Temporary storage for OAuth state parameters.
|
||||
|
||||
Prevents CSRF attacks during OAuth flows by storing a random state
|
||||
value that must match on callback. Also stores PKCE code_verifier
|
||||
for the Authorization Code flow with PKCE.
|
||||
|
||||
These records are short-lived (10 minutes by default) and should
|
||||
be deleted after use or expiration.
|
||||
"""
|
||||
|
||||
__tablename__ = "oauth_states"
|
||||
|
||||
# Random state parameter (CSRF protection)
|
||||
state = Column(String(255), unique=True, nullable=False, index=True)
|
||||
|
||||
# PKCE code_verifier (used to generate code_challenge)
|
||||
code_verifier = Column(String(128), nullable=True)
|
||||
|
||||
# OIDC nonce for ID token replay protection
|
||||
nonce = Column(String(255), nullable=True)
|
||||
|
||||
# OAuth provider (google, github, etc.)
|
||||
provider = Column(String(50), nullable=False)
|
||||
|
||||
# Original redirect URI (for callback validation)
|
||||
redirect_uri = Column(String(500), nullable=True)
|
||||
|
||||
# User ID if this is an account linking flow (user is already logged in)
|
||||
user_id = Column(UUID(as_uuid=True), nullable=True)
|
||||
|
||||
# Expiration time
|
||||
expires_at = Column(DateTime(timezone=True), nullable=False)
|
||||
|
||||
def __repr__(self):
|
||||
return f"<OAuthState {self.state[:8]}... ({self.provider})>"
|
||||
@@ -9,7 +9,8 @@ class User(Base, UUIDMixin, TimestampMixin):
|
||||
__tablename__ = "users"
|
||||
|
||||
email = Column(String(255), unique=True, nullable=False, index=True)
|
||||
password_hash = Column(String(255), nullable=False)
|
||||
# Nullable to support OAuth-only users who never set a password
|
||||
password_hash = Column(String(255), nullable=True)
|
||||
first_name = Column(String(100), nullable=False, default="user")
|
||||
last_name = Column(String(100), nullable=True)
|
||||
phone_number = Column(String(20))
|
||||
@@ -23,6 +24,19 @@ class User(Base, UUIDMixin, TimestampMixin):
|
||||
user_organizations = relationship(
|
||||
"UserOrganization", back_populates="user", cascade="all, delete-orphan"
|
||||
)
|
||||
oauth_accounts = relationship(
|
||||
"OAuthAccount", back_populates="user", cascade="all, delete-orphan"
|
||||
)
|
||||
|
||||
@property
|
||||
def has_password(self) -> bool:
|
||||
"""Check if user can login with password (not OAuth-only)."""
|
||||
return self.password_hash is not None
|
||||
|
||||
@property
|
||||
def can_remove_oauth(self) -> bool:
|
||||
"""Check if user can safely remove an OAuth account link."""
|
||||
return self.has_password or len(self.oauth_accounts) > 1
|
||||
|
||||
def __repr__(self):
|
||||
return f"<User {self.email}>"
|
||||
|
||||
313
backend/app/schemas/oauth.py
Normal file
313
backend/app/schemas/oauth.py
Normal file
@@ -0,0 +1,313 @@
|
||||
"""
|
||||
Pydantic schemas for OAuth authentication.
|
||||
"""
|
||||
|
||||
from datetime import datetime
|
||||
from uuid import UUID
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
# ============================================================================
|
||||
# OAuth Provider Info (for frontend to display available providers)
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class OAuthProviderInfo(BaseModel):
|
||||
"""Information about an available OAuth provider."""
|
||||
|
||||
provider: str = Field(..., description="Provider identifier (google, github)")
|
||||
name: str = Field(..., description="Human-readable provider name")
|
||||
icon: str | None = Field(None, description="Icon identifier for frontend")
|
||||
|
||||
|
||||
class OAuthProvidersResponse(BaseModel):
|
||||
"""Response containing list of enabled OAuth providers."""
|
||||
|
||||
enabled: bool = Field(..., description="Whether OAuth is globally enabled")
|
||||
providers: list[OAuthProviderInfo] = Field(
|
||||
default_factory=list, description="List of enabled providers"
|
||||
)
|
||||
|
||||
model_config = ConfigDict(
|
||||
json_schema_extra={
|
||||
"example": {
|
||||
"enabled": True,
|
||||
"providers": [
|
||||
{"provider": "google", "name": "Google", "icon": "google"},
|
||||
{"provider": "github", "name": "GitHub", "icon": "github"},
|
||||
],
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# OAuth Account (linked provider accounts)
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class OAuthAccountBase(BaseModel):
|
||||
"""Base schema for OAuth accounts."""
|
||||
|
||||
provider: str = Field(..., max_length=50, description="OAuth provider name")
|
||||
provider_email: str | None = Field(
|
||||
None, max_length=255, description="Email from OAuth provider"
|
||||
)
|
||||
|
||||
|
||||
class OAuthAccountCreate(OAuthAccountBase):
|
||||
"""Schema for creating an OAuth account link (internal use)."""
|
||||
|
||||
user_id: UUID
|
||||
provider_user_id: str = Field(..., max_length=255)
|
||||
access_token_encrypted: str | None = None
|
||||
refresh_token_encrypted: str | None = None
|
||||
token_expires_at: datetime | None = None
|
||||
|
||||
|
||||
class OAuthAccountResponse(OAuthAccountBase):
|
||||
"""Schema for OAuth account response to clients."""
|
||||
|
||||
id: UUID
|
||||
created_at: datetime
|
||||
|
||||
model_config = ConfigDict(
|
||||
from_attributes=True,
|
||||
json_schema_extra={
|
||||
"example": {
|
||||
"id": "123e4567-e89b-12d3-a456-426614174000",
|
||||
"provider": "google",
|
||||
"provider_email": "user@gmail.com",
|
||||
"created_at": "2025-11-24T12:00:00Z",
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
class OAuthAccountsListResponse(BaseModel):
|
||||
"""Response containing list of linked OAuth accounts."""
|
||||
|
||||
accounts: list[OAuthAccountResponse]
|
||||
|
||||
model_config = ConfigDict(
|
||||
json_schema_extra={
|
||||
"example": {
|
||||
"accounts": [
|
||||
{
|
||||
"id": "123e4567-e89b-12d3-a456-426614174000",
|
||||
"provider": "google",
|
||||
"provider_email": "user@gmail.com",
|
||||
"created_at": "2025-11-24T12:00:00Z",
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# OAuth Flow (authorization, callback, etc.)
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class OAuthAuthorizeRequest(BaseModel):
|
||||
"""Request parameters for OAuth authorization."""
|
||||
|
||||
provider: str = Field(..., description="OAuth provider (google, github)")
|
||||
redirect_uri: str | None = Field(
|
||||
None, description="Frontend callback URL after OAuth"
|
||||
)
|
||||
mode: str = Field(
|
||||
default="login",
|
||||
description="OAuth mode: login, register, or link",
|
||||
pattern="^(login|register|link)$",
|
||||
)
|
||||
|
||||
|
||||
class OAuthCallbackRequest(BaseModel):
|
||||
"""Request parameters for OAuth callback."""
|
||||
|
||||
code: str = Field(..., description="Authorization code from provider")
|
||||
state: str = Field(..., description="State parameter for CSRF protection")
|
||||
|
||||
|
||||
class OAuthCallbackResponse(BaseModel):
|
||||
"""Response after successful OAuth authentication."""
|
||||
|
||||
access_token: str = Field(..., description="JWT access token")
|
||||
refresh_token: str = Field(..., description="JWT refresh token")
|
||||
token_type: str = Field(default="bearer")
|
||||
expires_in: int = Field(..., description="Token expiration in seconds")
|
||||
is_new_user: bool = Field(
|
||||
default=False, description="Whether a new user was created"
|
||||
)
|
||||
|
||||
model_config = ConfigDict(
|
||||
json_schema_extra={
|
||||
"example": {
|
||||
"access_token": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9...",
|
||||
"refresh_token": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9...",
|
||||
"token_type": "bearer",
|
||||
"expires_in": 900,
|
||||
"is_new_user": False,
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
class OAuthUnlinkResponse(BaseModel):
|
||||
"""Response after unlinking an OAuth account."""
|
||||
|
||||
success: bool = Field(..., description="Whether the unlink was successful")
|
||||
message: str = Field(..., description="Status message")
|
||||
|
||||
model_config = ConfigDict(
|
||||
json_schema_extra={
|
||||
"example": {"success": True, "message": "Google account unlinked"}
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# OAuth State (CSRF protection - internal use)
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class OAuthStateCreate(BaseModel):
|
||||
"""Schema for creating OAuth state (internal use)."""
|
||||
|
||||
state: str = Field(..., max_length=255)
|
||||
code_verifier: str | None = Field(None, max_length=128)
|
||||
nonce: str | None = Field(None, max_length=255)
|
||||
provider: str = Field(..., max_length=50)
|
||||
redirect_uri: str | None = Field(None, max_length=500)
|
||||
user_id: UUID | None = None
|
||||
expires_at: datetime
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# OAuth Client (Provider Mode - MCP clients)
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class OAuthClientBase(BaseModel):
|
||||
"""Base schema for OAuth clients."""
|
||||
|
||||
client_name: str = Field(..., max_length=255, description="Client application name")
|
||||
client_description: str | None = Field(
|
||||
None, max_length=1000, description="Client description"
|
||||
)
|
||||
redirect_uris: list[str] = Field(
|
||||
default_factory=list, description="Allowed redirect URIs"
|
||||
)
|
||||
allowed_scopes: list[str] = Field(
|
||||
default_factory=list, description="Allowed OAuth scopes"
|
||||
)
|
||||
|
||||
|
||||
class OAuthClientCreate(OAuthClientBase):
|
||||
"""Schema for creating an OAuth client."""
|
||||
|
||||
client_type: str = Field(
|
||||
default="public",
|
||||
description="Client type: public or confidential",
|
||||
pattern="^(public|confidential)$",
|
||||
)
|
||||
|
||||
|
||||
class OAuthClientResponse(OAuthClientBase):
|
||||
"""Schema for OAuth client response."""
|
||||
|
||||
id: UUID
|
||||
client_id: str = Field(..., description="OAuth client ID")
|
||||
client_type: str
|
||||
is_active: bool
|
||||
created_at: datetime
|
||||
|
||||
model_config = ConfigDict(
|
||||
from_attributes=True,
|
||||
json_schema_extra={
|
||||
"example": {
|
||||
"id": "123e4567-e89b-12d3-a456-426614174000",
|
||||
"client_id": "abc123def456",
|
||||
"client_name": "My MCP App",
|
||||
"client_description": "My application that uses MCP",
|
||||
"client_type": "public",
|
||||
"redirect_uris": ["http://localhost:3000/callback"],
|
||||
"allowed_scopes": ["read:users", "write:users"],
|
||||
"is_active": True,
|
||||
"created_at": "2025-11-24T12:00:00Z",
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
class OAuthClientWithSecret(OAuthClientResponse):
|
||||
"""Schema for OAuth client response including secret (only shown once)."""
|
||||
|
||||
client_secret: str | None = Field(
|
||||
None, description="Client secret (only shown once for confidential clients)"
|
||||
)
|
||||
|
||||
model_config = ConfigDict(
|
||||
from_attributes=True,
|
||||
json_schema_extra={
|
||||
"example": {
|
||||
"id": "123e4567-e89b-12d3-a456-426614174000",
|
||||
"client_id": "abc123def456",
|
||||
"client_secret": "secret_xyz789",
|
||||
"client_name": "My MCP App",
|
||||
"client_type": "confidential",
|
||||
"redirect_uris": ["http://localhost:3000/callback"],
|
||||
"allowed_scopes": ["read:users"],
|
||||
"is_active": True,
|
||||
"created_at": "2025-11-24T12:00:00Z",
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# OAuth Provider Discovery (RFC 8414 - skeleton)
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class OAuthServerMetadata(BaseModel):
|
||||
"""OAuth 2.0 Authorization Server Metadata (RFC 8414)."""
|
||||
|
||||
issuer: str = Field(..., description="Authorization server issuer URL")
|
||||
authorization_endpoint: str = Field(..., description="Authorization endpoint URL")
|
||||
token_endpoint: str = Field(..., description="Token endpoint URL")
|
||||
registration_endpoint: str | None = Field(
|
||||
None, description="Dynamic client registration endpoint"
|
||||
)
|
||||
revocation_endpoint: str | None = Field(
|
||||
None, description="Token revocation endpoint"
|
||||
)
|
||||
scopes_supported: list[str] = Field(
|
||||
default_factory=list, description="Supported scopes"
|
||||
)
|
||||
response_types_supported: list[str] = Field(
|
||||
default_factory=lambda: ["code"], description="Supported response types"
|
||||
)
|
||||
grant_types_supported: list[str] = Field(
|
||||
default_factory=lambda: ["authorization_code", "refresh_token"],
|
||||
description="Supported grant types",
|
||||
)
|
||||
code_challenge_methods_supported: list[str] = Field(
|
||||
default_factory=lambda: ["S256"], description="Supported PKCE methods"
|
||||
)
|
||||
|
||||
model_config = ConfigDict(
|
||||
json_schema_extra={
|
||||
"example": {
|
||||
"issuer": "https://api.example.com",
|
||||
"authorization_endpoint": "https://api.example.com/oauth/authorize",
|
||||
"token_endpoint": "https://api.example.com/oauth/token",
|
||||
"scopes_supported": ["openid", "profile", "email", "read:users"],
|
||||
"response_types_supported": ["code"],
|
||||
"grant_types_supported": ["authorization_code", "refresh_token"],
|
||||
"code_challenge_methods_supported": ["S256"],
|
||||
}
|
||||
}
|
||||
)
|
||||
@@ -0,0 +1,5 @@
|
||||
# app/services/__init__.py
|
||||
from .auth_service import AuthService
|
||||
from .oauth_service import OAuthService
|
||||
|
||||
__all__ = ["AuthService", "OAuthService"]
|
||||
|
||||
598
backend/app/services/oauth_service.py
Normal file
598
backend/app/services/oauth_service.py
Normal file
@@ -0,0 +1,598 @@
|
||||
"""
|
||||
OAuth Service for handling social authentication flows.
|
||||
|
||||
Supports:
|
||||
- Google OAuth (OpenID Connect)
|
||||
- GitHub OAuth
|
||||
|
||||
Features:
|
||||
- PKCE support for public clients
|
||||
- State parameter for CSRF protection
|
||||
- Auto-linking by email (configurable)
|
||||
- Account linking for existing users
|
||||
"""
|
||||
|
||||
import logging
|
||||
import secrets
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from typing import TypedDict, cast
|
||||
from uuid import UUID
|
||||
|
||||
from authlib.integrations.httpx_client import AsyncOAuth2Client
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.core.auth import create_access_token, create_refresh_token
|
||||
from app.core.config import settings
|
||||
from app.core.exceptions import AuthenticationError
|
||||
from app.crud import oauth_account, oauth_state
|
||||
from app.models.user import User
|
||||
from app.schemas.oauth import (
|
||||
OAuthAccountCreate,
|
||||
OAuthCallbackResponse,
|
||||
OAuthProviderInfo,
|
||||
OAuthProvidersResponse,
|
||||
OAuthStateCreate,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class OAuthProviderConfig(TypedDict, total=False):
|
||||
"""Type definition for OAuth provider configuration."""
|
||||
|
||||
name: str
|
||||
icon: str
|
||||
authorize_url: str
|
||||
token_url: str
|
||||
userinfo_url: str
|
||||
email_url: str # Optional, GitHub-only
|
||||
scopes: list[str]
|
||||
supports_pkce: bool
|
||||
|
||||
|
||||
# Provider configurations
|
||||
OAUTH_PROVIDERS: dict[str, OAuthProviderConfig] = {
|
||||
"google": {
|
||||
"name": "Google",
|
||||
"icon": "google",
|
||||
"authorize_url": "https://accounts.google.com/o/oauth2/v2/auth",
|
||||
"token_url": "https://oauth2.googleapis.com/token",
|
||||
"userinfo_url": "https://www.googleapis.com/oauth2/v3/userinfo",
|
||||
"scopes": ["openid", "email", "profile"],
|
||||
"supports_pkce": True,
|
||||
},
|
||||
"github": {
|
||||
"name": "GitHub",
|
||||
"icon": "github",
|
||||
"authorize_url": "https://github.com/login/oauth/authorize",
|
||||
"token_url": "https://github.com/login/oauth/access_token",
|
||||
"userinfo_url": "https://api.github.com/user",
|
||||
"email_url": "https://api.github.com/user/emails",
|
||||
"scopes": ["read:user", "user:email"],
|
||||
"supports_pkce": False, # GitHub doesn't support PKCE
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
class OAuthService:
|
||||
"""Service for handling OAuth authentication flows."""
|
||||
|
||||
@staticmethod
|
||||
def get_enabled_providers() -> OAuthProvidersResponse:
|
||||
"""
|
||||
Get list of enabled OAuth providers.
|
||||
|
||||
Returns:
|
||||
OAuthProvidersResponse with enabled providers
|
||||
"""
|
||||
providers = []
|
||||
|
||||
for provider_id in settings.enabled_oauth_providers:
|
||||
if provider_id in OAUTH_PROVIDERS:
|
||||
config = OAUTH_PROVIDERS[provider_id]
|
||||
providers.append(
|
||||
OAuthProviderInfo(
|
||||
provider=provider_id,
|
||||
name=config["name"],
|
||||
icon=config["icon"],
|
||||
)
|
||||
)
|
||||
|
||||
return OAuthProvidersResponse(
|
||||
enabled=settings.OAUTH_ENABLED and len(providers) > 0,
|
||||
providers=providers,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _get_provider_credentials(provider: str) -> tuple[str, str]:
|
||||
"""Get client ID and secret for a provider."""
|
||||
if provider == "google":
|
||||
client_id = settings.OAUTH_GOOGLE_CLIENT_ID
|
||||
client_secret = settings.OAUTH_GOOGLE_CLIENT_SECRET
|
||||
elif provider == "github":
|
||||
client_id = settings.OAUTH_GITHUB_CLIENT_ID
|
||||
client_secret = settings.OAUTH_GITHUB_CLIENT_SECRET
|
||||
else:
|
||||
raise AuthenticationError(f"Unknown OAuth provider: {provider}")
|
||||
|
||||
if not client_id or not client_secret:
|
||||
raise AuthenticationError(f"OAuth provider {provider} is not configured")
|
||||
|
||||
return client_id, client_secret
|
||||
|
||||
@staticmethod
|
||||
async def create_authorization_url(
|
||||
db: AsyncSession,
|
||||
*,
|
||||
provider: str,
|
||||
redirect_uri: str,
|
||||
user_id: str | None = None,
|
||||
) -> tuple[str, str]:
|
||||
"""
|
||||
Create OAuth authorization URL with state and optional PKCE.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
provider: OAuth provider (google, github)
|
||||
redirect_uri: Callback URL after OAuth
|
||||
user_id: User ID if linking account (user is logged in)
|
||||
|
||||
Returns:
|
||||
Tuple of (authorization_url, state)
|
||||
|
||||
Raises:
|
||||
AuthenticationError: If provider is not configured
|
||||
"""
|
||||
if not settings.OAUTH_ENABLED:
|
||||
raise AuthenticationError("OAuth is not enabled")
|
||||
|
||||
if provider not in OAUTH_PROVIDERS:
|
||||
raise AuthenticationError(f"Unknown OAuth provider: {provider}")
|
||||
|
||||
if provider not in settings.enabled_oauth_providers:
|
||||
raise AuthenticationError(f"OAuth provider {provider} is not enabled")
|
||||
|
||||
config = OAUTH_PROVIDERS[provider]
|
||||
client_id, client_secret = OAuthService._get_provider_credentials(provider)
|
||||
|
||||
# Generate state for CSRF protection
|
||||
state = secrets.token_urlsafe(32)
|
||||
|
||||
# Generate PKCE code verifier and challenge if supported
|
||||
code_verifier = None
|
||||
code_challenge = None
|
||||
if config.get("supports_pkce"):
|
||||
code_verifier = secrets.token_urlsafe(64)
|
||||
# Create code_challenge using S256 method
|
||||
import base64
|
||||
import hashlib
|
||||
|
||||
code_challenge_bytes = hashlib.sha256(code_verifier.encode()).digest()
|
||||
code_challenge = (
|
||||
base64.urlsafe_b64encode(code_challenge_bytes).decode().rstrip("=")
|
||||
)
|
||||
|
||||
# Generate nonce for OIDC (Google)
|
||||
nonce = secrets.token_urlsafe(32) if provider == "google" else None
|
||||
|
||||
# Store state in database
|
||||
from uuid import UUID
|
||||
|
||||
state_data = OAuthStateCreate(
|
||||
state=state,
|
||||
code_verifier=code_verifier,
|
||||
nonce=nonce,
|
||||
provider=provider,
|
||||
redirect_uri=redirect_uri,
|
||||
user_id=UUID(user_id) if user_id else None,
|
||||
expires_at=datetime.now(UTC)
|
||||
+ timedelta(minutes=settings.OAUTH_STATE_EXPIRE_MINUTES),
|
||||
)
|
||||
await oauth_state.create_state(db, obj_in=state_data)
|
||||
|
||||
# Build authorization URL
|
||||
async with AsyncOAuth2Client(
|
||||
client_id=client_id,
|
||||
client_secret=client_secret,
|
||||
redirect_uri=redirect_uri,
|
||||
) as client:
|
||||
# Prepare authorization params
|
||||
auth_params = {
|
||||
"state": state,
|
||||
"scope": " ".join(config["scopes"]),
|
||||
}
|
||||
|
||||
if code_challenge:
|
||||
auth_params["code_challenge"] = code_challenge
|
||||
auth_params["code_challenge_method"] = "S256"
|
||||
|
||||
if nonce:
|
||||
auth_params["nonce"] = nonce
|
||||
|
||||
url, _ = client.create_authorization_url(
|
||||
config["authorize_url"],
|
||||
**auth_params,
|
||||
)
|
||||
|
||||
logger.info(f"OAuth authorization URL created for {provider}")
|
||||
return url, state
|
||||
|
||||
@staticmethod
|
||||
async def handle_callback(
|
||||
db: AsyncSession,
|
||||
*,
|
||||
code: str,
|
||||
state: str,
|
||||
redirect_uri: str,
|
||||
) -> OAuthCallbackResponse:
|
||||
"""
|
||||
Handle OAuth callback and authenticate/create user.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
code: Authorization code from provider
|
||||
state: State parameter for CSRF verification
|
||||
redirect_uri: Callback URL (must match authorization request)
|
||||
|
||||
Returns:
|
||||
OAuthCallbackResponse with tokens
|
||||
|
||||
Raises:
|
||||
AuthenticationError: If authentication fails
|
||||
"""
|
||||
# Validate and consume state
|
||||
state_record = await oauth_state.get_and_consume_state(db, state=state)
|
||||
if not state_record:
|
||||
raise AuthenticationError("Invalid or expired OAuth state")
|
||||
|
||||
# Extract provider from state record (str for type safety)
|
||||
provider: str = str(state_record.provider)
|
||||
|
||||
if provider not in OAUTH_PROVIDERS:
|
||||
raise AuthenticationError(f"Unknown OAuth provider: {provider}")
|
||||
|
||||
config = OAUTH_PROVIDERS[provider]
|
||||
client_id, client_secret = OAuthService._get_provider_credentials(provider)
|
||||
|
||||
# Exchange code for tokens
|
||||
async with AsyncOAuth2Client(
|
||||
client_id=client_id,
|
||||
client_secret=client_secret,
|
||||
redirect_uri=redirect_uri,
|
||||
) as client:
|
||||
try:
|
||||
# Prepare token request params
|
||||
token_params: dict[str, str] = {"code": code}
|
||||
|
||||
if state_record.code_verifier:
|
||||
token_params["code_verifier"] = str(state_record.code_verifier)
|
||||
|
||||
token = await client.fetch_token(
|
||||
config["token_url"],
|
||||
**token_params,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"OAuth token exchange failed: {e!s}")
|
||||
raise AuthenticationError("Failed to exchange authorization code")
|
||||
|
||||
# Get user info from provider
|
||||
try:
|
||||
access_token = token.get("access_token")
|
||||
if not access_token:
|
||||
raise AuthenticationError("No access token received")
|
||||
|
||||
user_info = await OAuthService._get_user_info(
|
||||
client, provider, config, access_token
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get user info: {e!s}")
|
||||
raise AuthenticationError(
|
||||
"Failed to get user information from provider"
|
||||
)
|
||||
|
||||
# Process user info and create/link account
|
||||
provider_user_id = str(user_info.get("id") or user_info.get("sub"))
|
||||
# Email can be None if user didn't grant email permission
|
||||
email_raw = user_info.get("email")
|
||||
provider_email: str | None = str(email_raw) if email_raw else None
|
||||
|
||||
if not provider_user_id:
|
||||
raise AuthenticationError("Provider did not return user ID")
|
||||
|
||||
# Check if this OAuth account already exists
|
||||
existing_oauth = await oauth_account.get_by_provider_id(
|
||||
db, provider=provider, provider_user_id=provider_user_id
|
||||
)
|
||||
|
||||
is_new_user = False
|
||||
|
||||
if existing_oauth:
|
||||
# Existing OAuth account - login
|
||||
user = existing_oauth.user
|
||||
if not user.is_active:
|
||||
raise AuthenticationError("User account is inactive")
|
||||
|
||||
# Update tokens if stored
|
||||
if token.get("access_token"):
|
||||
await oauth_account.update_tokens(
|
||||
db,
|
||||
account=existing_oauth,
|
||||
access_token_encrypted=token.get("access_token"), # TODO: encrypt
|
||||
refresh_token_encrypted=token.get("refresh_token"), # TODO: encrypt
|
||||
token_expires_at=datetime.now(UTC)
|
||||
+ timedelta(seconds=token.get("expires_in", 3600)),
|
||||
)
|
||||
|
||||
logger.info(f"OAuth login successful for {user.email} via {provider}")
|
||||
|
||||
elif state_record.user_id:
|
||||
# Account linking flow (user is already logged in)
|
||||
result = await db.execute(
|
||||
select(User).where(User.id == state_record.user_id)
|
||||
)
|
||||
user = result.scalar_one_or_none()
|
||||
|
||||
if not user:
|
||||
raise AuthenticationError("User not found for account linking")
|
||||
|
||||
# Check if user already has this provider linked
|
||||
user_id = cast(UUID, user.id)
|
||||
existing_provider = await oauth_account.get_user_account_by_provider(
|
||||
db, user_id=user_id, provider=provider
|
||||
)
|
||||
if existing_provider:
|
||||
raise AuthenticationError(
|
||||
f"You already have a {provider} account linked"
|
||||
)
|
||||
|
||||
# Create OAuth account link
|
||||
oauth_create = OAuthAccountCreate(
|
||||
user_id=user_id,
|
||||
provider=provider,
|
||||
provider_user_id=provider_user_id,
|
||||
provider_email=provider_email,
|
||||
access_token_encrypted=token.get("access_token"), # TODO: encrypt
|
||||
refresh_token_encrypted=token.get("refresh_token"), # TODO: encrypt
|
||||
token_expires_at=datetime.now(UTC)
|
||||
+ timedelta(seconds=token.get("expires_in", 3600))
|
||||
if token.get("expires_in")
|
||||
else None,
|
||||
)
|
||||
await oauth_account.create_account(db, obj_in=oauth_create)
|
||||
|
||||
logger.info(f"OAuth account linked: {provider} -> {user.email}")
|
||||
|
||||
else:
|
||||
# New OAuth login - check for existing user by email
|
||||
user = None
|
||||
|
||||
if provider_email and settings.OAUTH_AUTO_LINK_BY_EMAIL:
|
||||
result = await db.execute(
|
||||
select(User).where(User.email == provider_email)
|
||||
)
|
||||
user = result.scalar_one_or_none()
|
||||
|
||||
if user:
|
||||
# Auto-link to existing user
|
||||
if not user.is_active:
|
||||
raise AuthenticationError("User account is inactive")
|
||||
|
||||
# Check if user already has this provider linked
|
||||
user_id = cast(UUID, user.id)
|
||||
existing_provider = await oauth_account.get_user_account_by_provider(
|
||||
db, user_id=user_id, provider=provider
|
||||
)
|
||||
if existing_provider:
|
||||
# This shouldn't happen if we got here, but safety check
|
||||
logger.warning(
|
||||
f"OAuth account already linked (race condition?): {provider} -> {user.email}"
|
||||
)
|
||||
else:
|
||||
# Create OAuth account link
|
||||
oauth_create = OAuthAccountCreate(
|
||||
user_id=user_id,
|
||||
provider=provider,
|
||||
provider_user_id=provider_user_id,
|
||||
provider_email=provider_email,
|
||||
access_token_encrypted=token.get("access_token"),
|
||||
refresh_token_encrypted=token.get("refresh_token"),
|
||||
token_expires_at=datetime.now(UTC)
|
||||
+ timedelta(seconds=token.get("expires_in", 3600))
|
||||
if token.get("expires_in")
|
||||
else None,
|
||||
)
|
||||
await oauth_account.create_account(db, obj_in=oauth_create)
|
||||
|
||||
logger.info(f"OAuth auto-linked by email: {provider} -> {user.email}")
|
||||
|
||||
else:
|
||||
# Create new user
|
||||
if not provider_email:
|
||||
raise AuthenticationError(
|
||||
f"Email is required for registration. "
|
||||
f"Please grant email permission to {provider}."
|
||||
)
|
||||
|
||||
user = await OAuthService._create_oauth_user(
|
||||
db,
|
||||
email=provider_email,
|
||||
provider=provider,
|
||||
provider_user_id=provider_user_id,
|
||||
user_info=user_info,
|
||||
token=token,
|
||||
)
|
||||
is_new_user = True
|
||||
|
||||
logger.info(f"New user created via OAuth: {user.email} ({provider})")
|
||||
|
||||
# Generate JWT tokens
|
||||
claims = {
|
||||
"is_superuser": user.is_superuser,
|
||||
"email": user.email,
|
||||
"first_name": user.first_name,
|
||||
}
|
||||
|
||||
access_token_jwt = create_access_token(subject=str(user.id), claims=claims)
|
||||
refresh_token_jwt = create_refresh_token(subject=str(user.id))
|
||||
|
||||
return OAuthCallbackResponse(
|
||||
access_token=access_token_jwt,
|
||||
refresh_token=refresh_token_jwt,
|
||||
token_type="bearer",
|
||||
expires_in=settings.ACCESS_TOKEN_EXPIRE_MINUTES * 60,
|
||||
is_new_user=is_new_user,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def _get_user_info(
|
||||
client: AsyncOAuth2Client,
|
||||
provider: str,
|
||||
config: OAuthProviderConfig,
|
||||
access_token: str,
|
||||
) -> dict[str, object]:
|
||||
"""Get user info from OAuth provider."""
|
||||
headers = {"Authorization": f"Bearer {access_token}"}
|
||||
|
||||
if provider == "github":
|
||||
# GitHub returns JSON with Accept header
|
||||
headers["Accept"] = "application/vnd.github+json"
|
||||
|
||||
resp = await client.get(config["userinfo_url"], headers=headers)
|
||||
resp.raise_for_status()
|
||||
user_info = resp.json()
|
||||
|
||||
# GitHub requires separate request for email
|
||||
if provider == "github" and not user_info.get("email"):
|
||||
email_resp = await client.get(
|
||||
config["email_url"],
|
||||
headers=headers,
|
||||
)
|
||||
email_resp.raise_for_status()
|
||||
emails = email_resp.json()
|
||||
|
||||
# Find primary verified email
|
||||
for email_data in emails:
|
||||
if email_data.get("primary") and email_data.get("verified"):
|
||||
user_info["email"] = email_data["email"]
|
||||
break
|
||||
|
||||
return user_info
|
||||
|
||||
@staticmethod
|
||||
async def _create_oauth_user(
|
||||
db: AsyncSession,
|
||||
*,
|
||||
email: str,
|
||||
provider: str,
|
||||
provider_user_id: str,
|
||||
user_info: dict,
|
||||
token: dict,
|
||||
) -> User:
|
||||
"""Create a new user from OAuth provider data."""
|
||||
# Extract name from user_info
|
||||
first_name = "User"
|
||||
last_name = None
|
||||
|
||||
if provider == "google":
|
||||
first_name = user_info.get("given_name") or user_info.get("name", "User")
|
||||
last_name = user_info.get("family_name")
|
||||
elif provider == "github":
|
||||
# GitHub has full name, try to split
|
||||
name = user_info.get("name") or user_info.get("login", "User")
|
||||
parts = name.split(" ", 1)
|
||||
first_name = parts[0]
|
||||
last_name = parts[1] if len(parts) > 1 else None
|
||||
|
||||
# Create user (no password for OAuth-only users)
|
||||
user = User(
|
||||
email=email,
|
||||
password_hash=None, # OAuth-only user
|
||||
first_name=first_name,
|
||||
last_name=last_name,
|
||||
is_active=True,
|
||||
is_superuser=False,
|
||||
)
|
||||
db.add(user)
|
||||
await db.flush() # Get user.id
|
||||
|
||||
# Create OAuth account link
|
||||
user_id = cast(UUID, user.id)
|
||||
oauth_create = OAuthAccountCreate(
|
||||
user_id=user_id,
|
||||
provider=provider,
|
||||
provider_user_id=provider_user_id,
|
||||
provider_email=email,
|
||||
access_token_encrypted=token.get("access_token"), # TODO: encrypt
|
||||
refresh_token_encrypted=token.get("refresh_token"), # TODO: encrypt
|
||||
token_expires_at=datetime.now(UTC)
|
||||
+ timedelta(seconds=token.get("expires_in", 3600))
|
||||
if token.get("expires_in")
|
||||
else None,
|
||||
)
|
||||
await oauth_account.create_account(db, obj_in=oauth_create)
|
||||
|
||||
await db.commit()
|
||||
await db.refresh(user)
|
||||
|
||||
return user
|
||||
|
||||
@staticmethod
|
||||
async def unlink_provider(
|
||||
db: AsyncSession,
|
||||
*,
|
||||
user: User,
|
||||
provider: str,
|
||||
) -> bool:
|
||||
"""
|
||||
Unlink an OAuth provider from a user account.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
user: User to unlink from
|
||||
provider: Provider to unlink
|
||||
|
||||
Returns:
|
||||
True if unlinked successfully
|
||||
|
||||
Raises:
|
||||
AuthenticationError: If unlinking would leave user without login method
|
||||
"""
|
||||
# Check if user can safely remove this OAuth account
|
||||
# Note: We query directly instead of using user.can_remove_oauth property
|
||||
# because the property uses lazy loading which doesn't work in async context
|
||||
user_id = cast(UUID, user.id)
|
||||
has_password = user.password_hash is not None
|
||||
oauth_accounts = await oauth_account.get_user_accounts(db, user_id=user_id)
|
||||
can_remove = has_password or len(oauth_accounts) > 1
|
||||
|
||||
if not can_remove:
|
||||
raise AuthenticationError(
|
||||
"Cannot unlink OAuth account. You must have either a password set "
|
||||
"or at least one other OAuth provider linked."
|
||||
)
|
||||
|
||||
deleted = await oauth_account.delete_account(
|
||||
db, user_id=user_id, provider=provider
|
||||
)
|
||||
|
||||
if not deleted:
|
||||
raise AuthenticationError(f"No {provider} account found to unlink")
|
||||
|
||||
logger.info(f"OAuth provider unlinked: {provider} from {user.email}")
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
async def cleanup_expired_states(db: AsyncSession) -> int:
|
||||
"""
|
||||
Clean up expired OAuth states.
|
||||
|
||||
Should be called periodically (e.g., by a background task).
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
|
||||
Returns:
|
||||
Number of states cleaned up
|
||||
"""
|
||||
return await oauth_state.cleanup_expired(db)
|
||||
Reference in New Issue
Block a user