Compare commits
6 Commits
d49f819469
...
dc875c5c95
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
dc875c5c95 | ||
|
|
0ea428b718 | ||
|
|
400d6f6f75 | ||
|
|
7716468d72 | ||
|
|
48f052200f | ||
|
|
fbb030da69 |
28
AGENTS.md
28
AGENTS.md
@@ -42,7 +42,7 @@ Default superuser (change in production):
|
||||
│ │ ├── schemas/ # Pydantic request/response schemas
|
||||
│ │ ├── services/ # Business logic layer
|
||||
│ │ └── utils/ # Utilities (security, device detection)
|
||||
│ ├── tests/ # 97% coverage, 743 tests
|
||||
│ ├── tests/ # 96% coverage, 987 tests
|
||||
│ └── alembic/ # Database migrations
|
||||
│
|
||||
└── frontend/ # Next.js 15 frontend
|
||||
@@ -69,6 +69,27 @@ Default superuser (change in production):
|
||||
- `get_optional_current_user`: Accepts authenticated or anonymous
|
||||
- `get_current_superuser`: Requires superuser flag
|
||||
|
||||
### OAuth Provider Mode (MCP Integration)
|
||||
Full OAuth 2.0 Authorization Server for MCP (Model Context Protocol) clients:
|
||||
- **Authorization Code Flow with PKCE**: RFC 7636 compliant
|
||||
- **JWT access tokens**: Self-contained, no DB lookup required
|
||||
- **Opaque refresh tokens**: Stored hashed in database, supports rotation
|
||||
- **Token introspection**: RFC 7662 compliant endpoint
|
||||
- **Token revocation**: RFC 7009 compliant endpoint
|
||||
- **Server metadata**: RFC 8414 compliant discovery endpoint
|
||||
- **Consent management**: User can review and revoke app permissions
|
||||
|
||||
**API endpoints:**
|
||||
- `GET /.well-known/oauth-authorization-server` - Server metadata
|
||||
- `GET /oauth/provider/authorize` - Authorization endpoint
|
||||
- `POST /oauth/provider/authorize/consent` - Consent submission
|
||||
- `POST /oauth/provider/token` - Token endpoint
|
||||
- `POST /oauth/provider/revoke` - Token revocation
|
||||
- `POST /oauth/provider/introspect` - Token introspection
|
||||
- Client management endpoints (admin only)
|
||||
|
||||
**Scopes supported:** `openid`, `profile`, `email`, `read:users`, `write:users`, `admin`
|
||||
|
||||
### Database Pattern
|
||||
- **Async SQLAlchemy 2.0** with PostgreSQL
|
||||
- **Connection pooling**: 20 base connections, 50 max overflow
|
||||
@@ -107,7 +128,7 @@ Permission dependencies in `api/dependencies/permissions.py`:
|
||||
### Testing Infrastructure
|
||||
|
||||
**Backend Unit/Integration (pytest + SQLite):**
|
||||
- 97% coverage, 743+ tests
|
||||
- 96% coverage, 987 tests
|
||||
- Security-focused: JWT attacks, session hijacking, privilege escalation
|
||||
- Async fixtures in `tests/conftest.py`
|
||||
- Run: `IS_TEST=True uv run pytest` or `make test`
|
||||
@@ -238,12 +259,13 @@ docker-compose exec backend python -c "from app.init_db import init_db; import a
|
||||
|
||||
### Completed Features ✅
|
||||
- Authentication system (JWT with refresh tokens, OAuth/social login)
|
||||
- **OAuth Provider Mode (MCP-ready)**: Full OAuth 2.0 Authorization Server
|
||||
- Session management (device tracking, revocation)
|
||||
- User management (CRUD, password change)
|
||||
- Organization system (multi-tenant with RBAC)
|
||||
- Admin panel (user/org management, bulk operations)
|
||||
- **Internationalization (i18n)** with English and Italian
|
||||
- Comprehensive test coverage (97% backend, 97% frontend unit, 56 E2E tests)
|
||||
- Comprehensive test coverage (96% backend, 97% frontend unit, 56 E2E tests)
|
||||
- Design system documentation
|
||||
- **Marketing landing page** with animations
|
||||
- **`/dev` documentation portal** with live examples
|
||||
|
||||
@@ -0,0 +1,194 @@
|
||||
"""Add OAuth provider models for MCP integration.
|
||||
|
||||
Revision ID: f8c3d2e1a4b5
|
||||
Revises: d5a7b2c9e1f3
|
||||
Create Date: 2025-01-15 10:00:00.000000
|
||||
|
||||
This migration adds tables for OAuth provider mode:
|
||||
- oauth_authorization_codes: Temporary authorization codes
|
||||
- oauth_provider_refresh_tokens: Long-lived refresh tokens
|
||||
- oauth_consents: User consent records
|
||||
"""
|
||||
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "f8c3d2e1a4b5"
|
||||
down_revision = "d5a7b2c9e1f3"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Create oauth_authorization_codes table
|
||||
op.create_table(
|
||||
"oauth_authorization_codes",
|
||||
sa.Column("id", postgresql.UUID(as_uuid=True), nullable=False),
|
||||
sa.Column("code", sa.String(128), nullable=False),
|
||||
sa.Column("client_id", sa.String(64), nullable=False),
|
||||
sa.Column("user_id", postgresql.UUID(as_uuid=True), nullable=False),
|
||||
sa.Column("redirect_uri", sa.String(2048), nullable=False),
|
||||
sa.Column("scope", sa.String(1000), nullable=False, server_default=""),
|
||||
sa.Column("code_challenge", sa.String(128), nullable=True),
|
||||
sa.Column("code_challenge_method", sa.String(10), nullable=True),
|
||||
sa.Column("state", sa.String(256), nullable=True),
|
||||
sa.Column("nonce", sa.String(256), nullable=True),
|
||||
sa.Column("expires_at", sa.DateTime(timezone=True), nullable=False),
|
||||
sa.Column("used", sa.Boolean(), nullable=False, server_default="false"),
|
||||
sa.Column(
|
||||
"created_at",
|
||||
sa.DateTime(timezone=True),
|
||||
server_default=sa.text("now()"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column(
|
||||
"updated_at",
|
||||
sa.DateTime(timezone=True),
|
||||
server_default=sa.text("now()"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.ForeignKeyConstraint(
|
||||
["client_id"],
|
||||
["oauth_clients.client_id"],
|
||||
ondelete="CASCADE",
|
||||
),
|
||||
sa.ForeignKeyConstraint(
|
||||
["user_id"],
|
||||
["users.id"],
|
||||
ondelete="CASCADE",
|
||||
),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
)
|
||||
op.create_index(
|
||||
"ix_oauth_authorization_codes_code",
|
||||
"oauth_authorization_codes",
|
||||
["code"],
|
||||
unique=True,
|
||||
)
|
||||
op.create_index(
|
||||
"ix_oauth_authorization_codes_expires_at",
|
||||
"oauth_authorization_codes",
|
||||
["expires_at"],
|
||||
)
|
||||
op.create_index(
|
||||
"ix_oauth_authorization_codes_client_user",
|
||||
"oauth_authorization_codes",
|
||||
["client_id", "user_id"],
|
||||
)
|
||||
|
||||
# Create oauth_provider_refresh_tokens table
|
||||
op.create_table(
|
||||
"oauth_provider_refresh_tokens",
|
||||
sa.Column("id", postgresql.UUID(as_uuid=True), nullable=False),
|
||||
sa.Column("token_hash", sa.String(64), nullable=False),
|
||||
sa.Column("jti", sa.String(64), nullable=False),
|
||||
sa.Column("client_id", sa.String(64), nullable=False),
|
||||
sa.Column("user_id", postgresql.UUID(as_uuid=True), nullable=False),
|
||||
sa.Column("scope", sa.String(1000), nullable=False, server_default=""),
|
||||
sa.Column("expires_at", sa.DateTime(timezone=True), nullable=False),
|
||||
sa.Column("revoked", sa.Boolean(), nullable=False, server_default="false"),
|
||||
sa.Column("last_used_at", sa.DateTime(timezone=True), nullable=True),
|
||||
sa.Column("device_info", sa.String(500), nullable=True),
|
||||
sa.Column("ip_address", sa.String(45), nullable=True),
|
||||
sa.Column(
|
||||
"created_at",
|
||||
sa.DateTime(timezone=True),
|
||||
server_default=sa.text("now()"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column(
|
||||
"updated_at",
|
||||
sa.DateTime(timezone=True),
|
||||
server_default=sa.text("now()"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.ForeignKeyConstraint(
|
||||
["client_id"],
|
||||
["oauth_clients.client_id"],
|
||||
ondelete="CASCADE",
|
||||
),
|
||||
sa.ForeignKeyConstraint(
|
||||
["user_id"],
|
||||
["users.id"],
|
||||
ondelete="CASCADE",
|
||||
),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
)
|
||||
op.create_index(
|
||||
"ix_oauth_provider_refresh_tokens_token_hash",
|
||||
"oauth_provider_refresh_tokens",
|
||||
["token_hash"],
|
||||
unique=True,
|
||||
)
|
||||
op.create_index(
|
||||
"ix_oauth_provider_refresh_tokens_jti",
|
||||
"oauth_provider_refresh_tokens",
|
||||
["jti"],
|
||||
unique=True,
|
||||
)
|
||||
op.create_index(
|
||||
"ix_oauth_provider_refresh_tokens_expires_at",
|
||||
"oauth_provider_refresh_tokens",
|
||||
["expires_at"],
|
||||
)
|
||||
op.create_index(
|
||||
"ix_oauth_provider_refresh_tokens_client_user",
|
||||
"oauth_provider_refresh_tokens",
|
||||
["client_id", "user_id"],
|
||||
)
|
||||
op.create_index(
|
||||
"ix_oauth_provider_refresh_tokens_user_revoked",
|
||||
"oauth_provider_refresh_tokens",
|
||||
["user_id", "revoked"],
|
||||
)
|
||||
op.create_index(
|
||||
"ix_oauth_provider_refresh_tokens_revoked",
|
||||
"oauth_provider_refresh_tokens",
|
||||
["revoked"],
|
||||
)
|
||||
|
||||
# Create oauth_consents table
|
||||
op.create_table(
|
||||
"oauth_consents",
|
||||
sa.Column("id", postgresql.UUID(as_uuid=True), nullable=False),
|
||||
sa.Column("user_id", postgresql.UUID(as_uuid=True), nullable=False),
|
||||
sa.Column("client_id", sa.String(64), nullable=False),
|
||||
sa.Column("granted_scopes", sa.String(1000), nullable=False, server_default=""),
|
||||
sa.Column(
|
||||
"created_at",
|
||||
sa.DateTime(timezone=True),
|
||||
server_default=sa.text("now()"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column(
|
||||
"updated_at",
|
||||
sa.DateTime(timezone=True),
|
||||
server_default=sa.text("now()"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.ForeignKeyConstraint(
|
||||
["client_id"],
|
||||
["oauth_clients.client_id"],
|
||||
ondelete="CASCADE",
|
||||
),
|
||||
sa.ForeignKeyConstraint(
|
||||
["user_id"],
|
||||
["users.id"],
|
||||
ondelete="CASCADE",
|
||||
),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
)
|
||||
op.create_index(
|
||||
"ix_oauth_consents_user_client",
|
||||
"oauth_consents",
|
||||
["user_id", "client_id"],
|
||||
unique=True,
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_table("oauth_consents")
|
||||
op.drop_table("oauth_provider_refresh_tokens")
|
||||
op.drop_table("oauth_authorization_codes")
|
||||
@@ -1,37 +1,63 @@
|
||||
# app/api/routes/oauth_provider.py
|
||||
"""
|
||||
OAuth Provider routes (Authorization Server mode).
|
||||
OAuth Provider routes (Authorization Server mode) for MCP integration.
|
||||
|
||||
This is a skeleton implementation for MCP (Model Context Protocol) client authentication.
|
||||
Provides basic OAuth 2.0 endpoints that can be expanded for full functionality.
|
||||
|
||||
Endpoints:
|
||||
Implements OAuth 2.0 Authorization Server endpoints:
|
||||
- GET /.well-known/oauth-authorization-server - Server metadata (RFC 8414)
|
||||
- GET /oauth/provider/authorize - Authorization endpoint (skeleton)
|
||||
- POST /oauth/provider/token - Token endpoint (skeleton)
|
||||
- POST /oauth/provider/revoke - Token revocation endpoint (skeleton)
|
||||
- GET /oauth/provider/authorize - Authorization endpoint
|
||||
- POST /oauth/provider/token - Token endpoint
|
||||
- POST /oauth/provider/revoke - Token revocation (RFC 7009)
|
||||
- POST /oauth/provider/introspect - Token introspection (RFC 7662)
|
||||
- Client management endpoints
|
||||
|
||||
NOTE: This is intentionally minimal. Full implementation should include:
|
||||
- Complete authorization code flow
|
||||
- Refresh token handling
|
||||
- Scope validation
|
||||
- Client authentication
|
||||
- PKCE support
|
||||
Security features:
|
||||
- PKCE required for public clients (S256)
|
||||
- CSRF protection via state parameter
|
||||
- Secure token handling
|
||||
- Rate limiting on sensitive endpoints
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
from urllib.parse import urlencode
|
||||
|
||||
from fastapi import APIRouter, Depends, Form, HTTPException, Query, status
|
||||
from fastapi import APIRouter, Depends, Form, HTTPException, Query, Request, status
|
||||
from fastapi.responses import RedirectResponse
|
||||
from slowapi import Limiter
|
||||
from slowapi.util import get_remote_address
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.api.dependencies.auth import get_current_active_user, get_current_superuser
|
||||
from app.core.config import settings
|
||||
from app.core.database import get_db
|
||||
from app.crud import oauth_client
|
||||
from app.schemas.oauth import OAuthServerMetadata
|
||||
from app.crud import oauth_client as oauth_client_crud
|
||||
from app.models.user import User
|
||||
from app.schemas.oauth import (
|
||||
OAuthClientCreate,
|
||||
OAuthClientResponse,
|
||||
OAuthServerMetadata,
|
||||
OAuthTokenIntrospectionResponse,
|
||||
OAuthTokenResponse,
|
||||
)
|
||||
from app.services import oauth_provider_service as provider_service
|
||||
|
||||
router = APIRouter()
|
||||
logger = logging.getLogger(__name__)
|
||||
limiter = Limiter(key_func=get_remote_address)
|
||||
|
||||
|
||||
def require_provider_enabled():
|
||||
"""Dependency to check if OAuth provider mode is enabled."""
|
||||
if not settings.OAUTH_PROVIDER_ENABLED:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="OAuth provider mode is not enabled. Set OAUTH_PROVIDER_ENABLED=true",
|
||||
)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Server Metadata (RFC 8414)
|
||||
# ============================================================================
|
||||
|
||||
|
||||
@router.get(
|
||||
@@ -42,24 +68,15 @@ logger = logging.getLogger(__name__)
|
||||
OAuth 2.0 Authorization Server Metadata (RFC 8414).
|
||||
|
||||
Returns server metadata including supported endpoints, scopes,
|
||||
and capabilities for MCP clients.
|
||||
and capabilities. MCP clients use this to discover the server.
|
||||
""",
|
||||
operation_id="get_oauth_server_metadata",
|
||||
tags=["OAuth Provider"],
|
||||
)
|
||||
async def get_server_metadata() -> Any:
|
||||
"""
|
||||
Get OAuth 2.0 server metadata.
|
||||
|
||||
This endpoint is used by MCP clients to discover the authorization
|
||||
server's capabilities.
|
||||
"""
|
||||
if not settings.OAUTH_PROVIDER_ENABLED:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="OAuth provider mode is not enabled",
|
||||
)
|
||||
|
||||
async def get_server_metadata(
|
||||
_: None = Depends(require_provider_enabled),
|
||||
) -> OAuthServerMetadata:
|
||||
"""Get OAuth 2.0 server metadata."""
|
||||
base_url = settings.OAUTH_ISSUER.rstrip("/")
|
||||
|
||||
return OAuthServerMetadata(
|
||||
@@ -67,7 +84,8 @@ async def get_server_metadata() -> Any:
|
||||
authorization_endpoint=f"{base_url}/api/v1/oauth/provider/authorize",
|
||||
token_endpoint=f"{base_url}/api/v1/oauth/provider/token",
|
||||
revocation_endpoint=f"{base_url}/api/v1/oauth/provider/revoke",
|
||||
registration_endpoint=None, # Dynamic registration not implemented
|
||||
introspection_endpoint=f"{base_url}/api/v1/oauth/provider/introspect",
|
||||
registration_endpoint=None, # Dynamic registration not supported
|
||||
scopes_supported=[
|
||||
"openid",
|
||||
"profile",
|
||||
@@ -76,148 +94,446 @@ async def get_server_metadata() -> Any:
|
||||
"write:users",
|
||||
"read:organizations",
|
||||
"write:organizations",
|
||||
"admin",
|
||||
],
|
||||
response_types_supported=["code"],
|
||||
grant_types_supported=["authorization_code", "refresh_token"],
|
||||
code_challenge_methods_supported=["S256"],
|
||||
token_endpoint_auth_methods_supported=[
|
||||
"client_secret_basic",
|
||||
"client_secret_post",
|
||||
"none", # For public clients with PKCE
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Authorization Endpoint
|
||||
# ============================================================================
|
||||
|
||||
|
||||
@router.get(
|
||||
"/provider/authorize",
|
||||
summary="Authorization Endpoint (Skeleton)",
|
||||
summary="Authorization Endpoint",
|
||||
description="""
|
||||
OAuth 2.0 Authorization Endpoint.
|
||||
|
||||
**NOTE**: This is a skeleton implementation. In a full implementation,
|
||||
this would:
|
||||
1. Validate client_id and redirect_uri
|
||||
2. Display consent screen to user
|
||||
3. Generate authorization code
|
||||
4. Redirect back to client with code
|
||||
Initiates the authorization code flow:
|
||||
1. Validates client and parameters
|
||||
2. Checks if user is authenticated (redirects to login if not)
|
||||
3. Checks existing consent
|
||||
4. Redirects to consent page if needed
|
||||
5. Issues authorization code and redirects back to client
|
||||
|
||||
Currently returns a 501 Not Implemented response.
|
||||
Required parameters:
|
||||
- response_type: Must be "code"
|
||||
- client_id: Registered client ID
|
||||
- redirect_uri: Must match registered URI
|
||||
|
||||
Recommended parameters:
|
||||
- state: CSRF protection
|
||||
- code_challenge + code_challenge_method: PKCE (required for public clients)
|
||||
- scope: Requested permissions
|
||||
""",
|
||||
operation_id="oauth_provider_authorize",
|
||||
tags=["OAuth Provider"],
|
||||
)
|
||||
@limiter.limit("30/minute")
|
||||
async def authorize(
|
||||
request: Request,
|
||||
response_type: str = Query(..., description="Must be 'code'"),
|
||||
client_id: str = Query(..., description="OAuth client ID"),
|
||||
redirect_uri: str = Query(..., description="Redirect URI"),
|
||||
scope: str = Query(default="", description="Requested scopes"),
|
||||
scope: str = Query(default="", description="Requested scopes (space-separated)"),
|
||||
state: str = Query(default="", description="CSRF state parameter"),
|
||||
code_challenge: str | None = Query(default=None, description="PKCE code challenge"),
|
||||
code_challenge_method: str | None = Query(
|
||||
default=None, description="PKCE method (S256)"
|
||||
),
|
||||
nonce: str | None = Query(default=None, description="OpenID Connect nonce"),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
_: None = Depends(require_provider_enabled),
|
||||
current_user: User | None = Depends(get_current_active_user),
|
||||
) -> Any:
|
||||
"""
|
||||
Authorization endpoint (skeleton).
|
||||
Authorization endpoint - initiates OAuth flow.
|
||||
|
||||
In a full implementation, this would:
|
||||
1. Validate the client and redirect URI
|
||||
2. Authenticate the user (if not already)
|
||||
3. Show consent screen
|
||||
4. Generate authorization code
|
||||
5. Redirect to redirect_uri with code
|
||||
If user is not authenticated, redirects to login with return URL.
|
||||
If user has not consented, redirects to consent page.
|
||||
If all checks pass, generates code and redirects to client.
|
||||
"""
|
||||
if not settings.OAUTH_PROVIDER_ENABLED:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="OAuth provider mode is not enabled",
|
||||
)
|
||||
|
||||
# Validate client exists
|
||||
client = await oauth_client.get_by_client_id(db, client_id=client_id)
|
||||
if not client:
|
||||
# Validate response_type
|
||||
if response_type != "code":
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="invalid_client: Unknown client_id",
|
||||
detail="invalid_request: response_type must be 'code'",
|
||||
)
|
||||
|
||||
# Validate redirect_uri
|
||||
if redirect_uri not in (client.redirect_uris or []):
|
||||
# Validate PKCE method if provided - ONLY S256 is allowed (RFC 7636 Section 4.3)
|
||||
# "plain" method provides no security benefit and MUST NOT be used
|
||||
if code_challenge_method and code_challenge_method != "S256":
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="invalid_request: Invalid redirect_uri",
|
||||
detail="invalid_request: code_challenge_method must be 'S256' (plain is not supported)",
|
||||
)
|
||||
|
||||
# Skeleton: Return not implemented
|
||||
# Full implementation would redirect to consent screen
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_501_NOT_IMPLEMENTED,
|
||||
detail="Authorization endpoint not fully implemented. "
|
||||
"This is a skeleton for MCP integration.",
|
||||
# Validate client
|
||||
try:
|
||||
client = await provider_service.get_client(db, client_id)
|
||||
if not client:
|
||||
raise provider_service.InvalidClientError("Unknown client_id")
|
||||
provider_service.validate_redirect_uri(client, redirect_uri)
|
||||
except provider_service.OAuthProviderError as e:
|
||||
# For client/redirect errors, we can't safely redirect - show error
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"{e.error}: {e.error_description}",
|
||||
)
|
||||
|
||||
# Validate and filter scopes
|
||||
try:
|
||||
requested_scopes = provider_service.parse_scope(scope)
|
||||
valid_scopes = provider_service.validate_scopes(client, requested_scopes)
|
||||
except provider_service.InvalidScopeError as e:
|
||||
# Redirect with error
|
||||
scope_error_params: dict[str, str] = {"error": e.error}
|
||||
if e.error_description:
|
||||
scope_error_params["error_description"] = e.error_description
|
||||
if state:
|
||||
scope_error_params["state"] = state
|
||||
return RedirectResponse(
|
||||
url=f"{redirect_uri}?{urlencode(scope_error_params)}",
|
||||
status_code=status.HTTP_302_FOUND,
|
||||
)
|
||||
|
||||
# Public clients MUST use PKCE
|
||||
if client.client_type == "public":
|
||||
if not code_challenge or code_challenge_method != "S256":
|
||||
pkce_error_params: dict[str, str] = {
|
||||
"error": "invalid_request",
|
||||
"error_description": "PKCE with S256 is required for public clients",
|
||||
}
|
||||
if state:
|
||||
pkce_error_params["state"] = state
|
||||
return RedirectResponse(
|
||||
url=f"{redirect_uri}?{urlencode(pkce_error_params)}",
|
||||
status_code=status.HTTP_302_FOUND,
|
||||
)
|
||||
|
||||
# If user is not authenticated, redirect to login
|
||||
if not current_user:
|
||||
# Store authorization request in session and redirect to login
|
||||
# The frontend will handle the return URL
|
||||
login_url = f"{settings.FRONTEND_URL}/login"
|
||||
return_params = urlencode(
|
||||
{
|
||||
"oauth_authorize": "true",
|
||||
"client_id": client_id,
|
||||
"redirect_uri": redirect_uri,
|
||||
"scope": " ".join(valid_scopes),
|
||||
"state": state,
|
||||
"code_challenge": code_challenge or "",
|
||||
"code_challenge_method": code_challenge_method or "",
|
||||
"nonce": nonce or "",
|
||||
}
|
||||
)
|
||||
return RedirectResponse(
|
||||
url=f"{login_url}?return_to=/auth/consent?{return_params}",
|
||||
status_code=status.HTTP_302_FOUND,
|
||||
)
|
||||
|
||||
# Check if user has already consented
|
||||
has_consent = await provider_service.check_consent(
|
||||
db, current_user.id, client_id, valid_scopes
|
||||
)
|
||||
|
||||
if not has_consent:
|
||||
# Redirect to consent page
|
||||
consent_params = urlencode(
|
||||
{
|
||||
"client_id": client_id,
|
||||
"client_name": client.client_name,
|
||||
"redirect_uri": redirect_uri,
|
||||
"scope": " ".join(valid_scopes),
|
||||
"state": state,
|
||||
"code_challenge": code_challenge or "",
|
||||
"code_challenge_method": code_challenge_method or "",
|
||||
"nonce": nonce or "",
|
||||
}
|
||||
)
|
||||
return RedirectResponse(
|
||||
url=f"{settings.FRONTEND_URL}/auth/consent?{consent_params}",
|
||||
status_code=status.HTTP_302_FOUND,
|
||||
)
|
||||
|
||||
# User is authenticated and has consented - issue authorization code
|
||||
try:
|
||||
code = await provider_service.create_authorization_code(
|
||||
db=db,
|
||||
client=client,
|
||||
user=current_user,
|
||||
redirect_uri=redirect_uri,
|
||||
scope=" ".join(valid_scopes),
|
||||
code_challenge=code_challenge,
|
||||
code_challenge_method=code_challenge_method,
|
||||
state=state,
|
||||
nonce=nonce,
|
||||
)
|
||||
except provider_service.OAuthProviderError as e:
|
||||
error_params: dict[str, str] = {"error": e.error}
|
||||
if e.error_description:
|
||||
error_params["error_description"] = e.error_description
|
||||
if state:
|
||||
error_params["state"] = state
|
||||
return RedirectResponse(
|
||||
url=f"{redirect_uri}?{urlencode(error_params)}",
|
||||
status_code=status.HTTP_302_FOUND,
|
||||
)
|
||||
|
||||
# Success - redirect with code
|
||||
success_params = {"code": code}
|
||||
if state:
|
||||
success_params["state"] = state
|
||||
|
||||
return RedirectResponse(
|
||||
url=f"{redirect_uri}?{urlencode(success_params)}",
|
||||
status_code=status.HTTP_302_FOUND,
|
||||
)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/provider/authorize/consent",
|
||||
summary="Submit Authorization Consent",
|
||||
description="""
|
||||
Submit user consent for OAuth authorization.
|
||||
|
||||
Called by the consent page after user approves or denies.
|
||||
""",
|
||||
operation_id="oauth_provider_consent",
|
||||
tags=["OAuth Provider"],
|
||||
)
|
||||
@limiter.limit("30/minute")
|
||||
async def submit_consent(
|
||||
request: Request,
|
||||
approved: bool = Form(..., description="Whether user approved"),
|
||||
client_id: str = Form(..., description="OAuth client ID"),
|
||||
redirect_uri: str = Form(..., description="Redirect URI"),
|
||||
scope: str = Form(default="", description="Granted scopes"),
|
||||
state: str = Form(default="", description="CSRF state parameter"),
|
||||
code_challenge: str | None = Form(default=None),
|
||||
code_challenge_method: str | None = Form(default=None),
|
||||
nonce: str | None = Form(default=None),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
_: None = Depends(require_provider_enabled),
|
||||
current_user: User = Depends(get_current_active_user),
|
||||
) -> Any:
|
||||
"""Process consent form submission."""
|
||||
# Validate client
|
||||
try:
|
||||
client = await provider_service.get_client(db, client_id)
|
||||
if not client:
|
||||
raise provider_service.InvalidClientError("Unknown client_id")
|
||||
provider_service.validate_redirect_uri(client, redirect_uri)
|
||||
except provider_service.OAuthProviderError as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"{e.error}: {e.error_description}",
|
||||
)
|
||||
|
||||
# If user denied, redirect with error
|
||||
if not approved:
|
||||
denied_params: dict[str, str] = {
|
||||
"error": "access_denied",
|
||||
"error_description": "User denied authorization",
|
||||
}
|
||||
if state:
|
||||
denied_params["state"] = state
|
||||
return RedirectResponse(
|
||||
url=f"{redirect_uri}?{urlencode(denied_params)}",
|
||||
status_code=status.HTTP_302_FOUND,
|
||||
)
|
||||
|
||||
# Parse and validate scopes
|
||||
granted_scopes = provider_service.parse_scope(scope)
|
||||
valid_scopes = provider_service.validate_scopes(client, granted_scopes)
|
||||
|
||||
# Record consent
|
||||
await provider_service.grant_consent(db, current_user.id, client_id, valid_scopes)
|
||||
|
||||
# Generate authorization code
|
||||
try:
|
||||
code = await provider_service.create_authorization_code(
|
||||
db=db,
|
||||
client=client,
|
||||
user=current_user,
|
||||
redirect_uri=redirect_uri,
|
||||
scope=" ".join(valid_scopes),
|
||||
code_challenge=code_challenge,
|
||||
code_challenge_method=code_challenge_method,
|
||||
state=state,
|
||||
nonce=nonce,
|
||||
)
|
||||
except provider_service.OAuthProviderError as e:
|
||||
error_params: dict[str, str] = {"error": e.error}
|
||||
if e.error_description:
|
||||
error_params["error_description"] = e.error_description
|
||||
if state:
|
||||
error_params["state"] = state
|
||||
return RedirectResponse(
|
||||
url=f"{redirect_uri}?{urlencode(error_params)}",
|
||||
status_code=status.HTTP_302_FOUND,
|
||||
)
|
||||
|
||||
# Success
|
||||
success_params = {"code": code}
|
||||
if state:
|
||||
success_params["state"] = state
|
||||
|
||||
return RedirectResponse(
|
||||
url=f"{redirect_uri}?{urlencode(success_params)}",
|
||||
status_code=status.HTTP_302_FOUND,
|
||||
)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Token Endpoint
|
||||
# ============================================================================
|
||||
|
||||
|
||||
@router.post(
|
||||
"/provider/token",
|
||||
summary="Token Endpoint (Skeleton)",
|
||||
response_model=OAuthTokenResponse,
|
||||
summary="Token Endpoint",
|
||||
description="""
|
||||
OAuth 2.0 Token Endpoint.
|
||||
|
||||
**NOTE**: This is a skeleton implementation. In a full implementation,
|
||||
this would exchange authorization codes for access tokens.
|
||||
Supports:
|
||||
- authorization_code: Exchange code for tokens
|
||||
- refresh_token: Refresh access token
|
||||
|
||||
Currently returns a 501 Not Implemented response.
|
||||
Client authentication:
|
||||
- Confidential clients: client_secret (Basic auth or POST body)
|
||||
- Public clients: No secret, but PKCE code_verifier required
|
||||
""",
|
||||
operation_id="oauth_provider_token",
|
||||
tags=["OAuth Provider"],
|
||||
)
|
||||
@limiter.limit("60/minute")
|
||||
async def token(
|
||||
grant_type: str = Form(..., description="Grant type (authorization_code)"),
|
||||
request: Request,
|
||||
grant_type: str = Form(..., description="Grant type"),
|
||||
code: str | None = Form(default=None, description="Authorization code"),
|
||||
redirect_uri: str | None = Form(default=None, description="Redirect URI"),
|
||||
client_id: str | None = Form(default=None, description="Client ID"),
|
||||
client_secret: str | None = Form(default=None, description="Client secret"),
|
||||
code_verifier: str | None = Form(default=None, description="PKCE code verifier"),
|
||||
refresh_token: str | None = Form(default=None, description="Refresh token"),
|
||||
scope: str | None = Form(default=None, description="Scope (for refresh)"),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> Any:
|
||||
"""
|
||||
Token endpoint (skeleton).
|
||||
_: None = Depends(require_provider_enabled),
|
||||
) -> OAuthTokenResponse:
|
||||
"""Token endpoint - exchange code for tokens or refresh."""
|
||||
# Extract client credentials from Basic auth if not in body
|
||||
if not client_id:
|
||||
auth_header = request.headers.get("Authorization", "")
|
||||
if auth_header.startswith("Basic "):
|
||||
import base64
|
||||
|
||||
Supported grant types (when fully implemented):
|
||||
- authorization_code: Exchange code for tokens
|
||||
- refresh_token: Refresh access token
|
||||
"""
|
||||
if not settings.OAUTH_PROVIDER_ENABLED:
|
||||
try:
|
||||
decoded = base64.b64decode(auth_header[6:]).decode()
|
||||
client_id, client_secret = decoded.split(":", 1)
|
||||
except Exception as e:
|
||||
# Log malformed Basic auth for security monitoring
|
||||
logger.warning(
|
||||
f"Malformed Basic auth header in token request: {type(e).__name__}"
|
||||
)
|
||||
# Fall back to form body
|
||||
|
||||
if not client_id:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="OAuth provider mode is not enabled",
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="invalid_client: client_id required",
|
||||
headers={"WWW-Authenticate": "Basic"},
|
||||
)
|
||||
|
||||
if grant_type not in ["authorization_code", "refresh_token"]:
|
||||
# Get device info
|
||||
device_info = request.headers.get("User-Agent", "")[:500]
|
||||
ip_address = get_remote_address(request)
|
||||
|
||||
try:
|
||||
if grant_type == "authorization_code":
|
||||
if not code:
|
||||
raise provider_service.InvalidRequestError("code required")
|
||||
if not redirect_uri:
|
||||
raise provider_service.InvalidRequestError("redirect_uri required")
|
||||
|
||||
result = await provider_service.exchange_authorization_code(
|
||||
db=db,
|
||||
code=code,
|
||||
client_id=client_id,
|
||||
redirect_uri=redirect_uri,
|
||||
code_verifier=code_verifier,
|
||||
client_secret=client_secret,
|
||||
device_info=device_info,
|
||||
ip_address=ip_address,
|
||||
)
|
||||
|
||||
elif grant_type == "refresh_token":
|
||||
if not refresh_token:
|
||||
raise provider_service.InvalidRequestError("refresh_token required")
|
||||
|
||||
result = await provider_service.refresh_tokens(
|
||||
db=db,
|
||||
refresh_token=refresh_token,
|
||||
client_id=client_id,
|
||||
client_secret=client_secret,
|
||||
scope=scope,
|
||||
device_info=device_info,
|
||||
ip_address=ip_address,
|
||||
)
|
||||
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="unsupported_grant_type: Must be authorization_code or refresh_token",
|
||||
)
|
||||
|
||||
return OAuthTokenResponse(**result)
|
||||
|
||||
except provider_service.InvalidClientError as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail=f"{e.error}: {e.error_description}",
|
||||
headers={"WWW-Authenticate": "Basic"},
|
||||
)
|
||||
except provider_service.OAuthProviderError as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="unsupported_grant_type",
|
||||
detail=f"{e.error}: {e.error_description}",
|
||||
)
|
||||
|
||||
# Skeleton: Return not implemented
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_501_NOT_IMPLEMENTED,
|
||||
detail="Token endpoint not fully implemented. "
|
||||
"This is a skeleton for MCP integration.",
|
||||
)
|
||||
|
||||
# ============================================================================
|
||||
# Token Revocation (RFC 7009)
|
||||
# ============================================================================
|
||||
|
||||
|
||||
@router.post(
|
||||
"/provider/revoke",
|
||||
summary="Token Revocation Endpoint (Skeleton)",
|
||||
status_code=status.HTTP_200_OK,
|
||||
summary="Token Revocation Endpoint",
|
||||
description="""
|
||||
OAuth 2.0 Token Revocation Endpoint (RFC 7009).
|
||||
|
||||
**NOTE**: This is a skeleton implementation.
|
||||
|
||||
Currently returns a 501 Not Implemented response.
|
||||
Revokes an access token or refresh token.
|
||||
Always returns 200 OK (even if token is invalid) per spec.
|
||||
""",
|
||||
operation_id="oauth_provider_revoke",
|
||||
tags=["OAuth Provider"],
|
||||
)
|
||||
@limiter.limit("30/minute")
|
||||
async def revoke(
|
||||
request: Request,
|
||||
token: str = Form(..., description="Token to revoke"),
|
||||
token_type_hint: str | None = Form(
|
||||
default=None, description="Token type hint (access_token, refresh_token)"
|
||||
@@ -225,88 +541,298 @@ async def revoke(
|
||||
client_id: str | None = Form(default=None, description="Client ID"),
|
||||
client_secret: str | None = Form(default=None, description="Client secret"),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> Any:
|
||||
"""
|
||||
Token revocation endpoint (skeleton).
|
||||
_: None = Depends(require_provider_enabled),
|
||||
) -> dict[str, str]:
|
||||
"""Revoke a token."""
|
||||
# Extract client credentials from Basic auth if not in body
|
||||
if not client_id:
|
||||
auth_header = request.headers.get("Authorization", "")
|
||||
if auth_header.startswith("Basic "):
|
||||
import base64
|
||||
|
||||
In a full implementation, this would invalidate the specified token.
|
||||
"""
|
||||
if not settings.OAUTH_PROVIDER_ENABLED:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="OAuth provider mode is not enabled",
|
||||
try:
|
||||
decoded = base64.b64decode(auth_header[6:]).decode()
|
||||
client_id, client_secret = decoded.split(":", 1)
|
||||
except Exception as e:
|
||||
# Log malformed Basic auth for security monitoring
|
||||
logger.warning(
|
||||
f"Malformed Basic auth header in revoke request: {type(e).__name__}"
|
||||
)
|
||||
# Fall back to form body
|
||||
|
||||
try:
|
||||
await provider_service.revoke_token(
|
||||
db=db,
|
||||
token=token,
|
||||
token_type_hint=token_type_hint,
|
||||
client_id=client_id,
|
||||
client_secret=client_secret,
|
||||
)
|
||||
except provider_service.InvalidClientError:
|
||||
# Per RFC 7009, we should return 200 OK even for errors
|
||||
# But client authentication errors can return 401
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="invalid_client",
|
||||
headers={"WWW-Authenticate": "Basic"},
|
||||
)
|
||||
except Exception as e:
|
||||
# Log but don't expose errors per RFC 7009
|
||||
logger.warning(f"Token revocation error: {e}")
|
||||
|
||||
# Skeleton: Return not implemented
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_501_NOT_IMPLEMENTED,
|
||||
detail="Revocation endpoint not fully implemented. "
|
||||
"This is a skeleton for MCP integration.",
|
||||
)
|
||||
# Always return 200 OK per RFC 7009
|
||||
return {"status": "ok"}
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Client Management (Admin only)
|
||||
# Token Introspection (RFC 7662)
|
||||
# ============================================================================
|
||||
|
||||
|
||||
@router.post(
|
||||
"/provider/introspect",
|
||||
response_model=OAuthTokenIntrospectionResponse,
|
||||
summary="Token Introspection Endpoint",
|
||||
description="""
|
||||
OAuth 2.0 Token Introspection Endpoint (RFC 7662).
|
||||
|
||||
Allows resource servers to query the authorization server
|
||||
to determine the active state and metadata of a token.
|
||||
""",
|
||||
operation_id="oauth_provider_introspect",
|
||||
tags=["OAuth Provider"],
|
||||
)
|
||||
@limiter.limit("120/minute")
|
||||
async def introspect(
|
||||
request: Request,
|
||||
token: str = Form(..., description="Token to introspect"),
|
||||
token_type_hint: str | None = Form(
|
||||
default=None, description="Token type hint (access_token, refresh_token)"
|
||||
),
|
||||
client_id: str | None = Form(default=None, description="Client ID"),
|
||||
client_secret: str | None = Form(default=None, description="Client secret"),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
_: None = Depends(require_provider_enabled),
|
||||
) -> OAuthTokenIntrospectionResponse:
|
||||
"""Introspect a token."""
|
||||
# Extract client credentials from Basic auth if not in body
|
||||
if not client_id:
|
||||
auth_header = request.headers.get("Authorization", "")
|
||||
if auth_header.startswith("Basic "):
|
||||
import base64
|
||||
|
||||
try:
|
||||
decoded = base64.b64decode(auth_header[6:]).decode()
|
||||
client_id, client_secret = decoded.split(":", 1)
|
||||
except Exception as e:
|
||||
# Log malformed Basic auth for security monitoring
|
||||
logger.warning(
|
||||
f"Malformed Basic auth header in introspect request: {type(e).__name__}"
|
||||
)
|
||||
# Fall back to form body
|
||||
|
||||
try:
|
||||
result = await provider_service.introspect_token(
|
||||
db=db,
|
||||
token=token,
|
||||
token_type_hint=token_type_hint,
|
||||
client_id=client_id,
|
||||
client_secret=client_secret,
|
||||
)
|
||||
return OAuthTokenIntrospectionResponse(**result)
|
||||
except provider_service.InvalidClientError:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="invalid_client",
|
||||
headers={"WWW-Authenticate": "Basic"},
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"Token introspection error: {e}")
|
||||
return OAuthTokenIntrospectionResponse(active=False)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Client Management (Admin)
|
||||
# ============================================================================
|
||||
|
||||
|
||||
@router.post(
|
||||
"/provider/clients",
|
||||
summary="Register OAuth Client (Admin)",
|
||||
response_model=dict,
|
||||
summary="Register OAuth Client",
|
||||
description="""
|
||||
Register a new OAuth client (admin only).
|
||||
|
||||
This endpoint allows creating MCP clients that can authenticate
|
||||
against this API.
|
||||
Creates an MCP client that can authenticate against this API.
|
||||
Returns client_id and client_secret (for confidential clients).
|
||||
|
||||
**NOTE**: This is a minimal implementation.
|
||||
**Important:** Store the client_secret securely - it won't be shown again!
|
||||
""",
|
||||
operation_id="register_oauth_client",
|
||||
tags=["OAuth Provider"],
|
||||
tags=["OAuth Provider Admin"],
|
||||
)
|
||||
async def register_client(
|
||||
client_name: str = Form(..., description="Client application name"),
|
||||
redirect_uris: str = Form(..., description="Comma-separated list of redirect URIs"),
|
||||
redirect_uris: str = Form(..., description="Comma-separated redirect URIs"),
|
||||
client_type: str = Form(default="public", description="public or confidential"),
|
||||
scopes: str = Form(
|
||||
default="openid profile email",
|
||||
description="Allowed scopes (space-separated)",
|
||||
),
|
||||
mcp_server_url: str | None = Form(default=None, description="MCP server URL"),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> Any:
|
||||
"""
|
||||
Register a new OAuth client (skeleton).
|
||||
|
||||
In a full implementation, this would require admin authentication.
|
||||
"""
|
||||
if not settings.OAUTH_PROVIDER_ENABLED:
|
||||
_: None = Depends(require_provider_enabled),
|
||||
current_user: User = Depends(get_current_superuser),
|
||||
) -> dict:
|
||||
"""Register a new OAuth client."""
|
||||
# Parse redirect URIs
|
||||
uris = [uri.strip() for uri in redirect_uris.split(",") if uri.strip()]
|
||||
if not uris:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="OAuth provider mode is not enabled",
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="At least one redirect_uri is required",
|
||||
)
|
||||
|
||||
# NOTE: In production, this should require admin authentication
|
||||
# For now, this is a skeleton that shows the structure
|
||||
|
||||
from app.schemas.oauth import OAuthClientCreate
|
||||
# Parse scopes
|
||||
allowed_scopes = [s.strip() for s in scopes.split() if s.strip()]
|
||||
|
||||
client_data = OAuthClientCreate(
|
||||
client_name=client_name,
|
||||
client_description=None,
|
||||
redirect_uris=[uri.strip() for uri in redirect_uris.split(",")],
|
||||
allowed_scopes=["openid", "profile", "email"],
|
||||
redirect_uris=uris,
|
||||
allowed_scopes=allowed_scopes,
|
||||
client_type=client_type,
|
||||
)
|
||||
|
||||
client, secret = await oauth_client.create_client(db, obj_in=client_data)
|
||||
client, secret = await oauth_client_crud.create_client(db, obj_in=client_data)
|
||||
|
||||
# Update MCP server URL if provided
|
||||
if mcp_server_url:
|
||||
client.mcp_server_url = mcp_server_url
|
||||
await db.commit()
|
||||
|
||||
result = {
|
||||
"client_id": client.client_id,
|
||||
"client_name": client.client_name,
|
||||
"client_type": client.client_type,
|
||||
"redirect_uris": client.redirect_uris,
|
||||
"allowed_scopes": client.allowed_scopes,
|
||||
}
|
||||
|
||||
if secret:
|
||||
result["client_secret"] = secret
|
||||
result["warning"] = (
|
||||
"Store the client_secret securely. It will not be shown again."
|
||||
"Store the client_secret securely! It will not be shown again."
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
@router.get(
|
||||
"/provider/clients",
|
||||
response_model=list[OAuthClientResponse],
|
||||
summary="List OAuth Clients",
|
||||
description="List all registered OAuth clients (admin only).",
|
||||
operation_id="list_oauth_clients",
|
||||
tags=["OAuth Provider Admin"],
|
||||
)
|
||||
async def list_clients(
|
||||
db: AsyncSession = Depends(get_db),
|
||||
_: None = Depends(require_provider_enabled),
|
||||
current_user: User = Depends(get_current_superuser),
|
||||
) -> list[OAuthClientResponse]:
|
||||
"""List all OAuth clients."""
|
||||
clients = await oauth_client_crud.get_all_clients(db)
|
||||
return [OAuthClientResponse.model_validate(c) for c in clients]
|
||||
|
||||
|
||||
@router.delete(
|
||||
"/provider/clients/{client_id}",
|
||||
status_code=status.HTTP_204_NO_CONTENT,
|
||||
summary="Delete OAuth Client",
|
||||
description="Delete an OAuth client (admin only). Revokes all tokens.",
|
||||
operation_id="delete_oauth_client",
|
||||
tags=["OAuth Provider Admin"],
|
||||
)
|
||||
async def delete_client(
|
||||
client_id: str,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
_: None = Depends(require_provider_enabled),
|
||||
current_user: User = Depends(get_current_superuser),
|
||||
) -> None:
|
||||
"""Delete an OAuth client."""
|
||||
client = await provider_service.get_client(db, client_id)
|
||||
if not client:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Client not found",
|
||||
)
|
||||
|
||||
await oauth_client_crud.delete_client(db, client_id=client_id)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# User Consent Management
|
||||
# ============================================================================
|
||||
|
||||
|
||||
@router.get(
|
||||
"/provider/consents",
|
||||
summary="List My Consents",
|
||||
description="List OAuth applications the current user has authorized.",
|
||||
operation_id="list_my_oauth_consents",
|
||||
tags=["OAuth Provider"],
|
||||
)
|
||||
async def list_my_consents(
|
||||
db: AsyncSession = Depends(get_db),
|
||||
_: None = Depends(require_provider_enabled),
|
||||
current_user: User = Depends(get_current_active_user),
|
||||
) -> list[dict]:
|
||||
"""List applications the user has authorized."""
|
||||
from sqlalchemy import select
|
||||
|
||||
from app.models.oauth_client import OAuthClient
|
||||
from app.models.oauth_provider_token import OAuthConsent
|
||||
|
||||
result = await db.execute(
|
||||
select(OAuthConsent, OAuthClient)
|
||||
.join(OAuthClient, OAuthConsent.client_id == OAuthClient.client_id)
|
||||
.where(OAuthConsent.user_id == current_user.id)
|
||||
)
|
||||
rows = result.all()
|
||||
|
||||
return [
|
||||
{
|
||||
"client_id": consent.client_id,
|
||||
"client_name": client.client_name,
|
||||
"client_description": client.client_description,
|
||||
"granted_scopes": consent.granted_scopes.split()
|
||||
if consent.granted_scopes
|
||||
else [],
|
||||
"granted_at": consent.created_at.isoformat(),
|
||||
}
|
||||
for consent, client in rows
|
||||
]
|
||||
|
||||
|
||||
@router.delete(
|
||||
"/provider/consents/{client_id}",
|
||||
status_code=status.HTTP_204_NO_CONTENT,
|
||||
summary="Revoke My Consent",
|
||||
description="Revoke authorization for an OAuth application. Also revokes all tokens.",
|
||||
operation_id="revoke_my_oauth_consent",
|
||||
tags=["OAuth Provider"],
|
||||
)
|
||||
async def revoke_my_consent(
|
||||
client_id: str,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
_: None = Depends(require_provider_enabled),
|
||||
current_user: User = Depends(get_current_active_user),
|
||||
) -> None:
|
||||
"""Revoke consent for an application."""
|
||||
revoked = await provider_service.revoke_consent(db, current_user.id, client_id)
|
||||
if not revoked:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="No consent found for this client",
|
||||
)
|
||||
|
||||
@@ -515,11 +515,11 @@ class CRUDOAuthClient(CRUDBase[OAuthClient, OAuthClientCreate, EmptySchema]):
|
||||
client_secret_hash = None
|
||||
if obj_in.client_type == "confidential":
|
||||
client_secret = secrets.token_urlsafe(48)
|
||||
# In production, use proper password hashing (bcrypt)
|
||||
# For now, we store a hash placeholder
|
||||
import hashlib
|
||||
# SECURITY: Use bcrypt for secret storage (not SHA-256)
|
||||
# bcrypt is computationally expensive, making brute-force attacks infeasible
|
||||
from app.core.auth import get_password_hash
|
||||
|
||||
client_secret_hash = hashlib.sha256(client_secret.encode()).hexdigest()
|
||||
client_secret_hash = get_password_hash(client_secret)
|
||||
|
||||
db_obj = OAuthClient(
|
||||
client_id=client_id,
|
||||
@@ -632,17 +632,82 @@ class CRUDOAuthClient(CRUDBase[OAuthClient, OAuthClientCreate, EmptySchema]):
|
||||
if client is None or client.client_secret_hash is None:
|
||||
return False
|
||||
|
||||
# Verify secret
|
||||
import hashlib
|
||||
# SECURITY: Verify secret using bcrypt (not SHA-256)
|
||||
# This supports both old SHA-256 hashes (for migration) and new bcrypt hashes
|
||||
from app.core.auth import verify_password
|
||||
|
||||
secret_hash = hashlib.sha256(client_secret.encode()).hexdigest()
|
||||
# Cast to str for type safety with compare_digest
|
||||
stored_hash: str = str(client.client_secret_hash)
|
||||
return secrets.compare_digest(stored_hash, secret_hash)
|
||||
|
||||
# Check if it's a bcrypt hash (starts with $2b$) or legacy SHA-256
|
||||
if stored_hash.startswith("$2"):
|
||||
# New bcrypt format
|
||||
return verify_password(client_secret, stored_hash)
|
||||
else:
|
||||
# Legacy SHA-256 format - still support for migration
|
||||
import hashlib
|
||||
|
||||
secret_hash = hashlib.sha256(client_secret.encode()).hexdigest()
|
||||
return secrets.compare_digest(stored_hash, secret_hash)
|
||||
except Exception as e: # pragma: no cover
|
||||
logger.error(f"Error verifying client secret: {e!s}")
|
||||
return False
|
||||
|
||||
async def get_all_clients(
|
||||
self, db: AsyncSession, *, include_inactive: bool = False
|
||||
) -> list[OAuthClient]:
|
||||
"""
|
||||
Get all OAuth clients.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
include_inactive: Whether to include inactive clients
|
||||
|
||||
Returns:
|
||||
List of OAuthClient objects
|
||||
"""
|
||||
try:
|
||||
query = select(OAuthClient).order_by(OAuthClient.created_at.desc())
|
||||
if not include_inactive:
|
||||
query = query.where(OAuthClient.is_active == True) # noqa: E712
|
||||
|
||||
result = await db.execute(query)
|
||||
return list(result.scalars().all())
|
||||
except Exception as e: # pragma: no cover
|
||||
logger.error(f"Error getting all OAuth clients: {e!s}")
|
||||
raise
|
||||
|
||||
async def delete_client(self, db: AsyncSession, *, client_id: str) -> bool:
|
||||
"""
|
||||
Delete an OAuth client permanently.
|
||||
|
||||
Note: This will cascade delete related records (tokens, consents, etc.)
|
||||
due to foreign key constraints.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
client_id: OAuth client ID
|
||||
|
||||
Returns:
|
||||
True if deleted, False if not found
|
||||
"""
|
||||
try:
|
||||
result = await db.execute(
|
||||
delete(OAuthClient).where(OAuthClient.client_id == client_id)
|
||||
)
|
||||
await db.commit()
|
||||
|
||||
deleted = result.rowcount > 0
|
||||
if deleted:
|
||||
logger.info(f"OAuth client deleted: {client_id}")
|
||||
else:
|
||||
logger.warning(f"OAuth client not found for deletion: {client_id}")
|
||||
|
||||
return deleted
|
||||
except Exception as e: # pragma: no cover
|
||||
await db.rollback()
|
||||
logger.error(f"Error deleting OAuth client {client_id}: {e!s}")
|
||||
raise
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Singleton instances
|
||||
|
||||
@@ -8,9 +8,13 @@ from app.core.database import Base
|
||||
|
||||
from .base import TimestampMixin, UUIDMixin
|
||||
|
||||
# OAuth models
|
||||
# OAuth models (client mode - authenticate via Google/GitHub)
|
||||
from .oauth_account import OAuthAccount
|
||||
|
||||
# OAuth provider models (server mode - act as authorization server for MCP)
|
||||
from .oauth_authorization_code import OAuthAuthorizationCode
|
||||
from .oauth_client import OAuthClient
|
||||
from .oauth_provider_token import OAuthConsent, OAuthProviderRefreshToken
|
||||
from .oauth_state import OAuthState
|
||||
from .organization import Organization
|
||||
|
||||
@@ -22,7 +26,10 @@ from .user_session import UserSession
|
||||
__all__ = [
|
||||
"Base",
|
||||
"OAuthAccount",
|
||||
"OAuthAuthorizationCode",
|
||||
"OAuthClient",
|
||||
"OAuthConsent",
|
||||
"OAuthProviderRefreshToken",
|
||||
"OAuthState",
|
||||
"Organization",
|
||||
"OrganizationRole",
|
||||
|
||||
97
backend/app/models/oauth_authorization_code.py
Normal file
97
backend/app/models/oauth_authorization_code.py
Normal file
@@ -0,0 +1,97 @@
|
||||
"""OAuth authorization code model for OAuth provider mode."""
|
||||
|
||||
from datetime import UTC, datetime
|
||||
|
||||
from sqlalchemy import Boolean, Column, DateTime, ForeignKey, Index, String
|
||||
from sqlalchemy.dialects.postgresql import UUID
|
||||
from sqlalchemy.orm import relationship
|
||||
|
||||
from .base import Base, TimestampMixin, UUIDMixin
|
||||
|
||||
|
||||
class OAuthAuthorizationCode(Base, UUIDMixin, TimestampMixin):
|
||||
"""
|
||||
OAuth 2.0 Authorization Code for the authorization code flow.
|
||||
|
||||
Authorization codes are:
|
||||
- Single-use (marked as used after exchange)
|
||||
- Short-lived (10 minutes default)
|
||||
- Bound to specific client, user, redirect_uri
|
||||
- Support PKCE (code_challenge/code_challenge_method)
|
||||
|
||||
Security considerations:
|
||||
- Code must be cryptographically random (64 chars, URL-safe)
|
||||
- Must validate redirect_uri matches exactly
|
||||
- Must verify PKCE code_verifier for public clients
|
||||
- Must be consumed within expiration time
|
||||
"""
|
||||
|
||||
__tablename__ = "oauth_authorization_codes"
|
||||
|
||||
# The authorization code (cryptographically random, URL-safe)
|
||||
code = Column(String(128), unique=True, nullable=False, index=True)
|
||||
|
||||
# Client that requested the code
|
||||
client_id = Column(
|
||||
String(64),
|
||||
ForeignKey("oauth_clients.client_id", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
)
|
||||
|
||||
# User who authorized the request
|
||||
user_id = Column(
|
||||
UUID(as_uuid=True),
|
||||
ForeignKey("users.id", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
)
|
||||
|
||||
# Redirect URI (must match exactly on token exchange)
|
||||
redirect_uri = Column(String(2048), nullable=False)
|
||||
|
||||
# Granted scopes (space-separated)
|
||||
scope = Column(String(1000), nullable=False, default="")
|
||||
|
||||
# PKCE support (required for public clients)
|
||||
code_challenge = Column(String(128), nullable=True)
|
||||
code_challenge_method = Column(String(10), nullable=True) # "S256" or "plain"
|
||||
|
||||
# State parameter (for CSRF protection, returned to client)
|
||||
state = Column(String(256), nullable=True)
|
||||
|
||||
# Nonce (for OpenID Connect, included in ID token)
|
||||
nonce = Column(String(256), nullable=True)
|
||||
|
||||
# Expiration (codes are short-lived, typically 10 minutes)
|
||||
expires_at = Column(DateTime(timezone=True), nullable=False)
|
||||
|
||||
# Single-use flag (set to True after successful exchange)
|
||||
used = Column(Boolean, default=False, nullable=False)
|
||||
|
||||
# Relationships
|
||||
client = relationship("OAuthClient", backref="authorization_codes")
|
||||
user = relationship("User", backref="oauth_authorization_codes")
|
||||
|
||||
# Indexes for efficient cleanup queries
|
||||
__table_args__ = (
|
||||
Index("ix_oauth_authorization_codes_expires_at", "expires_at"),
|
||||
Index("ix_oauth_authorization_codes_client_user", "client_id", "user_id"),
|
||||
)
|
||||
|
||||
def __repr__(self):
|
||||
return f"<OAuthAuthorizationCode {self.code[:8]}... for {self.client_id}>"
|
||||
|
||||
@property
|
||||
def is_expired(self) -> bool:
|
||||
"""Check if the authorization code has expired."""
|
||||
# Use timezone-aware comparison (datetime.utcnow() is deprecated)
|
||||
now = datetime.now(UTC)
|
||||
expires_at = self.expires_at
|
||||
# Handle both timezone-aware and naive datetimes from DB
|
||||
if expires_at.tzinfo is None:
|
||||
expires_at = expires_at.replace(tzinfo=UTC)
|
||||
return now > expires_at
|
||||
|
||||
@property
|
||||
def is_valid(self) -> bool:
|
||||
"""Check if the authorization code is valid (not used, not expired)."""
|
||||
return not self.used and not self.is_expired
|
||||
159
backend/app/models/oauth_provider_token.py
Normal file
159
backend/app/models/oauth_provider_token.py
Normal file
@@ -0,0 +1,159 @@
|
||||
"""OAuth provider token models for OAuth provider mode."""
|
||||
|
||||
from datetime import UTC, datetime
|
||||
|
||||
from sqlalchemy import Boolean, Column, DateTime, ForeignKey, Index, String
|
||||
from sqlalchemy.dialects.postgresql import UUID
|
||||
from sqlalchemy.orm import relationship
|
||||
|
||||
from .base import Base, TimestampMixin, UUIDMixin
|
||||
|
||||
|
||||
class OAuthProviderRefreshToken(Base, UUIDMixin, TimestampMixin):
|
||||
"""
|
||||
OAuth 2.0 Refresh Token for the OAuth provider.
|
||||
|
||||
Refresh tokens are:
|
||||
- Opaque (stored as hash in DB, actual token given to client)
|
||||
- Long-lived (configurable, default 30 days)
|
||||
- Revocable (via revoked flag or deletion)
|
||||
- Bound to specific client, user, and scope
|
||||
|
||||
Access tokens are JWTs and not stored in DB (self-contained).
|
||||
This model only tracks refresh tokens for revocation support.
|
||||
|
||||
Security considerations:
|
||||
- Store token hash, not plaintext
|
||||
- Support token rotation (new refresh token on use)
|
||||
- Track last used time for security auditing
|
||||
- Support revocation by user, client, or admin
|
||||
"""
|
||||
|
||||
__tablename__ = "oauth_provider_refresh_tokens"
|
||||
|
||||
# Hash of the refresh token (SHA-256)
|
||||
# We store hash, not plaintext, for security
|
||||
token_hash = Column(String(64), unique=True, nullable=False, index=True)
|
||||
|
||||
# Unique token ID (JTI) - used in JWT access tokens to reference this refresh token
|
||||
jti = Column(String(64), unique=True, nullable=False, index=True)
|
||||
|
||||
# Client that owns this token
|
||||
client_id = Column(
|
||||
String(64),
|
||||
ForeignKey("oauth_clients.client_id", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
)
|
||||
|
||||
# User who authorized this token
|
||||
user_id = Column(
|
||||
UUID(as_uuid=True),
|
||||
ForeignKey("users.id", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
)
|
||||
|
||||
# Granted scopes (space-separated)
|
||||
scope = Column(String(1000), nullable=False, default="")
|
||||
|
||||
# Token expiration
|
||||
expires_at = Column(DateTime(timezone=True), nullable=False)
|
||||
|
||||
# Revocation flag
|
||||
revoked = Column(Boolean, default=False, nullable=False, index=True)
|
||||
|
||||
# Last used timestamp (for security auditing)
|
||||
last_used_at = Column(DateTime(timezone=True), nullable=True)
|
||||
|
||||
# Device/session info (optional, for user visibility)
|
||||
device_info = Column(String(500), nullable=True)
|
||||
ip_address = Column(String(45), nullable=True)
|
||||
|
||||
# Relationships
|
||||
client = relationship("OAuthClient", backref="refresh_tokens")
|
||||
user = relationship("User", backref="oauth_provider_refresh_tokens")
|
||||
|
||||
# Indexes
|
||||
__table_args__ = (
|
||||
Index("ix_oauth_provider_refresh_tokens_expires_at", "expires_at"),
|
||||
Index("ix_oauth_provider_refresh_tokens_client_user", "client_id", "user_id"),
|
||||
Index(
|
||||
"ix_oauth_provider_refresh_tokens_user_revoked",
|
||||
"user_id",
|
||||
"revoked",
|
||||
),
|
||||
)
|
||||
|
||||
def __repr__(self):
|
||||
status = "revoked" if self.revoked else "active"
|
||||
return f"<OAuthProviderRefreshToken {self.jti[:8]}... ({status})>"
|
||||
|
||||
@property
|
||||
def is_expired(self) -> bool:
|
||||
"""Check if the refresh token has expired."""
|
||||
# Use timezone-aware comparison (datetime.utcnow() is deprecated)
|
||||
now = datetime.now(UTC)
|
||||
expires_at = self.expires_at
|
||||
# Handle both timezone-aware and naive datetimes from DB
|
||||
if expires_at.tzinfo is None:
|
||||
expires_at = expires_at.replace(tzinfo=UTC)
|
||||
return now > expires_at
|
||||
|
||||
@property
|
||||
def is_valid(self) -> bool:
|
||||
"""Check if the refresh token is valid (not revoked, not expired)."""
|
||||
return not self.revoked and not self.is_expired
|
||||
|
||||
|
||||
class OAuthConsent(Base, UUIDMixin, TimestampMixin):
|
||||
"""
|
||||
OAuth consent record - remembers user consent for a client.
|
||||
|
||||
When a user grants consent to an OAuth client, we store the record
|
||||
so they don't have to re-consent on subsequent authorizations
|
||||
(unless scopes change).
|
||||
|
||||
This enables a better UX - users only see consent screen once per client,
|
||||
unless the client requests additional scopes.
|
||||
"""
|
||||
|
||||
__tablename__ = "oauth_consents"
|
||||
|
||||
# User who granted consent
|
||||
user_id = Column(
|
||||
UUID(as_uuid=True),
|
||||
ForeignKey("users.id", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
)
|
||||
|
||||
# Client that received consent
|
||||
client_id = Column(
|
||||
String(64),
|
||||
ForeignKey("oauth_clients.client_id", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
)
|
||||
|
||||
# Granted scopes (space-separated)
|
||||
granted_scopes = Column(String(1000), nullable=False, default="")
|
||||
|
||||
# Relationships
|
||||
client = relationship("OAuthClient", backref="consents")
|
||||
user = relationship("User", backref="oauth_consents")
|
||||
|
||||
# Unique constraint: one consent record per user+client
|
||||
__table_args__ = (
|
||||
Index(
|
||||
"ix_oauth_consents_user_client",
|
||||
"user_id",
|
||||
"client_id",
|
||||
unique=True,
|
||||
),
|
||||
)
|
||||
|
||||
def __repr__(self):
|
||||
return f"<OAuthConsent user={self.user_id} client={self.client_id}>"
|
||||
|
||||
def has_scopes(self, requested_scopes: list[str]) -> bool:
|
||||
"""Check if all requested scopes are already granted."""
|
||||
granted = set(self.granted_scopes.split()) if self.granted_scopes else set()
|
||||
requested = set(requested_scopes)
|
||||
return requested.issubset(granted)
|
||||
@@ -284,6 +284,9 @@ class OAuthServerMetadata(BaseModel):
|
||||
revocation_endpoint: str | None = Field(
|
||||
None, description="Token revocation endpoint"
|
||||
)
|
||||
introspection_endpoint: str | None = Field(
|
||||
None, description="Token introspection endpoint (RFC 7662)"
|
||||
)
|
||||
scopes_supported: list[str] = Field(
|
||||
default_factory=list, description="Supported scopes"
|
||||
)
|
||||
@@ -297,6 +300,10 @@ class OAuthServerMetadata(BaseModel):
|
||||
code_challenge_methods_supported: list[str] = Field(
|
||||
default_factory=lambda: ["S256"], description="Supported PKCE methods"
|
||||
)
|
||||
token_endpoint_auth_methods_supported: list[str] = Field(
|
||||
default_factory=lambda: ["client_secret_basic", "client_secret_post", "none"],
|
||||
description="Supported client authentication methods",
|
||||
)
|
||||
|
||||
model_config = ConfigDict(
|
||||
json_schema_extra={
|
||||
@@ -304,10 +311,85 @@ class OAuthServerMetadata(BaseModel):
|
||||
"issuer": "https://api.example.com",
|
||||
"authorization_endpoint": "https://api.example.com/oauth/authorize",
|
||||
"token_endpoint": "https://api.example.com/oauth/token",
|
||||
"revocation_endpoint": "https://api.example.com/oauth/revoke",
|
||||
"introspection_endpoint": "https://api.example.com/oauth/introspect",
|
||||
"scopes_supported": ["openid", "profile", "email", "read:users"],
|
||||
"response_types_supported": ["code"],
|
||||
"grant_types_supported": ["authorization_code", "refresh_token"],
|
||||
"code_challenge_methods_supported": ["S256"],
|
||||
"token_endpoint_auth_methods_supported": [
|
||||
"client_secret_basic",
|
||||
"client_secret_post",
|
||||
"none",
|
||||
],
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# OAuth Token Responses (RFC 6749)
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class OAuthTokenResponse(BaseModel):
|
||||
"""OAuth 2.0 Token Response (RFC 6749 Section 5.1)."""
|
||||
|
||||
access_token: str = Field(..., description="The access token issued by the server")
|
||||
token_type: str = Field(
|
||||
default="Bearer", description="The type of token (typically 'Bearer')"
|
||||
)
|
||||
expires_in: int | None = Field(None, description="Token lifetime in seconds")
|
||||
refresh_token: str | None = Field(
|
||||
None, description="Refresh token for obtaining new access tokens"
|
||||
)
|
||||
scope: str | None = Field(
|
||||
None, description="Space-separated list of granted scopes"
|
||||
)
|
||||
|
||||
model_config = ConfigDict(
|
||||
json_schema_extra={
|
||||
"example": {
|
||||
"access_token": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9...",
|
||||
"token_type": "Bearer",
|
||||
"expires_in": 3600,
|
||||
"refresh_token": "dGhpcyBpcyBhIHJlZnJlc2ggdG9rZW4...",
|
||||
"scope": "openid profile email",
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
class OAuthTokenIntrospectionResponse(BaseModel):
|
||||
"""OAuth 2.0 Token Introspection Response (RFC 7662)."""
|
||||
|
||||
active: bool = Field(..., description="Whether the token is currently active")
|
||||
scope: str | None = Field(None, description="Space-separated list of scopes")
|
||||
client_id: str | None = Field(None, description="Client identifier for the token")
|
||||
username: str | None = Field(
|
||||
None, description="Human-readable identifier for the resource owner"
|
||||
)
|
||||
token_type: str | None = Field(
|
||||
None, description="Type of the token (e.g., 'Bearer')"
|
||||
)
|
||||
exp: int | None = Field(None, description="Token expiration time (Unix timestamp)")
|
||||
iat: int | None = Field(None, description="Token issue time (Unix timestamp)")
|
||||
nbf: int | None = Field(None, description="Token not-before time (Unix timestamp)")
|
||||
sub: str | None = Field(None, description="Subject of the token (user ID)")
|
||||
aud: str | None = Field(None, description="Intended audience of the token")
|
||||
iss: str | None = Field(None, description="Issuer of the token")
|
||||
|
||||
model_config = ConfigDict(
|
||||
json_schema_extra={
|
||||
"example": {
|
||||
"active": True,
|
||||
"scope": "openid profile",
|
||||
"client_id": "client123",
|
||||
"username": "user@example.com",
|
||||
"token_type": "Bearer",
|
||||
"exp": 1735689600,
|
||||
"iat": 1735686000,
|
||||
"sub": "user-uuid-here",
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
1069
backend/app/services/oauth_provider_service.py
Normal file
1069
backend/app/services/oauth_provider_service.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -246,6 +246,15 @@ class OAuthService:
|
||||
if not state_record:
|
||||
raise AuthenticationError("Invalid or expired OAuth state")
|
||||
|
||||
# SECURITY: Validate redirect_uri matches the one from authorization request
|
||||
# This prevents authorization code injection attacks (RFC 6749 Section 10.6)
|
||||
if state_record.redirect_uri != redirect_uri:
|
||||
logger.warning(
|
||||
f"OAuth redirect_uri mismatch: expected {state_record.redirect_uri}, "
|
||||
f"got {redirect_uri}"
|
||||
)
|
||||
raise AuthenticationError("Redirect URI mismatch")
|
||||
|
||||
# Extract provider from state record (str for type safety)
|
||||
provider: str = str(state_record.provider)
|
||||
|
||||
@@ -272,6 +281,19 @@ class OAuthService:
|
||||
config["token_url"],
|
||||
**token_params,
|
||||
)
|
||||
|
||||
# SECURITY: Validate ID token signature and nonce for OpenID Connect
|
||||
# This prevents token forgery and replay attacks (OIDC Core 3.1.3.7)
|
||||
if provider == "google" and state_record.nonce:
|
||||
id_token = token.get("id_token")
|
||||
if id_token:
|
||||
await OAuthService._verify_google_id_token(
|
||||
id_token=str(id_token),
|
||||
expected_nonce=str(state_record.nonce),
|
||||
client_id=client_id,
|
||||
)
|
||||
except AuthenticationError:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"OAuth token exchange failed: {e!s}")
|
||||
raise AuthenticationError("Failed to exchange authorization code")
|
||||
@@ -294,8 +316,11 @@ class OAuthService:
|
||||
# Process user info and create/link account
|
||||
provider_user_id = str(user_info.get("id") or user_info.get("sub"))
|
||||
# Email can be None if user didn't grant email permission
|
||||
# SECURITY: Normalize email (lowercase, strip) to prevent case-based account duplication
|
||||
email_raw = user_info.get("email")
|
||||
provider_email: str | None = str(email_raw) if email_raw else None
|
||||
provider_email: str | None = (
|
||||
str(email_raw).lower().strip() if email_raw else None
|
||||
)
|
||||
|
||||
if not provider_user_id:
|
||||
raise AuthenticationError("Provider did not return user ID")
|
||||
@@ -479,6 +504,106 @@ class OAuthService:
|
||||
|
||||
return user_info
|
||||
|
||||
# Google's OIDC configuration endpoints
|
||||
GOOGLE_JWKS_URL = "https://www.googleapis.com/oauth2/v3/certs"
|
||||
GOOGLE_ISSUERS = ("https://accounts.google.com", "accounts.google.com")
|
||||
|
||||
@staticmethod
|
||||
async def _verify_google_id_token(
|
||||
id_token: str,
|
||||
expected_nonce: str,
|
||||
client_id: str,
|
||||
) -> dict[str, object]:
|
||||
"""
|
||||
Verify Google ID token signature and claims.
|
||||
|
||||
SECURITY: This properly verifies the ID token by:
|
||||
1. Fetching Google's public keys (JWKS)
|
||||
2. Verifying the JWT signature against the public key
|
||||
3. Validating issuer, audience, expiry, and nonce claims
|
||||
|
||||
Args:
|
||||
id_token: The ID token JWT string
|
||||
expected_nonce: The nonce we sent in the authorization request
|
||||
client_id: Our OAuth client ID (expected audience)
|
||||
|
||||
Returns:
|
||||
Decoded ID token payload
|
||||
|
||||
Raises:
|
||||
AuthenticationError: If verification fails
|
||||
"""
|
||||
import httpx
|
||||
from jose import jwt as jose_jwt
|
||||
from jose.exceptions import JWTError
|
||||
|
||||
try:
|
||||
# Fetch Google's public keys (JWKS)
|
||||
# In production, consider caching this with TTL matching Cache-Control header
|
||||
async with httpx.AsyncClient() as client:
|
||||
jwks_response = await client.get(
|
||||
OAuthService.GOOGLE_JWKS_URL,
|
||||
timeout=10.0,
|
||||
)
|
||||
jwks_response.raise_for_status()
|
||||
jwks = jwks_response.json()
|
||||
|
||||
# Get the key ID from the token header
|
||||
unverified_header = jose_jwt.get_unverified_header(id_token)
|
||||
kid = unverified_header.get("kid")
|
||||
if not kid:
|
||||
raise AuthenticationError("ID token missing key ID (kid)")
|
||||
|
||||
# Find the matching public key
|
||||
public_key = None
|
||||
for key in jwks.get("keys", []):
|
||||
if key.get("kid") == kid:
|
||||
public_key = key
|
||||
break
|
||||
|
||||
if not public_key:
|
||||
raise AuthenticationError("ID token signed with unknown key")
|
||||
|
||||
# Verify the token signature and decode claims
|
||||
# jose library will verify signature against the JWK
|
||||
payload = jose_jwt.decode(
|
||||
id_token,
|
||||
public_key,
|
||||
algorithms=["RS256"], # Google uses RS256
|
||||
audience=client_id,
|
||||
issuer=OAuthService.GOOGLE_ISSUERS,
|
||||
options={
|
||||
"verify_signature": True,
|
||||
"verify_aud": True,
|
||||
"verify_iss": True,
|
||||
"verify_exp": True,
|
||||
"verify_iat": True,
|
||||
},
|
||||
)
|
||||
|
||||
# Verify nonce (OIDC replay attack protection)
|
||||
token_nonce = payload.get("nonce")
|
||||
if token_nonce != expected_nonce:
|
||||
logger.warning(
|
||||
f"OAuth ID token nonce mismatch: expected {expected_nonce}, "
|
||||
f"got {token_nonce}"
|
||||
)
|
||||
raise AuthenticationError("Invalid ID token nonce")
|
||||
|
||||
logger.debug("Google ID token verified successfully")
|
||||
return payload
|
||||
|
||||
except JWTError as e:
|
||||
logger.warning(f"Google ID token verification failed: {e}")
|
||||
raise AuthenticationError("Invalid ID token signature")
|
||||
except httpx.HTTPError as e:
|
||||
logger.error(f"Failed to fetch Google JWKS: {e}")
|
||||
# If we can't verify the ID token, fail closed for security
|
||||
raise AuthenticationError("Failed to verify ID token")
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error verifying Google ID token: {e}")
|
||||
raise AuthenticationError("ID token verification error")
|
||||
|
||||
@staticmethod
|
||||
async def _create_oauth_user(
|
||||
db: AsyncSession,
|
||||
|
||||
@@ -344,8 +344,8 @@ class TestOAuthProviderEndpoints:
|
||||
assert response.status_code == 404
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_provider_authorize_skeleton(self, client, async_test_db):
|
||||
"""Test provider authorize returns not implemented (skeleton)."""
|
||||
async def test_provider_authorize_requires_auth(self, client, async_test_db):
|
||||
"""Test provider authorize requires authentication."""
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
# Create a test client
|
||||
@@ -374,12 +374,12 @@ class TestOAuthProviderEndpoints:
|
||||
"redirect_uri": "http://localhost:3000/callback",
|
||||
},
|
||||
)
|
||||
# Should return 501 Not Implemented (skeleton)
|
||||
assert response.status_code == 501
|
||||
# Authorize endpoint requires authentication
|
||||
assert response.status_code == 401
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_provider_token_skeleton(self, client):
|
||||
"""Test provider token returns not implemented (skeleton)."""
|
||||
async def test_provider_token_requires_client_id(self, client):
|
||||
"""Test provider token requires client_id."""
|
||||
with patch("app.api.routes.oauth_provider.settings") as mock_settings:
|
||||
mock_settings.OAUTH_PROVIDER_ENABLED = True
|
||||
|
||||
@@ -390,5 +390,5 @@ class TestOAuthProviderEndpoints:
|
||||
"code": "test_code",
|
||||
},
|
||||
)
|
||||
# Should return 501 Not Implemented (skeleton)
|
||||
assert response.status_code == 501
|
||||
# Missing client_id returns 401 (invalid_client)
|
||||
assert response.status_code == 401
|
||||
|
||||
@@ -203,3 +203,168 @@ async def e2e_client(async_postgres_url):
|
||||
|
||||
app.dependency_overrides.clear()
|
||||
await engine.dispose()
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def e2e_superuser(e2e_client):
|
||||
"""
|
||||
Create a superuser and return credentials + tokens.
|
||||
|
||||
Returns dict with: email, password, tokens, user_id
|
||||
"""
|
||||
from uuid import uuid4
|
||||
|
||||
email = f"admin-{uuid4().hex[:8]}@example.com"
|
||||
password = "SuperAdmin123!"
|
||||
|
||||
# Register via API first to get proper password hashing
|
||||
await e2e_client.post(
|
||||
"/api/v1/auth/register",
|
||||
json={
|
||||
"email": email,
|
||||
"password": password,
|
||||
"first_name": "Super",
|
||||
"last_name": "Admin",
|
||||
},
|
||||
)
|
||||
|
||||
# Login to get tokens
|
||||
login_resp = await e2e_client.post(
|
||||
"/api/v1/auth/login",
|
||||
json={"email": email, "password": password},
|
||||
)
|
||||
tokens = login_resp.json()
|
||||
|
||||
# Now we need to make this user a superuser directly via SQL
|
||||
# Get the db session from the client's override
|
||||
from sqlalchemy import text
|
||||
|
||||
from app.core.database import get_db
|
||||
from app.main import app
|
||||
|
||||
async for db in app.dependency_overrides[get_db]():
|
||||
# Update user to be superuser
|
||||
await db.execute(
|
||||
text("UPDATE users SET is_superuser = true WHERE email = :email"),
|
||||
{"email": email},
|
||||
)
|
||||
await db.commit()
|
||||
|
||||
# Get user ID
|
||||
result = await db.execute(
|
||||
text("SELECT id FROM users WHERE email = :email"),
|
||||
{"email": email},
|
||||
)
|
||||
user_id = str(result.scalar())
|
||||
break
|
||||
|
||||
return {
|
||||
"email": email,
|
||||
"password": password,
|
||||
"tokens": tokens,
|
||||
"user_id": user_id,
|
||||
}
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def e2e_org_with_members(e2e_client, e2e_superuser):
|
||||
"""
|
||||
Create an organization with owner and member.
|
||||
|
||||
Returns dict with: org_id, org_slug, owner (tokens), member (tokens)
|
||||
"""
|
||||
from uuid import uuid4
|
||||
|
||||
# Create organization via admin API
|
||||
org_name = f"Test Org {uuid4().hex[:8]}"
|
||||
org_slug = f"test-org-{uuid4().hex[:8]}"
|
||||
|
||||
create_resp = await e2e_client.post(
|
||||
"/api/v1/admin/organizations",
|
||||
headers={"Authorization": f"Bearer {e2e_superuser['tokens']['access_token']}"},
|
||||
json={
|
||||
"name": org_name,
|
||||
"slug": org_slug,
|
||||
"description": "Test organization for E2E tests",
|
||||
},
|
||||
)
|
||||
org_data = create_resp.json()
|
||||
org_id = org_data["id"]
|
||||
|
||||
# Create owner user
|
||||
owner_email = f"owner-{uuid4().hex[:8]}@example.com"
|
||||
owner_password = "OwnerPass123!"
|
||||
|
||||
await e2e_client.post(
|
||||
"/api/v1/auth/register",
|
||||
json={
|
||||
"email": owner_email,
|
||||
"password": owner_password,
|
||||
"first_name": "Org",
|
||||
"last_name": "Owner",
|
||||
},
|
||||
)
|
||||
owner_login = await e2e_client.post(
|
||||
"/api/v1/auth/login",
|
||||
json={"email": owner_email, "password": owner_password},
|
||||
)
|
||||
owner_tokens = owner_login.json()
|
||||
|
||||
# Get owner user ID
|
||||
owner_me = await e2e_client.get(
|
||||
"/api/v1/users/me",
|
||||
headers={"Authorization": f"Bearer {owner_tokens['access_token']}"},
|
||||
)
|
||||
owner_id = owner_me.json()["id"]
|
||||
|
||||
# Add owner to organization as owner role
|
||||
await e2e_client.post(
|
||||
f"/api/v1/admin/organizations/{org_id}/members",
|
||||
headers={"Authorization": f"Bearer {e2e_superuser['tokens']['access_token']}"},
|
||||
json={"user_id": owner_id, "role": "owner"},
|
||||
)
|
||||
|
||||
# Create member user
|
||||
member_email = f"member-{uuid4().hex[:8]}@example.com"
|
||||
member_password = "MemberPass123!"
|
||||
|
||||
await e2e_client.post(
|
||||
"/api/v1/auth/register",
|
||||
json={
|
||||
"email": member_email,
|
||||
"password": member_password,
|
||||
"first_name": "Org",
|
||||
"last_name": "Member",
|
||||
},
|
||||
)
|
||||
member_login = await e2e_client.post(
|
||||
"/api/v1/auth/login",
|
||||
json={"email": member_email, "password": member_password},
|
||||
)
|
||||
member_tokens = member_login.json()
|
||||
|
||||
# Get member user ID
|
||||
member_me = await e2e_client.get(
|
||||
"/api/v1/users/me",
|
||||
headers={"Authorization": f"Bearer {member_tokens['access_token']}"},
|
||||
)
|
||||
member_id = member_me.json()["id"]
|
||||
|
||||
# Add member to organization
|
||||
await e2e_client.post(
|
||||
f"/api/v1/admin/organizations/{org_id}/members",
|
||||
headers={"Authorization": f"Bearer {e2e_superuser['tokens']['access_token']}"},
|
||||
json={"user_id": member_id, "role": "member"},
|
||||
)
|
||||
|
||||
return {
|
||||
"org_id": org_id,
|
||||
"org_slug": org_slug,
|
||||
"org_name": org_name,
|
||||
"owner": {"email": owner_email, "tokens": owner_tokens, "user_id": owner_id},
|
||||
"member": {
|
||||
"email": member_email,
|
||||
"tokens": member_tokens,
|
||||
"user_id": member_id,
|
||||
},
|
||||
}
|
||||
|
||||
648
backend/tests/e2e/test_admin_superuser_workflows.py
Normal file
648
backend/tests/e2e/test_admin_superuser_workflows.py
Normal file
@@ -0,0 +1,648 @@
|
||||
"""
|
||||
Admin superuser E2E workflow tests with real PostgreSQL.
|
||||
|
||||
These tests validate admin operations with actual superuser privileges:
|
||||
- User management (list, create, update, delete, bulk actions)
|
||||
- Organization management (create, update, delete, members)
|
||||
- Admin statistics
|
||||
|
||||
Usage:
|
||||
make test-e2e # Run all E2E tests
|
||||
"""
|
||||
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
|
||||
pytestmark = [
|
||||
pytest.mark.e2e,
|
||||
pytest.mark.postgres,
|
||||
pytest.mark.asyncio,
|
||||
]
|
||||
|
||||
|
||||
class TestAdminUserManagement:
|
||||
"""Test admin user management with superuser."""
|
||||
|
||||
async def test_admin_list_users(self, e2e_client, e2e_superuser):
|
||||
"""Superuser can list all users."""
|
||||
response = await e2e_client.get(
|
||||
"/api/v1/admin/users",
|
||||
headers={
|
||||
"Authorization": f"Bearer {e2e_superuser['tokens']['access_token']}"
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert "data" in data
|
||||
assert "pagination" in data
|
||||
assert len(data["data"]) >= 1 # At least the superuser
|
||||
|
||||
async def test_admin_list_users_with_pagination(self, e2e_client, e2e_superuser):
|
||||
"""Superuser can list users with pagination."""
|
||||
# Create a few more users
|
||||
for i in range(3):
|
||||
await e2e_client.post(
|
||||
"/api/v1/auth/register",
|
||||
json={
|
||||
"email": f"user{i}-{uuid4().hex[:8]}@example.com",
|
||||
"password": "TestPass123!",
|
||||
"first_name": f"User{i}",
|
||||
"last_name": "Test",
|
||||
},
|
||||
)
|
||||
|
||||
response = await e2e_client.get(
|
||||
"/api/v1/admin/users",
|
||||
headers={
|
||||
"Authorization": f"Bearer {e2e_superuser['tokens']['access_token']}"
|
||||
},
|
||||
params={"page": 1, "limit": 2},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert len(data["data"]) <= 2
|
||||
assert data["pagination"]["page_size"] <= 2
|
||||
|
||||
async def test_admin_create_user(self, e2e_client, e2e_superuser):
|
||||
"""Superuser can create new users."""
|
||||
email = f"newuser-{uuid4().hex[:8]}@example.com"
|
||||
|
||||
response = await e2e_client.post(
|
||||
"/api/v1/admin/users",
|
||||
headers={
|
||||
"Authorization": f"Bearer {e2e_superuser['tokens']['access_token']}"
|
||||
},
|
||||
json={
|
||||
"email": email,
|
||||
"password": "NewUserPass123!",
|
||||
"first_name": "New",
|
||||
"last_name": "User",
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code in [200, 201]
|
||||
data = response.json()
|
||||
assert data["email"] == email
|
||||
|
||||
async def test_admin_get_user_by_id(self, e2e_client, e2e_superuser):
|
||||
"""Superuser can get any user by ID."""
|
||||
# Create a user
|
||||
email = f"target-{uuid4().hex[:8]}@example.com"
|
||||
await e2e_client.post(
|
||||
"/api/v1/auth/register",
|
||||
json={
|
||||
"email": email,
|
||||
"password": "TargetPass123!",
|
||||
"first_name": "Target",
|
||||
"last_name": "User",
|
||||
},
|
||||
)
|
||||
|
||||
# Get user list to find the ID
|
||||
list_resp = await e2e_client.get(
|
||||
"/api/v1/admin/users",
|
||||
headers={
|
||||
"Authorization": f"Bearer {e2e_superuser['tokens']['access_token']}"
|
||||
},
|
||||
)
|
||||
users = list_resp.json()["data"]
|
||||
target_user = next(u for u in users if u["email"] == email)
|
||||
|
||||
# Get user by ID
|
||||
response = await e2e_client.get(
|
||||
f"/api/v1/admin/users/{target_user['id']}",
|
||||
headers={
|
||||
"Authorization": f"Bearer {e2e_superuser['tokens']['access_token']}"
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json()["email"] == email
|
||||
|
||||
async def test_admin_update_user(self, e2e_client, e2e_superuser):
|
||||
"""Superuser can update any user."""
|
||||
# Create a user
|
||||
email = f"update-{uuid4().hex[:8]}@example.com"
|
||||
await e2e_client.post(
|
||||
"/api/v1/auth/register",
|
||||
json={
|
||||
"email": email,
|
||||
"password": "UpdatePass123!",
|
||||
"first_name": "Update",
|
||||
"last_name": "User",
|
||||
},
|
||||
)
|
||||
|
||||
# Get user ID
|
||||
list_resp = await e2e_client.get(
|
||||
"/api/v1/admin/users",
|
||||
headers={
|
||||
"Authorization": f"Bearer {e2e_superuser['tokens']['access_token']}"
|
||||
},
|
||||
)
|
||||
users = list_resp.json()["data"]
|
||||
target_user = next(u for u in users if u["email"] == email)
|
||||
|
||||
# Update user
|
||||
response = await e2e_client.put(
|
||||
f"/api/v1/admin/users/{target_user['id']}",
|
||||
headers={
|
||||
"Authorization": f"Bearer {e2e_superuser['tokens']['access_token']}"
|
||||
},
|
||||
json={"first_name": "Updated", "last_name": "Name"},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json()["first_name"] == "Updated"
|
||||
|
||||
async def test_admin_deactivate_user(self, e2e_client, e2e_superuser):
|
||||
"""Superuser can deactivate users."""
|
||||
# Create a user
|
||||
email = f"deactivate-{uuid4().hex[:8]}@example.com"
|
||||
await e2e_client.post(
|
||||
"/api/v1/auth/register",
|
||||
json={
|
||||
"email": email,
|
||||
"password": "DeactivatePass123!",
|
||||
"first_name": "Deactivate",
|
||||
"last_name": "User",
|
||||
},
|
||||
)
|
||||
|
||||
# Get user ID
|
||||
list_resp = await e2e_client.get(
|
||||
"/api/v1/admin/users",
|
||||
headers={
|
||||
"Authorization": f"Bearer {e2e_superuser['tokens']['access_token']}"
|
||||
},
|
||||
)
|
||||
users = list_resp.json()["data"]
|
||||
target_user = next(u for u in users if u["email"] == email)
|
||||
|
||||
# Deactivate user
|
||||
response = await e2e_client.post(
|
||||
f"/api/v1/admin/users/{target_user['id']}/deactivate",
|
||||
headers={
|
||||
"Authorization": f"Bearer {e2e_superuser['tokens']['access_token']}"
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
|
||||
async def test_admin_bulk_action(self, e2e_client, e2e_superuser):
|
||||
"""Superuser can perform bulk actions on users."""
|
||||
# Create users for bulk action
|
||||
user_ids = []
|
||||
for i in range(2):
|
||||
email = f"bulk-{i}-{uuid4().hex[:8]}@example.com"
|
||||
await e2e_client.post(
|
||||
"/api/v1/auth/register",
|
||||
json={
|
||||
"email": email,
|
||||
"password": "BulkPass123!",
|
||||
"first_name": f"Bulk{i}",
|
||||
"last_name": "User",
|
||||
},
|
||||
)
|
||||
|
||||
# Get user IDs
|
||||
list_resp = await e2e_client.get(
|
||||
"/api/v1/admin/users",
|
||||
headers={
|
||||
"Authorization": f"Bearer {e2e_superuser['tokens']['access_token']}"
|
||||
},
|
||||
)
|
||||
users = list_resp.json()["data"]
|
||||
bulk_users = [u for u in users if u["email"].startswith("bulk-")]
|
||||
user_ids = [u["id"] for u in bulk_users]
|
||||
|
||||
# Bulk deactivate
|
||||
response = await e2e_client.post(
|
||||
"/api/v1/admin/users/bulk-action",
|
||||
headers={
|
||||
"Authorization": f"Bearer {e2e_superuser['tokens']['access_token']}"
|
||||
},
|
||||
json={"action": "deactivate", "user_ids": user_ids},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["success"] is True
|
||||
assert data["affected_count"] >= 1
|
||||
|
||||
|
||||
class TestAdminOrganizationManagement:
|
||||
"""Test admin organization management with superuser."""
|
||||
|
||||
async def test_admin_list_organizations(self, e2e_client, e2e_superuser):
|
||||
"""Superuser can list all organizations."""
|
||||
response = await e2e_client.get(
|
||||
"/api/v1/admin/organizations",
|
||||
headers={
|
||||
"Authorization": f"Bearer {e2e_superuser['tokens']['access_token']}"
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert "data" in data
|
||||
assert "pagination" in data
|
||||
|
||||
async def test_admin_create_organization(self, e2e_client, e2e_superuser):
|
||||
"""Superuser can create organizations."""
|
||||
org_name = f"Admin Org {uuid4().hex[:8]}"
|
||||
org_slug = f"admin-org-{uuid4().hex[:8]}"
|
||||
|
||||
response = await e2e_client.post(
|
||||
"/api/v1/admin/organizations",
|
||||
headers={
|
||||
"Authorization": f"Bearer {e2e_superuser['tokens']['access_token']}"
|
||||
},
|
||||
json={
|
||||
"name": org_name,
|
||||
"slug": org_slug,
|
||||
"description": "Created by admin",
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code in [200, 201]
|
||||
data = response.json()
|
||||
assert data["name"] == org_name
|
||||
assert data["slug"] == org_slug
|
||||
|
||||
async def test_admin_get_organization(self, e2e_client, e2e_superuser):
|
||||
"""Superuser can get organization details."""
|
||||
# Create org first
|
||||
org_slug = f"get-org-{uuid4().hex[:8]}"
|
||||
create_resp = await e2e_client.post(
|
||||
"/api/v1/admin/organizations",
|
||||
headers={
|
||||
"Authorization": f"Bearer {e2e_superuser['tokens']['access_token']}"
|
||||
},
|
||||
json={
|
||||
"name": "Get Org Test",
|
||||
"slug": org_slug,
|
||||
},
|
||||
)
|
||||
org_id = create_resp.json()["id"]
|
||||
|
||||
# Get org
|
||||
response = await e2e_client.get(
|
||||
f"/api/v1/admin/organizations/{org_id}",
|
||||
headers={
|
||||
"Authorization": f"Bearer {e2e_superuser['tokens']['access_token']}"
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json()["slug"] == org_slug
|
||||
|
||||
async def test_admin_update_organization(self, e2e_client, e2e_superuser):
|
||||
"""Superuser can update organizations."""
|
||||
# Create org
|
||||
org_slug = f"update-org-{uuid4().hex[:8]}"
|
||||
create_resp = await e2e_client.post(
|
||||
"/api/v1/admin/organizations",
|
||||
headers={
|
||||
"Authorization": f"Bearer {e2e_superuser['tokens']['access_token']}"
|
||||
},
|
||||
json={"name": "Update Org Test", "slug": org_slug},
|
||||
)
|
||||
org_id = create_resp.json()["id"]
|
||||
|
||||
# Update org
|
||||
response = await e2e_client.put(
|
||||
f"/api/v1/admin/organizations/{org_id}",
|
||||
headers={
|
||||
"Authorization": f"Bearer {e2e_superuser['tokens']['access_token']}"
|
||||
},
|
||||
json={"name": "Updated Org Name", "description": "Updated description"},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json()["name"] == "Updated Org Name"
|
||||
|
||||
async def test_admin_add_member_to_organization(self, e2e_client, e2e_superuser):
|
||||
"""Superuser can add members to organizations."""
|
||||
# Create org
|
||||
org_slug = f"member-org-{uuid4().hex[:8]}"
|
||||
create_resp = await e2e_client.post(
|
||||
"/api/v1/admin/organizations",
|
||||
headers={
|
||||
"Authorization": f"Bearer {e2e_superuser['tokens']['access_token']}"
|
||||
},
|
||||
json={"name": "Member Org Test", "slug": org_slug},
|
||||
)
|
||||
org_id = create_resp.json()["id"]
|
||||
|
||||
# Create user to add
|
||||
email = f"new-member-{uuid4().hex[:8]}@example.com"
|
||||
await e2e_client.post(
|
||||
"/api/v1/auth/register",
|
||||
json={
|
||||
"email": email,
|
||||
"password": "MemberPass123!",
|
||||
"first_name": "New",
|
||||
"last_name": "Member",
|
||||
},
|
||||
)
|
||||
|
||||
# Get user ID
|
||||
list_resp = await e2e_client.get(
|
||||
"/api/v1/admin/users",
|
||||
headers={
|
||||
"Authorization": f"Bearer {e2e_superuser['tokens']['access_token']}"
|
||||
},
|
||||
)
|
||||
users = list_resp.json()["data"]
|
||||
new_user = next(u for u in users if u["email"] == email)
|
||||
|
||||
# Add to org
|
||||
response = await e2e_client.post(
|
||||
f"/api/v1/admin/organizations/{org_id}/members",
|
||||
headers={
|
||||
"Authorization": f"Bearer {e2e_superuser['tokens']['access_token']}"
|
||||
},
|
||||
json={"user_id": new_user["id"], "role": "member"},
|
||||
)
|
||||
|
||||
assert response.status_code in [200, 201]
|
||||
|
||||
async def test_admin_list_organization_members(self, e2e_client, e2e_superuser):
|
||||
"""Superuser can list organization members."""
|
||||
# Create org with member
|
||||
org_slug = f"list-members-org-{uuid4().hex[:8]}"
|
||||
create_resp = await e2e_client.post(
|
||||
"/api/v1/admin/organizations",
|
||||
headers={
|
||||
"Authorization": f"Bearer {e2e_superuser['tokens']['access_token']}"
|
||||
},
|
||||
json={"name": "List Members Org", "slug": org_slug},
|
||||
)
|
||||
org_id = create_resp.json()["id"]
|
||||
|
||||
# List members
|
||||
response = await e2e_client.get(
|
||||
f"/api/v1/admin/organizations/{org_id}/members",
|
||||
headers={
|
||||
"Authorization": f"Bearer {e2e_superuser['tokens']['access_token']}"
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
|
||||
|
||||
class TestAdminStats:
|
||||
"""Test admin statistics endpoints."""
|
||||
|
||||
async def test_admin_get_stats(self, e2e_client, e2e_superuser):
|
||||
"""Superuser can get admin statistics."""
|
||||
response = await e2e_client.get(
|
||||
"/api/v1/admin/stats",
|
||||
headers={
|
||||
"Authorization": f"Bearer {e2e_superuser['tokens']['access_token']}"
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
# Stats should have user growth, org distribution, etc.
|
||||
assert "user_growth" in data or "user_status" in data
|
||||
|
||||
|
||||
class TestAdminSessionManagement:
|
||||
"""Test admin session management."""
|
||||
|
||||
async def test_admin_list_all_sessions(self, e2e_client, e2e_superuser):
|
||||
"""Superuser can list all sessions."""
|
||||
response = await e2e_client.get(
|
||||
"/api/v1/admin/sessions",
|
||||
headers={
|
||||
"Authorization": f"Bearer {e2e_superuser['tokens']['access_token']}"
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert "data" in data
|
||||
|
||||
|
||||
class TestAdminDeleteOperations:
|
||||
"""Test admin delete operations."""
|
||||
|
||||
async def test_admin_delete_user(self, e2e_client, e2e_superuser):
|
||||
"""Superuser can delete users."""
|
||||
# Create user
|
||||
email = f"delete-{uuid4().hex[:8]}@example.com"
|
||||
await e2e_client.post(
|
||||
"/api/v1/auth/register",
|
||||
json={
|
||||
"email": email,
|
||||
"password": "DeletePass123!",
|
||||
"first_name": "Delete",
|
||||
"last_name": "User",
|
||||
},
|
||||
)
|
||||
|
||||
# Get user ID
|
||||
list_resp = await e2e_client.get(
|
||||
"/api/v1/admin/users",
|
||||
headers={
|
||||
"Authorization": f"Bearer {e2e_superuser['tokens']['access_token']}"
|
||||
},
|
||||
)
|
||||
users = list_resp.json()["data"]
|
||||
target_user = next(u for u in users if u["email"] == email)
|
||||
|
||||
# Delete user
|
||||
response = await e2e_client.delete(
|
||||
f"/api/v1/admin/users/{target_user['id']}",
|
||||
headers={
|
||||
"Authorization": f"Bearer {e2e_superuser['tokens']['access_token']}"
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code in [200, 204]
|
||||
|
||||
async def test_admin_delete_organization(self, e2e_client, e2e_superuser):
|
||||
"""Superuser can delete organizations."""
|
||||
# Create org
|
||||
org_slug = f"delete-org-{uuid4().hex[:8]}"
|
||||
create_resp = await e2e_client.post(
|
||||
"/api/v1/admin/organizations",
|
||||
headers={
|
||||
"Authorization": f"Bearer {e2e_superuser['tokens']['access_token']}"
|
||||
},
|
||||
json={"name": "Delete Org Test", "slug": org_slug},
|
||||
)
|
||||
org_id = create_resp.json()["id"]
|
||||
|
||||
# Delete org
|
||||
response = await e2e_client.delete(
|
||||
f"/api/v1/admin/organizations/{org_id}",
|
||||
headers={
|
||||
"Authorization": f"Bearer {e2e_superuser['tokens']['access_token']}"
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code in [200, 204]
|
||||
|
||||
async def test_admin_remove_org_member(self, e2e_client, e2e_superuser):
|
||||
"""Superuser can remove members from organizations."""
|
||||
# Create org
|
||||
org_slug = f"remove-member-org-{uuid4().hex[:8]}"
|
||||
create_resp = await e2e_client.post(
|
||||
"/api/v1/admin/organizations",
|
||||
headers={
|
||||
"Authorization": f"Bearer {e2e_superuser['tokens']['access_token']}"
|
||||
},
|
||||
json={"name": "Remove Member Org", "slug": org_slug},
|
||||
)
|
||||
org_id = create_resp.json()["id"]
|
||||
|
||||
# Create user
|
||||
email = f"remove-member-{uuid4().hex[:8]}@example.com"
|
||||
await e2e_client.post(
|
||||
"/api/v1/auth/register",
|
||||
json={
|
||||
"email": email,
|
||||
"password": "RemovePass123!",
|
||||
"first_name": "Remove",
|
||||
"last_name": "Member",
|
||||
},
|
||||
)
|
||||
|
||||
# Get user ID
|
||||
list_resp = await e2e_client.get(
|
||||
"/api/v1/admin/users",
|
||||
headers={
|
||||
"Authorization": f"Bearer {e2e_superuser['tokens']['access_token']}"
|
||||
},
|
||||
)
|
||||
users = list_resp.json()["data"]
|
||||
target_user = next(u for u in users if u["email"] == email)
|
||||
|
||||
# Add to org
|
||||
await e2e_client.post(
|
||||
f"/api/v1/admin/organizations/{org_id}/members",
|
||||
headers={
|
||||
"Authorization": f"Bearer {e2e_superuser['tokens']['access_token']}"
|
||||
},
|
||||
json={"user_id": target_user["id"], "role": "member"},
|
||||
)
|
||||
|
||||
# Remove from org
|
||||
response = await e2e_client.delete(
|
||||
f"/api/v1/admin/organizations/{org_id}/members/{target_user['id']}",
|
||||
headers={
|
||||
"Authorization": f"Bearer {e2e_superuser['tokens']['access_token']}"
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code in [200, 204]
|
||||
|
||||
|
||||
class TestAdminSearchAndFilter:
|
||||
"""Test admin search and filter capabilities."""
|
||||
|
||||
async def test_admin_search_users_by_email(self, e2e_client, e2e_superuser):
|
||||
"""Superuser can search users by email."""
|
||||
# Create user with unique prefix
|
||||
prefix = f"searchable-{uuid4().hex[:8]}"
|
||||
email = f"{prefix}@example.com"
|
||||
await e2e_client.post(
|
||||
"/api/v1/auth/register",
|
||||
json={
|
||||
"email": email,
|
||||
"password": "SearchPass123!",
|
||||
"first_name": "Search",
|
||||
"last_name": "User",
|
||||
},
|
||||
)
|
||||
|
||||
response = await e2e_client.get(
|
||||
"/api/v1/admin/users",
|
||||
headers={
|
||||
"Authorization": f"Bearer {e2e_superuser['tokens']['access_token']}"
|
||||
},
|
||||
params={"search": prefix},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
# Search should find the user
|
||||
assert len(data["data"]) >= 1
|
||||
emails = [u["email"] for u in data["data"]]
|
||||
assert any(prefix in e for e in emails)
|
||||
|
||||
async def test_admin_filter_active_users(self, e2e_client, e2e_superuser):
|
||||
"""Superuser can filter by active status."""
|
||||
response = await e2e_client.get(
|
||||
"/api/v1/admin/users",
|
||||
headers={
|
||||
"Authorization": f"Bearer {e2e_superuser['tokens']['access_token']}"
|
||||
},
|
||||
params={"is_active": True},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
# All returned users should be active
|
||||
for user in data["data"]:
|
||||
assert user["is_active"] is True
|
||||
|
||||
async def test_admin_filter_superusers(self, e2e_client, e2e_superuser):
|
||||
"""Superuser can filter superusers."""
|
||||
response = await e2e_client.get(
|
||||
"/api/v1/admin/users",
|
||||
headers={
|
||||
"Authorization": f"Bearer {e2e_superuser['tokens']['access_token']}"
|
||||
},
|
||||
params={"is_superuser": True},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
# Should find at least the test superuser
|
||||
assert len(data["data"]) >= 1
|
||||
|
||||
async def test_admin_sort_users(self, e2e_client, e2e_superuser):
|
||||
"""Superuser can sort users by different fields."""
|
||||
response = await e2e_client.get(
|
||||
"/api/v1/admin/users",
|
||||
headers={
|
||||
"Authorization": f"Bearer {e2e_superuser['tokens']['access_token']}"
|
||||
},
|
||||
params={"sort_by": "created_at", "sort_order": "desc"},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert "data" in data
|
||||
|
||||
async def test_admin_search_organizations(self, e2e_client, e2e_superuser):
|
||||
"""Superuser can search organizations."""
|
||||
# Create org with unique name
|
||||
prefix = f"searchorg-{uuid4().hex[:8]}"
|
||||
await e2e_client.post(
|
||||
"/api/v1/admin/organizations",
|
||||
headers={
|
||||
"Authorization": f"Bearer {e2e_superuser['tokens']['access_token']}"
|
||||
},
|
||||
json={"name": f"{prefix} Test", "slug": f"{prefix}-slug"},
|
||||
)
|
||||
|
||||
response = await e2e_client.get(
|
||||
"/api/v1/admin/organizations",
|
||||
headers={
|
||||
"Authorization": f"Bearer {e2e_superuser['tokens']['access_token']}"
|
||||
},
|
||||
params={"search": prefix},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert len(data["data"]) >= 1
|
||||
212
backend/tests/e2e/test_admin_workflows.py
Normal file
212
backend/tests/e2e/test_admin_workflows.py
Normal file
@@ -0,0 +1,212 @@
|
||||
"""
|
||||
Admin E2E workflow tests with real PostgreSQL.
|
||||
|
||||
These tests validate complete admin workflows including:
|
||||
- User management (list, create, update, delete, bulk actions)
|
||||
- Organization management (create, update, delete, members)
|
||||
- Admin statistics
|
||||
|
||||
Usage:
|
||||
make test-e2e # Run all E2E tests
|
||||
"""
|
||||
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
|
||||
pytestmark = [
|
||||
pytest.mark.e2e,
|
||||
pytest.mark.postgres,
|
||||
pytest.mark.asyncio,
|
||||
]
|
||||
|
||||
|
||||
async def register_user(client, email: str, password: str = "SecurePassword123!"): # noqa: S107
|
||||
"""Helper to register a user."""
|
||||
resp = await client.post(
|
||||
"/api/v1/auth/register",
|
||||
json={
|
||||
"email": email,
|
||||
"password": password,
|
||||
"first_name": "Test",
|
||||
"last_name": "User",
|
||||
},
|
||||
)
|
||||
return resp.json()
|
||||
|
||||
|
||||
async def login_user(client, email: str, password: str = "SecurePassword123!"): # noqa: S107
|
||||
"""Helper to login a user."""
|
||||
resp = await client.post(
|
||||
"/api/v1/auth/login",
|
||||
json={"email": email, "password": password},
|
||||
)
|
||||
return resp.json()
|
||||
|
||||
|
||||
async def create_superuser(e2e_db_session, email: str, password: str):
|
||||
"""Create a superuser directly in the database."""
|
||||
from app.crud.user import user as user_crud
|
||||
from app.schemas.users import UserCreate
|
||||
|
||||
user_in = UserCreate(
|
||||
email=email,
|
||||
password=password,
|
||||
first_name="Admin",
|
||||
last_name="User",
|
||||
is_superuser=True,
|
||||
)
|
||||
user = await user_crud.create(e2e_db_session, obj_in=user_in)
|
||||
return user
|
||||
|
||||
|
||||
class TestAdminUserManagementWorkflows:
|
||||
"""Test admin user management workflows."""
|
||||
|
||||
async def test_regular_user_cannot_access_admin_endpoints(self, e2e_client):
|
||||
"""Regular users cannot access admin endpoints."""
|
||||
email = f"e2e-{uuid4().hex[:8]}@example.com"
|
||||
await register_user(e2e_client, email)
|
||||
tokens = await login_user(e2e_client, email)
|
||||
|
||||
response = await e2e_client.get(
|
||||
"/api/v1/admin/users",
|
||||
headers={"Authorization": f"Bearer {tokens['access_token']}"},
|
||||
)
|
||||
|
||||
assert response.status_code == 403
|
||||
|
||||
async def test_admin_stats_requires_superuser(self, e2e_client):
|
||||
"""Admin stats endpoint requires superuser."""
|
||||
email = f"e2e-{uuid4().hex[:8]}@example.com"
|
||||
await register_user(e2e_client, email)
|
||||
tokens = await login_user(e2e_client, email)
|
||||
|
||||
response = await e2e_client.get(
|
||||
"/api/v1/admin/stats",
|
||||
headers={"Authorization": f"Bearer {tokens['access_token']}"},
|
||||
)
|
||||
|
||||
assert response.status_code == 403
|
||||
|
||||
async def test_admin_create_user_requires_superuser(self, e2e_client):
|
||||
"""Creating users via admin endpoint requires superuser."""
|
||||
email = f"e2e-{uuid4().hex[:8]}@example.com"
|
||||
await register_user(e2e_client, email)
|
||||
tokens = await login_user(e2e_client, email)
|
||||
|
||||
response = await e2e_client.post(
|
||||
"/api/v1/admin/users",
|
||||
headers={"Authorization": f"Bearer {tokens['access_token']}"},
|
||||
json={
|
||||
"email": f"newuser-{uuid4().hex[:8]}@example.com",
|
||||
"password": "NewUserPass123!",
|
||||
"first_name": "New",
|
||||
"last_name": "User",
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 403
|
||||
|
||||
|
||||
class TestAdminOrganizationWorkflows:
|
||||
"""Test admin organization management workflows."""
|
||||
|
||||
async def test_regular_user_cannot_list_admin_orgs(self, e2e_client):
|
||||
"""Regular users cannot list organizations via admin endpoint."""
|
||||
email = f"e2e-{uuid4().hex[:8]}@example.com"
|
||||
await register_user(e2e_client, email)
|
||||
tokens = await login_user(e2e_client, email)
|
||||
|
||||
response = await e2e_client.get(
|
||||
"/api/v1/admin/organizations",
|
||||
headers={"Authorization": f"Bearer {tokens['access_token']}"},
|
||||
)
|
||||
|
||||
assert response.status_code == 403
|
||||
|
||||
async def test_regular_user_cannot_create_org_via_admin(self, e2e_client):
|
||||
"""Regular users cannot create organizations via admin endpoint."""
|
||||
email = f"e2e-{uuid4().hex[:8]}@example.com"
|
||||
await register_user(e2e_client, email)
|
||||
tokens = await login_user(e2e_client, email)
|
||||
|
||||
response = await e2e_client.post(
|
||||
"/api/v1/admin/organizations",
|
||||
headers={"Authorization": f"Bearer {tokens['access_token']}"},
|
||||
json={
|
||||
"name": "Test Org",
|
||||
"slug": f"test-org-{uuid4().hex[:8]}",
|
||||
"description": "Test organization",
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 403
|
||||
|
||||
|
||||
class TestAdminSessionWorkflows:
|
||||
"""Test admin session management workflows."""
|
||||
|
||||
async def test_regular_user_cannot_list_admin_sessions(self, e2e_client):
|
||||
"""Regular users cannot list sessions via admin endpoint."""
|
||||
email = f"e2e-{uuid4().hex[:8]}@example.com"
|
||||
await register_user(e2e_client, email)
|
||||
tokens = await login_user(e2e_client, email)
|
||||
|
||||
response = await e2e_client.get(
|
||||
"/api/v1/admin/sessions",
|
||||
headers={"Authorization": f"Bearer {tokens['access_token']}"},
|
||||
)
|
||||
|
||||
assert response.status_code == 403
|
||||
|
||||
|
||||
class TestAdminBulkOperations:
|
||||
"""Test admin bulk operation workflows."""
|
||||
|
||||
async def test_regular_user_cannot_bulk_activate_users(self, e2e_client):
|
||||
"""Regular users cannot perform bulk user activation."""
|
||||
email = f"e2e-{uuid4().hex[:8]}@example.com"
|
||||
await register_user(e2e_client, email)
|
||||
tokens = await login_user(e2e_client, email)
|
||||
|
||||
response = await e2e_client.post(
|
||||
"/api/v1/admin/users/bulk-action",
|
||||
headers={"Authorization": f"Bearer {tokens['access_token']}"},
|
||||
json={
|
||||
"action": "activate",
|
||||
"user_ids": [str(uuid4())],
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 403
|
||||
|
||||
|
||||
class TestAdminAuthorizationBoundaries:
|
||||
"""Test admin authorization security boundaries."""
|
||||
|
||||
async def test_unauthenticated_cannot_access_admin(self, e2e_client):
|
||||
"""Unauthenticated requests cannot access admin endpoints."""
|
||||
endpoints = [
|
||||
("/api/v1/admin/users", "get"),
|
||||
("/api/v1/admin/organizations", "get"),
|
||||
("/api/v1/admin/sessions", "get"),
|
||||
("/api/v1/admin/stats", "get"),
|
||||
]
|
||||
|
||||
for endpoint, method in endpoints:
|
||||
if method == "get":
|
||||
response = await e2e_client.get(endpoint)
|
||||
assert response.status_code == 401, f"Expected 401 for {endpoint}"
|
||||
|
||||
async def test_expired_token_rejected_for_admin(self, e2e_client):
|
||||
"""Expired tokens are rejected for admin endpoints."""
|
||||
# Use a clearly invalid/malformed token
|
||||
fake_token = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiaWF0IjoxNTE2MjM5MDIyfQ.SflKxwRJSMeKKF2QT4fwpMeJf36POk6yJV_adQssw5c"
|
||||
|
||||
response = await e2e_client.get(
|
||||
"/api/v1/admin/users",
|
||||
headers={"Authorization": f"Bearer {fake_token}"},
|
||||
)
|
||||
|
||||
assert response.status_code == 401
|
||||
@@ -40,56 +40,154 @@ if SCHEMATHESIS_AVAILABLE:
|
||||
# Load schema from the FastAPI app using schemathesis.openapi (v4.x API)
|
||||
schema = openapi.from_asgi("/api/v1/openapi.json", app=app)
|
||||
|
||||
# Test root endpoint (simple, always works)
|
||||
# =========================================================================
|
||||
# Public Endpoints (No Auth Required)
|
||||
# =========================================================================
|
||||
|
||||
# Test root endpoint
|
||||
root_schema = schema.include(path="/")
|
||||
|
||||
@root_schema.parametrize()
|
||||
@settings(max_examples=5)
|
||||
def test_root_endpoint_schema(case):
|
||||
"""
|
||||
Root endpoint schema compliance.
|
||||
|
||||
Tests that the root endpoint returns responses matching its schema.
|
||||
"""
|
||||
"""Root endpoint schema compliance."""
|
||||
response = case.call()
|
||||
# Just verify we get a response and no 5xx errors
|
||||
assert response.status_code < 500, f"Server error: {response.text}"
|
||||
|
||||
# Test health endpoint
|
||||
health_schema = schema.include(path="/health")
|
||||
|
||||
@health_schema.parametrize()
|
||||
@settings(max_examples=3)
|
||||
def test_health_endpoint_schema(case):
|
||||
"""Health endpoint schema compliance."""
|
||||
response = case.call()
|
||||
# Health check may return 200 or 503 depending on DB
|
||||
assert response.status_code < 500 or response.status_code == 503
|
||||
|
||||
# Test auth registration endpoint
|
||||
# Note: This tests schema validation, not actual database operations
|
||||
auth_register_schema = schema.include(path="/api/v1/auth/register")
|
||||
|
||||
@auth_register_schema.parametrize()
|
||||
@settings(max_examples=10)
|
||||
def test_register_endpoint_validates_input(case):
|
||||
"""
|
||||
Registration endpoint input validation.
|
||||
|
||||
Schemathesis generates various inputs to test validation.
|
||||
The endpoint should never return 5xx errors for invalid input.
|
||||
"""
|
||||
"""Registration endpoint input validation."""
|
||||
response = case.call()
|
||||
# Registration returns 200/201 (success), 400/422 (validation), 409 (conflict)
|
||||
# Never a 5xx error for validation issues
|
||||
# 200/201 (success), 400/422 (validation), 409 (conflict)
|
||||
assert response.status_code < 500, f"Server error: {response.text}"
|
||||
|
||||
# Note: Login and refresh endpoints require database, so they're tested
|
||||
# in test_database_workflows.py instead of here. Schemathesis tests run
|
||||
# without the testcontainers database fixtures.
|
||||
|
||||
# =========================================================================
|
||||
# Protected Endpoints - Manual tests for auth requirements
|
||||
# (Schemathesis parametrize tests all methods, manual tests are clearer)
|
||||
# =========================================================================
|
||||
|
||||
class TestProtectedEndpointsRequireAuth:
|
||||
"""Test that protected endpoints return proper auth errors."""
|
||||
|
||||
def test_users_me_requires_auth(self):
|
||||
"""Users/me GET endpoint requires authentication."""
|
||||
from starlette.testclient import TestClient
|
||||
|
||||
with TestClient(app) as client:
|
||||
response = client.get("/api/v1/users/me")
|
||||
assert response.status_code == 401
|
||||
|
||||
def test_sessions_me_requires_auth(self):
|
||||
"""Sessions/me GET endpoint requires authentication."""
|
||||
from starlette.testclient import TestClient
|
||||
|
||||
with TestClient(app) as client:
|
||||
response = client.get("/api/v1/sessions/me")
|
||||
assert response.status_code == 401
|
||||
|
||||
def test_organizations_me_requires_auth(self):
|
||||
"""Organizations/me GET endpoint requires authentication."""
|
||||
from starlette.testclient import TestClient
|
||||
|
||||
with TestClient(app) as client:
|
||||
response = client.get("/api/v1/organizations/me")
|
||||
assert response.status_code == 401
|
||||
|
||||
def test_admin_users_requires_auth(self):
|
||||
"""Admin users GET endpoint requires authentication."""
|
||||
from starlette.testclient import TestClient
|
||||
|
||||
with TestClient(app) as client:
|
||||
response = client.get("/api/v1/admin/users")
|
||||
assert response.status_code == 401
|
||||
|
||||
def test_admin_stats_requires_auth(self):
|
||||
"""Admin stats GET endpoint requires authentication."""
|
||||
from starlette.testclient import TestClient
|
||||
|
||||
with TestClient(app) as client:
|
||||
response = client.get("/api/v1/admin/stats")
|
||||
assert response.status_code == 401
|
||||
|
||||
def test_admin_organizations_requires_auth(self):
|
||||
"""Admin organizations GET endpoint requires authentication."""
|
||||
from starlette.testclient import TestClient
|
||||
|
||||
with TestClient(app) as client:
|
||||
response = client.get("/api/v1/admin/organizations")
|
||||
assert response.status_code == 401
|
||||
|
||||
# =========================================================================
|
||||
# Schema Validation Tests
|
||||
# =========================================================================
|
||||
|
||||
class TestSchemaValidation:
|
||||
"""Manual validation tests for schema structure."""
|
||||
|
||||
def test_schema_loaded_successfully(self):
|
||||
"""Verify schema was loaded from the app."""
|
||||
# Count operations to verify schema loaded
|
||||
ops = list(schema.get_all_operations())
|
||||
assert len(ops) > 0, "No operations found in schema"
|
||||
|
||||
def test_multiple_endpoints_documented(self):
|
||||
"""Verify multiple endpoints are documented in schema."""
|
||||
ops = list(schema.get_all_operations())
|
||||
# Should have at least 10 operations in a real API
|
||||
assert len(ops) >= 10, f"Only {len(ops)} operations found"
|
||||
|
||||
def test_schema_has_auth_operations(self):
|
||||
"""Verify auth-related operations exist."""
|
||||
# Filter for auth endpoints
|
||||
auth_ops = list(schema.include(path_regex=r".*auth.*").get_all_operations())
|
||||
assert len(auth_ops) > 0, "No auth operations found"
|
||||
|
||||
def test_schema_has_user_operations(self):
|
||||
"""Verify user-related operations exist."""
|
||||
user_ops = list(
|
||||
schema.include(path_regex=r".*users.*").get_all_operations()
|
||||
)
|
||||
assert len(user_ops) > 0, "No user operations found"
|
||||
|
||||
def test_schema_has_organization_operations(self):
|
||||
"""Verify organization-related operations exist."""
|
||||
org_ops = list(
|
||||
schema.include(path_regex=r".*organizations.*").get_all_operations()
|
||||
)
|
||||
assert len(org_ops) > 0, "No organization operations found"
|
||||
|
||||
def test_schema_has_admin_operations(self):
|
||||
"""Verify admin-related operations exist."""
|
||||
admin_ops = list(
|
||||
schema.include(path_regex=r".*admin.*").get_all_operations()
|
||||
)
|
||||
assert len(admin_ops) > 0, "No admin operations found"
|
||||
|
||||
def test_schema_has_session_operations(self):
|
||||
"""Verify session-related operations exist."""
|
||||
session_ops = list(
|
||||
schema.include(path_regex=r".*sessions.*").get_all_operations()
|
||||
)
|
||||
assert len(session_ops) > 0, "No session operations found"
|
||||
|
||||
def test_total_endpoint_count(self):
|
||||
"""Verify expected number of endpoints are documented."""
|
||||
ops = list(schema.get_all_operations())
|
||||
# We expect at least 40+ endpoints in this comprehensive API
|
||||
assert len(ops) >= 40, f"Only {len(ops)} operations found, expected 40+"
|
||||
|
||||
@@ -188,3 +188,134 @@ class TestHealthEndpoint:
|
||||
assert response.status_code in [200, 503]
|
||||
data = response.json()
|
||||
assert "status" in data
|
||||
|
||||
|
||||
class TestLogoutWorkflows:
|
||||
"""Test logout workflows."""
|
||||
|
||||
async def test_logout_invalidates_session(self, e2e_client):
|
||||
"""Test that logout invalidates the session."""
|
||||
email = f"e2e-logout-{uuid4().hex[:8]}@example.com"
|
||||
password = "SecurePassword123!"
|
||||
|
||||
# Register and login
|
||||
await e2e_client.post(
|
||||
"/api/v1/auth/register",
|
||||
json={
|
||||
"email": email,
|
||||
"password": password,
|
||||
"first_name": "Logout",
|
||||
"last_name": "Test",
|
||||
},
|
||||
)
|
||||
|
||||
login_resp = await e2e_client.post(
|
||||
"/api/v1/auth/login",
|
||||
json={"email": email, "password": password},
|
||||
)
|
||||
tokens = login_resp.json()
|
||||
|
||||
# Logout requires both access token (auth) and refresh token (body)
|
||||
logout_resp = await e2e_client.post(
|
||||
"/api/v1/auth/logout",
|
||||
headers={"Authorization": f"Bearer {tokens['access_token']}"},
|
||||
json={"refresh_token": tokens["refresh_token"]},
|
||||
)
|
||||
assert logout_resp.status_code == 200
|
||||
|
||||
async def test_invalid_refresh_token_rejected(self, e2e_client):
|
||||
"""Test that invalid refresh tokens are rejected."""
|
||||
response = await e2e_client.post(
|
||||
"/api/v1/auth/refresh",
|
||||
json={"refresh_token": "invalid_refresh_token"},
|
||||
)
|
||||
assert response.status_code in [401, 422]
|
||||
|
||||
|
||||
class TestValidationWorkflows:
|
||||
"""Test input validation workflows."""
|
||||
|
||||
async def test_register_invalid_email(self, e2e_client):
|
||||
"""Test that invalid email format is rejected."""
|
||||
response = await e2e_client.post(
|
||||
"/api/v1/auth/register",
|
||||
json={
|
||||
"email": "not_an_email",
|
||||
"password": "ValidPassword123!",
|
||||
"first_name": "Test",
|
||||
"last_name": "User",
|
||||
},
|
||||
)
|
||||
assert response.status_code == 422
|
||||
|
||||
async def test_register_weak_password(self, e2e_client):
|
||||
"""Test that weak passwords are rejected."""
|
||||
email = f"e2e-weak-{uuid4().hex[:8]}@example.com"
|
||||
response = await e2e_client.post(
|
||||
"/api/v1/auth/register",
|
||||
json={
|
||||
"email": email,
|
||||
"password": "weak", # Too weak
|
||||
"first_name": "Test",
|
||||
"last_name": "User",
|
||||
},
|
||||
)
|
||||
assert response.status_code == 422
|
||||
|
||||
async def test_login_missing_fields(self, e2e_client):
|
||||
"""Test that login requires all fields."""
|
||||
response = await e2e_client.post(
|
||||
"/api/v1/auth/login",
|
||||
json={"email": "test@example.com"}, # Missing password
|
||||
)
|
||||
assert response.status_code == 422
|
||||
|
||||
|
||||
class TestRootEndpoint:
|
||||
"""Test root endpoint."""
|
||||
|
||||
async def test_root_responds(self, e2e_client):
|
||||
"""Root endpoint should respond with HTML."""
|
||||
response = await e2e_client.get("/")
|
||||
assert response.status_code == 200
|
||||
# Root returns HTML
|
||||
assert "html" in response.text.lower() or "Welcome" in response.text
|
||||
|
||||
async def test_openapi_available(self, e2e_client):
|
||||
"""OpenAPI schema should be available."""
|
||||
response = await e2e_client.get("/api/v1/openapi.json")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert "openapi" in data
|
||||
assert "paths" in data
|
||||
|
||||
|
||||
class TestAuthTokenWorkflows:
|
||||
"""Test authentication token workflows."""
|
||||
|
||||
async def test_access_token_expires(self, e2e_client):
|
||||
"""Test using expired access token."""
|
||||
# Use a fake/expired token
|
||||
fake_token = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwiZXhwIjoxNjAwMDAwMDAwfQ.invalid"
|
||||
|
||||
response = await e2e_client.get(
|
||||
"/api/v1/users/me",
|
||||
headers={"Authorization": f"Bearer {fake_token}"},
|
||||
)
|
||||
assert response.status_code == 401
|
||||
|
||||
async def test_malformed_token_rejected(self, e2e_client):
|
||||
"""Test that malformed tokens are rejected."""
|
||||
response = await e2e_client.get(
|
||||
"/api/v1/users/me",
|
||||
headers={"Authorization": "Bearer not-a-valid-token"},
|
||||
)
|
||||
assert response.status_code == 401
|
||||
|
||||
async def test_missing_bearer_prefix(self, e2e_client):
|
||||
"""Test that tokens without Bearer prefix are rejected."""
|
||||
response = await e2e_client.get(
|
||||
"/api/v1/users/me",
|
||||
headers={"Authorization": "some-token"},
|
||||
)
|
||||
assert response.status_code == 401
|
||||
|
||||
353
backend/tests/e2e/test_organization_workflows.py
Normal file
353
backend/tests/e2e/test_organization_workflows.py
Normal file
@@ -0,0 +1,353 @@
|
||||
"""
|
||||
Organization E2E workflow tests with real PostgreSQL.
|
||||
|
||||
These tests validate complete organization workflows including:
|
||||
- Creating organizations (via admin)
|
||||
- Viewing user's organizations
|
||||
- Organization membership management
|
||||
- Organization updates
|
||||
|
||||
Usage:
|
||||
make test-e2e # Run all E2E tests
|
||||
"""
|
||||
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
|
||||
pytestmark = [
|
||||
pytest.mark.e2e,
|
||||
pytest.mark.postgres,
|
||||
pytest.mark.asyncio,
|
||||
]
|
||||
|
||||
|
||||
async def register_and_login(client, email: str, password: str = "SecurePassword123!"): # noqa: S107
|
||||
"""Helper to register a user and get tokens."""
|
||||
# Register
|
||||
await client.post(
|
||||
"/api/v1/auth/register",
|
||||
json={
|
||||
"email": email,
|
||||
"password": password,
|
||||
"first_name": "Test",
|
||||
"last_name": "User",
|
||||
},
|
||||
)
|
||||
|
||||
# Login
|
||||
login_resp = await client.post(
|
||||
"/api/v1/auth/login",
|
||||
json={"email": email, "password": password},
|
||||
)
|
||||
tokens = login_resp.json()
|
||||
return tokens
|
||||
|
||||
|
||||
async def create_superuser_and_login(client, db_session):
|
||||
"""Helper to create a superuser directly in DB and login."""
|
||||
from app.crud.user import user as user_crud
|
||||
from app.schemas.users import UserCreate
|
||||
|
||||
email = f"admin-{uuid4().hex[:8]}@example.com"
|
||||
password = "AdminPassword123!"
|
||||
|
||||
# Create superuser directly
|
||||
user_in = UserCreate(
|
||||
email=email,
|
||||
password=password,
|
||||
first_name="Admin",
|
||||
last_name="User",
|
||||
is_superuser=True,
|
||||
)
|
||||
await user_crud.create(db_session, obj_in=user_in)
|
||||
|
||||
# Login
|
||||
login_resp = await client.post(
|
||||
"/api/v1/auth/login",
|
||||
json={"email": email, "password": password},
|
||||
)
|
||||
return login_resp.json(), email
|
||||
|
||||
|
||||
class TestOrganizationWorkflows:
|
||||
"""Test organization management workflows."""
|
||||
|
||||
async def test_user_has_no_organizations_initially(self, e2e_client):
|
||||
"""New users should have no organizations."""
|
||||
email = f"e2e-{uuid4().hex[:8]}@example.com"
|
||||
tokens = await register_and_login(e2e_client, email)
|
||||
|
||||
response = await e2e_client.get(
|
||||
"/api/v1/organizations/me",
|
||||
headers={"Authorization": f"Bearer {tokens['access_token']}"},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert isinstance(data, list)
|
||||
assert len(data) == 0
|
||||
|
||||
async def test_get_organizations_requires_auth(self, e2e_client):
|
||||
"""Organizations endpoint requires authentication."""
|
||||
response = await e2e_client.get("/api/v1/organizations/me")
|
||||
assert response.status_code == 401
|
||||
|
||||
async def test_get_nonexistent_organization(self, e2e_client):
|
||||
"""Getting a non-member organization returns 403."""
|
||||
email = f"e2e-{uuid4().hex[:8]}@example.com"
|
||||
tokens = await register_and_login(e2e_client, email)
|
||||
|
||||
fake_org_id = str(uuid4())
|
||||
response = await e2e_client.get(
|
||||
f"/api/v1/organizations/{fake_org_id}",
|
||||
headers={"Authorization": f"Bearer {tokens['access_token']}"},
|
||||
)
|
||||
|
||||
# Should be 403 (not a member) or 404 (not found)
|
||||
assert response.status_code in [403, 404]
|
||||
|
||||
|
||||
class TestOrganizationMembershipWorkflows:
|
||||
"""Test organization membership workflows."""
|
||||
|
||||
async def test_non_member_cannot_view_org_details(self, e2e_client):
|
||||
"""Users cannot view organizations they're not members of."""
|
||||
# Create two users
|
||||
user1_email = f"e2e-user1-{uuid4().hex[:8]}@example.com"
|
||||
user2_email = f"e2e-user2-{uuid4().hex[:8]}@example.com"
|
||||
|
||||
await register_and_login(e2e_client, user1_email)
|
||||
user2_tokens = await register_and_login(e2e_client, user2_email)
|
||||
|
||||
# User2 tries to access a random org ID
|
||||
fake_org_id = str(uuid4())
|
||||
response = await e2e_client.get(
|
||||
f"/api/v1/organizations/{fake_org_id}",
|
||||
headers={"Authorization": f"Bearer {user2_tokens['access_token']}"},
|
||||
)
|
||||
|
||||
assert response.status_code in [403, 404]
|
||||
|
||||
async def test_non_member_cannot_view_org_members(self, e2e_client):
|
||||
"""Users cannot view members of organizations they don't belong to."""
|
||||
email = f"e2e-{uuid4().hex[:8]}@example.com"
|
||||
tokens = await register_and_login(e2e_client, email)
|
||||
|
||||
fake_org_id = str(uuid4())
|
||||
response = await e2e_client.get(
|
||||
f"/api/v1/organizations/{fake_org_id}/members",
|
||||
headers={"Authorization": f"Bearer {tokens['access_token']}"},
|
||||
)
|
||||
|
||||
assert response.status_code in [403, 404]
|
||||
|
||||
async def test_non_admin_cannot_update_organization(self, e2e_client):
|
||||
"""Regular users cannot update organizations (need admin role)."""
|
||||
email = f"e2e-{uuid4().hex[:8]}@example.com"
|
||||
tokens = await register_and_login(e2e_client, email)
|
||||
|
||||
fake_org_id = str(uuid4())
|
||||
response = await e2e_client.put(
|
||||
f"/api/v1/organizations/{fake_org_id}",
|
||||
headers={"Authorization": f"Bearer {tokens['access_token']}"},
|
||||
json={"name": "Updated Name"},
|
||||
)
|
||||
|
||||
assert response.status_code in [403, 404]
|
||||
|
||||
|
||||
class TestOrganizationWithMembers:
|
||||
"""Test organization workflows using e2e_org_with_members fixture."""
|
||||
|
||||
async def test_owner_can_view_organization(self, e2e_client, e2e_org_with_members):
|
||||
"""Organization owner can view organization details."""
|
||||
org = e2e_org_with_members
|
||||
response = await e2e_client.get(
|
||||
f"/api/v1/organizations/{org['org_id']}",
|
||||
headers={
|
||||
"Authorization": f"Bearer {org['owner']['tokens']['access_token']}"
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["id"] == org["org_id"]
|
||||
assert data["name"] == org["org_name"]
|
||||
|
||||
async def test_member_can_view_organization(self, e2e_client, e2e_org_with_members):
|
||||
"""Organization member can view organization details."""
|
||||
org = e2e_org_with_members
|
||||
response = await e2e_client.get(
|
||||
f"/api/v1/organizations/{org['org_id']}",
|
||||
headers={
|
||||
"Authorization": f"Bearer {org['member']['tokens']['access_token']}"
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["id"] == org["org_id"]
|
||||
|
||||
async def test_owner_can_list_members(self, e2e_client, e2e_org_with_members):
|
||||
"""Organization owner can list members."""
|
||||
org = e2e_org_with_members
|
||||
response = await e2e_client.get(
|
||||
f"/api/v1/organizations/{org['org_id']}/members",
|
||||
headers={
|
||||
"Authorization": f"Bearer {org['owner']['tokens']['access_token']}"
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
# Should have owner + member = at least 2 members
|
||||
assert len(data) >= 2
|
||||
|
||||
async def test_member_can_list_members(self, e2e_client, e2e_org_with_members):
|
||||
"""Organization member can list members."""
|
||||
org = e2e_org_with_members
|
||||
response = await e2e_client.get(
|
||||
f"/api/v1/organizations/{org['org_id']}/members",
|
||||
headers={
|
||||
"Authorization": f"Bearer {org['member']['tokens']['access_token']}"
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
|
||||
async def test_owner_appears_in_my_organizations(
|
||||
self, e2e_client, e2e_org_with_members
|
||||
):
|
||||
"""Owner sees organization in their organizations list."""
|
||||
org = e2e_org_with_members
|
||||
response = await e2e_client.get(
|
||||
"/api/v1/organizations/me",
|
||||
headers={
|
||||
"Authorization": f"Bearer {org['owner']['tokens']['access_token']}"
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
org_ids = [o["id"] for o in data]
|
||||
assert org["org_id"] in org_ids
|
||||
|
||||
async def test_member_appears_in_my_organizations(
|
||||
self, e2e_client, e2e_org_with_members
|
||||
):
|
||||
"""Member sees organization in their organizations list."""
|
||||
org = e2e_org_with_members
|
||||
response = await e2e_client.get(
|
||||
"/api/v1/organizations/me",
|
||||
headers={
|
||||
"Authorization": f"Bearer {org['member']['tokens']['access_token']}"
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
org_ids = [o["id"] for o in data]
|
||||
assert org["org_id"] in org_ids
|
||||
|
||||
async def test_owner_can_update_organization(
|
||||
self, e2e_client, e2e_org_with_members
|
||||
):
|
||||
"""Organization owner can update organization details."""
|
||||
org = e2e_org_with_members
|
||||
new_description = f"Updated at {uuid4().hex[:8]}"
|
||||
|
||||
response = await e2e_client.put(
|
||||
f"/api/v1/organizations/{org['org_id']}",
|
||||
headers={
|
||||
"Authorization": f"Bearer {org['owner']['tokens']['access_token']}"
|
||||
},
|
||||
json={"description": new_description},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["description"] == new_description
|
||||
|
||||
async def test_member_cannot_update_organization(
|
||||
self, e2e_client, e2e_org_with_members
|
||||
):
|
||||
"""Regular member cannot update organization details."""
|
||||
org = e2e_org_with_members
|
||||
|
||||
response = await e2e_client.put(
|
||||
f"/api/v1/organizations/{org['org_id']}",
|
||||
headers={
|
||||
"Authorization": f"Bearer {org['member']['tokens']['access_token']}"
|
||||
},
|
||||
json={"description": "Should fail"},
|
||||
)
|
||||
|
||||
assert response.status_code == 403
|
||||
|
||||
async def test_non_member_cannot_view_organization(
|
||||
self, e2e_client, e2e_org_with_members
|
||||
):
|
||||
"""Non-members cannot view organization details."""
|
||||
org = e2e_org_with_members
|
||||
|
||||
# Create a new user who is not a member
|
||||
non_member_email = f"nonmember-{uuid4().hex[:8]}@example.com"
|
||||
tokens = await register_and_login(e2e_client, non_member_email)
|
||||
|
||||
response = await e2e_client.get(
|
||||
f"/api/v1/organizations/{org['org_id']}",
|
||||
headers={"Authorization": f"Bearer {tokens['access_token']}"},
|
||||
)
|
||||
|
||||
assert response.status_code == 403
|
||||
|
||||
async def test_get_organization_by_slug(self, e2e_client, e2e_org_with_members):
|
||||
"""Organization can be retrieved by slug."""
|
||||
org = e2e_org_with_members
|
||||
response = await e2e_client.get(
|
||||
f"/api/v1/organizations/slug/{org['org_slug']}",
|
||||
headers={
|
||||
"Authorization": f"Bearer {org['owner']['tokens']['access_token']}"
|
||||
},
|
||||
)
|
||||
|
||||
# May be 200 or 404/403 depending on implementation
|
||||
assert response.status_code in [200, 403, 404]
|
||||
|
||||
|
||||
class TestOrganizationAdminOperations:
|
||||
"""Test organization admin operations."""
|
||||
|
||||
async def test_admin_list_org_members_with_pagination(
|
||||
self, e2e_client, e2e_superuser, e2e_org_with_members
|
||||
):
|
||||
"""Admin can list org members with pagination."""
|
||||
org = e2e_org_with_members
|
||||
response = await e2e_client.get(
|
||||
f"/api/v1/admin/organizations/{org['org_id']}/members",
|
||||
headers={
|
||||
"Authorization": f"Bearer {e2e_superuser['tokens']['access_token']}"
|
||||
},
|
||||
params={"page": 1, "limit": 10},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert "data" in data
|
||||
assert "pagination" in data
|
||||
|
||||
async def test_admin_list_org_members_filter_active(
|
||||
self, e2e_client, e2e_superuser, e2e_org_with_members
|
||||
):
|
||||
"""Admin can filter org members by active status."""
|
||||
org = e2e_org_with_members
|
||||
response = await e2e_client.get(
|
||||
f"/api/v1/admin/organizations/{org['org_id']}/members",
|
||||
headers={
|
||||
"Authorization": f"Bearer {e2e_superuser['tokens']['access_token']}"
|
||||
},
|
||||
params={"is_active": True},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
331
backend/tests/e2e/test_session_workflows.py
Normal file
331
backend/tests/e2e/test_session_workflows.py
Normal file
@@ -0,0 +1,331 @@
|
||||
"""
|
||||
Session management E2E workflow tests with real PostgreSQL.
|
||||
|
||||
These tests validate complete session management workflows including:
|
||||
- Listing active sessions
|
||||
- Session revocation
|
||||
- Session cleanup
|
||||
- Multi-device session handling
|
||||
|
||||
Usage:
|
||||
make test-e2e # Run all E2E tests
|
||||
"""
|
||||
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
|
||||
pytestmark = [
|
||||
pytest.mark.e2e,
|
||||
pytest.mark.postgres,
|
||||
pytest.mark.asyncio,
|
||||
]
|
||||
|
||||
|
||||
async def register_and_login(
|
||||
client,
|
||||
email: str,
|
||||
password: str = "SecurePassword123!", # noqa: S107
|
||||
user_agent: str | None = None,
|
||||
):
|
||||
"""Helper to register a user and get tokens."""
|
||||
await client.post(
|
||||
"/api/v1/auth/register",
|
||||
json={
|
||||
"email": email,
|
||||
"password": password,
|
||||
"first_name": "Test",
|
||||
"last_name": "User",
|
||||
},
|
||||
)
|
||||
|
||||
headers = {}
|
||||
if user_agent:
|
||||
headers["User-Agent"] = user_agent
|
||||
|
||||
login_resp = await client.post(
|
||||
"/api/v1/auth/login",
|
||||
json={"email": email, "password": password},
|
||||
headers=headers,
|
||||
)
|
||||
return login_resp.json()
|
||||
|
||||
|
||||
class TestSessionListingWorkflows:
|
||||
"""Test session listing workflows."""
|
||||
|
||||
async def test_list_sessions_after_login(self, e2e_client):
|
||||
"""Users can list their active sessions after login."""
|
||||
email = f"e2e-{uuid4().hex[:8]}@example.com"
|
||||
tokens = await register_and_login(e2e_client, email)
|
||||
|
||||
response = await e2e_client.get(
|
||||
"/api/v1/sessions/me",
|
||||
headers={"Authorization": f"Bearer {tokens['access_token']}"},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert "sessions" in data
|
||||
assert "total" in data
|
||||
assert data["total"] >= 1
|
||||
assert len(data["sessions"]) >= 1
|
||||
|
||||
async def test_session_contains_expected_fields(self, e2e_client):
|
||||
"""Session response contains expected fields."""
|
||||
email = f"e2e-{uuid4().hex[:8]}@example.com"
|
||||
tokens = await register_and_login(e2e_client, email)
|
||||
|
||||
response = await e2e_client.get(
|
||||
"/api/v1/sessions/me",
|
||||
headers={"Authorization": f"Bearer {tokens['access_token']}"},
|
||||
)
|
||||
|
||||
data = response.json()
|
||||
session = data["sessions"][0]
|
||||
|
||||
# Check required fields
|
||||
assert "id" in session
|
||||
assert "created_at" in session
|
||||
assert "last_used_at" in session
|
||||
assert "is_current" in session
|
||||
|
||||
async def test_list_sessions_requires_auth(self, e2e_client):
|
||||
"""Sessions endpoint requires authentication."""
|
||||
response = await e2e_client.get("/api/v1/sessions/me")
|
||||
assert response.status_code == 401
|
||||
|
||||
async def test_multiple_logins_create_multiple_sessions(self, e2e_client):
|
||||
"""Multiple logins create multiple sessions."""
|
||||
email = f"e2e-{uuid4().hex[:8]}@example.com"
|
||||
password = "SecurePassword123!"
|
||||
|
||||
# Register
|
||||
await e2e_client.post(
|
||||
"/api/v1/auth/register",
|
||||
json={
|
||||
"email": email,
|
||||
"password": password,
|
||||
"first_name": "Test",
|
||||
"last_name": "User",
|
||||
},
|
||||
)
|
||||
|
||||
# Login multiple times with different user agents
|
||||
tokens1 = (
|
||||
await e2e_client.post(
|
||||
"/api/v1/auth/login",
|
||||
json={"email": email, "password": password},
|
||||
headers={"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64)"},
|
||||
)
|
||||
).json()
|
||||
|
||||
# Second login to create another session
|
||||
await e2e_client.post(
|
||||
"/api/v1/auth/login",
|
||||
json={"email": email, "password": password},
|
||||
headers={"User-Agent": "Mozilla/5.0 (iPhone; CPU iPhone OS 15_0)"},
|
||||
)
|
||||
|
||||
# Check sessions using first token
|
||||
response = await e2e_client.get(
|
||||
"/api/v1/sessions/me",
|
||||
headers={"Authorization": f"Bearer {tokens1['access_token']}"},
|
||||
)
|
||||
|
||||
data = response.json()
|
||||
assert data["total"] >= 2
|
||||
|
||||
|
||||
class TestSessionRevocationWorkflows:
|
||||
"""Test session revocation workflows."""
|
||||
|
||||
async def test_revoke_own_session(self, e2e_client):
|
||||
"""Users can revoke their own sessions."""
|
||||
email = f"e2e-{uuid4().hex[:8]}@example.com"
|
||||
password = "SecurePassword123!"
|
||||
|
||||
# Register
|
||||
await e2e_client.post(
|
||||
"/api/v1/auth/register",
|
||||
json={
|
||||
"email": email,
|
||||
"password": password,
|
||||
"first_name": "Test",
|
||||
"last_name": "User",
|
||||
},
|
||||
)
|
||||
|
||||
# Create two sessions
|
||||
tokens1 = (
|
||||
await e2e_client.post(
|
||||
"/api/v1/auth/login",
|
||||
json={"email": email, "password": password},
|
||||
)
|
||||
).json()
|
||||
|
||||
# Second login to create another session
|
||||
await e2e_client.post(
|
||||
"/api/v1/auth/login",
|
||||
json={"email": email, "password": password},
|
||||
)
|
||||
|
||||
# Get sessions
|
||||
sessions_resp = await e2e_client.get(
|
||||
"/api/v1/sessions/me",
|
||||
headers={"Authorization": f"Bearer {tokens1['access_token']}"},
|
||||
)
|
||||
sessions = sessions_resp.json()["sessions"]
|
||||
initial_count = len(sessions)
|
||||
|
||||
# Revoke one session (not the current one)
|
||||
session_to_revoke = sessions[-1]["id"]
|
||||
revoke_resp = await e2e_client.delete(
|
||||
f"/api/v1/sessions/{session_to_revoke}",
|
||||
headers={"Authorization": f"Bearer {tokens1['access_token']}"},
|
||||
)
|
||||
|
||||
assert revoke_resp.status_code == 200
|
||||
assert revoke_resp.json()["success"] is True
|
||||
|
||||
# Verify session count decreased
|
||||
updated_sessions_resp = await e2e_client.get(
|
||||
"/api/v1/sessions/me",
|
||||
headers={"Authorization": f"Bearer {tokens1['access_token']}"},
|
||||
)
|
||||
updated_count = updated_sessions_resp.json()["total"]
|
||||
assert updated_count == initial_count - 1
|
||||
|
||||
async def test_cannot_revoke_nonexistent_session(self, e2e_client):
|
||||
"""Cannot revoke a session that doesn't exist."""
|
||||
email = f"e2e-{uuid4().hex[:8]}@example.com"
|
||||
tokens = await register_and_login(e2e_client, email)
|
||||
|
||||
fake_session_id = str(uuid4())
|
||||
response = await e2e_client.delete(
|
||||
f"/api/v1/sessions/{fake_session_id}",
|
||||
headers={"Authorization": f"Bearer {tokens['access_token']}"},
|
||||
)
|
||||
|
||||
assert response.status_code == 404
|
||||
|
||||
async def test_cannot_revoke_other_user_session(self, e2e_client):
|
||||
"""Users cannot revoke other users' sessions."""
|
||||
user1_email = f"e2e-user1-{uuid4().hex[:8]}@example.com"
|
||||
user2_email = f"e2e-user2-{uuid4().hex[:8]}@example.com"
|
||||
|
||||
tokens1 = await register_and_login(e2e_client, user1_email)
|
||||
tokens2 = await register_and_login(e2e_client, user2_email)
|
||||
|
||||
# Get user1's session ID
|
||||
sessions_resp = await e2e_client.get(
|
||||
"/api/v1/sessions/me",
|
||||
headers={"Authorization": f"Bearer {tokens1['access_token']}"},
|
||||
)
|
||||
user1_session_id = sessions_resp.json()["sessions"][0]["id"]
|
||||
|
||||
# User2 tries to revoke user1's session
|
||||
response = await e2e_client.delete(
|
||||
f"/api/v1/sessions/{user1_session_id}",
|
||||
headers={"Authorization": f"Bearer {tokens2['access_token']}"},
|
||||
)
|
||||
|
||||
assert response.status_code == 403
|
||||
|
||||
|
||||
class TestSessionCleanupWorkflows:
|
||||
"""Test session cleanup workflows."""
|
||||
|
||||
async def test_cleanup_expired_sessions(self, e2e_client):
|
||||
"""Users can cleanup their expired sessions."""
|
||||
email = f"e2e-{uuid4().hex[:8]}@example.com"
|
||||
tokens = await register_and_login(e2e_client, email)
|
||||
|
||||
response = await e2e_client.delete(
|
||||
"/api/v1/sessions/me/expired",
|
||||
headers={"Authorization": f"Bearer {tokens['access_token']}"},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["success"] is True
|
||||
assert "Cleaned up" in data["message"]
|
||||
|
||||
async def test_cleanup_requires_auth(self, e2e_client):
|
||||
"""Session cleanup requires authentication."""
|
||||
response = await e2e_client.delete("/api/v1/sessions/me/expired")
|
||||
assert response.status_code == 401
|
||||
|
||||
|
||||
class TestLogoutWorkflows:
|
||||
"""Test logout workflows."""
|
||||
|
||||
async def test_logout_invalidates_session(self, e2e_client):
|
||||
"""Logout should invalidate the session."""
|
||||
email = f"e2e-{uuid4().hex[:8]}@example.com"
|
||||
tokens = await register_and_login(e2e_client, email)
|
||||
|
||||
# Logout
|
||||
logout_resp = await e2e_client.post(
|
||||
"/api/v1/auth/logout",
|
||||
headers={"Authorization": f"Bearer {tokens['access_token']}"},
|
||||
json={"refresh_token": tokens["refresh_token"]},
|
||||
)
|
||||
|
||||
assert logout_resp.status_code == 200
|
||||
|
||||
# Refresh token should no longer work
|
||||
refresh_resp = await e2e_client.post(
|
||||
"/api/v1/auth/refresh",
|
||||
json={"refresh_token": tokens["refresh_token"]},
|
||||
)
|
||||
|
||||
# May be 401 or 400 depending on implementation
|
||||
assert refresh_resp.status_code in [400, 401]
|
||||
|
||||
async def test_logout_all_revokes_all_sessions(self, e2e_client):
|
||||
"""Logout all should revoke all sessions."""
|
||||
email = f"e2e-{uuid4().hex[:8]}@example.com"
|
||||
password = "SecurePassword123!"
|
||||
|
||||
# Register
|
||||
await e2e_client.post(
|
||||
"/api/v1/auth/register",
|
||||
json={
|
||||
"email": email,
|
||||
"password": password,
|
||||
"first_name": "Test",
|
||||
"last_name": "User",
|
||||
},
|
||||
)
|
||||
|
||||
# Create multiple sessions
|
||||
tokens1 = (
|
||||
await e2e_client.post(
|
||||
"/api/v1/auth/login",
|
||||
json={"email": email, "password": password},
|
||||
)
|
||||
).json()
|
||||
|
||||
tokens2 = (
|
||||
await e2e_client.post(
|
||||
"/api/v1/auth/login",
|
||||
json={"email": email, "password": password},
|
||||
)
|
||||
).json()
|
||||
|
||||
# Logout all
|
||||
logout_resp = await e2e_client.post(
|
||||
"/api/v1/auth/logout-all",
|
||||
headers={"Authorization": f"Bearer {tokens1['access_token']}"},
|
||||
)
|
||||
|
||||
assert logout_resp.status_code == 200
|
||||
|
||||
# Second token's refresh should no longer work
|
||||
refresh_resp = await e2e_client.post(
|
||||
"/api/v1/auth/refresh",
|
||||
json={"refresh_token": tokens2["refresh_token"]},
|
||||
)
|
||||
|
||||
assert refresh_resp.status_code in [400, 401]
|
||||
351
backend/tests/e2e/test_user_workflows.py
Normal file
351
backend/tests/e2e/test_user_workflows.py
Normal file
@@ -0,0 +1,351 @@
|
||||
"""
|
||||
User management E2E workflow tests with real PostgreSQL.
|
||||
|
||||
These tests validate complete user management workflows including:
|
||||
- Profile viewing and updates
|
||||
- Password changes
|
||||
- User settings management
|
||||
|
||||
Usage:
|
||||
make test-e2e # Run all E2E tests
|
||||
"""
|
||||
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
|
||||
pytestmark = [
|
||||
pytest.mark.e2e,
|
||||
pytest.mark.postgres,
|
||||
pytest.mark.asyncio,
|
||||
]
|
||||
|
||||
|
||||
async def register_and_login(client, email: str, password: str = "SecurePassword123!"): # noqa: S107
|
||||
"""Helper to register a user and get tokens."""
|
||||
await client.post(
|
||||
"/api/v1/auth/register",
|
||||
json={
|
||||
"email": email,
|
||||
"password": password,
|
||||
"first_name": "Test",
|
||||
"last_name": "User",
|
||||
},
|
||||
)
|
||||
|
||||
login_resp = await client.post(
|
||||
"/api/v1/auth/login",
|
||||
json={"email": email, "password": password},
|
||||
)
|
||||
return login_resp.json()
|
||||
|
||||
|
||||
class TestUserProfileWorkflows:
|
||||
"""Test user profile management workflows."""
|
||||
|
||||
async def test_get_own_profile(self, e2e_client):
|
||||
"""Users can view their own profile."""
|
||||
email = f"e2e-{uuid4().hex[:8]}@example.com"
|
||||
tokens = await register_and_login(e2e_client, email)
|
||||
|
||||
response = await e2e_client.get(
|
||||
"/api/v1/users/me",
|
||||
headers={"Authorization": f"Bearer {tokens['access_token']}"},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["email"] == email
|
||||
assert data["first_name"] == "Test"
|
||||
assert data["last_name"] == "User"
|
||||
assert "id" in data
|
||||
assert "is_active" in data
|
||||
|
||||
async def test_update_own_profile(self, e2e_client):
|
||||
"""Users can update their own profile."""
|
||||
email = f"e2e-{uuid4().hex[:8]}@example.com"
|
||||
tokens = await register_and_login(e2e_client, email)
|
||||
|
||||
response = await e2e_client.patch(
|
||||
"/api/v1/users/me",
|
||||
headers={"Authorization": f"Bearer {tokens['access_token']}"},
|
||||
json={
|
||||
"first_name": "Updated",
|
||||
"last_name": "Name",
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["first_name"] == "Updated"
|
||||
assert data["last_name"] == "Name"
|
||||
|
||||
# Verify changes persisted
|
||||
verify_resp = await e2e_client.get(
|
||||
"/api/v1/users/me",
|
||||
headers={"Authorization": f"Bearer {tokens['access_token']}"},
|
||||
)
|
||||
assert verify_resp.json()["first_name"] == "Updated"
|
||||
|
||||
async def test_profile_requires_auth(self, e2e_client):
|
||||
"""Profile endpoints require authentication."""
|
||||
response = await e2e_client.get("/api/v1/users/me")
|
||||
assert response.status_code == 401
|
||||
|
||||
async def test_get_user_by_id_own_profile(self, e2e_client):
|
||||
"""Users can get their own profile by ID."""
|
||||
email = f"e2e-{uuid4().hex[:8]}@example.com"
|
||||
tokens = await register_and_login(e2e_client, email)
|
||||
|
||||
# Get user ID from /me endpoint
|
||||
me_resp = await e2e_client.get(
|
||||
"/api/v1/users/me",
|
||||
headers={"Authorization": f"Bearer {tokens['access_token']}"},
|
||||
)
|
||||
user_id = me_resp.json()["id"]
|
||||
|
||||
# Get by ID
|
||||
response = await e2e_client.get(
|
||||
f"/api/v1/users/{user_id}",
|
||||
headers={"Authorization": f"Bearer {tokens['access_token']}"},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json()["id"] == user_id
|
||||
|
||||
async def test_cannot_get_other_user_profile(self, e2e_client):
|
||||
"""Regular users cannot view other users' profiles."""
|
||||
# Create two users
|
||||
user1_email = f"e2e-user1-{uuid4().hex[:8]}@example.com"
|
||||
user2_email = f"e2e-user2-{uuid4().hex[:8]}@example.com"
|
||||
|
||||
tokens1 = await register_and_login(e2e_client, user1_email)
|
||||
tokens2 = await register_and_login(e2e_client, user2_email)
|
||||
|
||||
# Get user1's ID
|
||||
me_resp = await e2e_client.get(
|
||||
"/api/v1/users/me",
|
||||
headers={"Authorization": f"Bearer {tokens1['access_token']}"},
|
||||
)
|
||||
user1_id = me_resp.json()["id"]
|
||||
|
||||
# User2 tries to access user1's profile
|
||||
response = await e2e_client.get(
|
||||
f"/api/v1/users/{user1_id}",
|
||||
headers={"Authorization": f"Bearer {tokens2['access_token']}"},
|
||||
)
|
||||
|
||||
assert response.status_code == 403
|
||||
|
||||
|
||||
class TestPasswordChangeWorkflows:
|
||||
"""Test password change workflows."""
|
||||
|
||||
async def test_change_password_success(self, e2e_client):
|
||||
"""Users can change their password with correct current password."""
|
||||
email = f"e2e-{uuid4().hex[:8]}@example.com"
|
||||
old_password = "OldPassword123!"
|
||||
new_password = "NewPassword456!"
|
||||
|
||||
tokens = await register_and_login(e2e_client, email, old_password)
|
||||
|
||||
response = await e2e_client.patch(
|
||||
"/api/v1/users/me/password",
|
||||
headers={"Authorization": f"Bearer {tokens['access_token']}"},
|
||||
json={
|
||||
"current_password": old_password,
|
||||
"new_password": new_password,
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["success"] is True
|
||||
|
||||
# Verify new password works
|
||||
login_resp = await e2e_client.post(
|
||||
"/api/v1/auth/login",
|
||||
json={"email": email, "password": new_password},
|
||||
)
|
||||
assert login_resp.status_code == 200
|
||||
|
||||
async def test_change_password_wrong_current(self, e2e_client):
|
||||
"""Password change fails with wrong current password."""
|
||||
email = f"e2e-{uuid4().hex[:8]}@example.com"
|
||||
tokens = await register_and_login(e2e_client, email)
|
||||
|
||||
response = await e2e_client.patch(
|
||||
"/api/v1/users/me/password",
|
||||
headers={"Authorization": f"Bearer {tokens['access_token']}"},
|
||||
json={
|
||||
"current_password": "WrongPassword123!",
|
||||
"new_password": "NewPassword456!",
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 403
|
||||
|
||||
async def test_change_password_weak_new_password(self, e2e_client):
|
||||
"""Password change fails with weak new password."""
|
||||
email = f"e2e-{uuid4().hex[:8]}@example.com"
|
||||
password = "SecurePassword123!"
|
||||
tokens = await register_and_login(e2e_client, email, password)
|
||||
|
||||
response = await e2e_client.patch(
|
||||
"/api/v1/users/me/password",
|
||||
headers={"Authorization": f"Bearer {tokens['access_token']}"},
|
||||
json={
|
||||
"current_password": password,
|
||||
"new_password": "weak", # Too weak
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 422 # Validation error
|
||||
|
||||
async def test_old_password_invalid_after_change(self, e2e_client):
|
||||
"""Old password no longer works after password change."""
|
||||
email = f"e2e-{uuid4().hex[:8]}@example.com"
|
||||
old_password = "OldPassword123!"
|
||||
new_password = "NewPassword456!"
|
||||
|
||||
tokens = await register_and_login(e2e_client, email, old_password)
|
||||
|
||||
# Change password
|
||||
await e2e_client.patch(
|
||||
"/api/v1/users/me/password",
|
||||
headers={"Authorization": f"Bearer {tokens['access_token']}"},
|
||||
json={
|
||||
"current_password": old_password,
|
||||
"new_password": new_password,
|
||||
},
|
||||
)
|
||||
|
||||
# Old password should fail
|
||||
login_resp = await e2e_client.post(
|
||||
"/api/v1/auth/login",
|
||||
json={"email": email, "password": old_password},
|
||||
)
|
||||
assert login_resp.status_code == 401
|
||||
|
||||
|
||||
class TestUserUpdateWorkflows:
|
||||
"""Test user update edge cases."""
|
||||
|
||||
async def test_cannot_elevate_own_privileges(self, e2e_client):
|
||||
"""Users cannot make themselves superusers."""
|
||||
email = f"e2e-{uuid4().hex[:8]}@example.com"
|
||||
tokens = await register_and_login(e2e_client, email)
|
||||
|
||||
# Try to make self superuser - should be silently ignored or rejected
|
||||
response = await e2e_client.patch(
|
||||
"/api/v1/users/me",
|
||||
headers={"Authorization": f"Bearer {tokens['access_token']}"},
|
||||
json={"is_superuser": True},
|
||||
)
|
||||
|
||||
# The request may succeed but is_superuser should not change
|
||||
if response.status_code == 200:
|
||||
data = response.json()
|
||||
assert data.get("is_superuser") is False
|
||||
else:
|
||||
# Or it may be rejected outright
|
||||
assert response.status_code in [400, 403, 422]
|
||||
|
||||
async def test_cannot_update_other_user_profile(self, e2e_client):
|
||||
"""Regular users cannot update other users' profiles."""
|
||||
user1_email = f"e2e-user1-{uuid4().hex[:8]}@example.com"
|
||||
user2_email = f"e2e-user2-{uuid4().hex[:8]}@example.com"
|
||||
|
||||
tokens1 = await register_and_login(e2e_client, user1_email)
|
||||
tokens2 = await register_and_login(e2e_client, user2_email)
|
||||
|
||||
# Get user1's ID
|
||||
me_resp = await e2e_client.get(
|
||||
"/api/v1/users/me",
|
||||
headers={"Authorization": f"Bearer {tokens1['access_token']}"},
|
||||
)
|
||||
user1_id = me_resp.json()["id"]
|
||||
|
||||
# User2 tries to update user1
|
||||
response = await e2e_client.patch(
|
||||
f"/api/v1/users/{user1_id}",
|
||||
headers={"Authorization": f"Bearer {tokens2['access_token']}"},
|
||||
json={"first_name": "Hacked"},
|
||||
)
|
||||
|
||||
assert response.status_code == 403
|
||||
|
||||
|
||||
class TestAdminUserListWorkflows:
|
||||
"""Test admin user list workflows via /users endpoint."""
|
||||
|
||||
async def test_superuser_can_list_all_users(self, e2e_client, e2e_superuser):
|
||||
"""Superuser can list all users via /users endpoint."""
|
||||
response = await e2e_client.get(
|
||||
"/api/v1/users",
|
||||
headers={
|
||||
"Authorization": f"Bearer {e2e_superuser['tokens']['access_token']}"
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert "data" in data
|
||||
assert "pagination" in data
|
||||
|
||||
async def test_regular_user_cannot_list_users(self, e2e_client):
|
||||
"""Regular users cannot list all users."""
|
||||
email = f"e2e-{uuid4().hex[:8]}@example.com"
|
||||
tokens = await register_and_login(e2e_client, email)
|
||||
|
||||
response = await e2e_client.get(
|
||||
"/api/v1/users",
|
||||
headers={"Authorization": f"Bearer {tokens['access_token']}"},
|
||||
)
|
||||
|
||||
assert response.status_code == 403
|
||||
|
||||
|
||||
class TestDeactivatedUserWorkflows:
|
||||
"""Test workflows involving deactivated users."""
|
||||
|
||||
async def test_deactivated_user_cannot_login(self, e2e_client, e2e_superuser):
|
||||
"""Deactivated users cannot login."""
|
||||
# Create user
|
||||
email = f"deactivate-login-{uuid4().hex[:8]}@example.com"
|
||||
password = "DeactivatePass123!"
|
||||
await e2e_client.post(
|
||||
"/api/v1/auth/register",
|
||||
json={
|
||||
"email": email,
|
||||
"password": password,
|
||||
"first_name": "Deactivate",
|
||||
"last_name": "Login",
|
||||
},
|
||||
)
|
||||
|
||||
# Get user ID
|
||||
list_resp = await e2e_client.get(
|
||||
"/api/v1/admin/users",
|
||||
headers={
|
||||
"Authorization": f"Bearer {e2e_superuser['tokens']['access_token']}"
|
||||
},
|
||||
)
|
||||
users = list_resp.json()["data"]
|
||||
target_user = next(u for u in users if u["email"] == email)
|
||||
|
||||
# Deactivate user
|
||||
await e2e_client.post(
|
||||
f"/api/v1/admin/users/{target_user['id']}/deactivate",
|
||||
headers={
|
||||
"Authorization": f"Bearer {e2e_superuser['tokens']['access_token']}"
|
||||
},
|
||||
)
|
||||
|
||||
# Try to login - should fail
|
||||
response = await e2e_client.post(
|
||||
"/api/v1/auth/login",
|
||||
json={"email": email, "password": password},
|
||||
)
|
||||
|
||||
assert response.status_code in [401, 403]
|
||||
772
backend/tests/services/test_oauth_provider_service.py
Normal file
772
backend/tests/services/test_oauth_provider_service.py
Normal file
@@ -0,0 +1,772 @@
|
||||
# tests/services/test_oauth_provider_service.py
|
||||
"""
|
||||
Tests for OAuth Provider Service (Authorization Server mode for MCP).
|
||||
|
||||
Covers:
|
||||
- Authorization code creation and exchange
|
||||
- Token generation, refresh, and revocation
|
||||
- PKCE verification
|
||||
- Token introspection (RFC 7662)
|
||||
- Consent management
|
||||
- Error handling
|
||||
"""
|
||||
|
||||
import base64
|
||||
import hashlib
|
||||
import secrets
|
||||
from unittest.mock import patch
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
|
||||
from app.models.oauth_client import OAuthClient
|
||||
from app.models.user import User
|
||||
from app.services import oauth_provider_service as service
|
||||
from app.utils.test_utils import setup_async_test_db, teardown_async_test_db
|
||||
|
||||
|
||||
@pytest_asyncio.fixture(scope="function")
|
||||
async def db():
|
||||
"""Fixture provides testing engine and session for each test."""
|
||||
test_engine, AsyncTestingSessionLocal = await setup_async_test_db()
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
yield session
|
||||
await teardown_async_test_db(test_engine)
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def test_user(db):
|
||||
"""Create a test user."""
|
||||
user = User(
|
||||
id=uuid4(),
|
||||
email="testuser@example.com",
|
||||
password_hash="$2b$12$test",
|
||||
first_name="Test",
|
||||
last_name="User",
|
||||
is_active=True,
|
||||
is_superuser=False,
|
||||
)
|
||||
db.add(user)
|
||||
await db.commit()
|
||||
await db.refresh(user)
|
||||
return user
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def public_client(db):
|
||||
"""Create a test public OAuth client."""
|
||||
client = OAuthClient(
|
||||
id=uuid4(),
|
||||
client_id="test_public_client",
|
||||
client_name="Test Public Client",
|
||||
client_type="public",
|
||||
redirect_uris=["http://localhost:3000/callback"],
|
||||
allowed_scopes=["openid", "profile", "email", "read:users"],
|
||||
is_active=True,
|
||||
)
|
||||
db.add(client)
|
||||
await db.commit()
|
||||
await db.refresh(client)
|
||||
return client
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def confidential_client(db):
|
||||
"""Create a test confidential OAuth client using bcrypt."""
|
||||
from app.core.auth import get_password_hash
|
||||
|
||||
secret = "test_client_secret"
|
||||
# Use bcrypt for new client secret hashing (security improvement)
|
||||
secret_hash = get_password_hash(secret)
|
||||
client = OAuthClient(
|
||||
id=uuid4(),
|
||||
client_id="test_confidential_client",
|
||||
client_name="Test Confidential Client",
|
||||
client_type="confidential",
|
||||
client_secret_hash=secret_hash,
|
||||
redirect_uris=["http://localhost:3000/callback"],
|
||||
allowed_scopes=["openid", "profile", "email"],
|
||||
is_active=True,
|
||||
)
|
||||
db.add(client)
|
||||
await db.commit()
|
||||
await db.refresh(client)
|
||||
return client, secret
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def confidential_client_legacy_hash(db):
|
||||
"""Create a test confidential OAuth client with legacy SHA-256 hash."""
|
||||
# This tests backward compatibility with old SHA-256 hashed secrets
|
||||
secret = "test_legacy_secret"
|
||||
secret_hash = hashlib.sha256(secret.encode()).hexdigest()
|
||||
client = OAuthClient(
|
||||
id=uuid4(),
|
||||
client_id="test_legacy_client",
|
||||
client_name="Test Legacy Client",
|
||||
client_type="confidential",
|
||||
client_secret_hash=secret_hash,
|
||||
redirect_uris=["http://localhost:3000/callback"],
|
||||
allowed_scopes=["openid", "profile"],
|
||||
is_active=True,
|
||||
)
|
||||
db.add(client)
|
||||
await db.commit()
|
||||
await db.refresh(client)
|
||||
return client, secret
|
||||
|
||||
|
||||
class TestHelperFunctions:
|
||||
"""Tests for helper functions."""
|
||||
|
||||
def test_generate_code_length(self):
|
||||
"""Test authorization code generation has proper length."""
|
||||
code = service.generate_code()
|
||||
assert len(code) > 64 # Base64 encoding of 64 bytes
|
||||
|
||||
def test_generate_code_unique(self):
|
||||
"""Test authorization codes are unique."""
|
||||
codes = [service.generate_code() for _ in range(100)]
|
||||
assert len(set(codes)) == 100
|
||||
|
||||
def test_generate_token(self):
|
||||
"""Test token generation."""
|
||||
token = service.generate_token()
|
||||
assert len(token) > 32
|
||||
|
||||
def test_generate_jti(self):
|
||||
"""Test JTI generation."""
|
||||
jti = service.generate_jti()
|
||||
assert len(jti) > 20
|
||||
|
||||
def test_hash_token(self):
|
||||
"""Test token hashing."""
|
||||
token = "test_token"
|
||||
hashed = service.hash_token(token)
|
||||
assert len(hashed) == 64 # SHA-256 hex digest
|
||||
|
||||
def test_hash_token_deterministic(self):
|
||||
"""Test same token produces same hash."""
|
||||
token = "test_token"
|
||||
hash1 = service.hash_token(token)
|
||||
hash2 = service.hash_token(token)
|
||||
assert hash1 == hash2
|
||||
|
||||
def test_parse_scope(self):
|
||||
"""Test scope parsing."""
|
||||
assert service.parse_scope("openid profile email") == [
|
||||
"openid",
|
||||
"profile",
|
||||
"email",
|
||||
]
|
||||
assert service.parse_scope("") == []
|
||||
assert service.parse_scope(" openid profile ") == ["openid", "profile"]
|
||||
|
||||
def test_join_scope(self):
|
||||
"""Test scope joining."""
|
||||
# Result is sorted and deduplicated
|
||||
result = service.join_scope(["profile", "openid", "profile"])
|
||||
assert "openid" in result
|
||||
assert "profile" in result
|
||||
|
||||
|
||||
class TestPKCEVerification:
|
||||
"""Tests for PKCE verification."""
|
||||
|
||||
def test_verify_pkce_s256_valid(self):
|
||||
"""Test PKCE verification with S256 method."""
|
||||
# Generate code_verifier
|
||||
code_verifier = secrets.token_urlsafe(64)
|
||||
|
||||
# Generate code_challenge using S256
|
||||
digest = hashlib.sha256(code_verifier.encode()).digest()
|
||||
code_challenge = base64.urlsafe_b64encode(digest).rstrip(b"=").decode()
|
||||
|
||||
assert service.verify_pkce(code_verifier, code_challenge, "S256") is True
|
||||
|
||||
def test_verify_pkce_s256_invalid(self):
|
||||
"""Test PKCE verification fails with wrong verifier."""
|
||||
code_verifier = secrets.token_urlsafe(64)
|
||||
wrong_verifier = secrets.token_urlsafe(64)
|
||||
|
||||
digest = hashlib.sha256(code_verifier.encode()).digest()
|
||||
code_challenge = base64.urlsafe_b64encode(digest).rstrip(b"=").decode()
|
||||
|
||||
assert service.verify_pkce(wrong_verifier, code_challenge, "S256") is False
|
||||
|
||||
def test_verify_pkce_plain_rejected(self):
|
||||
"""Test PKCE verification rejects 'plain' method for security."""
|
||||
# SECURITY: 'plain' method provides no security benefit and must be rejected
|
||||
# per RFC 7636 Section 4.3 - only S256 is allowed
|
||||
code_verifier = "test_verifier"
|
||||
assert service.verify_pkce(code_verifier, code_verifier, "plain") is False
|
||||
|
||||
def test_verify_pkce_unknown_method(self):
|
||||
"""Test PKCE verification with unknown method returns False."""
|
||||
assert service.verify_pkce("verifier", "challenge", "unknown") is False
|
||||
|
||||
|
||||
class TestClientValidation:
|
||||
"""Tests for client validation."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_client_success(self, db, public_client):
|
||||
"""Test getting a valid client."""
|
||||
client = await service.get_client(db, public_client.client_id)
|
||||
assert client is not None
|
||||
assert client.client_id == public_client.client_id
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_client_not_found(self, db):
|
||||
"""Test getting a non-existent client."""
|
||||
client = await service.get_client(db, "nonexistent")
|
||||
assert client is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_client_inactive(self, db, public_client):
|
||||
"""Test getting an inactive client returns None."""
|
||||
public_client.is_active = False
|
||||
await db.commit()
|
||||
|
||||
client = await service.get_client(db, public_client.client_id)
|
||||
assert client is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_validate_client_public(self, db, public_client):
|
||||
"""Test validating a public client."""
|
||||
client = await service.validate_client(db, public_client.client_id)
|
||||
assert client.client_id == public_client.client_id
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_validate_client_confidential_with_secret(
|
||||
self, db, confidential_client
|
||||
):
|
||||
"""Test validating a confidential client with correct secret."""
|
||||
client, secret = confidential_client
|
||||
validated = await service.validate_client(db, client.client_id, secret)
|
||||
assert validated.client_id == client.client_id
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_validate_client_confidential_wrong_secret(
|
||||
self, db, confidential_client
|
||||
):
|
||||
"""Test validating a confidential client with wrong secret."""
|
||||
client, _ = confidential_client
|
||||
with pytest.raises(service.InvalidClientError, match="Invalid client secret"):
|
||||
await service.validate_client(db, client.client_id, "wrong_secret")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_validate_client_confidential_no_secret(
|
||||
self, db, confidential_client
|
||||
):
|
||||
"""Test validating a confidential client without secret."""
|
||||
client, _ = confidential_client
|
||||
with pytest.raises(service.InvalidClientError, match="Client secret required"):
|
||||
await service.validate_client(db, client.client_id)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_validate_client_legacy_sha256_hash(
|
||||
self, db, confidential_client_legacy_hash
|
||||
):
|
||||
"""Test validating a client with legacy SHA-256 hash (backward compatibility)."""
|
||||
client, secret = confidential_client_legacy_hash
|
||||
validated = await service.validate_client(db, client.client_id, secret)
|
||||
assert validated.client_id == client.client_id
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_validate_client_legacy_sha256_wrong_secret(
|
||||
self, db, confidential_client_legacy_hash
|
||||
):
|
||||
"""Test legacy SHA-256 client rejects wrong secret."""
|
||||
client, _ = confidential_client_legacy_hash
|
||||
with pytest.raises(service.InvalidClientError, match="Invalid client secret"):
|
||||
await service.validate_client(db, client.client_id, "wrong_secret")
|
||||
|
||||
def test_validate_redirect_uri_success(self, public_client):
|
||||
"""Test validating a registered redirect URI."""
|
||||
# Should not raise
|
||||
service.validate_redirect_uri(public_client, "http://localhost:3000/callback")
|
||||
|
||||
def test_validate_redirect_uri_invalid(self, public_client):
|
||||
"""Test validating an unregistered redirect URI."""
|
||||
with pytest.raises(service.InvalidRequestError, match="Invalid redirect_uri"):
|
||||
service.validate_redirect_uri(public_client, "http://evil.com/callback")
|
||||
|
||||
def test_validate_redirect_uri_no_uris(self, public_client):
|
||||
"""Test validating when client has no URIs."""
|
||||
public_client.redirect_uris = []
|
||||
with pytest.raises(service.InvalidRequestError, match="no registered"):
|
||||
service.validate_redirect_uri(public_client, "http://localhost:3000")
|
||||
|
||||
|
||||
class TestScopeValidation:
|
||||
"""Tests for scope validation."""
|
||||
|
||||
def test_validate_scopes_all_valid(self, public_client):
|
||||
"""Test validating all valid scopes."""
|
||||
scopes = service.validate_scopes(public_client, ["openid", "profile"])
|
||||
assert "openid" in scopes
|
||||
assert "profile" in scopes
|
||||
|
||||
def test_validate_scopes_partial_valid(self, public_client):
|
||||
"""Test validating with some invalid scopes - filters them out."""
|
||||
scopes = service.validate_scopes(public_client, ["openid", "invalid_scope"])
|
||||
assert "openid" in scopes
|
||||
assert "invalid_scope" not in scopes
|
||||
|
||||
def test_validate_scopes_empty_uses_all_allowed(self, public_client):
|
||||
"""Test empty scope request uses all allowed scopes."""
|
||||
scopes = service.validate_scopes(public_client, [])
|
||||
assert set(scopes) == set(public_client.allowed_scopes)
|
||||
|
||||
def test_validate_scopes_none_valid(self, public_client):
|
||||
"""Test validating with no valid scopes raises error."""
|
||||
with pytest.raises(service.InvalidScopeError):
|
||||
service.validate_scopes(public_client, ["invalid1", "invalid2"])
|
||||
|
||||
|
||||
class TestAuthorizationCode:
|
||||
"""Tests for authorization code creation and exchange."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_authorization_code_public_with_pkce(
|
||||
self, db, public_client, test_user
|
||||
):
|
||||
"""Test creating authorization code for public client with PKCE."""
|
||||
code_verifier = secrets.token_urlsafe(64)
|
||||
digest = hashlib.sha256(code_verifier.encode()).digest()
|
||||
code_challenge = base64.urlsafe_b64encode(digest).rstrip(b"=").decode()
|
||||
|
||||
code = await service.create_authorization_code(
|
||||
db=db,
|
||||
client=public_client,
|
||||
user=test_user,
|
||||
redirect_uri="http://localhost:3000/callback",
|
||||
scope="openid profile",
|
||||
code_challenge=code_challenge,
|
||||
code_challenge_method="S256",
|
||||
)
|
||||
|
||||
assert code is not None
|
||||
assert len(code) > 64
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_authorization_code_public_without_pkce_fails(
|
||||
self, db, public_client, test_user
|
||||
):
|
||||
"""Test creating authorization code for public client without PKCE fails."""
|
||||
with pytest.raises(service.InvalidRequestError, match="PKCE"):
|
||||
await service.create_authorization_code(
|
||||
db=db,
|
||||
client=public_client,
|
||||
user=test_user,
|
||||
redirect_uri="http://localhost:3000/callback",
|
||||
scope="openid",
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_exchange_authorization_code_success(
|
||||
self, db, public_client, test_user
|
||||
):
|
||||
"""Test exchanging valid authorization code for tokens."""
|
||||
# Create PKCE challenge
|
||||
code_verifier = secrets.token_urlsafe(64)
|
||||
digest = hashlib.sha256(code_verifier.encode()).digest()
|
||||
code_challenge = base64.urlsafe_b64encode(digest).rstrip(b"=").decode()
|
||||
|
||||
# Create auth code
|
||||
code = await service.create_authorization_code(
|
||||
db=db,
|
||||
client=public_client,
|
||||
user=test_user,
|
||||
redirect_uri="http://localhost:3000/callback",
|
||||
scope="openid profile",
|
||||
code_challenge=code_challenge,
|
||||
code_challenge_method="S256",
|
||||
)
|
||||
|
||||
# Exchange code
|
||||
with patch("app.services.oauth_provider_service.settings") as mock_settings:
|
||||
mock_settings.OAUTH_ISSUER = "http://localhost:8000"
|
||||
mock_settings.SECRET_KEY = "test_secret_key_for_jwt_signing_123456"
|
||||
mock_settings.ALGORITHM = "HS256"
|
||||
|
||||
result = await service.exchange_authorization_code(
|
||||
db=db,
|
||||
code=code,
|
||||
client_id=public_client.client_id,
|
||||
redirect_uri="http://localhost:3000/callback",
|
||||
code_verifier=code_verifier,
|
||||
)
|
||||
|
||||
assert "access_token" in result
|
||||
assert "refresh_token" in result
|
||||
assert result["token_type"] == "Bearer"
|
||||
assert "expires_in" in result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_exchange_authorization_code_invalid_code(self, db, public_client):
|
||||
"""Test exchanging invalid code fails."""
|
||||
with pytest.raises(service.InvalidGrantError, match="Invalid authorization"):
|
||||
await service.exchange_authorization_code(
|
||||
db=db,
|
||||
code="invalid_code",
|
||||
client_id=public_client.client_id,
|
||||
redirect_uri="http://localhost:3000/callback",
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_exchange_authorization_code_wrong_redirect_uri(
|
||||
self, db, public_client, test_user
|
||||
):
|
||||
"""Test exchanging code with wrong redirect_uri fails."""
|
||||
code_verifier = secrets.token_urlsafe(64)
|
||||
digest = hashlib.sha256(code_verifier.encode()).digest()
|
||||
code_challenge = base64.urlsafe_b64encode(digest).rstrip(b"=").decode()
|
||||
|
||||
code = await service.create_authorization_code(
|
||||
db=db,
|
||||
client=public_client,
|
||||
user=test_user,
|
||||
redirect_uri="http://localhost:3000/callback",
|
||||
scope="openid",
|
||||
code_challenge=code_challenge,
|
||||
code_challenge_method="S256",
|
||||
)
|
||||
|
||||
with pytest.raises(service.InvalidGrantError, match="redirect_uri mismatch"):
|
||||
await service.exchange_authorization_code(
|
||||
db=db,
|
||||
code=code,
|
||||
client_id=public_client.client_id,
|
||||
redirect_uri="http://different.com/callback",
|
||||
code_verifier=code_verifier,
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_exchange_authorization_code_invalid_pkce(
|
||||
self, db, public_client, test_user
|
||||
):
|
||||
"""Test exchanging code with invalid PKCE verifier fails."""
|
||||
code_verifier = secrets.token_urlsafe(64)
|
||||
digest = hashlib.sha256(code_verifier.encode()).digest()
|
||||
code_challenge = base64.urlsafe_b64encode(digest).rstrip(b"=").decode()
|
||||
|
||||
code = await service.create_authorization_code(
|
||||
db=db,
|
||||
client=public_client,
|
||||
user=test_user,
|
||||
redirect_uri="http://localhost:3000/callback",
|
||||
scope="openid",
|
||||
code_challenge=code_challenge,
|
||||
code_challenge_method="S256",
|
||||
)
|
||||
|
||||
with pytest.raises(service.InvalidGrantError, match="Invalid code_verifier"):
|
||||
await service.exchange_authorization_code(
|
||||
db=db,
|
||||
code=code,
|
||||
client_id=public_client.client_id,
|
||||
redirect_uri="http://localhost:3000/callback",
|
||||
code_verifier="wrong_verifier",
|
||||
)
|
||||
|
||||
|
||||
class TestTokenRefresh:
|
||||
"""Tests for token refresh."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_refresh_tokens_success(self, db, public_client, test_user):
|
||||
"""Test refreshing tokens successfully."""
|
||||
# Create initial tokens
|
||||
with patch("app.services.oauth_provider_service.settings") as mock_settings:
|
||||
mock_settings.OAUTH_ISSUER = "http://localhost:8000"
|
||||
mock_settings.SECRET_KEY = "test_secret_key_for_jwt_signing_123456"
|
||||
mock_settings.ALGORITHM = "HS256"
|
||||
|
||||
result = await service.create_tokens(
|
||||
db=db,
|
||||
client=public_client,
|
||||
user=test_user,
|
||||
scope="openid profile",
|
||||
)
|
||||
|
||||
refresh_token = result["refresh_token"]
|
||||
|
||||
# Refresh the tokens
|
||||
new_result = await service.refresh_tokens(
|
||||
db=db,
|
||||
refresh_token=refresh_token,
|
||||
client_id=public_client.client_id,
|
||||
)
|
||||
|
||||
assert "access_token" in new_result
|
||||
assert "refresh_token" in new_result
|
||||
assert new_result["refresh_token"] != refresh_token # Token rotation
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_refresh_tokens_invalid_token(self, db, public_client):
|
||||
"""Test refreshing with invalid token fails."""
|
||||
with pytest.raises(service.InvalidGrantError, match="Invalid refresh token"):
|
||||
await service.refresh_tokens(
|
||||
db=db,
|
||||
refresh_token="invalid_token",
|
||||
client_id=public_client.client_id,
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_refresh_tokens_scope_reduction(self, db, public_client, test_user):
|
||||
"""Test refreshing with reduced scope."""
|
||||
with patch("app.services.oauth_provider_service.settings") as mock_settings:
|
||||
mock_settings.OAUTH_ISSUER = "http://localhost:8000"
|
||||
mock_settings.SECRET_KEY = "test_secret_key_for_jwt_signing_123456"
|
||||
mock_settings.ALGORITHM = "HS256"
|
||||
|
||||
result = await service.create_tokens(
|
||||
db=db,
|
||||
client=public_client,
|
||||
user=test_user,
|
||||
scope="openid profile email",
|
||||
)
|
||||
|
||||
new_result = await service.refresh_tokens(
|
||||
db=db,
|
||||
refresh_token=result["refresh_token"],
|
||||
client_id=public_client.client_id,
|
||||
scope="openid", # Reduced scope
|
||||
)
|
||||
|
||||
assert "openid" in new_result["scope"]
|
||||
assert "profile" not in new_result["scope"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_refresh_tokens_scope_expansion_fails(
|
||||
self, db, public_client, test_user
|
||||
):
|
||||
"""Test refreshing with expanded scope fails."""
|
||||
with patch("app.services.oauth_provider_service.settings") as mock_settings:
|
||||
mock_settings.OAUTH_ISSUER = "http://localhost:8000"
|
||||
mock_settings.SECRET_KEY = "test_secret_key_for_jwt_signing_123456"
|
||||
mock_settings.ALGORITHM = "HS256"
|
||||
|
||||
result = await service.create_tokens(
|
||||
db=db,
|
||||
client=public_client,
|
||||
user=test_user,
|
||||
scope="openid",
|
||||
)
|
||||
|
||||
with pytest.raises(service.InvalidScopeError, match="Cannot expand scope"):
|
||||
await service.refresh_tokens(
|
||||
db=db,
|
||||
refresh_token=result["refresh_token"],
|
||||
client_id=public_client.client_id,
|
||||
scope="openid profile", # Expanded scope
|
||||
)
|
||||
|
||||
|
||||
class TestTokenRevocation:
|
||||
"""Tests for token revocation."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_revoke_refresh_token(self, db, public_client, test_user):
|
||||
"""Test revoking a refresh token."""
|
||||
with patch("app.services.oauth_provider_service.settings") as mock_settings:
|
||||
mock_settings.OAUTH_ISSUER = "http://localhost:8000"
|
||||
mock_settings.SECRET_KEY = "test_secret_key_for_jwt_signing_123456"
|
||||
mock_settings.ALGORITHM = "HS256"
|
||||
|
||||
result = await service.create_tokens(
|
||||
db=db,
|
||||
client=public_client,
|
||||
user=test_user,
|
||||
scope="openid",
|
||||
)
|
||||
|
||||
# Revoke the token
|
||||
revoked = await service.revoke_token(
|
||||
db=db,
|
||||
token=result["refresh_token"],
|
||||
token_type_hint="refresh_token",
|
||||
)
|
||||
|
||||
assert revoked is True
|
||||
|
||||
# Try to use revoked token
|
||||
with pytest.raises(service.InvalidGrantError, match="revoked"):
|
||||
await service.refresh_tokens(
|
||||
db=db,
|
||||
refresh_token=result["refresh_token"],
|
||||
client_id=public_client.client_id,
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_revoke_all_user_tokens(self, db, public_client, test_user):
|
||||
"""Test revoking all tokens for a user."""
|
||||
with patch("app.services.oauth_provider_service.settings") as mock_settings:
|
||||
mock_settings.OAUTH_ISSUER = "http://localhost:8000"
|
||||
mock_settings.SECRET_KEY = "test_secret_key_for_jwt_signing_123456"
|
||||
mock_settings.ALGORITHM = "HS256"
|
||||
|
||||
# Create multiple tokens (we don't need to capture results)
|
||||
await service.create_tokens(
|
||||
db=db,
|
||||
client=public_client,
|
||||
user=test_user,
|
||||
scope="openid",
|
||||
)
|
||||
await service.create_tokens(
|
||||
db=db,
|
||||
client=public_client,
|
||||
user=test_user,
|
||||
scope="profile",
|
||||
)
|
||||
|
||||
# Revoke all
|
||||
count = await service.revoke_all_user_tokens(db, test_user.id)
|
||||
assert count == 2
|
||||
|
||||
|
||||
class TestTokenIntrospection:
|
||||
"""Tests for token introspection (RFC 7662)."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_introspect_valid_access_token(self, db, public_client, test_user):
|
||||
"""Test introspecting a valid access token."""
|
||||
with patch("app.services.oauth_provider_service.settings") as mock_settings:
|
||||
mock_settings.OAUTH_ISSUER = "http://localhost:8000"
|
||||
mock_settings.SECRET_KEY = "test_secret_key_for_jwt_signing_123456"
|
||||
mock_settings.ALGORITHM = "HS256"
|
||||
|
||||
result = await service.create_tokens(
|
||||
db=db,
|
||||
client=public_client,
|
||||
user=test_user,
|
||||
scope="openid profile",
|
||||
)
|
||||
|
||||
introspection = await service.introspect_token(
|
||||
db=db,
|
||||
token=result["access_token"],
|
||||
)
|
||||
|
||||
assert introspection["active"] is True
|
||||
assert introspection["client_id"] == public_client.client_id
|
||||
assert introspection["sub"] == str(test_user.id)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_introspect_invalid_token(self, db):
|
||||
"""Test introspecting an invalid token."""
|
||||
introspection = await service.introspect_token(
|
||||
db=db,
|
||||
token="invalid_token",
|
||||
)
|
||||
assert introspection["active"] is False
|
||||
|
||||
|
||||
class TestConsentManagement:
|
||||
"""Tests for consent management."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_grant_consent(self, db, public_client, test_user):
|
||||
"""Test granting consent."""
|
||||
consent = await service.grant_consent(
|
||||
db=db,
|
||||
user_id=test_user.id,
|
||||
client_id=public_client.client_id,
|
||||
scopes=["openid", "profile"],
|
||||
)
|
||||
|
||||
assert consent is not None
|
||||
assert "openid" in consent.granted_scopes
|
||||
assert "profile" in consent.granted_scopes
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_check_consent_granted(self, db, public_client, test_user):
|
||||
"""Test checking granted consent."""
|
||||
await service.grant_consent(
|
||||
db=db,
|
||||
user_id=test_user.id,
|
||||
client_id=public_client.client_id,
|
||||
scopes=["openid", "profile"],
|
||||
)
|
||||
|
||||
has_consent = await service.check_consent(
|
||||
db=db,
|
||||
user_id=test_user.id,
|
||||
client_id=public_client.client_id,
|
||||
requested_scopes=["openid"],
|
||||
)
|
||||
assert has_consent is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_check_consent_not_granted(self, db, public_client, test_user):
|
||||
"""Test checking consent that hasn't been granted."""
|
||||
has_consent = await service.check_consent(
|
||||
db=db,
|
||||
user_id=test_user.id,
|
||||
client_id=public_client.client_id,
|
||||
requested_scopes=["openid"],
|
||||
)
|
||||
assert has_consent is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_revoke_consent(self, db, public_client, test_user):
|
||||
"""Test revoking consent."""
|
||||
await service.grant_consent(
|
||||
db=db,
|
||||
user_id=test_user.id,
|
||||
client_id=public_client.client_id,
|
||||
scopes=["openid"],
|
||||
)
|
||||
|
||||
revoked = await service.revoke_consent(
|
||||
db=db,
|
||||
user_id=test_user.id,
|
||||
client_id=public_client.client_id,
|
||||
)
|
||||
assert revoked is True
|
||||
|
||||
# Check consent is gone
|
||||
has_consent = await service.check_consent(
|
||||
db=db,
|
||||
user_id=test_user.id,
|
||||
client_id=public_client.client_id,
|
||||
requested_scopes=["openid"],
|
||||
)
|
||||
assert has_consent is False
|
||||
|
||||
|
||||
class TestOAuthErrors:
|
||||
"""Tests for OAuth error classes."""
|
||||
|
||||
def test_invalid_client_error(self):
|
||||
"""Test InvalidClientError."""
|
||||
error = service.InvalidClientError("Test description")
|
||||
assert error.error == "invalid_client"
|
||||
assert error.error_description == "Test description"
|
||||
|
||||
def test_invalid_grant_error(self):
|
||||
"""Test InvalidGrantError."""
|
||||
error = service.InvalidGrantError("Test description")
|
||||
assert error.error == "invalid_grant"
|
||||
assert error.error_description == "Test description"
|
||||
|
||||
def test_invalid_request_error(self):
|
||||
"""Test InvalidRequestError."""
|
||||
error = service.InvalidRequestError("Test description")
|
||||
assert error.error == "invalid_request"
|
||||
assert error.error_description == "Test description"
|
||||
|
||||
def test_invalid_scope_error(self):
|
||||
"""Test InvalidScopeError."""
|
||||
error = service.InvalidScopeError("Test description")
|
||||
assert error.error == "invalid_scope"
|
||||
assert error.error_description == "Test description"
|
||||
|
||||
def test_access_denied_error(self):
|
||||
"""Test AccessDeniedError."""
|
||||
error = service.AccessDeniedError("Test description")
|
||||
assert error.error == "access_denied"
|
||||
assert error.error_description == "Test description"
|
||||
@@ -451,6 +451,7 @@ class TestHandleCallbackComplete:
|
||||
state="valid_state_login",
|
||||
provider="google",
|
||||
code_verifier="test_verifier",
|
||||
redirect_uri="http://localhost:3000/callback",
|
||||
expires_at=datetime.now(UTC) + timedelta(minutes=10),
|
||||
)
|
||||
await oauth_state.create_state(session, obj_in=state_data)
|
||||
@@ -533,6 +534,7 @@ class TestHandleCallbackComplete:
|
||||
state_data = OAuthStateCreate(
|
||||
state="valid_state_inactive",
|
||||
provider="google",
|
||||
redirect_uri="http://localhost:3000/callback",
|
||||
expires_at=datetime.now(UTC) + timedelta(minutes=10),
|
||||
)
|
||||
await oauth_state.create_state(session, obj_in=state_data)
|
||||
@@ -583,6 +585,7 @@ class TestHandleCallbackComplete:
|
||||
state="valid_state_linking",
|
||||
provider="github",
|
||||
user_id=async_test_user.id, # User is logged in
|
||||
redirect_uri="http://localhost:3000/callback",
|
||||
expires_at=datetime.now(UTC) + timedelta(minutes=10),
|
||||
)
|
||||
await oauth_state.create_state(session, obj_in=state_data)
|
||||
@@ -648,6 +651,7 @@ class TestHandleCallbackComplete:
|
||||
state="valid_state_bad_user",
|
||||
provider="google",
|
||||
user_id=uuid4(), # Non-existent user
|
||||
redirect_uri="http://localhost:3000/callback",
|
||||
expires_at=datetime.now(UTC) + timedelta(minutes=10),
|
||||
)
|
||||
await oauth_state.create_state(session, obj_in=state_data)
|
||||
@@ -707,6 +711,7 @@ class TestHandleCallbackComplete:
|
||||
state="valid_state_already_linked",
|
||||
provider="google",
|
||||
user_id=async_test_user.id,
|
||||
redirect_uri="http://localhost:3000/callback",
|
||||
expires_at=datetime.now(UTC) + timedelta(minutes=10),
|
||||
)
|
||||
await oauth_state.create_state(session, obj_in=state_data)
|
||||
@@ -769,6 +774,7 @@ class TestHandleCallbackComplete:
|
||||
state_data = OAuthStateCreate(
|
||||
state="valid_state_autolink",
|
||||
provider="google",
|
||||
redirect_uri="http://localhost:3000/callback",
|
||||
expires_at=datetime.now(UTC) + timedelta(minutes=10),
|
||||
)
|
||||
await oauth_state.create_state(session, obj_in=state_data)
|
||||
@@ -832,6 +838,7 @@ class TestHandleCallbackComplete:
|
||||
state_data = OAuthStateCreate(
|
||||
state="valid_state_new_user",
|
||||
provider="google",
|
||||
redirect_uri="http://localhost:3000/callback",
|
||||
expires_at=datetime.now(UTC) + timedelta(minutes=10),
|
||||
)
|
||||
await oauth_state.create_state(session, obj_in=state_data)
|
||||
@@ -904,6 +911,7 @@ class TestHandleCallbackComplete:
|
||||
state_data = OAuthStateCreate(
|
||||
state="valid_state_no_email",
|
||||
provider="github",
|
||||
redirect_uri="http://localhost:3000/callback",
|
||||
expires_at=datetime.now(UTC) + timedelta(minutes=10),
|
||||
)
|
||||
await oauth_state.create_state(session, obj_in=state_data)
|
||||
@@ -961,6 +969,7 @@ class TestHandleCallbackComplete:
|
||||
state_data = OAuthStateCreate(
|
||||
state="valid_state_token_fail",
|
||||
provider="google",
|
||||
redirect_uri="http://localhost:3000/callback",
|
||||
expires_at=datetime.now(UTC) + timedelta(minutes=10),
|
||||
)
|
||||
await oauth_state.create_state(session, obj_in=state_data)
|
||||
@@ -1004,6 +1013,7 @@ class TestHandleCallbackComplete:
|
||||
state_data = OAuthStateCreate(
|
||||
state="valid_state_userinfo_fail",
|
||||
provider="google",
|
||||
redirect_uri="http://localhost:3000/callback",
|
||||
expires_at=datetime.now(UTC) + timedelta(minutes=10),
|
||||
)
|
||||
await oauth_state.create_state(session, obj_in=state_data)
|
||||
@@ -1047,6 +1057,7 @@ class TestHandleCallbackComplete:
|
||||
state_data = OAuthStateCreate(
|
||||
state="valid_state_no_token",
|
||||
provider="google",
|
||||
redirect_uri="http://localhost:3000/callback",
|
||||
expires_at=datetime.now(UTC) + timedelta(minutes=10),
|
||||
)
|
||||
await oauth_state.create_state(session, obj_in=state_data)
|
||||
@@ -1090,6 +1101,7 @@ class TestHandleCallbackComplete:
|
||||
state_data = OAuthStateCreate(
|
||||
state="valid_state_no_user_id",
|
||||
provider="google",
|
||||
redirect_uri="http://localhost:3000/callback",
|
||||
expires_at=datetime.now(UTC) + timedelta(minutes=10),
|
||||
)
|
||||
await oauth_state.create_state(session, obj_in=state_data)
|
||||
|
||||
@@ -153,6 +153,7 @@
|
||||
"authFailed": "Authentication Failed",
|
||||
"providerError": "The authentication provider returned an error",
|
||||
"missingParams": "Missing authentication parameters",
|
||||
"stateMismatch": "Invalid OAuth state. Please try again.",
|
||||
"unexpectedError": "An unexpected error occurred during authentication",
|
||||
"backToLogin": "Back to Login"
|
||||
}
|
||||
|
||||
@@ -153,6 +153,7 @@
|
||||
"authFailed": "Autenticazione Fallita",
|
||||
"providerError": "Il provider di autenticazione ha restituito un errore",
|
||||
"missingParams": "Parametri di autenticazione mancanti",
|
||||
"stateMismatch": "Stato OAuth non valido. Riprova.",
|
||||
"unexpectedError": "Si è verificato un errore durante l'autenticazione",
|
||||
"backToLogin": "Torna al Login"
|
||||
}
|
||||
|
||||
@@ -21,6 +21,24 @@ import { Loader2 } from 'lucide-react';
|
||||
import { useOAuthCallback } from '@/lib/api/hooks/useOAuth';
|
||||
import config from '@/config/app.config';
|
||||
|
||||
/**
|
||||
* SECURITY: Constant-time string comparison to prevent timing attacks.
|
||||
* JavaScript's === operator may short-circuit, potentially leaking information.
|
||||
* While timing attacks on frontend state are less critical (state is in URL),
|
||||
* this provides defense-in-depth.
|
||||
*/
|
||||
function constantTimeCompare(a: string, b: string): boolean {
|
||||
if (a.length !== b.length) {
|
||||
return false;
|
||||
}
|
||||
|
||||
let result = 0;
|
||||
for (let i = 0; i < a.length; i++) {
|
||||
result |= a.charCodeAt(i) ^ b.charCodeAt(i);
|
||||
}
|
||||
return result === 0;
|
||||
}
|
||||
|
||||
export default function OAuthCallbackPage() {
|
||||
const params = useParams();
|
||||
const searchParams = useSearchParams();
|
||||
@@ -53,6 +71,19 @@ export default function OAuthCallbackPage() {
|
||||
return;
|
||||
}
|
||||
|
||||
// SECURITY: Validate state parameter against stored value (CSRF protection)
|
||||
// This prevents cross-site request forgery attacks
|
||||
// Use constant-time comparison for defense-in-depth
|
||||
const storedState = sessionStorage.getItem('oauth_state');
|
||||
if (!storedState || !constantTimeCompare(storedState, state)) {
|
||||
// Clean up stored state on mismatch
|
||||
sessionStorage.removeItem('oauth_state');
|
||||
sessionStorage.removeItem('oauth_mode');
|
||||
sessionStorage.removeItem('oauth_provider');
|
||||
setError(t('stateMismatch') || 'Invalid OAuth state. Please try again.');
|
||||
return;
|
||||
}
|
||||
|
||||
hasProcessed.current = true;
|
||||
|
||||
// Process the OAuth callback
|
||||
|
||||
325
frontend/src/app/[locale]/(auth)/auth/consent/page.tsx
Normal file
325
frontend/src/app/[locale]/(auth)/auth/consent/page.tsx
Normal file
@@ -0,0 +1,325 @@
|
||||
/**
|
||||
* OAuth Consent Page
|
||||
* Displays authorization consent form for OAuth provider mode (MCP integration).
|
||||
*
|
||||
* Users are redirected here when an external application (MCP client) requests
|
||||
* access to their account. They can approve or deny the requested permissions.
|
||||
*/
|
||||
|
||||
'use client';
|
||||
|
||||
import { useState, useEffect } from 'react';
|
||||
import { useSearchParams } from 'next/navigation';
|
||||
import { useRouter } from '@/lib/i18n/routing';
|
||||
import { useTranslations } from 'next-intl';
|
||||
import { Button } from '@/components/ui/button';
|
||||
import {
|
||||
Card,
|
||||
CardContent,
|
||||
CardDescription,
|
||||
CardFooter,
|
||||
CardHeader,
|
||||
CardTitle,
|
||||
} from '@/components/ui/card';
|
||||
import { Alert, AlertDescription } from '@/components/ui/alert';
|
||||
import { Checkbox } from '@/components/ui/checkbox';
|
||||
import { Label } from '@/components/ui/label';
|
||||
import { Loader2, Shield, AlertTriangle, ExternalLink, CheckCircle2 } from 'lucide-react';
|
||||
import { useAuth } from '@/lib/auth/AuthContext';
|
||||
import config from '@/config/app.config';
|
||||
|
||||
// Scope descriptions for display
|
||||
const SCOPE_INFO: Record<string, { name: string; description: string; icon: string }> = {
|
||||
openid: {
|
||||
name: 'OpenID Connect',
|
||||
description: 'Verify your identity',
|
||||
icon: 'user',
|
||||
},
|
||||
profile: {
|
||||
name: 'Profile',
|
||||
description: 'Access your name and basic profile information',
|
||||
icon: 'user-circle',
|
||||
},
|
||||
email: {
|
||||
name: 'Email',
|
||||
description: 'Access your email address',
|
||||
icon: 'mail',
|
||||
},
|
||||
'read:users': {
|
||||
name: 'Read Users',
|
||||
description: 'View user information',
|
||||
icon: 'users',
|
||||
},
|
||||
'write:users': {
|
||||
name: 'Write Users',
|
||||
description: 'Modify user information',
|
||||
icon: 'user-edit',
|
||||
},
|
||||
'read:organizations': {
|
||||
name: 'Read Organizations',
|
||||
description: 'View organization information',
|
||||
icon: 'building',
|
||||
},
|
||||
'write:organizations': {
|
||||
name: 'Write Organizations',
|
||||
description: 'Modify organization information',
|
||||
icon: 'building-edit',
|
||||
},
|
||||
admin: {
|
||||
name: 'Admin Access',
|
||||
description: 'Full administrative access',
|
||||
icon: 'shield',
|
||||
},
|
||||
};
|
||||
|
||||
interface ConsentParams {
|
||||
clientId: string;
|
||||
clientName: string;
|
||||
redirectUri: string;
|
||||
scope: string;
|
||||
state: string;
|
||||
codeChallenge: string;
|
||||
codeChallengeMethod: string;
|
||||
nonce: string;
|
||||
}
|
||||
|
||||
export default function OAuthConsentPage() {
|
||||
const searchParams = useSearchParams();
|
||||
const router = useRouter();
|
||||
// Note: t is available for future i18n use
|
||||
const _t = useTranslations('auth.oauth');
|
||||
void _t; // Suppress unused warning - ready for i18n
|
||||
const { isAuthenticated, isLoading: authLoading } = useAuth();
|
||||
|
||||
const [isSubmitting, setIsSubmitting] = useState(false);
|
||||
const [error, setError] = useState<string | null>(null);
|
||||
const [selectedScopes, setSelectedScopes] = useState<Set<string>>(new Set());
|
||||
const [params, setParams] = useState<ConsentParams | null>(null);
|
||||
|
||||
// Parse URL parameters
|
||||
useEffect(() => {
|
||||
const clientId = searchParams.get('client_id') || '';
|
||||
const clientName = searchParams.get('client_name') || 'Application';
|
||||
const redirectUri = searchParams.get('redirect_uri') || '';
|
||||
const scope = searchParams.get('scope') || '';
|
||||
const state = searchParams.get('state') || '';
|
||||
const codeChallenge = searchParams.get('code_challenge') || '';
|
||||
const codeChallengeMethod = searchParams.get('code_challenge_method') || '';
|
||||
const nonce = searchParams.get('nonce') || '';
|
||||
|
||||
if (!clientId || !redirectUri) {
|
||||
setError('Invalid authorization request. Missing required parameters.');
|
||||
return;
|
||||
}
|
||||
|
||||
setParams({
|
||||
clientId,
|
||||
clientName,
|
||||
redirectUri,
|
||||
scope,
|
||||
state,
|
||||
codeChallenge,
|
||||
codeChallengeMethod,
|
||||
nonce,
|
||||
});
|
||||
|
||||
// Initialize selected scopes with all requested scopes
|
||||
if (scope) {
|
||||
setSelectedScopes(new Set(scope.split(' ')));
|
||||
}
|
||||
}, [searchParams]);
|
||||
|
||||
// Redirect to login if not authenticated
|
||||
useEffect(() => {
|
||||
if (!authLoading && !isAuthenticated) {
|
||||
const returnUrl = `/auth/consent?${searchParams.toString()}`;
|
||||
router.push(`${config.routes.login}?return_to=${encodeURIComponent(returnUrl)}`);
|
||||
}
|
||||
}, [authLoading, isAuthenticated, router, searchParams]);
|
||||
|
||||
const handleScopeToggle = (scope: string) => {
|
||||
setSelectedScopes((prev) => {
|
||||
const next = new Set(prev);
|
||||
if (next.has(scope)) {
|
||||
next.delete(scope);
|
||||
} else {
|
||||
next.add(scope);
|
||||
}
|
||||
return next;
|
||||
});
|
||||
};
|
||||
|
||||
const handleSubmit = async (approved: boolean) => {
|
||||
if (!params) return;
|
||||
|
||||
setIsSubmitting(true);
|
||||
setError(null);
|
||||
|
||||
try {
|
||||
// Create form data for consent submission
|
||||
const formData = new FormData();
|
||||
formData.append('approved', approved.toString());
|
||||
formData.append('client_id', params.clientId);
|
||||
formData.append('redirect_uri', params.redirectUri);
|
||||
formData.append('scope', Array.from(selectedScopes).join(' '));
|
||||
formData.append('state', params.state);
|
||||
if (params.codeChallenge) {
|
||||
formData.append('code_challenge', params.codeChallenge);
|
||||
}
|
||||
if (params.codeChallengeMethod) {
|
||||
formData.append('code_challenge_method', params.codeChallengeMethod);
|
||||
}
|
||||
if (params.nonce) {
|
||||
formData.append('nonce', params.nonce);
|
||||
}
|
||||
|
||||
// Submit consent to backend
|
||||
const apiUrl = process.env.NEXT_PUBLIC_API_URL || 'http://localhost:8000';
|
||||
const response = await fetch(`${apiUrl}/api/v1/oauth/provider/authorize/consent`, {
|
||||
method: 'POST',
|
||||
body: formData,
|
||||
credentials: 'include',
|
||||
});
|
||||
|
||||
// The endpoint returns a redirect, so follow it
|
||||
if (response.redirected) {
|
||||
window.location.href = response.url;
|
||||
} else if (!response.ok) {
|
||||
const data = await response.json();
|
||||
throw new Error(data.detail || 'Failed to process consent');
|
||||
}
|
||||
} catch (err) {
|
||||
setError(err instanceof Error ? err.message : 'An unexpected error occurred');
|
||||
setIsSubmitting(false);
|
||||
}
|
||||
};
|
||||
|
||||
// Show loading state while checking auth
|
||||
if (authLoading) {
|
||||
return (
|
||||
<div className="flex min-h-screen items-center justify-center p-4">
|
||||
<div className="text-center space-y-4">
|
||||
<Loader2 className="h-8 w-8 animate-spin mx-auto text-primary" />
|
||||
<p className="text-muted-foreground">Loading...</p>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
// Show error state
|
||||
if (error && !params) {
|
||||
return (
|
||||
<div className="flex min-h-screen items-center justify-center p-4">
|
||||
<div className="w-full max-w-md space-y-4">
|
||||
<Alert variant="destructive">
|
||||
<AlertTriangle className="h-4 w-4" />
|
||||
<AlertDescription>{error}</AlertDescription>
|
||||
</Alert>
|
||||
<div className="flex gap-2 justify-center">
|
||||
<Button variant="outline" onClick={() => router.push(config.routes.login)}>
|
||||
Back to Login
|
||||
</Button>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
if (!params) {
|
||||
return null;
|
||||
}
|
||||
|
||||
const requestedScopes = params.scope ? params.scope.split(' ') : [];
|
||||
|
||||
return (
|
||||
<div className="flex min-h-screen items-center justify-center p-4">
|
||||
<Card className="w-full max-w-md">
|
||||
<CardHeader className="text-center">
|
||||
<div className="flex justify-center mb-4">
|
||||
<Shield className="h-12 w-12 text-primary" />
|
||||
</div>
|
||||
<CardTitle className="text-xl">Authorization Request</CardTitle>
|
||||
<CardDescription className="mt-2">
|
||||
<span className="font-semibold text-foreground">{params.clientName}</span> wants to
|
||||
access your account
|
||||
</CardDescription>
|
||||
</CardHeader>
|
||||
|
||||
<CardContent className="space-y-4">
|
||||
{error && (
|
||||
<Alert variant="destructive">
|
||||
<AlertTriangle className="h-4 w-4" />
|
||||
<AlertDescription>{error}</AlertDescription>
|
||||
</Alert>
|
||||
)}
|
||||
|
||||
<div className="space-y-3">
|
||||
<p className="text-sm font-medium">This application will be able to:</p>
|
||||
<div className="space-y-2 border rounded-lg p-3">
|
||||
{requestedScopes.map((scope) => {
|
||||
const info = SCOPE_INFO[scope] || {
|
||||
name: scope,
|
||||
description: `Access to ${scope}`,
|
||||
};
|
||||
const isSelected = selectedScopes.has(scope);
|
||||
|
||||
return (
|
||||
<div
|
||||
key={scope}
|
||||
className="flex items-start space-x-3 py-2 border-b last:border-0"
|
||||
>
|
||||
<Checkbox
|
||||
id={`scope-${scope}`}
|
||||
checked={isSelected}
|
||||
onCheckedChange={() => handleScopeToggle(scope)}
|
||||
disabled={isSubmitting}
|
||||
/>
|
||||
<div className="flex-1 space-y-0.5">
|
||||
<Label
|
||||
htmlFor={`scope-${scope}`}
|
||||
className="text-sm font-medium cursor-pointer"
|
||||
>
|
||||
{info.name}
|
||||
</Label>
|
||||
<p className="text-xs text-muted-foreground">{info.description}</p>
|
||||
</div>
|
||||
{isSelected && <CheckCircle2 className="h-4 w-4 text-green-500 mt-0.5" />}
|
||||
</div>
|
||||
);
|
||||
})}
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<Alert>
|
||||
<ExternalLink className="h-4 w-4" />
|
||||
<AlertDescription className="text-xs">
|
||||
After authorization, you will be redirected to:
|
||||
<br />
|
||||
<code className="text-xs break-all bg-muted px-1 py-0.5 rounded">
|
||||
{params.redirectUri}
|
||||
</code>
|
||||
</AlertDescription>
|
||||
</Alert>
|
||||
</CardContent>
|
||||
|
||||
<CardFooter className="flex gap-3">
|
||||
<Button
|
||||
variant="outline"
|
||||
className="flex-1"
|
||||
onClick={() => handleSubmit(false)}
|
||||
disabled={isSubmitting}
|
||||
>
|
||||
{isSubmitting ? <Loader2 className="h-4 w-4 animate-spin" /> : 'Deny'}
|
||||
</Button>
|
||||
<Button
|
||||
className="flex-1"
|
||||
onClick={() => handleSubmit(true)}
|
||||
disabled={isSubmitting || selectedScopes.size === 0}
|
||||
>
|
||||
{isSubmitting ? <Loader2 className="h-4 w-4 animate-spin" /> : 'Authorize'}
|
||||
</Button>
|
||||
</CardFooter>
|
||||
</Card>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
@@ -56,6 +56,44 @@ export function useOAuthProviders() {
|
||||
// OAuth Flow Mutations
|
||||
// ============================================================================
|
||||
|
||||
// Allowed OAuth provider domains for security validation
|
||||
const ALLOWED_OAUTH_DOMAINS = [
|
||||
'accounts.google.com',
|
||||
'github.com',
|
||||
'www.facebook.com', // For future Facebook support
|
||||
'login.microsoftonline.com', // For future Microsoft support
|
||||
];
|
||||
|
||||
/**
|
||||
* Validate OAuth authorization URL
|
||||
* SECURITY: Prevents open redirect attacks by only allowing known OAuth provider domains
|
||||
*/
|
||||
function isValidOAuthUrl(url: string): boolean {
|
||||
try {
|
||||
const parsed = new URL(url);
|
||||
// Only allow HTTPS for OAuth (security requirement)
|
||||
if (parsed.protocol !== 'https:') {
|
||||
return false;
|
||||
}
|
||||
// Check if domain is in allowlist
|
||||
return ALLOWED_OAUTH_DOMAINS.includes(parsed.hostname);
|
||||
} catch {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Extract state parameter from OAuth authorization URL
|
||||
*/
|
||||
function extractStateFromUrl(url: string): string | null {
|
||||
try {
|
||||
const parsed = new URL(url);
|
||||
return parsed.searchParams.get('state');
|
||||
} catch {
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Start OAuth login/registration flow
|
||||
* Redirects user to the OAuth provider
|
||||
@@ -77,12 +115,27 @@ export function useOAuthStart() {
|
||||
});
|
||||
|
||||
if (response.data) {
|
||||
// Store mode in sessionStorage for callback handling
|
||||
sessionStorage.setItem('oauth_mode', mode);
|
||||
sessionStorage.setItem('oauth_provider', provider);
|
||||
|
||||
// Response is { [key: string]: unknown }, so cast authorization_url
|
||||
const authUrl = (response.data as { authorization_url: string }).authorization_url;
|
||||
|
||||
// SECURITY: Validate the authorization URL before redirecting
|
||||
// This prevents open redirect attacks if the backend is compromised
|
||||
if (!isValidOAuthUrl(authUrl)) {
|
||||
throw new Error('Invalid OAuth authorization URL');
|
||||
}
|
||||
|
||||
// SECURITY: Extract and store the state parameter for CSRF validation
|
||||
// The callback page will verify this matches the state in the response
|
||||
const state = extractStateFromUrl(authUrl);
|
||||
if (!state) {
|
||||
throw new Error('Missing state parameter in authorization URL');
|
||||
}
|
||||
|
||||
// Store mode, provider, and state in sessionStorage for callback handling
|
||||
sessionStorage.setItem('oauth_mode', mode);
|
||||
sessionStorage.setItem('oauth_provider', provider);
|
||||
sessionStorage.setItem('oauth_state', state);
|
||||
|
||||
// Redirect to OAuth provider
|
||||
window.location.href = authUrl;
|
||||
}
|
||||
@@ -151,14 +204,16 @@ export function useOAuthCallback() {
|
||||
queryClient.invalidateQueries({ queryKey: ['user'] });
|
||||
}
|
||||
|
||||
// Clean up session storage
|
||||
// Clean up session storage (including state for security)
|
||||
sessionStorage.removeItem('oauth_mode');
|
||||
sessionStorage.removeItem('oauth_provider');
|
||||
sessionStorage.removeItem('oauth_state');
|
||||
},
|
||||
onError: () => {
|
||||
// Clean up session storage on error too
|
||||
sessionStorage.removeItem('oauth_mode');
|
||||
sessionStorage.removeItem('oauth_provider');
|
||||
sessionStorage.removeItem('oauth_state');
|
||||
},
|
||||
});
|
||||
}
|
||||
@@ -199,12 +254,25 @@ export function useOAuthLink() {
|
||||
});
|
||||
|
||||
if (response.data) {
|
||||
// Store mode in sessionStorage for callback handling
|
||||
sessionStorage.setItem('oauth_mode', 'link');
|
||||
sessionStorage.setItem('oauth_provider', provider);
|
||||
|
||||
// Response is { [key: string]: unknown }, so cast authorization_url
|
||||
const authUrl = (response.data as { authorization_url: string }).authorization_url;
|
||||
|
||||
// SECURITY: Validate the authorization URL before redirecting
|
||||
if (!isValidOAuthUrl(authUrl)) {
|
||||
throw new Error('Invalid OAuth authorization URL');
|
||||
}
|
||||
|
||||
// SECURITY: Extract and store the state parameter for CSRF validation
|
||||
const state = extractStateFromUrl(authUrl);
|
||||
if (!state) {
|
||||
throw new Error('Missing state parameter in authorization URL');
|
||||
}
|
||||
|
||||
// Store mode, provider, and state in sessionStorage for callback handling
|
||||
sessionStorage.setItem('oauth_mode', 'link');
|
||||
sessionStorage.setItem('oauth_provider', provider);
|
||||
sessionStorage.setItem('oauth_state', state);
|
||||
|
||||
// Redirect to OAuth provider
|
||||
window.location.href = authUrl;
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user