Add OAuth provider mode and MCP integration

- Introduced full OAuth 2.0 Authorization Server functionality for MCP clients.
- Updated documentation with details on endpoints, scopes, and consent management.
- Added a new frontend OAuth consent page for user authorization flows.
- Implemented database models for authorization codes, refresh tokens, and user consents.
- Created unit tests for service methods (PKCE verification, client validation, scope handling).
- Included comprehensive integration tests for OAuth provider workflows.
This commit is contained in:
Felipe Cardoso
2025-11-25 23:18:19 +01:00
parent fbb030da69
commit 48f052200f
12 changed files with 3335 additions and 142 deletions

View File

@@ -69,6 +69,27 @@ Default superuser (change in production):
- `get_optional_current_user`: Accepts authenticated or anonymous
- `get_current_superuser`: Requires superuser flag
### OAuth Provider Mode (MCP Integration)
Full OAuth 2.0 Authorization Server for MCP (Model Context Protocol) clients:
- **Authorization Code Flow with PKCE**: RFC 7636 compliant
- **JWT access tokens**: Self-contained, no DB lookup required
- **Opaque refresh tokens**: Stored hashed in database, supports rotation
- **Token introspection**: RFC 7662 compliant endpoint
- **Token revocation**: RFC 7009 compliant endpoint
- **Server metadata**: RFC 8414 compliant discovery endpoint
- **Consent management**: User can review and revoke app permissions
**API endpoints:**
- `GET /.well-known/oauth-authorization-server` - Server metadata
- `GET /oauth/provider/authorize` - Authorization endpoint
- `POST /oauth/provider/authorize/consent` - Consent submission
- `POST /oauth/provider/token` - Token endpoint
- `POST /oauth/provider/revoke` - Token revocation
- `POST /oauth/provider/introspect` - Token introspection
- Client management endpoints (admin only)
**Scopes supported:** `openid`, `profile`, `email`, `read:users`, `write:users`, `admin`
### Database Pattern
- **Async SQLAlchemy 2.0** with PostgreSQL
- **Connection pooling**: 20 base connections, 50 max overflow
@@ -238,6 +259,7 @@ docker-compose exec backend python -c "from app.init_db import init_db; import a
### Completed Features ✅
- Authentication system (JWT with refresh tokens, OAuth/social login)
- **OAuth Provider Mode (MCP-ready)**: Full OAuth 2.0 Authorization Server
- Session management (device tracking, revocation)
- User management (CRUD, password change)
- Organization system (multi-tenant with RBAC)

View File

@@ -0,0 +1,194 @@
"""Add OAuth provider models for MCP integration.
Revision ID: f8c3d2e1a4b5
Revises: d5a7b2c9e1f3
Create Date: 2025-01-15 10:00:00.000000
This migration adds tables for OAuth provider mode:
- oauth_authorization_codes: Temporary authorization codes
- oauth_provider_refresh_tokens: Long-lived refresh tokens
- oauth_consents: User consent records
"""
import sqlalchemy as sa
from alembic import op
from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision = "f8c3d2e1a4b5"
down_revision = "d5a7b2c9e1f3"
branch_labels = None
depends_on = None
def upgrade() -> None:
# Create oauth_authorization_codes table
op.create_table(
"oauth_authorization_codes",
sa.Column("id", postgresql.UUID(as_uuid=True), nullable=False),
sa.Column("code", sa.String(128), nullable=False),
sa.Column("client_id", sa.String(64), nullable=False),
sa.Column("user_id", postgresql.UUID(as_uuid=True), nullable=False),
sa.Column("redirect_uri", sa.String(2048), nullable=False),
sa.Column("scope", sa.String(1000), nullable=False, server_default=""),
sa.Column("code_challenge", sa.String(128), nullable=True),
sa.Column("code_challenge_method", sa.String(10), nullable=True),
sa.Column("state", sa.String(256), nullable=True),
sa.Column("nonce", sa.String(256), nullable=True),
sa.Column("expires_at", sa.DateTime(timezone=True), nullable=False),
sa.Column("used", sa.Boolean(), nullable=False, server_default="false"),
sa.Column(
"created_at",
sa.DateTime(timezone=True),
server_default=sa.text("now()"),
nullable=False,
),
sa.Column(
"updated_at",
sa.DateTime(timezone=True),
server_default=sa.text("now()"),
nullable=False,
),
sa.ForeignKeyConstraint(
["client_id"],
["oauth_clients.client_id"],
ondelete="CASCADE",
),
sa.ForeignKeyConstraint(
["user_id"],
["users.id"],
ondelete="CASCADE",
),
sa.PrimaryKeyConstraint("id"),
)
op.create_index(
"ix_oauth_authorization_codes_code",
"oauth_authorization_codes",
["code"],
unique=True,
)
op.create_index(
"ix_oauth_authorization_codes_expires_at",
"oauth_authorization_codes",
["expires_at"],
)
op.create_index(
"ix_oauth_authorization_codes_client_user",
"oauth_authorization_codes",
["client_id", "user_id"],
)
# Create oauth_provider_refresh_tokens table
op.create_table(
"oauth_provider_refresh_tokens",
sa.Column("id", postgresql.UUID(as_uuid=True), nullable=False),
sa.Column("token_hash", sa.String(64), nullable=False),
sa.Column("jti", sa.String(64), nullable=False),
sa.Column("client_id", sa.String(64), nullable=False),
sa.Column("user_id", postgresql.UUID(as_uuid=True), nullable=False),
sa.Column("scope", sa.String(1000), nullable=False, server_default=""),
sa.Column("expires_at", sa.DateTime(timezone=True), nullable=False),
sa.Column("revoked", sa.Boolean(), nullable=False, server_default="false"),
sa.Column("last_used_at", sa.DateTime(timezone=True), nullable=True),
sa.Column("device_info", sa.String(500), nullable=True),
sa.Column("ip_address", sa.String(45), nullable=True),
sa.Column(
"created_at",
sa.DateTime(timezone=True),
server_default=sa.text("now()"),
nullable=False,
),
sa.Column(
"updated_at",
sa.DateTime(timezone=True),
server_default=sa.text("now()"),
nullable=False,
),
sa.ForeignKeyConstraint(
["client_id"],
["oauth_clients.client_id"],
ondelete="CASCADE",
),
sa.ForeignKeyConstraint(
["user_id"],
["users.id"],
ondelete="CASCADE",
),
sa.PrimaryKeyConstraint("id"),
)
op.create_index(
"ix_oauth_provider_refresh_tokens_token_hash",
"oauth_provider_refresh_tokens",
["token_hash"],
unique=True,
)
op.create_index(
"ix_oauth_provider_refresh_tokens_jti",
"oauth_provider_refresh_tokens",
["jti"],
unique=True,
)
op.create_index(
"ix_oauth_provider_refresh_tokens_expires_at",
"oauth_provider_refresh_tokens",
["expires_at"],
)
op.create_index(
"ix_oauth_provider_refresh_tokens_client_user",
"oauth_provider_refresh_tokens",
["client_id", "user_id"],
)
op.create_index(
"ix_oauth_provider_refresh_tokens_user_revoked",
"oauth_provider_refresh_tokens",
["user_id", "revoked"],
)
op.create_index(
"ix_oauth_provider_refresh_tokens_revoked",
"oauth_provider_refresh_tokens",
["revoked"],
)
# Create oauth_consents table
op.create_table(
"oauth_consents",
sa.Column("id", postgresql.UUID(as_uuid=True), nullable=False),
sa.Column("user_id", postgresql.UUID(as_uuid=True), nullable=False),
sa.Column("client_id", sa.String(64), nullable=False),
sa.Column("granted_scopes", sa.String(1000), nullable=False, server_default=""),
sa.Column(
"created_at",
sa.DateTime(timezone=True),
server_default=sa.text("now()"),
nullable=False,
),
sa.Column(
"updated_at",
sa.DateTime(timezone=True),
server_default=sa.text("now()"),
nullable=False,
),
sa.ForeignKeyConstraint(
["client_id"],
["oauth_clients.client_id"],
ondelete="CASCADE",
),
sa.ForeignKeyConstraint(
["user_id"],
["users.id"],
ondelete="CASCADE",
),
sa.PrimaryKeyConstraint("id"),
)
op.create_index(
"ix_oauth_consents_user_client",
"oauth_consents",
["user_id", "client_id"],
unique=True,
)
def downgrade() -> None:
op.drop_table("oauth_consents")
op.drop_table("oauth_provider_refresh_tokens")
op.drop_table("oauth_authorization_codes")

View File

@@ -1,37 +1,63 @@
# app/api/routes/oauth_provider.py
"""
OAuth Provider routes (Authorization Server mode).
OAuth Provider routes (Authorization Server mode) for MCP integration.
This is a skeleton implementation for MCP (Model Context Protocol) client authentication.
Provides basic OAuth 2.0 endpoints that can be expanded for full functionality.
Endpoints:
Implements OAuth 2.0 Authorization Server endpoints:
- GET /.well-known/oauth-authorization-server - Server metadata (RFC 8414)
- GET /oauth/provider/authorize - Authorization endpoint (skeleton)
- POST /oauth/provider/token - Token endpoint (skeleton)
- POST /oauth/provider/revoke - Token revocation endpoint (skeleton)
- GET /oauth/provider/authorize - Authorization endpoint
- POST /oauth/provider/token - Token endpoint
- POST /oauth/provider/revoke - Token revocation (RFC 7009)
- POST /oauth/provider/introspect - Token introspection (RFC 7662)
- Client management endpoints
NOTE: This is intentionally minimal. Full implementation should include:
- Complete authorization code flow
- Refresh token handling
- Scope validation
- Client authentication
- PKCE support
Security features:
- PKCE required for public clients (S256)
- CSRF protection via state parameter
- Secure token handling
- Rate limiting on sensitive endpoints
"""
import logging
from typing import Any
from urllib.parse import urlencode
from fastapi import APIRouter, Depends, Form, HTTPException, Query, status
from fastapi import APIRouter, Depends, Form, HTTPException, Query, Request, status
from fastapi.responses import RedirectResponse
from slowapi import Limiter
from slowapi.util import get_remote_address
from sqlalchemy.ext.asyncio import AsyncSession
from app.api.dependencies.auth import get_current_active_user, get_current_superuser
from app.core.config import settings
from app.core.database import get_db
from app.crud import oauth_client
from app.schemas.oauth import OAuthServerMetadata
from app.crud import oauth_client as oauth_client_crud
from app.models.user import User
from app.schemas.oauth import (
OAuthClientCreate,
OAuthClientResponse,
OAuthServerMetadata,
OAuthTokenIntrospectionResponse,
OAuthTokenResponse,
)
from app.services import oauth_provider_service as provider_service
router = APIRouter()
logger = logging.getLogger(__name__)
limiter = Limiter(key_func=get_remote_address)
def require_provider_enabled():
"""Dependency to check if OAuth provider mode is enabled."""
if not settings.OAUTH_PROVIDER_ENABLED:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="OAuth provider mode is not enabled. Set OAUTH_PROVIDER_ENABLED=true",
)
# ============================================================================
# Server Metadata (RFC 8414)
# ============================================================================
@router.get(
@@ -42,24 +68,15 @@ logger = logging.getLogger(__name__)
OAuth 2.0 Authorization Server Metadata (RFC 8414).
Returns server metadata including supported endpoints, scopes,
and capabilities for MCP clients.
and capabilities. MCP clients use this to discover the server.
""",
operation_id="get_oauth_server_metadata",
tags=["OAuth Provider"],
)
async def get_server_metadata() -> Any:
"""
Get OAuth 2.0 server metadata.
This endpoint is used by MCP clients to discover the authorization
server's capabilities.
"""
if not settings.OAUTH_PROVIDER_ENABLED:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="OAuth provider mode is not enabled",
)
async def get_server_metadata(
_: None = Depends(require_provider_enabled),
) -> OAuthServerMetadata:
"""Get OAuth 2.0 server metadata."""
base_url = settings.OAUTH_ISSUER.rstrip("/")
return OAuthServerMetadata(
@@ -67,7 +84,8 @@ async def get_server_metadata() -> Any:
authorization_endpoint=f"{base_url}/api/v1/oauth/provider/authorize",
token_endpoint=f"{base_url}/api/v1/oauth/provider/token",
revocation_endpoint=f"{base_url}/api/v1/oauth/provider/revoke",
registration_endpoint=None, # Dynamic registration not implemented
introspection_endpoint=f"{base_url}/api/v1/oauth/provider/introspect",
registration_endpoint=None, # Dynamic registration not supported
scopes_supported=[
"openid",
"profile",
@@ -76,148 +94,441 @@ async def get_server_metadata() -> Any:
"write:users",
"read:organizations",
"write:organizations",
"admin",
],
response_types_supported=["code"],
grant_types_supported=["authorization_code", "refresh_token"],
code_challenge_methods_supported=["S256"],
token_endpoint_auth_methods_supported=[
"client_secret_basic",
"client_secret_post",
"none", # For public clients with PKCE
],
)
# ============================================================================
# Authorization Endpoint
# ============================================================================
@router.get(
"/provider/authorize",
summary="Authorization Endpoint (Skeleton)",
summary="Authorization Endpoint",
description="""
OAuth 2.0 Authorization Endpoint.
**NOTE**: This is a skeleton implementation. In a full implementation,
this would:
1. Validate client_id and redirect_uri
2. Display consent screen to user
3. Generate authorization code
4. Redirect back to client with code
Initiates the authorization code flow:
1. Validates client and parameters
2. Checks if user is authenticated (redirects to login if not)
3. Checks existing consent
4. Redirects to consent page if needed
5. Issues authorization code and redirects back to client
Currently returns a 501 Not Implemented response.
Required parameters:
- response_type: Must be "code"
- client_id: Registered client ID
- redirect_uri: Must match registered URI
Recommended parameters:
- state: CSRF protection
- code_challenge + code_challenge_method: PKCE (required for public clients)
- scope: Requested permissions
""",
operation_id="oauth_provider_authorize",
tags=["OAuth Provider"],
)
@limiter.limit("30/minute")
async def authorize(
request: Request,
response_type: str = Query(..., description="Must be 'code'"),
client_id: str = Query(..., description="OAuth client ID"),
redirect_uri: str = Query(..., description="Redirect URI"),
scope: str = Query(default="", description="Requested scopes"),
scope: str = Query(default="", description="Requested scopes (space-separated)"),
state: str = Query(default="", description="CSRF state parameter"),
code_challenge: str | None = Query(default=None, description="PKCE code challenge"),
code_challenge_method: str | None = Query(
default=None, description="PKCE method (S256)"
),
nonce: str | None = Query(default=None, description="OpenID Connect nonce"),
db: AsyncSession = Depends(get_db),
_: None = Depends(require_provider_enabled),
current_user: User | None = Depends(get_current_active_user),
) -> Any:
"""
Authorization endpoint (skeleton).
Authorization endpoint - initiates OAuth flow.
In a full implementation, this would:
1. Validate the client and redirect URI
2. Authenticate the user (if not already)
3. Show consent screen
4. Generate authorization code
5. Redirect to redirect_uri with code
If user is not authenticated, redirects to login with return URL.
If user has not consented, redirects to consent page.
If all checks pass, generates code and redirects to client.
"""
if not settings.OAUTH_PROVIDER_ENABLED:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="OAuth provider mode is not enabled",
)
# Validate client exists
client = await oauth_client.get_by_client_id(db, client_id=client_id)
if not client:
# Validate response_type
if response_type != "code":
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="invalid_client: Unknown client_id",
detail="invalid_request: response_type must be 'code'",
)
# Validate redirect_uri
if redirect_uri not in (client.redirect_uris or []):
# Validate PKCE method if provided
if code_challenge_method and code_challenge_method not in ["S256", "plain"]:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="invalid_request: Invalid redirect_uri",
detail="invalid_request: code_challenge_method must be 'S256'",
)
# Skeleton: Return not implemented
# Full implementation would redirect to consent screen
raise HTTPException(
status_code=status.HTTP_501_NOT_IMPLEMENTED,
detail="Authorization endpoint not fully implemented. "
"This is a skeleton for MCP integration.",
# Validate client
try:
client = await provider_service.get_client(db, client_id)
if not client:
raise provider_service.InvalidClientError("Unknown client_id")
provider_service.validate_redirect_uri(client, redirect_uri)
except provider_service.OAuthProviderError as e:
# For client/redirect errors, we can't safely redirect - show error
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"{e.error}: {e.error_description}",
)
# Validate and filter scopes
try:
requested_scopes = provider_service.parse_scope(scope)
valid_scopes = provider_service.validate_scopes(client, requested_scopes)
except provider_service.InvalidScopeError as e:
# Redirect with error
error_params = {
"error": e.error,
"error_description": e.error_description,
}
if state:
error_params["state"] = state
return RedirectResponse(
url=f"{redirect_uri}?{urlencode(error_params)}",
status_code=status.HTTP_302_FOUND,
)
# Public clients MUST use PKCE
if client.client_type == "public":
if not code_challenge or code_challenge_method != "S256":
error_params = {
"error": "invalid_request",
"error_description": "PKCE with S256 is required for public clients",
}
if state:
error_params["state"] = state
return RedirectResponse(
url=f"{redirect_uri}?{urlencode(error_params)}",
status_code=status.HTTP_302_FOUND,
)
# If user is not authenticated, redirect to login
if not current_user:
# Store authorization request in session and redirect to login
# The frontend will handle the return URL
login_url = f"{settings.FRONTEND_URL}/login"
return_params = urlencode({
"oauth_authorize": "true",
"client_id": client_id,
"redirect_uri": redirect_uri,
"scope": " ".join(valid_scopes),
"state": state,
"code_challenge": code_challenge or "",
"code_challenge_method": code_challenge_method or "",
"nonce": nonce or "",
})
return RedirectResponse(
url=f"{login_url}?return_to=/auth/consent?{return_params}",
status_code=status.HTTP_302_FOUND,
)
# Check if user has already consented
has_consent = await provider_service.check_consent(
db, current_user.id, client_id, valid_scopes
)
if not has_consent:
# Redirect to consent page
consent_params = urlencode({
"client_id": client_id,
"client_name": client.client_name,
"redirect_uri": redirect_uri,
"scope": " ".join(valid_scopes),
"state": state,
"code_challenge": code_challenge or "",
"code_challenge_method": code_challenge_method or "",
"nonce": nonce or "",
})
return RedirectResponse(
url=f"{settings.FRONTEND_URL}/auth/consent?{consent_params}",
status_code=status.HTTP_302_FOUND,
)
# User is authenticated and has consented - issue authorization code
try:
code = await provider_service.create_authorization_code(
db=db,
client=client,
user=current_user,
redirect_uri=redirect_uri,
scope=" ".join(valid_scopes),
code_challenge=code_challenge,
code_challenge_method=code_challenge_method,
state=state,
nonce=nonce,
)
except provider_service.OAuthProviderError as e:
error_params = {
"error": e.error,
"error_description": e.error_description,
}
if state:
error_params["state"] = state
return RedirectResponse(
url=f"{redirect_uri}?{urlencode(error_params)}",
status_code=status.HTTP_302_FOUND,
)
# Success - redirect with code
success_params = {"code": code}
if state:
success_params["state"] = state
return RedirectResponse(
url=f"{redirect_uri}?{urlencode(success_params)}",
status_code=status.HTTP_302_FOUND,
)
@router.post(
"/provider/authorize/consent",
summary="Submit Authorization Consent",
description="""
Submit user consent for OAuth authorization.
Called by the consent page after user approves or denies.
""",
operation_id="oauth_provider_consent",
tags=["OAuth Provider"],
)
@limiter.limit("30/minute")
async def submit_consent(
request: Request,
approved: bool = Form(..., description="Whether user approved"),
client_id: str = Form(..., description="OAuth client ID"),
redirect_uri: str = Form(..., description="Redirect URI"),
scope: str = Form(default="", description="Granted scopes"),
state: str = Form(default="", description="CSRF state parameter"),
code_challenge: str | None = Form(default=None),
code_challenge_method: str | None = Form(default=None),
nonce: str | None = Form(default=None),
db: AsyncSession = Depends(get_db),
_: None = Depends(require_provider_enabled),
current_user: User = Depends(get_current_active_user),
) -> Any:
"""Process consent form submission."""
# Validate client
try:
client = await provider_service.get_client(db, client_id)
if not client:
raise provider_service.InvalidClientError("Unknown client_id")
provider_service.validate_redirect_uri(client, redirect_uri)
except provider_service.OAuthProviderError as e:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"{e.error}: {e.error_description}",
)
# If user denied, redirect with error
if not approved:
error_params = {
"error": "access_denied",
"error_description": "User denied authorization",
}
if state:
error_params["state"] = state
return RedirectResponse(
url=f"{redirect_uri}?{urlencode(error_params)}",
status_code=status.HTTP_302_FOUND,
)
# Parse and validate scopes
granted_scopes = provider_service.parse_scope(scope)
valid_scopes = provider_service.validate_scopes(client, granted_scopes)
# Record consent
await provider_service.grant_consent(
db, current_user.id, client_id, valid_scopes
)
# Generate authorization code
try:
code = await provider_service.create_authorization_code(
db=db,
client=client,
user=current_user,
redirect_uri=redirect_uri,
scope=" ".join(valid_scopes),
code_challenge=code_challenge,
code_challenge_method=code_challenge_method,
state=state,
nonce=nonce,
)
except provider_service.OAuthProviderError as e:
error_params = {
"error": e.error,
"error_description": e.error_description,
}
if state:
error_params["state"] = state
return RedirectResponse(
url=f"{redirect_uri}?{urlencode(error_params)}",
status_code=status.HTTP_302_FOUND,
)
# Success
success_params = {"code": code}
if state:
success_params["state"] = state
return RedirectResponse(
url=f"{redirect_uri}?{urlencode(success_params)}",
status_code=status.HTTP_302_FOUND,
)
# ============================================================================
# Token Endpoint
# ============================================================================
@router.post(
"/provider/token",
summary="Token Endpoint (Skeleton)",
response_model=OAuthTokenResponse,
summary="Token Endpoint",
description="""
OAuth 2.0 Token Endpoint.
**NOTE**: This is a skeleton implementation. In a full implementation,
this would exchange authorization codes for access tokens.
Supports:
- authorization_code: Exchange code for tokens
- refresh_token: Refresh access token
Currently returns a 501 Not Implemented response.
Client authentication:
- Confidential clients: client_secret (Basic auth or POST body)
- Public clients: No secret, but PKCE code_verifier required
""",
operation_id="oauth_provider_token",
tags=["OAuth Provider"],
)
@limiter.limit("60/minute")
async def token(
grant_type: str = Form(..., description="Grant type (authorization_code)"),
request: Request,
grant_type: str = Form(..., description="Grant type"),
code: str | None = Form(default=None, description="Authorization code"),
redirect_uri: str | None = Form(default=None, description="Redirect URI"),
client_id: str | None = Form(default=None, description="Client ID"),
client_secret: str | None = Form(default=None, description="Client secret"),
code_verifier: str | None = Form(default=None, description="PKCE code verifier"),
refresh_token: str | None = Form(default=None, description="Refresh token"),
scope: str | None = Form(default=None, description="Scope (for refresh)"),
db: AsyncSession = Depends(get_db),
) -> Any:
"""
Token endpoint (skeleton).
_: None = Depends(require_provider_enabled),
) -> OAuthTokenResponse:
"""Token endpoint - exchange code for tokens or refresh."""
# Extract client credentials from Basic auth if not in body
if not client_id:
auth_header = request.headers.get("Authorization", "")
if auth_header.startswith("Basic "):
import base64
try:
decoded = base64.b64decode(auth_header[6:]).decode()
client_id, client_secret = decoded.split(":", 1)
except Exception: # noqa: S110 - Intentional: malformed Basic auth falls back to form body
pass
Supported grant types (when fully implemented):
- authorization_code: Exchange code for tokens
- refresh_token: Refresh access token
"""
if not settings.OAUTH_PROVIDER_ENABLED:
if not client_id:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="OAuth provider mode is not enabled",
status_code=status.HTTP_401_UNAUTHORIZED,
detail="invalid_client: client_id required",
headers={"WWW-Authenticate": "Basic"},
)
if grant_type not in ["authorization_code", "refresh_token"]:
# Get device info
device_info = request.headers.get("User-Agent", "")[:500]
ip_address = get_remote_address(request)
try:
if grant_type == "authorization_code":
if not code:
raise provider_service.InvalidRequestError("code required")
if not redirect_uri:
raise provider_service.InvalidRequestError("redirect_uri required")
result = await provider_service.exchange_authorization_code(
db=db,
code=code,
client_id=client_id,
redirect_uri=redirect_uri,
code_verifier=code_verifier,
client_secret=client_secret,
device_info=device_info,
ip_address=ip_address,
)
elif grant_type == "refresh_token":
if not refresh_token:
raise provider_service.InvalidRequestError("refresh_token required")
result = await provider_service.refresh_tokens(
db=db,
refresh_token=refresh_token,
client_id=client_id,
client_secret=client_secret,
scope=scope,
device_info=device_info,
ip_address=ip_address,
)
else:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="unsupported_grant_type: Must be authorization_code or refresh_token",
)
return OAuthTokenResponse(**result)
except provider_service.InvalidClientError as e:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=f"{e.error}: {e.error_description}",
headers={"WWW-Authenticate": "Basic"},
)
except provider_service.OAuthProviderError as e:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="unsupported_grant_type",
detail=f"{e.error}: {e.error_description}",
)
# Skeleton: Return not implemented
raise HTTPException(
status_code=status.HTTP_501_NOT_IMPLEMENTED,
detail="Token endpoint not fully implemented. "
"This is a skeleton for MCP integration.",
)
# ============================================================================
# Token Revocation (RFC 7009)
# ============================================================================
@router.post(
"/provider/revoke",
summary="Token Revocation Endpoint (Skeleton)",
status_code=status.HTTP_200_OK,
summary="Token Revocation Endpoint",
description="""
OAuth 2.0 Token Revocation Endpoint (RFC 7009).
**NOTE**: This is a skeleton implementation.
Currently returns a 501 Not Implemented response.
Revokes an access token or refresh token.
Always returns 200 OK (even if token is invalid) per spec.
""",
operation_id="oauth_provider_revoke",
tags=["OAuth Provider"],
)
@limiter.limit("30/minute")
async def revoke(
request: Request,
token: str = Form(..., description="Token to revoke"),
token_type_hint: str | None = Form(
default=None, description="Token type hint (access_token, refresh_token)"
@@ -225,88 +536,286 @@ async def revoke(
client_id: str | None = Form(default=None, description="Client ID"),
client_secret: str | None = Form(default=None, description="Client secret"),
db: AsyncSession = Depends(get_db),
) -> Any:
"""
Token revocation endpoint (skeleton).
_: None = Depends(require_provider_enabled),
) -> dict[str, str]:
"""Revoke a token."""
# Extract client credentials from Basic auth if not in body
if not client_id:
auth_header = request.headers.get("Authorization", "")
if auth_header.startswith("Basic "):
import base64
try:
decoded = base64.b64decode(auth_header[6:]).decode()
client_id, client_secret = decoded.split(":", 1)
except Exception: # noqa: S110 - Intentional: malformed Basic auth falls back to form body
pass
In a full implementation, this would invalidate the specified token.
"""
if not settings.OAUTH_PROVIDER_ENABLED:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="OAuth provider mode is not enabled",
try:
await provider_service.revoke_token(
db=db,
token=token,
token_type_hint=token_type_hint,
client_id=client_id,
client_secret=client_secret,
)
except provider_service.InvalidClientError:
# Per RFC 7009, we should return 200 OK even for errors
# But client authentication errors can return 401
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="invalid_client",
headers={"WWW-Authenticate": "Basic"},
)
except Exception as e:
# Log but don't expose errors per RFC 7009
logger.warning(f"Token revocation error: {e}")
# Skeleton: Return not implemented
raise HTTPException(
status_code=status.HTTP_501_NOT_IMPLEMENTED,
detail="Revocation endpoint not fully implemented. "
"This is a skeleton for MCP integration.",
)
# Always return 200 OK per RFC 7009
return {"status": "ok"}
# ============================================================================
# Client Management (Admin only)
# Token Introspection (RFC 7662)
# ============================================================================
@router.post(
"/provider/introspect",
response_model=OAuthTokenIntrospectionResponse,
summary="Token Introspection Endpoint",
description="""
OAuth 2.0 Token Introspection Endpoint (RFC 7662).
Allows resource servers to query the authorization server
to determine the active state and metadata of a token.
""",
operation_id="oauth_provider_introspect",
tags=["OAuth Provider"],
)
@limiter.limit("120/minute")
async def introspect(
request: Request,
token: str = Form(..., description="Token to introspect"),
token_type_hint: str | None = Form(
default=None, description="Token type hint (access_token, refresh_token)"
),
client_id: str | None = Form(default=None, description="Client ID"),
client_secret: str | None = Form(default=None, description="Client secret"),
db: AsyncSession = Depends(get_db),
_: None = Depends(require_provider_enabled),
) -> OAuthTokenIntrospectionResponse:
"""Introspect a token."""
# Extract client credentials from Basic auth if not in body
if not client_id:
auth_header = request.headers.get("Authorization", "")
if auth_header.startswith("Basic "):
import base64
try:
decoded = base64.b64decode(auth_header[6:]).decode()
client_id, client_secret = decoded.split(":", 1)
except Exception: # noqa: S110 - Intentional: malformed Basic auth falls back to form body
pass
try:
result = await provider_service.introspect_token(
db=db,
token=token,
token_type_hint=token_type_hint,
client_id=client_id,
client_secret=client_secret,
)
return OAuthTokenIntrospectionResponse(**result)
except provider_service.InvalidClientError:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="invalid_client",
headers={"WWW-Authenticate": "Basic"},
)
except Exception as e:
logger.warning(f"Token introspection error: {e}")
return OAuthTokenIntrospectionResponse(active=False)
# ============================================================================
# Client Management (Admin)
# ============================================================================
@router.post(
"/provider/clients",
summary="Register OAuth Client (Admin)",
response_model=dict,
summary="Register OAuth Client",
description="""
Register a new OAuth client (admin only).
This endpoint allows creating MCP clients that can authenticate
against this API.
Creates an MCP client that can authenticate against this API.
Returns client_id and client_secret (for confidential clients).
**NOTE**: This is a minimal implementation.
**Important:** Store the client_secret securely - it won't be shown again!
""",
operation_id="register_oauth_client",
tags=["OAuth Provider"],
tags=["OAuth Provider Admin"],
)
async def register_client(
client_name: str = Form(..., description="Client application name"),
redirect_uris: str = Form(..., description="Comma-separated list of redirect URIs"),
redirect_uris: str = Form(..., description="Comma-separated redirect URIs"),
client_type: str = Form(default="public", description="public or confidential"),
scopes: str = Form(
default="openid profile email",
description="Allowed scopes (space-separated)",
),
mcp_server_url: str | None = Form(default=None, description="MCP server URL"),
db: AsyncSession = Depends(get_db),
) -> Any:
"""
Register a new OAuth client (skeleton).
In a full implementation, this would require admin authentication.
"""
if not settings.OAUTH_PROVIDER_ENABLED:
_: None = Depends(require_provider_enabled),
current_user: User = Depends(get_current_superuser),
) -> dict:
"""Register a new OAuth client."""
# Parse redirect URIs
uris = [uri.strip() for uri in redirect_uris.split(",") if uri.strip()]
if not uris:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="OAuth provider mode is not enabled",
status_code=status.HTTP_400_BAD_REQUEST,
detail="At least one redirect_uri is required",
)
# NOTE: In production, this should require admin authentication
# For now, this is a skeleton that shows the structure
from app.schemas.oauth import OAuthClientCreate
# Parse scopes
allowed_scopes = [s.strip() for s in scopes.split() if s.strip()]
client_data = OAuthClientCreate(
client_name=client_name,
client_description=None,
redirect_uris=[uri.strip() for uri in redirect_uris.split(",")],
allowed_scopes=["openid", "profile", "email"],
redirect_uris=uris,
allowed_scopes=allowed_scopes,
client_type=client_type,
)
client, secret = await oauth_client.create_client(db, obj_in=client_data)
client, secret = await oauth_client_crud.create_client(db, obj_in=client_data)
# Update MCP server URL if provided
if mcp_server_url:
client.mcp_server_url = mcp_server_url
await db.commit()
result = {
"client_id": client.client_id,
"client_name": client.client_name,
"client_type": client.client_type,
"redirect_uris": client.redirect_uris,
"allowed_scopes": client.allowed_scopes,
}
if secret:
result["client_secret"] = secret
result["warning"] = (
"Store the client_secret securely. It will not be shown again."
"Store the client_secret securely! It will not be shown again."
)
return result
@router.get(
"/provider/clients",
response_model=list[OAuthClientResponse],
summary="List OAuth Clients",
description="List all registered OAuth clients (admin only).",
operation_id="list_oauth_clients",
tags=["OAuth Provider Admin"],
)
async def list_clients(
db: AsyncSession = Depends(get_db),
_: None = Depends(require_provider_enabled),
current_user: User = Depends(get_current_superuser),
) -> list[OAuthClientResponse]:
"""List all OAuth clients."""
clients = await oauth_client_crud.get_all_clients(db)
return [OAuthClientResponse.model_validate(c) for c in clients]
@router.delete(
"/provider/clients/{client_id}",
status_code=status.HTTP_204_NO_CONTENT,
summary="Delete OAuth Client",
description="Delete an OAuth client (admin only). Revokes all tokens.",
operation_id="delete_oauth_client",
tags=["OAuth Provider Admin"],
)
async def delete_client(
client_id: str,
db: AsyncSession = Depends(get_db),
_: None = Depends(require_provider_enabled),
current_user: User = Depends(get_current_superuser),
) -> None:
"""Delete an OAuth client."""
client = await provider_service.get_client(db, client_id)
if not client:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Client not found",
)
await oauth_client_crud.delete_client(db, client_id=client_id)
# ============================================================================
# User Consent Management
# ============================================================================
@router.get(
"/provider/consents",
summary="List My Consents",
description="List OAuth applications the current user has authorized.",
operation_id="list_my_oauth_consents",
tags=["OAuth Provider"],
)
async def list_my_consents(
db: AsyncSession = Depends(get_db),
_: None = Depends(require_provider_enabled),
current_user: User = Depends(get_current_active_user),
) -> list[dict]:
"""List applications the user has authorized."""
from sqlalchemy import select
from app.models.oauth_client import OAuthClient
from app.models.oauth_provider_token import OAuthConsent
result = await db.execute(
select(OAuthConsent, OAuthClient)
.join(OAuthClient, OAuthConsent.client_id == OAuthClient.client_id)
.where(OAuthConsent.user_id == current_user.id)
)
rows = result.all()
return [
{
"client_id": consent.client_id,
"client_name": client.client_name,
"client_description": client.client_description,
"granted_scopes": consent.granted_scopes.split() if consent.granted_scopes else [],
"granted_at": consent.created_at.isoformat(),
}
for consent, client in rows
]
@router.delete(
"/provider/consents/{client_id}",
status_code=status.HTTP_204_NO_CONTENT,
summary="Revoke My Consent",
description="Revoke authorization for an OAuth application. Also revokes all tokens.",
operation_id="revoke_my_oauth_consent",
tags=["OAuth Provider"],
)
async def revoke_my_consent(
client_id: str,
db: AsyncSession = Depends(get_db),
_: None = Depends(require_provider_enabled),
current_user: User = Depends(get_current_active_user),
) -> None:
"""Revoke consent for an application."""
revoked = await provider_service.revoke_consent(db, current_user.id, client_id)
if not revoked:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="No consent found for this client",
)

View File

@@ -643,6 +643,62 @@ class CRUDOAuthClient(CRUDBase[OAuthClient, OAuthClientCreate, EmptySchema]):
logger.error(f"Error verifying client secret: {e!s}")
return False
async def get_all_clients(
self, db: AsyncSession, *, include_inactive: bool = False
) -> list[OAuthClient]:
"""
Get all OAuth clients.
Args:
db: Database session
include_inactive: Whether to include inactive clients
Returns:
List of OAuthClient objects
"""
try:
query = select(OAuthClient).order_by(OAuthClient.created_at.desc())
if not include_inactive:
query = query.where(OAuthClient.is_active == True) # noqa: E712
result = await db.execute(query)
return list(result.scalars().all())
except Exception as e: # pragma: no cover
logger.error(f"Error getting all OAuth clients: {e!s}")
raise
async def delete_client(self, db: AsyncSession, *, client_id: str) -> bool:
"""
Delete an OAuth client permanently.
Note: This will cascade delete related records (tokens, consents, etc.)
due to foreign key constraints.
Args:
db: Database session
client_id: OAuth client ID
Returns:
True if deleted, False if not found
"""
try:
result = await db.execute(
delete(OAuthClient).where(OAuthClient.client_id == client_id)
)
await db.commit()
deleted = result.rowcount > 0
if deleted:
logger.info(f"OAuth client deleted: {client_id}")
else:
logger.warning(f"OAuth client not found for deletion: {client_id}")
return deleted
except Exception as e: # pragma: no cover
await db.rollback()
logger.error(f"Error deleting OAuth client {client_id}: {e!s}")
raise
# ============================================================================
# Singleton instances

View File

@@ -8,9 +8,13 @@ from app.core.database import Base
from .base import TimestampMixin, UUIDMixin
# OAuth models
# OAuth models (client mode - authenticate via Google/GitHub)
from .oauth_account import OAuthAccount
# OAuth provider models (server mode - act as authorization server for MCP)
from .oauth_authorization_code import OAuthAuthorizationCode
from .oauth_client import OAuthClient
from .oauth_provider_token import OAuthConsent, OAuthProviderRefreshToken
from .oauth_state import OAuthState
from .organization import Organization
@@ -22,7 +26,10 @@ from .user_session import UserSession
__all__ = [
"Base",
"OAuthAccount",
"OAuthAuthorizationCode",
"OAuthClient",
"OAuthConsent",
"OAuthProviderRefreshToken",
"OAuthState",
"Organization",
"OrganizationRole",

View File

@@ -0,0 +1,91 @@
"""OAuth authorization code model for OAuth provider mode."""
from datetime import datetime
from sqlalchemy import Boolean, Column, DateTime, ForeignKey, Index, String
from sqlalchemy.dialects.postgresql import UUID
from sqlalchemy.orm import relationship
from .base import Base, TimestampMixin, UUIDMixin
class OAuthAuthorizationCode(Base, UUIDMixin, TimestampMixin):
"""
OAuth 2.0 Authorization Code for the authorization code flow.
Authorization codes are:
- Single-use (marked as used after exchange)
- Short-lived (10 minutes default)
- Bound to specific client, user, redirect_uri
- Support PKCE (code_challenge/code_challenge_method)
Security considerations:
- Code must be cryptographically random (64 chars, URL-safe)
- Must validate redirect_uri matches exactly
- Must verify PKCE code_verifier for public clients
- Must be consumed within expiration time
"""
__tablename__ = "oauth_authorization_codes"
# The authorization code (cryptographically random, URL-safe)
code = Column(String(128), unique=True, nullable=False, index=True)
# Client that requested the code
client_id = Column(
String(64),
ForeignKey("oauth_clients.client_id", ondelete="CASCADE"),
nullable=False,
)
# User who authorized the request
user_id = Column(
UUID(as_uuid=True),
ForeignKey("users.id", ondelete="CASCADE"),
nullable=False,
)
# Redirect URI (must match exactly on token exchange)
redirect_uri = Column(String(2048), nullable=False)
# Granted scopes (space-separated)
scope = Column(String(1000), nullable=False, default="")
# PKCE support (required for public clients)
code_challenge = Column(String(128), nullable=True)
code_challenge_method = Column(String(10), nullable=True) # "S256" or "plain"
# State parameter (for CSRF protection, returned to client)
state = Column(String(256), nullable=True)
# Nonce (for OpenID Connect, included in ID token)
nonce = Column(String(256), nullable=True)
# Expiration (codes are short-lived, typically 10 minutes)
expires_at = Column(DateTime(timezone=True), nullable=False)
# Single-use flag (set to True after successful exchange)
used = Column(Boolean, default=False, nullable=False)
# Relationships
client = relationship("OAuthClient", backref="authorization_codes")
user = relationship("User", backref="oauth_authorization_codes")
# Indexes for efficient cleanup queries
__table_args__ = (
Index("ix_oauth_authorization_codes_expires_at", "expires_at"),
Index("ix_oauth_authorization_codes_client_user", "client_id", "user_id"),
)
def __repr__(self):
return f"<OAuthAuthorizationCode {self.code[:8]}... for {self.client_id}>"
@property
def is_expired(self) -> bool:
"""Check if the authorization code has expired."""
return datetime.utcnow() > self.expires_at.replace(tzinfo=None)
@property
def is_valid(self) -> bool:
"""Check if the authorization code is valid (not used, not expired)."""
return not self.used and not self.is_expired

View File

@@ -0,0 +1,153 @@
"""OAuth provider token models for OAuth provider mode."""
from datetime import datetime
from sqlalchemy import Boolean, Column, DateTime, ForeignKey, Index, String
from sqlalchemy.dialects.postgresql import UUID
from sqlalchemy.orm import relationship
from .base import Base, TimestampMixin, UUIDMixin
class OAuthProviderRefreshToken(Base, UUIDMixin, TimestampMixin):
"""
OAuth 2.0 Refresh Token for the OAuth provider.
Refresh tokens are:
- Opaque (stored as hash in DB, actual token given to client)
- Long-lived (configurable, default 30 days)
- Revocable (via revoked flag or deletion)
- Bound to specific client, user, and scope
Access tokens are JWTs and not stored in DB (self-contained).
This model only tracks refresh tokens for revocation support.
Security considerations:
- Store token hash, not plaintext
- Support token rotation (new refresh token on use)
- Track last used time for security auditing
- Support revocation by user, client, or admin
"""
__tablename__ = "oauth_provider_refresh_tokens"
# Hash of the refresh token (SHA-256)
# We store hash, not plaintext, for security
token_hash = Column(String(64), unique=True, nullable=False, index=True)
# Unique token ID (JTI) - used in JWT access tokens to reference this refresh token
jti = Column(String(64), unique=True, nullable=False, index=True)
# Client that owns this token
client_id = Column(
String(64),
ForeignKey("oauth_clients.client_id", ondelete="CASCADE"),
nullable=False,
)
# User who authorized this token
user_id = Column(
UUID(as_uuid=True),
ForeignKey("users.id", ondelete="CASCADE"),
nullable=False,
)
# Granted scopes (space-separated)
scope = Column(String(1000), nullable=False, default="")
# Token expiration
expires_at = Column(DateTime(timezone=True), nullable=False)
# Revocation flag
revoked = Column(Boolean, default=False, nullable=False, index=True)
# Last used timestamp (for security auditing)
last_used_at = Column(DateTime(timezone=True), nullable=True)
# Device/session info (optional, for user visibility)
device_info = Column(String(500), nullable=True)
ip_address = Column(String(45), nullable=True)
# Relationships
client = relationship("OAuthClient", backref="refresh_tokens")
user = relationship("User", backref="oauth_provider_refresh_tokens")
# Indexes
__table_args__ = (
Index("ix_oauth_provider_refresh_tokens_expires_at", "expires_at"),
Index("ix_oauth_provider_refresh_tokens_client_user", "client_id", "user_id"),
Index(
"ix_oauth_provider_refresh_tokens_user_revoked",
"user_id",
"revoked",
),
)
def __repr__(self):
status = "revoked" if self.revoked else "active"
return f"<OAuthProviderRefreshToken {self.jti[:8]}... ({status})>"
@property
def is_expired(self) -> bool:
"""Check if the refresh token has expired."""
return datetime.utcnow() > self.expires_at.replace(tzinfo=None)
@property
def is_valid(self) -> bool:
"""Check if the refresh token is valid (not revoked, not expired)."""
return not self.revoked and not self.is_expired
class OAuthConsent(Base, UUIDMixin, TimestampMixin):
"""
OAuth consent record - remembers user consent for a client.
When a user grants consent to an OAuth client, we store the record
so they don't have to re-consent on subsequent authorizations
(unless scopes change).
This enables a better UX - users only see consent screen once per client,
unless the client requests additional scopes.
"""
__tablename__ = "oauth_consents"
# User who granted consent
user_id = Column(
UUID(as_uuid=True),
ForeignKey("users.id", ondelete="CASCADE"),
nullable=False,
)
# Client that received consent
client_id = Column(
String(64),
ForeignKey("oauth_clients.client_id", ondelete="CASCADE"),
nullable=False,
)
# Granted scopes (space-separated)
granted_scopes = Column(String(1000), nullable=False, default="")
# Relationships
client = relationship("OAuthClient", backref="consents")
user = relationship("User", backref="oauth_consents")
# Unique constraint: one consent record per user+client
__table_args__ = (
Index(
"ix_oauth_consents_user_client",
"user_id",
"client_id",
unique=True,
),
)
def __repr__(self):
return f"<OAuthConsent user={self.user_id} client={self.client_id}>"
def has_scopes(self, requested_scopes: list[str]) -> bool:
"""Check if all requested scopes are already granted."""
granted = set(self.granted_scopes.split()) if self.granted_scopes else set()
requested = set(requested_scopes)
return requested.issubset(granted)

View File

@@ -284,6 +284,9 @@ class OAuthServerMetadata(BaseModel):
revocation_endpoint: str | None = Field(
None, description="Token revocation endpoint"
)
introspection_endpoint: str | None = Field(
None, description="Token introspection endpoint (RFC 7662)"
)
scopes_supported: list[str] = Field(
default_factory=list, description="Supported scopes"
)
@@ -297,6 +300,10 @@ class OAuthServerMetadata(BaseModel):
code_challenge_methods_supported: list[str] = Field(
default_factory=lambda: ["S256"], description="Supported PKCE methods"
)
token_endpoint_auth_methods_supported: list[str] = Field(
default_factory=lambda: ["client_secret_basic", "client_secret_post", "none"],
description="Supported client authentication methods",
)
model_config = ConfigDict(
json_schema_extra={
@@ -304,10 +311,105 @@ class OAuthServerMetadata(BaseModel):
"issuer": "https://api.example.com",
"authorization_endpoint": "https://api.example.com/oauth/authorize",
"token_endpoint": "https://api.example.com/oauth/token",
"revocation_endpoint": "https://api.example.com/oauth/revoke",
"introspection_endpoint": "https://api.example.com/oauth/introspect",
"scopes_supported": ["openid", "profile", "email", "read:users"],
"response_types_supported": ["code"],
"grant_types_supported": ["authorization_code", "refresh_token"],
"code_challenge_methods_supported": ["S256"],
"token_endpoint_auth_methods_supported": [
"client_secret_basic",
"client_secret_post",
"none",
],
}
}
)
# ============================================================================
# OAuth Token Responses (RFC 6749)
# ============================================================================
class OAuthTokenResponse(BaseModel):
"""OAuth 2.0 Token Response (RFC 6749 Section 5.1)."""
access_token: str = Field(..., description="The access token issued by the server")
token_type: str = Field(
default="Bearer", description="The type of token (typically 'Bearer')"
)
expires_in: int | None = Field(
None, description="Token lifetime in seconds"
)
refresh_token: str | None = Field(
None, description="Refresh token for obtaining new access tokens"
)
scope: str | None = Field(
None, description="Space-separated list of granted scopes"
)
model_config = ConfigDict(
json_schema_extra={
"example": {
"access_token": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9...",
"token_type": "Bearer",
"expires_in": 3600,
"refresh_token": "dGhpcyBpcyBhIHJlZnJlc2ggdG9rZW4...",
"scope": "openid profile email",
}
}
)
class OAuthTokenIntrospectionResponse(BaseModel):
"""OAuth 2.0 Token Introspection Response (RFC 7662)."""
active: bool = Field(
..., description="Whether the token is currently active"
)
scope: str | None = Field(
None, description="Space-separated list of scopes"
)
client_id: str | None = Field(
None, description="Client identifier for the token"
)
username: str | None = Field(
None, description="Human-readable identifier for the resource owner"
)
token_type: str | None = Field(
None, description="Type of the token (e.g., 'Bearer')"
)
exp: int | None = Field(
None, description="Token expiration time (Unix timestamp)"
)
iat: int | None = Field(
None, description="Token issue time (Unix timestamp)"
)
nbf: int | None = Field(
None, description="Token not-before time (Unix timestamp)"
)
sub: str | None = Field(
None, description="Subject of the token (user ID)"
)
aud: str | None = Field(
None, description="Intended audience of the token"
)
iss: str | None = Field(
None, description="Issuer of the token"
)
model_config = ConfigDict(
json_schema_extra={
"example": {
"active": True,
"scope": "openid profile",
"client_id": "client123",
"username": "user@example.com",
"token_type": "Bearer",
"exp": 1735689600,
"iat": 1735686000,
"sub": "user-uuid-here",
}
}
)

File diff suppressed because it is too large Load Diff

View File

@@ -344,8 +344,8 @@ class TestOAuthProviderEndpoints:
assert response.status_code == 404
@pytest.mark.asyncio
async def test_provider_authorize_skeleton(self, client, async_test_db):
"""Test provider authorize returns not implemented (skeleton)."""
async def test_provider_authorize_requires_auth(self, client, async_test_db):
"""Test provider authorize requires authentication."""
_test_engine, AsyncTestingSessionLocal = async_test_db
# Create a test client
@@ -374,12 +374,12 @@ class TestOAuthProviderEndpoints:
"redirect_uri": "http://localhost:3000/callback",
},
)
# Should return 501 Not Implemented (skeleton)
assert response.status_code == 501
# Authorize endpoint requires authentication
assert response.status_code == 401
@pytest.mark.asyncio
async def test_provider_token_skeleton(self, client):
"""Test provider token returns not implemented (skeleton)."""
async def test_provider_token_requires_client_id(self, client):
"""Test provider token requires client_id."""
with patch("app.api.routes.oauth_provider.settings") as mock_settings:
mock_settings.OAUTH_PROVIDER_ENABLED = True
@@ -390,5 +390,5 @@ class TestOAuthProviderEndpoints:
"code": "test_code",
},
)
# Should return 501 Not Implemented (skeleton)
assert response.status_code == 501
# Missing client_id returns 401 (invalid_client)
assert response.status_code == 401

View File

@@ -0,0 +1,726 @@
# tests/services/test_oauth_provider_service.py
"""
Tests for OAuth Provider Service (Authorization Server mode for MCP).
Covers:
- Authorization code creation and exchange
- Token generation, refresh, and revocation
- PKCE verification
- Token introspection (RFC 7662)
- Consent management
- Error handling
"""
import base64
import hashlib
import secrets
from unittest.mock import patch
from uuid import uuid4
import pytest
import pytest_asyncio
from app.models.oauth_client import OAuthClient
from app.models.user import User
from app.services import oauth_provider_service as service
from app.utils.test_utils import setup_async_test_db, teardown_async_test_db
@pytest_asyncio.fixture(scope="function")
async def db():
"""Fixture provides testing engine and session for each test."""
test_engine, AsyncTestingSessionLocal = await setup_async_test_db()
async with AsyncTestingSessionLocal() as session:
yield session
await teardown_async_test_db(test_engine)
@pytest_asyncio.fixture
async def test_user(db):
"""Create a test user."""
user = User(
id=uuid4(),
email="testuser@example.com",
password_hash="$2b$12$test",
first_name="Test",
last_name="User",
is_active=True,
is_superuser=False,
)
db.add(user)
await db.commit()
await db.refresh(user)
return user
@pytest_asyncio.fixture
async def public_client(db):
"""Create a test public OAuth client."""
client = OAuthClient(
id=uuid4(),
client_id="test_public_client",
client_name="Test Public Client",
client_type="public",
redirect_uris=["http://localhost:3000/callback"],
allowed_scopes=["openid", "profile", "email", "read:users"],
is_active=True,
)
db.add(client)
await db.commit()
await db.refresh(client)
return client
@pytest_asyncio.fixture
async def confidential_client(db):
"""Create a test confidential OAuth client."""
secret = "test_client_secret"
secret_hash = hashlib.sha256(secret.encode()).hexdigest()
client = OAuthClient(
id=uuid4(),
client_id="test_confidential_client",
client_name="Test Confidential Client",
client_type="confidential",
client_secret_hash=secret_hash,
redirect_uris=["http://localhost:3000/callback"],
allowed_scopes=["openid", "profile", "email"],
is_active=True,
)
db.add(client)
await db.commit()
await db.refresh(client)
return client, secret
class TestHelperFunctions:
"""Tests for helper functions."""
def test_generate_code_length(self):
"""Test authorization code generation has proper length."""
code = service.generate_code()
assert len(code) > 64 # Base64 encoding of 64 bytes
def test_generate_code_unique(self):
"""Test authorization codes are unique."""
codes = [service.generate_code() for _ in range(100)]
assert len(set(codes)) == 100
def test_generate_token(self):
"""Test token generation."""
token = service.generate_token()
assert len(token) > 32
def test_generate_jti(self):
"""Test JTI generation."""
jti = service.generate_jti()
assert len(jti) > 20
def test_hash_token(self):
"""Test token hashing."""
token = "test_token"
hashed = service.hash_token(token)
assert len(hashed) == 64 # SHA-256 hex digest
def test_hash_token_deterministic(self):
"""Test same token produces same hash."""
token = "test_token"
hash1 = service.hash_token(token)
hash2 = service.hash_token(token)
assert hash1 == hash2
def test_parse_scope(self):
"""Test scope parsing."""
assert service.parse_scope("openid profile email") == [
"openid",
"profile",
"email",
]
assert service.parse_scope("") == []
assert service.parse_scope(" openid profile ") == ["openid", "profile"]
def test_join_scope(self):
"""Test scope joining."""
# Result is sorted and deduplicated
result = service.join_scope(["profile", "openid", "profile"])
assert "openid" in result
assert "profile" in result
class TestPKCEVerification:
"""Tests for PKCE verification."""
def test_verify_pkce_s256_valid(self):
"""Test PKCE verification with S256 method."""
# Generate code_verifier
code_verifier = secrets.token_urlsafe(64)
# Generate code_challenge using S256
digest = hashlib.sha256(code_verifier.encode()).digest()
code_challenge = base64.urlsafe_b64encode(digest).rstrip(b"=").decode()
assert service.verify_pkce(code_verifier, code_challenge, "S256") is True
def test_verify_pkce_s256_invalid(self):
"""Test PKCE verification fails with wrong verifier."""
code_verifier = secrets.token_urlsafe(64)
wrong_verifier = secrets.token_urlsafe(64)
digest = hashlib.sha256(code_verifier.encode()).digest()
code_challenge = base64.urlsafe_b64encode(digest).rstrip(b"=").decode()
assert service.verify_pkce(wrong_verifier, code_challenge, "S256") is False
def test_verify_pkce_plain(self):
"""Test PKCE verification with plain method."""
code_verifier = "test_verifier"
assert service.verify_pkce(code_verifier, code_verifier, "plain") is True
assert service.verify_pkce(code_verifier, "wrong", "plain") is False
def test_verify_pkce_unknown_method(self):
"""Test PKCE verification with unknown method returns False."""
assert service.verify_pkce("verifier", "challenge", "unknown") is False
class TestClientValidation:
"""Tests for client validation."""
@pytest.mark.asyncio
async def test_get_client_success(self, db, public_client):
"""Test getting a valid client."""
client = await service.get_client(db, public_client.client_id)
assert client is not None
assert client.client_id == public_client.client_id
@pytest.mark.asyncio
async def test_get_client_not_found(self, db):
"""Test getting a non-existent client."""
client = await service.get_client(db, "nonexistent")
assert client is None
@pytest.mark.asyncio
async def test_get_client_inactive(self, db, public_client):
"""Test getting an inactive client returns None."""
public_client.is_active = False
await db.commit()
client = await service.get_client(db, public_client.client_id)
assert client is None
@pytest.mark.asyncio
async def test_validate_client_public(self, db, public_client):
"""Test validating a public client."""
client = await service.validate_client(db, public_client.client_id)
assert client.client_id == public_client.client_id
@pytest.mark.asyncio
async def test_validate_client_confidential_with_secret(
self, db, confidential_client
):
"""Test validating a confidential client with correct secret."""
client, secret = confidential_client
validated = await service.validate_client(db, client.client_id, secret)
assert validated.client_id == client.client_id
@pytest.mark.asyncio
async def test_validate_client_confidential_wrong_secret(
self, db, confidential_client
):
"""Test validating a confidential client with wrong secret."""
client, _ = confidential_client
with pytest.raises(service.InvalidClientError, match="Invalid client secret"):
await service.validate_client(db, client.client_id, "wrong_secret")
@pytest.mark.asyncio
async def test_validate_client_confidential_no_secret(self, db, confidential_client):
"""Test validating a confidential client without secret."""
client, _ = confidential_client
with pytest.raises(service.InvalidClientError, match="Client secret required"):
await service.validate_client(db, client.client_id)
def test_validate_redirect_uri_success(self, public_client):
"""Test validating a registered redirect URI."""
# Should not raise
service.validate_redirect_uri(public_client, "http://localhost:3000/callback")
def test_validate_redirect_uri_invalid(self, public_client):
"""Test validating an unregistered redirect URI."""
with pytest.raises(service.InvalidRequestError, match="Invalid redirect_uri"):
service.validate_redirect_uri(public_client, "http://evil.com/callback")
def test_validate_redirect_uri_no_uris(self, public_client):
"""Test validating when client has no URIs."""
public_client.redirect_uris = []
with pytest.raises(service.InvalidRequestError, match="no registered"):
service.validate_redirect_uri(public_client, "http://localhost:3000")
class TestScopeValidation:
"""Tests for scope validation."""
def test_validate_scopes_all_valid(self, public_client):
"""Test validating all valid scopes."""
scopes = service.validate_scopes(public_client, ["openid", "profile"])
assert "openid" in scopes
assert "profile" in scopes
def test_validate_scopes_partial_valid(self, public_client):
"""Test validating with some invalid scopes - filters them out."""
scopes = service.validate_scopes(public_client, ["openid", "invalid_scope"])
assert "openid" in scopes
assert "invalid_scope" not in scopes
def test_validate_scopes_empty_uses_all_allowed(self, public_client):
"""Test empty scope request uses all allowed scopes."""
scopes = service.validate_scopes(public_client, [])
assert set(scopes) == set(public_client.allowed_scopes)
def test_validate_scopes_none_valid(self, public_client):
"""Test validating with no valid scopes raises error."""
with pytest.raises(service.InvalidScopeError):
service.validate_scopes(public_client, ["invalid1", "invalid2"])
class TestAuthorizationCode:
"""Tests for authorization code creation and exchange."""
@pytest.mark.asyncio
async def test_create_authorization_code_public_with_pkce(
self, db, public_client, test_user
):
"""Test creating authorization code for public client with PKCE."""
code_verifier = secrets.token_urlsafe(64)
digest = hashlib.sha256(code_verifier.encode()).digest()
code_challenge = base64.urlsafe_b64encode(digest).rstrip(b"=").decode()
code = await service.create_authorization_code(
db=db,
client=public_client,
user=test_user,
redirect_uri="http://localhost:3000/callback",
scope="openid profile",
code_challenge=code_challenge,
code_challenge_method="S256",
)
assert code is not None
assert len(code) > 64
@pytest.mark.asyncio
async def test_create_authorization_code_public_without_pkce_fails(
self, db, public_client, test_user
):
"""Test creating authorization code for public client without PKCE fails."""
with pytest.raises(service.InvalidRequestError, match="PKCE"):
await service.create_authorization_code(
db=db,
client=public_client,
user=test_user,
redirect_uri="http://localhost:3000/callback",
scope="openid",
)
@pytest.mark.asyncio
async def test_exchange_authorization_code_success(
self, db, public_client, test_user
):
"""Test exchanging valid authorization code for tokens."""
# Create PKCE challenge
code_verifier = secrets.token_urlsafe(64)
digest = hashlib.sha256(code_verifier.encode()).digest()
code_challenge = base64.urlsafe_b64encode(digest).rstrip(b"=").decode()
# Create auth code
code = await service.create_authorization_code(
db=db,
client=public_client,
user=test_user,
redirect_uri="http://localhost:3000/callback",
scope="openid profile",
code_challenge=code_challenge,
code_challenge_method="S256",
)
# Exchange code
with patch("app.services.oauth_provider_service.settings") as mock_settings:
mock_settings.OAUTH_ISSUER = "http://localhost:8000"
mock_settings.SECRET_KEY = "test_secret_key_for_jwt_signing_123456"
mock_settings.ALGORITHM = "HS256"
result = await service.exchange_authorization_code(
db=db,
code=code,
client_id=public_client.client_id,
redirect_uri="http://localhost:3000/callback",
code_verifier=code_verifier,
)
assert "access_token" in result
assert "refresh_token" in result
assert result["token_type"] == "Bearer"
assert "expires_in" in result
@pytest.mark.asyncio
async def test_exchange_authorization_code_invalid_code(self, db, public_client):
"""Test exchanging invalid code fails."""
with pytest.raises(service.InvalidGrantError, match="Invalid authorization"):
await service.exchange_authorization_code(
db=db,
code="invalid_code",
client_id=public_client.client_id,
redirect_uri="http://localhost:3000/callback",
)
@pytest.mark.asyncio
async def test_exchange_authorization_code_wrong_redirect_uri(
self, db, public_client, test_user
):
"""Test exchanging code with wrong redirect_uri fails."""
code_verifier = secrets.token_urlsafe(64)
digest = hashlib.sha256(code_verifier.encode()).digest()
code_challenge = base64.urlsafe_b64encode(digest).rstrip(b"=").decode()
code = await service.create_authorization_code(
db=db,
client=public_client,
user=test_user,
redirect_uri="http://localhost:3000/callback",
scope="openid",
code_challenge=code_challenge,
code_challenge_method="S256",
)
with pytest.raises(service.InvalidGrantError, match="redirect_uri mismatch"):
await service.exchange_authorization_code(
db=db,
code=code,
client_id=public_client.client_id,
redirect_uri="http://different.com/callback",
code_verifier=code_verifier,
)
@pytest.mark.asyncio
async def test_exchange_authorization_code_invalid_pkce(
self, db, public_client, test_user
):
"""Test exchanging code with invalid PKCE verifier fails."""
code_verifier = secrets.token_urlsafe(64)
digest = hashlib.sha256(code_verifier.encode()).digest()
code_challenge = base64.urlsafe_b64encode(digest).rstrip(b"=").decode()
code = await service.create_authorization_code(
db=db,
client=public_client,
user=test_user,
redirect_uri="http://localhost:3000/callback",
scope="openid",
code_challenge=code_challenge,
code_challenge_method="S256",
)
with pytest.raises(service.InvalidGrantError, match="Invalid code_verifier"):
await service.exchange_authorization_code(
db=db,
code=code,
client_id=public_client.client_id,
redirect_uri="http://localhost:3000/callback",
code_verifier="wrong_verifier",
)
class TestTokenRefresh:
"""Tests for token refresh."""
@pytest.mark.asyncio
async def test_refresh_tokens_success(self, db, public_client, test_user):
"""Test refreshing tokens successfully."""
# Create initial tokens
with patch("app.services.oauth_provider_service.settings") as mock_settings:
mock_settings.OAUTH_ISSUER = "http://localhost:8000"
mock_settings.SECRET_KEY = "test_secret_key_for_jwt_signing_123456"
mock_settings.ALGORITHM = "HS256"
result = await service.create_tokens(
db=db,
client=public_client,
user=test_user,
scope="openid profile",
)
refresh_token = result["refresh_token"]
# Refresh the tokens
new_result = await service.refresh_tokens(
db=db,
refresh_token=refresh_token,
client_id=public_client.client_id,
)
assert "access_token" in new_result
assert "refresh_token" in new_result
assert new_result["refresh_token"] != refresh_token # Token rotation
@pytest.mark.asyncio
async def test_refresh_tokens_invalid_token(self, db, public_client):
"""Test refreshing with invalid token fails."""
with pytest.raises(service.InvalidGrantError, match="Invalid refresh token"):
await service.refresh_tokens(
db=db,
refresh_token="invalid_token",
client_id=public_client.client_id,
)
@pytest.mark.asyncio
async def test_refresh_tokens_scope_reduction(self, db, public_client, test_user):
"""Test refreshing with reduced scope."""
with patch("app.services.oauth_provider_service.settings") as mock_settings:
mock_settings.OAUTH_ISSUER = "http://localhost:8000"
mock_settings.SECRET_KEY = "test_secret_key_for_jwt_signing_123456"
mock_settings.ALGORITHM = "HS256"
result = await service.create_tokens(
db=db,
client=public_client,
user=test_user,
scope="openid profile email",
)
new_result = await service.refresh_tokens(
db=db,
refresh_token=result["refresh_token"],
client_id=public_client.client_id,
scope="openid", # Reduced scope
)
assert "openid" in new_result["scope"]
assert "profile" not in new_result["scope"]
@pytest.mark.asyncio
async def test_refresh_tokens_scope_expansion_fails(
self, db, public_client, test_user
):
"""Test refreshing with expanded scope fails."""
with patch("app.services.oauth_provider_service.settings") as mock_settings:
mock_settings.OAUTH_ISSUER = "http://localhost:8000"
mock_settings.SECRET_KEY = "test_secret_key_for_jwt_signing_123456"
mock_settings.ALGORITHM = "HS256"
result = await service.create_tokens(
db=db,
client=public_client,
user=test_user,
scope="openid",
)
with pytest.raises(service.InvalidScopeError, match="Cannot expand scope"):
await service.refresh_tokens(
db=db,
refresh_token=result["refresh_token"],
client_id=public_client.client_id,
scope="openid profile", # Expanded scope
)
class TestTokenRevocation:
"""Tests for token revocation."""
@pytest.mark.asyncio
async def test_revoke_refresh_token(self, db, public_client, test_user):
"""Test revoking a refresh token."""
with patch("app.services.oauth_provider_service.settings") as mock_settings:
mock_settings.OAUTH_ISSUER = "http://localhost:8000"
mock_settings.SECRET_KEY = "test_secret_key_for_jwt_signing_123456"
mock_settings.ALGORITHM = "HS256"
result = await service.create_tokens(
db=db,
client=public_client,
user=test_user,
scope="openid",
)
# Revoke the token
revoked = await service.revoke_token(
db=db,
token=result["refresh_token"],
token_type_hint="refresh_token",
)
assert revoked is True
# Try to use revoked token
with pytest.raises(service.InvalidGrantError, match="revoked"):
await service.refresh_tokens(
db=db,
refresh_token=result["refresh_token"],
client_id=public_client.client_id,
)
@pytest.mark.asyncio
async def test_revoke_all_user_tokens(self, db, public_client, test_user):
"""Test revoking all tokens for a user."""
with patch("app.services.oauth_provider_service.settings") as mock_settings:
mock_settings.OAUTH_ISSUER = "http://localhost:8000"
mock_settings.SECRET_KEY = "test_secret_key_for_jwt_signing_123456"
mock_settings.ALGORITHM = "HS256"
# Create multiple tokens (we don't need to capture results)
await service.create_tokens(
db=db,
client=public_client,
user=test_user,
scope="openid",
)
await service.create_tokens(
db=db,
client=public_client,
user=test_user,
scope="profile",
)
# Revoke all
count = await service.revoke_all_user_tokens(db, test_user.id)
assert count == 2
class TestTokenIntrospection:
"""Tests for token introspection (RFC 7662)."""
@pytest.mark.asyncio
async def test_introspect_valid_access_token(self, db, public_client, test_user):
"""Test introspecting a valid access token."""
with patch("app.services.oauth_provider_service.settings") as mock_settings:
mock_settings.OAUTH_ISSUER = "http://localhost:8000"
mock_settings.SECRET_KEY = "test_secret_key_for_jwt_signing_123456"
mock_settings.ALGORITHM = "HS256"
result = await service.create_tokens(
db=db,
client=public_client,
user=test_user,
scope="openid profile",
)
introspection = await service.introspect_token(
db=db,
token=result["access_token"],
)
assert introspection["active"] is True
assert introspection["client_id"] == public_client.client_id
assert introspection["sub"] == str(test_user.id)
@pytest.mark.asyncio
async def test_introspect_invalid_token(self, db):
"""Test introspecting an invalid token."""
introspection = await service.introspect_token(
db=db,
token="invalid_token",
)
assert introspection["active"] is False
class TestConsentManagement:
"""Tests for consent management."""
@pytest.mark.asyncio
async def test_grant_consent(self, db, public_client, test_user):
"""Test granting consent."""
consent = await service.grant_consent(
db=db,
user_id=test_user.id,
client_id=public_client.client_id,
scopes=["openid", "profile"],
)
assert consent is not None
assert "openid" in consent.granted_scopes
assert "profile" in consent.granted_scopes
@pytest.mark.asyncio
async def test_check_consent_granted(self, db, public_client, test_user):
"""Test checking granted consent."""
await service.grant_consent(
db=db,
user_id=test_user.id,
client_id=public_client.client_id,
scopes=["openid", "profile"],
)
has_consent = await service.check_consent(
db=db,
user_id=test_user.id,
client_id=public_client.client_id,
requested_scopes=["openid"],
)
assert has_consent is True
@pytest.mark.asyncio
async def test_check_consent_not_granted(self, db, public_client, test_user):
"""Test checking consent that hasn't been granted."""
has_consent = await service.check_consent(
db=db,
user_id=test_user.id,
client_id=public_client.client_id,
requested_scopes=["openid"],
)
assert has_consent is False
@pytest.mark.asyncio
async def test_revoke_consent(self, db, public_client, test_user):
"""Test revoking consent."""
await service.grant_consent(
db=db,
user_id=test_user.id,
client_id=public_client.client_id,
scopes=["openid"],
)
revoked = await service.revoke_consent(
db=db,
user_id=test_user.id,
client_id=public_client.client_id,
)
assert revoked is True
# Check consent is gone
has_consent = await service.check_consent(
db=db,
user_id=test_user.id,
client_id=public_client.client_id,
requested_scopes=["openid"],
)
assert has_consent is False
class TestOAuthErrors:
"""Tests for OAuth error classes."""
def test_invalid_client_error(self):
"""Test InvalidClientError."""
error = service.InvalidClientError("Test description")
assert error.error == "invalid_client"
assert error.error_description == "Test description"
def test_invalid_grant_error(self):
"""Test InvalidGrantError."""
error = service.InvalidGrantError("Test description")
assert error.error == "invalid_grant"
assert error.error_description == "Test description"
def test_invalid_request_error(self):
"""Test InvalidRequestError."""
error = service.InvalidRequestError("Test description")
assert error.error == "invalid_request"
assert error.error_description == "Test description"
def test_invalid_scope_error(self):
"""Test InvalidScopeError."""
error = service.InvalidScopeError("Test description")
assert error.error == "invalid_scope"
assert error.error_description == "Test description"
def test_access_denied_error(self):
"""Test AccessDeniedError."""
error = service.AccessDeniedError("Test description")
assert error.error == "access_denied"
assert error.error_description == "Test description"

View File

@@ -0,0 +1,325 @@
/**
* OAuth Consent Page
* Displays authorization consent form for OAuth provider mode (MCP integration).
*
* Users are redirected here when an external application (MCP client) requests
* access to their account. They can approve or deny the requested permissions.
*/
'use client';
import { useState, useEffect } from 'react';
import { useSearchParams } from 'next/navigation';
import { useRouter } from '@/lib/i18n/routing';
import { useTranslations } from 'next-intl';
import { Button } from '@/components/ui/button';
import {
Card,
CardContent,
CardDescription,
CardFooter,
CardHeader,
CardTitle,
} from '@/components/ui/card';
import { Alert, AlertDescription } from '@/components/ui/alert';
import { Checkbox } from '@/components/ui/checkbox';
import { Label } from '@/components/ui/label';
import { Loader2, Shield, AlertTriangle, ExternalLink, CheckCircle2 } from 'lucide-react';
import { useAuth } from '@/lib/auth/AuthContext';
import config from '@/config/app.config';
// Scope descriptions for display
const SCOPE_INFO: Record<string, { name: string; description: string; icon: string }> = {
openid: {
name: 'OpenID Connect',
description: 'Verify your identity',
icon: 'user',
},
profile: {
name: 'Profile',
description: 'Access your name and basic profile information',
icon: 'user-circle',
},
email: {
name: 'Email',
description: 'Access your email address',
icon: 'mail',
},
'read:users': {
name: 'Read Users',
description: 'View user information',
icon: 'users',
},
'write:users': {
name: 'Write Users',
description: 'Modify user information',
icon: 'user-edit',
},
'read:organizations': {
name: 'Read Organizations',
description: 'View organization information',
icon: 'building',
},
'write:organizations': {
name: 'Write Organizations',
description: 'Modify organization information',
icon: 'building-edit',
},
admin: {
name: 'Admin Access',
description: 'Full administrative access',
icon: 'shield',
},
};
interface ConsentParams {
clientId: string;
clientName: string;
redirectUri: string;
scope: string;
state: string;
codeChallenge: string;
codeChallengeMethod: string;
nonce: string;
}
export default function OAuthConsentPage() {
const searchParams = useSearchParams();
const router = useRouter();
// Note: t is available for future i18n use
const _t = useTranslations('auth.oauth');
void _t; // Suppress unused warning - ready for i18n
const { isAuthenticated, isLoading: authLoading } = useAuth();
const [isSubmitting, setIsSubmitting] = useState(false);
const [error, setError] = useState<string | null>(null);
const [selectedScopes, setSelectedScopes] = useState<Set<string>>(new Set());
const [params, setParams] = useState<ConsentParams | null>(null);
// Parse URL parameters
useEffect(() => {
const clientId = searchParams.get('client_id') || '';
const clientName = searchParams.get('client_name') || 'Application';
const redirectUri = searchParams.get('redirect_uri') || '';
const scope = searchParams.get('scope') || '';
const state = searchParams.get('state') || '';
const codeChallenge = searchParams.get('code_challenge') || '';
const codeChallengeMethod = searchParams.get('code_challenge_method') || '';
const nonce = searchParams.get('nonce') || '';
if (!clientId || !redirectUri) {
setError('Invalid authorization request. Missing required parameters.');
return;
}
setParams({
clientId,
clientName,
redirectUri,
scope,
state,
codeChallenge,
codeChallengeMethod,
nonce,
});
// Initialize selected scopes with all requested scopes
if (scope) {
setSelectedScopes(new Set(scope.split(' ')));
}
}, [searchParams]);
// Redirect to login if not authenticated
useEffect(() => {
if (!authLoading && !isAuthenticated) {
const returnUrl = `/auth/consent?${searchParams.toString()}`;
router.push(`${config.routes.login}?return_to=${encodeURIComponent(returnUrl)}`);
}
}, [authLoading, isAuthenticated, router, searchParams]);
const handleScopeToggle = (scope: string) => {
setSelectedScopes((prev) => {
const next = new Set(prev);
if (next.has(scope)) {
next.delete(scope);
} else {
next.add(scope);
}
return next;
});
};
const handleSubmit = async (approved: boolean) => {
if (!params) return;
setIsSubmitting(true);
setError(null);
try {
// Create form data for consent submission
const formData = new FormData();
formData.append('approved', approved.toString());
formData.append('client_id', params.clientId);
formData.append('redirect_uri', params.redirectUri);
formData.append('scope', Array.from(selectedScopes).join(' '));
formData.append('state', params.state);
if (params.codeChallenge) {
formData.append('code_challenge', params.codeChallenge);
}
if (params.codeChallengeMethod) {
formData.append('code_challenge_method', params.codeChallengeMethod);
}
if (params.nonce) {
formData.append('nonce', params.nonce);
}
// Submit consent to backend
const apiUrl = process.env.NEXT_PUBLIC_API_URL || 'http://localhost:8000';
const response = await fetch(`${apiUrl}/api/v1/oauth/provider/authorize/consent`, {
method: 'POST',
body: formData,
credentials: 'include',
});
// The endpoint returns a redirect, so follow it
if (response.redirected) {
window.location.href = response.url;
} else if (!response.ok) {
const data = await response.json();
throw new Error(data.detail || 'Failed to process consent');
}
} catch (err) {
setError(err instanceof Error ? err.message : 'An unexpected error occurred');
setIsSubmitting(false);
}
};
// Show loading state while checking auth
if (authLoading) {
return (
<div className="flex min-h-screen items-center justify-center p-4">
<div className="text-center space-y-4">
<Loader2 className="h-8 w-8 animate-spin mx-auto text-primary" />
<p className="text-muted-foreground">Loading...</p>
</div>
</div>
);
}
// Show error state
if (error && !params) {
return (
<div className="flex min-h-screen items-center justify-center p-4">
<div className="w-full max-w-md space-y-4">
<Alert variant="destructive">
<AlertTriangle className="h-4 w-4" />
<AlertDescription>{error}</AlertDescription>
</Alert>
<div className="flex gap-2 justify-center">
<Button variant="outline" onClick={() => router.push(config.routes.login)}>
Back to Login
</Button>
</div>
</div>
</div>
);
}
if (!params) {
return null;
}
const requestedScopes = params.scope ? params.scope.split(' ') : [];
return (
<div className="flex min-h-screen items-center justify-center p-4">
<Card className="w-full max-w-md">
<CardHeader className="text-center">
<div className="flex justify-center mb-4">
<Shield className="h-12 w-12 text-primary" />
</div>
<CardTitle className="text-xl">Authorization Request</CardTitle>
<CardDescription className="mt-2">
<span className="font-semibold text-foreground">{params.clientName}</span> wants to
access your account
</CardDescription>
</CardHeader>
<CardContent className="space-y-4">
{error && (
<Alert variant="destructive">
<AlertTriangle className="h-4 w-4" />
<AlertDescription>{error}</AlertDescription>
</Alert>
)}
<div className="space-y-3">
<p className="text-sm font-medium">This application will be able to:</p>
<div className="space-y-2 border rounded-lg p-3">
{requestedScopes.map((scope) => {
const info = SCOPE_INFO[scope] || {
name: scope,
description: `Access to ${scope}`,
};
const isSelected = selectedScopes.has(scope);
return (
<div
key={scope}
className="flex items-start space-x-3 py-2 border-b last:border-0"
>
<Checkbox
id={`scope-${scope}`}
checked={isSelected}
onCheckedChange={() => handleScopeToggle(scope)}
disabled={isSubmitting}
/>
<div className="flex-1 space-y-0.5">
<Label
htmlFor={`scope-${scope}`}
className="text-sm font-medium cursor-pointer"
>
{info.name}
</Label>
<p className="text-xs text-muted-foreground">{info.description}</p>
</div>
{isSelected && <CheckCircle2 className="h-4 w-4 text-green-500 mt-0.5" />}
</div>
);
})}
</div>
</div>
<Alert>
<ExternalLink className="h-4 w-4" />
<AlertDescription className="text-xs">
After authorization, you will be redirected to:
<br />
<code className="text-xs break-all bg-muted px-1 py-0.5 rounded">
{params.redirectUri}
</code>
</AlertDescription>
</Alert>
</CardContent>
<CardFooter className="flex gap-3">
<Button
variant="outline"
className="flex-1"
onClick={() => handleSubmit(false)}
disabled={isSubmitting}
>
{isSubmitting ? <Loader2 className="h-4 w-4 animate-spin" /> : 'Deny'}
</Button>
<Button
className="flex-1"
onClick={() => handleSubmit(true)}
disabled={isSubmitting || selectedScopes.size === 0}
>
{isSubmitting ? <Loader2 className="h-4 w-4 animate-spin" /> : 'Authorize'}
</Button>
</CardFooter>
</Card>
</div>
);
}