diff --git a/AGENTS.md b/AGENTS.md index 3542f93..8c8a100 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -69,6 +69,27 @@ Default superuser (change in production): - `get_optional_current_user`: Accepts authenticated or anonymous - `get_current_superuser`: Requires superuser flag +### OAuth Provider Mode (MCP Integration) +Full OAuth 2.0 Authorization Server for MCP (Model Context Protocol) clients: +- **Authorization Code Flow with PKCE**: RFC 7636 compliant +- **JWT access tokens**: Self-contained, no DB lookup required +- **Opaque refresh tokens**: Stored hashed in database, supports rotation +- **Token introspection**: RFC 7662 compliant endpoint +- **Token revocation**: RFC 7009 compliant endpoint +- **Server metadata**: RFC 8414 compliant discovery endpoint +- **Consent management**: User can review and revoke app permissions + +**API endpoints:** +- `GET /.well-known/oauth-authorization-server` - Server metadata +- `GET /oauth/provider/authorize` - Authorization endpoint +- `POST /oauth/provider/authorize/consent` - Consent submission +- `POST /oauth/provider/token` - Token endpoint +- `POST /oauth/provider/revoke` - Token revocation +- `POST /oauth/provider/introspect` - Token introspection +- Client management endpoints (admin only) + +**Scopes supported:** `openid`, `profile`, `email`, `read:users`, `write:users`, `admin` + ### Database Pattern - **Async SQLAlchemy 2.0** with PostgreSQL - **Connection pooling**: 20 base connections, 50 max overflow @@ -238,6 +259,7 @@ docker-compose exec backend python -c "from app.init_db import init_db; import a ### Completed Features ✅ - Authentication system (JWT with refresh tokens, OAuth/social login) +- **OAuth Provider Mode (MCP-ready)**: Full OAuth 2.0 Authorization Server - Session management (device tracking, revocation) - User management (CRUD, password change) - Organization system (multi-tenant with RBAC) diff --git a/backend/app/alembic/versions/f8c3d2e1a4b5_add_oauth_provider_models.py b/backend/app/alembic/versions/f8c3d2e1a4b5_add_oauth_provider_models.py new file mode 100644 index 0000000..214ed23 --- /dev/null +++ b/backend/app/alembic/versions/f8c3d2e1a4b5_add_oauth_provider_models.py @@ -0,0 +1,194 @@ +"""Add OAuth provider models for MCP integration. + +Revision ID: f8c3d2e1a4b5 +Revises: d5a7b2c9e1f3 +Create Date: 2025-01-15 10:00:00.000000 + +This migration adds tables for OAuth provider mode: +- oauth_authorization_codes: Temporary authorization codes +- oauth_provider_refresh_tokens: Long-lived refresh tokens +- oauth_consents: User consent records +""" + +import sqlalchemy as sa +from alembic import op +from sqlalchemy.dialects import postgresql + +# revision identifiers, used by Alembic. +revision = "f8c3d2e1a4b5" +down_revision = "d5a7b2c9e1f3" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + # Create oauth_authorization_codes table + op.create_table( + "oauth_authorization_codes", + sa.Column("id", postgresql.UUID(as_uuid=True), nullable=False), + sa.Column("code", sa.String(128), nullable=False), + sa.Column("client_id", sa.String(64), nullable=False), + sa.Column("user_id", postgresql.UUID(as_uuid=True), nullable=False), + sa.Column("redirect_uri", sa.String(2048), nullable=False), + sa.Column("scope", sa.String(1000), nullable=False, server_default=""), + sa.Column("code_challenge", sa.String(128), nullable=True), + sa.Column("code_challenge_method", sa.String(10), nullable=True), + sa.Column("state", sa.String(256), nullable=True), + sa.Column("nonce", sa.String(256), nullable=True), + sa.Column("expires_at", sa.DateTime(timezone=True), nullable=False), + sa.Column("used", sa.Boolean(), nullable=False, server_default="false"), + sa.Column( + "created_at", + sa.DateTime(timezone=True), + server_default=sa.text("now()"), + nullable=False, + ), + sa.Column( + "updated_at", + sa.DateTime(timezone=True), + server_default=sa.text("now()"), + nullable=False, + ), + sa.ForeignKeyConstraint( + ["client_id"], + ["oauth_clients.client_id"], + ondelete="CASCADE", + ), + sa.ForeignKeyConstraint( + ["user_id"], + ["users.id"], + ondelete="CASCADE", + ), + sa.PrimaryKeyConstraint("id"), + ) + op.create_index( + "ix_oauth_authorization_codes_code", + "oauth_authorization_codes", + ["code"], + unique=True, + ) + op.create_index( + "ix_oauth_authorization_codes_expires_at", + "oauth_authorization_codes", + ["expires_at"], + ) + op.create_index( + "ix_oauth_authorization_codes_client_user", + "oauth_authorization_codes", + ["client_id", "user_id"], + ) + + # Create oauth_provider_refresh_tokens table + op.create_table( + "oauth_provider_refresh_tokens", + sa.Column("id", postgresql.UUID(as_uuid=True), nullable=False), + sa.Column("token_hash", sa.String(64), nullable=False), + sa.Column("jti", sa.String(64), nullable=False), + sa.Column("client_id", sa.String(64), nullable=False), + sa.Column("user_id", postgresql.UUID(as_uuid=True), nullable=False), + sa.Column("scope", sa.String(1000), nullable=False, server_default=""), + sa.Column("expires_at", sa.DateTime(timezone=True), nullable=False), + sa.Column("revoked", sa.Boolean(), nullable=False, server_default="false"), + sa.Column("last_used_at", sa.DateTime(timezone=True), nullable=True), + sa.Column("device_info", sa.String(500), nullable=True), + sa.Column("ip_address", sa.String(45), nullable=True), + sa.Column( + "created_at", + sa.DateTime(timezone=True), + server_default=sa.text("now()"), + nullable=False, + ), + sa.Column( + "updated_at", + sa.DateTime(timezone=True), + server_default=sa.text("now()"), + nullable=False, + ), + sa.ForeignKeyConstraint( + ["client_id"], + ["oauth_clients.client_id"], + ondelete="CASCADE", + ), + sa.ForeignKeyConstraint( + ["user_id"], + ["users.id"], + ondelete="CASCADE", + ), + sa.PrimaryKeyConstraint("id"), + ) + op.create_index( + "ix_oauth_provider_refresh_tokens_token_hash", + "oauth_provider_refresh_tokens", + ["token_hash"], + unique=True, + ) + op.create_index( + "ix_oauth_provider_refresh_tokens_jti", + "oauth_provider_refresh_tokens", + ["jti"], + unique=True, + ) + op.create_index( + "ix_oauth_provider_refresh_tokens_expires_at", + "oauth_provider_refresh_tokens", + ["expires_at"], + ) + op.create_index( + "ix_oauth_provider_refresh_tokens_client_user", + "oauth_provider_refresh_tokens", + ["client_id", "user_id"], + ) + op.create_index( + "ix_oauth_provider_refresh_tokens_user_revoked", + "oauth_provider_refresh_tokens", + ["user_id", "revoked"], + ) + op.create_index( + "ix_oauth_provider_refresh_tokens_revoked", + "oauth_provider_refresh_tokens", + ["revoked"], + ) + + # Create oauth_consents table + op.create_table( + "oauth_consents", + sa.Column("id", postgresql.UUID(as_uuid=True), nullable=False), + sa.Column("user_id", postgresql.UUID(as_uuid=True), nullable=False), + sa.Column("client_id", sa.String(64), nullable=False), + sa.Column("granted_scopes", sa.String(1000), nullable=False, server_default=""), + sa.Column( + "created_at", + sa.DateTime(timezone=True), + server_default=sa.text("now()"), + nullable=False, + ), + sa.Column( + "updated_at", + sa.DateTime(timezone=True), + server_default=sa.text("now()"), + nullable=False, + ), + sa.ForeignKeyConstraint( + ["client_id"], + ["oauth_clients.client_id"], + ondelete="CASCADE", + ), + sa.ForeignKeyConstraint( + ["user_id"], + ["users.id"], + ondelete="CASCADE", + ), + sa.PrimaryKeyConstraint("id"), + ) + op.create_index( + "ix_oauth_consents_user_client", + "oauth_consents", + ["user_id", "client_id"], + unique=True, + ) + + +def downgrade() -> None: + op.drop_table("oauth_consents") + op.drop_table("oauth_provider_refresh_tokens") + op.drop_table("oauth_authorization_codes") diff --git a/backend/app/api/routes/oauth_provider.py b/backend/app/api/routes/oauth_provider.py index a52185c..b8d2b26 100644 --- a/backend/app/api/routes/oauth_provider.py +++ b/backend/app/api/routes/oauth_provider.py @@ -1,37 +1,63 @@ # app/api/routes/oauth_provider.py """ -OAuth Provider routes (Authorization Server mode). +OAuth Provider routes (Authorization Server mode) for MCP integration. -This is a skeleton implementation for MCP (Model Context Protocol) client authentication. -Provides basic OAuth 2.0 endpoints that can be expanded for full functionality. - -Endpoints: +Implements OAuth 2.0 Authorization Server endpoints: - GET /.well-known/oauth-authorization-server - Server metadata (RFC 8414) -- GET /oauth/provider/authorize - Authorization endpoint (skeleton) -- POST /oauth/provider/token - Token endpoint (skeleton) -- POST /oauth/provider/revoke - Token revocation endpoint (skeleton) +- GET /oauth/provider/authorize - Authorization endpoint +- POST /oauth/provider/token - Token endpoint +- POST /oauth/provider/revoke - Token revocation (RFC 7009) +- POST /oauth/provider/introspect - Token introspection (RFC 7662) +- Client management endpoints -NOTE: This is intentionally minimal. Full implementation should include: -- Complete authorization code flow -- Refresh token handling -- Scope validation -- Client authentication -- PKCE support +Security features: +- PKCE required for public clients (S256) +- CSRF protection via state parameter +- Secure token handling +- Rate limiting on sensitive endpoints """ import logging from typing import Any +from urllib.parse import urlencode -from fastapi import APIRouter, Depends, Form, HTTPException, Query, status +from fastapi import APIRouter, Depends, Form, HTTPException, Query, Request, status +from fastapi.responses import RedirectResponse +from slowapi import Limiter +from slowapi.util import get_remote_address from sqlalchemy.ext.asyncio import AsyncSession +from app.api.dependencies.auth import get_current_active_user, get_current_superuser from app.core.config import settings from app.core.database import get_db -from app.crud import oauth_client -from app.schemas.oauth import OAuthServerMetadata +from app.crud import oauth_client as oauth_client_crud +from app.models.user import User +from app.schemas.oauth import ( + OAuthClientCreate, + OAuthClientResponse, + OAuthServerMetadata, + OAuthTokenIntrospectionResponse, + OAuthTokenResponse, +) +from app.services import oauth_provider_service as provider_service router = APIRouter() logger = logging.getLogger(__name__) +limiter = Limiter(key_func=get_remote_address) + + +def require_provider_enabled(): + """Dependency to check if OAuth provider mode is enabled.""" + if not settings.OAUTH_PROVIDER_ENABLED: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="OAuth provider mode is not enabled. Set OAUTH_PROVIDER_ENABLED=true", + ) + + +# ============================================================================ +# Server Metadata (RFC 8414) +# ============================================================================ @router.get( @@ -42,24 +68,15 @@ logger = logging.getLogger(__name__) OAuth 2.0 Authorization Server Metadata (RFC 8414). Returns server metadata including supported endpoints, scopes, - and capabilities for MCP clients. + and capabilities. MCP clients use this to discover the server. """, operation_id="get_oauth_server_metadata", tags=["OAuth Provider"], ) -async def get_server_metadata() -> Any: - """ - Get OAuth 2.0 server metadata. - - This endpoint is used by MCP clients to discover the authorization - server's capabilities. - """ - if not settings.OAUTH_PROVIDER_ENABLED: - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail="OAuth provider mode is not enabled", - ) - +async def get_server_metadata( + _: None = Depends(require_provider_enabled), +) -> OAuthServerMetadata: + """Get OAuth 2.0 server metadata.""" base_url = settings.OAUTH_ISSUER.rstrip("/") return OAuthServerMetadata( @@ -67,7 +84,8 @@ async def get_server_metadata() -> Any: authorization_endpoint=f"{base_url}/api/v1/oauth/provider/authorize", token_endpoint=f"{base_url}/api/v1/oauth/provider/token", revocation_endpoint=f"{base_url}/api/v1/oauth/provider/revoke", - registration_endpoint=None, # Dynamic registration not implemented + introspection_endpoint=f"{base_url}/api/v1/oauth/provider/introspect", + registration_endpoint=None, # Dynamic registration not supported scopes_supported=[ "openid", "profile", @@ -76,148 +94,441 @@ async def get_server_metadata() -> Any: "write:users", "read:organizations", "write:organizations", + "admin", ], response_types_supported=["code"], grant_types_supported=["authorization_code", "refresh_token"], code_challenge_methods_supported=["S256"], + token_endpoint_auth_methods_supported=[ + "client_secret_basic", + "client_secret_post", + "none", # For public clients with PKCE + ], ) +# ============================================================================ +# Authorization Endpoint +# ============================================================================ + + @router.get( "/provider/authorize", - summary="Authorization Endpoint (Skeleton)", + summary="Authorization Endpoint", description=""" OAuth 2.0 Authorization Endpoint. - **NOTE**: This is a skeleton implementation. In a full implementation, - this would: - 1. Validate client_id and redirect_uri - 2. Display consent screen to user - 3. Generate authorization code - 4. Redirect back to client with code + Initiates the authorization code flow: + 1. Validates client and parameters + 2. Checks if user is authenticated (redirects to login if not) + 3. Checks existing consent + 4. Redirects to consent page if needed + 5. Issues authorization code and redirects back to client - Currently returns a 501 Not Implemented response. + Required parameters: + - response_type: Must be "code" + - client_id: Registered client ID + - redirect_uri: Must match registered URI + + Recommended parameters: + - state: CSRF protection + - code_challenge + code_challenge_method: PKCE (required for public clients) + - scope: Requested permissions """, operation_id="oauth_provider_authorize", tags=["OAuth Provider"], ) +@limiter.limit("30/minute") async def authorize( + request: Request, response_type: str = Query(..., description="Must be 'code'"), client_id: str = Query(..., description="OAuth client ID"), redirect_uri: str = Query(..., description="Redirect URI"), - scope: str = Query(default="", description="Requested scopes"), + scope: str = Query(default="", description="Requested scopes (space-separated)"), state: str = Query(default="", description="CSRF state parameter"), code_challenge: str | None = Query(default=None, description="PKCE code challenge"), code_challenge_method: str | None = Query( default=None, description="PKCE method (S256)" ), + nonce: str | None = Query(default=None, description="OpenID Connect nonce"), db: AsyncSession = Depends(get_db), + _: None = Depends(require_provider_enabled), + current_user: User | None = Depends(get_current_active_user), ) -> Any: """ - Authorization endpoint (skeleton). + Authorization endpoint - initiates OAuth flow. - In a full implementation, this would: - 1. Validate the client and redirect URI - 2. Authenticate the user (if not already) - 3. Show consent screen - 4. Generate authorization code - 5. Redirect to redirect_uri with code + If user is not authenticated, redirects to login with return URL. + If user has not consented, redirects to consent page. + If all checks pass, generates code and redirects to client. """ - if not settings.OAUTH_PROVIDER_ENABLED: - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail="OAuth provider mode is not enabled", - ) - - # Validate client exists - client = await oauth_client.get_by_client_id(db, client_id=client_id) - if not client: + # Validate response_type + if response_type != "code": raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, - detail="invalid_client: Unknown client_id", + detail="invalid_request: response_type must be 'code'", ) - # Validate redirect_uri - if redirect_uri not in (client.redirect_uris or []): + # Validate PKCE method if provided + if code_challenge_method and code_challenge_method not in ["S256", "plain"]: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, - detail="invalid_request: Invalid redirect_uri", + detail="invalid_request: code_challenge_method must be 'S256'", ) - # Skeleton: Return not implemented - # Full implementation would redirect to consent screen - raise HTTPException( - status_code=status.HTTP_501_NOT_IMPLEMENTED, - detail="Authorization endpoint not fully implemented. " - "This is a skeleton for MCP integration.", + # Validate client + try: + client = await provider_service.get_client(db, client_id) + if not client: + raise provider_service.InvalidClientError("Unknown client_id") + provider_service.validate_redirect_uri(client, redirect_uri) + except provider_service.OAuthProviderError as e: + # For client/redirect errors, we can't safely redirect - show error + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=f"{e.error}: {e.error_description}", + ) + + # Validate and filter scopes + try: + requested_scopes = provider_service.parse_scope(scope) + valid_scopes = provider_service.validate_scopes(client, requested_scopes) + except provider_service.InvalidScopeError as e: + # Redirect with error + error_params = { + "error": e.error, + "error_description": e.error_description, + } + if state: + error_params["state"] = state + return RedirectResponse( + url=f"{redirect_uri}?{urlencode(error_params)}", + status_code=status.HTTP_302_FOUND, + ) + + # Public clients MUST use PKCE + if client.client_type == "public": + if not code_challenge or code_challenge_method != "S256": + error_params = { + "error": "invalid_request", + "error_description": "PKCE with S256 is required for public clients", + } + if state: + error_params["state"] = state + return RedirectResponse( + url=f"{redirect_uri}?{urlencode(error_params)}", + status_code=status.HTTP_302_FOUND, + ) + + # If user is not authenticated, redirect to login + if not current_user: + # Store authorization request in session and redirect to login + # The frontend will handle the return URL + login_url = f"{settings.FRONTEND_URL}/login" + return_params = urlencode({ + "oauth_authorize": "true", + "client_id": client_id, + "redirect_uri": redirect_uri, + "scope": " ".join(valid_scopes), + "state": state, + "code_challenge": code_challenge or "", + "code_challenge_method": code_challenge_method or "", + "nonce": nonce or "", + }) + return RedirectResponse( + url=f"{login_url}?return_to=/auth/consent?{return_params}", + status_code=status.HTTP_302_FOUND, + ) + + # Check if user has already consented + has_consent = await provider_service.check_consent( + db, current_user.id, client_id, valid_scopes + ) + + if not has_consent: + # Redirect to consent page + consent_params = urlencode({ + "client_id": client_id, + "client_name": client.client_name, + "redirect_uri": redirect_uri, + "scope": " ".join(valid_scopes), + "state": state, + "code_challenge": code_challenge or "", + "code_challenge_method": code_challenge_method or "", + "nonce": nonce or "", + }) + return RedirectResponse( + url=f"{settings.FRONTEND_URL}/auth/consent?{consent_params}", + status_code=status.HTTP_302_FOUND, + ) + + # User is authenticated and has consented - issue authorization code + try: + code = await provider_service.create_authorization_code( + db=db, + client=client, + user=current_user, + redirect_uri=redirect_uri, + scope=" ".join(valid_scopes), + code_challenge=code_challenge, + code_challenge_method=code_challenge_method, + state=state, + nonce=nonce, + ) + except provider_service.OAuthProviderError as e: + error_params = { + "error": e.error, + "error_description": e.error_description, + } + if state: + error_params["state"] = state + return RedirectResponse( + url=f"{redirect_uri}?{urlencode(error_params)}", + status_code=status.HTTP_302_FOUND, + ) + + # Success - redirect with code + success_params = {"code": code} + if state: + success_params["state"] = state + + return RedirectResponse( + url=f"{redirect_uri}?{urlencode(success_params)}", + status_code=status.HTTP_302_FOUND, ) +@router.post( + "/provider/authorize/consent", + summary="Submit Authorization Consent", + description=""" + Submit user consent for OAuth authorization. + + Called by the consent page after user approves or denies. + """, + operation_id="oauth_provider_consent", + tags=["OAuth Provider"], +) +@limiter.limit("30/minute") +async def submit_consent( + request: Request, + approved: bool = Form(..., description="Whether user approved"), + client_id: str = Form(..., description="OAuth client ID"), + redirect_uri: str = Form(..., description="Redirect URI"), + scope: str = Form(default="", description="Granted scopes"), + state: str = Form(default="", description="CSRF state parameter"), + code_challenge: str | None = Form(default=None), + code_challenge_method: str | None = Form(default=None), + nonce: str | None = Form(default=None), + db: AsyncSession = Depends(get_db), + _: None = Depends(require_provider_enabled), + current_user: User = Depends(get_current_active_user), +) -> Any: + """Process consent form submission.""" + # Validate client + try: + client = await provider_service.get_client(db, client_id) + if not client: + raise provider_service.InvalidClientError("Unknown client_id") + provider_service.validate_redirect_uri(client, redirect_uri) + except provider_service.OAuthProviderError as e: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=f"{e.error}: {e.error_description}", + ) + + # If user denied, redirect with error + if not approved: + error_params = { + "error": "access_denied", + "error_description": "User denied authorization", + } + if state: + error_params["state"] = state + return RedirectResponse( + url=f"{redirect_uri}?{urlencode(error_params)}", + status_code=status.HTTP_302_FOUND, + ) + + # Parse and validate scopes + granted_scopes = provider_service.parse_scope(scope) + valid_scopes = provider_service.validate_scopes(client, granted_scopes) + + # Record consent + await provider_service.grant_consent( + db, current_user.id, client_id, valid_scopes + ) + + # Generate authorization code + try: + code = await provider_service.create_authorization_code( + db=db, + client=client, + user=current_user, + redirect_uri=redirect_uri, + scope=" ".join(valid_scopes), + code_challenge=code_challenge, + code_challenge_method=code_challenge_method, + state=state, + nonce=nonce, + ) + except provider_service.OAuthProviderError as e: + error_params = { + "error": e.error, + "error_description": e.error_description, + } + if state: + error_params["state"] = state + return RedirectResponse( + url=f"{redirect_uri}?{urlencode(error_params)}", + status_code=status.HTTP_302_FOUND, + ) + + # Success + success_params = {"code": code} + if state: + success_params["state"] = state + + return RedirectResponse( + url=f"{redirect_uri}?{urlencode(success_params)}", + status_code=status.HTTP_302_FOUND, + ) + + +# ============================================================================ +# Token Endpoint +# ============================================================================ + + @router.post( "/provider/token", - summary="Token Endpoint (Skeleton)", + response_model=OAuthTokenResponse, + summary="Token Endpoint", description=""" OAuth 2.0 Token Endpoint. - **NOTE**: This is a skeleton implementation. In a full implementation, - this would exchange authorization codes for access tokens. + Supports: + - authorization_code: Exchange code for tokens + - refresh_token: Refresh access token - Currently returns a 501 Not Implemented response. + Client authentication: + - Confidential clients: client_secret (Basic auth or POST body) + - Public clients: No secret, but PKCE code_verifier required """, operation_id="oauth_provider_token", tags=["OAuth Provider"], ) +@limiter.limit("60/minute") async def token( - grant_type: str = Form(..., description="Grant type (authorization_code)"), + request: Request, + grant_type: str = Form(..., description="Grant type"), code: str | None = Form(default=None, description="Authorization code"), redirect_uri: str | None = Form(default=None, description="Redirect URI"), client_id: str | None = Form(default=None, description="Client ID"), client_secret: str | None = Form(default=None, description="Client secret"), code_verifier: str | None = Form(default=None, description="PKCE code verifier"), refresh_token: str | None = Form(default=None, description="Refresh token"), + scope: str | None = Form(default=None, description="Scope (for refresh)"), db: AsyncSession = Depends(get_db), -) -> Any: - """ - Token endpoint (skeleton). + _: None = Depends(require_provider_enabled), +) -> OAuthTokenResponse: + """Token endpoint - exchange code for tokens or refresh.""" + # Extract client credentials from Basic auth if not in body + if not client_id: + auth_header = request.headers.get("Authorization", "") + if auth_header.startswith("Basic "): + import base64 + try: + decoded = base64.b64decode(auth_header[6:]).decode() + client_id, client_secret = decoded.split(":", 1) + except Exception: # noqa: S110 - Intentional: malformed Basic auth falls back to form body + pass - Supported grant types (when fully implemented): - - authorization_code: Exchange code for tokens - - refresh_token: Refresh access token - """ - if not settings.OAUTH_PROVIDER_ENABLED: + if not client_id: raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail="OAuth provider mode is not enabled", + status_code=status.HTTP_401_UNAUTHORIZED, + detail="invalid_client: client_id required", + headers={"WWW-Authenticate": "Basic"}, ) - if grant_type not in ["authorization_code", "refresh_token"]: + # Get device info + device_info = request.headers.get("User-Agent", "")[:500] + ip_address = get_remote_address(request) + + try: + if grant_type == "authorization_code": + if not code: + raise provider_service.InvalidRequestError("code required") + if not redirect_uri: + raise provider_service.InvalidRequestError("redirect_uri required") + + result = await provider_service.exchange_authorization_code( + db=db, + code=code, + client_id=client_id, + redirect_uri=redirect_uri, + code_verifier=code_verifier, + client_secret=client_secret, + device_info=device_info, + ip_address=ip_address, + ) + + elif grant_type == "refresh_token": + if not refresh_token: + raise provider_service.InvalidRequestError("refresh_token required") + + result = await provider_service.refresh_tokens( + db=db, + refresh_token=refresh_token, + client_id=client_id, + client_secret=client_secret, + scope=scope, + device_info=device_info, + ip_address=ip_address, + ) + + else: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="unsupported_grant_type: Must be authorization_code or refresh_token", + ) + + return OAuthTokenResponse(**result) + + except provider_service.InvalidClientError as e: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail=f"{e.error}: {e.error_description}", + headers={"WWW-Authenticate": "Basic"}, + ) + except provider_service.OAuthProviderError as e: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, - detail="unsupported_grant_type", + detail=f"{e.error}: {e.error_description}", ) - # Skeleton: Return not implemented - raise HTTPException( - status_code=status.HTTP_501_NOT_IMPLEMENTED, - detail="Token endpoint not fully implemented. " - "This is a skeleton for MCP integration.", - ) + +# ============================================================================ +# Token Revocation (RFC 7009) +# ============================================================================ @router.post( "/provider/revoke", - summary="Token Revocation Endpoint (Skeleton)", + status_code=status.HTTP_200_OK, + summary="Token Revocation Endpoint", description=""" OAuth 2.0 Token Revocation Endpoint (RFC 7009). - **NOTE**: This is a skeleton implementation. - - Currently returns a 501 Not Implemented response. + Revokes an access token or refresh token. + Always returns 200 OK (even if token is invalid) per spec. """, operation_id="oauth_provider_revoke", tags=["OAuth Provider"], ) +@limiter.limit("30/minute") async def revoke( + request: Request, token: str = Form(..., description="Token to revoke"), token_type_hint: str | None = Form( default=None, description="Token type hint (access_token, refresh_token)" @@ -225,88 +536,286 @@ async def revoke( client_id: str | None = Form(default=None, description="Client ID"), client_secret: str | None = Form(default=None, description="Client secret"), db: AsyncSession = Depends(get_db), -) -> Any: - """ - Token revocation endpoint (skeleton). + _: None = Depends(require_provider_enabled), +) -> dict[str, str]: + """Revoke a token.""" + # Extract client credentials from Basic auth if not in body + if not client_id: + auth_header = request.headers.get("Authorization", "") + if auth_header.startswith("Basic "): + import base64 + try: + decoded = base64.b64decode(auth_header[6:]).decode() + client_id, client_secret = decoded.split(":", 1) + except Exception: # noqa: S110 - Intentional: malformed Basic auth falls back to form body + pass - In a full implementation, this would invalidate the specified token. - """ - if not settings.OAUTH_PROVIDER_ENABLED: - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail="OAuth provider mode is not enabled", + try: + await provider_service.revoke_token( + db=db, + token=token, + token_type_hint=token_type_hint, + client_id=client_id, + client_secret=client_secret, ) + except provider_service.InvalidClientError: + # Per RFC 7009, we should return 200 OK even for errors + # But client authentication errors can return 401 + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="invalid_client", + headers={"WWW-Authenticate": "Basic"}, + ) + except Exception as e: + # Log but don't expose errors per RFC 7009 + logger.warning(f"Token revocation error: {e}") - # Skeleton: Return not implemented - raise HTTPException( - status_code=status.HTTP_501_NOT_IMPLEMENTED, - detail="Revocation endpoint not fully implemented. " - "This is a skeleton for MCP integration.", - ) + # Always return 200 OK per RFC 7009 + return {"status": "ok"} # ============================================================================ -# Client Management (Admin only) +# Token Introspection (RFC 7662) +# ============================================================================ + + +@router.post( + "/provider/introspect", + response_model=OAuthTokenIntrospectionResponse, + summary="Token Introspection Endpoint", + description=""" + OAuth 2.0 Token Introspection Endpoint (RFC 7662). + + Allows resource servers to query the authorization server + to determine the active state and metadata of a token. + """, + operation_id="oauth_provider_introspect", + tags=["OAuth Provider"], +) +@limiter.limit("120/minute") +async def introspect( + request: Request, + token: str = Form(..., description="Token to introspect"), + token_type_hint: str | None = Form( + default=None, description="Token type hint (access_token, refresh_token)" + ), + client_id: str | None = Form(default=None, description="Client ID"), + client_secret: str | None = Form(default=None, description="Client secret"), + db: AsyncSession = Depends(get_db), + _: None = Depends(require_provider_enabled), +) -> OAuthTokenIntrospectionResponse: + """Introspect a token.""" + # Extract client credentials from Basic auth if not in body + if not client_id: + auth_header = request.headers.get("Authorization", "") + if auth_header.startswith("Basic "): + import base64 + try: + decoded = base64.b64decode(auth_header[6:]).decode() + client_id, client_secret = decoded.split(":", 1) + except Exception: # noqa: S110 - Intentional: malformed Basic auth falls back to form body + pass + + try: + result = await provider_service.introspect_token( + db=db, + token=token, + token_type_hint=token_type_hint, + client_id=client_id, + client_secret=client_secret, + ) + return OAuthTokenIntrospectionResponse(**result) + except provider_service.InvalidClientError: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="invalid_client", + headers={"WWW-Authenticate": "Basic"}, + ) + except Exception as e: + logger.warning(f"Token introspection error: {e}") + return OAuthTokenIntrospectionResponse(active=False) + + +# ============================================================================ +# Client Management (Admin) # ============================================================================ @router.post( "/provider/clients", - summary="Register OAuth Client (Admin)", + response_model=dict, + summary="Register OAuth Client", description=""" Register a new OAuth client (admin only). - This endpoint allows creating MCP clients that can authenticate - against this API. + Creates an MCP client that can authenticate against this API. + Returns client_id and client_secret (for confidential clients). - **NOTE**: This is a minimal implementation. + **Important:** Store the client_secret securely - it won't be shown again! """, operation_id="register_oauth_client", - tags=["OAuth Provider"], + tags=["OAuth Provider Admin"], ) async def register_client( client_name: str = Form(..., description="Client application name"), - redirect_uris: str = Form(..., description="Comma-separated list of redirect URIs"), + redirect_uris: str = Form(..., description="Comma-separated redirect URIs"), client_type: str = Form(default="public", description="public or confidential"), + scopes: str = Form( + default="openid profile email", + description="Allowed scopes (space-separated)", + ), + mcp_server_url: str | None = Form(default=None, description="MCP server URL"), db: AsyncSession = Depends(get_db), -) -> Any: - """ - Register a new OAuth client (skeleton). - - In a full implementation, this would require admin authentication. - """ - if not settings.OAUTH_PROVIDER_ENABLED: + _: None = Depends(require_provider_enabled), + current_user: User = Depends(get_current_superuser), +) -> dict: + """Register a new OAuth client.""" + # Parse redirect URIs + uris = [uri.strip() for uri in redirect_uris.split(",") if uri.strip()] + if not uris: raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail="OAuth provider mode is not enabled", + status_code=status.HTTP_400_BAD_REQUEST, + detail="At least one redirect_uri is required", ) - # NOTE: In production, this should require admin authentication - # For now, this is a skeleton that shows the structure - - from app.schemas.oauth import OAuthClientCreate + # Parse scopes + allowed_scopes = [s.strip() for s in scopes.split() if s.strip()] client_data = OAuthClientCreate( client_name=client_name, client_description=None, - redirect_uris=[uri.strip() for uri in redirect_uris.split(",")], - allowed_scopes=["openid", "profile", "email"], + redirect_uris=uris, + allowed_scopes=allowed_scopes, client_type=client_type, ) - client, secret = await oauth_client.create_client(db, obj_in=client_data) + client, secret = await oauth_client_crud.create_client(db, obj_in=client_data) + + # Update MCP server URL if provided + if mcp_server_url: + client.mcp_server_url = mcp_server_url + await db.commit() result = { "client_id": client.client_id, "client_name": client.client_name, "client_type": client.client_type, "redirect_uris": client.redirect_uris, + "allowed_scopes": client.allowed_scopes, } if secret: result["client_secret"] = secret result["warning"] = ( - "Store the client_secret securely. It will not be shown again." + "Store the client_secret securely! It will not be shown again." ) return result + + +@router.get( + "/provider/clients", + response_model=list[OAuthClientResponse], + summary="List OAuth Clients", + description="List all registered OAuth clients (admin only).", + operation_id="list_oauth_clients", + tags=["OAuth Provider Admin"], +) +async def list_clients( + db: AsyncSession = Depends(get_db), + _: None = Depends(require_provider_enabled), + current_user: User = Depends(get_current_superuser), +) -> list[OAuthClientResponse]: + """List all OAuth clients.""" + clients = await oauth_client_crud.get_all_clients(db) + return [OAuthClientResponse.model_validate(c) for c in clients] + + +@router.delete( + "/provider/clients/{client_id}", + status_code=status.HTTP_204_NO_CONTENT, + summary="Delete OAuth Client", + description="Delete an OAuth client (admin only). Revokes all tokens.", + operation_id="delete_oauth_client", + tags=["OAuth Provider Admin"], +) +async def delete_client( + client_id: str, + db: AsyncSession = Depends(get_db), + _: None = Depends(require_provider_enabled), + current_user: User = Depends(get_current_superuser), +) -> None: + """Delete an OAuth client.""" + client = await provider_service.get_client(db, client_id) + if not client: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="Client not found", + ) + + await oauth_client_crud.delete_client(db, client_id=client_id) + + +# ============================================================================ +# User Consent Management +# ============================================================================ + + +@router.get( + "/provider/consents", + summary="List My Consents", + description="List OAuth applications the current user has authorized.", + operation_id="list_my_oauth_consents", + tags=["OAuth Provider"], +) +async def list_my_consents( + db: AsyncSession = Depends(get_db), + _: None = Depends(require_provider_enabled), + current_user: User = Depends(get_current_active_user), +) -> list[dict]: + """List applications the user has authorized.""" + from sqlalchemy import select + + from app.models.oauth_client import OAuthClient + from app.models.oauth_provider_token import OAuthConsent + + result = await db.execute( + select(OAuthConsent, OAuthClient) + .join(OAuthClient, OAuthConsent.client_id == OAuthClient.client_id) + .where(OAuthConsent.user_id == current_user.id) + ) + rows = result.all() + + return [ + { + "client_id": consent.client_id, + "client_name": client.client_name, + "client_description": client.client_description, + "granted_scopes": consent.granted_scopes.split() if consent.granted_scopes else [], + "granted_at": consent.created_at.isoformat(), + } + for consent, client in rows + ] + + +@router.delete( + "/provider/consents/{client_id}", + status_code=status.HTTP_204_NO_CONTENT, + summary="Revoke My Consent", + description="Revoke authorization for an OAuth application. Also revokes all tokens.", + operation_id="revoke_my_oauth_consent", + tags=["OAuth Provider"], +) +async def revoke_my_consent( + client_id: str, + db: AsyncSession = Depends(get_db), + _: None = Depends(require_provider_enabled), + current_user: User = Depends(get_current_active_user), +) -> None: + """Revoke consent for an application.""" + revoked = await provider_service.revoke_consent(db, current_user.id, client_id) + if not revoked: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="No consent found for this client", + ) diff --git a/backend/app/crud/oauth.py b/backend/app/crud/oauth.py index e11307d..22e4a06 100755 --- a/backend/app/crud/oauth.py +++ b/backend/app/crud/oauth.py @@ -643,6 +643,62 @@ class CRUDOAuthClient(CRUDBase[OAuthClient, OAuthClientCreate, EmptySchema]): logger.error(f"Error verifying client secret: {e!s}") return False + async def get_all_clients( + self, db: AsyncSession, *, include_inactive: bool = False + ) -> list[OAuthClient]: + """ + Get all OAuth clients. + + Args: + db: Database session + include_inactive: Whether to include inactive clients + + Returns: + List of OAuthClient objects + """ + try: + query = select(OAuthClient).order_by(OAuthClient.created_at.desc()) + if not include_inactive: + query = query.where(OAuthClient.is_active == True) # noqa: E712 + + result = await db.execute(query) + return list(result.scalars().all()) + except Exception as e: # pragma: no cover + logger.error(f"Error getting all OAuth clients: {e!s}") + raise + + async def delete_client(self, db: AsyncSession, *, client_id: str) -> bool: + """ + Delete an OAuth client permanently. + + Note: This will cascade delete related records (tokens, consents, etc.) + due to foreign key constraints. + + Args: + db: Database session + client_id: OAuth client ID + + Returns: + True if deleted, False if not found + """ + try: + result = await db.execute( + delete(OAuthClient).where(OAuthClient.client_id == client_id) + ) + await db.commit() + + deleted = result.rowcount > 0 + if deleted: + logger.info(f"OAuth client deleted: {client_id}") + else: + logger.warning(f"OAuth client not found for deletion: {client_id}") + + return deleted + except Exception as e: # pragma: no cover + await db.rollback() + logger.error(f"Error deleting OAuth client {client_id}: {e!s}") + raise + # ============================================================================ # Singleton instances diff --git a/backend/app/models/__init__.py b/backend/app/models/__init__.py index 0e65351..f968f28 100644 --- a/backend/app/models/__init__.py +++ b/backend/app/models/__init__.py @@ -8,9 +8,13 @@ from app.core.database import Base from .base import TimestampMixin, UUIDMixin -# OAuth models +# OAuth models (client mode - authenticate via Google/GitHub) from .oauth_account import OAuthAccount + +# OAuth provider models (server mode - act as authorization server for MCP) +from .oauth_authorization_code import OAuthAuthorizationCode from .oauth_client import OAuthClient +from .oauth_provider_token import OAuthConsent, OAuthProviderRefreshToken from .oauth_state import OAuthState from .organization import Organization @@ -22,7 +26,10 @@ from .user_session import UserSession __all__ = [ "Base", "OAuthAccount", + "OAuthAuthorizationCode", "OAuthClient", + "OAuthConsent", + "OAuthProviderRefreshToken", "OAuthState", "Organization", "OrganizationRole", diff --git a/backend/app/models/oauth_authorization_code.py b/backend/app/models/oauth_authorization_code.py new file mode 100644 index 0000000..5f0543c --- /dev/null +++ b/backend/app/models/oauth_authorization_code.py @@ -0,0 +1,91 @@ +"""OAuth authorization code model for OAuth provider mode.""" + +from datetime import datetime + +from sqlalchemy import Boolean, Column, DateTime, ForeignKey, Index, String +from sqlalchemy.dialects.postgresql import UUID +from sqlalchemy.orm import relationship + +from .base import Base, TimestampMixin, UUIDMixin + + +class OAuthAuthorizationCode(Base, UUIDMixin, TimestampMixin): + """ + OAuth 2.0 Authorization Code for the authorization code flow. + + Authorization codes are: + - Single-use (marked as used after exchange) + - Short-lived (10 minutes default) + - Bound to specific client, user, redirect_uri + - Support PKCE (code_challenge/code_challenge_method) + + Security considerations: + - Code must be cryptographically random (64 chars, URL-safe) + - Must validate redirect_uri matches exactly + - Must verify PKCE code_verifier for public clients + - Must be consumed within expiration time + """ + + __tablename__ = "oauth_authorization_codes" + + # The authorization code (cryptographically random, URL-safe) + code = Column(String(128), unique=True, nullable=False, index=True) + + # Client that requested the code + client_id = Column( + String(64), + ForeignKey("oauth_clients.client_id", ondelete="CASCADE"), + nullable=False, + ) + + # User who authorized the request + user_id = Column( + UUID(as_uuid=True), + ForeignKey("users.id", ondelete="CASCADE"), + nullable=False, + ) + + # Redirect URI (must match exactly on token exchange) + redirect_uri = Column(String(2048), nullable=False) + + # Granted scopes (space-separated) + scope = Column(String(1000), nullable=False, default="") + + # PKCE support (required for public clients) + code_challenge = Column(String(128), nullable=True) + code_challenge_method = Column(String(10), nullable=True) # "S256" or "plain" + + # State parameter (for CSRF protection, returned to client) + state = Column(String(256), nullable=True) + + # Nonce (for OpenID Connect, included in ID token) + nonce = Column(String(256), nullable=True) + + # Expiration (codes are short-lived, typically 10 minutes) + expires_at = Column(DateTime(timezone=True), nullable=False) + + # Single-use flag (set to True after successful exchange) + used = Column(Boolean, default=False, nullable=False) + + # Relationships + client = relationship("OAuthClient", backref="authorization_codes") + user = relationship("User", backref="oauth_authorization_codes") + + # Indexes for efficient cleanup queries + __table_args__ = ( + Index("ix_oauth_authorization_codes_expires_at", "expires_at"), + Index("ix_oauth_authorization_codes_client_user", "client_id", "user_id"), + ) + + def __repr__(self): + return f"" + + @property + def is_expired(self) -> bool: + """Check if the authorization code has expired.""" + return datetime.utcnow() > self.expires_at.replace(tzinfo=None) + + @property + def is_valid(self) -> bool: + """Check if the authorization code is valid (not used, not expired).""" + return not self.used and not self.is_expired diff --git a/backend/app/models/oauth_provider_token.py b/backend/app/models/oauth_provider_token.py new file mode 100644 index 0000000..2f99826 --- /dev/null +++ b/backend/app/models/oauth_provider_token.py @@ -0,0 +1,153 @@ +"""OAuth provider token models for OAuth provider mode.""" + +from datetime import datetime + +from sqlalchemy import Boolean, Column, DateTime, ForeignKey, Index, String +from sqlalchemy.dialects.postgresql import UUID +from sqlalchemy.orm import relationship + +from .base import Base, TimestampMixin, UUIDMixin + + +class OAuthProviderRefreshToken(Base, UUIDMixin, TimestampMixin): + """ + OAuth 2.0 Refresh Token for the OAuth provider. + + Refresh tokens are: + - Opaque (stored as hash in DB, actual token given to client) + - Long-lived (configurable, default 30 days) + - Revocable (via revoked flag or deletion) + - Bound to specific client, user, and scope + + Access tokens are JWTs and not stored in DB (self-contained). + This model only tracks refresh tokens for revocation support. + + Security considerations: + - Store token hash, not plaintext + - Support token rotation (new refresh token on use) + - Track last used time for security auditing + - Support revocation by user, client, or admin + """ + + __tablename__ = "oauth_provider_refresh_tokens" + + # Hash of the refresh token (SHA-256) + # We store hash, not plaintext, for security + token_hash = Column(String(64), unique=True, nullable=False, index=True) + + # Unique token ID (JTI) - used in JWT access tokens to reference this refresh token + jti = Column(String(64), unique=True, nullable=False, index=True) + + # Client that owns this token + client_id = Column( + String(64), + ForeignKey("oauth_clients.client_id", ondelete="CASCADE"), + nullable=False, + ) + + # User who authorized this token + user_id = Column( + UUID(as_uuid=True), + ForeignKey("users.id", ondelete="CASCADE"), + nullable=False, + ) + + # Granted scopes (space-separated) + scope = Column(String(1000), nullable=False, default="") + + # Token expiration + expires_at = Column(DateTime(timezone=True), nullable=False) + + # Revocation flag + revoked = Column(Boolean, default=False, nullable=False, index=True) + + # Last used timestamp (for security auditing) + last_used_at = Column(DateTime(timezone=True), nullable=True) + + # Device/session info (optional, for user visibility) + device_info = Column(String(500), nullable=True) + ip_address = Column(String(45), nullable=True) + + # Relationships + client = relationship("OAuthClient", backref="refresh_tokens") + user = relationship("User", backref="oauth_provider_refresh_tokens") + + # Indexes + __table_args__ = ( + Index("ix_oauth_provider_refresh_tokens_expires_at", "expires_at"), + Index("ix_oauth_provider_refresh_tokens_client_user", "client_id", "user_id"), + Index( + "ix_oauth_provider_refresh_tokens_user_revoked", + "user_id", + "revoked", + ), + ) + + def __repr__(self): + status = "revoked" if self.revoked else "active" + return f"" + + @property + def is_expired(self) -> bool: + """Check if the refresh token has expired.""" + return datetime.utcnow() > self.expires_at.replace(tzinfo=None) + + @property + def is_valid(self) -> bool: + """Check if the refresh token is valid (not revoked, not expired).""" + return not self.revoked and not self.is_expired + + +class OAuthConsent(Base, UUIDMixin, TimestampMixin): + """ + OAuth consent record - remembers user consent for a client. + + When a user grants consent to an OAuth client, we store the record + so they don't have to re-consent on subsequent authorizations + (unless scopes change). + + This enables a better UX - users only see consent screen once per client, + unless the client requests additional scopes. + """ + + __tablename__ = "oauth_consents" + + # User who granted consent + user_id = Column( + UUID(as_uuid=True), + ForeignKey("users.id", ondelete="CASCADE"), + nullable=False, + ) + + # Client that received consent + client_id = Column( + String(64), + ForeignKey("oauth_clients.client_id", ondelete="CASCADE"), + nullable=False, + ) + + # Granted scopes (space-separated) + granted_scopes = Column(String(1000), nullable=False, default="") + + # Relationships + client = relationship("OAuthClient", backref="consents") + user = relationship("User", backref="oauth_consents") + + # Unique constraint: one consent record per user+client + __table_args__ = ( + Index( + "ix_oauth_consents_user_client", + "user_id", + "client_id", + unique=True, + ), + ) + + def __repr__(self): + return f"" + + def has_scopes(self, requested_scopes: list[str]) -> bool: + """Check if all requested scopes are already granted.""" + granted = set(self.granted_scopes.split()) if self.granted_scopes else set() + requested = set(requested_scopes) + return requested.issubset(granted) diff --git a/backend/app/schemas/oauth.py b/backend/app/schemas/oauth.py index c1df309..aa7d908 100644 --- a/backend/app/schemas/oauth.py +++ b/backend/app/schemas/oauth.py @@ -284,6 +284,9 @@ class OAuthServerMetadata(BaseModel): revocation_endpoint: str | None = Field( None, description="Token revocation endpoint" ) + introspection_endpoint: str | None = Field( + None, description="Token introspection endpoint (RFC 7662)" + ) scopes_supported: list[str] = Field( default_factory=list, description="Supported scopes" ) @@ -297,6 +300,10 @@ class OAuthServerMetadata(BaseModel): code_challenge_methods_supported: list[str] = Field( default_factory=lambda: ["S256"], description="Supported PKCE methods" ) + token_endpoint_auth_methods_supported: list[str] = Field( + default_factory=lambda: ["client_secret_basic", "client_secret_post", "none"], + description="Supported client authentication methods", + ) model_config = ConfigDict( json_schema_extra={ @@ -304,10 +311,105 @@ class OAuthServerMetadata(BaseModel): "issuer": "https://api.example.com", "authorization_endpoint": "https://api.example.com/oauth/authorize", "token_endpoint": "https://api.example.com/oauth/token", + "revocation_endpoint": "https://api.example.com/oauth/revoke", + "introspection_endpoint": "https://api.example.com/oauth/introspect", "scopes_supported": ["openid", "profile", "email", "read:users"], "response_types_supported": ["code"], "grant_types_supported": ["authorization_code", "refresh_token"], "code_challenge_methods_supported": ["S256"], + "token_endpoint_auth_methods_supported": [ + "client_secret_basic", + "client_secret_post", + "none", + ], + } + } + ) + + +# ============================================================================ +# OAuth Token Responses (RFC 6749) +# ============================================================================ + + +class OAuthTokenResponse(BaseModel): + """OAuth 2.0 Token Response (RFC 6749 Section 5.1).""" + + access_token: str = Field(..., description="The access token issued by the server") + token_type: str = Field( + default="Bearer", description="The type of token (typically 'Bearer')" + ) + expires_in: int | None = Field( + None, description="Token lifetime in seconds" + ) + refresh_token: str | None = Field( + None, description="Refresh token for obtaining new access tokens" + ) + scope: str | None = Field( + None, description="Space-separated list of granted scopes" + ) + + model_config = ConfigDict( + json_schema_extra={ + "example": { + "access_token": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9...", + "token_type": "Bearer", + "expires_in": 3600, + "refresh_token": "dGhpcyBpcyBhIHJlZnJlc2ggdG9rZW4...", + "scope": "openid profile email", + } + } + ) + + +class OAuthTokenIntrospectionResponse(BaseModel): + """OAuth 2.0 Token Introspection Response (RFC 7662).""" + + active: bool = Field( + ..., description="Whether the token is currently active" + ) + scope: str | None = Field( + None, description="Space-separated list of scopes" + ) + client_id: str | None = Field( + None, description="Client identifier for the token" + ) + username: str | None = Field( + None, description="Human-readable identifier for the resource owner" + ) + token_type: str | None = Field( + None, description="Type of the token (e.g., 'Bearer')" + ) + exp: int | None = Field( + None, description="Token expiration time (Unix timestamp)" + ) + iat: int | None = Field( + None, description="Token issue time (Unix timestamp)" + ) + nbf: int | None = Field( + None, description="Token not-before time (Unix timestamp)" + ) + sub: str | None = Field( + None, description="Subject of the token (user ID)" + ) + aud: str | None = Field( + None, description="Intended audience of the token" + ) + iss: str | None = Field( + None, description="Issuer of the token" + ) + + model_config = ConfigDict( + json_schema_extra={ + "example": { + "active": True, + "scope": "openid profile", + "client_id": "client123", + "username": "user@example.com", + "token_type": "Bearer", + "exp": 1735689600, + "iat": 1735686000, + "sub": "user-uuid-here", } } ) diff --git a/backend/app/services/oauth_provider_service.py b/backend/app/services/oauth_provider_service.py new file mode 100644 index 0000000..a198364 --- /dev/null +++ b/backend/app/services/oauth_provider_service.py @@ -0,0 +1,1008 @@ +""" +OAuth Provider Service for MCP integration. + +Implements OAuth 2.0 Authorization Server functionality: +- Authorization code flow with PKCE +- Token issuance (JWT access tokens, opaque refresh tokens) +- Token refresh +- Token revocation +- Consent management + +Security features: +- PKCE required for public clients (S256) +- Short-lived authorization codes (10 minutes) +- JWT access tokens (self-contained, no DB lookup) +- Secure refresh token storage (hashed) +- Token rotation on refresh +- Comprehensive validation +""" + +import base64 +import hashlib +import logging +import secrets +from datetime import UTC, datetime, timedelta +from typing import Any +from uuid import UUID + +from jose import jwt +from sqlalchemy import and_, delete, select +from sqlalchemy.ext.asyncio import AsyncSession + +from app.core.config import settings +from app.models.oauth_authorization_code import OAuthAuthorizationCode +from app.models.oauth_client import OAuthClient +from app.models.oauth_provider_token import OAuthConsent, OAuthProviderRefreshToken +from app.models.user import User + +logger = logging.getLogger(__name__) + +# Constants +AUTHORIZATION_CODE_EXPIRY_MINUTES = 10 +ACCESS_TOKEN_EXPIRY_MINUTES = 60 # 1 hour for MCP clients +REFRESH_TOKEN_EXPIRY_DAYS = 30 + + +class OAuthProviderError(Exception): + """Base exception for OAuth provider errors.""" + + def __init__( + self, + error: str, + error_description: str | None = None, + error_uri: str | None = None, + ): + self.error = error + self.error_description = error_description + self.error_uri = error_uri + super().__init__(error_description or error) + + +class InvalidClientError(OAuthProviderError): + """Client authentication failed.""" + + def __init__(self, description: str = "Invalid client credentials"): + super().__init__("invalid_client", description) + + +class InvalidGrantError(OAuthProviderError): + """Invalid authorization grant.""" + + def __init__(self, description: str = "Invalid grant"): + super().__init__("invalid_grant", description) + + +class InvalidRequestError(OAuthProviderError): + """Invalid request parameters.""" + + def __init__(self, description: str = "Invalid request"): + super().__init__("invalid_request", description) + + +class InvalidScopeError(OAuthProviderError): + """Invalid scope requested.""" + + def __init__(self, description: str = "Invalid scope"): + super().__init__("invalid_scope", description) + + +class UnauthorizedClientError(OAuthProviderError): + """Client not authorized for this grant type.""" + + def __init__(self, description: str = "Unauthorized client"): + super().__init__("unauthorized_client", description) + + +class AccessDeniedError(OAuthProviderError): + """User denied authorization.""" + + def __init__(self, description: str = "Access denied"): + super().__init__("access_denied", description) + + +# ============================================================================ +# Helper Functions +# ============================================================================ + + +def generate_code() -> str: + """Generate a cryptographically secure authorization code.""" + return secrets.token_urlsafe(64) + + +def generate_token() -> str: + """Generate a cryptographically secure token.""" + return secrets.token_urlsafe(48) + + +def generate_jti() -> str: + """Generate a unique JWT ID.""" + return secrets.token_urlsafe(32) + + +def hash_token(token: str) -> str: + """Hash a token using SHA-256.""" + return hashlib.sha256(token.encode()).hexdigest() + + +def verify_pkce(code_verifier: str, code_challenge: str, method: str) -> bool: + """Verify PKCE code_verifier against stored code_challenge.""" + if method == "S256": + # SHA-256 hash, then base64url encode + digest = hashlib.sha256(code_verifier.encode()).digest() + computed = base64.urlsafe_b64encode(digest).rstrip(b"=").decode() + return secrets.compare_digest(computed, code_challenge) + elif method == "plain": + # Direct comparison (not recommended, but supported) + return secrets.compare_digest(code_verifier, code_challenge) + return False + + +def parse_scope(scope: str) -> list[str]: + """Parse space-separated scope string into list.""" + return [s.strip() for s in scope.split() if s.strip()] + + +def join_scope(scopes: list[str]) -> str: + """Join scope list into space-separated string.""" + return " ".join(sorted(set(scopes))) + + +# ============================================================================ +# Client Validation +# ============================================================================ + + +async def get_client(db: AsyncSession, client_id: str) -> OAuthClient | None: + """Get OAuth client by client_id.""" + result = await db.execute( + select(OAuthClient).where( + and_( + OAuthClient.client_id == client_id, + OAuthClient.is_active == True, # noqa: E712 + ) + ) + ) + return result.scalar_one_or_none() + + +async def validate_client( + db: AsyncSession, + client_id: str, + client_secret: str | None = None, + require_secret: bool = False, +) -> OAuthClient: + """ + Validate OAuth client credentials. + + Args: + db: Database session + client_id: Client identifier + client_secret: Client secret (required for confidential clients) + require_secret: Whether to require secret validation + + Returns: + Validated OAuthClient + + Raises: + InvalidClientError: If client validation fails + """ + client = await get_client(db, client_id) + if not client: + raise InvalidClientError("Unknown client_id") + + # Confidential clients must provide valid secret + if client.client_type == "confidential" or require_secret: + if not client_secret: + raise InvalidClientError("Client secret required") + if not client.client_secret_hash: + raise InvalidClientError("Client not configured with secret") + + # Verify secret using SHA256 hash (consistent with CRUD) + computed_hash = hashlib.sha256(client_secret.encode()).hexdigest() + if not secrets.compare_digest(computed_hash, client.client_secret_hash): + raise InvalidClientError("Invalid client secret") + + return client + + +def validate_redirect_uri(client: OAuthClient, redirect_uri: str) -> None: + """ + Validate redirect_uri against client's registered URIs. + + Raises: + InvalidRequestError: If redirect_uri is not registered + """ + if not client.redirect_uris: + raise InvalidRequestError("Client has no registered redirect URIs") + + if redirect_uri not in client.redirect_uris: + raise InvalidRequestError("Invalid redirect_uri") + + +def validate_scopes(client: OAuthClient, requested_scopes: list[str]) -> list[str]: + """ + Validate requested scopes against client's allowed scopes. + + Returns: + List of valid scopes (intersection of requested and allowed) + + Raises: + InvalidScopeError: If no valid scopes + """ + allowed = set(client.allowed_scopes or []) + requested = set(requested_scopes) + + # If no scopes requested, use all allowed scopes + if not requested: + return list(allowed) + + valid = requested & allowed + if not valid: + raise InvalidScopeError( + "None of the requested scopes are allowed for this client" + ) + + # Warn if some scopes were filtered out + invalid = requested - allowed + if invalid: + logger.warning( + f"Client {client.client_id} requested invalid scopes: {invalid}" + ) + + return list(valid) + + +# ============================================================================ +# Authorization Code Flow +# ============================================================================ + + +async def create_authorization_code( + db: AsyncSession, + client: OAuthClient, + user: User, + redirect_uri: str, + scope: str, + code_challenge: str | None = None, + code_challenge_method: str | None = None, + state: str | None = None, + nonce: str | None = None, +) -> str: + """ + Create an authorization code for the authorization code flow. + + Args: + db: Database session + client: Validated OAuth client + user: Authenticated user + redirect_uri: Validated redirect URI + scope: Granted scopes (space-separated) + code_challenge: PKCE code challenge + code_challenge_method: PKCE method (S256) + state: CSRF state parameter + nonce: OpenID Connect nonce + + Returns: + Authorization code string + """ + # Public clients MUST use PKCE + if client.client_type == "public": + if not code_challenge or code_challenge_method != "S256": + raise InvalidRequestError("PKCE with S256 is required for public clients") + + code = generate_code() + expires_at = datetime.now(UTC) + timedelta( + minutes=AUTHORIZATION_CODE_EXPIRY_MINUTES + ) + + auth_code = OAuthAuthorizationCode( + code=code, + client_id=client.client_id, + user_id=user.id, + redirect_uri=redirect_uri, + scope=scope, + code_challenge=code_challenge, + code_challenge_method=code_challenge_method, + state=state, + nonce=nonce, + expires_at=expires_at, + used=False, + ) + + db.add(auth_code) + await db.commit() + + logger.info( + f"Created authorization code for user {user.id} and client {client.client_id}" + ) + return code + + +async def exchange_authorization_code( + db: AsyncSession, + code: str, + client_id: str, + redirect_uri: str, + code_verifier: str | None = None, + client_secret: str | None = None, + device_info: str | None = None, + ip_address: str | None = None, +) -> dict[str, Any]: + """ + Exchange authorization code for tokens. + + Args: + db: Database session + code: Authorization code + client_id: Client identifier + redirect_uri: Must match the original redirect_uri + code_verifier: PKCE code verifier + client_secret: Client secret (for confidential clients) + device_info: Optional device information + ip_address: Optional IP address + + Returns: + Token response dict with access_token, refresh_token, etc. + + Raises: + InvalidGrantError: If code is invalid, expired, or already used + InvalidClientError: If client validation fails + """ + # Get and validate authorization code + result = await db.execute( + select(OAuthAuthorizationCode).where(OAuthAuthorizationCode.code == code) + ) + auth_code = result.scalar_one_or_none() + + if not auth_code: + raise InvalidGrantError("Invalid authorization code") + + if auth_code.used: + # Code reuse is a security incident - revoke all tokens for this grant + logger.warning( + f"Authorization code reuse detected for client {auth_code.client_id}" + ) + await revoke_tokens_for_user_client(db, auth_code.user_id, auth_code.client_id) + raise InvalidGrantError("Authorization code has already been used") + + if auth_code.is_expired: + raise InvalidGrantError("Authorization code has expired") + + if auth_code.client_id != client_id: + raise InvalidGrantError("Authorization code was not issued to this client") + + if auth_code.redirect_uri != redirect_uri: + raise InvalidGrantError("redirect_uri mismatch") + + # Validate client + client = await validate_client( + db, + client_id, + client_secret, + require_secret=(client_secret is not None), + ) + + # Verify PKCE + if auth_code.code_challenge: + if not code_verifier: + raise InvalidGrantError("code_verifier required") + if not verify_pkce( + code_verifier, + auth_code.code_challenge, + auth_code.code_challenge_method or "S256", + ): + raise InvalidGrantError("Invalid code_verifier") + elif client.client_type == "public": + # Public clients without PKCE - this shouldn't happen if we validated on authorize + raise InvalidGrantError("PKCE required for public clients") + + # Mark code as used (single-use) + auth_code.used = True + await db.commit() + + # Get user + user_result = await db.execute(select(User).where(User.id == auth_code.user_id)) + user = user_result.scalar_one_or_none() + if not user or not user.is_active: + raise InvalidGrantError("User not found or inactive") + + # Generate tokens + return await create_tokens( + db=db, + client=client, + user=user, + scope=auth_code.scope, + nonce=auth_code.nonce, + device_info=device_info, + ip_address=ip_address, + ) + + +# ============================================================================ +# Token Generation +# ============================================================================ + + +async def create_tokens( + db: AsyncSession, + client: OAuthClient, + user: User, + scope: str, + nonce: str | None = None, + device_info: str | None = None, + ip_address: str | None = None, +) -> dict[str, Any]: + """ + Create access and refresh tokens. + + Args: + db: Database session + client: OAuth client + user: User + scope: Granted scopes + nonce: OpenID Connect nonce (included in ID token) + device_info: Optional device information + ip_address: Optional IP address + + Returns: + Token response dict + """ + now = datetime.now(UTC) + jti = generate_jti() + + # Access token expiry + access_token_lifetime = int(client.access_token_lifetime or "3600") + access_expires = now + timedelta(seconds=access_token_lifetime) + + # Refresh token expiry + refresh_token_lifetime = int(client.refresh_token_lifetime or str(REFRESH_TOKEN_EXPIRY_DAYS * 86400)) + refresh_expires = now + timedelta(seconds=refresh_token_lifetime) + + # Create JWT access token + access_token_payload = { + "iss": settings.OAUTH_ISSUER, + "sub": str(user.id), + "aud": client.client_id, + "exp": int(access_expires.timestamp()), + "iat": int(now.timestamp()), + "jti": jti, + "scope": scope, + "client_id": client.client_id, + # User info (basic claims) + "email": user.email, + "name": f"{user.first_name or ''} {user.last_name or ''}".strip() or user.email, + } + + # Add nonce for OpenID Connect + if nonce: + access_token_payload["nonce"] = nonce + + access_token = jwt.encode( + access_token_payload, + settings.SECRET_KEY, + algorithm=settings.ALGORITHM, + ) + + # Create opaque refresh token + refresh_token = generate_token() + refresh_token_hash = hash_token(refresh_token) + + # Store refresh token in database + refresh_token_record = OAuthProviderRefreshToken( + token_hash=refresh_token_hash, + jti=jti, + client_id=client.client_id, + user_id=user.id, + scope=scope, + expires_at=refresh_expires, + device_info=device_info, + ip_address=ip_address, + ) + db.add(refresh_token_record) + await db.commit() + + logger.info(f"Issued tokens for user {user.id} to client {client.client_id}") + + return { + "access_token": access_token, + "token_type": "Bearer", + "expires_in": access_token_lifetime, + "refresh_token": refresh_token, + "scope": scope, + } + + +async def refresh_tokens( + db: AsyncSession, + refresh_token: str, + client_id: str, + client_secret: str | None = None, + scope: str | None = None, + device_info: str | None = None, + ip_address: str | None = None, +) -> dict[str, Any]: + """ + Refresh access token using refresh token. + + Implements token rotation - old refresh token is invalidated, + new refresh token is issued. + + Args: + db: Database session + refresh_token: Refresh token + client_id: Client identifier + client_secret: Client secret (for confidential clients) + scope: Optional reduced scope + device_info: Optional device information + ip_address: Optional IP address + + Returns: + New token response dict + + Raises: + InvalidGrantError: If refresh token is invalid + """ + # Find refresh token + token_hash = hash_token(refresh_token) + result = await db.execute( + select(OAuthProviderRefreshToken).where( + OAuthProviderRefreshToken.token_hash == token_hash + ) + ) + token_record = result.scalar_one_or_none() + + if not token_record: + raise InvalidGrantError("Invalid refresh token") + + if token_record.revoked: + # Token reuse after revocation - security incident + logger.warning( + f"Revoked refresh token reuse detected for client {token_record.client_id}" + ) + raise InvalidGrantError("Refresh token has been revoked") + + if token_record.is_expired: + raise InvalidGrantError("Refresh token has expired") + + if token_record.client_id != client_id: + raise InvalidGrantError("Refresh token was not issued to this client") + + # Validate client + client = await validate_client( + db, + client_id, + client_secret, + require_secret=(client_secret is not None), + ) + + # Get user + user_result = await db.execute( + select(User).where(User.id == token_record.user_id) + ) + user = user_result.scalar_one_or_none() + if not user or not user.is_active: + raise InvalidGrantError("User not found or inactive") + + # Validate scope (can only reduce, not expand) + original_scopes = set(parse_scope(token_record.scope)) + if scope: + requested_scopes = set(parse_scope(scope)) + if not requested_scopes.issubset(original_scopes): + raise InvalidScopeError("Cannot expand scope on refresh") + final_scope = join_scope(list(requested_scopes)) + else: + final_scope = token_record.scope + + # Revoke old refresh token (token rotation) + token_record.revoked = True + token_record.last_used_at = datetime.now(UTC) + await db.commit() + + # Issue new tokens + return await create_tokens( + db=db, + client=client, + user=user, + scope=final_scope, + device_info=device_info or token_record.device_info, + ip_address=ip_address or token_record.ip_address, + ) + + +# ============================================================================ +# Token Revocation +# ============================================================================ + + +async def revoke_token( + db: AsyncSession, + token: str, + token_type_hint: str | None = None, + client_id: str | None = None, + client_secret: str | None = None, +) -> bool: + """ + Revoke a token (access or refresh). + + For refresh tokens: marks as revoked in database + For access tokens: we can't truly revoke JWTs, but we can revoke + the associated refresh token to prevent further refreshes + + Args: + db: Database session + token: Token to revoke + token_type_hint: "access_token" or "refresh_token" + client_id: Client identifier (for validation) + client_secret: Client secret (for confidential clients) + + Returns: + True if token was revoked, False if not found + """ + # Try as refresh token first (more likely) + if token_type_hint != "access_token": + token_hash = hash_token(token) + result = await db.execute( + select(OAuthProviderRefreshToken).where( + OAuthProviderRefreshToken.token_hash == token_hash + ) + ) + refresh_record = result.scalar_one_or_none() + + if refresh_record: + # Validate client if provided + if client_id and refresh_record.client_id != client_id: + raise InvalidClientError("Token was not issued to this client") + + refresh_record.revoked = True + await db.commit() + logger.info(f"Revoked refresh token {refresh_record.jti[:8]}...") + return True + + # Try as access token (JWT) + if token_type_hint != "refresh_token": + try: + from jose.exceptions import JWTError + + payload = jwt.decode( + token, + settings.SECRET_KEY, + algorithms=[settings.ALGORITHM], + options={"verify_exp": False, "verify_aud": False}, # Allow expired tokens + ) + jti = payload.get("jti") + if jti: + # Find and revoke the associated refresh token + result = await db.execute( + select(OAuthProviderRefreshToken).where( + OAuthProviderRefreshToken.jti == jti + ) + ) + refresh_record = result.scalar_one_or_none() + if refresh_record: + if client_id and refresh_record.client_id != client_id: + raise InvalidClientError("Token was not issued to this client") + refresh_record.revoked = True + await db.commit() + logger.info( + f"Revoked refresh token via access token JTI {jti[:8]}..." + ) + return True + except (JWTError, Exception): # noqa: S110 - Intentional: invalid JWT not an error + pass + + return False + + +async def revoke_tokens_for_user_client( + db: AsyncSession, + user_id: UUID, + client_id: str, +) -> int: + """ + Revoke all tokens for a specific user-client pair. + + Used when security incidents are detected (e.g., code reuse). + + Args: + db: Database session + user_id: User identifier + client_id: Client identifier + + Returns: + Number of tokens revoked + """ + result = await db.execute( + select(OAuthProviderRefreshToken).where( + and_( + OAuthProviderRefreshToken.user_id == user_id, + OAuthProviderRefreshToken.client_id == client_id, + OAuthProviderRefreshToken.revoked == False, # noqa: E712 + ) + ) + ) + tokens = result.scalars().all() + + count = 0 + for token in tokens: + token.revoked = True + count += 1 + + if count > 0: + await db.commit() + logger.warning( + f"Revoked {count} tokens for user {user_id} and client {client_id}" + ) + + return count + + +async def revoke_all_user_tokens(db: AsyncSession, user_id: UUID) -> int: + """ + Revoke all OAuth provider tokens for a user. + + Used when user changes password or explicitly logs out everywhere. + + Args: + db: Database session + user_id: User identifier + + Returns: + Number of tokens revoked + """ + result = await db.execute( + select(OAuthProviderRefreshToken).where( + and_( + OAuthProviderRefreshToken.user_id == user_id, + OAuthProviderRefreshToken.revoked == False, # noqa: E712 + ) + ) + ) + tokens = result.scalars().all() + + count = 0 + for token in tokens: + token.revoked = True + count += 1 + + if count > 0: + await db.commit() + logger.info(f"Revoked {count} OAuth provider tokens for user {user_id}") + + return count + + +# ============================================================================ +# Token Introspection (RFC 7662) +# ============================================================================ + + +async def introspect_token( + db: AsyncSession, + token: str, + token_type_hint: str | None = None, + client_id: str | None = None, + client_secret: str | None = None, +) -> dict[str, Any]: + """ + Introspect a token to determine its validity and metadata. + + Implements RFC 7662 Token Introspection. + + Args: + db: Database session + token: Token to introspect + token_type_hint: "access_token" or "refresh_token" + client_id: Client requesting introspection + client_secret: Client secret + + Returns: + Introspection response dict + """ + # Validate client if credentials provided + if client_id: + await validate_client(db, client_id, client_secret) + + # Try as access token (JWT) first + if token_type_hint != "refresh_token": + try: + from jose.exceptions import ExpiredSignatureError, JWTError + + payload = jwt.decode( + token, + settings.SECRET_KEY, + algorithms=[settings.ALGORITHM], + options={"verify_aud": False}, # Don't require audience match for introspection + ) + + # Check if associated refresh token is revoked + jti = payload.get("jti") + if jti: + result = await db.execute( + select(OAuthProviderRefreshToken).where( + OAuthProviderRefreshToken.jti == jti + ) + ) + refresh_record = result.scalar_one_or_none() + if refresh_record and refresh_record.revoked: + return {"active": False} + + return { + "active": True, + "scope": payload.get("scope", ""), + "client_id": payload.get("client_id"), + "username": payload.get("email"), + "token_type": "Bearer", + "exp": payload.get("exp"), + "iat": payload.get("iat"), + "sub": payload.get("sub"), + "aud": payload.get("aud"), + "iss": payload.get("iss"), + } + except ExpiredSignatureError: + return {"active": False} + except (JWTError, Exception): # noqa: S110 - Intentional: invalid JWT falls through to refresh token check + pass + + # Try as refresh token + if token_type_hint != "access_token": + token_hash = hash_token(token) + result = await db.execute( + select(OAuthProviderRefreshToken).where( + OAuthProviderRefreshToken.token_hash == token_hash + ) + ) + refresh_record = result.scalar_one_or_none() + + if refresh_record and refresh_record.is_valid: + return { + "active": True, + "scope": refresh_record.scope, + "client_id": refresh_record.client_id, + "token_type": "refresh_token", + "exp": int(refresh_record.expires_at.timestamp()), + "iat": int(refresh_record.created_at.timestamp()), + "sub": str(refresh_record.user_id), + } + + return {"active": False} + + +# ============================================================================ +# Consent Management +# ============================================================================ + + +async def get_consent( + db: AsyncSession, + user_id: UUID, + client_id: str, +) -> OAuthConsent | None: + """Get existing consent record for user-client pair.""" + result = await db.execute( + select(OAuthConsent).where( + and_( + OAuthConsent.user_id == user_id, + OAuthConsent.client_id == client_id, + ) + ) + ) + return result.scalar_one_or_none() + + +async def check_consent( + db: AsyncSession, + user_id: UUID, + client_id: str, + requested_scopes: list[str], +) -> bool: + """ + Check if user has already consented to the requested scopes. + + Returns True if all requested scopes are already granted. + """ + consent = await get_consent(db, user_id, client_id) + if not consent: + return False + return consent.has_scopes(requested_scopes) + + +async def grant_consent( + db: AsyncSession, + user_id: UUID, + client_id: str, + scopes: list[str], +) -> OAuthConsent: + """ + Grant or update consent for a user-client pair. + + If consent already exists, updates the granted scopes. + """ + consent = await get_consent(db, user_id, client_id) + + if consent: + # Merge scopes + existing = set(parse_scope(consent.granted_scopes)) + new_scopes = existing | set(scopes) + consent.granted_scopes = join_scope(list(new_scopes)) + else: + consent = OAuthConsent( + user_id=user_id, + client_id=client_id, + granted_scopes=join_scope(scopes), + ) + db.add(consent) + + await db.commit() + await db.refresh(consent) + return consent + + +async def revoke_consent( + db: AsyncSession, + user_id: UUID, + client_id: str, +) -> bool: + """ + Revoke consent and all tokens for a user-client pair. + + Returns True if consent was found and revoked. + """ + # Delete consent record + result = await db.execute( + delete(OAuthConsent).where( + and_( + OAuthConsent.user_id == user_id, + OAuthConsent.client_id == client_id, + ) + ) + ) + + # Revoke all tokens + await revoke_tokens_for_user_client(db, user_id, client_id) + + await db.commit() + return result.rowcount > 0 + + +# ============================================================================ +# Cleanup +# ============================================================================ + + +async def cleanup_expired_codes(db: AsyncSession) -> int: + """ + Delete expired authorization codes. + + Should be called periodically (e.g., every hour). + + Returns: + Number of codes deleted + """ + result = await db.execute( + delete(OAuthAuthorizationCode).where( + OAuthAuthorizationCode.expires_at < datetime.now(UTC) + ) + ) + await db.commit() + return result.rowcount + + +async def cleanup_expired_tokens(db: AsyncSession) -> int: + """ + Delete expired and revoked refresh tokens. + + Should be called periodically (e.g., daily). + + Returns: + Number of tokens deleted + """ + # Delete tokens that are both expired AND revoked (or just very old) + cutoff = datetime.now(UTC) - timedelta(days=7) + result = await db.execute( + delete(OAuthProviderRefreshToken).where( + OAuthProviderRefreshToken.expires_at < cutoff + ) + ) + await db.commit() + return result.rowcount diff --git a/backend/tests/api/test_oauth.py b/backend/tests/api/test_oauth.py index ad72c13..8e4be5d 100644 --- a/backend/tests/api/test_oauth.py +++ b/backend/tests/api/test_oauth.py @@ -344,8 +344,8 @@ class TestOAuthProviderEndpoints: assert response.status_code == 404 @pytest.mark.asyncio - async def test_provider_authorize_skeleton(self, client, async_test_db): - """Test provider authorize returns not implemented (skeleton).""" + async def test_provider_authorize_requires_auth(self, client, async_test_db): + """Test provider authorize requires authentication.""" _test_engine, AsyncTestingSessionLocal = async_test_db # Create a test client @@ -374,12 +374,12 @@ class TestOAuthProviderEndpoints: "redirect_uri": "http://localhost:3000/callback", }, ) - # Should return 501 Not Implemented (skeleton) - assert response.status_code == 501 + # Authorize endpoint requires authentication + assert response.status_code == 401 @pytest.mark.asyncio - async def test_provider_token_skeleton(self, client): - """Test provider token returns not implemented (skeleton).""" + async def test_provider_token_requires_client_id(self, client): + """Test provider token requires client_id.""" with patch("app.api.routes.oauth_provider.settings") as mock_settings: mock_settings.OAUTH_PROVIDER_ENABLED = True @@ -390,5 +390,5 @@ class TestOAuthProviderEndpoints: "code": "test_code", }, ) - # Should return 501 Not Implemented (skeleton) - assert response.status_code == 501 + # Missing client_id returns 401 (invalid_client) + assert response.status_code == 401 diff --git a/backend/tests/services/test_oauth_provider_service.py b/backend/tests/services/test_oauth_provider_service.py new file mode 100644 index 0000000..7e41714 --- /dev/null +++ b/backend/tests/services/test_oauth_provider_service.py @@ -0,0 +1,726 @@ +# tests/services/test_oauth_provider_service.py +""" +Tests for OAuth Provider Service (Authorization Server mode for MCP). + +Covers: +- Authorization code creation and exchange +- Token generation, refresh, and revocation +- PKCE verification +- Token introspection (RFC 7662) +- Consent management +- Error handling +""" + +import base64 +import hashlib +import secrets +from unittest.mock import patch +from uuid import uuid4 + +import pytest +import pytest_asyncio + +from app.models.oauth_client import OAuthClient +from app.models.user import User +from app.services import oauth_provider_service as service +from app.utils.test_utils import setup_async_test_db, teardown_async_test_db + + +@pytest_asyncio.fixture(scope="function") +async def db(): + """Fixture provides testing engine and session for each test.""" + test_engine, AsyncTestingSessionLocal = await setup_async_test_db() + async with AsyncTestingSessionLocal() as session: + yield session + await teardown_async_test_db(test_engine) + + +@pytest_asyncio.fixture +async def test_user(db): + """Create a test user.""" + user = User( + id=uuid4(), + email="testuser@example.com", + password_hash="$2b$12$test", + first_name="Test", + last_name="User", + is_active=True, + is_superuser=False, + ) + db.add(user) + await db.commit() + await db.refresh(user) + return user + + +@pytest_asyncio.fixture +async def public_client(db): + """Create a test public OAuth client.""" + client = OAuthClient( + id=uuid4(), + client_id="test_public_client", + client_name="Test Public Client", + client_type="public", + redirect_uris=["http://localhost:3000/callback"], + allowed_scopes=["openid", "profile", "email", "read:users"], + is_active=True, + ) + db.add(client) + await db.commit() + await db.refresh(client) + return client + + +@pytest_asyncio.fixture +async def confidential_client(db): + """Create a test confidential OAuth client.""" + secret = "test_client_secret" + secret_hash = hashlib.sha256(secret.encode()).hexdigest() + client = OAuthClient( + id=uuid4(), + client_id="test_confidential_client", + client_name="Test Confidential Client", + client_type="confidential", + client_secret_hash=secret_hash, + redirect_uris=["http://localhost:3000/callback"], + allowed_scopes=["openid", "profile", "email"], + is_active=True, + ) + db.add(client) + await db.commit() + await db.refresh(client) + return client, secret + + +class TestHelperFunctions: + """Tests for helper functions.""" + + def test_generate_code_length(self): + """Test authorization code generation has proper length.""" + code = service.generate_code() + assert len(code) > 64 # Base64 encoding of 64 bytes + + def test_generate_code_unique(self): + """Test authorization codes are unique.""" + codes = [service.generate_code() for _ in range(100)] + assert len(set(codes)) == 100 + + def test_generate_token(self): + """Test token generation.""" + token = service.generate_token() + assert len(token) > 32 + + def test_generate_jti(self): + """Test JTI generation.""" + jti = service.generate_jti() + assert len(jti) > 20 + + def test_hash_token(self): + """Test token hashing.""" + token = "test_token" + hashed = service.hash_token(token) + assert len(hashed) == 64 # SHA-256 hex digest + + def test_hash_token_deterministic(self): + """Test same token produces same hash.""" + token = "test_token" + hash1 = service.hash_token(token) + hash2 = service.hash_token(token) + assert hash1 == hash2 + + def test_parse_scope(self): + """Test scope parsing.""" + assert service.parse_scope("openid profile email") == [ + "openid", + "profile", + "email", + ] + assert service.parse_scope("") == [] + assert service.parse_scope(" openid profile ") == ["openid", "profile"] + + def test_join_scope(self): + """Test scope joining.""" + # Result is sorted and deduplicated + result = service.join_scope(["profile", "openid", "profile"]) + assert "openid" in result + assert "profile" in result + + +class TestPKCEVerification: + """Tests for PKCE verification.""" + + def test_verify_pkce_s256_valid(self): + """Test PKCE verification with S256 method.""" + # Generate code_verifier + code_verifier = secrets.token_urlsafe(64) + + # Generate code_challenge using S256 + digest = hashlib.sha256(code_verifier.encode()).digest() + code_challenge = base64.urlsafe_b64encode(digest).rstrip(b"=").decode() + + assert service.verify_pkce(code_verifier, code_challenge, "S256") is True + + def test_verify_pkce_s256_invalid(self): + """Test PKCE verification fails with wrong verifier.""" + code_verifier = secrets.token_urlsafe(64) + wrong_verifier = secrets.token_urlsafe(64) + + digest = hashlib.sha256(code_verifier.encode()).digest() + code_challenge = base64.urlsafe_b64encode(digest).rstrip(b"=").decode() + + assert service.verify_pkce(wrong_verifier, code_challenge, "S256") is False + + def test_verify_pkce_plain(self): + """Test PKCE verification with plain method.""" + code_verifier = "test_verifier" + assert service.verify_pkce(code_verifier, code_verifier, "plain") is True + assert service.verify_pkce(code_verifier, "wrong", "plain") is False + + def test_verify_pkce_unknown_method(self): + """Test PKCE verification with unknown method returns False.""" + assert service.verify_pkce("verifier", "challenge", "unknown") is False + + +class TestClientValidation: + """Tests for client validation.""" + + @pytest.mark.asyncio + async def test_get_client_success(self, db, public_client): + """Test getting a valid client.""" + client = await service.get_client(db, public_client.client_id) + assert client is not None + assert client.client_id == public_client.client_id + + @pytest.mark.asyncio + async def test_get_client_not_found(self, db): + """Test getting a non-existent client.""" + client = await service.get_client(db, "nonexistent") + assert client is None + + @pytest.mark.asyncio + async def test_get_client_inactive(self, db, public_client): + """Test getting an inactive client returns None.""" + public_client.is_active = False + await db.commit() + + client = await service.get_client(db, public_client.client_id) + assert client is None + + @pytest.mark.asyncio + async def test_validate_client_public(self, db, public_client): + """Test validating a public client.""" + client = await service.validate_client(db, public_client.client_id) + assert client.client_id == public_client.client_id + + @pytest.mark.asyncio + async def test_validate_client_confidential_with_secret( + self, db, confidential_client + ): + """Test validating a confidential client with correct secret.""" + client, secret = confidential_client + validated = await service.validate_client(db, client.client_id, secret) + assert validated.client_id == client.client_id + + @pytest.mark.asyncio + async def test_validate_client_confidential_wrong_secret( + self, db, confidential_client + ): + """Test validating a confidential client with wrong secret.""" + client, _ = confidential_client + with pytest.raises(service.InvalidClientError, match="Invalid client secret"): + await service.validate_client(db, client.client_id, "wrong_secret") + + @pytest.mark.asyncio + async def test_validate_client_confidential_no_secret(self, db, confidential_client): + """Test validating a confidential client without secret.""" + client, _ = confidential_client + with pytest.raises(service.InvalidClientError, match="Client secret required"): + await service.validate_client(db, client.client_id) + + def test_validate_redirect_uri_success(self, public_client): + """Test validating a registered redirect URI.""" + # Should not raise + service.validate_redirect_uri(public_client, "http://localhost:3000/callback") + + def test_validate_redirect_uri_invalid(self, public_client): + """Test validating an unregistered redirect URI.""" + with pytest.raises(service.InvalidRequestError, match="Invalid redirect_uri"): + service.validate_redirect_uri(public_client, "http://evil.com/callback") + + def test_validate_redirect_uri_no_uris(self, public_client): + """Test validating when client has no URIs.""" + public_client.redirect_uris = [] + with pytest.raises(service.InvalidRequestError, match="no registered"): + service.validate_redirect_uri(public_client, "http://localhost:3000") + + +class TestScopeValidation: + """Tests for scope validation.""" + + def test_validate_scopes_all_valid(self, public_client): + """Test validating all valid scopes.""" + scopes = service.validate_scopes(public_client, ["openid", "profile"]) + assert "openid" in scopes + assert "profile" in scopes + + def test_validate_scopes_partial_valid(self, public_client): + """Test validating with some invalid scopes - filters them out.""" + scopes = service.validate_scopes(public_client, ["openid", "invalid_scope"]) + assert "openid" in scopes + assert "invalid_scope" not in scopes + + def test_validate_scopes_empty_uses_all_allowed(self, public_client): + """Test empty scope request uses all allowed scopes.""" + scopes = service.validate_scopes(public_client, []) + assert set(scopes) == set(public_client.allowed_scopes) + + def test_validate_scopes_none_valid(self, public_client): + """Test validating with no valid scopes raises error.""" + with pytest.raises(service.InvalidScopeError): + service.validate_scopes(public_client, ["invalid1", "invalid2"]) + + +class TestAuthorizationCode: + """Tests for authorization code creation and exchange.""" + + @pytest.mark.asyncio + async def test_create_authorization_code_public_with_pkce( + self, db, public_client, test_user + ): + """Test creating authorization code for public client with PKCE.""" + code_verifier = secrets.token_urlsafe(64) + digest = hashlib.sha256(code_verifier.encode()).digest() + code_challenge = base64.urlsafe_b64encode(digest).rstrip(b"=").decode() + + code = await service.create_authorization_code( + db=db, + client=public_client, + user=test_user, + redirect_uri="http://localhost:3000/callback", + scope="openid profile", + code_challenge=code_challenge, + code_challenge_method="S256", + ) + + assert code is not None + assert len(code) > 64 + + @pytest.mark.asyncio + async def test_create_authorization_code_public_without_pkce_fails( + self, db, public_client, test_user + ): + """Test creating authorization code for public client without PKCE fails.""" + with pytest.raises(service.InvalidRequestError, match="PKCE"): + await service.create_authorization_code( + db=db, + client=public_client, + user=test_user, + redirect_uri="http://localhost:3000/callback", + scope="openid", + ) + + @pytest.mark.asyncio + async def test_exchange_authorization_code_success( + self, db, public_client, test_user + ): + """Test exchanging valid authorization code for tokens.""" + # Create PKCE challenge + code_verifier = secrets.token_urlsafe(64) + digest = hashlib.sha256(code_verifier.encode()).digest() + code_challenge = base64.urlsafe_b64encode(digest).rstrip(b"=").decode() + + # Create auth code + code = await service.create_authorization_code( + db=db, + client=public_client, + user=test_user, + redirect_uri="http://localhost:3000/callback", + scope="openid profile", + code_challenge=code_challenge, + code_challenge_method="S256", + ) + + # Exchange code + with patch("app.services.oauth_provider_service.settings") as mock_settings: + mock_settings.OAUTH_ISSUER = "http://localhost:8000" + mock_settings.SECRET_KEY = "test_secret_key_for_jwt_signing_123456" + mock_settings.ALGORITHM = "HS256" + + result = await service.exchange_authorization_code( + db=db, + code=code, + client_id=public_client.client_id, + redirect_uri="http://localhost:3000/callback", + code_verifier=code_verifier, + ) + + assert "access_token" in result + assert "refresh_token" in result + assert result["token_type"] == "Bearer" + assert "expires_in" in result + + @pytest.mark.asyncio + async def test_exchange_authorization_code_invalid_code(self, db, public_client): + """Test exchanging invalid code fails.""" + with pytest.raises(service.InvalidGrantError, match="Invalid authorization"): + await service.exchange_authorization_code( + db=db, + code="invalid_code", + client_id=public_client.client_id, + redirect_uri="http://localhost:3000/callback", + ) + + @pytest.mark.asyncio + async def test_exchange_authorization_code_wrong_redirect_uri( + self, db, public_client, test_user + ): + """Test exchanging code with wrong redirect_uri fails.""" + code_verifier = secrets.token_urlsafe(64) + digest = hashlib.sha256(code_verifier.encode()).digest() + code_challenge = base64.urlsafe_b64encode(digest).rstrip(b"=").decode() + + code = await service.create_authorization_code( + db=db, + client=public_client, + user=test_user, + redirect_uri="http://localhost:3000/callback", + scope="openid", + code_challenge=code_challenge, + code_challenge_method="S256", + ) + + with pytest.raises(service.InvalidGrantError, match="redirect_uri mismatch"): + await service.exchange_authorization_code( + db=db, + code=code, + client_id=public_client.client_id, + redirect_uri="http://different.com/callback", + code_verifier=code_verifier, + ) + + @pytest.mark.asyncio + async def test_exchange_authorization_code_invalid_pkce( + self, db, public_client, test_user + ): + """Test exchanging code with invalid PKCE verifier fails.""" + code_verifier = secrets.token_urlsafe(64) + digest = hashlib.sha256(code_verifier.encode()).digest() + code_challenge = base64.urlsafe_b64encode(digest).rstrip(b"=").decode() + + code = await service.create_authorization_code( + db=db, + client=public_client, + user=test_user, + redirect_uri="http://localhost:3000/callback", + scope="openid", + code_challenge=code_challenge, + code_challenge_method="S256", + ) + + with pytest.raises(service.InvalidGrantError, match="Invalid code_verifier"): + await service.exchange_authorization_code( + db=db, + code=code, + client_id=public_client.client_id, + redirect_uri="http://localhost:3000/callback", + code_verifier="wrong_verifier", + ) + + +class TestTokenRefresh: + """Tests for token refresh.""" + + @pytest.mark.asyncio + async def test_refresh_tokens_success(self, db, public_client, test_user): + """Test refreshing tokens successfully.""" + # Create initial tokens + with patch("app.services.oauth_provider_service.settings") as mock_settings: + mock_settings.OAUTH_ISSUER = "http://localhost:8000" + mock_settings.SECRET_KEY = "test_secret_key_for_jwt_signing_123456" + mock_settings.ALGORITHM = "HS256" + + result = await service.create_tokens( + db=db, + client=public_client, + user=test_user, + scope="openid profile", + ) + + refresh_token = result["refresh_token"] + + # Refresh the tokens + new_result = await service.refresh_tokens( + db=db, + refresh_token=refresh_token, + client_id=public_client.client_id, + ) + + assert "access_token" in new_result + assert "refresh_token" in new_result + assert new_result["refresh_token"] != refresh_token # Token rotation + + @pytest.mark.asyncio + async def test_refresh_tokens_invalid_token(self, db, public_client): + """Test refreshing with invalid token fails.""" + with pytest.raises(service.InvalidGrantError, match="Invalid refresh token"): + await service.refresh_tokens( + db=db, + refresh_token="invalid_token", + client_id=public_client.client_id, + ) + + @pytest.mark.asyncio + async def test_refresh_tokens_scope_reduction(self, db, public_client, test_user): + """Test refreshing with reduced scope.""" + with patch("app.services.oauth_provider_service.settings") as mock_settings: + mock_settings.OAUTH_ISSUER = "http://localhost:8000" + mock_settings.SECRET_KEY = "test_secret_key_for_jwt_signing_123456" + mock_settings.ALGORITHM = "HS256" + + result = await service.create_tokens( + db=db, + client=public_client, + user=test_user, + scope="openid profile email", + ) + + new_result = await service.refresh_tokens( + db=db, + refresh_token=result["refresh_token"], + client_id=public_client.client_id, + scope="openid", # Reduced scope + ) + + assert "openid" in new_result["scope"] + assert "profile" not in new_result["scope"] + + @pytest.mark.asyncio + async def test_refresh_tokens_scope_expansion_fails( + self, db, public_client, test_user + ): + """Test refreshing with expanded scope fails.""" + with patch("app.services.oauth_provider_service.settings") as mock_settings: + mock_settings.OAUTH_ISSUER = "http://localhost:8000" + mock_settings.SECRET_KEY = "test_secret_key_for_jwt_signing_123456" + mock_settings.ALGORITHM = "HS256" + + result = await service.create_tokens( + db=db, + client=public_client, + user=test_user, + scope="openid", + ) + + with pytest.raises(service.InvalidScopeError, match="Cannot expand scope"): + await service.refresh_tokens( + db=db, + refresh_token=result["refresh_token"], + client_id=public_client.client_id, + scope="openid profile", # Expanded scope + ) + + +class TestTokenRevocation: + """Tests for token revocation.""" + + @pytest.mark.asyncio + async def test_revoke_refresh_token(self, db, public_client, test_user): + """Test revoking a refresh token.""" + with patch("app.services.oauth_provider_service.settings") as mock_settings: + mock_settings.OAUTH_ISSUER = "http://localhost:8000" + mock_settings.SECRET_KEY = "test_secret_key_for_jwt_signing_123456" + mock_settings.ALGORITHM = "HS256" + + result = await service.create_tokens( + db=db, + client=public_client, + user=test_user, + scope="openid", + ) + + # Revoke the token + revoked = await service.revoke_token( + db=db, + token=result["refresh_token"], + token_type_hint="refresh_token", + ) + + assert revoked is True + + # Try to use revoked token + with pytest.raises(service.InvalidGrantError, match="revoked"): + await service.refresh_tokens( + db=db, + refresh_token=result["refresh_token"], + client_id=public_client.client_id, + ) + + @pytest.mark.asyncio + async def test_revoke_all_user_tokens(self, db, public_client, test_user): + """Test revoking all tokens for a user.""" + with patch("app.services.oauth_provider_service.settings") as mock_settings: + mock_settings.OAUTH_ISSUER = "http://localhost:8000" + mock_settings.SECRET_KEY = "test_secret_key_for_jwt_signing_123456" + mock_settings.ALGORITHM = "HS256" + + # Create multiple tokens (we don't need to capture results) + await service.create_tokens( + db=db, + client=public_client, + user=test_user, + scope="openid", + ) + await service.create_tokens( + db=db, + client=public_client, + user=test_user, + scope="profile", + ) + + # Revoke all + count = await service.revoke_all_user_tokens(db, test_user.id) + assert count == 2 + + +class TestTokenIntrospection: + """Tests for token introspection (RFC 7662).""" + + @pytest.mark.asyncio + async def test_introspect_valid_access_token(self, db, public_client, test_user): + """Test introspecting a valid access token.""" + with patch("app.services.oauth_provider_service.settings") as mock_settings: + mock_settings.OAUTH_ISSUER = "http://localhost:8000" + mock_settings.SECRET_KEY = "test_secret_key_for_jwt_signing_123456" + mock_settings.ALGORITHM = "HS256" + + result = await service.create_tokens( + db=db, + client=public_client, + user=test_user, + scope="openid profile", + ) + + introspection = await service.introspect_token( + db=db, + token=result["access_token"], + ) + + assert introspection["active"] is True + assert introspection["client_id"] == public_client.client_id + assert introspection["sub"] == str(test_user.id) + + @pytest.mark.asyncio + async def test_introspect_invalid_token(self, db): + """Test introspecting an invalid token.""" + introspection = await service.introspect_token( + db=db, + token="invalid_token", + ) + assert introspection["active"] is False + + +class TestConsentManagement: + """Tests for consent management.""" + + @pytest.mark.asyncio + async def test_grant_consent(self, db, public_client, test_user): + """Test granting consent.""" + consent = await service.grant_consent( + db=db, + user_id=test_user.id, + client_id=public_client.client_id, + scopes=["openid", "profile"], + ) + + assert consent is not None + assert "openid" in consent.granted_scopes + assert "profile" in consent.granted_scopes + + @pytest.mark.asyncio + async def test_check_consent_granted(self, db, public_client, test_user): + """Test checking granted consent.""" + await service.grant_consent( + db=db, + user_id=test_user.id, + client_id=public_client.client_id, + scopes=["openid", "profile"], + ) + + has_consent = await service.check_consent( + db=db, + user_id=test_user.id, + client_id=public_client.client_id, + requested_scopes=["openid"], + ) + assert has_consent is True + + @pytest.mark.asyncio + async def test_check_consent_not_granted(self, db, public_client, test_user): + """Test checking consent that hasn't been granted.""" + has_consent = await service.check_consent( + db=db, + user_id=test_user.id, + client_id=public_client.client_id, + requested_scopes=["openid"], + ) + assert has_consent is False + + @pytest.mark.asyncio + async def test_revoke_consent(self, db, public_client, test_user): + """Test revoking consent.""" + await service.grant_consent( + db=db, + user_id=test_user.id, + client_id=public_client.client_id, + scopes=["openid"], + ) + + revoked = await service.revoke_consent( + db=db, + user_id=test_user.id, + client_id=public_client.client_id, + ) + assert revoked is True + + # Check consent is gone + has_consent = await service.check_consent( + db=db, + user_id=test_user.id, + client_id=public_client.client_id, + requested_scopes=["openid"], + ) + assert has_consent is False + + +class TestOAuthErrors: + """Tests for OAuth error classes.""" + + def test_invalid_client_error(self): + """Test InvalidClientError.""" + error = service.InvalidClientError("Test description") + assert error.error == "invalid_client" + assert error.error_description == "Test description" + + def test_invalid_grant_error(self): + """Test InvalidGrantError.""" + error = service.InvalidGrantError("Test description") + assert error.error == "invalid_grant" + assert error.error_description == "Test description" + + def test_invalid_request_error(self): + """Test InvalidRequestError.""" + error = service.InvalidRequestError("Test description") + assert error.error == "invalid_request" + assert error.error_description == "Test description" + + def test_invalid_scope_error(self): + """Test InvalidScopeError.""" + error = service.InvalidScopeError("Test description") + assert error.error == "invalid_scope" + assert error.error_description == "Test description" + + def test_access_denied_error(self): + """Test AccessDeniedError.""" + error = service.AccessDeniedError("Test description") + assert error.error == "access_denied" + assert error.error_description == "Test description" diff --git a/frontend/src/app/[locale]/(auth)/auth/consent/page.tsx b/frontend/src/app/[locale]/(auth)/auth/consent/page.tsx new file mode 100644 index 0000000..c6349c9 --- /dev/null +++ b/frontend/src/app/[locale]/(auth)/auth/consent/page.tsx @@ -0,0 +1,325 @@ +/** + * OAuth Consent Page + * Displays authorization consent form for OAuth provider mode (MCP integration). + * + * Users are redirected here when an external application (MCP client) requests + * access to their account. They can approve or deny the requested permissions. + */ + +'use client'; + +import { useState, useEffect } from 'react'; +import { useSearchParams } from 'next/navigation'; +import { useRouter } from '@/lib/i18n/routing'; +import { useTranslations } from 'next-intl'; +import { Button } from '@/components/ui/button'; +import { + Card, + CardContent, + CardDescription, + CardFooter, + CardHeader, + CardTitle, +} from '@/components/ui/card'; +import { Alert, AlertDescription } from '@/components/ui/alert'; +import { Checkbox } from '@/components/ui/checkbox'; +import { Label } from '@/components/ui/label'; +import { Loader2, Shield, AlertTriangle, ExternalLink, CheckCircle2 } from 'lucide-react'; +import { useAuth } from '@/lib/auth/AuthContext'; +import config from '@/config/app.config'; + +// Scope descriptions for display +const SCOPE_INFO: Record = { + openid: { + name: 'OpenID Connect', + description: 'Verify your identity', + icon: 'user', + }, + profile: { + name: 'Profile', + description: 'Access your name and basic profile information', + icon: 'user-circle', + }, + email: { + name: 'Email', + description: 'Access your email address', + icon: 'mail', + }, + 'read:users': { + name: 'Read Users', + description: 'View user information', + icon: 'users', + }, + 'write:users': { + name: 'Write Users', + description: 'Modify user information', + icon: 'user-edit', + }, + 'read:organizations': { + name: 'Read Organizations', + description: 'View organization information', + icon: 'building', + }, + 'write:organizations': { + name: 'Write Organizations', + description: 'Modify organization information', + icon: 'building-edit', + }, + admin: { + name: 'Admin Access', + description: 'Full administrative access', + icon: 'shield', + }, +}; + +interface ConsentParams { + clientId: string; + clientName: string; + redirectUri: string; + scope: string; + state: string; + codeChallenge: string; + codeChallengeMethod: string; + nonce: string; +} + +export default function OAuthConsentPage() { + const searchParams = useSearchParams(); + const router = useRouter(); + // Note: t is available for future i18n use + const _t = useTranslations('auth.oauth'); + void _t; // Suppress unused warning - ready for i18n + const { isAuthenticated, isLoading: authLoading } = useAuth(); + + const [isSubmitting, setIsSubmitting] = useState(false); + const [error, setError] = useState(null); + const [selectedScopes, setSelectedScopes] = useState>(new Set()); + const [params, setParams] = useState(null); + + // Parse URL parameters + useEffect(() => { + const clientId = searchParams.get('client_id') || ''; + const clientName = searchParams.get('client_name') || 'Application'; + const redirectUri = searchParams.get('redirect_uri') || ''; + const scope = searchParams.get('scope') || ''; + const state = searchParams.get('state') || ''; + const codeChallenge = searchParams.get('code_challenge') || ''; + const codeChallengeMethod = searchParams.get('code_challenge_method') || ''; + const nonce = searchParams.get('nonce') || ''; + + if (!clientId || !redirectUri) { + setError('Invalid authorization request. Missing required parameters.'); + return; + } + + setParams({ + clientId, + clientName, + redirectUri, + scope, + state, + codeChallenge, + codeChallengeMethod, + nonce, + }); + + // Initialize selected scopes with all requested scopes + if (scope) { + setSelectedScopes(new Set(scope.split(' '))); + } + }, [searchParams]); + + // Redirect to login if not authenticated + useEffect(() => { + if (!authLoading && !isAuthenticated) { + const returnUrl = `/auth/consent?${searchParams.toString()}`; + router.push(`${config.routes.login}?return_to=${encodeURIComponent(returnUrl)}`); + } + }, [authLoading, isAuthenticated, router, searchParams]); + + const handleScopeToggle = (scope: string) => { + setSelectedScopes((prev) => { + const next = new Set(prev); + if (next.has(scope)) { + next.delete(scope); + } else { + next.add(scope); + } + return next; + }); + }; + + const handleSubmit = async (approved: boolean) => { + if (!params) return; + + setIsSubmitting(true); + setError(null); + + try { + // Create form data for consent submission + const formData = new FormData(); + formData.append('approved', approved.toString()); + formData.append('client_id', params.clientId); + formData.append('redirect_uri', params.redirectUri); + formData.append('scope', Array.from(selectedScopes).join(' ')); + formData.append('state', params.state); + if (params.codeChallenge) { + formData.append('code_challenge', params.codeChallenge); + } + if (params.codeChallengeMethod) { + formData.append('code_challenge_method', params.codeChallengeMethod); + } + if (params.nonce) { + formData.append('nonce', params.nonce); + } + + // Submit consent to backend + const apiUrl = process.env.NEXT_PUBLIC_API_URL || 'http://localhost:8000'; + const response = await fetch(`${apiUrl}/api/v1/oauth/provider/authorize/consent`, { + method: 'POST', + body: formData, + credentials: 'include', + }); + + // The endpoint returns a redirect, so follow it + if (response.redirected) { + window.location.href = response.url; + } else if (!response.ok) { + const data = await response.json(); + throw new Error(data.detail || 'Failed to process consent'); + } + } catch (err) { + setError(err instanceof Error ? err.message : 'An unexpected error occurred'); + setIsSubmitting(false); + } + }; + + // Show loading state while checking auth + if (authLoading) { + return ( +
+
+ +

Loading...

+
+
+ ); + } + + // Show error state + if (error && !params) { + return ( +
+
+ + + {error} + +
+ +
+
+
+ ); + } + + if (!params) { + return null; + } + + const requestedScopes = params.scope ? params.scope.split(' ') : []; + + return ( +
+ + +
+ +
+ Authorization Request + + {params.clientName} wants to + access your account + +
+ + + {error && ( + + + {error} + + )} + +
+

This application will be able to:

+
+ {requestedScopes.map((scope) => { + const info = SCOPE_INFO[scope] || { + name: scope, + description: `Access to ${scope}`, + }; + const isSelected = selectedScopes.has(scope); + + return ( +
+ handleScopeToggle(scope)} + disabled={isSubmitting} + /> +
+ +

{info.description}

+
+ {isSelected && } +
+ ); + })} +
+
+ + + + + After authorization, you will be redirected to: +
+ + {params.redirectUri} + +
+
+
+ + + + + +
+
+ ); +}