Initial implementation of OAuth models, endpoints, and migrations

- Added models for `OAuthClient`, `OAuthState`, and `OAuthAccount`.
- Created Pydantic schemas to support OAuth flows, client management, and linked accounts.
- Implemented skeleton endpoints for OAuth Provider mode: authorization, token, and revocation.
- Updated router imports to include new `/oauth` and `/oauth/provider` routes.
- Added Alembic migration script to create OAuth-related database tables.
- Enhanced `users` table to allow OAuth-only accounts by making `password_hash` nullable.
This commit is contained in:
Felipe Cardoso
2025-11-25 00:37:23 +01:00
parent e6792c2d6c
commit 16ee4e0cb3
23 changed files with 4109 additions and 13 deletions

View File

@@ -0,0 +1,144 @@
"""add oauth models
Revision ID: d5a7b2c9e1f3
Revises: c8e9f3a2d1b4
Create Date: 2025-11-24 20:00:00.000000
"""
from collections.abc import Sequence
import sqlalchemy as sa
from alembic import op
from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision: str = "d5a7b2c9e1f3"
down_revision: str | None = "c8e9f3a2d1b4"
branch_labels: str | Sequence[str] | None = None
depends_on: str | Sequence[str] | None = None
def upgrade() -> None:
# 1. Make password_hash nullable on users table (for OAuth-only users)
op.alter_column(
"users",
"password_hash",
existing_type=sa.String(length=255),
nullable=True,
)
# 2. Create oauth_accounts table (links OAuth providers to users)
op.create_table(
"oauth_accounts",
sa.Column("id", sa.UUID(), nullable=False),
sa.Column("user_id", sa.UUID(), nullable=False),
sa.Column("provider", sa.String(length=50), nullable=False),
sa.Column("provider_user_id", sa.String(length=255), nullable=False),
sa.Column("provider_email", sa.String(length=255), nullable=True),
sa.Column("access_token_encrypted", sa.String(length=2048), nullable=True),
sa.Column("refresh_token_encrypted", sa.String(length=2048), nullable=True),
sa.Column("token_expires_at", sa.DateTime(timezone=True), nullable=True),
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False),
sa.Column("updated_at", sa.DateTime(timezone=True), nullable=False),
sa.PrimaryKeyConstraint("id"),
sa.ForeignKeyConstraint(
["user_id"],
["users.id"],
name="fk_oauth_accounts_user_id",
ondelete="CASCADE",
),
sa.UniqueConstraint(
"provider", "provider_user_id", name="uq_oauth_provider_user"
),
)
# Create indexes for oauth_accounts
op.create_index("ix_oauth_accounts_user_id", "oauth_accounts", ["user_id"])
op.create_index("ix_oauth_accounts_provider", "oauth_accounts", ["provider"])
op.create_index(
"ix_oauth_accounts_provider_email", "oauth_accounts", ["provider_email"]
)
op.create_index(
"ix_oauth_accounts_user_provider", "oauth_accounts", ["user_id", "provider"]
)
# 3. Create oauth_states table (CSRF protection during OAuth flow)
op.create_table(
"oauth_states",
sa.Column("id", sa.UUID(), nullable=False),
sa.Column("state", sa.String(length=255), nullable=False),
sa.Column("code_verifier", sa.String(length=128), nullable=True),
sa.Column("nonce", sa.String(length=255), nullable=True),
sa.Column("provider", sa.String(length=50), nullable=False),
sa.Column("redirect_uri", sa.String(length=500), nullable=True),
sa.Column("user_id", sa.UUID(), nullable=True),
sa.Column("expires_at", sa.DateTime(timezone=True), nullable=False),
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False),
sa.Column("updated_at", sa.DateTime(timezone=True), nullable=False),
sa.PrimaryKeyConstraint("id"),
)
# Create indexes for oauth_states
op.create_index("ix_oauth_states_state", "oauth_states", ["state"], unique=True)
op.create_index("ix_oauth_states_expires_at", "oauth_states", ["expires_at"])
# 4. Create oauth_clients table (OAuth provider mode - skeleton for MCP)
op.create_table(
"oauth_clients",
sa.Column("id", sa.UUID(), nullable=False),
sa.Column("client_id", sa.String(length=64), nullable=False),
sa.Column("client_secret_hash", sa.String(length=255), nullable=True),
sa.Column("client_name", sa.String(length=255), nullable=False),
sa.Column("client_description", sa.String(length=1000), nullable=True),
sa.Column("client_type", sa.String(length=20), nullable=False),
sa.Column("redirect_uris", postgresql.JSONB(), nullable=False),
sa.Column("allowed_scopes", postgresql.JSONB(), nullable=False),
sa.Column("access_token_lifetime", sa.String(length=10), nullable=False),
sa.Column("refresh_token_lifetime", sa.String(length=10), nullable=False),
sa.Column("is_active", sa.Boolean(), nullable=False, server_default="true"),
sa.Column("owner_user_id", sa.UUID(), nullable=True),
sa.Column("mcp_server_url", sa.String(length=2048), nullable=True),
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False),
sa.Column("updated_at", sa.DateTime(timezone=True), nullable=False),
sa.PrimaryKeyConstraint("id"),
sa.ForeignKeyConstraint(
["owner_user_id"],
["users.id"],
name="fk_oauth_clients_owner_user_id",
ondelete="SET NULL",
),
)
# Create indexes for oauth_clients
op.create_index(
"ix_oauth_clients_client_id", "oauth_clients", ["client_id"], unique=True
)
op.create_index("ix_oauth_clients_is_active", "oauth_clients", ["is_active"])
def downgrade() -> None:
# Drop oauth_clients table and indexes
op.drop_index("ix_oauth_clients_is_active", table_name="oauth_clients")
op.drop_index("ix_oauth_clients_client_id", table_name="oauth_clients")
op.drop_table("oauth_clients")
# Drop oauth_states table and indexes
op.drop_index("ix_oauth_states_expires_at", table_name="oauth_states")
op.drop_index("ix_oauth_states_state", table_name="oauth_states")
op.drop_table("oauth_states")
# Drop oauth_accounts table and indexes
op.drop_index("ix_oauth_accounts_user_provider", table_name="oauth_accounts")
op.drop_index("ix_oauth_accounts_provider_email", table_name="oauth_accounts")
op.drop_index("ix_oauth_accounts_provider", table_name="oauth_accounts")
op.drop_index("ix_oauth_accounts_user_id", table_name="oauth_accounts")
op.drop_table("oauth_accounts")
# Revert password_hash to non-nullable
op.alter_column(
"users",
"password_hash",
existing_type=sa.String(length=255),
nullable=False,
)

View File

@@ -1,9 +1,21 @@
from fastapi import APIRouter
from app.api.routes import admin, auth, organizations, sessions, users
from app.api.routes import (
admin,
auth,
oauth,
oauth_provider,
organizations,
sessions,
users,
)
api_router = APIRouter()
api_router.include_router(auth.router, prefix="/auth", tags=["Authentication"])
api_router.include_router(oauth.router, prefix="/oauth", tags=["OAuth"])
api_router.include_router(
oauth_provider.router, prefix="/oauth", tags=["OAuth Provider"]
)
api_router.include_router(users.router, prefix="/users", tags=["Users"])
api_router.include_router(sessions.router, prefix="/sessions", tags=["Sessions"])
api_router.include_router(admin.router, prefix="/admin", tags=["Admin"])

View File

@@ -0,0 +1,433 @@
# app/api/routes/oauth.py
"""
OAuth routes for social authentication.
Endpoints:
- GET /oauth/providers - List enabled OAuth providers
- GET /oauth/authorize/{provider} - Get authorization URL
- POST /oauth/callback/{provider} - Handle OAuth callback
- GET /oauth/accounts - List linked OAuth accounts
- DELETE /oauth/accounts/{provider} - Unlink an OAuth account
"""
import logging
import os
from datetime import UTC, datetime
from typing import Any
from fastapi import APIRouter, Depends, HTTPException, Query, Request, status
from slowapi import Limiter
from slowapi.util import get_remote_address
from sqlalchemy.ext.asyncio import AsyncSession
from app.api.dependencies.auth import get_current_user, get_optional_current_user
from app.core.auth import decode_token
from app.core.config import settings
from app.core.database import get_db
from app.core.exceptions import AuthenticationError as AuthError
from app.crud import oauth_account
from app.crud.session import session as session_crud
from app.models.user import User
from app.schemas.oauth import (
OAuthAccountsListResponse,
OAuthCallbackRequest,
OAuthCallbackResponse,
OAuthProvidersResponse,
OAuthUnlinkResponse,
)
from app.schemas.sessions import SessionCreate
from app.schemas.users import Token
from app.services.oauth_service import OAuthService
from app.utils.device import extract_device_info
router = APIRouter()
logger = logging.getLogger(__name__)
# Initialize limiter for this router
limiter = Limiter(key_func=get_remote_address)
# Use higher rate limits in test environment
IS_TEST = os.getenv("IS_TEST", "False") == "True"
RATE_MULTIPLIER = 100 if IS_TEST else 1
async def _create_oauth_login_session(
db: AsyncSession,
request: Request,
user: User,
tokens: Token,
provider: str,
) -> None:
"""
Create a session record for successful OAuth login.
This is a best-effort operation - login succeeds even if session creation fails.
"""
try:
device_info = extract_device_info(request)
# Decode refresh token to get JTI and expiration
refresh_payload = decode_token(tokens.refresh_token, verify_type="refresh")
session_data = SessionCreate(
user_id=user.id,
refresh_token_jti=refresh_payload.jti,
device_name=device_info.device_name or f"OAuth ({provider})",
device_id=device_info.device_id,
ip_address=device_info.ip_address,
user_agent=device_info.user_agent,
last_used_at=datetime.now(UTC),
expires_at=datetime.fromtimestamp(refresh_payload.exp, tz=UTC),
location_city=device_info.location_city,
location_country=device_info.location_country,
)
await session_crud.create_session(db, obj_in=session_data)
logger.info(
f"OAuth login successful: {user.email} via {provider} "
f"from {device_info.device_name} (IP: {device_info.ip_address})"
)
except Exception as session_err:
# Log but don't fail login if session creation fails
logger.error(
f"Failed to create session for OAuth login {user.email}: {session_err!s}",
exc_info=True,
)
@router.get(
"/providers",
response_model=OAuthProvidersResponse,
summary="List OAuth Providers",
description="""
Get list of enabled OAuth providers for the login/register UI.
Returns:
List of enabled providers with display info.
""",
operation_id="list_oauth_providers",
)
async def list_providers() -> Any:
"""
Get list of enabled OAuth providers.
This endpoint is public (no authentication required) as it's needed
for the login/register UI to display available social login options.
"""
return OAuthService.get_enabled_providers()
@router.get(
"/authorize/{provider}",
response_model=dict,
summary="Get OAuth Authorization URL",
description="""
Get the authorization URL to redirect the user to the OAuth provider.
The frontend should redirect the user to the returned URL.
After authentication, the provider will redirect back to the callback URL.
**Rate Limit**: 10 requests/minute
""",
operation_id="get_oauth_authorization_url",
)
@limiter.limit(f"{10 * RATE_MULTIPLIER}/minute")
async def get_authorization_url(
request: Request,
provider: str,
redirect_uri: str = Query(
..., description="Frontend callback URL after OAuth completes"
),
current_user: User | None = Depends(get_optional_current_user),
db: AsyncSession = Depends(get_db),
) -> Any:
"""
Get OAuth authorization URL.
Args:
provider: OAuth provider (google, github)
redirect_uri: Frontend callback URL
current_user: Current user (optional, for account linking)
db: Database session
Returns:
dict with authorization_url and state
"""
if not settings.OAUTH_ENABLED:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="OAuth is not enabled",
)
try:
# If user is logged in, this is an account linking flow
user_id = str(current_user.id) if current_user else None
url, state = await OAuthService.create_authorization_url(
db,
provider=provider,
redirect_uri=redirect_uri,
user_id=user_id,
)
return {
"authorization_url": url,
"state": state,
}
except AuthError as e:
logger.warning(f"OAuth authorization failed: {e!s}")
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=str(e),
)
except Exception as e:
logger.error(f"OAuth authorization error: {e!s}", exc_info=True)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to create authorization URL",
)
@router.post(
"/callback/{provider}",
response_model=OAuthCallbackResponse,
summary="OAuth Callback",
description="""
Handle OAuth callback from provider.
The frontend should call this endpoint with the code and state
parameters received from the OAuth provider redirect.
Returns:
JWT tokens for the authenticated user.
**Rate Limit**: 10 requests/minute
""",
operation_id="handle_oauth_callback",
)
@limiter.limit(f"{10 * RATE_MULTIPLIER}/minute")
async def handle_callback(
request: Request,
provider: str,
callback_data: OAuthCallbackRequest,
redirect_uri: str = Query(
..., description="Must match the redirect_uri used in authorization"
),
db: AsyncSession = Depends(get_db),
) -> Any:
"""
Handle OAuth callback.
Args:
provider: OAuth provider (google, github)
callback_data: Code and state from provider
redirect_uri: Original redirect URI (for validation)
db: Database session
Returns:
OAuthCallbackResponse with tokens
"""
if not settings.OAUTH_ENABLED:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="OAuth is not enabled",
)
try:
result = await OAuthService.handle_callback(
db,
code=callback_data.code,
state=callback_data.state,
redirect_uri=redirect_uri,
)
# Create session for the login (need to get the user first)
# Note: This requires fetching the user from the token
# For now, we skip session creation here as the result doesn't include user info
# The session will be created on next request if needed
return result
except AuthError as e:
logger.warning(f"OAuth callback failed: {e!s}")
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=str(e),
)
except Exception as e:
logger.error(f"OAuth callback error: {e!s}", exc_info=True)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="OAuth authentication failed",
)
@router.get(
"/accounts",
response_model=OAuthAccountsListResponse,
summary="List Linked OAuth Accounts",
description="""
Get list of OAuth accounts linked to the current user.
Requires authentication.
""",
operation_id="list_oauth_accounts",
)
async def list_accounts(
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
) -> Any:
"""
List OAuth accounts linked to the current user.
Args:
current_user: Current authenticated user
db: Database session
Returns:
List of linked OAuth accounts
"""
accounts = await oauth_account.get_user_accounts(db, user_id=current_user.id)
return OAuthAccountsListResponse(accounts=accounts)
@router.delete(
"/accounts/{provider}",
response_model=OAuthUnlinkResponse,
summary="Unlink OAuth Account",
description="""
Unlink an OAuth provider from the current user.
The user must have either a password set or another OAuth provider
linked to ensure they can still log in.
**Rate Limit**: 5 requests/minute
""",
operation_id="unlink_oauth_account",
)
@limiter.limit(f"{5 * RATE_MULTIPLIER}/minute")
async def unlink_account(
request: Request,
provider: str,
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
) -> Any:
"""
Unlink an OAuth provider from the current user.
Args:
provider: Provider to unlink (google, github)
current_user: Current authenticated user
db: Database session
Returns:
Success message
"""
try:
await OAuthService.unlink_provider(
db,
user=current_user,
provider=provider,
)
return OAuthUnlinkResponse(
success=True,
message=f"{provider.capitalize()} account unlinked successfully",
)
except AuthError as e:
logger.warning(f"OAuth unlink failed for {current_user.email}: {e!s}")
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=str(e),
)
except Exception as e:
logger.error(f"OAuth unlink error: {e!s}", exc_info=True)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to unlink OAuth account",
)
@router.post(
"/link/{provider}",
response_model=dict,
summary="Start Account Linking",
description="""
Start the OAuth flow to link a new provider to the current user.
This is a convenience endpoint that redirects to /authorize/{provider}
with the current user context.
**Rate Limit**: 10 requests/minute
""",
operation_id="start_oauth_link",
)
@limiter.limit(f"{10 * RATE_MULTIPLIER}/minute")
async def start_link(
request: Request,
provider: str,
redirect_uri: str = Query(
..., description="Frontend callback URL after OAuth completes"
),
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
) -> Any:
"""
Start OAuth account linking flow.
This endpoint requires authentication and will initiate an OAuth flow
to link a new provider to the current user's account.
Args:
provider: OAuth provider to link (google, github)
redirect_uri: Frontend callback URL
current_user: Current authenticated user
db: Database session
Returns:
dict with authorization_url and state
"""
if not settings.OAUTH_ENABLED:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="OAuth is not enabled",
)
# Check if user already has this provider linked
existing = await oauth_account.get_user_account_by_provider(
db, user_id=current_user.id, provider=provider
)
if existing:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"You already have a {provider} account linked",
)
try:
url, state = await OAuthService.create_authorization_url(
db,
provider=provider,
redirect_uri=redirect_uri,
user_id=str(current_user.id),
)
return {
"authorization_url": url,
"state": state,
}
except AuthError as e:
logger.warning(f"OAuth link authorization failed: {e!s}")
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=str(e),
)
except Exception as e:
logger.error(f"OAuth link error: {e!s}", exc_info=True)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to create authorization URL",
)

View File

@@ -0,0 +1,312 @@
# app/api/routes/oauth_provider.py
"""
OAuth Provider routes (Authorization Server mode).
This is a skeleton implementation for MCP (Model Context Protocol) client authentication.
Provides basic OAuth 2.0 endpoints that can be expanded for full functionality.
Endpoints:
- GET /.well-known/oauth-authorization-server - Server metadata (RFC 8414)
- GET /oauth/provider/authorize - Authorization endpoint (skeleton)
- POST /oauth/provider/token - Token endpoint (skeleton)
- POST /oauth/provider/revoke - Token revocation endpoint (skeleton)
NOTE: This is intentionally minimal. Full implementation should include:
- Complete authorization code flow
- Refresh token handling
- Scope validation
- Client authentication
- PKCE support
"""
import logging
from typing import Any
from fastapi import APIRouter, Depends, Form, HTTPException, Query, status
from sqlalchemy.ext.asyncio import AsyncSession
from app.core.config import settings
from app.core.database import get_db
from app.crud import oauth_client
from app.schemas.oauth import OAuthServerMetadata
router = APIRouter()
logger = logging.getLogger(__name__)
@router.get(
"/.well-known/oauth-authorization-server",
response_model=OAuthServerMetadata,
summary="OAuth Server Metadata",
description="""
OAuth 2.0 Authorization Server Metadata (RFC 8414).
Returns server metadata including supported endpoints, scopes,
and capabilities for MCP clients.
""",
operation_id="get_oauth_server_metadata",
tags=["OAuth Provider"],
)
async def get_server_metadata() -> Any:
"""
Get OAuth 2.0 server metadata.
This endpoint is used by MCP clients to discover the authorization
server's capabilities.
"""
if not settings.OAUTH_PROVIDER_ENABLED:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="OAuth provider mode is not enabled",
)
base_url = settings.OAUTH_ISSUER.rstrip("/")
return OAuthServerMetadata(
issuer=base_url,
authorization_endpoint=f"{base_url}/api/v1/oauth/provider/authorize",
token_endpoint=f"{base_url}/api/v1/oauth/provider/token",
revocation_endpoint=f"{base_url}/api/v1/oauth/provider/revoke",
registration_endpoint=None, # Dynamic registration not implemented
scopes_supported=[
"openid",
"profile",
"email",
"read:users",
"write:users",
"read:organizations",
"write:organizations",
],
response_types_supported=["code"],
grant_types_supported=["authorization_code", "refresh_token"],
code_challenge_methods_supported=["S256"],
)
@router.get(
"/provider/authorize",
summary="Authorization Endpoint (Skeleton)",
description="""
OAuth 2.0 Authorization Endpoint.
**NOTE**: This is a skeleton implementation. In a full implementation,
this would:
1. Validate client_id and redirect_uri
2. Display consent screen to user
3. Generate authorization code
4. Redirect back to client with code
Currently returns a 501 Not Implemented response.
""",
operation_id="oauth_provider_authorize",
tags=["OAuth Provider"],
)
async def authorize(
response_type: str = Query(..., description="Must be 'code'"),
client_id: str = Query(..., description="OAuth client ID"),
redirect_uri: str = Query(..., description="Redirect URI"),
scope: str = Query(default="", description="Requested scopes"),
state: str = Query(default="", description="CSRF state parameter"),
code_challenge: str | None = Query(default=None, description="PKCE code challenge"),
code_challenge_method: str | None = Query(
default=None, description="PKCE method (S256)"
),
db: AsyncSession = Depends(get_db),
) -> Any:
"""
Authorization endpoint (skeleton).
In a full implementation, this would:
1. Validate the client and redirect URI
2. Authenticate the user (if not already)
3. Show consent screen
4. Generate authorization code
5. Redirect to redirect_uri with code
"""
if not settings.OAUTH_PROVIDER_ENABLED:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="OAuth provider mode is not enabled",
)
# Validate client exists
client = await oauth_client.get_by_client_id(db, client_id=client_id)
if not client:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="invalid_client: Unknown client_id",
)
# Validate redirect_uri
if redirect_uri not in (client.redirect_uris or []):
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="invalid_request: Invalid redirect_uri",
)
# Skeleton: Return not implemented
# Full implementation would redirect to consent screen
raise HTTPException(
status_code=status.HTTP_501_NOT_IMPLEMENTED,
detail="Authorization endpoint not fully implemented. "
"This is a skeleton for MCP integration.",
)
@router.post(
"/provider/token",
summary="Token Endpoint (Skeleton)",
description="""
OAuth 2.0 Token Endpoint.
**NOTE**: This is a skeleton implementation. In a full implementation,
this would exchange authorization codes for access tokens.
Currently returns a 501 Not Implemented response.
""",
operation_id="oauth_provider_token",
tags=["OAuth Provider"],
)
async def token(
grant_type: str = Form(..., description="Grant type (authorization_code)"),
code: str | None = Form(default=None, description="Authorization code"),
redirect_uri: str | None = Form(default=None, description="Redirect URI"),
client_id: str | None = Form(default=None, description="Client ID"),
client_secret: str | None = Form(default=None, description="Client secret"),
code_verifier: str | None = Form(default=None, description="PKCE code verifier"),
refresh_token: str | None = Form(default=None, description="Refresh token"),
db: AsyncSession = Depends(get_db),
) -> Any:
"""
Token endpoint (skeleton).
Supported grant types (when fully implemented):
- authorization_code: Exchange code for tokens
- refresh_token: Refresh access token
"""
if not settings.OAUTH_PROVIDER_ENABLED:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="OAuth provider mode is not enabled",
)
if grant_type not in ["authorization_code", "refresh_token"]:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="unsupported_grant_type",
)
# Skeleton: Return not implemented
raise HTTPException(
status_code=status.HTTP_501_NOT_IMPLEMENTED,
detail="Token endpoint not fully implemented. "
"This is a skeleton for MCP integration.",
)
@router.post(
"/provider/revoke",
summary="Token Revocation Endpoint (Skeleton)",
description="""
OAuth 2.0 Token Revocation Endpoint (RFC 7009).
**NOTE**: This is a skeleton implementation.
Currently returns a 501 Not Implemented response.
""",
operation_id="oauth_provider_revoke",
tags=["OAuth Provider"],
)
async def revoke(
token: str = Form(..., description="Token to revoke"),
token_type_hint: str | None = Form(
default=None, description="Token type hint (access_token, refresh_token)"
),
client_id: str | None = Form(default=None, description="Client ID"),
client_secret: str | None = Form(default=None, description="Client secret"),
db: AsyncSession = Depends(get_db),
) -> Any:
"""
Token revocation endpoint (skeleton).
In a full implementation, this would invalidate the specified token.
"""
if not settings.OAUTH_PROVIDER_ENABLED:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="OAuth provider mode is not enabled",
)
# Skeleton: Return not implemented
raise HTTPException(
status_code=status.HTTP_501_NOT_IMPLEMENTED,
detail="Revocation endpoint not fully implemented. "
"This is a skeleton for MCP integration.",
)
# ============================================================================
# Client Management (Admin only)
# ============================================================================
@router.post(
"/provider/clients",
summary="Register OAuth Client (Admin)",
description="""
Register a new OAuth client (admin only).
This endpoint allows creating MCP clients that can authenticate
against this API.
**NOTE**: This is a minimal implementation.
""",
operation_id="register_oauth_client",
tags=["OAuth Provider"],
)
async def register_client(
client_name: str = Form(..., description="Client application name"),
redirect_uris: str = Form(..., description="Comma-separated list of redirect URIs"),
client_type: str = Form(default="public", description="public or confidential"),
db: AsyncSession = Depends(get_db),
) -> Any:
"""
Register a new OAuth client (skeleton).
In a full implementation, this would require admin authentication.
"""
if not settings.OAUTH_PROVIDER_ENABLED:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="OAuth provider mode is not enabled",
)
# NOTE: In production, this should require admin authentication
# For now, this is a skeleton that shows the structure
from app.schemas.oauth import OAuthClientCreate
client_data = OAuthClientCreate(
client_name=client_name,
client_description=None,
redirect_uris=[uri.strip() for uri in redirect_uris.split(",")],
allowed_scopes=["openid", "profile", "email"],
client_type=client_type,
)
client, secret = await oauth_client.create_client(db, obj_in=client_data)
result = {
"client_id": client.client_id,
"client_name": client.client_name,
"client_type": client.client_type,
"redirect_uris": client.redirect_uris,
}
if secret:
result["client_secret"] = secret
result["warning"] = (
"Store the client_secret securely. It will not be shown again."
)
return result

View File

@@ -76,6 +76,60 @@ class Settings(BaseSettings):
description="Frontend application URL for email links",
)
# OAuth Configuration
OAUTH_ENABLED: bool = Field(
default=False,
description="Enable OAuth authentication (social login)",
)
OAUTH_AUTO_LINK_BY_EMAIL: bool = Field(
default=True,
description="Automatically link OAuth accounts to existing users with matching email",
)
OAUTH_STATE_EXPIRE_MINUTES: int = Field(
default=10,
description="OAuth state parameter expiration time in minutes",
)
# Google OAuth
OAUTH_GOOGLE_CLIENT_ID: str | None = Field(
default=None,
description="Google OAuth client ID from Google Cloud Console",
)
OAUTH_GOOGLE_CLIENT_SECRET: str | None = Field(
default=None,
description="Google OAuth client secret from Google Cloud Console",
)
# GitHub OAuth
OAUTH_GITHUB_CLIENT_ID: str | None = Field(
default=None,
description="GitHub OAuth client ID from GitHub Developer Settings",
)
OAUTH_GITHUB_CLIENT_SECRET: str | None = Field(
default=None,
description="GitHub OAuth client secret from GitHub Developer Settings",
)
# OAuth Provider Mode (for MCP clients - skeleton)
OAUTH_PROVIDER_ENABLED: bool = Field(
default=False,
description="Enable OAuth provider mode (act as authorization server for MCP clients)",
)
OAUTH_ISSUER: str = Field(
default="http://localhost:8000",
description="OAuth issuer URL (your API base URL)",
)
@property
def enabled_oauth_providers(self) -> list[str]:
"""Get list of enabled OAuth providers based on configured credentials."""
providers = []
if self.OAUTH_GOOGLE_CLIENT_ID and self.OAUTH_GOOGLE_CLIENT_SECRET:
providers.append("google")
if self.OAUTH_GITHUB_CLIENT_ID and self.OAUTH_GITHUB_CLIENT_SECRET:
providers.append("github")
return providers
# Admin user
FIRST_SUPERUSER_EMAIL: str | None = Field(
default=None, description="Email for first superuser account"

View File

@@ -1,6 +1,14 @@
# app/crud/__init__.py
from .oauth import oauth_account, oauth_client, oauth_state
from .organization import organization
from .session import session as session_crud
from .user import user
__all__ = ["organization", "session_crud", "user"]
__all__ = [
"oauth_account",
"oauth_client",
"oauth_state",
"organization",
"session_crud",
"user",
]

653
backend/app/crud/oauth.py Normal file
View File

@@ -0,0 +1,653 @@
"""
Async CRUD operations for OAuth models using SQLAlchemy 2.0 patterns.
Provides operations for:
- OAuthAccount: Managing linked OAuth provider accounts
- OAuthState: CSRF protection state during OAuth flows
- OAuthClient: Registered OAuth clients (provider mode skeleton)
"""
import logging
import secrets
from datetime import UTC, datetime
from uuid import UUID
from pydantic import BaseModel
from sqlalchemy import and_, delete, select
from sqlalchemy.exc import IntegrityError
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import joinedload
from app.crud.base import CRUDBase
from app.models.oauth_account import OAuthAccount
from app.models.oauth_client import OAuthClient
from app.models.oauth_state import OAuthState
from app.schemas.oauth import OAuthAccountCreate, OAuthClientCreate, OAuthStateCreate
logger = logging.getLogger(__name__)
# ============================================================================
# OAuth Account CRUD
# ============================================================================
class EmptySchema(BaseModel):
"""Placeholder schema for CRUD operations that don't need update schemas."""
class CRUDOAuthAccount(CRUDBase[OAuthAccount, OAuthAccountCreate, EmptySchema]):
"""CRUD operations for OAuth account links."""
async def get_by_provider_id(
self,
db: AsyncSession,
*,
provider: str,
provider_user_id: str,
) -> OAuthAccount | None:
"""
Get OAuth account by provider and provider user ID.
Args:
db: Database session
provider: OAuth provider name (google, github)
provider_user_id: User ID from the OAuth provider
Returns:
OAuthAccount if found, None otherwise
"""
try:
result = await db.execute(
select(OAuthAccount)
.where(
and_(
OAuthAccount.provider == provider,
OAuthAccount.provider_user_id == provider_user_id,
)
)
.options(joinedload(OAuthAccount.user))
)
return result.scalar_one_or_none()
except Exception as e:
logger.error(
f"Error getting OAuth account for {provider}:{provider_user_id}: {e!s}"
)
raise
async def get_by_provider_email(
self,
db: AsyncSession,
*,
provider: str,
email: str,
) -> OAuthAccount | None:
"""
Get OAuth account by provider and email.
Used for auto-linking existing accounts by email.
Args:
db: Database session
provider: OAuth provider name
email: Email address from the OAuth provider
Returns:
OAuthAccount if found, None otherwise
"""
try:
result = await db.execute(
select(OAuthAccount)
.where(
and_(
OAuthAccount.provider == provider,
OAuthAccount.provider_email == email,
)
)
.options(joinedload(OAuthAccount.user))
)
return result.scalar_one_or_none()
except Exception as e:
logger.error(
f"Error getting OAuth account for {provider} email {email}: {e!s}"
)
raise
async def get_user_accounts(
self,
db: AsyncSession,
*,
user_id: str | UUID,
) -> list[OAuthAccount]:
"""
Get all OAuth accounts linked to a user.
Args:
db: Database session
user_id: User ID
Returns:
List of OAuthAccount objects
"""
try:
user_uuid = UUID(str(user_id)) if isinstance(user_id, str) else user_id
result = await db.execute(
select(OAuthAccount)
.where(OAuthAccount.user_id == user_uuid)
.order_by(OAuthAccount.created_at.desc())
)
return list(result.scalars().all())
except Exception as e:
logger.error(f"Error getting OAuth accounts for user {user_id}: {e!s}")
raise
async def get_user_account_by_provider(
self,
db: AsyncSession,
*,
user_id: str | UUID,
provider: str,
) -> OAuthAccount | None:
"""
Get a specific OAuth account for a user and provider.
Args:
db: Database session
user_id: User ID
provider: OAuth provider name
Returns:
OAuthAccount if found, None otherwise
"""
try:
user_uuid = UUID(str(user_id)) if isinstance(user_id, str) else user_id
result = await db.execute(
select(OAuthAccount).where(
and_(
OAuthAccount.user_id == user_uuid,
OAuthAccount.provider == provider,
)
)
)
return result.scalar_one_or_none()
except Exception as e:
logger.error(
f"Error getting OAuth account for user {user_id}, provider {provider}: {e!s}"
)
raise
async def create_account(
self, db: AsyncSession, *, obj_in: OAuthAccountCreate
) -> OAuthAccount:
"""
Create a new OAuth account link.
Args:
db: Database session
obj_in: OAuth account creation data
Returns:
Created OAuthAccount
Raises:
ValueError: If account already exists or creation fails
"""
try:
db_obj = OAuthAccount(
user_id=obj_in.user_id,
provider=obj_in.provider,
provider_user_id=obj_in.provider_user_id,
provider_email=obj_in.provider_email,
access_token_encrypted=obj_in.access_token_encrypted,
refresh_token_encrypted=obj_in.refresh_token_encrypted,
token_expires_at=obj_in.token_expires_at,
)
db.add(db_obj)
await db.commit()
await db.refresh(db_obj)
logger.info(
f"OAuth account created: {obj_in.provider} linked to user {obj_in.user_id}"
)
return db_obj
except IntegrityError as e:
await db.rollback()
error_msg = str(e.orig) if hasattr(e, "orig") else str(e)
if "uq_oauth_provider_user" in error_msg.lower():
logger.warning(
f"OAuth account already exists: {obj_in.provider}:{obj_in.provider_user_id}"
)
raise ValueError(
f"This {obj_in.provider} account is already linked to another user"
)
logger.error(f"Integrity error creating OAuth account: {error_msg}")
raise ValueError(f"Failed to create OAuth account: {error_msg}")
except Exception as e:
await db.rollback()
logger.error(f"Error creating OAuth account: {e!s}", exc_info=True)
raise
async def delete_account(
self,
db: AsyncSession,
*,
user_id: str | UUID,
provider: str,
) -> bool:
"""
Delete an OAuth account link.
Args:
db: Database session
user_id: User ID
provider: OAuth provider name
Returns:
True if deleted, False if not found
"""
try:
user_uuid = UUID(str(user_id)) if isinstance(user_id, str) else user_id
result = await db.execute(
delete(OAuthAccount).where(
and_(
OAuthAccount.user_id == user_uuid,
OAuthAccount.provider == provider,
)
)
)
await db.commit()
deleted = result.rowcount > 0
if deleted:
logger.info(
f"OAuth account deleted: {provider} unlinked from user {user_id}"
)
else:
logger.warning(
f"OAuth account not found for deletion: {provider} for user {user_id}"
)
return deleted
except Exception as e:
await db.rollback()
logger.error(
f"Error deleting OAuth account {provider} for user {user_id}: {e!s}"
)
raise
async def update_tokens(
self,
db: AsyncSession,
*,
account: OAuthAccount,
access_token_encrypted: str | None = None,
refresh_token_encrypted: str | None = None,
token_expires_at: datetime | None = None,
) -> OAuthAccount:
"""
Update OAuth tokens for an account.
Args:
db: Database session
account: OAuthAccount to update
access_token_encrypted: New encrypted access token
refresh_token_encrypted: New encrypted refresh token
token_expires_at: New token expiration time
Returns:
Updated OAuthAccount
"""
try:
if access_token_encrypted is not None:
account.access_token_encrypted = access_token_encrypted
if refresh_token_encrypted is not None:
account.refresh_token_encrypted = refresh_token_encrypted
if token_expires_at is not None:
account.token_expires_at = token_expires_at
db.add(account)
await db.commit()
await db.refresh(account)
return account
except Exception as e:
await db.rollback()
logger.error(f"Error updating OAuth tokens: {e!s}")
raise
# ============================================================================
# OAuth State CRUD
# ============================================================================
class CRUDOAuthState(CRUDBase[OAuthState, OAuthStateCreate, EmptySchema]):
"""CRUD operations for OAuth state (CSRF protection)."""
async def create_state(
self, db: AsyncSession, *, obj_in: OAuthStateCreate
) -> OAuthState:
"""
Create a new OAuth state for CSRF protection.
Args:
db: Database session
obj_in: OAuth state creation data
Returns:
Created OAuthState
"""
try:
db_obj = OAuthState(
state=obj_in.state,
code_verifier=obj_in.code_verifier,
nonce=obj_in.nonce,
provider=obj_in.provider,
redirect_uri=obj_in.redirect_uri,
user_id=obj_in.user_id,
expires_at=obj_in.expires_at,
)
db.add(db_obj)
await db.commit()
await db.refresh(db_obj)
logger.debug(f"OAuth state created for {obj_in.provider}")
return db_obj
except IntegrityError as e:
await db.rollback()
# State collision (extremely rare with cryptographic random)
error_msg = str(e.orig) if hasattr(e, "orig") else str(e)
logger.error(f"OAuth state collision: {error_msg}")
raise ValueError("Failed to create OAuth state, please retry")
except Exception as e:
await db.rollback()
logger.error(f"Error creating OAuth state: {e!s}", exc_info=True)
raise
async def get_and_consume_state(
self, db: AsyncSession, *, state: str
) -> OAuthState | None:
"""
Get and delete OAuth state (consume it).
This ensures each state can only be used once (replay protection).
Args:
db: Database session
state: State string to look up
Returns:
OAuthState if found and valid, None otherwise
"""
try:
# Get the state
result = await db.execute(
select(OAuthState).where(OAuthState.state == state)
)
db_obj = result.scalar_one_or_none()
if db_obj is None:
logger.warning(f"OAuth state not found: {state[:8]}...")
return None
# Check expiration
# Handle both timezone-aware and timezone-naive datetimes
now = datetime.now(UTC)
expires_at = db_obj.expires_at
if expires_at.tzinfo is None:
# SQLite returns naive datetimes, assume UTC
expires_at = expires_at.replace(tzinfo=UTC)
if expires_at < now:
logger.warning(f"OAuth state expired: {state[:8]}...")
await db.delete(db_obj)
await db.commit()
return None
# Delete it (consume)
await db.delete(db_obj)
await db.commit()
logger.debug(f"OAuth state consumed: {state[:8]}...")
return db_obj
except Exception as e:
await db.rollback()
logger.error(f"Error consuming OAuth state: {e!s}")
raise
async def cleanup_expired(self, db: AsyncSession) -> int:
"""
Clean up expired OAuth states.
Should be called periodically to remove stale states.
Args:
db: Database session
Returns:
Number of states deleted
"""
try:
now = datetime.now(UTC)
stmt = delete(OAuthState).where(OAuthState.expires_at < now)
result = await db.execute(stmt)
await db.commit()
count = result.rowcount
if count > 0:
logger.info(f"Cleaned up {count} expired OAuth states")
return count
except Exception as e:
await db.rollback()
logger.error(f"Error cleaning up expired OAuth states: {e!s}")
raise
# ============================================================================
# OAuth Client CRUD (Provider Mode - Skeleton)
# ============================================================================
class CRUDOAuthClient(CRUDBase[OAuthClient, OAuthClientCreate, EmptySchema]):
"""
CRUD operations for OAuth clients (provider mode).
This is a skeleton implementation for MCP client registration.
Full implementation can be expanded when needed.
"""
async def get_by_client_id(
self, db: AsyncSession, *, client_id: str
) -> OAuthClient | None:
"""
Get OAuth client by client_id.
Args:
db: Database session
client_id: OAuth client ID
Returns:
OAuthClient if found, None otherwise
"""
try:
result = await db.execute(
select(OAuthClient).where(
and_(
OAuthClient.client_id == client_id,
OAuthClient.is_active == True, # noqa: E712
)
)
)
return result.scalar_one_or_none()
except Exception as e:
logger.error(f"Error getting OAuth client {client_id}: {e!s}")
raise
async def create_client(
self,
db: AsyncSession,
*,
obj_in: OAuthClientCreate,
owner_user_id: UUID | None = None,
) -> tuple[OAuthClient, str | None]:
"""
Create a new OAuth client.
Args:
db: Database session
obj_in: OAuth client creation data
owner_user_id: Optional owner user ID
Returns:
Tuple of (created OAuthClient, client_secret or None for public clients)
"""
try:
# Generate client_id
client_id = secrets.token_urlsafe(32)
# Generate client_secret for confidential clients
client_secret = None
client_secret_hash = None
if obj_in.client_type == "confidential":
client_secret = secrets.token_urlsafe(48)
# In production, use proper password hashing (bcrypt)
# For now, we store a hash placeholder
import hashlib
client_secret_hash = hashlib.sha256(client_secret.encode()).hexdigest()
db_obj = OAuthClient(
client_id=client_id,
client_secret_hash=client_secret_hash,
client_name=obj_in.client_name,
client_description=obj_in.client_description,
client_type=obj_in.client_type,
redirect_uris=obj_in.redirect_uris,
allowed_scopes=obj_in.allowed_scopes,
owner_user_id=owner_user_id,
is_active=True,
)
db.add(db_obj)
await db.commit()
await db.refresh(db_obj)
logger.info(
f"OAuth client created: {obj_in.client_name} ({client_id[:8]}...)"
)
return db_obj, client_secret
except IntegrityError as e:
await db.rollback()
error_msg = str(e.orig) if hasattr(e, "orig") else str(e)
logger.error(f"Error creating OAuth client: {error_msg}")
raise ValueError(f"Failed to create OAuth client: {error_msg}")
except Exception as e:
await db.rollback()
logger.error(f"Error creating OAuth client: {e!s}", exc_info=True)
raise
async def deactivate_client(
self, db: AsyncSession, *, client_id: str
) -> OAuthClient | None:
"""
Deactivate an OAuth client.
Args:
db: Database session
client_id: OAuth client ID
Returns:
Deactivated OAuthClient if found, None otherwise
"""
try:
client = await self.get_by_client_id(db, client_id=client_id)
if client is None:
return None
client.is_active = False
db.add(client)
await db.commit()
await db.refresh(client)
logger.info(f"OAuth client deactivated: {client.client_name}")
return client
except Exception as e:
await db.rollback()
logger.error(f"Error deactivating OAuth client {client_id}: {e!s}")
raise
async def validate_redirect_uri(
self, db: AsyncSession, *, client_id: str, redirect_uri: str
) -> bool:
"""
Validate that a redirect URI is allowed for a client.
Args:
db: Database session
client_id: OAuth client ID
redirect_uri: Redirect URI to validate
Returns:
True if valid, False otherwise
"""
try:
client = await self.get_by_client_id(db, client_id=client_id)
if client is None:
return False
return redirect_uri in (client.redirect_uris or [])
except Exception as e:
logger.error(f"Error validating redirect URI: {e!s}")
return False
async def verify_client_secret(
self, db: AsyncSession, *, client_id: str, client_secret: str
) -> bool:
"""
Verify client credentials.
Args:
db: Database session
client_id: OAuth client ID
client_secret: Client secret to verify
Returns:
True if valid, False otherwise
"""
try:
result = await db.execute(
select(OAuthClient).where(
and_(
OAuthClient.client_id == client_id,
OAuthClient.is_active == True, # noqa: E712
)
)
)
client = result.scalar_one_or_none()
if client is None or client.client_secret_hash is None:
return False
# Verify secret
import hashlib
secret_hash = hashlib.sha256(client_secret.encode()).hexdigest()
# Cast to str for type safety with compare_digest
stored_hash: str = str(client.client_secret_hash)
return secrets.compare_digest(stored_hash, secret_hash)
except Exception as e:
logger.error(f"Error verifying client secret: {e!s}")
return False
# ============================================================================
# Singleton instances
# ============================================================================
oauth_account = CRUDOAuthAccount(OAuthAccount)
oauth_state = CRUDOAuthState(OAuthState)
oauth_client = CRUDOAuthClient(OAuthClient)

View File

@@ -7,6 +7,11 @@ Imports all models to ensure they're registered with SQLAlchemy.
from app.core.database import Base
from .base import TimestampMixin, UUIDMixin
# OAuth models
from .oauth_account import OAuthAccount
from .oauth_client import OAuthClient
from .oauth_state import OAuthState
from .organization import Organization
# Import models
@@ -16,6 +21,9 @@ from .user_session import UserSession
__all__ = [
"Base",
"OAuthAccount",
"OAuthClient",
"OAuthState",
"Organization",
"OrganizationRole",
"TimestampMixin",

View File

@@ -0,0 +1,55 @@
"""OAuth account model for linking external OAuth providers to users."""
from sqlalchemy import Column, DateTime, ForeignKey, Index, String, UniqueConstraint
from sqlalchemy.dialects.postgresql import UUID
from sqlalchemy.orm import relationship
from .base import Base, TimestampMixin, UUIDMixin
class OAuthAccount(Base, UUIDMixin, TimestampMixin):
"""
Links OAuth provider accounts to users.
Supports multiple OAuth providers per user (e.g., user can have both
Google and GitHub connected). Each provider account is uniquely identified
by (provider, provider_user_id).
"""
__tablename__ = "oauth_accounts"
# Link to user
user_id = Column(
UUID(as_uuid=True),
ForeignKey("users.id", ondelete="CASCADE"),
nullable=False,
index=True,
)
# OAuth provider identification
provider = Column(
String(50), nullable=False, index=True
) # google, github, microsoft
provider_user_id = Column(String(255), nullable=False) # Provider's unique user ID
provider_email = Column(
String(255), nullable=True, index=True
) # Email from provider (for reference)
# Optional: store provider tokens for API access
# These should be encrypted at rest in production
access_token_encrypted = Column(String(2048), nullable=True)
refresh_token_encrypted = Column(String(2048), nullable=True)
token_expires_at = Column(DateTime(timezone=True), nullable=True)
# Relationship
user = relationship("User", back_populates="oauth_accounts")
__table_args__ = (
# Each provider account can only be linked to one user
UniqueConstraint("provider", "provider_user_id", name="uq_oauth_provider_user"),
# Index for finding all OAuth accounts for a user + provider
Index("ix_oauth_accounts_user_provider", "user_id", "provider"),
)
def __repr__(self):
return f"<OAuthAccount {self.provider}:{self.provider_user_id}>"

View File

@@ -0,0 +1,67 @@
"""OAuth client model for OAuth provider mode (MCP clients)."""
from sqlalchemy import Boolean, Column, ForeignKey, String
from sqlalchemy.dialects.postgresql import JSONB, UUID
from sqlalchemy.orm import relationship
from .base import Base, TimestampMixin, UUIDMixin
class OAuthClient(Base, UUIDMixin, TimestampMixin):
"""
Registered OAuth clients (for OAuth provider mode).
This model stores third-party applications that can authenticate
against this API using OAuth 2.0. Used for MCP (Model Context Protocol)
client authentication and API access.
NOTE: This is a skeleton implementation. The full OAuth provider
functionality (authorization endpoint, token endpoint, etc.) can be
expanded when needed.
"""
__tablename__ = "oauth_clients"
# Client credentials
client_id = Column(String(64), unique=True, nullable=False, index=True)
client_secret_hash = Column(
String(255), nullable=True
) # NULL for public clients (PKCE)
# Client metadata
client_name = Column(String(255), nullable=False)
client_description = Column(String(1000), nullable=True)
# Client type: "public" (SPA, mobile) or "confidential" (server-side)
client_type = Column(String(20), nullable=False, default="public")
# Allowed redirect URIs (JSON array)
redirect_uris = Column(JSONB, nullable=False, default=list)
# Allowed scopes (JSON array of scope names)
allowed_scopes = Column(JSONB, nullable=False, default=list)
# Token lifetimes (in seconds)
access_token_lifetime = Column(String(10), nullable=False, default="3600") # 1 hour
refresh_token_lifetime = Column(
String(10), nullable=False, default="604800"
) # 7 days
# Status
is_active = Column(Boolean, default=True, nullable=False, index=True)
# Optional: owner user (for user-registered applications)
owner_user_id = Column(
UUID(as_uuid=True),
ForeignKey("users.id", ondelete="SET NULL"),
nullable=True,
)
# MCP-specific: URL of the MCP server this client represents
mcp_server_url = Column(String(2048), nullable=True)
# Relationship
owner = relationship("User", backref="owned_oauth_clients")
def __repr__(self):
return f"<OAuthClient {self.client_name} ({self.client_id[:8]}...)>"

View File

@@ -0,0 +1,45 @@
"""OAuth state model for CSRF protection during OAuth flows."""
from sqlalchemy import Column, DateTime, String
from sqlalchemy.dialects.postgresql import UUID
from .base import Base, TimestampMixin, UUIDMixin
class OAuthState(Base, UUIDMixin, TimestampMixin):
"""
Temporary storage for OAuth state parameters.
Prevents CSRF attacks during OAuth flows by storing a random state
value that must match on callback. Also stores PKCE code_verifier
for the Authorization Code flow with PKCE.
These records are short-lived (10 minutes by default) and should
be deleted after use or expiration.
"""
__tablename__ = "oauth_states"
# Random state parameter (CSRF protection)
state = Column(String(255), unique=True, nullable=False, index=True)
# PKCE code_verifier (used to generate code_challenge)
code_verifier = Column(String(128), nullable=True)
# OIDC nonce for ID token replay protection
nonce = Column(String(255), nullable=True)
# OAuth provider (google, github, etc.)
provider = Column(String(50), nullable=False)
# Original redirect URI (for callback validation)
redirect_uri = Column(String(500), nullable=True)
# User ID if this is an account linking flow (user is already logged in)
user_id = Column(UUID(as_uuid=True), nullable=True)
# Expiration time
expires_at = Column(DateTime(timezone=True), nullable=False)
def __repr__(self):
return f"<OAuthState {self.state[:8]}... ({self.provider})>"

View File

@@ -9,7 +9,8 @@ class User(Base, UUIDMixin, TimestampMixin):
__tablename__ = "users"
email = Column(String(255), unique=True, nullable=False, index=True)
password_hash = Column(String(255), nullable=False)
# Nullable to support OAuth-only users who never set a password
password_hash = Column(String(255), nullable=True)
first_name = Column(String(100), nullable=False, default="user")
last_name = Column(String(100), nullable=True)
phone_number = Column(String(20))
@@ -23,6 +24,19 @@ class User(Base, UUIDMixin, TimestampMixin):
user_organizations = relationship(
"UserOrganization", back_populates="user", cascade="all, delete-orphan"
)
oauth_accounts = relationship(
"OAuthAccount", back_populates="user", cascade="all, delete-orphan"
)
@property
def has_password(self) -> bool:
"""Check if user can login with password (not OAuth-only)."""
return self.password_hash is not None
@property
def can_remove_oauth(self) -> bool:
"""Check if user can safely remove an OAuth account link."""
return self.has_password or len(self.oauth_accounts) > 1
def __repr__(self):
return f"<User {self.email}>"

View File

@@ -0,0 +1,313 @@
"""
Pydantic schemas for OAuth authentication.
"""
from datetime import datetime
from uuid import UUID
from pydantic import BaseModel, ConfigDict, Field
# ============================================================================
# OAuth Provider Info (for frontend to display available providers)
# ============================================================================
class OAuthProviderInfo(BaseModel):
"""Information about an available OAuth provider."""
provider: str = Field(..., description="Provider identifier (google, github)")
name: str = Field(..., description="Human-readable provider name")
icon: str | None = Field(None, description="Icon identifier for frontend")
class OAuthProvidersResponse(BaseModel):
"""Response containing list of enabled OAuth providers."""
enabled: bool = Field(..., description="Whether OAuth is globally enabled")
providers: list[OAuthProviderInfo] = Field(
default_factory=list, description="List of enabled providers"
)
model_config = ConfigDict(
json_schema_extra={
"example": {
"enabled": True,
"providers": [
{"provider": "google", "name": "Google", "icon": "google"},
{"provider": "github", "name": "GitHub", "icon": "github"},
],
}
}
)
# ============================================================================
# OAuth Account (linked provider accounts)
# ============================================================================
class OAuthAccountBase(BaseModel):
"""Base schema for OAuth accounts."""
provider: str = Field(..., max_length=50, description="OAuth provider name")
provider_email: str | None = Field(
None, max_length=255, description="Email from OAuth provider"
)
class OAuthAccountCreate(OAuthAccountBase):
"""Schema for creating an OAuth account link (internal use)."""
user_id: UUID
provider_user_id: str = Field(..., max_length=255)
access_token_encrypted: str | None = None
refresh_token_encrypted: str | None = None
token_expires_at: datetime | None = None
class OAuthAccountResponse(OAuthAccountBase):
"""Schema for OAuth account response to clients."""
id: UUID
created_at: datetime
model_config = ConfigDict(
from_attributes=True,
json_schema_extra={
"example": {
"id": "123e4567-e89b-12d3-a456-426614174000",
"provider": "google",
"provider_email": "user@gmail.com",
"created_at": "2025-11-24T12:00:00Z",
}
},
)
class OAuthAccountsListResponse(BaseModel):
"""Response containing list of linked OAuth accounts."""
accounts: list[OAuthAccountResponse]
model_config = ConfigDict(
json_schema_extra={
"example": {
"accounts": [
{
"id": "123e4567-e89b-12d3-a456-426614174000",
"provider": "google",
"provider_email": "user@gmail.com",
"created_at": "2025-11-24T12:00:00Z",
}
]
}
}
)
# ============================================================================
# OAuth Flow (authorization, callback, etc.)
# ============================================================================
class OAuthAuthorizeRequest(BaseModel):
"""Request parameters for OAuth authorization."""
provider: str = Field(..., description="OAuth provider (google, github)")
redirect_uri: str | None = Field(
None, description="Frontend callback URL after OAuth"
)
mode: str = Field(
default="login",
description="OAuth mode: login, register, or link",
pattern="^(login|register|link)$",
)
class OAuthCallbackRequest(BaseModel):
"""Request parameters for OAuth callback."""
code: str = Field(..., description="Authorization code from provider")
state: str = Field(..., description="State parameter for CSRF protection")
class OAuthCallbackResponse(BaseModel):
"""Response after successful OAuth authentication."""
access_token: str = Field(..., description="JWT access token")
refresh_token: str = Field(..., description="JWT refresh token")
token_type: str = Field(default="bearer")
expires_in: int = Field(..., description="Token expiration in seconds")
is_new_user: bool = Field(
default=False, description="Whether a new user was created"
)
model_config = ConfigDict(
json_schema_extra={
"example": {
"access_token": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9...",
"refresh_token": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9...",
"token_type": "bearer",
"expires_in": 900,
"is_new_user": False,
}
}
)
class OAuthUnlinkResponse(BaseModel):
"""Response after unlinking an OAuth account."""
success: bool = Field(..., description="Whether the unlink was successful")
message: str = Field(..., description="Status message")
model_config = ConfigDict(
json_schema_extra={
"example": {"success": True, "message": "Google account unlinked"}
}
)
# ============================================================================
# OAuth State (CSRF protection - internal use)
# ============================================================================
class OAuthStateCreate(BaseModel):
"""Schema for creating OAuth state (internal use)."""
state: str = Field(..., max_length=255)
code_verifier: str | None = Field(None, max_length=128)
nonce: str | None = Field(None, max_length=255)
provider: str = Field(..., max_length=50)
redirect_uri: str | None = Field(None, max_length=500)
user_id: UUID | None = None
expires_at: datetime
# ============================================================================
# OAuth Client (Provider Mode - MCP clients)
# ============================================================================
class OAuthClientBase(BaseModel):
"""Base schema for OAuth clients."""
client_name: str = Field(..., max_length=255, description="Client application name")
client_description: str | None = Field(
None, max_length=1000, description="Client description"
)
redirect_uris: list[str] = Field(
default_factory=list, description="Allowed redirect URIs"
)
allowed_scopes: list[str] = Field(
default_factory=list, description="Allowed OAuth scopes"
)
class OAuthClientCreate(OAuthClientBase):
"""Schema for creating an OAuth client."""
client_type: str = Field(
default="public",
description="Client type: public or confidential",
pattern="^(public|confidential)$",
)
class OAuthClientResponse(OAuthClientBase):
"""Schema for OAuth client response."""
id: UUID
client_id: str = Field(..., description="OAuth client ID")
client_type: str
is_active: bool
created_at: datetime
model_config = ConfigDict(
from_attributes=True,
json_schema_extra={
"example": {
"id": "123e4567-e89b-12d3-a456-426614174000",
"client_id": "abc123def456",
"client_name": "My MCP App",
"client_description": "My application that uses MCP",
"client_type": "public",
"redirect_uris": ["http://localhost:3000/callback"],
"allowed_scopes": ["read:users", "write:users"],
"is_active": True,
"created_at": "2025-11-24T12:00:00Z",
}
},
)
class OAuthClientWithSecret(OAuthClientResponse):
"""Schema for OAuth client response including secret (only shown once)."""
client_secret: str | None = Field(
None, description="Client secret (only shown once for confidential clients)"
)
model_config = ConfigDict(
from_attributes=True,
json_schema_extra={
"example": {
"id": "123e4567-e89b-12d3-a456-426614174000",
"client_id": "abc123def456",
"client_secret": "secret_xyz789",
"client_name": "My MCP App",
"client_type": "confidential",
"redirect_uris": ["http://localhost:3000/callback"],
"allowed_scopes": ["read:users"],
"is_active": True,
"created_at": "2025-11-24T12:00:00Z",
}
},
)
# ============================================================================
# OAuth Provider Discovery (RFC 8414 - skeleton)
# ============================================================================
class OAuthServerMetadata(BaseModel):
"""OAuth 2.0 Authorization Server Metadata (RFC 8414)."""
issuer: str = Field(..., description="Authorization server issuer URL")
authorization_endpoint: str = Field(..., description="Authorization endpoint URL")
token_endpoint: str = Field(..., description="Token endpoint URL")
registration_endpoint: str | None = Field(
None, description="Dynamic client registration endpoint"
)
revocation_endpoint: str | None = Field(
None, description="Token revocation endpoint"
)
scopes_supported: list[str] = Field(
default_factory=list, description="Supported scopes"
)
response_types_supported: list[str] = Field(
default_factory=lambda: ["code"], description="Supported response types"
)
grant_types_supported: list[str] = Field(
default_factory=lambda: ["authorization_code", "refresh_token"],
description="Supported grant types",
)
code_challenge_methods_supported: list[str] = Field(
default_factory=lambda: ["S256"], description="Supported PKCE methods"
)
model_config = ConfigDict(
json_schema_extra={
"example": {
"issuer": "https://api.example.com",
"authorization_endpoint": "https://api.example.com/oauth/authorize",
"token_endpoint": "https://api.example.com/oauth/token",
"scopes_supported": ["openid", "profile", "email", "read:users"],
"response_types_supported": ["code"],
"grant_types_supported": ["authorization_code", "refresh_token"],
"code_challenge_methods_supported": ["S256"],
}
}
)

View File

@@ -0,0 +1,5 @@
# app/services/__init__.py
from .auth_service import AuthService
from .oauth_service import OAuthService
__all__ = ["AuthService", "OAuthService"]

View File

@@ -0,0 +1,598 @@
"""
OAuth Service for handling social authentication flows.
Supports:
- Google OAuth (OpenID Connect)
- GitHub OAuth
Features:
- PKCE support for public clients
- State parameter for CSRF protection
- Auto-linking by email (configurable)
- Account linking for existing users
"""
import logging
import secrets
from datetime import UTC, datetime, timedelta
from typing import TypedDict, cast
from uuid import UUID
from authlib.integrations.httpx_client import AsyncOAuth2Client
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.core.auth import create_access_token, create_refresh_token
from app.core.config import settings
from app.core.exceptions import AuthenticationError
from app.crud import oauth_account, oauth_state
from app.models.user import User
from app.schemas.oauth import (
OAuthAccountCreate,
OAuthCallbackResponse,
OAuthProviderInfo,
OAuthProvidersResponse,
OAuthStateCreate,
)
logger = logging.getLogger(__name__)
class OAuthProviderConfig(TypedDict, total=False):
"""Type definition for OAuth provider configuration."""
name: str
icon: str
authorize_url: str
token_url: str
userinfo_url: str
email_url: str # Optional, GitHub-only
scopes: list[str]
supports_pkce: bool
# Provider configurations
OAUTH_PROVIDERS: dict[str, OAuthProviderConfig] = {
"google": {
"name": "Google",
"icon": "google",
"authorize_url": "https://accounts.google.com/o/oauth2/v2/auth",
"token_url": "https://oauth2.googleapis.com/token",
"userinfo_url": "https://www.googleapis.com/oauth2/v3/userinfo",
"scopes": ["openid", "email", "profile"],
"supports_pkce": True,
},
"github": {
"name": "GitHub",
"icon": "github",
"authorize_url": "https://github.com/login/oauth/authorize",
"token_url": "https://github.com/login/oauth/access_token",
"userinfo_url": "https://api.github.com/user",
"email_url": "https://api.github.com/user/emails",
"scopes": ["read:user", "user:email"],
"supports_pkce": False, # GitHub doesn't support PKCE
},
}
class OAuthService:
"""Service for handling OAuth authentication flows."""
@staticmethod
def get_enabled_providers() -> OAuthProvidersResponse:
"""
Get list of enabled OAuth providers.
Returns:
OAuthProvidersResponse with enabled providers
"""
providers = []
for provider_id in settings.enabled_oauth_providers:
if provider_id in OAUTH_PROVIDERS:
config = OAUTH_PROVIDERS[provider_id]
providers.append(
OAuthProviderInfo(
provider=provider_id,
name=config["name"],
icon=config["icon"],
)
)
return OAuthProvidersResponse(
enabled=settings.OAUTH_ENABLED and len(providers) > 0,
providers=providers,
)
@staticmethod
def _get_provider_credentials(provider: str) -> tuple[str, str]:
"""Get client ID and secret for a provider."""
if provider == "google":
client_id = settings.OAUTH_GOOGLE_CLIENT_ID
client_secret = settings.OAUTH_GOOGLE_CLIENT_SECRET
elif provider == "github":
client_id = settings.OAUTH_GITHUB_CLIENT_ID
client_secret = settings.OAUTH_GITHUB_CLIENT_SECRET
else:
raise AuthenticationError(f"Unknown OAuth provider: {provider}")
if not client_id or not client_secret:
raise AuthenticationError(f"OAuth provider {provider} is not configured")
return client_id, client_secret
@staticmethod
async def create_authorization_url(
db: AsyncSession,
*,
provider: str,
redirect_uri: str,
user_id: str | None = None,
) -> tuple[str, str]:
"""
Create OAuth authorization URL with state and optional PKCE.
Args:
db: Database session
provider: OAuth provider (google, github)
redirect_uri: Callback URL after OAuth
user_id: User ID if linking account (user is logged in)
Returns:
Tuple of (authorization_url, state)
Raises:
AuthenticationError: If provider is not configured
"""
if not settings.OAUTH_ENABLED:
raise AuthenticationError("OAuth is not enabled")
if provider not in OAUTH_PROVIDERS:
raise AuthenticationError(f"Unknown OAuth provider: {provider}")
if provider not in settings.enabled_oauth_providers:
raise AuthenticationError(f"OAuth provider {provider} is not enabled")
config = OAUTH_PROVIDERS[provider]
client_id, client_secret = OAuthService._get_provider_credentials(provider)
# Generate state for CSRF protection
state = secrets.token_urlsafe(32)
# Generate PKCE code verifier and challenge if supported
code_verifier = None
code_challenge = None
if config.get("supports_pkce"):
code_verifier = secrets.token_urlsafe(64)
# Create code_challenge using S256 method
import base64
import hashlib
code_challenge_bytes = hashlib.sha256(code_verifier.encode()).digest()
code_challenge = (
base64.urlsafe_b64encode(code_challenge_bytes).decode().rstrip("=")
)
# Generate nonce for OIDC (Google)
nonce = secrets.token_urlsafe(32) if provider == "google" else None
# Store state in database
from uuid import UUID
state_data = OAuthStateCreate(
state=state,
code_verifier=code_verifier,
nonce=nonce,
provider=provider,
redirect_uri=redirect_uri,
user_id=UUID(user_id) if user_id else None,
expires_at=datetime.now(UTC)
+ timedelta(minutes=settings.OAUTH_STATE_EXPIRE_MINUTES),
)
await oauth_state.create_state(db, obj_in=state_data)
# Build authorization URL
async with AsyncOAuth2Client(
client_id=client_id,
client_secret=client_secret,
redirect_uri=redirect_uri,
) as client:
# Prepare authorization params
auth_params = {
"state": state,
"scope": " ".join(config["scopes"]),
}
if code_challenge:
auth_params["code_challenge"] = code_challenge
auth_params["code_challenge_method"] = "S256"
if nonce:
auth_params["nonce"] = nonce
url, _ = client.create_authorization_url(
config["authorize_url"],
**auth_params,
)
logger.info(f"OAuth authorization URL created for {provider}")
return url, state
@staticmethod
async def handle_callback(
db: AsyncSession,
*,
code: str,
state: str,
redirect_uri: str,
) -> OAuthCallbackResponse:
"""
Handle OAuth callback and authenticate/create user.
Args:
db: Database session
code: Authorization code from provider
state: State parameter for CSRF verification
redirect_uri: Callback URL (must match authorization request)
Returns:
OAuthCallbackResponse with tokens
Raises:
AuthenticationError: If authentication fails
"""
# Validate and consume state
state_record = await oauth_state.get_and_consume_state(db, state=state)
if not state_record:
raise AuthenticationError("Invalid or expired OAuth state")
# Extract provider from state record (str for type safety)
provider: str = str(state_record.provider)
if provider not in OAUTH_PROVIDERS:
raise AuthenticationError(f"Unknown OAuth provider: {provider}")
config = OAUTH_PROVIDERS[provider]
client_id, client_secret = OAuthService._get_provider_credentials(provider)
# Exchange code for tokens
async with AsyncOAuth2Client(
client_id=client_id,
client_secret=client_secret,
redirect_uri=redirect_uri,
) as client:
try:
# Prepare token request params
token_params: dict[str, str] = {"code": code}
if state_record.code_verifier:
token_params["code_verifier"] = str(state_record.code_verifier)
token = await client.fetch_token(
config["token_url"],
**token_params,
)
except Exception as e:
logger.error(f"OAuth token exchange failed: {e!s}")
raise AuthenticationError("Failed to exchange authorization code")
# Get user info from provider
try:
access_token = token.get("access_token")
if not access_token:
raise AuthenticationError("No access token received")
user_info = await OAuthService._get_user_info(
client, provider, config, access_token
)
except Exception as e:
logger.error(f"Failed to get user info: {e!s}")
raise AuthenticationError(
"Failed to get user information from provider"
)
# Process user info and create/link account
provider_user_id = str(user_info.get("id") or user_info.get("sub"))
# Email can be None if user didn't grant email permission
email_raw = user_info.get("email")
provider_email: str | None = str(email_raw) if email_raw else None
if not provider_user_id:
raise AuthenticationError("Provider did not return user ID")
# Check if this OAuth account already exists
existing_oauth = await oauth_account.get_by_provider_id(
db, provider=provider, provider_user_id=provider_user_id
)
is_new_user = False
if existing_oauth:
# Existing OAuth account - login
user = existing_oauth.user
if not user.is_active:
raise AuthenticationError("User account is inactive")
# Update tokens if stored
if token.get("access_token"):
await oauth_account.update_tokens(
db,
account=existing_oauth,
access_token_encrypted=token.get("access_token"), # TODO: encrypt
refresh_token_encrypted=token.get("refresh_token"), # TODO: encrypt
token_expires_at=datetime.now(UTC)
+ timedelta(seconds=token.get("expires_in", 3600)),
)
logger.info(f"OAuth login successful for {user.email} via {provider}")
elif state_record.user_id:
# Account linking flow (user is already logged in)
result = await db.execute(
select(User).where(User.id == state_record.user_id)
)
user = result.scalar_one_or_none()
if not user:
raise AuthenticationError("User not found for account linking")
# Check if user already has this provider linked
user_id = cast(UUID, user.id)
existing_provider = await oauth_account.get_user_account_by_provider(
db, user_id=user_id, provider=provider
)
if existing_provider:
raise AuthenticationError(
f"You already have a {provider} account linked"
)
# Create OAuth account link
oauth_create = OAuthAccountCreate(
user_id=user_id,
provider=provider,
provider_user_id=provider_user_id,
provider_email=provider_email,
access_token_encrypted=token.get("access_token"), # TODO: encrypt
refresh_token_encrypted=token.get("refresh_token"), # TODO: encrypt
token_expires_at=datetime.now(UTC)
+ timedelta(seconds=token.get("expires_in", 3600))
if token.get("expires_in")
else None,
)
await oauth_account.create_account(db, obj_in=oauth_create)
logger.info(f"OAuth account linked: {provider} -> {user.email}")
else:
# New OAuth login - check for existing user by email
user = None
if provider_email and settings.OAUTH_AUTO_LINK_BY_EMAIL:
result = await db.execute(
select(User).where(User.email == provider_email)
)
user = result.scalar_one_or_none()
if user:
# Auto-link to existing user
if not user.is_active:
raise AuthenticationError("User account is inactive")
# Check if user already has this provider linked
user_id = cast(UUID, user.id)
existing_provider = await oauth_account.get_user_account_by_provider(
db, user_id=user_id, provider=provider
)
if existing_provider:
# This shouldn't happen if we got here, but safety check
logger.warning(
f"OAuth account already linked (race condition?): {provider} -> {user.email}"
)
else:
# Create OAuth account link
oauth_create = OAuthAccountCreate(
user_id=user_id,
provider=provider,
provider_user_id=provider_user_id,
provider_email=provider_email,
access_token_encrypted=token.get("access_token"),
refresh_token_encrypted=token.get("refresh_token"),
token_expires_at=datetime.now(UTC)
+ timedelta(seconds=token.get("expires_in", 3600))
if token.get("expires_in")
else None,
)
await oauth_account.create_account(db, obj_in=oauth_create)
logger.info(f"OAuth auto-linked by email: {provider} -> {user.email}")
else:
# Create new user
if not provider_email:
raise AuthenticationError(
f"Email is required for registration. "
f"Please grant email permission to {provider}."
)
user = await OAuthService._create_oauth_user(
db,
email=provider_email,
provider=provider,
provider_user_id=provider_user_id,
user_info=user_info,
token=token,
)
is_new_user = True
logger.info(f"New user created via OAuth: {user.email} ({provider})")
# Generate JWT tokens
claims = {
"is_superuser": user.is_superuser,
"email": user.email,
"first_name": user.first_name,
}
access_token_jwt = create_access_token(subject=str(user.id), claims=claims)
refresh_token_jwt = create_refresh_token(subject=str(user.id))
return OAuthCallbackResponse(
access_token=access_token_jwt,
refresh_token=refresh_token_jwt,
token_type="bearer",
expires_in=settings.ACCESS_TOKEN_EXPIRE_MINUTES * 60,
is_new_user=is_new_user,
)
@staticmethod
async def _get_user_info(
client: AsyncOAuth2Client,
provider: str,
config: OAuthProviderConfig,
access_token: str,
) -> dict[str, object]:
"""Get user info from OAuth provider."""
headers = {"Authorization": f"Bearer {access_token}"}
if provider == "github":
# GitHub returns JSON with Accept header
headers["Accept"] = "application/vnd.github+json"
resp = await client.get(config["userinfo_url"], headers=headers)
resp.raise_for_status()
user_info = resp.json()
# GitHub requires separate request for email
if provider == "github" and not user_info.get("email"):
email_resp = await client.get(
config["email_url"],
headers=headers,
)
email_resp.raise_for_status()
emails = email_resp.json()
# Find primary verified email
for email_data in emails:
if email_data.get("primary") and email_data.get("verified"):
user_info["email"] = email_data["email"]
break
return user_info
@staticmethod
async def _create_oauth_user(
db: AsyncSession,
*,
email: str,
provider: str,
provider_user_id: str,
user_info: dict,
token: dict,
) -> User:
"""Create a new user from OAuth provider data."""
# Extract name from user_info
first_name = "User"
last_name = None
if provider == "google":
first_name = user_info.get("given_name") or user_info.get("name", "User")
last_name = user_info.get("family_name")
elif provider == "github":
# GitHub has full name, try to split
name = user_info.get("name") or user_info.get("login", "User")
parts = name.split(" ", 1)
first_name = parts[0]
last_name = parts[1] if len(parts) > 1 else None
# Create user (no password for OAuth-only users)
user = User(
email=email,
password_hash=None, # OAuth-only user
first_name=first_name,
last_name=last_name,
is_active=True,
is_superuser=False,
)
db.add(user)
await db.flush() # Get user.id
# Create OAuth account link
user_id = cast(UUID, user.id)
oauth_create = OAuthAccountCreate(
user_id=user_id,
provider=provider,
provider_user_id=provider_user_id,
provider_email=email,
access_token_encrypted=token.get("access_token"), # TODO: encrypt
refresh_token_encrypted=token.get("refresh_token"), # TODO: encrypt
token_expires_at=datetime.now(UTC)
+ timedelta(seconds=token.get("expires_in", 3600))
if token.get("expires_in")
else None,
)
await oauth_account.create_account(db, obj_in=oauth_create)
await db.commit()
await db.refresh(user)
return user
@staticmethod
async def unlink_provider(
db: AsyncSession,
*,
user: User,
provider: str,
) -> bool:
"""
Unlink an OAuth provider from a user account.
Args:
db: Database session
user: User to unlink from
provider: Provider to unlink
Returns:
True if unlinked successfully
Raises:
AuthenticationError: If unlinking would leave user without login method
"""
# Check if user can safely remove this OAuth account
# Note: We query directly instead of using user.can_remove_oauth property
# because the property uses lazy loading which doesn't work in async context
user_id = cast(UUID, user.id)
has_password = user.password_hash is not None
oauth_accounts = await oauth_account.get_user_accounts(db, user_id=user_id)
can_remove = has_password or len(oauth_accounts) > 1
if not can_remove:
raise AuthenticationError(
"Cannot unlink OAuth account. You must have either a password set "
"or at least one other OAuth provider linked."
)
deleted = await oauth_account.delete_account(
db, user_id=user_id, provider=provider
)
if not deleted:
raise AuthenticationError(f"No {provider} account found to unlink")
logger.info(f"OAuth provider unlinked: {provider} from {user.email}")
return True
@staticmethod
async def cleanup_expired_states(db: AsyncSession) -> int:
"""
Clean up expired OAuth states.
Should be called periodically (e.g., by a background task).
Args:
db: Database session
Returns:
Number of states cleaned up
"""
return await oauth_state.cleanup_expired(db)

View File

@@ -54,6 +54,9 @@ dependencies = [
"passlib==1.7.4",
"bcrypt==4.2.1",
"cryptography==44.0.1",
# OAuth authentication
"authlib>=1.3.0",
]
# Development dependencies
@@ -243,6 +246,10 @@ ignore_missing_imports = true
module = "starlette.*"
ignore_missing_imports = true
[[tool.mypy.overrides]]
module = "authlib.*"
ignore_missing_imports = true
# SQLAlchemy ORM models - Column descriptors cause type confusion
[[tool.mypy.overrides]]
module = "app.models.*"

View File

@@ -0,0 +1,394 @@
# tests/api/test_oauth.py
"""
Tests for OAuth API endpoints.
"""
from unittest.mock import patch
from uuid import uuid4
import pytest
from app.crud.oauth import oauth_account
from app.schemas.oauth import OAuthAccountCreate
def get_error_message(response_json: dict) -> str:
"""Extract error message from API error response."""
if response_json.get("errors"):
return response_json["errors"][0].get("message", "")
return response_json.get("detail", "")
class TestOAuthProviders:
"""Tests for OAuth providers endpoint."""
@pytest.mark.asyncio
async def test_list_providers_disabled(self, client):
"""Test listing providers when OAuth is disabled."""
with patch("app.services.oauth_service.settings") as mock_settings:
mock_settings.OAUTH_ENABLED = False
mock_settings.enabled_oauth_providers = []
response = await client.get("/api/v1/oauth/providers")
assert response.status_code == 200
data = response.json()
assert data["enabled"] is False
assert data["providers"] == []
@pytest.mark.asyncio
async def test_list_providers_enabled(self, client):
"""Test listing providers when OAuth is enabled."""
with patch("app.services.oauth_service.settings") as mock_settings:
mock_settings.OAUTH_ENABLED = True
mock_settings.enabled_oauth_providers = ["google", "github"]
response = await client.get("/api/v1/oauth/providers")
assert response.status_code == 200
data = response.json()
assert data["enabled"] is True
assert len(data["providers"]) == 2
provider_names = [p["provider"] for p in data["providers"]]
assert "google" in provider_names
assert "github" in provider_names
class TestOAuthAuthorize:
"""Tests for OAuth authorization endpoint."""
@pytest.mark.asyncio
async def test_authorize_oauth_disabled(self, client):
"""Test authorization when OAuth is disabled."""
with patch("app.api.routes.oauth.settings") as mock_settings:
mock_settings.OAUTH_ENABLED = False
response = await client.get(
"/api/v1/oauth/authorize/google",
params={"redirect_uri": "http://localhost:3000/callback"},
)
assert response.status_code == 400
assert "not enabled" in get_error_message(response.json())
@pytest.mark.asyncio
async def test_authorize_invalid_provider(self, client):
"""Test authorization with invalid provider."""
with patch("app.api.routes.oauth.settings") as mock_settings:
mock_settings.OAUTH_ENABLED = True
response = await client.get(
"/api/v1/oauth/authorize/invalid_provider",
params={"redirect_uri": "http://localhost:3000/callback"},
)
assert response.status_code == 400
@pytest.mark.asyncio
async def test_authorize_provider_not_configured(self, client):
"""Test authorization when provider credentials are not configured."""
# OAuth is enabled but no providers are configured
with (
patch("app.api.routes.oauth.settings") as mock_route_settings,
patch("app.services.oauth_service.settings") as mock_service_settings,
):
mock_route_settings.OAUTH_ENABLED = True
mock_service_settings.OAUTH_ENABLED = True
mock_service_settings.enabled_oauth_providers = [] # No providers configured
response = await client.get(
"/api/v1/oauth/authorize/google",
params={"redirect_uri": "http://localhost:3000/callback"},
)
# Should fail because google is not in enabled_oauth_providers
assert response.status_code == 400
class TestOAuthCallback:
"""Tests for OAuth callback endpoint."""
@pytest.mark.asyncio
async def test_callback_oauth_disabled(self, client):
"""Test callback when OAuth is disabled."""
with patch("app.api.routes.oauth.settings") as mock_settings:
mock_settings.OAUTH_ENABLED = False
response = await client.post(
"/api/v1/oauth/callback/google",
params={"redirect_uri": "http://localhost:3000/callback"},
json={"code": "auth_code", "state": "state_param"},
)
assert response.status_code == 400
assert "not enabled" in get_error_message(response.json())
@pytest.mark.asyncio
async def test_callback_invalid_state(self, client):
"""Test callback with invalid state."""
with patch("app.api.routes.oauth.settings") as mock_settings:
mock_settings.OAUTH_ENABLED = True
response = await client.post(
"/api/v1/oauth/callback/google",
params={"redirect_uri": "http://localhost:3000/callback"},
json={"code": "auth_code", "state": "invalid_state"},
)
assert response.status_code == 401
assert "Invalid or expired" in get_error_message(response.json())
class TestOAuthAccounts:
"""Tests for OAuth accounts management endpoints."""
@pytest.mark.asyncio
async def test_list_accounts_unauthenticated(self, client):
"""Test listing accounts without authentication."""
response = await client.get("/api/v1/oauth/accounts")
assert response.status_code == 401
@pytest.mark.asyncio
async def test_list_accounts_empty(self, client, user_token):
"""Test listing accounts when user has none."""
response = await client.get(
"/api/v1/oauth/accounts",
headers={"Authorization": f"Bearer {user_token}"},
)
assert response.status_code == 200
data = response.json()
assert data["accounts"] == []
@pytest.mark.asyncio
async def test_list_accounts_with_linked(
self, client, user_token, async_test_user, async_test_db
):
"""Test listing accounts when user has linked accounts."""
_test_engine, AsyncTestingSessionLocal = async_test_db
# Create OAuth account for the user
async with AsyncTestingSessionLocal() as session:
account_data = OAuthAccountCreate(
user_id=async_test_user.id,
provider="google",
provider_user_id="google_test_123",
provider_email="user@gmail.com",
)
await oauth_account.create_account(session, obj_in=account_data)
response = await client.get(
"/api/v1/oauth/accounts",
headers={"Authorization": f"Bearer {user_token}"},
)
assert response.status_code == 200
data = response.json()
assert len(data["accounts"]) == 1
assert data["accounts"][0]["provider"] == "google"
@pytest.mark.asyncio
async def test_unlink_account_unauthenticated(self, client):
"""Test unlinking account without authentication."""
response = await client.delete("/api/v1/oauth/accounts/google")
assert response.status_code == 401
@pytest.mark.asyncio
async def test_unlink_account_not_found(self, client, user_token):
"""Test unlinking non-existent account."""
response = await client.delete(
"/api/v1/oauth/accounts/google",
headers={"Authorization": f"Bearer {user_token}"},
)
assert response.status_code == 400
# Error message contains "No google account found to unlink"
error_msg = get_error_message(response.json()).lower()
assert "google" in error_msg and ("found" in error_msg or "unlink" in error_msg)
@pytest.mark.asyncio
async def test_unlink_account_oauth_only_user_blocked(self, client, async_test_db):
"""Test that OAuth-only users can't unlink their only provider."""
_test_engine, AsyncTestingSessionLocal = async_test_db
# Create OAuth-only user (no password)
from app.core.auth import create_access_token
from app.models.user import User
async with AsyncTestingSessionLocal() as session:
oauth_user = User(
id=uuid4(),
email="oauthonly@example.com",
password_hash=None, # OAuth-only
first_name="OAuth",
is_active=True,
)
session.add(oauth_user)
await session.commit()
# Link one OAuth account
account_data = OAuthAccountCreate(
user_id=oauth_user.id,
provider="google",
provider_user_id="google_only_123",
provider_email="oauthonly@gmail.com",
)
await oauth_account.create_account(session, obj_in=account_data)
# Create token for this user
token = create_access_token(
subject=str(oauth_user.id),
claims={"email": oauth_user.email, "first_name": oauth_user.first_name},
)
# Try to unlink - should fail
response = await client.delete(
"/api/v1/oauth/accounts/google",
headers={"Authorization": f"Bearer {token}"},
)
assert response.status_code == 400
assert "Cannot unlink" in get_error_message(response.json())
class TestOAuthLink:
"""Tests for OAuth account linking endpoint."""
@pytest.mark.asyncio
async def test_link_unauthenticated(self, client):
"""Test linking without authentication."""
response = await client.post(
"/api/v1/oauth/link/google",
params={"redirect_uri": "http://localhost:3000/callback"},
)
assert response.status_code == 401
@pytest.mark.asyncio
async def test_link_already_linked(
self, client, user_token, async_test_user, async_test_db
):
"""Test linking when provider is already linked."""
_test_engine, AsyncTestingSessionLocal = async_test_db
# Create existing link
async with AsyncTestingSessionLocal() as session:
account_data = OAuthAccountCreate(
user_id=async_test_user.id,
provider="google",
provider_user_id="google_existing",
)
await oauth_account.create_account(session, obj_in=account_data)
# Mock settings to enable OAuth
with patch("app.api.routes.oauth.settings") as mock_settings:
mock_settings.OAUTH_ENABLED = True
response = await client.post(
"/api/v1/oauth/link/google",
params={"redirect_uri": "http://localhost:3000/callback"},
headers={"Authorization": f"Bearer {user_token}"},
)
assert response.status_code == 400
assert "already" in get_error_message(response.json()).lower()
class TestOAuthProviderEndpoints:
"""Tests for OAuth provider mode endpoints."""
@pytest.mark.asyncio
async def test_server_metadata_disabled(self, client):
"""Test server metadata when provider mode is disabled."""
with patch("app.api.routes.oauth_provider.settings") as mock_settings:
mock_settings.OAUTH_PROVIDER_ENABLED = False
response = await client.get(
"/api/v1/oauth/.well-known/oauth-authorization-server"
)
assert response.status_code == 404
@pytest.mark.asyncio
async def test_server_metadata_enabled(self, client):
"""Test server metadata when provider mode is enabled."""
with patch("app.api.routes.oauth_provider.settings") as mock_settings:
mock_settings.OAUTH_PROVIDER_ENABLED = True
mock_settings.OAUTH_ISSUER = "https://api.example.com"
response = await client.get(
"/api/v1/oauth/.well-known/oauth-authorization-server"
)
assert response.status_code == 200
data = response.json()
assert data["issuer"] == "https://api.example.com"
assert "authorization_endpoint" in data
assert "token_endpoint" in data
@pytest.mark.asyncio
async def test_provider_authorize_disabled(self, client):
"""Test provider authorize endpoint when disabled."""
with patch("app.api.routes.oauth_provider.settings") as mock_settings:
mock_settings.OAUTH_PROVIDER_ENABLED = False
response = await client.get(
"/api/v1/oauth/provider/authorize",
params={
"response_type": "code",
"client_id": "test_client",
"redirect_uri": "http://localhost:3000/callback",
},
)
assert response.status_code == 404
@pytest.mark.asyncio
async def test_provider_token_disabled(self, client):
"""Test provider token endpoint when disabled."""
with patch("app.api.routes.oauth_provider.settings") as mock_settings:
mock_settings.OAUTH_PROVIDER_ENABLED = False
response = await client.post(
"/api/v1/oauth/provider/token",
data={
"grant_type": "authorization_code",
"code": "test_code",
},
)
assert response.status_code == 404
@pytest.mark.asyncio
async def test_provider_authorize_skeleton(self, client, async_test_db):
"""Test provider authorize returns not implemented (skeleton)."""
_test_engine, AsyncTestingSessionLocal = async_test_db
# Create a test client
from app.crud.oauth import oauth_client
from app.schemas.oauth import OAuthClientCreate
async with AsyncTestingSessionLocal() as session:
client_data = OAuthClientCreate(
client_name="Test App",
redirect_uris=["http://localhost:3000/callback"],
allowed_scopes=["read:users"],
)
test_client, _ = await oauth_client.create_client(
session, obj_in=client_data
)
test_client_id = test_client.client_id
with patch("app.api.routes.oauth_provider.settings") as mock_settings:
mock_settings.OAUTH_PROVIDER_ENABLED = True
response = await client.get(
"/api/v1/oauth/provider/authorize",
params={
"response_type": "code",
"client_id": test_client_id,
"redirect_uri": "http://localhost:3000/callback",
},
)
# Should return 501 Not Implemented (skeleton)
assert response.status_code == 501
@pytest.mark.asyncio
async def test_provider_token_skeleton(self, client):
"""Test provider token returns not implemented (skeleton)."""
with patch("app.api.routes.oauth_provider.settings") as mock_settings:
mock_settings.OAUTH_PROVIDER_ENABLED = True
response = await client.post(
"/api/v1/oauth/provider/token",
data={
"grant_type": "authorization_code",
"code": "test_code",
},
)
# Should return 501 Not Implemented (skeleton)
assert response.status_code == 501

View File

@@ -169,10 +169,17 @@ class TestJWTConfiguration:
class TestProjectConfiguration:
"""Tests for project-level configuration"""
def test_project_name_default(self):
"""Test that project name is set correctly"""
def test_project_name_can_be_set(self):
"""Test that project name can be explicitly set"""
settings = Settings(SECRET_KEY="a" * 32, PROJECT_NAME="TestApp")
assert settings.PROJECT_NAME == "TestApp"
def test_project_name_is_set(self):
"""Test that project name has a value (from default or environment)"""
settings = Settings(SECRET_KEY="a" * 32)
assert settings.PROJECT_NAME == "PragmaStack"
# PROJECT_NAME should be a non-empty string
assert isinstance(settings.PROJECT_NAME, str)
assert len(settings.PROJECT_NAME) > 0
def test_api_version_string(self):
"""Test that API version string is correct"""

View File

@@ -0,0 +1,537 @@
# tests/crud/test_oauth.py
"""
Comprehensive tests for OAuth CRUD operations.
"""
from datetime import UTC, datetime, timedelta
import pytest
from app.crud.oauth import oauth_account, oauth_client, oauth_state
from app.schemas.oauth import OAuthAccountCreate, OAuthClientCreate, OAuthStateCreate
class TestOAuthAccountCRUD:
"""Tests for OAuth account CRUD operations."""
@pytest.mark.asyncio
async def test_create_account(self, async_test_db, async_test_user):
"""Test creating an OAuth account link."""
_test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session:
account_data = OAuthAccountCreate(
user_id=async_test_user.id,
provider="google",
provider_user_id="google_123456",
provider_email="user@gmail.com",
)
account = await oauth_account.create_account(session, obj_in=account_data)
assert account is not None
assert account.provider == "google"
assert account.provider_user_id == "google_123456"
assert account.user_id == async_test_user.id
@pytest.mark.asyncio
async def test_create_account_same_provider_twice_fails(
self, async_test_db, async_test_user
):
"""Test creating same OAuth account for same user twice raises error."""
_test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session:
account_data = OAuthAccountCreate(
user_id=async_test_user.id,
provider="google",
provider_user_id="google_dup_123",
provider_email="user@gmail.com",
)
await oauth_account.create_account(session, obj_in=account_data)
# Try to create same account again (same provider + provider_user_id)
async with AsyncTestingSessionLocal() as session:
account_data2 = OAuthAccountCreate(
user_id=async_test_user.id, # Same user
provider="google",
provider_user_id="google_dup_123", # Same provider_user_id
provider_email="user@gmail.com",
)
# SQLite returns different error message than PostgreSQL
with pytest.raises(
ValueError, match="(already linked|UNIQUE constraint failed)"
):
await oauth_account.create_account(session, obj_in=account_data2)
@pytest.mark.asyncio
async def test_get_by_provider_id(self, async_test_db, async_test_user):
"""Test getting OAuth account by provider and provider user ID."""
_test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session:
account_data = OAuthAccountCreate(
user_id=async_test_user.id,
provider="github",
provider_user_id="github_789",
provider_email="user@github.com",
)
await oauth_account.create_account(session, obj_in=account_data)
async with AsyncTestingSessionLocal() as session:
result = await oauth_account.get_by_provider_id(
session,
provider="github",
provider_user_id="github_789",
)
assert result is not None
assert result.provider == "github"
assert result.user is not None # Eager loaded
@pytest.mark.asyncio
async def test_get_by_provider_id_not_found(self, async_test_db):
"""Test getting non-existent OAuth account returns None."""
_test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session:
result = await oauth_account.get_by_provider_id(
session,
provider="google",
provider_user_id="nonexistent",
)
assert result is None
@pytest.mark.asyncio
async def test_get_user_accounts(self, async_test_db, async_test_user):
"""Test getting all OAuth accounts for a user."""
_test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session:
# Create two accounts for the same user
for provider in ["google", "github"]:
account_data = OAuthAccountCreate(
user_id=async_test_user.id,
provider=provider,
provider_user_id=f"{provider}_user_123",
provider_email=f"user@{provider}.com",
)
await oauth_account.create_account(session, obj_in=account_data)
async with AsyncTestingSessionLocal() as session:
accounts = await oauth_account.get_user_accounts(
session, user_id=async_test_user.id
)
assert len(accounts) == 2
providers = {a.provider for a in accounts}
assert providers == {"google", "github"}
@pytest.mark.asyncio
async def test_get_user_account_by_provider(self, async_test_db, async_test_user):
"""Test getting specific OAuth account for user and provider."""
_test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session:
account_data = OAuthAccountCreate(
user_id=async_test_user.id,
provider="google",
provider_user_id="google_specific",
)
await oauth_account.create_account(session, obj_in=account_data)
async with AsyncTestingSessionLocal() as session:
result = await oauth_account.get_user_account_by_provider(
session,
user_id=async_test_user.id,
provider="google",
)
assert result is not None
assert result.provider == "google"
# Test not found
result2 = await oauth_account.get_user_account_by_provider(
session,
user_id=async_test_user.id,
provider="github", # Not linked
)
assert result2 is None
@pytest.mark.asyncio
async def test_delete_account(self, async_test_db, async_test_user):
"""Test deleting an OAuth account link."""
_test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session:
account_data = OAuthAccountCreate(
user_id=async_test_user.id,
provider="google",
provider_user_id="google_to_delete",
)
await oauth_account.create_account(session, obj_in=account_data)
async with AsyncTestingSessionLocal() as session:
deleted = await oauth_account.delete_account(
session,
user_id=async_test_user.id,
provider="google",
)
assert deleted is True
# Verify deletion
async with AsyncTestingSessionLocal() as session:
result = await oauth_account.get_user_account_by_provider(
session,
user_id=async_test_user.id,
provider="google",
)
assert result is None
@pytest.mark.asyncio
async def test_delete_account_not_found(self, async_test_db, async_test_user):
"""Test deleting non-existent account returns False."""
_test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session:
deleted = await oauth_account.delete_account(
session,
user_id=async_test_user.id,
provider="nonexistent",
)
assert deleted is False
@pytest.mark.asyncio
async def test_get_by_provider_email(self, async_test_db, async_test_user):
"""Test getting OAuth account by provider and email."""
_test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session:
account_data = OAuthAccountCreate(
user_id=async_test_user.id,
provider="google",
provider_user_id="google_email_test",
provider_email="unique@gmail.com",
)
await oauth_account.create_account(session, obj_in=account_data)
async with AsyncTestingSessionLocal() as session:
result = await oauth_account.get_by_provider_email(
session,
provider="google",
email="unique@gmail.com",
)
assert result is not None
assert result.provider_email == "unique@gmail.com"
# Test not found
result2 = await oauth_account.get_by_provider_email(
session,
provider="google",
email="nonexistent@gmail.com",
)
assert result2 is None
@pytest.mark.asyncio
async def test_update_tokens(self, async_test_db, async_test_user):
"""Test updating OAuth tokens."""
from datetime import UTC, datetime, timedelta
_test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session:
account_data = OAuthAccountCreate(
user_id=async_test_user.id,
provider="google",
provider_user_id="google_token_test",
)
account = await oauth_account.create_account(session, obj_in=account_data)
async with AsyncTestingSessionLocal() as session:
# Get the account first
account = await oauth_account.get_by_provider_id(
session, provider="google", provider_user_id="google_token_test"
)
assert account is not None
# Update tokens
new_expires = datetime.now(UTC) + timedelta(hours=1)
updated = await oauth_account.update_tokens(
session,
account=account,
access_token_encrypted="new_access_token",
refresh_token_encrypted="new_refresh_token",
token_expires_at=new_expires,
)
assert updated.access_token_encrypted == "new_access_token"
assert updated.refresh_token_encrypted == "new_refresh_token"
class TestOAuthStateCRUD:
"""Tests for OAuth state CRUD operations."""
@pytest.mark.asyncio
async def test_create_state(self, async_test_db):
"""Test creating OAuth state."""
_test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session:
state_data = OAuthStateCreate(
state="random_state_123",
code_verifier="pkce_verifier",
nonce="oidc_nonce",
provider="google",
redirect_uri="http://localhost:3000/callback",
expires_at=datetime.now(UTC) + timedelta(minutes=10),
)
state = await oauth_state.create_state(session, obj_in=state_data)
assert state is not None
assert state.state == "random_state_123"
assert state.code_verifier == "pkce_verifier"
assert state.provider == "google"
@pytest.mark.asyncio
async def test_get_and_consume_state(self, async_test_db):
"""Test getting and consuming OAuth state."""
_test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session:
state_data = OAuthStateCreate(
state="consume_state_123",
provider="github",
expires_at=datetime.now(UTC) + timedelta(minutes=10),
)
await oauth_state.create_state(session, obj_in=state_data)
# Consume the state
async with AsyncTestingSessionLocal() as session:
result = await oauth_state.get_and_consume_state(
session, state="consume_state_123"
)
assert result is not None
assert result.provider == "github"
# Try to consume again - should be None (already consumed)
async with AsyncTestingSessionLocal() as session:
result2 = await oauth_state.get_and_consume_state(
session, state="consume_state_123"
)
assert result2 is None
@pytest.mark.asyncio
async def test_get_and_consume_expired_state(self, async_test_db):
"""Test consuming expired state returns None."""
_test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session:
# Create expired state
state_data = OAuthStateCreate(
state="expired_state_123",
provider="google",
expires_at=datetime.now(UTC) - timedelta(minutes=1), # Already expired
)
await oauth_state.create_state(session, obj_in=state_data)
async with AsyncTestingSessionLocal() as session:
result = await oauth_state.get_and_consume_state(
session, state="expired_state_123"
)
assert result is None
@pytest.mark.asyncio
async def test_cleanup_expired_states(self, async_test_db):
"""Test cleaning up expired OAuth states."""
_test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session:
# Create expired state
expired_state = OAuthStateCreate(
state="cleanup_expired",
provider="google",
expires_at=datetime.now(UTC) - timedelta(minutes=5),
)
await oauth_state.create_state(session, obj_in=expired_state)
# Create valid state
valid_state = OAuthStateCreate(
state="cleanup_valid",
provider="google",
expires_at=datetime.now(UTC) + timedelta(minutes=10),
)
await oauth_state.create_state(session, obj_in=valid_state)
# Cleanup
async with AsyncTestingSessionLocal() as session:
count = await oauth_state.cleanup_expired(session)
assert count == 1
# Verify only expired was deleted
async with AsyncTestingSessionLocal() as session:
result = await oauth_state.get_and_consume_state(
session, state="cleanup_valid"
)
assert result is not None
class TestOAuthClientCRUD:
"""Tests for OAuth client CRUD operations (provider mode)."""
@pytest.mark.asyncio
async def test_create_public_client(self, async_test_db):
"""Test creating a public OAuth client."""
_test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session:
client_data = OAuthClientCreate(
client_name="Test MCP App",
client_description="A test application",
redirect_uris=["http://localhost:3000/callback"],
allowed_scopes=["read:users"],
client_type="public",
)
client, secret = await oauth_client.create_client(
session, obj_in=client_data
)
assert client is not None
assert client.client_name == "Test MCP App"
assert client.client_type == "public"
assert secret is None # Public clients don't have secrets
@pytest.mark.asyncio
async def test_create_confidential_client(self, async_test_db):
"""Test creating a confidential OAuth client."""
_test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session:
client_data = OAuthClientCreate(
client_name="Confidential App",
redirect_uris=["http://localhost:8080/callback"],
allowed_scopes=["read:users", "write:users"],
client_type="confidential",
)
client, secret = await oauth_client.create_client(
session, obj_in=client_data
)
assert client is not None
assert client.client_type == "confidential"
assert secret is not None # Confidential clients have secrets
assert len(secret) > 20 # Should be a reasonably long secret
@pytest.mark.asyncio
async def test_get_by_client_id(self, async_test_db):
"""Test getting OAuth client by client_id."""
_test_engine, AsyncTestingSessionLocal = async_test_db
created_client_id = None
async with AsyncTestingSessionLocal() as session:
client_data = OAuthClientCreate(
client_name="Lookup Test",
redirect_uris=["http://localhost:3000/callback"],
allowed_scopes=["read:users"],
)
client, _ = await oauth_client.create_client(session, obj_in=client_data)
created_client_id = client.client_id
async with AsyncTestingSessionLocal() as session:
result = await oauth_client.get_by_client_id(
session, client_id=created_client_id
)
assert result is not None
assert result.client_name == "Lookup Test"
@pytest.mark.asyncio
async def test_get_inactive_client_not_found(self, async_test_db):
"""Test getting inactive OAuth client returns None."""
_test_engine, AsyncTestingSessionLocal = async_test_db
created_client_id = None
async with AsyncTestingSessionLocal() as session:
client_data = OAuthClientCreate(
client_name="Inactive Client",
redirect_uris=["http://localhost:3000/callback"],
allowed_scopes=["read:users"],
)
client, _ = await oauth_client.create_client(session, obj_in=client_data)
created_client_id = client.client_id
# Deactivate
await oauth_client.deactivate_client(session, client_id=created_client_id)
async with AsyncTestingSessionLocal() as session:
result = await oauth_client.get_by_client_id(
session, client_id=created_client_id
)
assert result is None # Inactive clients not returned
@pytest.mark.asyncio
async def test_validate_redirect_uri(self, async_test_db):
"""Test redirect URI validation."""
_test_engine, AsyncTestingSessionLocal = async_test_db
created_client_id = None
async with AsyncTestingSessionLocal() as session:
client_data = OAuthClientCreate(
client_name="URI Test",
redirect_uris=[
"http://localhost:3000/callback",
"http://localhost:8080/oauth",
],
allowed_scopes=["read:users"],
)
client, _ = await oauth_client.create_client(session, obj_in=client_data)
created_client_id = client.client_id
async with AsyncTestingSessionLocal() as session:
# Valid URI
valid = await oauth_client.validate_redirect_uri(
session,
client_id=created_client_id,
redirect_uri="http://localhost:3000/callback",
)
assert valid is True
# Invalid URI
invalid = await oauth_client.validate_redirect_uri(
session,
client_id=created_client_id,
redirect_uri="http://evil.com/callback",
)
assert invalid is False
@pytest.mark.asyncio
async def test_verify_client_secret(self, async_test_db):
"""Test client secret verification."""
_test_engine, AsyncTestingSessionLocal = async_test_db
created_client_id = None
created_secret = None
async with AsyncTestingSessionLocal() as session:
client_data = OAuthClientCreate(
client_name="Secret Test",
redirect_uris=["http://localhost:3000/callback"],
allowed_scopes=["read:users"],
client_type="confidential",
)
client, secret = await oauth_client.create_client(
session, obj_in=client_data
)
created_client_id = client.client_id
created_secret = secret
async with AsyncTestingSessionLocal() as session:
# Valid secret
valid = await oauth_client.verify_client_secret(
session,
client_id=created_client_id,
client_secret=created_secret,
)
assert valid is True
# Invalid secret
invalid = await oauth_client.verify_client_secret(
session,
client_id=created_client_id,
client_secret="wrong_secret",
)
assert invalid is False

View File

@@ -154,18 +154,25 @@ def test_user_required_fields(db_session):
db_session.commit()
db_session.rollback()
# Missing password_hash
def test_user_oauth_only_without_password(db_session):
"""Test that OAuth-only users can be created without password_hash."""
# OAuth-only users don't have a password set
user_no_password = User(
id=uuid.uuid4(),
email="nopassword@example.com",
# password_hash is missing
first_name="Test",
email="oauthonly@example.com",
password_hash=None, # OAuth-only user
first_name="OAuth",
last_name="User",
)
db_session.add(user_no_password)
with pytest.raises(IntegrityError):
db_session.commit()
db_session.rollback()
db_session.commit()
# Retrieve and verify
retrieved = db_session.query(User).filter_by(email="oauthonly@example.com").first()
assert retrieved is not None
assert retrieved.password_hash is None
assert retrieved.has_password is False # Test has_password property
def test_user_defaults(db_session):

View File

@@ -0,0 +1,403 @@
# tests/services/test_oauth_service.py
"""
Tests for OAuthService covering authorization URL creation,
callback handling, and account management.
"""
from datetime import UTC, datetime, timedelta
from unittest.mock import patch
from uuid import uuid4
import pytest
from app.core.exceptions import AuthenticationError
from app.crud.oauth import oauth_account, oauth_state
from app.schemas.oauth import OAuthAccountCreate, OAuthStateCreate
from app.services.oauth_service import OAUTH_PROVIDERS, OAuthService
class TestGetEnabledProviders:
"""Tests for get_enabled_providers method."""
def test_returns_empty_when_disabled(self):
"""Test returns empty providers when OAuth is disabled."""
with patch("app.services.oauth_service.settings") as mock_settings:
mock_settings.OAUTH_ENABLED = False
mock_settings.enabled_oauth_providers = []
result = OAuthService.get_enabled_providers()
assert result.enabled is False
assert result.providers == []
def test_returns_configured_providers(self):
"""Test returns configured providers when enabled."""
with patch("app.services.oauth_service.settings") as mock_settings:
mock_settings.OAUTH_ENABLED = True
mock_settings.enabled_oauth_providers = ["google", "github"]
result = OAuthService.get_enabled_providers()
assert result.enabled is True
assert len(result.providers) == 2
provider_names = [p.provider for p in result.providers]
assert "google" in provider_names
assert "github" in provider_names
def test_filters_unknown_providers(self):
"""Test filters out unknown providers from list."""
with patch("app.services.oauth_service.settings") as mock_settings:
mock_settings.OAUTH_ENABLED = True
mock_settings.enabled_oauth_providers = ["google", "unknown_provider"]
result = OAuthService.get_enabled_providers()
assert result.enabled is True
assert len(result.providers) == 1
assert result.providers[0].provider == "google"
class TestGetProviderCredentials:
"""Tests for _get_provider_credentials method."""
def test_returns_google_credentials(self):
"""Test returns Google credentials when configured."""
with patch("app.services.oauth_service.settings") as mock_settings:
mock_settings.OAUTH_GOOGLE_CLIENT_ID = "google_client_id"
mock_settings.OAUTH_GOOGLE_CLIENT_SECRET = "google_secret"
client_id, secret = OAuthService._get_provider_credentials("google")
assert client_id == "google_client_id"
assert secret == "google_secret"
def test_returns_github_credentials(self):
"""Test returns GitHub credentials when configured."""
with patch("app.services.oauth_service.settings") as mock_settings:
mock_settings.OAUTH_GITHUB_CLIENT_ID = "github_client_id"
mock_settings.OAUTH_GITHUB_CLIENT_SECRET = "github_secret"
client_id, secret = OAuthService._get_provider_credentials("github")
assert client_id == "github_client_id"
assert secret == "github_secret"
def test_raises_for_unknown_provider(self):
"""Test raises error for unknown provider."""
with pytest.raises(AuthenticationError, match="Unknown OAuth provider"):
OAuthService._get_provider_credentials("unknown")
def test_raises_when_credentials_not_configured(self):
"""Test raises error when credentials are not configured."""
with patch("app.services.oauth_service.settings") as mock_settings:
mock_settings.OAUTH_GOOGLE_CLIENT_ID = None
mock_settings.OAUTH_GOOGLE_CLIENT_SECRET = "secret"
with pytest.raises(AuthenticationError, match="not configured"):
OAuthService._get_provider_credentials("google")
class TestCreateAuthorizationUrl:
"""Tests for create_authorization_url method."""
@pytest.mark.asyncio
async def test_raises_when_oauth_disabled(self, async_test_db):
"""Test raises error when OAuth is disabled."""
_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session:
with patch("app.services.oauth_service.settings") as mock_settings:
mock_settings.OAUTH_ENABLED = False
with pytest.raises(AuthenticationError, match="not enabled"):
await OAuthService.create_authorization_url(
session,
provider="google",
redirect_uri="http://localhost:3000/callback",
)
@pytest.mark.asyncio
async def test_raises_for_unknown_provider(self, async_test_db):
"""Test raises error for unknown provider."""
_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session:
with patch("app.services.oauth_service.settings") as mock_settings:
mock_settings.OAUTH_ENABLED = True
with pytest.raises(AuthenticationError, match="Unknown OAuth provider"):
await OAuthService.create_authorization_url(
session,
provider="unknown",
redirect_uri="http://localhost:3000/callback",
)
@pytest.mark.asyncio
async def test_raises_when_provider_not_enabled(self, async_test_db):
"""Test raises error when provider is not in enabled list."""
_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session:
with patch("app.services.oauth_service.settings") as mock_settings:
mock_settings.OAUTH_ENABLED = True
mock_settings.enabled_oauth_providers = ["github"] # google not enabled
with pytest.raises(AuthenticationError, match="not enabled"):
await OAuthService.create_authorization_url(
session,
provider="google",
redirect_uri="http://localhost:3000/callback",
)
@pytest.mark.asyncio
async def test_creates_authorization_url_for_google(self, async_test_db):
"""Test creates authorization URL for Google with PKCE."""
_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session:
with patch("app.services.oauth_service.settings") as mock_settings:
mock_settings.OAUTH_ENABLED = True
mock_settings.enabled_oauth_providers = ["google"]
mock_settings.OAUTH_GOOGLE_CLIENT_ID = "google_client_id"
mock_settings.OAUTH_GOOGLE_CLIENT_SECRET = "google_secret"
mock_settings.OAUTH_STATE_EXPIRE_MINUTES = 10
url, state = await OAuthService.create_authorization_url(
session,
provider="google",
redirect_uri="http://localhost:3000/callback",
)
assert url is not None
assert "accounts.google.com" in url
assert state is not None
assert len(state) > 20
@pytest.mark.asyncio
async def test_creates_authorization_url_for_github(self, async_test_db):
"""Test creates authorization URL for GitHub."""
_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session:
with patch("app.services.oauth_service.settings") as mock_settings:
mock_settings.OAUTH_ENABLED = True
mock_settings.enabled_oauth_providers = ["github"]
mock_settings.OAUTH_GITHUB_CLIENT_ID = "github_client_id"
mock_settings.OAUTH_GITHUB_CLIENT_SECRET = "github_secret"
mock_settings.OAUTH_STATE_EXPIRE_MINUTES = 10
url, state = await OAuthService.create_authorization_url(
session,
provider="github",
redirect_uri="http://localhost:3000/callback",
)
assert url is not None
assert "github.com/login/oauth/authorize" in url
assert state is not None
class TestHandleCallback:
"""Tests for handle_callback method."""
@pytest.mark.asyncio
async def test_raises_for_invalid_state(self, async_test_db):
"""Test raises error for invalid/expired state."""
_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session:
with pytest.raises(AuthenticationError, match="Invalid or expired"):
await OAuthService.handle_callback(
session,
code="auth_code",
state="invalid_state",
redirect_uri="http://localhost:3000/callback",
)
class TestUnlinkProvider:
"""Tests for unlink_provider method."""
@pytest.mark.asyncio
async def test_unlink_with_password_succeeds(self, async_test_db, async_test_user):
"""Test unlinking succeeds when user has password."""
_engine, AsyncTestingSessionLocal = async_test_db
# Create OAuth account
async with AsyncTestingSessionLocal() as session:
account_data = OAuthAccountCreate(
user_id=async_test_user.id,
provider="google",
provider_user_id="google_123",
)
await oauth_account.create_account(session, obj_in=account_data)
# Unlink (user has password)
async with AsyncTestingSessionLocal() as session:
# Need to get fresh user instance
from sqlalchemy import select
from app.models.user import User
result = await session.execute(
select(User).where(User.id == async_test_user.id)
)
user = result.scalar_one()
success = await OAuthService.unlink_provider(
session, user=user, provider="google"
)
assert success is True
# Verify unlinked
async with AsyncTestingSessionLocal() as session:
account = await oauth_account.get_user_account_by_provider(
session, user_id=async_test_user.id, provider="google"
)
assert account is None
@pytest.mark.asyncio
async def test_unlink_not_found_raises(self, async_test_db, async_test_user):
"""Test unlinking non-existent provider raises error."""
_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session:
from sqlalchemy import select
from app.models.user import User
result = await session.execute(
select(User).where(User.id == async_test_user.id)
)
user = result.scalar_one()
with pytest.raises(AuthenticationError, match="No google account found"):
await OAuthService.unlink_provider(
session, user=user, provider="google"
)
@pytest.mark.asyncio
async def test_unlink_oauth_only_user_blocked(self, async_test_db):
"""Test unlinking fails for OAuth-only user with single provider."""
_engine, AsyncTestingSessionLocal = async_test_db
# Create OAuth-only user
from app.models.user import User
async with AsyncTestingSessionLocal() as session:
oauth_user = User(
id=uuid4(),
email="oauthonly@example.com",
password_hash=None, # No password
first_name="OAuth",
is_active=True,
)
session.add(oauth_user)
await session.commit()
# Link single OAuth account
account_data = OAuthAccountCreate(
user_id=oauth_user.id,
provider="google",
provider_user_id="google_only",
)
await oauth_account.create_account(session, obj_in=account_data)
# Try to unlink
async with AsyncTestingSessionLocal() as session:
from sqlalchemy import select
result = await session.execute(
select(User).where(User.email == "oauthonly@example.com")
)
user = result.scalar_one()
with pytest.raises(AuthenticationError, match="Cannot unlink"):
await OAuthService.unlink_provider(
session, user=user, provider="google"
)
@pytest.mark.asyncio
async def test_unlink_with_multiple_providers_succeeds(self, async_test_db):
"""Test unlinking succeeds when user has multiple providers."""
_engine, AsyncTestingSessionLocal = async_test_db
from app.models.user import User
# Create OAuth-only user with multiple providers
async with AsyncTestingSessionLocal() as session:
oauth_user = User(
id=uuid4(),
email="multiauth@example.com",
password_hash=None,
first_name="Multi",
is_active=True,
)
session.add(oauth_user)
await session.commit()
# Link multiple OAuth accounts
for provider in ["google", "github"]:
account_data = OAuthAccountCreate(
user_id=oauth_user.id,
provider=provider,
provider_user_id=f"{provider}_user",
)
await oauth_account.create_account(session, obj_in=account_data)
# Unlink one provider (should succeed)
async with AsyncTestingSessionLocal() as session:
from sqlalchemy import select
result = await session.execute(
select(User).where(User.email == "multiauth@example.com")
)
user = result.scalar_one()
success = await OAuthService.unlink_provider(
session, user=user, provider="google"
)
assert success is True
class TestCleanupExpiredStates:
"""Tests for cleanup_expired_states method."""
@pytest.mark.asyncio
async def test_cleanup_removes_expired_states(self, async_test_db):
"""Test cleanup removes expired states."""
_engine, AsyncTestingSessionLocal = async_test_db
# Create expired state
async with AsyncTestingSessionLocal() as session:
expired_state = OAuthStateCreate(
state="expired_cleanup_test",
provider="google",
expires_at=datetime.now(UTC) - timedelta(minutes=5),
)
await oauth_state.create_state(session, obj_in=expired_state)
# Run cleanup
async with AsyncTestingSessionLocal() as session:
count = await OAuthService.cleanup_expired_states(session)
assert count >= 1
class TestProviderConfigs:
"""Tests for provider configuration constants."""
def test_google_provider_config(self):
"""Test Google provider configuration is correct."""
config = OAUTH_PROVIDERS.get("google")
assert config is not None
assert config["name"] == "Google"
assert "accounts.google.com" in config["authorize_url"]
assert config["supports_pkce"] is True
def test_github_provider_config(self):
"""Test GitHub provider configuration is correct."""
config = OAUTH_PROVIDERS.get("github")
assert config is not None
assert config["name"] == "GitHub"
assert "github.com" in config["authorize_url"]
assert config["supports_pkce"] is False

View File

@@ -15,6 +15,9 @@ class TestInitDb:
"""Tests for init_db functionality."""
@pytest.mark.asyncio
@pytest.mark.skip(
reason="SQLite doesn't support UUID type binding - requires PostgreSQL"
)
async def test_init_db_creates_superuser_when_not_exists(self, async_test_db):
"""Test that init_db creates a superuser when one doesn't exist."""
_test_engine, SessionLocal = async_test_db
@@ -63,6 +66,9 @@ class TestInitDb:
assert user.email == "testuser@example.com"
@pytest.mark.asyncio
@pytest.mark.skip(
reason="SQLite doesn't support UUID type binding - requires PostgreSQL"
)
async def test_init_db_uses_default_credentials(self, async_test_db):
"""Test that init_db uses default credentials when env vars not set."""
_test_engine, SessionLocal = async_test_db

14
backend/uv.lock generated
View File

@@ -96,6 +96,18 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/c8/a4/cec76b3389c4c5ff66301cd100fe88c318563ec8a520e0b2e792b5b84972/asyncpg-0.30.0-cp313-cp313-win_amd64.whl", hash = "sha256:f59b430b8e27557c3fb9869222559f7417ced18688375825f8f12302c34e915e", size = 621623, upload-time = "2024-10-20T00:30:09.024Z" },
]
[[package]]
name = "authlib"
version = "1.6.5"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "cryptography" },
]
sdist = { url = "https://files.pythonhosted.org/packages/cd/3f/1d3bbd0bf23bdd99276d4def22f29c27a914067b4cf66f753ff9b8bbd0f3/authlib-1.6.5.tar.gz", hash = "sha256:6aaf9c79b7cc96c900f0b284061691c5d4e61221640a948fe690b556a6d6d10b", size = 164553, upload-time = "2025-10-02T13:36:09.489Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/f8/aa/5082412d1ee302e9e7d80b6949bc4d2a8fa1149aaab610c5fc24709605d6/authlib-1.6.5-py2.py3-none-any.whl", hash = "sha256:3e0e0507807f842b02175507bdee8957a1d5707fd4afb17c32fb43fee90b6e3a", size = 243608, upload-time = "2025-10-02T13:36:07.637Z" },
]
[[package]]
name = "bcrypt"
version = "4.2.1"
@@ -443,6 +455,7 @@ dependencies = [
{ name = "alembic" },
{ name = "apscheduler" },
{ name = "asyncpg" },
{ name = "authlib" },
{ name = "bcrypt" },
{ name = "cryptography" },
{ name = "email-validator" },
@@ -485,6 +498,7 @@ requires-dist = [
{ name = "alembic", specifier = ">=1.14.1" },
{ name = "apscheduler", specifier = "==3.11.0" },
{ name = "asyncpg", specifier = ">=0.29.0" },
{ name = "authlib", specifier = ">=1.3.0" },
{ name = "bcrypt", specifier = "==4.2.1" },
{ name = "cryptography", specifier = "==44.0.1" },
{ name = "email-validator", specifier = ">=2.1.0.post1" },