Files
fast-next-template/backend/app/crud/oauth.py
Felipe Cardoso 16ee4e0cb3 Initial implementation of OAuth models, endpoints, and migrations
- Added models for `OAuthClient`, `OAuthState`, and `OAuthAccount`.
- Created Pydantic schemas to support OAuth flows, client management, and linked accounts.
- Implemented skeleton endpoints for OAuth Provider mode: authorization, token, and revocation.
- Updated router imports to include new `/oauth` and `/oauth/provider` routes.
- Added Alembic migration script to create OAuth-related database tables.
- Enhanced `users` table to allow OAuth-only accounts by making `password_hash` nullable.
2025-11-25 00:37:23 +01:00

654 lines
20 KiB
Python

"""
Async CRUD operations for OAuth models using SQLAlchemy 2.0 patterns.
Provides operations for:
- OAuthAccount: Managing linked OAuth provider accounts
- OAuthState: CSRF protection state during OAuth flows
- OAuthClient: Registered OAuth clients (provider mode skeleton)
"""
import logging
import secrets
from datetime import UTC, datetime
from uuid import UUID
from pydantic import BaseModel
from sqlalchemy import and_, delete, select
from sqlalchemy.exc import IntegrityError
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import joinedload
from app.crud.base import CRUDBase
from app.models.oauth_account import OAuthAccount
from app.models.oauth_client import OAuthClient
from app.models.oauth_state import OAuthState
from app.schemas.oauth import OAuthAccountCreate, OAuthClientCreate, OAuthStateCreate
logger = logging.getLogger(__name__)
# ============================================================================
# OAuth Account CRUD
# ============================================================================
class EmptySchema(BaseModel):
"""Placeholder schema for CRUD operations that don't need update schemas."""
class CRUDOAuthAccount(CRUDBase[OAuthAccount, OAuthAccountCreate, EmptySchema]):
"""CRUD operations for OAuth account links."""
async def get_by_provider_id(
self,
db: AsyncSession,
*,
provider: str,
provider_user_id: str,
) -> OAuthAccount | None:
"""
Get OAuth account by provider and provider user ID.
Args:
db: Database session
provider: OAuth provider name (google, github)
provider_user_id: User ID from the OAuth provider
Returns:
OAuthAccount if found, None otherwise
"""
try:
result = await db.execute(
select(OAuthAccount)
.where(
and_(
OAuthAccount.provider == provider,
OAuthAccount.provider_user_id == provider_user_id,
)
)
.options(joinedload(OAuthAccount.user))
)
return result.scalar_one_or_none()
except Exception as e:
logger.error(
f"Error getting OAuth account for {provider}:{provider_user_id}: {e!s}"
)
raise
async def get_by_provider_email(
self,
db: AsyncSession,
*,
provider: str,
email: str,
) -> OAuthAccount | None:
"""
Get OAuth account by provider and email.
Used for auto-linking existing accounts by email.
Args:
db: Database session
provider: OAuth provider name
email: Email address from the OAuth provider
Returns:
OAuthAccount if found, None otherwise
"""
try:
result = await db.execute(
select(OAuthAccount)
.where(
and_(
OAuthAccount.provider == provider,
OAuthAccount.provider_email == email,
)
)
.options(joinedload(OAuthAccount.user))
)
return result.scalar_one_or_none()
except Exception as e:
logger.error(
f"Error getting OAuth account for {provider} email {email}: {e!s}"
)
raise
async def get_user_accounts(
self,
db: AsyncSession,
*,
user_id: str | UUID,
) -> list[OAuthAccount]:
"""
Get all OAuth accounts linked to a user.
Args:
db: Database session
user_id: User ID
Returns:
List of OAuthAccount objects
"""
try:
user_uuid = UUID(str(user_id)) if isinstance(user_id, str) else user_id
result = await db.execute(
select(OAuthAccount)
.where(OAuthAccount.user_id == user_uuid)
.order_by(OAuthAccount.created_at.desc())
)
return list(result.scalars().all())
except Exception as e:
logger.error(f"Error getting OAuth accounts for user {user_id}: {e!s}")
raise
async def get_user_account_by_provider(
self,
db: AsyncSession,
*,
user_id: str | UUID,
provider: str,
) -> OAuthAccount | None:
"""
Get a specific OAuth account for a user and provider.
Args:
db: Database session
user_id: User ID
provider: OAuth provider name
Returns:
OAuthAccount if found, None otherwise
"""
try:
user_uuid = UUID(str(user_id)) if isinstance(user_id, str) else user_id
result = await db.execute(
select(OAuthAccount).where(
and_(
OAuthAccount.user_id == user_uuid,
OAuthAccount.provider == provider,
)
)
)
return result.scalar_one_or_none()
except Exception as e:
logger.error(
f"Error getting OAuth account for user {user_id}, provider {provider}: {e!s}"
)
raise
async def create_account(
self, db: AsyncSession, *, obj_in: OAuthAccountCreate
) -> OAuthAccount:
"""
Create a new OAuth account link.
Args:
db: Database session
obj_in: OAuth account creation data
Returns:
Created OAuthAccount
Raises:
ValueError: If account already exists or creation fails
"""
try:
db_obj = OAuthAccount(
user_id=obj_in.user_id,
provider=obj_in.provider,
provider_user_id=obj_in.provider_user_id,
provider_email=obj_in.provider_email,
access_token_encrypted=obj_in.access_token_encrypted,
refresh_token_encrypted=obj_in.refresh_token_encrypted,
token_expires_at=obj_in.token_expires_at,
)
db.add(db_obj)
await db.commit()
await db.refresh(db_obj)
logger.info(
f"OAuth account created: {obj_in.provider} linked to user {obj_in.user_id}"
)
return db_obj
except IntegrityError as e:
await db.rollback()
error_msg = str(e.orig) if hasattr(e, "orig") else str(e)
if "uq_oauth_provider_user" in error_msg.lower():
logger.warning(
f"OAuth account already exists: {obj_in.provider}:{obj_in.provider_user_id}"
)
raise ValueError(
f"This {obj_in.provider} account is already linked to another user"
)
logger.error(f"Integrity error creating OAuth account: {error_msg}")
raise ValueError(f"Failed to create OAuth account: {error_msg}")
except Exception as e:
await db.rollback()
logger.error(f"Error creating OAuth account: {e!s}", exc_info=True)
raise
async def delete_account(
self,
db: AsyncSession,
*,
user_id: str | UUID,
provider: str,
) -> bool:
"""
Delete an OAuth account link.
Args:
db: Database session
user_id: User ID
provider: OAuth provider name
Returns:
True if deleted, False if not found
"""
try:
user_uuid = UUID(str(user_id)) if isinstance(user_id, str) else user_id
result = await db.execute(
delete(OAuthAccount).where(
and_(
OAuthAccount.user_id == user_uuid,
OAuthAccount.provider == provider,
)
)
)
await db.commit()
deleted = result.rowcount > 0
if deleted:
logger.info(
f"OAuth account deleted: {provider} unlinked from user {user_id}"
)
else:
logger.warning(
f"OAuth account not found for deletion: {provider} for user {user_id}"
)
return deleted
except Exception as e:
await db.rollback()
logger.error(
f"Error deleting OAuth account {provider} for user {user_id}: {e!s}"
)
raise
async def update_tokens(
self,
db: AsyncSession,
*,
account: OAuthAccount,
access_token_encrypted: str | None = None,
refresh_token_encrypted: str | None = None,
token_expires_at: datetime | None = None,
) -> OAuthAccount:
"""
Update OAuth tokens for an account.
Args:
db: Database session
account: OAuthAccount to update
access_token_encrypted: New encrypted access token
refresh_token_encrypted: New encrypted refresh token
token_expires_at: New token expiration time
Returns:
Updated OAuthAccount
"""
try:
if access_token_encrypted is not None:
account.access_token_encrypted = access_token_encrypted
if refresh_token_encrypted is not None:
account.refresh_token_encrypted = refresh_token_encrypted
if token_expires_at is not None:
account.token_expires_at = token_expires_at
db.add(account)
await db.commit()
await db.refresh(account)
return account
except Exception as e:
await db.rollback()
logger.error(f"Error updating OAuth tokens: {e!s}")
raise
# ============================================================================
# OAuth State CRUD
# ============================================================================
class CRUDOAuthState(CRUDBase[OAuthState, OAuthStateCreate, EmptySchema]):
"""CRUD operations for OAuth state (CSRF protection)."""
async def create_state(
self, db: AsyncSession, *, obj_in: OAuthStateCreate
) -> OAuthState:
"""
Create a new OAuth state for CSRF protection.
Args:
db: Database session
obj_in: OAuth state creation data
Returns:
Created OAuthState
"""
try:
db_obj = OAuthState(
state=obj_in.state,
code_verifier=obj_in.code_verifier,
nonce=obj_in.nonce,
provider=obj_in.provider,
redirect_uri=obj_in.redirect_uri,
user_id=obj_in.user_id,
expires_at=obj_in.expires_at,
)
db.add(db_obj)
await db.commit()
await db.refresh(db_obj)
logger.debug(f"OAuth state created for {obj_in.provider}")
return db_obj
except IntegrityError as e:
await db.rollback()
# State collision (extremely rare with cryptographic random)
error_msg = str(e.orig) if hasattr(e, "orig") else str(e)
logger.error(f"OAuth state collision: {error_msg}")
raise ValueError("Failed to create OAuth state, please retry")
except Exception as e:
await db.rollback()
logger.error(f"Error creating OAuth state: {e!s}", exc_info=True)
raise
async def get_and_consume_state(
self, db: AsyncSession, *, state: str
) -> OAuthState | None:
"""
Get and delete OAuth state (consume it).
This ensures each state can only be used once (replay protection).
Args:
db: Database session
state: State string to look up
Returns:
OAuthState if found and valid, None otherwise
"""
try:
# Get the state
result = await db.execute(
select(OAuthState).where(OAuthState.state == state)
)
db_obj = result.scalar_one_or_none()
if db_obj is None:
logger.warning(f"OAuth state not found: {state[:8]}...")
return None
# Check expiration
# Handle both timezone-aware and timezone-naive datetimes
now = datetime.now(UTC)
expires_at = db_obj.expires_at
if expires_at.tzinfo is None:
# SQLite returns naive datetimes, assume UTC
expires_at = expires_at.replace(tzinfo=UTC)
if expires_at < now:
logger.warning(f"OAuth state expired: {state[:8]}...")
await db.delete(db_obj)
await db.commit()
return None
# Delete it (consume)
await db.delete(db_obj)
await db.commit()
logger.debug(f"OAuth state consumed: {state[:8]}...")
return db_obj
except Exception as e:
await db.rollback()
logger.error(f"Error consuming OAuth state: {e!s}")
raise
async def cleanup_expired(self, db: AsyncSession) -> int:
"""
Clean up expired OAuth states.
Should be called periodically to remove stale states.
Args:
db: Database session
Returns:
Number of states deleted
"""
try:
now = datetime.now(UTC)
stmt = delete(OAuthState).where(OAuthState.expires_at < now)
result = await db.execute(stmt)
await db.commit()
count = result.rowcount
if count > 0:
logger.info(f"Cleaned up {count} expired OAuth states")
return count
except Exception as e:
await db.rollback()
logger.error(f"Error cleaning up expired OAuth states: {e!s}")
raise
# ============================================================================
# OAuth Client CRUD (Provider Mode - Skeleton)
# ============================================================================
class CRUDOAuthClient(CRUDBase[OAuthClient, OAuthClientCreate, EmptySchema]):
"""
CRUD operations for OAuth clients (provider mode).
This is a skeleton implementation for MCP client registration.
Full implementation can be expanded when needed.
"""
async def get_by_client_id(
self, db: AsyncSession, *, client_id: str
) -> OAuthClient | None:
"""
Get OAuth client by client_id.
Args:
db: Database session
client_id: OAuth client ID
Returns:
OAuthClient if found, None otherwise
"""
try:
result = await db.execute(
select(OAuthClient).where(
and_(
OAuthClient.client_id == client_id,
OAuthClient.is_active == True, # noqa: E712
)
)
)
return result.scalar_one_or_none()
except Exception as e:
logger.error(f"Error getting OAuth client {client_id}: {e!s}")
raise
async def create_client(
self,
db: AsyncSession,
*,
obj_in: OAuthClientCreate,
owner_user_id: UUID | None = None,
) -> tuple[OAuthClient, str | None]:
"""
Create a new OAuth client.
Args:
db: Database session
obj_in: OAuth client creation data
owner_user_id: Optional owner user ID
Returns:
Tuple of (created OAuthClient, client_secret or None for public clients)
"""
try:
# Generate client_id
client_id = secrets.token_urlsafe(32)
# Generate client_secret for confidential clients
client_secret = None
client_secret_hash = None
if obj_in.client_type == "confidential":
client_secret = secrets.token_urlsafe(48)
# In production, use proper password hashing (bcrypt)
# For now, we store a hash placeholder
import hashlib
client_secret_hash = hashlib.sha256(client_secret.encode()).hexdigest()
db_obj = OAuthClient(
client_id=client_id,
client_secret_hash=client_secret_hash,
client_name=obj_in.client_name,
client_description=obj_in.client_description,
client_type=obj_in.client_type,
redirect_uris=obj_in.redirect_uris,
allowed_scopes=obj_in.allowed_scopes,
owner_user_id=owner_user_id,
is_active=True,
)
db.add(db_obj)
await db.commit()
await db.refresh(db_obj)
logger.info(
f"OAuth client created: {obj_in.client_name} ({client_id[:8]}...)"
)
return db_obj, client_secret
except IntegrityError as e:
await db.rollback()
error_msg = str(e.orig) if hasattr(e, "orig") else str(e)
logger.error(f"Error creating OAuth client: {error_msg}")
raise ValueError(f"Failed to create OAuth client: {error_msg}")
except Exception as e:
await db.rollback()
logger.error(f"Error creating OAuth client: {e!s}", exc_info=True)
raise
async def deactivate_client(
self, db: AsyncSession, *, client_id: str
) -> OAuthClient | None:
"""
Deactivate an OAuth client.
Args:
db: Database session
client_id: OAuth client ID
Returns:
Deactivated OAuthClient if found, None otherwise
"""
try:
client = await self.get_by_client_id(db, client_id=client_id)
if client is None:
return None
client.is_active = False
db.add(client)
await db.commit()
await db.refresh(client)
logger.info(f"OAuth client deactivated: {client.client_name}")
return client
except Exception as e:
await db.rollback()
logger.error(f"Error deactivating OAuth client {client_id}: {e!s}")
raise
async def validate_redirect_uri(
self, db: AsyncSession, *, client_id: str, redirect_uri: str
) -> bool:
"""
Validate that a redirect URI is allowed for a client.
Args:
db: Database session
client_id: OAuth client ID
redirect_uri: Redirect URI to validate
Returns:
True if valid, False otherwise
"""
try:
client = await self.get_by_client_id(db, client_id=client_id)
if client is None:
return False
return redirect_uri in (client.redirect_uris or [])
except Exception as e:
logger.error(f"Error validating redirect URI: {e!s}")
return False
async def verify_client_secret(
self, db: AsyncSession, *, client_id: str, client_secret: str
) -> bool:
"""
Verify client credentials.
Args:
db: Database session
client_id: OAuth client ID
client_secret: Client secret to verify
Returns:
True if valid, False otherwise
"""
try:
result = await db.execute(
select(OAuthClient).where(
and_(
OAuthClient.client_id == client_id,
OAuthClient.is_active == True, # noqa: E712
)
)
)
client = result.scalar_one_or_none()
if client is None or client.client_secret_hash is None:
return False
# Verify secret
import hashlib
secret_hash = hashlib.sha256(client_secret.encode()).hexdigest()
# Cast to str for type safety with compare_digest
stored_hash: str = str(client.client_secret_hash)
return secrets.compare_digest(stored_hash, secret_hash)
except Exception as e:
logger.error(f"Error verifying client secret: {e!s}")
return False
# ============================================================================
# Singleton instances
# ============================================================================
oauth_account = CRUDOAuthAccount(OAuthAccount)
oauth_state = CRUDOAuthState(OAuthState)
oauth_client = CRUDOAuthClient(OAuthClient)