Compare commits

..

6 Commits

Author SHA1 Message Date
Felipe Cardoso
dc875c5c95 Enhance OAuth security, PKCE, and state validation
- Enforced stricter PKCE requirements by rejecting insecure 'plain' method for public clients.
- Transitioned client secret hashing to bcrypt for improved security and migration compatibility.
- Added constant-time comparison for state parameter validation to prevent timing attacks.
- Improved error handling and logging for OAuth workflows, including malformed headers and invalid scopes.
- Upgraded Google OIDC token validation to verify both signature and nonce.
- Refactored OAuth service methods and schemas for better readability, consistency, and compliance with RFC specifications.
2025-11-26 00:14:26 +01:00
Felipe Cardoso
0ea428b718 Refactor tests for improved readability and fixture consistency
- Reformatted headers in E2E tests to improve readability and ensure consistent style.
- Updated confidential client fixture to use bcrypt for secret hashing, enhancing security and testing backward compatibility with legacy SHA-256 hashes.
- Added new test cases for PKCE verification, rejecting insecure 'plain' methods, and improved error handling.
- Refined session workflows and user agent handling in E2E tests for session management.
- Consolidated schema operation tests and fixed minor formatting inconsistencies.
2025-11-26 00:13:53 +01:00
Felipe Cardoso
400d6f6f75 Enhance OAuth security and state validation
- Implemented stricter OAuth security measures, including CSRF protection via state parameter validation and redirect_uri checks.
- Updated OAuth models to support timezone-aware datetime comparisons, replacing deprecated `utcnow`.
- Enhanced logging for malformed Basic auth headers during token, introspect, and revoke requests.
- Added allowlist validation for OAuth provider domains to prevent open redirect attacks.
- Improved nonce validation for OpenID Connect tokens, ensuring token integrity during Google provider flows.
- Updated E2E and unit tests to cover new security features and expanded OAuth state handling scenarios.
2025-11-25 23:50:43 +01:00
Felipe Cardoso
7716468d72 Add E2E tests for admin and organization workflows
- Introduced E2E tests for admin user and organization management workflows: user listing, creation, updates, bulk actions, and organization membership management.
- Added comprehensive tests for organization CRUD operations, membership visibility, roles, and permission validation.
- Expanded fixtures for superuser and member setup to streamline testing of admin-specific operations.
- Verified pagination, filtering, and action consistency across admin endpoints.
2025-11-25 23:50:02 +01:00
Felipe Cardoso
48f052200f Add OAuth provider mode and MCP integration
- Introduced full OAuth 2.0 Authorization Server functionality for MCP clients.
- Updated documentation with details on endpoints, scopes, and consent management.
- Added a new frontend OAuth consent page for user authorization flows.
- Implemented database models for authorization codes, refresh tokens, and user consents.
- Created unit tests for service methods (PKCE verification, client validation, scope handling).
- Included comprehensive integration tests for OAuth provider workflows.
2025-11-25 23:18:19 +01:00
Felipe Cardoso
fbb030da69 Add E2E workflow tests for organizations, users, sessions, and API contracts
- Introduced comprehensive E2E tests for organization workflows: creation, membership management, and updates.
- Added tests for user management workflows: profile viewing, updates, password changes, and settings.
- Implemented session management tests, including listing, revocation, multi-device handling, and cleanup.
- Included API contract validation tests using Schemathesis, covering protected endpoints and schema structure.
- Enhanced E2E testing infrastructure with full PostgreSQL support and detailed workflow coverage.
2025-11-25 23:13:28 +01:00
26 changed files with 6028 additions and 183 deletions

View File

@@ -42,7 +42,7 @@ Default superuser (change in production):
│ │ ├── schemas/ # Pydantic request/response schemas
│ │ ├── services/ # Business logic layer
│ │ └── utils/ # Utilities (security, device detection)
│ ├── tests/ # 97% coverage, 743 tests
│ ├── tests/ # 96% coverage, 987 tests
│ └── alembic/ # Database migrations
└── frontend/ # Next.js 15 frontend
@@ -69,6 +69,27 @@ Default superuser (change in production):
- `get_optional_current_user`: Accepts authenticated or anonymous
- `get_current_superuser`: Requires superuser flag
### OAuth Provider Mode (MCP Integration)
Full OAuth 2.0 Authorization Server for MCP (Model Context Protocol) clients:
- **Authorization Code Flow with PKCE**: RFC 7636 compliant
- **JWT access tokens**: Self-contained, no DB lookup required
- **Opaque refresh tokens**: Stored hashed in database, supports rotation
- **Token introspection**: RFC 7662 compliant endpoint
- **Token revocation**: RFC 7009 compliant endpoint
- **Server metadata**: RFC 8414 compliant discovery endpoint
- **Consent management**: User can review and revoke app permissions
**API endpoints:**
- `GET /.well-known/oauth-authorization-server` - Server metadata
- `GET /oauth/provider/authorize` - Authorization endpoint
- `POST /oauth/provider/authorize/consent` - Consent submission
- `POST /oauth/provider/token` - Token endpoint
- `POST /oauth/provider/revoke` - Token revocation
- `POST /oauth/provider/introspect` - Token introspection
- Client management endpoints (admin only)
**Scopes supported:** `openid`, `profile`, `email`, `read:users`, `write:users`, `admin`
### Database Pattern
- **Async SQLAlchemy 2.0** with PostgreSQL
- **Connection pooling**: 20 base connections, 50 max overflow
@@ -107,7 +128,7 @@ Permission dependencies in `api/dependencies/permissions.py`:
### Testing Infrastructure
**Backend Unit/Integration (pytest + SQLite):**
- 97% coverage, 743+ tests
- 96% coverage, 987 tests
- Security-focused: JWT attacks, session hijacking, privilege escalation
- Async fixtures in `tests/conftest.py`
- Run: `IS_TEST=True uv run pytest` or `make test`
@@ -238,12 +259,13 @@ docker-compose exec backend python -c "from app.init_db import init_db; import a
### Completed Features ✅
- Authentication system (JWT with refresh tokens, OAuth/social login)
- **OAuth Provider Mode (MCP-ready)**: Full OAuth 2.0 Authorization Server
- Session management (device tracking, revocation)
- User management (CRUD, password change)
- Organization system (multi-tenant with RBAC)
- Admin panel (user/org management, bulk operations)
- **Internationalization (i18n)** with English and Italian
- Comprehensive test coverage (97% backend, 97% frontend unit, 56 E2E tests)
- Comprehensive test coverage (96% backend, 97% frontend unit, 56 E2E tests)
- Design system documentation
- **Marketing landing page** with animations
- **`/dev` documentation portal** with live examples

View File

@@ -0,0 +1,194 @@
"""Add OAuth provider models for MCP integration.
Revision ID: f8c3d2e1a4b5
Revises: d5a7b2c9e1f3
Create Date: 2025-01-15 10:00:00.000000
This migration adds tables for OAuth provider mode:
- oauth_authorization_codes: Temporary authorization codes
- oauth_provider_refresh_tokens: Long-lived refresh tokens
- oauth_consents: User consent records
"""
import sqlalchemy as sa
from alembic import op
from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision = "f8c3d2e1a4b5"
down_revision = "d5a7b2c9e1f3"
branch_labels = None
depends_on = None
def upgrade() -> None:
# Create oauth_authorization_codes table
op.create_table(
"oauth_authorization_codes",
sa.Column("id", postgresql.UUID(as_uuid=True), nullable=False),
sa.Column("code", sa.String(128), nullable=False),
sa.Column("client_id", sa.String(64), nullable=False),
sa.Column("user_id", postgresql.UUID(as_uuid=True), nullable=False),
sa.Column("redirect_uri", sa.String(2048), nullable=False),
sa.Column("scope", sa.String(1000), nullable=False, server_default=""),
sa.Column("code_challenge", sa.String(128), nullable=True),
sa.Column("code_challenge_method", sa.String(10), nullable=True),
sa.Column("state", sa.String(256), nullable=True),
sa.Column("nonce", sa.String(256), nullable=True),
sa.Column("expires_at", sa.DateTime(timezone=True), nullable=False),
sa.Column("used", sa.Boolean(), nullable=False, server_default="false"),
sa.Column(
"created_at",
sa.DateTime(timezone=True),
server_default=sa.text("now()"),
nullable=False,
),
sa.Column(
"updated_at",
sa.DateTime(timezone=True),
server_default=sa.text("now()"),
nullable=False,
),
sa.ForeignKeyConstraint(
["client_id"],
["oauth_clients.client_id"],
ondelete="CASCADE",
),
sa.ForeignKeyConstraint(
["user_id"],
["users.id"],
ondelete="CASCADE",
),
sa.PrimaryKeyConstraint("id"),
)
op.create_index(
"ix_oauth_authorization_codes_code",
"oauth_authorization_codes",
["code"],
unique=True,
)
op.create_index(
"ix_oauth_authorization_codes_expires_at",
"oauth_authorization_codes",
["expires_at"],
)
op.create_index(
"ix_oauth_authorization_codes_client_user",
"oauth_authorization_codes",
["client_id", "user_id"],
)
# Create oauth_provider_refresh_tokens table
op.create_table(
"oauth_provider_refresh_tokens",
sa.Column("id", postgresql.UUID(as_uuid=True), nullable=False),
sa.Column("token_hash", sa.String(64), nullable=False),
sa.Column("jti", sa.String(64), nullable=False),
sa.Column("client_id", sa.String(64), nullable=False),
sa.Column("user_id", postgresql.UUID(as_uuid=True), nullable=False),
sa.Column("scope", sa.String(1000), nullable=False, server_default=""),
sa.Column("expires_at", sa.DateTime(timezone=True), nullable=False),
sa.Column("revoked", sa.Boolean(), nullable=False, server_default="false"),
sa.Column("last_used_at", sa.DateTime(timezone=True), nullable=True),
sa.Column("device_info", sa.String(500), nullable=True),
sa.Column("ip_address", sa.String(45), nullable=True),
sa.Column(
"created_at",
sa.DateTime(timezone=True),
server_default=sa.text("now()"),
nullable=False,
),
sa.Column(
"updated_at",
sa.DateTime(timezone=True),
server_default=sa.text("now()"),
nullable=False,
),
sa.ForeignKeyConstraint(
["client_id"],
["oauth_clients.client_id"],
ondelete="CASCADE",
),
sa.ForeignKeyConstraint(
["user_id"],
["users.id"],
ondelete="CASCADE",
),
sa.PrimaryKeyConstraint("id"),
)
op.create_index(
"ix_oauth_provider_refresh_tokens_token_hash",
"oauth_provider_refresh_tokens",
["token_hash"],
unique=True,
)
op.create_index(
"ix_oauth_provider_refresh_tokens_jti",
"oauth_provider_refresh_tokens",
["jti"],
unique=True,
)
op.create_index(
"ix_oauth_provider_refresh_tokens_expires_at",
"oauth_provider_refresh_tokens",
["expires_at"],
)
op.create_index(
"ix_oauth_provider_refresh_tokens_client_user",
"oauth_provider_refresh_tokens",
["client_id", "user_id"],
)
op.create_index(
"ix_oauth_provider_refresh_tokens_user_revoked",
"oauth_provider_refresh_tokens",
["user_id", "revoked"],
)
op.create_index(
"ix_oauth_provider_refresh_tokens_revoked",
"oauth_provider_refresh_tokens",
["revoked"],
)
# Create oauth_consents table
op.create_table(
"oauth_consents",
sa.Column("id", postgresql.UUID(as_uuid=True), nullable=False),
sa.Column("user_id", postgresql.UUID(as_uuid=True), nullable=False),
sa.Column("client_id", sa.String(64), nullable=False),
sa.Column("granted_scopes", sa.String(1000), nullable=False, server_default=""),
sa.Column(
"created_at",
sa.DateTime(timezone=True),
server_default=sa.text("now()"),
nullable=False,
),
sa.Column(
"updated_at",
sa.DateTime(timezone=True),
server_default=sa.text("now()"),
nullable=False,
),
sa.ForeignKeyConstraint(
["client_id"],
["oauth_clients.client_id"],
ondelete="CASCADE",
),
sa.ForeignKeyConstraint(
["user_id"],
["users.id"],
ondelete="CASCADE",
),
sa.PrimaryKeyConstraint("id"),
)
op.create_index(
"ix_oauth_consents_user_client",
"oauth_consents",
["user_id", "client_id"],
unique=True,
)
def downgrade() -> None:
op.drop_table("oauth_consents")
op.drop_table("oauth_provider_refresh_tokens")
op.drop_table("oauth_authorization_codes")

View File

@@ -1,37 +1,63 @@
# app/api/routes/oauth_provider.py
"""
OAuth Provider routes (Authorization Server mode).
OAuth Provider routes (Authorization Server mode) for MCP integration.
This is a skeleton implementation for MCP (Model Context Protocol) client authentication.
Provides basic OAuth 2.0 endpoints that can be expanded for full functionality.
Endpoints:
Implements OAuth 2.0 Authorization Server endpoints:
- GET /.well-known/oauth-authorization-server - Server metadata (RFC 8414)
- GET /oauth/provider/authorize - Authorization endpoint (skeleton)
- POST /oauth/provider/token - Token endpoint (skeleton)
- POST /oauth/provider/revoke - Token revocation endpoint (skeleton)
- GET /oauth/provider/authorize - Authorization endpoint
- POST /oauth/provider/token - Token endpoint
- POST /oauth/provider/revoke - Token revocation (RFC 7009)
- POST /oauth/provider/introspect - Token introspection (RFC 7662)
- Client management endpoints
NOTE: This is intentionally minimal. Full implementation should include:
- Complete authorization code flow
- Refresh token handling
- Scope validation
- Client authentication
- PKCE support
Security features:
- PKCE required for public clients (S256)
- CSRF protection via state parameter
- Secure token handling
- Rate limiting on sensitive endpoints
"""
import logging
from typing import Any
from urllib.parse import urlencode
from fastapi import APIRouter, Depends, Form, HTTPException, Query, status
from fastapi import APIRouter, Depends, Form, HTTPException, Query, Request, status
from fastapi.responses import RedirectResponse
from slowapi import Limiter
from slowapi.util import get_remote_address
from sqlalchemy.ext.asyncio import AsyncSession
from app.api.dependencies.auth import get_current_active_user, get_current_superuser
from app.core.config import settings
from app.core.database import get_db
from app.crud import oauth_client
from app.schemas.oauth import OAuthServerMetadata
from app.crud import oauth_client as oauth_client_crud
from app.models.user import User
from app.schemas.oauth import (
OAuthClientCreate,
OAuthClientResponse,
OAuthServerMetadata,
OAuthTokenIntrospectionResponse,
OAuthTokenResponse,
)
from app.services import oauth_provider_service as provider_service
router = APIRouter()
logger = logging.getLogger(__name__)
limiter = Limiter(key_func=get_remote_address)
def require_provider_enabled():
"""Dependency to check if OAuth provider mode is enabled."""
if not settings.OAUTH_PROVIDER_ENABLED:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="OAuth provider mode is not enabled. Set OAUTH_PROVIDER_ENABLED=true",
)
# ============================================================================
# Server Metadata (RFC 8414)
# ============================================================================
@router.get(
@@ -42,24 +68,15 @@ logger = logging.getLogger(__name__)
OAuth 2.0 Authorization Server Metadata (RFC 8414).
Returns server metadata including supported endpoints, scopes,
and capabilities for MCP clients.
and capabilities. MCP clients use this to discover the server.
""",
operation_id="get_oauth_server_metadata",
tags=["OAuth Provider"],
)
async def get_server_metadata() -> Any:
"""
Get OAuth 2.0 server metadata.
This endpoint is used by MCP clients to discover the authorization
server's capabilities.
"""
if not settings.OAUTH_PROVIDER_ENABLED:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="OAuth provider mode is not enabled",
)
async def get_server_metadata(
_: None = Depends(require_provider_enabled),
) -> OAuthServerMetadata:
"""Get OAuth 2.0 server metadata."""
base_url = settings.OAUTH_ISSUER.rstrip("/")
return OAuthServerMetadata(
@@ -67,7 +84,8 @@ async def get_server_metadata() -> Any:
authorization_endpoint=f"{base_url}/api/v1/oauth/provider/authorize",
token_endpoint=f"{base_url}/api/v1/oauth/provider/token",
revocation_endpoint=f"{base_url}/api/v1/oauth/provider/revoke",
registration_endpoint=None, # Dynamic registration not implemented
introspection_endpoint=f"{base_url}/api/v1/oauth/provider/introspect",
registration_endpoint=None, # Dynamic registration not supported
scopes_supported=[
"openid",
"profile",
@@ -76,148 +94,446 @@ async def get_server_metadata() -> Any:
"write:users",
"read:organizations",
"write:organizations",
"admin",
],
response_types_supported=["code"],
grant_types_supported=["authorization_code", "refresh_token"],
code_challenge_methods_supported=["S256"],
token_endpoint_auth_methods_supported=[
"client_secret_basic",
"client_secret_post",
"none", # For public clients with PKCE
],
)
# ============================================================================
# Authorization Endpoint
# ============================================================================
@router.get(
"/provider/authorize",
summary="Authorization Endpoint (Skeleton)",
summary="Authorization Endpoint",
description="""
OAuth 2.0 Authorization Endpoint.
**NOTE**: This is a skeleton implementation. In a full implementation,
this would:
1. Validate client_id and redirect_uri
2. Display consent screen to user
3. Generate authorization code
4. Redirect back to client with code
Initiates the authorization code flow:
1. Validates client and parameters
2. Checks if user is authenticated (redirects to login if not)
3. Checks existing consent
4. Redirects to consent page if needed
5. Issues authorization code and redirects back to client
Currently returns a 501 Not Implemented response.
Required parameters:
- response_type: Must be "code"
- client_id: Registered client ID
- redirect_uri: Must match registered URI
Recommended parameters:
- state: CSRF protection
- code_challenge + code_challenge_method: PKCE (required for public clients)
- scope: Requested permissions
""",
operation_id="oauth_provider_authorize",
tags=["OAuth Provider"],
)
@limiter.limit("30/minute")
async def authorize(
request: Request,
response_type: str = Query(..., description="Must be 'code'"),
client_id: str = Query(..., description="OAuth client ID"),
redirect_uri: str = Query(..., description="Redirect URI"),
scope: str = Query(default="", description="Requested scopes"),
scope: str = Query(default="", description="Requested scopes (space-separated)"),
state: str = Query(default="", description="CSRF state parameter"),
code_challenge: str | None = Query(default=None, description="PKCE code challenge"),
code_challenge_method: str | None = Query(
default=None, description="PKCE method (S256)"
),
nonce: str | None = Query(default=None, description="OpenID Connect nonce"),
db: AsyncSession = Depends(get_db),
_: None = Depends(require_provider_enabled),
current_user: User | None = Depends(get_current_active_user),
) -> Any:
"""
Authorization endpoint (skeleton).
Authorization endpoint - initiates OAuth flow.
In a full implementation, this would:
1. Validate the client and redirect URI
2. Authenticate the user (if not already)
3. Show consent screen
4. Generate authorization code
5. Redirect to redirect_uri with code
If user is not authenticated, redirects to login with return URL.
If user has not consented, redirects to consent page.
If all checks pass, generates code and redirects to client.
"""
if not settings.OAUTH_PROVIDER_ENABLED:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="OAuth provider mode is not enabled",
)
# Validate client exists
client = await oauth_client.get_by_client_id(db, client_id=client_id)
if not client:
# Validate response_type
if response_type != "code":
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="invalid_client: Unknown client_id",
detail="invalid_request: response_type must be 'code'",
)
# Validate redirect_uri
if redirect_uri not in (client.redirect_uris or []):
# Validate PKCE method if provided - ONLY S256 is allowed (RFC 7636 Section 4.3)
# "plain" method provides no security benefit and MUST NOT be used
if code_challenge_method and code_challenge_method != "S256":
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="invalid_request: Invalid redirect_uri",
detail="invalid_request: code_challenge_method must be 'S256' (plain is not supported)",
)
# Skeleton: Return not implemented
# Full implementation would redirect to consent screen
raise HTTPException(
status_code=status.HTTP_501_NOT_IMPLEMENTED,
detail="Authorization endpoint not fully implemented. "
"This is a skeleton for MCP integration.",
# Validate client
try:
client = await provider_service.get_client(db, client_id)
if not client:
raise provider_service.InvalidClientError("Unknown client_id")
provider_service.validate_redirect_uri(client, redirect_uri)
except provider_service.OAuthProviderError as e:
# For client/redirect errors, we can't safely redirect - show error
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"{e.error}: {e.error_description}",
)
# Validate and filter scopes
try:
requested_scopes = provider_service.parse_scope(scope)
valid_scopes = provider_service.validate_scopes(client, requested_scopes)
except provider_service.InvalidScopeError as e:
# Redirect with error
scope_error_params: dict[str, str] = {"error": e.error}
if e.error_description:
scope_error_params["error_description"] = e.error_description
if state:
scope_error_params["state"] = state
return RedirectResponse(
url=f"{redirect_uri}?{urlencode(scope_error_params)}",
status_code=status.HTTP_302_FOUND,
)
# Public clients MUST use PKCE
if client.client_type == "public":
if not code_challenge or code_challenge_method != "S256":
pkce_error_params: dict[str, str] = {
"error": "invalid_request",
"error_description": "PKCE with S256 is required for public clients",
}
if state:
pkce_error_params["state"] = state
return RedirectResponse(
url=f"{redirect_uri}?{urlencode(pkce_error_params)}",
status_code=status.HTTP_302_FOUND,
)
# If user is not authenticated, redirect to login
if not current_user:
# Store authorization request in session and redirect to login
# The frontend will handle the return URL
login_url = f"{settings.FRONTEND_URL}/login"
return_params = urlencode(
{
"oauth_authorize": "true",
"client_id": client_id,
"redirect_uri": redirect_uri,
"scope": " ".join(valid_scopes),
"state": state,
"code_challenge": code_challenge or "",
"code_challenge_method": code_challenge_method or "",
"nonce": nonce or "",
}
)
return RedirectResponse(
url=f"{login_url}?return_to=/auth/consent?{return_params}",
status_code=status.HTTP_302_FOUND,
)
# Check if user has already consented
has_consent = await provider_service.check_consent(
db, current_user.id, client_id, valid_scopes
)
if not has_consent:
# Redirect to consent page
consent_params = urlencode(
{
"client_id": client_id,
"client_name": client.client_name,
"redirect_uri": redirect_uri,
"scope": " ".join(valid_scopes),
"state": state,
"code_challenge": code_challenge or "",
"code_challenge_method": code_challenge_method or "",
"nonce": nonce or "",
}
)
return RedirectResponse(
url=f"{settings.FRONTEND_URL}/auth/consent?{consent_params}",
status_code=status.HTTP_302_FOUND,
)
# User is authenticated and has consented - issue authorization code
try:
code = await provider_service.create_authorization_code(
db=db,
client=client,
user=current_user,
redirect_uri=redirect_uri,
scope=" ".join(valid_scopes),
code_challenge=code_challenge,
code_challenge_method=code_challenge_method,
state=state,
nonce=nonce,
)
except provider_service.OAuthProviderError as e:
error_params: dict[str, str] = {"error": e.error}
if e.error_description:
error_params["error_description"] = e.error_description
if state:
error_params["state"] = state
return RedirectResponse(
url=f"{redirect_uri}?{urlencode(error_params)}",
status_code=status.HTTP_302_FOUND,
)
# Success - redirect with code
success_params = {"code": code}
if state:
success_params["state"] = state
return RedirectResponse(
url=f"{redirect_uri}?{urlencode(success_params)}",
status_code=status.HTTP_302_FOUND,
)
@router.post(
"/provider/authorize/consent",
summary="Submit Authorization Consent",
description="""
Submit user consent for OAuth authorization.
Called by the consent page after user approves or denies.
""",
operation_id="oauth_provider_consent",
tags=["OAuth Provider"],
)
@limiter.limit("30/minute")
async def submit_consent(
request: Request,
approved: bool = Form(..., description="Whether user approved"),
client_id: str = Form(..., description="OAuth client ID"),
redirect_uri: str = Form(..., description="Redirect URI"),
scope: str = Form(default="", description="Granted scopes"),
state: str = Form(default="", description="CSRF state parameter"),
code_challenge: str | None = Form(default=None),
code_challenge_method: str | None = Form(default=None),
nonce: str | None = Form(default=None),
db: AsyncSession = Depends(get_db),
_: None = Depends(require_provider_enabled),
current_user: User = Depends(get_current_active_user),
) -> Any:
"""Process consent form submission."""
# Validate client
try:
client = await provider_service.get_client(db, client_id)
if not client:
raise provider_service.InvalidClientError("Unknown client_id")
provider_service.validate_redirect_uri(client, redirect_uri)
except provider_service.OAuthProviderError as e:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"{e.error}: {e.error_description}",
)
# If user denied, redirect with error
if not approved:
denied_params: dict[str, str] = {
"error": "access_denied",
"error_description": "User denied authorization",
}
if state:
denied_params["state"] = state
return RedirectResponse(
url=f"{redirect_uri}?{urlencode(denied_params)}",
status_code=status.HTTP_302_FOUND,
)
# Parse and validate scopes
granted_scopes = provider_service.parse_scope(scope)
valid_scopes = provider_service.validate_scopes(client, granted_scopes)
# Record consent
await provider_service.grant_consent(db, current_user.id, client_id, valid_scopes)
# Generate authorization code
try:
code = await provider_service.create_authorization_code(
db=db,
client=client,
user=current_user,
redirect_uri=redirect_uri,
scope=" ".join(valid_scopes),
code_challenge=code_challenge,
code_challenge_method=code_challenge_method,
state=state,
nonce=nonce,
)
except provider_service.OAuthProviderError as e:
error_params: dict[str, str] = {"error": e.error}
if e.error_description:
error_params["error_description"] = e.error_description
if state:
error_params["state"] = state
return RedirectResponse(
url=f"{redirect_uri}?{urlencode(error_params)}",
status_code=status.HTTP_302_FOUND,
)
# Success
success_params = {"code": code}
if state:
success_params["state"] = state
return RedirectResponse(
url=f"{redirect_uri}?{urlencode(success_params)}",
status_code=status.HTTP_302_FOUND,
)
# ============================================================================
# Token Endpoint
# ============================================================================
@router.post(
"/provider/token",
summary="Token Endpoint (Skeleton)",
response_model=OAuthTokenResponse,
summary="Token Endpoint",
description="""
OAuth 2.0 Token Endpoint.
**NOTE**: This is a skeleton implementation. In a full implementation,
this would exchange authorization codes for access tokens.
Supports:
- authorization_code: Exchange code for tokens
- refresh_token: Refresh access token
Currently returns a 501 Not Implemented response.
Client authentication:
- Confidential clients: client_secret (Basic auth or POST body)
- Public clients: No secret, but PKCE code_verifier required
""",
operation_id="oauth_provider_token",
tags=["OAuth Provider"],
)
@limiter.limit("60/minute")
async def token(
grant_type: str = Form(..., description="Grant type (authorization_code)"),
request: Request,
grant_type: str = Form(..., description="Grant type"),
code: str | None = Form(default=None, description="Authorization code"),
redirect_uri: str | None = Form(default=None, description="Redirect URI"),
client_id: str | None = Form(default=None, description="Client ID"),
client_secret: str | None = Form(default=None, description="Client secret"),
code_verifier: str | None = Form(default=None, description="PKCE code verifier"),
refresh_token: str | None = Form(default=None, description="Refresh token"),
scope: str | None = Form(default=None, description="Scope (for refresh)"),
db: AsyncSession = Depends(get_db),
) -> Any:
"""
Token endpoint (skeleton).
_: None = Depends(require_provider_enabled),
) -> OAuthTokenResponse:
"""Token endpoint - exchange code for tokens or refresh."""
# Extract client credentials from Basic auth if not in body
if not client_id:
auth_header = request.headers.get("Authorization", "")
if auth_header.startswith("Basic "):
import base64
Supported grant types (when fully implemented):
- authorization_code: Exchange code for tokens
- refresh_token: Refresh access token
"""
if not settings.OAUTH_PROVIDER_ENABLED:
try:
decoded = base64.b64decode(auth_header[6:]).decode()
client_id, client_secret = decoded.split(":", 1)
except Exception as e:
# Log malformed Basic auth for security monitoring
logger.warning(
f"Malformed Basic auth header in token request: {type(e).__name__}"
)
# Fall back to form body
if not client_id:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="OAuth provider mode is not enabled",
status_code=status.HTTP_401_UNAUTHORIZED,
detail="invalid_client: client_id required",
headers={"WWW-Authenticate": "Basic"},
)
if grant_type not in ["authorization_code", "refresh_token"]:
# Get device info
device_info = request.headers.get("User-Agent", "")[:500]
ip_address = get_remote_address(request)
try:
if grant_type == "authorization_code":
if not code:
raise provider_service.InvalidRequestError("code required")
if not redirect_uri:
raise provider_service.InvalidRequestError("redirect_uri required")
result = await provider_service.exchange_authorization_code(
db=db,
code=code,
client_id=client_id,
redirect_uri=redirect_uri,
code_verifier=code_verifier,
client_secret=client_secret,
device_info=device_info,
ip_address=ip_address,
)
elif grant_type == "refresh_token":
if not refresh_token:
raise provider_service.InvalidRequestError("refresh_token required")
result = await provider_service.refresh_tokens(
db=db,
refresh_token=refresh_token,
client_id=client_id,
client_secret=client_secret,
scope=scope,
device_info=device_info,
ip_address=ip_address,
)
else:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="unsupported_grant_type: Must be authorization_code or refresh_token",
)
return OAuthTokenResponse(**result)
except provider_service.InvalidClientError as e:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=f"{e.error}: {e.error_description}",
headers={"WWW-Authenticate": "Basic"},
)
except provider_service.OAuthProviderError as e:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="unsupported_grant_type",
detail=f"{e.error}: {e.error_description}",
)
# Skeleton: Return not implemented
raise HTTPException(
status_code=status.HTTP_501_NOT_IMPLEMENTED,
detail="Token endpoint not fully implemented. "
"This is a skeleton for MCP integration.",
)
# ============================================================================
# Token Revocation (RFC 7009)
# ============================================================================
@router.post(
"/provider/revoke",
summary="Token Revocation Endpoint (Skeleton)",
status_code=status.HTTP_200_OK,
summary="Token Revocation Endpoint",
description="""
OAuth 2.0 Token Revocation Endpoint (RFC 7009).
**NOTE**: This is a skeleton implementation.
Currently returns a 501 Not Implemented response.
Revokes an access token or refresh token.
Always returns 200 OK (even if token is invalid) per spec.
""",
operation_id="oauth_provider_revoke",
tags=["OAuth Provider"],
)
@limiter.limit("30/minute")
async def revoke(
request: Request,
token: str = Form(..., description="Token to revoke"),
token_type_hint: str | None = Form(
default=None, description="Token type hint (access_token, refresh_token)"
@@ -225,88 +541,298 @@ async def revoke(
client_id: str | None = Form(default=None, description="Client ID"),
client_secret: str | None = Form(default=None, description="Client secret"),
db: AsyncSession = Depends(get_db),
) -> Any:
"""
Token revocation endpoint (skeleton).
_: None = Depends(require_provider_enabled),
) -> dict[str, str]:
"""Revoke a token."""
# Extract client credentials from Basic auth if not in body
if not client_id:
auth_header = request.headers.get("Authorization", "")
if auth_header.startswith("Basic "):
import base64
In a full implementation, this would invalidate the specified token.
"""
if not settings.OAUTH_PROVIDER_ENABLED:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="OAuth provider mode is not enabled",
try:
decoded = base64.b64decode(auth_header[6:]).decode()
client_id, client_secret = decoded.split(":", 1)
except Exception as e:
# Log malformed Basic auth for security monitoring
logger.warning(
f"Malformed Basic auth header in revoke request: {type(e).__name__}"
)
# Fall back to form body
try:
await provider_service.revoke_token(
db=db,
token=token,
token_type_hint=token_type_hint,
client_id=client_id,
client_secret=client_secret,
)
except provider_service.InvalidClientError:
# Per RFC 7009, we should return 200 OK even for errors
# But client authentication errors can return 401
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="invalid_client",
headers={"WWW-Authenticate": "Basic"},
)
except Exception as e:
# Log but don't expose errors per RFC 7009
logger.warning(f"Token revocation error: {e}")
# Skeleton: Return not implemented
raise HTTPException(
status_code=status.HTTP_501_NOT_IMPLEMENTED,
detail="Revocation endpoint not fully implemented. "
"This is a skeleton for MCP integration.",
)
# Always return 200 OK per RFC 7009
return {"status": "ok"}
# ============================================================================
# Client Management (Admin only)
# Token Introspection (RFC 7662)
# ============================================================================
@router.post(
"/provider/introspect",
response_model=OAuthTokenIntrospectionResponse,
summary="Token Introspection Endpoint",
description="""
OAuth 2.0 Token Introspection Endpoint (RFC 7662).
Allows resource servers to query the authorization server
to determine the active state and metadata of a token.
""",
operation_id="oauth_provider_introspect",
tags=["OAuth Provider"],
)
@limiter.limit("120/minute")
async def introspect(
request: Request,
token: str = Form(..., description="Token to introspect"),
token_type_hint: str | None = Form(
default=None, description="Token type hint (access_token, refresh_token)"
),
client_id: str | None = Form(default=None, description="Client ID"),
client_secret: str | None = Form(default=None, description="Client secret"),
db: AsyncSession = Depends(get_db),
_: None = Depends(require_provider_enabled),
) -> OAuthTokenIntrospectionResponse:
"""Introspect a token."""
# Extract client credentials from Basic auth if not in body
if not client_id:
auth_header = request.headers.get("Authorization", "")
if auth_header.startswith("Basic "):
import base64
try:
decoded = base64.b64decode(auth_header[6:]).decode()
client_id, client_secret = decoded.split(":", 1)
except Exception as e:
# Log malformed Basic auth for security monitoring
logger.warning(
f"Malformed Basic auth header in introspect request: {type(e).__name__}"
)
# Fall back to form body
try:
result = await provider_service.introspect_token(
db=db,
token=token,
token_type_hint=token_type_hint,
client_id=client_id,
client_secret=client_secret,
)
return OAuthTokenIntrospectionResponse(**result)
except provider_service.InvalidClientError:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="invalid_client",
headers={"WWW-Authenticate": "Basic"},
)
except Exception as e:
logger.warning(f"Token introspection error: {e}")
return OAuthTokenIntrospectionResponse(active=False)
# ============================================================================
# Client Management (Admin)
# ============================================================================
@router.post(
"/provider/clients",
summary="Register OAuth Client (Admin)",
response_model=dict,
summary="Register OAuth Client",
description="""
Register a new OAuth client (admin only).
This endpoint allows creating MCP clients that can authenticate
against this API.
Creates an MCP client that can authenticate against this API.
Returns client_id and client_secret (for confidential clients).
**NOTE**: This is a minimal implementation.
**Important:** Store the client_secret securely - it won't be shown again!
""",
operation_id="register_oauth_client",
tags=["OAuth Provider"],
tags=["OAuth Provider Admin"],
)
async def register_client(
client_name: str = Form(..., description="Client application name"),
redirect_uris: str = Form(..., description="Comma-separated list of redirect URIs"),
redirect_uris: str = Form(..., description="Comma-separated redirect URIs"),
client_type: str = Form(default="public", description="public or confidential"),
scopes: str = Form(
default="openid profile email",
description="Allowed scopes (space-separated)",
),
mcp_server_url: str | None = Form(default=None, description="MCP server URL"),
db: AsyncSession = Depends(get_db),
) -> Any:
"""
Register a new OAuth client (skeleton).
In a full implementation, this would require admin authentication.
"""
if not settings.OAUTH_PROVIDER_ENABLED:
_: None = Depends(require_provider_enabled),
current_user: User = Depends(get_current_superuser),
) -> dict:
"""Register a new OAuth client."""
# Parse redirect URIs
uris = [uri.strip() for uri in redirect_uris.split(",") if uri.strip()]
if not uris:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="OAuth provider mode is not enabled",
status_code=status.HTTP_400_BAD_REQUEST,
detail="At least one redirect_uri is required",
)
# NOTE: In production, this should require admin authentication
# For now, this is a skeleton that shows the structure
from app.schemas.oauth import OAuthClientCreate
# Parse scopes
allowed_scopes = [s.strip() for s in scopes.split() if s.strip()]
client_data = OAuthClientCreate(
client_name=client_name,
client_description=None,
redirect_uris=[uri.strip() for uri in redirect_uris.split(",")],
allowed_scopes=["openid", "profile", "email"],
redirect_uris=uris,
allowed_scopes=allowed_scopes,
client_type=client_type,
)
client, secret = await oauth_client.create_client(db, obj_in=client_data)
client, secret = await oauth_client_crud.create_client(db, obj_in=client_data)
# Update MCP server URL if provided
if mcp_server_url:
client.mcp_server_url = mcp_server_url
await db.commit()
result = {
"client_id": client.client_id,
"client_name": client.client_name,
"client_type": client.client_type,
"redirect_uris": client.redirect_uris,
"allowed_scopes": client.allowed_scopes,
}
if secret:
result["client_secret"] = secret
result["warning"] = (
"Store the client_secret securely. It will not be shown again."
"Store the client_secret securely! It will not be shown again."
)
return result
@router.get(
"/provider/clients",
response_model=list[OAuthClientResponse],
summary="List OAuth Clients",
description="List all registered OAuth clients (admin only).",
operation_id="list_oauth_clients",
tags=["OAuth Provider Admin"],
)
async def list_clients(
db: AsyncSession = Depends(get_db),
_: None = Depends(require_provider_enabled),
current_user: User = Depends(get_current_superuser),
) -> list[OAuthClientResponse]:
"""List all OAuth clients."""
clients = await oauth_client_crud.get_all_clients(db)
return [OAuthClientResponse.model_validate(c) for c in clients]
@router.delete(
"/provider/clients/{client_id}",
status_code=status.HTTP_204_NO_CONTENT,
summary="Delete OAuth Client",
description="Delete an OAuth client (admin only). Revokes all tokens.",
operation_id="delete_oauth_client",
tags=["OAuth Provider Admin"],
)
async def delete_client(
client_id: str,
db: AsyncSession = Depends(get_db),
_: None = Depends(require_provider_enabled),
current_user: User = Depends(get_current_superuser),
) -> None:
"""Delete an OAuth client."""
client = await provider_service.get_client(db, client_id)
if not client:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Client not found",
)
await oauth_client_crud.delete_client(db, client_id=client_id)
# ============================================================================
# User Consent Management
# ============================================================================
@router.get(
"/provider/consents",
summary="List My Consents",
description="List OAuth applications the current user has authorized.",
operation_id="list_my_oauth_consents",
tags=["OAuth Provider"],
)
async def list_my_consents(
db: AsyncSession = Depends(get_db),
_: None = Depends(require_provider_enabled),
current_user: User = Depends(get_current_active_user),
) -> list[dict]:
"""List applications the user has authorized."""
from sqlalchemy import select
from app.models.oauth_client import OAuthClient
from app.models.oauth_provider_token import OAuthConsent
result = await db.execute(
select(OAuthConsent, OAuthClient)
.join(OAuthClient, OAuthConsent.client_id == OAuthClient.client_id)
.where(OAuthConsent.user_id == current_user.id)
)
rows = result.all()
return [
{
"client_id": consent.client_id,
"client_name": client.client_name,
"client_description": client.client_description,
"granted_scopes": consent.granted_scopes.split()
if consent.granted_scopes
else [],
"granted_at": consent.created_at.isoformat(),
}
for consent, client in rows
]
@router.delete(
"/provider/consents/{client_id}",
status_code=status.HTTP_204_NO_CONTENT,
summary="Revoke My Consent",
description="Revoke authorization for an OAuth application. Also revokes all tokens.",
operation_id="revoke_my_oauth_consent",
tags=["OAuth Provider"],
)
async def revoke_my_consent(
client_id: str,
db: AsyncSession = Depends(get_db),
_: None = Depends(require_provider_enabled),
current_user: User = Depends(get_current_active_user),
) -> None:
"""Revoke consent for an application."""
revoked = await provider_service.revoke_consent(db, current_user.id, client_id)
if not revoked:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="No consent found for this client",
)

View File

@@ -515,11 +515,11 @@ class CRUDOAuthClient(CRUDBase[OAuthClient, OAuthClientCreate, EmptySchema]):
client_secret_hash = None
if obj_in.client_type == "confidential":
client_secret = secrets.token_urlsafe(48)
# In production, use proper password hashing (bcrypt)
# For now, we store a hash placeholder
import hashlib
# SECURITY: Use bcrypt for secret storage (not SHA-256)
# bcrypt is computationally expensive, making brute-force attacks infeasible
from app.core.auth import get_password_hash
client_secret_hash = hashlib.sha256(client_secret.encode()).hexdigest()
client_secret_hash = get_password_hash(client_secret)
db_obj = OAuthClient(
client_id=client_id,
@@ -632,17 +632,82 @@ class CRUDOAuthClient(CRUDBase[OAuthClient, OAuthClientCreate, EmptySchema]):
if client is None or client.client_secret_hash is None:
return False
# Verify secret
import hashlib
# SECURITY: Verify secret using bcrypt (not SHA-256)
# This supports both old SHA-256 hashes (for migration) and new bcrypt hashes
from app.core.auth import verify_password
secret_hash = hashlib.sha256(client_secret.encode()).hexdigest()
# Cast to str for type safety with compare_digest
stored_hash: str = str(client.client_secret_hash)
return secrets.compare_digest(stored_hash, secret_hash)
# Check if it's a bcrypt hash (starts with $2b$) or legacy SHA-256
if stored_hash.startswith("$2"):
# New bcrypt format
return verify_password(client_secret, stored_hash)
else:
# Legacy SHA-256 format - still support for migration
import hashlib
secret_hash = hashlib.sha256(client_secret.encode()).hexdigest()
return secrets.compare_digest(stored_hash, secret_hash)
except Exception as e: # pragma: no cover
logger.error(f"Error verifying client secret: {e!s}")
return False
async def get_all_clients(
self, db: AsyncSession, *, include_inactive: bool = False
) -> list[OAuthClient]:
"""
Get all OAuth clients.
Args:
db: Database session
include_inactive: Whether to include inactive clients
Returns:
List of OAuthClient objects
"""
try:
query = select(OAuthClient).order_by(OAuthClient.created_at.desc())
if not include_inactive:
query = query.where(OAuthClient.is_active == True) # noqa: E712
result = await db.execute(query)
return list(result.scalars().all())
except Exception as e: # pragma: no cover
logger.error(f"Error getting all OAuth clients: {e!s}")
raise
async def delete_client(self, db: AsyncSession, *, client_id: str) -> bool:
"""
Delete an OAuth client permanently.
Note: This will cascade delete related records (tokens, consents, etc.)
due to foreign key constraints.
Args:
db: Database session
client_id: OAuth client ID
Returns:
True if deleted, False if not found
"""
try:
result = await db.execute(
delete(OAuthClient).where(OAuthClient.client_id == client_id)
)
await db.commit()
deleted = result.rowcount > 0
if deleted:
logger.info(f"OAuth client deleted: {client_id}")
else:
logger.warning(f"OAuth client not found for deletion: {client_id}")
return deleted
except Exception as e: # pragma: no cover
await db.rollback()
logger.error(f"Error deleting OAuth client {client_id}: {e!s}")
raise
# ============================================================================
# Singleton instances

View File

@@ -8,9 +8,13 @@ from app.core.database import Base
from .base import TimestampMixin, UUIDMixin
# OAuth models
# OAuth models (client mode - authenticate via Google/GitHub)
from .oauth_account import OAuthAccount
# OAuth provider models (server mode - act as authorization server for MCP)
from .oauth_authorization_code import OAuthAuthorizationCode
from .oauth_client import OAuthClient
from .oauth_provider_token import OAuthConsent, OAuthProviderRefreshToken
from .oauth_state import OAuthState
from .organization import Organization
@@ -22,7 +26,10 @@ from .user_session import UserSession
__all__ = [
"Base",
"OAuthAccount",
"OAuthAuthorizationCode",
"OAuthClient",
"OAuthConsent",
"OAuthProviderRefreshToken",
"OAuthState",
"Organization",
"OrganizationRole",

View File

@@ -0,0 +1,97 @@
"""OAuth authorization code model for OAuth provider mode."""
from datetime import UTC, datetime
from sqlalchemy import Boolean, Column, DateTime, ForeignKey, Index, String
from sqlalchemy.dialects.postgresql import UUID
from sqlalchemy.orm import relationship
from .base import Base, TimestampMixin, UUIDMixin
class OAuthAuthorizationCode(Base, UUIDMixin, TimestampMixin):
"""
OAuth 2.0 Authorization Code for the authorization code flow.
Authorization codes are:
- Single-use (marked as used after exchange)
- Short-lived (10 minutes default)
- Bound to specific client, user, redirect_uri
- Support PKCE (code_challenge/code_challenge_method)
Security considerations:
- Code must be cryptographically random (64 chars, URL-safe)
- Must validate redirect_uri matches exactly
- Must verify PKCE code_verifier for public clients
- Must be consumed within expiration time
"""
__tablename__ = "oauth_authorization_codes"
# The authorization code (cryptographically random, URL-safe)
code = Column(String(128), unique=True, nullable=False, index=True)
# Client that requested the code
client_id = Column(
String(64),
ForeignKey("oauth_clients.client_id", ondelete="CASCADE"),
nullable=False,
)
# User who authorized the request
user_id = Column(
UUID(as_uuid=True),
ForeignKey("users.id", ondelete="CASCADE"),
nullable=False,
)
# Redirect URI (must match exactly on token exchange)
redirect_uri = Column(String(2048), nullable=False)
# Granted scopes (space-separated)
scope = Column(String(1000), nullable=False, default="")
# PKCE support (required for public clients)
code_challenge = Column(String(128), nullable=True)
code_challenge_method = Column(String(10), nullable=True) # "S256" or "plain"
# State parameter (for CSRF protection, returned to client)
state = Column(String(256), nullable=True)
# Nonce (for OpenID Connect, included in ID token)
nonce = Column(String(256), nullable=True)
# Expiration (codes are short-lived, typically 10 minutes)
expires_at = Column(DateTime(timezone=True), nullable=False)
# Single-use flag (set to True after successful exchange)
used = Column(Boolean, default=False, nullable=False)
# Relationships
client = relationship("OAuthClient", backref="authorization_codes")
user = relationship("User", backref="oauth_authorization_codes")
# Indexes for efficient cleanup queries
__table_args__ = (
Index("ix_oauth_authorization_codes_expires_at", "expires_at"),
Index("ix_oauth_authorization_codes_client_user", "client_id", "user_id"),
)
def __repr__(self):
return f"<OAuthAuthorizationCode {self.code[:8]}... for {self.client_id}>"
@property
def is_expired(self) -> bool:
"""Check if the authorization code has expired."""
# Use timezone-aware comparison (datetime.utcnow() is deprecated)
now = datetime.now(UTC)
expires_at = self.expires_at
# Handle both timezone-aware and naive datetimes from DB
if expires_at.tzinfo is None:
expires_at = expires_at.replace(tzinfo=UTC)
return now > expires_at
@property
def is_valid(self) -> bool:
"""Check if the authorization code is valid (not used, not expired)."""
return not self.used and not self.is_expired

View File

@@ -0,0 +1,159 @@
"""OAuth provider token models for OAuth provider mode."""
from datetime import UTC, datetime
from sqlalchemy import Boolean, Column, DateTime, ForeignKey, Index, String
from sqlalchemy.dialects.postgresql import UUID
from sqlalchemy.orm import relationship
from .base import Base, TimestampMixin, UUIDMixin
class OAuthProviderRefreshToken(Base, UUIDMixin, TimestampMixin):
"""
OAuth 2.0 Refresh Token for the OAuth provider.
Refresh tokens are:
- Opaque (stored as hash in DB, actual token given to client)
- Long-lived (configurable, default 30 days)
- Revocable (via revoked flag or deletion)
- Bound to specific client, user, and scope
Access tokens are JWTs and not stored in DB (self-contained).
This model only tracks refresh tokens for revocation support.
Security considerations:
- Store token hash, not plaintext
- Support token rotation (new refresh token on use)
- Track last used time for security auditing
- Support revocation by user, client, or admin
"""
__tablename__ = "oauth_provider_refresh_tokens"
# Hash of the refresh token (SHA-256)
# We store hash, not plaintext, for security
token_hash = Column(String(64), unique=True, nullable=False, index=True)
# Unique token ID (JTI) - used in JWT access tokens to reference this refresh token
jti = Column(String(64), unique=True, nullable=False, index=True)
# Client that owns this token
client_id = Column(
String(64),
ForeignKey("oauth_clients.client_id", ondelete="CASCADE"),
nullable=False,
)
# User who authorized this token
user_id = Column(
UUID(as_uuid=True),
ForeignKey("users.id", ondelete="CASCADE"),
nullable=False,
)
# Granted scopes (space-separated)
scope = Column(String(1000), nullable=False, default="")
# Token expiration
expires_at = Column(DateTime(timezone=True), nullable=False)
# Revocation flag
revoked = Column(Boolean, default=False, nullable=False, index=True)
# Last used timestamp (for security auditing)
last_used_at = Column(DateTime(timezone=True), nullable=True)
# Device/session info (optional, for user visibility)
device_info = Column(String(500), nullable=True)
ip_address = Column(String(45), nullable=True)
# Relationships
client = relationship("OAuthClient", backref="refresh_tokens")
user = relationship("User", backref="oauth_provider_refresh_tokens")
# Indexes
__table_args__ = (
Index("ix_oauth_provider_refresh_tokens_expires_at", "expires_at"),
Index("ix_oauth_provider_refresh_tokens_client_user", "client_id", "user_id"),
Index(
"ix_oauth_provider_refresh_tokens_user_revoked",
"user_id",
"revoked",
),
)
def __repr__(self):
status = "revoked" if self.revoked else "active"
return f"<OAuthProviderRefreshToken {self.jti[:8]}... ({status})>"
@property
def is_expired(self) -> bool:
"""Check if the refresh token has expired."""
# Use timezone-aware comparison (datetime.utcnow() is deprecated)
now = datetime.now(UTC)
expires_at = self.expires_at
# Handle both timezone-aware and naive datetimes from DB
if expires_at.tzinfo is None:
expires_at = expires_at.replace(tzinfo=UTC)
return now > expires_at
@property
def is_valid(self) -> bool:
"""Check if the refresh token is valid (not revoked, not expired)."""
return not self.revoked and not self.is_expired
class OAuthConsent(Base, UUIDMixin, TimestampMixin):
"""
OAuth consent record - remembers user consent for a client.
When a user grants consent to an OAuth client, we store the record
so they don't have to re-consent on subsequent authorizations
(unless scopes change).
This enables a better UX - users only see consent screen once per client,
unless the client requests additional scopes.
"""
__tablename__ = "oauth_consents"
# User who granted consent
user_id = Column(
UUID(as_uuid=True),
ForeignKey("users.id", ondelete="CASCADE"),
nullable=False,
)
# Client that received consent
client_id = Column(
String(64),
ForeignKey("oauth_clients.client_id", ondelete="CASCADE"),
nullable=False,
)
# Granted scopes (space-separated)
granted_scopes = Column(String(1000), nullable=False, default="")
# Relationships
client = relationship("OAuthClient", backref="consents")
user = relationship("User", backref="oauth_consents")
# Unique constraint: one consent record per user+client
__table_args__ = (
Index(
"ix_oauth_consents_user_client",
"user_id",
"client_id",
unique=True,
),
)
def __repr__(self):
return f"<OAuthConsent user={self.user_id} client={self.client_id}>"
def has_scopes(self, requested_scopes: list[str]) -> bool:
"""Check if all requested scopes are already granted."""
granted = set(self.granted_scopes.split()) if self.granted_scopes else set()
requested = set(requested_scopes)
return requested.issubset(granted)

View File

@@ -284,6 +284,9 @@ class OAuthServerMetadata(BaseModel):
revocation_endpoint: str | None = Field(
None, description="Token revocation endpoint"
)
introspection_endpoint: str | None = Field(
None, description="Token introspection endpoint (RFC 7662)"
)
scopes_supported: list[str] = Field(
default_factory=list, description="Supported scopes"
)
@@ -297,6 +300,10 @@ class OAuthServerMetadata(BaseModel):
code_challenge_methods_supported: list[str] = Field(
default_factory=lambda: ["S256"], description="Supported PKCE methods"
)
token_endpoint_auth_methods_supported: list[str] = Field(
default_factory=lambda: ["client_secret_basic", "client_secret_post", "none"],
description="Supported client authentication methods",
)
model_config = ConfigDict(
json_schema_extra={
@@ -304,10 +311,85 @@ class OAuthServerMetadata(BaseModel):
"issuer": "https://api.example.com",
"authorization_endpoint": "https://api.example.com/oauth/authorize",
"token_endpoint": "https://api.example.com/oauth/token",
"revocation_endpoint": "https://api.example.com/oauth/revoke",
"introspection_endpoint": "https://api.example.com/oauth/introspect",
"scopes_supported": ["openid", "profile", "email", "read:users"],
"response_types_supported": ["code"],
"grant_types_supported": ["authorization_code", "refresh_token"],
"code_challenge_methods_supported": ["S256"],
"token_endpoint_auth_methods_supported": [
"client_secret_basic",
"client_secret_post",
"none",
],
}
}
)
# ============================================================================
# OAuth Token Responses (RFC 6749)
# ============================================================================
class OAuthTokenResponse(BaseModel):
"""OAuth 2.0 Token Response (RFC 6749 Section 5.1)."""
access_token: str = Field(..., description="The access token issued by the server")
token_type: str = Field(
default="Bearer", description="The type of token (typically 'Bearer')"
)
expires_in: int | None = Field(None, description="Token lifetime in seconds")
refresh_token: str | None = Field(
None, description="Refresh token for obtaining new access tokens"
)
scope: str | None = Field(
None, description="Space-separated list of granted scopes"
)
model_config = ConfigDict(
json_schema_extra={
"example": {
"access_token": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9...",
"token_type": "Bearer",
"expires_in": 3600,
"refresh_token": "dGhpcyBpcyBhIHJlZnJlc2ggdG9rZW4...",
"scope": "openid profile email",
}
}
)
class OAuthTokenIntrospectionResponse(BaseModel):
"""OAuth 2.0 Token Introspection Response (RFC 7662)."""
active: bool = Field(..., description="Whether the token is currently active")
scope: str | None = Field(None, description="Space-separated list of scopes")
client_id: str | None = Field(None, description="Client identifier for the token")
username: str | None = Field(
None, description="Human-readable identifier for the resource owner"
)
token_type: str | None = Field(
None, description="Type of the token (e.g., 'Bearer')"
)
exp: int | None = Field(None, description="Token expiration time (Unix timestamp)")
iat: int | None = Field(None, description="Token issue time (Unix timestamp)")
nbf: int | None = Field(None, description="Token not-before time (Unix timestamp)")
sub: str | None = Field(None, description="Subject of the token (user ID)")
aud: str | None = Field(None, description="Intended audience of the token")
iss: str | None = Field(None, description="Issuer of the token")
model_config = ConfigDict(
json_schema_extra={
"example": {
"active": True,
"scope": "openid profile",
"client_id": "client123",
"username": "user@example.com",
"token_type": "Bearer",
"exp": 1735689600,
"iat": 1735686000,
"sub": "user-uuid-here",
}
}
)

File diff suppressed because it is too large Load Diff

View File

@@ -246,6 +246,15 @@ class OAuthService:
if not state_record:
raise AuthenticationError("Invalid or expired OAuth state")
# SECURITY: Validate redirect_uri matches the one from authorization request
# This prevents authorization code injection attacks (RFC 6749 Section 10.6)
if state_record.redirect_uri != redirect_uri:
logger.warning(
f"OAuth redirect_uri mismatch: expected {state_record.redirect_uri}, "
f"got {redirect_uri}"
)
raise AuthenticationError("Redirect URI mismatch")
# Extract provider from state record (str for type safety)
provider: str = str(state_record.provider)
@@ -272,6 +281,19 @@ class OAuthService:
config["token_url"],
**token_params,
)
# SECURITY: Validate ID token signature and nonce for OpenID Connect
# This prevents token forgery and replay attacks (OIDC Core 3.1.3.7)
if provider == "google" and state_record.nonce:
id_token = token.get("id_token")
if id_token:
await OAuthService._verify_google_id_token(
id_token=str(id_token),
expected_nonce=str(state_record.nonce),
client_id=client_id,
)
except AuthenticationError:
raise
except Exception as e:
logger.error(f"OAuth token exchange failed: {e!s}")
raise AuthenticationError("Failed to exchange authorization code")
@@ -294,8 +316,11 @@ class OAuthService:
# Process user info and create/link account
provider_user_id = str(user_info.get("id") or user_info.get("sub"))
# Email can be None if user didn't grant email permission
# SECURITY: Normalize email (lowercase, strip) to prevent case-based account duplication
email_raw = user_info.get("email")
provider_email: str | None = str(email_raw) if email_raw else None
provider_email: str | None = (
str(email_raw).lower().strip() if email_raw else None
)
if not provider_user_id:
raise AuthenticationError("Provider did not return user ID")
@@ -479,6 +504,106 @@ class OAuthService:
return user_info
# Google's OIDC configuration endpoints
GOOGLE_JWKS_URL = "https://www.googleapis.com/oauth2/v3/certs"
GOOGLE_ISSUERS = ("https://accounts.google.com", "accounts.google.com")
@staticmethod
async def _verify_google_id_token(
id_token: str,
expected_nonce: str,
client_id: str,
) -> dict[str, object]:
"""
Verify Google ID token signature and claims.
SECURITY: This properly verifies the ID token by:
1. Fetching Google's public keys (JWKS)
2. Verifying the JWT signature against the public key
3. Validating issuer, audience, expiry, and nonce claims
Args:
id_token: The ID token JWT string
expected_nonce: The nonce we sent in the authorization request
client_id: Our OAuth client ID (expected audience)
Returns:
Decoded ID token payload
Raises:
AuthenticationError: If verification fails
"""
import httpx
from jose import jwt as jose_jwt
from jose.exceptions import JWTError
try:
# Fetch Google's public keys (JWKS)
# In production, consider caching this with TTL matching Cache-Control header
async with httpx.AsyncClient() as client:
jwks_response = await client.get(
OAuthService.GOOGLE_JWKS_URL,
timeout=10.0,
)
jwks_response.raise_for_status()
jwks = jwks_response.json()
# Get the key ID from the token header
unverified_header = jose_jwt.get_unverified_header(id_token)
kid = unverified_header.get("kid")
if not kid:
raise AuthenticationError("ID token missing key ID (kid)")
# Find the matching public key
public_key = None
for key in jwks.get("keys", []):
if key.get("kid") == kid:
public_key = key
break
if not public_key:
raise AuthenticationError("ID token signed with unknown key")
# Verify the token signature and decode claims
# jose library will verify signature against the JWK
payload = jose_jwt.decode(
id_token,
public_key,
algorithms=["RS256"], # Google uses RS256
audience=client_id,
issuer=OAuthService.GOOGLE_ISSUERS,
options={
"verify_signature": True,
"verify_aud": True,
"verify_iss": True,
"verify_exp": True,
"verify_iat": True,
},
)
# Verify nonce (OIDC replay attack protection)
token_nonce = payload.get("nonce")
if token_nonce != expected_nonce:
logger.warning(
f"OAuth ID token nonce mismatch: expected {expected_nonce}, "
f"got {token_nonce}"
)
raise AuthenticationError("Invalid ID token nonce")
logger.debug("Google ID token verified successfully")
return payload
except JWTError as e:
logger.warning(f"Google ID token verification failed: {e}")
raise AuthenticationError("Invalid ID token signature")
except httpx.HTTPError as e:
logger.error(f"Failed to fetch Google JWKS: {e}")
# If we can't verify the ID token, fail closed for security
raise AuthenticationError("Failed to verify ID token")
except Exception as e:
logger.error(f"Unexpected error verifying Google ID token: {e}")
raise AuthenticationError("ID token verification error")
@staticmethod
async def _create_oauth_user(
db: AsyncSession,

View File

@@ -344,8 +344,8 @@ class TestOAuthProviderEndpoints:
assert response.status_code == 404
@pytest.mark.asyncio
async def test_provider_authorize_skeleton(self, client, async_test_db):
"""Test provider authorize returns not implemented (skeleton)."""
async def test_provider_authorize_requires_auth(self, client, async_test_db):
"""Test provider authorize requires authentication."""
_test_engine, AsyncTestingSessionLocal = async_test_db
# Create a test client
@@ -374,12 +374,12 @@ class TestOAuthProviderEndpoints:
"redirect_uri": "http://localhost:3000/callback",
},
)
# Should return 501 Not Implemented (skeleton)
assert response.status_code == 501
# Authorize endpoint requires authentication
assert response.status_code == 401
@pytest.mark.asyncio
async def test_provider_token_skeleton(self, client):
"""Test provider token returns not implemented (skeleton)."""
async def test_provider_token_requires_client_id(self, client):
"""Test provider token requires client_id."""
with patch("app.api.routes.oauth_provider.settings") as mock_settings:
mock_settings.OAUTH_PROVIDER_ENABLED = True
@@ -390,5 +390,5 @@ class TestOAuthProviderEndpoints:
"code": "test_code",
},
)
# Should return 501 Not Implemented (skeleton)
assert response.status_code == 501
# Missing client_id returns 401 (invalid_client)
assert response.status_code == 401

View File

@@ -203,3 +203,168 @@ async def e2e_client(async_postgres_url):
app.dependency_overrides.clear()
await engine.dispose()
@pytest_asyncio.fixture
async def e2e_superuser(e2e_client):
"""
Create a superuser and return credentials + tokens.
Returns dict with: email, password, tokens, user_id
"""
from uuid import uuid4
email = f"admin-{uuid4().hex[:8]}@example.com"
password = "SuperAdmin123!"
# Register via API first to get proper password hashing
await e2e_client.post(
"/api/v1/auth/register",
json={
"email": email,
"password": password,
"first_name": "Super",
"last_name": "Admin",
},
)
# Login to get tokens
login_resp = await e2e_client.post(
"/api/v1/auth/login",
json={"email": email, "password": password},
)
tokens = login_resp.json()
# Now we need to make this user a superuser directly via SQL
# Get the db session from the client's override
from sqlalchemy import text
from app.core.database import get_db
from app.main import app
async for db in app.dependency_overrides[get_db]():
# Update user to be superuser
await db.execute(
text("UPDATE users SET is_superuser = true WHERE email = :email"),
{"email": email},
)
await db.commit()
# Get user ID
result = await db.execute(
text("SELECT id FROM users WHERE email = :email"),
{"email": email},
)
user_id = str(result.scalar())
break
return {
"email": email,
"password": password,
"tokens": tokens,
"user_id": user_id,
}
@pytest_asyncio.fixture
async def e2e_org_with_members(e2e_client, e2e_superuser):
"""
Create an organization with owner and member.
Returns dict with: org_id, org_slug, owner (tokens), member (tokens)
"""
from uuid import uuid4
# Create organization via admin API
org_name = f"Test Org {uuid4().hex[:8]}"
org_slug = f"test-org-{uuid4().hex[:8]}"
create_resp = await e2e_client.post(
"/api/v1/admin/organizations",
headers={"Authorization": f"Bearer {e2e_superuser['tokens']['access_token']}"},
json={
"name": org_name,
"slug": org_slug,
"description": "Test organization for E2E tests",
},
)
org_data = create_resp.json()
org_id = org_data["id"]
# Create owner user
owner_email = f"owner-{uuid4().hex[:8]}@example.com"
owner_password = "OwnerPass123!"
await e2e_client.post(
"/api/v1/auth/register",
json={
"email": owner_email,
"password": owner_password,
"first_name": "Org",
"last_name": "Owner",
},
)
owner_login = await e2e_client.post(
"/api/v1/auth/login",
json={"email": owner_email, "password": owner_password},
)
owner_tokens = owner_login.json()
# Get owner user ID
owner_me = await e2e_client.get(
"/api/v1/users/me",
headers={"Authorization": f"Bearer {owner_tokens['access_token']}"},
)
owner_id = owner_me.json()["id"]
# Add owner to organization as owner role
await e2e_client.post(
f"/api/v1/admin/organizations/{org_id}/members",
headers={"Authorization": f"Bearer {e2e_superuser['tokens']['access_token']}"},
json={"user_id": owner_id, "role": "owner"},
)
# Create member user
member_email = f"member-{uuid4().hex[:8]}@example.com"
member_password = "MemberPass123!"
await e2e_client.post(
"/api/v1/auth/register",
json={
"email": member_email,
"password": member_password,
"first_name": "Org",
"last_name": "Member",
},
)
member_login = await e2e_client.post(
"/api/v1/auth/login",
json={"email": member_email, "password": member_password},
)
member_tokens = member_login.json()
# Get member user ID
member_me = await e2e_client.get(
"/api/v1/users/me",
headers={"Authorization": f"Bearer {member_tokens['access_token']}"},
)
member_id = member_me.json()["id"]
# Add member to organization
await e2e_client.post(
f"/api/v1/admin/organizations/{org_id}/members",
headers={"Authorization": f"Bearer {e2e_superuser['tokens']['access_token']}"},
json={"user_id": member_id, "role": "member"},
)
return {
"org_id": org_id,
"org_slug": org_slug,
"org_name": org_name,
"owner": {"email": owner_email, "tokens": owner_tokens, "user_id": owner_id},
"member": {
"email": member_email,
"tokens": member_tokens,
"user_id": member_id,
},
}

View File

@@ -0,0 +1,648 @@
"""
Admin superuser E2E workflow tests with real PostgreSQL.
These tests validate admin operations with actual superuser privileges:
- User management (list, create, update, delete, bulk actions)
- Organization management (create, update, delete, members)
- Admin statistics
Usage:
make test-e2e # Run all E2E tests
"""
from uuid import uuid4
import pytest
pytestmark = [
pytest.mark.e2e,
pytest.mark.postgres,
pytest.mark.asyncio,
]
class TestAdminUserManagement:
"""Test admin user management with superuser."""
async def test_admin_list_users(self, e2e_client, e2e_superuser):
"""Superuser can list all users."""
response = await e2e_client.get(
"/api/v1/admin/users",
headers={
"Authorization": f"Bearer {e2e_superuser['tokens']['access_token']}"
},
)
assert response.status_code == 200
data = response.json()
assert "data" in data
assert "pagination" in data
assert len(data["data"]) >= 1 # At least the superuser
async def test_admin_list_users_with_pagination(self, e2e_client, e2e_superuser):
"""Superuser can list users with pagination."""
# Create a few more users
for i in range(3):
await e2e_client.post(
"/api/v1/auth/register",
json={
"email": f"user{i}-{uuid4().hex[:8]}@example.com",
"password": "TestPass123!",
"first_name": f"User{i}",
"last_name": "Test",
},
)
response = await e2e_client.get(
"/api/v1/admin/users",
headers={
"Authorization": f"Bearer {e2e_superuser['tokens']['access_token']}"
},
params={"page": 1, "limit": 2},
)
assert response.status_code == 200
data = response.json()
assert len(data["data"]) <= 2
assert data["pagination"]["page_size"] <= 2
async def test_admin_create_user(self, e2e_client, e2e_superuser):
"""Superuser can create new users."""
email = f"newuser-{uuid4().hex[:8]}@example.com"
response = await e2e_client.post(
"/api/v1/admin/users",
headers={
"Authorization": f"Bearer {e2e_superuser['tokens']['access_token']}"
},
json={
"email": email,
"password": "NewUserPass123!",
"first_name": "New",
"last_name": "User",
},
)
assert response.status_code in [200, 201]
data = response.json()
assert data["email"] == email
async def test_admin_get_user_by_id(self, e2e_client, e2e_superuser):
"""Superuser can get any user by ID."""
# Create a user
email = f"target-{uuid4().hex[:8]}@example.com"
await e2e_client.post(
"/api/v1/auth/register",
json={
"email": email,
"password": "TargetPass123!",
"first_name": "Target",
"last_name": "User",
},
)
# Get user list to find the ID
list_resp = await e2e_client.get(
"/api/v1/admin/users",
headers={
"Authorization": f"Bearer {e2e_superuser['tokens']['access_token']}"
},
)
users = list_resp.json()["data"]
target_user = next(u for u in users if u["email"] == email)
# Get user by ID
response = await e2e_client.get(
f"/api/v1/admin/users/{target_user['id']}",
headers={
"Authorization": f"Bearer {e2e_superuser['tokens']['access_token']}"
},
)
assert response.status_code == 200
assert response.json()["email"] == email
async def test_admin_update_user(self, e2e_client, e2e_superuser):
"""Superuser can update any user."""
# Create a user
email = f"update-{uuid4().hex[:8]}@example.com"
await e2e_client.post(
"/api/v1/auth/register",
json={
"email": email,
"password": "UpdatePass123!",
"first_name": "Update",
"last_name": "User",
},
)
# Get user ID
list_resp = await e2e_client.get(
"/api/v1/admin/users",
headers={
"Authorization": f"Bearer {e2e_superuser['tokens']['access_token']}"
},
)
users = list_resp.json()["data"]
target_user = next(u for u in users if u["email"] == email)
# Update user
response = await e2e_client.put(
f"/api/v1/admin/users/{target_user['id']}",
headers={
"Authorization": f"Bearer {e2e_superuser['tokens']['access_token']}"
},
json={"first_name": "Updated", "last_name": "Name"},
)
assert response.status_code == 200
assert response.json()["first_name"] == "Updated"
async def test_admin_deactivate_user(self, e2e_client, e2e_superuser):
"""Superuser can deactivate users."""
# Create a user
email = f"deactivate-{uuid4().hex[:8]}@example.com"
await e2e_client.post(
"/api/v1/auth/register",
json={
"email": email,
"password": "DeactivatePass123!",
"first_name": "Deactivate",
"last_name": "User",
},
)
# Get user ID
list_resp = await e2e_client.get(
"/api/v1/admin/users",
headers={
"Authorization": f"Bearer {e2e_superuser['tokens']['access_token']}"
},
)
users = list_resp.json()["data"]
target_user = next(u for u in users if u["email"] == email)
# Deactivate user
response = await e2e_client.post(
f"/api/v1/admin/users/{target_user['id']}/deactivate",
headers={
"Authorization": f"Bearer {e2e_superuser['tokens']['access_token']}"
},
)
assert response.status_code == 200
async def test_admin_bulk_action(self, e2e_client, e2e_superuser):
"""Superuser can perform bulk actions on users."""
# Create users for bulk action
user_ids = []
for i in range(2):
email = f"bulk-{i}-{uuid4().hex[:8]}@example.com"
await e2e_client.post(
"/api/v1/auth/register",
json={
"email": email,
"password": "BulkPass123!",
"first_name": f"Bulk{i}",
"last_name": "User",
},
)
# Get user IDs
list_resp = await e2e_client.get(
"/api/v1/admin/users",
headers={
"Authorization": f"Bearer {e2e_superuser['tokens']['access_token']}"
},
)
users = list_resp.json()["data"]
bulk_users = [u for u in users if u["email"].startswith("bulk-")]
user_ids = [u["id"] for u in bulk_users]
# Bulk deactivate
response = await e2e_client.post(
"/api/v1/admin/users/bulk-action",
headers={
"Authorization": f"Bearer {e2e_superuser['tokens']['access_token']}"
},
json={"action": "deactivate", "user_ids": user_ids},
)
assert response.status_code == 200
data = response.json()
assert data["success"] is True
assert data["affected_count"] >= 1
class TestAdminOrganizationManagement:
"""Test admin organization management with superuser."""
async def test_admin_list_organizations(self, e2e_client, e2e_superuser):
"""Superuser can list all organizations."""
response = await e2e_client.get(
"/api/v1/admin/organizations",
headers={
"Authorization": f"Bearer {e2e_superuser['tokens']['access_token']}"
},
)
assert response.status_code == 200
data = response.json()
assert "data" in data
assert "pagination" in data
async def test_admin_create_organization(self, e2e_client, e2e_superuser):
"""Superuser can create organizations."""
org_name = f"Admin Org {uuid4().hex[:8]}"
org_slug = f"admin-org-{uuid4().hex[:8]}"
response = await e2e_client.post(
"/api/v1/admin/organizations",
headers={
"Authorization": f"Bearer {e2e_superuser['tokens']['access_token']}"
},
json={
"name": org_name,
"slug": org_slug,
"description": "Created by admin",
},
)
assert response.status_code in [200, 201]
data = response.json()
assert data["name"] == org_name
assert data["slug"] == org_slug
async def test_admin_get_organization(self, e2e_client, e2e_superuser):
"""Superuser can get organization details."""
# Create org first
org_slug = f"get-org-{uuid4().hex[:8]}"
create_resp = await e2e_client.post(
"/api/v1/admin/organizations",
headers={
"Authorization": f"Bearer {e2e_superuser['tokens']['access_token']}"
},
json={
"name": "Get Org Test",
"slug": org_slug,
},
)
org_id = create_resp.json()["id"]
# Get org
response = await e2e_client.get(
f"/api/v1/admin/organizations/{org_id}",
headers={
"Authorization": f"Bearer {e2e_superuser['tokens']['access_token']}"
},
)
assert response.status_code == 200
assert response.json()["slug"] == org_slug
async def test_admin_update_organization(self, e2e_client, e2e_superuser):
"""Superuser can update organizations."""
# Create org
org_slug = f"update-org-{uuid4().hex[:8]}"
create_resp = await e2e_client.post(
"/api/v1/admin/organizations",
headers={
"Authorization": f"Bearer {e2e_superuser['tokens']['access_token']}"
},
json={"name": "Update Org Test", "slug": org_slug},
)
org_id = create_resp.json()["id"]
# Update org
response = await e2e_client.put(
f"/api/v1/admin/organizations/{org_id}",
headers={
"Authorization": f"Bearer {e2e_superuser['tokens']['access_token']}"
},
json={"name": "Updated Org Name", "description": "Updated description"},
)
assert response.status_code == 200
assert response.json()["name"] == "Updated Org Name"
async def test_admin_add_member_to_organization(self, e2e_client, e2e_superuser):
"""Superuser can add members to organizations."""
# Create org
org_slug = f"member-org-{uuid4().hex[:8]}"
create_resp = await e2e_client.post(
"/api/v1/admin/organizations",
headers={
"Authorization": f"Bearer {e2e_superuser['tokens']['access_token']}"
},
json={"name": "Member Org Test", "slug": org_slug},
)
org_id = create_resp.json()["id"]
# Create user to add
email = f"new-member-{uuid4().hex[:8]}@example.com"
await e2e_client.post(
"/api/v1/auth/register",
json={
"email": email,
"password": "MemberPass123!",
"first_name": "New",
"last_name": "Member",
},
)
# Get user ID
list_resp = await e2e_client.get(
"/api/v1/admin/users",
headers={
"Authorization": f"Bearer {e2e_superuser['tokens']['access_token']}"
},
)
users = list_resp.json()["data"]
new_user = next(u for u in users if u["email"] == email)
# Add to org
response = await e2e_client.post(
f"/api/v1/admin/organizations/{org_id}/members",
headers={
"Authorization": f"Bearer {e2e_superuser['tokens']['access_token']}"
},
json={"user_id": new_user["id"], "role": "member"},
)
assert response.status_code in [200, 201]
async def test_admin_list_organization_members(self, e2e_client, e2e_superuser):
"""Superuser can list organization members."""
# Create org with member
org_slug = f"list-members-org-{uuid4().hex[:8]}"
create_resp = await e2e_client.post(
"/api/v1/admin/organizations",
headers={
"Authorization": f"Bearer {e2e_superuser['tokens']['access_token']}"
},
json={"name": "List Members Org", "slug": org_slug},
)
org_id = create_resp.json()["id"]
# List members
response = await e2e_client.get(
f"/api/v1/admin/organizations/{org_id}/members",
headers={
"Authorization": f"Bearer {e2e_superuser['tokens']['access_token']}"
},
)
assert response.status_code == 200
class TestAdminStats:
"""Test admin statistics endpoints."""
async def test_admin_get_stats(self, e2e_client, e2e_superuser):
"""Superuser can get admin statistics."""
response = await e2e_client.get(
"/api/v1/admin/stats",
headers={
"Authorization": f"Bearer {e2e_superuser['tokens']['access_token']}"
},
)
assert response.status_code == 200
data = response.json()
# Stats should have user growth, org distribution, etc.
assert "user_growth" in data or "user_status" in data
class TestAdminSessionManagement:
"""Test admin session management."""
async def test_admin_list_all_sessions(self, e2e_client, e2e_superuser):
"""Superuser can list all sessions."""
response = await e2e_client.get(
"/api/v1/admin/sessions",
headers={
"Authorization": f"Bearer {e2e_superuser['tokens']['access_token']}"
},
)
assert response.status_code == 200
data = response.json()
assert "data" in data
class TestAdminDeleteOperations:
"""Test admin delete operations."""
async def test_admin_delete_user(self, e2e_client, e2e_superuser):
"""Superuser can delete users."""
# Create user
email = f"delete-{uuid4().hex[:8]}@example.com"
await e2e_client.post(
"/api/v1/auth/register",
json={
"email": email,
"password": "DeletePass123!",
"first_name": "Delete",
"last_name": "User",
},
)
# Get user ID
list_resp = await e2e_client.get(
"/api/v1/admin/users",
headers={
"Authorization": f"Bearer {e2e_superuser['tokens']['access_token']}"
},
)
users = list_resp.json()["data"]
target_user = next(u for u in users if u["email"] == email)
# Delete user
response = await e2e_client.delete(
f"/api/v1/admin/users/{target_user['id']}",
headers={
"Authorization": f"Bearer {e2e_superuser['tokens']['access_token']}"
},
)
assert response.status_code in [200, 204]
async def test_admin_delete_organization(self, e2e_client, e2e_superuser):
"""Superuser can delete organizations."""
# Create org
org_slug = f"delete-org-{uuid4().hex[:8]}"
create_resp = await e2e_client.post(
"/api/v1/admin/organizations",
headers={
"Authorization": f"Bearer {e2e_superuser['tokens']['access_token']}"
},
json={"name": "Delete Org Test", "slug": org_slug},
)
org_id = create_resp.json()["id"]
# Delete org
response = await e2e_client.delete(
f"/api/v1/admin/organizations/{org_id}",
headers={
"Authorization": f"Bearer {e2e_superuser['tokens']['access_token']}"
},
)
assert response.status_code in [200, 204]
async def test_admin_remove_org_member(self, e2e_client, e2e_superuser):
"""Superuser can remove members from organizations."""
# Create org
org_slug = f"remove-member-org-{uuid4().hex[:8]}"
create_resp = await e2e_client.post(
"/api/v1/admin/organizations",
headers={
"Authorization": f"Bearer {e2e_superuser['tokens']['access_token']}"
},
json={"name": "Remove Member Org", "slug": org_slug},
)
org_id = create_resp.json()["id"]
# Create user
email = f"remove-member-{uuid4().hex[:8]}@example.com"
await e2e_client.post(
"/api/v1/auth/register",
json={
"email": email,
"password": "RemovePass123!",
"first_name": "Remove",
"last_name": "Member",
},
)
# Get user ID
list_resp = await e2e_client.get(
"/api/v1/admin/users",
headers={
"Authorization": f"Bearer {e2e_superuser['tokens']['access_token']}"
},
)
users = list_resp.json()["data"]
target_user = next(u for u in users if u["email"] == email)
# Add to org
await e2e_client.post(
f"/api/v1/admin/organizations/{org_id}/members",
headers={
"Authorization": f"Bearer {e2e_superuser['tokens']['access_token']}"
},
json={"user_id": target_user["id"], "role": "member"},
)
# Remove from org
response = await e2e_client.delete(
f"/api/v1/admin/organizations/{org_id}/members/{target_user['id']}",
headers={
"Authorization": f"Bearer {e2e_superuser['tokens']['access_token']}"
},
)
assert response.status_code in [200, 204]
class TestAdminSearchAndFilter:
"""Test admin search and filter capabilities."""
async def test_admin_search_users_by_email(self, e2e_client, e2e_superuser):
"""Superuser can search users by email."""
# Create user with unique prefix
prefix = f"searchable-{uuid4().hex[:8]}"
email = f"{prefix}@example.com"
await e2e_client.post(
"/api/v1/auth/register",
json={
"email": email,
"password": "SearchPass123!",
"first_name": "Search",
"last_name": "User",
},
)
response = await e2e_client.get(
"/api/v1/admin/users",
headers={
"Authorization": f"Bearer {e2e_superuser['tokens']['access_token']}"
},
params={"search": prefix},
)
assert response.status_code == 200
data = response.json()
# Search should find the user
assert len(data["data"]) >= 1
emails = [u["email"] for u in data["data"]]
assert any(prefix in e for e in emails)
async def test_admin_filter_active_users(self, e2e_client, e2e_superuser):
"""Superuser can filter by active status."""
response = await e2e_client.get(
"/api/v1/admin/users",
headers={
"Authorization": f"Bearer {e2e_superuser['tokens']['access_token']}"
},
params={"is_active": True},
)
assert response.status_code == 200
data = response.json()
# All returned users should be active
for user in data["data"]:
assert user["is_active"] is True
async def test_admin_filter_superusers(self, e2e_client, e2e_superuser):
"""Superuser can filter superusers."""
response = await e2e_client.get(
"/api/v1/admin/users",
headers={
"Authorization": f"Bearer {e2e_superuser['tokens']['access_token']}"
},
params={"is_superuser": True},
)
assert response.status_code == 200
data = response.json()
# Should find at least the test superuser
assert len(data["data"]) >= 1
async def test_admin_sort_users(self, e2e_client, e2e_superuser):
"""Superuser can sort users by different fields."""
response = await e2e_client.get(
"/api/v1/admin/users",
headers={
"Authorization": f"Bearer {e2e_superuser['tokens']['access_token']}"
},
params={"sort_by": "created_at", "sort_order": "desc"},
)
assert response.status_code == 200
data = response.json()
assert "data" in data
async def test_admin_search_organizations(self, e2e_client, e2e_superuser):
"""Superuser can search organizations."""
# Create org with unique name
prefix = f"searchorg-{uuid4().hex[:8]}"
await e2e_client.post(
"/api/v1/admin/organizations",
headers={
"Authorization": f"Bearer {e2e_superuser['tokens']['access_token']}"
},
json={"name": f"{prefix} Test", "slug": f"{prefix}-slug"},
)
response = await e2e_client.get(
"/api/v1/admin/organizations",
headers={
"Authorization": f"Bearer {e2e_superuser['tokens']['access_token']}"
},
params={"search": prefix},
)
assert response.status_code == 200
data = response.json()
assert len(data["data"]) >= 1

View File

@@ -0,0 +1,212 @@
"""
Admin E2E workflow tests with real PostgreSQL.
These tests validate complete admin workflows including:
- User management (list, create, update, delete, bulk actions)
- Organization management (create, update, delete, members)
- Admin statistics
Usage:
make test-e2e # Run all E2E tests
"""
from uuid import uuid4
import pytest
pytestmark = [
pytest.mark.e2e,
pytest.mark.postgres,
pytest.mark.asyncio,
]
async def register_user(client, email: str, password: str = "SecurePassword123!"): # noqa: S107
"""Helper to register a user."""
resp = await client.post(
"/api/v1/auth/register",
json={
"email": email,
"password": password,
"first_name": "Test",
"last_name": "User",
},
)
return resp.json()
async def login_user(client, email: str, password: str = "SecurePassword123!"): # noqa: S107
"""Helper to login a user."""
resp = await client.post(
"/api/v1/auth/login",
json={"email": email, "password": password},
)
return resp.json()
async def create_superuser(e2e_db_session, email: str, password: str):
"""Create a superuser directly in the database."""
from app.crud.user import user as user_crud
from app.schemas.users import UserCreate
user_in = UserCreate(
email=email,
password=password,
first_name="Admin",
last_name="User",
is_superuser=True,
)
user = await user_crud.create(e2e_db_session, obj_in=user_in)
return user
class TestAdminUserManagementWorkflows:
"""Test admin user management workflows."""
async def test_regular_user_cannot_access_admin_endpoints(self, e2e_client):
"""Regular users cannot access admin endpoints."""
email = f"e2e-{uuid4().hex[:8]}@example.com"
await register_user(e2e_client, email)
tokens = await login_user(e2e_client, email)
response = await e2e_client.get(
"/api/v1/admin/users",
headers={"Authorization": f"Bearer {tokens['access_token']}"},
)
assert response.status_code == 403
async def test_admin_stats_requires_superuser(self, e2e_client):
"""Admin stats endpoint requires superuser."""
email = f"e2e-{uuid4().hex[:8]}@example.com"
await register_user(e2e_client, email)
tokens = await login_user(e2e_client, email)
response = await e2e_client.get(
"/api/v1/admin/stats",
headers={"Authorization": f"Bearer {tokens['access_token']}"},
)
assert response.status_code == 403
async def test_admin_create_user_requires_superuser(self, e2e_client):
"""Creating users via admin endpoint requires superuser."""
email = f"e2e-{uuid4().hex[:8]}@example.com"
await register_user(e2e_client, email)
tokens = await login_user(e2e_client, email)
response = await e2e_client.post(
"/api/v1/admin/users",
headers={"Authorization": f"Bearer {tokens['access_token']}"},
json={
"email": f"newuser-{uuid4().hex[:8]}@example.com",
"password": "NewUserPass123!",
"first_name": "New",
"last_name": "User",
},
)
assert response.status_code == 403
class TestAdminOrganizationWorkflows:
"""Test admin organization management workflows."""
async def test_regular_user_cannot_list_admin_orgs(self, e2e_client):
"""Regular users cannot list organizations via admin endpoint."""
email = f"e2e-{uuid4().hex[:8]}@example.com"
await register_user(e2e_client, email)
tokens = await login_user(e2e_client, email)
response = await e2e_client.get(
"/api/v1/admin/organizations",
headers={"Authorization": f"Bearer {tokens['access_token']}"},
)
assert response.status_code == 403
async def test_regular_user_cannot_create_org_via_admin(self, e2e_client):
"""Regular users cannot create organizations via admin endpoint."""
email = f"e2e-{uuid4().hex[:8]}@example.com"
await register_user(e2e_client, email)
tokens = await login_user(e2e_client, email)
response = await e2e_client.post(
"/api/v1/admin/organizations",
headers={"Authorization": f"Bearer {tokens['access_token']}"},
json={
"name": "Test Org",
"slug": f"test-org-{uuid4().hex[:8]}",
"description": "Test organization",
},
)
assert response.status_code == 403
class TestAdminSessionWorkflows:
"""Test admin session management workflows."""
async def test_regular_user_cannot_list_admin_sessions(self, e2e_client):
"""Regular users cannot list sessions via admin endpoint."""
email = f"e2e-{uuid4().hex[:8]}@example.com"
await register_user(e2e_client, email)
tokens = await login_user(e2e_client, email)
response = await e2e_client.get(
"/api/v1/admin/sessions",
headers={"Authorization": f"Bearer {tokens['access_token']}"},
)
assert response.status_code == 403
class TestAdminBulkOperations:
"""Test admin bulk operation workflows."""
async def test_regular_user_cannot_bulk_activate_users(self, e2e_client):
"""Regular users cannot perform bulk user activation."""
email = f"e2e-{uuid4().hex[:8]}@example.com"
await register_user(e2e_client, email)
tokens = await login_user(e2e_client, email)
response = await e2e_client.post(
"/api/v1/admin/users/bulk-action",
headers={"Authorization": f"Bearer {tokens['access_token']}"},
json={
"action": "activate",
"user_ids": [str(uuid4())],
},
)
assert response.status_code == 403
class TestAdminAuthorizationBoundaries:
"""Test admin authorization security boundaries."""
async def test_unauthenticated_cannot_access_admin(self, e2e_client):
"""Unauthenticated requests cannot access admin endpoints."""
endpoints = [
("/api/v1/admin/users", "get"),
("/api/v1/admin/organizations", "get"),
("/api/v1/admin/sessions", "get"),
("/api/v1/admin/stats", "get"),
]
for endpoint, method in endpoints:
if method == "get":
response = await e2e_client.get(endpoint)
assert response.status_code == 401, f"Expected 401 for {endpoint}"
async def test_expired_token_rejected_for_admin(self, e2e_client):
"""Expired tokens are rejected for admin endpoints."""
# Use a clearly invalid/malformed token
fake_token = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiaWF0IjoxNTE2MjM5MDIyfQ.SflKxwRJSMeKKF2QT4fwpMeJf36POk6yJV_adQssw5c"
response = await e2e_client.get(
"/api/v1/admin/users",
headers={"Authorization": f"Bearer {fake_token}"},
)
assert response.status_code == 401

View File

@@ -40,56 +40,154 @@ if SCHEMATHESIS_AVAILABLE:
# Load schema from the FastAPI app using schemathesis.openapi (v4.x API)
schema = openapi.from_asgi("/api/v1/openapi.json", app=app)
# Test root endpoint (simple, always works)
# =========================================================================
# Public Endpoints (No Auth Required)
# =========================================================================
# Test root endpoint
root_schema = schema.include(path="/")
@root_schema.parametrize()
@settings(max_examples=5)
def test_root_endpoint_schema(case):
"""
Root endpoint schema compliance.
Tests that the root endpoint returns responses matching its schema.
"""
"""Root endpoint schema compliance."""
response = case.call()
# Just verify we get a response and no 5xx errors
assert response.status_code < 500, f"Server error: {response.text}"
# Test health endpoint
health_schema = schema.include(path="/health")
@health_schema.parametrize()
@settings(max_examples=3)
def test_health_endpoint_schema(case):
"""Health endpoint schema compliance."""
response = case.call()
# Health check may return 200 or 503 depending on DB
assert response.status_code < 500 or response.status_code == 503
# Test auth registration endpoint
# Note: This tests schema validation, not actual database operations
auth_register_schema = schema.include(path="/api/v1/auth/register")
@auth_register_schema.parametrize()
@settings(max_examples=10)
def test_register_endpoint_validates_input(case):
"""
Registration endpoint input validation.
Schemathesis generates various inputs to test validation.
The endpoint should never return 5xx errors for invalid input.
"""
"""Registration endpoint input validation."""
response = case.call()
# Registration returns 200/201 (success), 400/422 (validation), 409 (conflict)
# Never a 5xx error for validation issues
# 200/201 (success), 400/422 (validation), 409 (conflict)
assert response.status_code < 500, f"Server error: {response.text}"
# Note: Login and refresh endpoints require database, so they're tested
# in test_database_workflows.py instead of here. Schemathesis tests run
# without the testcontainers database fixtures.
# =========================================================================
# Protected Endpoints - Manual tests for auth requirements
# (Schemathesis parametrize tests all methods, manual tests are clearer)
# =========================================================================
class TestProtectedEndpointsRequireAuth:
"""Test that protected endpoints return proper auth errors."""
def test_users_me_requires_auth(self):
"""Users/me GET endpoint requires authentication."""
from starlette.testclient import TestClient
with TestClient(app) as client:
response = client.get("/api/v1/users/me")
assert response.status_code == 401
def test_sessions_me_requires_auth(self):
"""Sessions/me GET endpoint requires authentication."""
from starlette.testclient import TestClient
with TestClient(app) as client:
response = client.get("/api/v1/sessions/me")
assert response.status_code == 401
def test_organizations_me_requires_auth(self):
"""Organizations/me GET endpoint requires authentication."""
from starlette.testclient import TestClient
with TestClient(app) as client:
response = client.get("/api/v1/organizations/me")
assert response.status_code == 401
def test_admin_users_requires_auth(self):
"""Admin users GET endpoint requires authentication."""
from starlette.testclient import TestClient
with TestClient(app) as client:
response = client.get("/api/v1/admin/users")
assert response.status_code == 401
def test_admin_stats_requires_auth(self):
"""Admin stats GET endpoint requires authentication."""
from starlette.testclient import TestClient
with TestClient(app) as client:
response = client.get("/api/v1/admin/stats")
assert response.status_code == 401
def test_admin_organizations_requires_auth(self):
"""Admin organizations GET endpoint requires authentication."""
from starlette.testclient import TestClient
with TestClient(app) as client:
response = client.get("/api/v1/admin/organizations")
assert response.status_code == 401
# =========================================================================
# Schema Validation Tests
# =========================================================================
class TestSchemaValidation:
"""Manual validation tests for schema structure."""
def test_schema_loaded_successfully(self):
"""Verify schema was loaded from the app."""
# Count operations to verify schema loaded
ops = list(schema.get_all_operations())
assert len(ops) > 0, "No operations found in schema"
def test_multiple_endpoints_documented(self):
"""Verify multiple endpoints are documented in schema."""
ops = list(schema.get_all_operations())
# Should have at least 10 operations in a real API
assert len(ops) >= 10, f"Only {len(ops)} operations found"
def test_schema_has_auth_operations(self):
"""Verify auth-related operations exist."""
# Filter for auth endpoints
auth_ops = list(schema.include(path_regex=r".*auth.*").get_all_operations())
assert len(auth_ops) > 0, "No auth operations found"
def test_schema_has_user_operations(self):
"""Verify user-related operations exist."""
user_ops = list(
schema.include(path_regex=r".*users.*").get_all_operations()
)
assert len(user_ops) > 0, "No user operations found"
def test_schema_has_organization_operations(self):
"""Verify organization-related operations exist."""
org_ops = list(
schema.include(path_regex=r".*organizations.*").get_all_operations()
)
assert len(org_ops) > 0, "No organization operations found"
def test_schema_has_admin_operations(self):
"""Verify admin-related operations exist."""
admin_ops = list(
schema.include(path_regex=r".*admin.*").get_all_operations()
)
assert len(admin_ops) > 0, "No admin operations found"
def test_schema_has_session_operations(self):
"""Verify session-related operations exist."""
session_ops = list(
schema.include(path_regex=r".*sessions.*").get_all_operations()
)
assert len(session_ops) > 0, "No session operations found"
def test_total_endpoint_count(self):
"""Verify expected number of endpoints are documented."""
ops = list(schema.get_all_operations())
# We expect at least 40+ endpoints in this comprehensive API
assert len(ops) >= 40, f"Only {len(ops)} operations found, expected 40+"

View File

@@ -188,3 +188,134 @@ class TestHealthEndpoint:
assert response.status_code in [200, 503]
data = response.json()
assert "status" in data
class TestLogoutWorkflows:
"""Test logout workflows."""
async def test_logout_invalidates_session(self, e2e_client):
"""Test that logout invalidates the session."""
email = f"e2e-logout-{uuid4().hex[:8]}@example.com"
password = "SecurePassword123!"
# Register and login
await e2e_client.post(
"/api/v1/auth/register",
json={
"email": email,
"password": password,
"first_name": "Logout",
"last_name": "Test",
},
)
login_resp = await e2e_client.post(
"/api/v1/auth/login",
json={"email": email, "password": password},
)
tokens = login_resp.json()
# Logout requires both access token (auth) and refresh token (body)
logout_resp = await e2e_client.post(
"/api/v1/auth/logout",
headers={"Authorization": f"Bearer {tokens['access_token']}"},
json={"refresh_token": tokens["refresh_token"]},
)
assert logout_resp.status_code == 200
async def test_invalid_refresh_token_rejected(self, e2e_client):
"""Test that invalid refresh tokens are rejected."""
response = await e2e_client.post(
"/api/v1/auth/refresh",
json={"refresh_token": "invalid_refresh_token"},
)
assert response.status_code in [401, 422]
class TestValidationWorkflows:
"""Test input validation workflows."""
async def test_register_invalid_email(self, e2e_client):
"""Test that invalid email format is rejected."""
response = await e2e_client.post(
"/api/v1/auth/register",
json={
"email": "not_an_email",
"password": "ValidPassword123!",
"first_name": "Test",
"last_name": "User",
},
)
assert response.status_code == 422
async def test_register_weak_password(self, e2e_client):
"""Test that weak passwords are rejected."""
email = f"e2e-weak-{uuid4().hex[:8]}@example.com"
response = await e2e_client.post(
"/api/v1/auth/register",
json={
"email": email,
"password": "weak", # Too weak
"first_name": "Test",
"last_name": "User",
},
)
assert response.status_code == 422
async def test_login_missing_fields(self, e2e_client):
"""Test that login requires all fields."""
response = await e2e_client.post(
"/api/v1/auth/login",
json={"email": "test@example.com"}, # Missing password
)
assert response.status_code == 422
class TestRootEndpoint:
"""Test root endpoint."""
async def test_root_responds(self, e2e_client):
"""Root endpoint should respond with HTML."""
response = await e2e_client.get("/")
assert response.status_code == 200
# Root returns HTML
assert "html" in response.text.lower() or "Welcome" in response.text
async def test_openapi_available(self, e2e_client):
"""OpenAPI schema should be available."""
response = await e2e_client.get("/api/v1/openapi.json")
assert response.status_code == 200
data = response.json()
assert "openapi" in data
assert "paths" in data
class TestAuthTokenWorkflows:
"""Test authentication token workflows."""
async def test_access_token_expires(self, e2e_client):
"""Test using expired access token."""
# Use a fake/expired token
fake_token = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwiZXhwIjoxNjAwMDAwMDAwfQ.invalid"
response = await e2e_client.get(
"/api/v1/users/me",
headers={"Authorization": f"Bearer {fake_token}"},
)
assert response.status_code == 401
async def test_malformed_token_rejected(self, e2e_client):
"""Test that malformed tokens are rejected."""
response = await e2e_client.get(
"/api/v1/users/me",
headers={"Authorization": "Bearer not-a-valid-token"},
)
assert response.status_code == 401
async def test_missing_bearer_prefix(self, e2e_client):
"""Test that tokens without Bearer prefix are rejected."""
response = await e2e_client.get(
"/api/v1/users/me",
headers={"Authorization": "some-token"},
)
assert response.status_code == 401

View File

@@ -0,0 +1,353 @@
"""
Organization E2E workflow tests with real PostgreSQL.
These tests validate complete organization workflows including:
- Creating organizations (via admin)
- Viewing user's organizations
- Organization membership management
- Organization updates
Usage:
make test-e2e # Run all E2E tests
"""
from uuid import uuid4
import pytest
pytestmark = [
pytest.mark.e2e,
pytest.mark.postgres,
pytest.mark.asyncio,
]
async def register_and_login(client, email: str, password: str = "SecurePassword123!"): # noqa: S107
"""Helper to register a user and get tokens."""
# Register
await client.post(
"/api/v1/auth/register",
json={
"email": email,
"password": password,
"first_name": "Test",
"last_name": "User",
},
)
# Login
login_resp = await client.post(
"/api/v1/auth/login",
json={"email": email, "password": password},
)
tokens = login_resp.json()
return tokens
async def create_superuser_and_login(client, db_session):
"""Helper to create a superuser directly in DB and login."""
from app.crud.user import user as user_crud
from app.schemas.users import UserCreate
email = f"admin-{uuid4().hex[:8]}@example.com"
password = "AdminPassword123!"
# Create superuser directly
user_in = UserCreate(
email=email,
password=password,
first_name="Admin",
last_name="User",
is_superuser=True,
)
await user_crud.create(db_session, obj_in=user_in)
# Login
login_resp = await client.post(
"/api/v1/auth/login",
json={"email": email, "password": password},
)
return login_resp.json(), email
class TestOrganizationWorkflows:
"""Test organization management workflows."""
async def test_user_has_no_organizations_initially(self, e2e_client):
"""New users should have no organizations."""
email = f"e2e-{uuid4().hex[:8]}@example.com"
tokens = await register_and_login(e2e_client, email)
response = await e2e_client.get(
"/api/v1/organizations/me",
headers={"Authorization": f"Bearer {tokens['access_token']}"},
)
assert response.status_code == 200
data = response.json()
assert isinstance(data, list)
assert len(data) == 0
async def test_get_organizations_requires_auth(self, e2e_client):
"""Organizations endpoint requires authentication."""
response = await e2e_client.get("/api/v1/organizations/me")
assert response.status_code == 401
async def test_get_nonexistent_organization(self, e2e_client):
"""Getting a non-member organization returns 403."""
email = f"e2e-{uuid4().hex[:8]}@example.com"
tokens = await register_and_login(e2e_client, email)
fake_org_id = str(uuid4())
response = await e2e_client.get(
f"/api/v1/organizations/{fake_org_id}",
headers={"Authorization": f"Bearer {tokens['access_token']}"},
)
# Should be 403 (not a member) or 404 (not found)
assert response.status_code in [403, 404]
class TestOrganizationMembershipWorkflows:
"""Test organization membership workflows."""
async def test_non_member_cannot_view_org_details(self, e2e_client):
"""Users cannot view organizations they're not members of."""
# Create two users
user1_email = f"e2e-user1-{uuid4().hex[:8]}@example.com"
user2_email = f"e2e-user2-{uuid4().hex[:8]}@example.com"
await register_and_login(e2e_client, user1_email)
user2_tokens = await register_and_login(e2e_client, user2_email)
# User2 tries to access a random org ID
fake_org_id = str(uuid4())
response = await e2e_client.get(
f"/api/v1/organizations/{fake_org_id}",
headers={"Authorization": f"Bearer {user2_tokens['access_token']}"},
)
assert response.status_code in [403, 404]
async def test_non_member_cannot_view_org_members(self, e2e_client):
"""Users cannot view members of organizations they don't belong to."""
email = f"e2e-{uuid4().hex[:8]}@example.com"
tokens = await register_and_login(e2e_client, email)
fake_org_id = str(uuid4())
response = await e2e_client.get(
f"/api/v1/organizations/{fake_org_id}/members",
headers={"Authorization": f"Bearer {tokens['access_token']}"},
)
assert response.status_code in [403, 404]
async def test_non_admin_cannot_update_organization(self, e2e_client):
"""Regular users cannot update organizations (need admin role)."""
email = f"e2e-{uuid4().hex[:8]}@example.com"
tokens = await register_and_login(e2e_client, email)
fake_org_id = str(uuid4())
response = await e2e_client.put(
f"/api/v1/organizations/{fake_org_id}",
headers={"Authorization": f"Bearer {tokens['access_token']}"},
json={"name": "Updated Name"},
)
assert response.status_code in [403, 404]
class TestOrganizationWithMembers:
"""Test organization workflows using e2e_org_with_members fixture."""
async def test_owner_can_view_organization(self, e2e_client, e2e_org_with_members):
"""Organization owner can view organization details."""
org = e2e_org_with_members
response = await e2e_client.get(
f"/api/v1/organizations/{org['org_id']}",
headers={
"Authorization": f"Bearer {org['owner']['tokens']['access_token']}"
},
)
assert response.status_code == 200
data = response.json()
assert data["id"] == org["org_id"]
assert data["name"] == org["org_name"]
async def test_member_can_view_organization(self, e2e_client, e2e_org_with_members):
"""Organization member can view organization details."""
org = e2e_org_with_members
response = await e2e_client.get(
f"/api/v1/organizations/{org['org_id']}",
headers={
"Authorization": f"Bearer {org['member']['tokens']['access_token']}"
},
)
assert response.status_code == 200
data = response.json()
assert data["id"] == org["org_id"]
async def test_owner_can_list_members(self, e2e_client, e2e_org_with_members):
"""Organization owner can list members."""
org = e2e_org_with_members
response = await e2e_client.get(
f"/api/v1/organizations/{org['org_id']}/members",
headers={
"Authorization": f"Bearer {org['owner']['tokens']['access_token']}"
},
)
assert response.status_code == 200
data = response.json()
# Should have owner + member = at least 2 members
assert len(data) >= 2
async def test_member_can_list_members(self, e2e_client, e2e_org_with_members):
"""Organization member can list members."""
org = e2e_org_with_members
response = await e2e_client.get(
f"/api/v1/organizations/{org['org_id']}/members",
headers={
"Authorization": f"Bearer {org['member']['tokens']['access_token']}"
},
)
assert response.status_code == 200
async def test_owner_appears_in_my_organizations(
self, e2e_client, e2e_org_with_members
):
"""Owner sees organization in their organizations list."""
org = e2e_org_with_members
response = await e2e_client.get(
"/api/v1/organizations/me",
headers={
"Authorization": f"Bearer {org['owner']['tokens']['access_token']}"
},
)
assert response.status_code == 200
data = response.json()
org_ids = [o["id"] for o in data]
assert org["org_id"] in org_ids
async def test_member_appears_in_my_organizations(
self, e2e_client, e2e_org_with_members
):
"""Member sees organization in their organizations list."""
org = e2e_org_with_members
response = await e2e_client.get(
"/api/v1/organizations/me",
headers={
"Authorization": f"Bearer {org['member']['tokens']['access_token']}"
},
)
assert response.status_code == 200
data = response.json()
org_ids = [o["id"] for o in data]
assert org["org_id"] in org_ids
async def test_owner_can_update_organization(
self, e2e_client, e2e_org_with_members
):
"""Organization owner can update organization details."""
org = e2e_org_with_members
new_description = f"Updated at {uuid4().hex[:8]}"
response = await e2e_client.put(
f"/api/v1/organizations/{org['org_id']}",
headers={
"Authorization": f"Bearer {org['owner']['tokens']['access_token']}"
},
json={"description": new_description},
)
assert response.status_code == 200
data = response.json()
assert data["description"] == new_description
async def test_member_cannot_update_organization(
self, e2e_client, e2e_org_with_members
):
"""Regular member cannot update organization details."""
org = e2e_org_with_members
response = await e2e_client.put(
f"/api/v1/organizations/{org['org_id']}",
headers={
"Authorization": f"Bearer {org['member']['tokens']['access_token']}"
},
json={"description": "Should fail"},
)
assert response.status_code == 403
async def test_non_member_cannot_view_organization(
self, e2e_client, e2e_org_with_members
):
"""Non-members cannot view organization details."""
org = e2e_org_with_members
# Create a new user who is not a member
non_member_email = f"nonmember-{uuid4().hex[:8]}@example.com"
tokens = await register_and_login(e2e_client, non_member_email)
response = await e2e_client.get(
f"/api/v1/organizations/{org['org_id']}",
headers={"Authorization": f"Bearer {tokens['access_token']}"},
)
assert response.status_code == 403
async def test_get_organization_by_slug(self, e2e_client, e2e_org_with_members):
"""Organization can be retrieved by slug."""
org = e2e_org_with_members
response = await e2e_client.get(
f"/api/v1/organizations/slug/{org['org_slug']}",
headers={
"Authorization": f"Bearer {org['owner']['tokens']['access_token']}"
},
)
# May be 200 or 404/403 depending on implementation
assert response.status_code in [200, 403, 404]
class TestOrganizationAdminOperations:
"""Test organization admin operations."""
async def test_admin_list_org_members_with_pagination(
self, e2e_client, e2e_superuser, e2e_org_with_members
):
"""Admin can list org members with pagination."""
org = e2e_org_with_members
response = await e2e_client.get(
f"/api/v1/admin/organizations/{org['org_id']}/members",
headers={
"Authorization": f"Bearer {e2e_superuser['tokens']['access_token']}"
},
params={"page": 1, "limit": 10},
)
assert response.status_code == 200
data = response.json()
assert "data" in data
assert "pagination" in data
async def test_admin_list_org_members_filter_active(
self, e2e_client, e2e_superuser, e2e_org_with_members
):
"""Admin can filter org members by active status."""
org = e2e_org_with_members
response = await e2e_client.get(
f"/api/v1/admin/organizations/{org['org_id']}/members",
headers={
"Authorization": f"Bearer {e2e_superuser['tokens']['access_token']}"
},
params={"is_active": True},
)
assert response.status_code == 200

View File

@@ -0,0 +1,331 @@
"""
Session management E2E workflow tests with real PostgreSQL.
These tests validate complete session management workflows including:
- Listing active sessions
- Session revocation
- Session cleanup
- Multi-device session handling
Usage:
make test-e2e # Run all E2E tests
"""
from uuid import uuid4
import pytest
pytestmark = [
pytest.mark.e2e,
pytest.mark.postgres,
pytest.mark.asyncio,
]
async def register_and_login(
client,
email: str,
password: str = "SecurePassword123!", # noqa: S107
user_agent: str | None = None,
):
"""Helper to register a user and get tokens."""
await client.post(
"/api/v1/auth/register",
json={
"email": email,
"password": password,
"first_name": "Test",
"last_name": "User",
},
)
headers = {}
if user_agent:
headers["User-Agent"] = user_agent
login_resp = await client.post(
"/api/v1/auth/login",
json={"email": email, "password": password},
headers=headers,
)
return login_resp.json()
class TestSessionListingWorkflows:
"""Test session listing workflows."""
async def test_list_sessions_after_login(self, e2e_client):
"""Users can list their active sessions after login."""
email = f"e2e-{uuid4().hex[:8]}@example.com"
tokens = await register_and_login(e2e_client, email)
response = await e2e_client.get(
"/api/v1/sessions/me",
headers={"Authorization": f"Bearer {tokens['access_token']}"},
)
assert response.status_code == 200
data = response.json()
assert "sessions" in data
assert "total" in data
assert data["total"] >= 1
assert len(data["sessions"]) >= 1
async def test_session_contains_expected_fields(self, e2e_client):
"""Session response contains expected fields."""
email = f"e2e-{uuid4().hex[:8]}@example.com"
tokens = await register_and_login(e2e_client, email)
response = await e2e_client.get(
"/api/v1/sessions/me",
headers={"Authorization": f"Bearer {tokens['access_token']}"},
)
data = response.json()
session = data["sessions"][0]
# Check required fields
assert "id" in session
assert "created_at" in session
assert "last_used_at" in session
assert "is_current" in session
async def test_list_sessions_requires_auth(self, e2e_client):
"""Sessions endpoint requires authentication."""
response = await e2e_client.get("/api/v1/sessions/me")
assert response.status_code == 401
async def test_multiple_logins_create_multiple_sessions(self, e2e_client):
"""Multiple logins create multiple sessions."""
email = f"e2e-{uuid4().hex[:8]}@example.com"
password = "SecurePassword123!"
# Register
await e2e_client.post(
"/api/v1/auth/register",
json={
"email": email,
"password": password,
"first_name": "Test",
"last_name": "User",
},
)
# Login multiple times with different user agents
tokens1 = (
await e2e_client.post(
"/api/v1/auth/login",
json={"email": email, "password": password},
headers={"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64)"},
)
).json()
# Second login to create another session
await e2e_client.post(
"/api/v1/auth/login",
json={"email": email, "password": password},
headers={"User-Agent": "Mozilla/5.0 (iPhone; CPU iPhone OS 15_0)"},
)
# Check sessions using first token
response = await e2e_client.get(
"/api/v1/sessions/me",
headers={"Authorization": f"Bearer {tokens1['access_token']}"},
)
data = response.json()
assert data["total"] >= 2
class TestSessionRevocationWorkflows:
"""Test session revocation workflows."""
async def test_revoke_own_session(self, e2e_client):
"""Users can revoke their own sessions."""
email = f"e2e-{uuid4().hex[:8]}@example.com"
password = "SecurePassword123!"
# Register
await e2e_client.post(
"/api/v1/auth/register",
json={
"email": email,
"password": password,
"first_name": "Test",
"last_name": "User",
},
)
# Create two sessions
tokens1 = (
await e2e_client.post(
"/api/v1/auth/login",
json={"email": email, "password": password},
)
).json()
# Second login to create another session
await e2e_client.post(
"/api/v1/auth/login",
json={"email": email, "password": password},
)
# Get sessions
sessions_resp = await e2e_client.get(
"/api/v1/sessions/me",
headers={"Authorization": f"Bearer {tokens1['access_token']}"},
)
sessions = sessions_resp.json()["sessions"]
initial_count = len(sessions)
# Revoke one session (not the current one)
session_to_revoke = sessions[-1]["id"]
revoke_resp = await e2e_client.delete(
f"/api/v1/sessions/{session_to_revoke}",
headers={"Authorization": f"Bearer {tokens1['access_token']}"},
)
assert revoke_resp.status_code == 200
assert revoke_resp.json()["success"] is True
# Verify session count decreased
updated_sessions_resp = await e2e_client.get(
"/api/v1/sessions/me",
headers={"Authorization": f"Bearer {tokens1['access_token']}"},
)
updated_count = updated_sessions_resp.json()["total"]
assert updated_count == initial_count - 1
async def test_cannot_revoke_nonexistent_session(self, e2e_client):
"""Cannot revoke a session that doesn't exist."""
email = f"e2e-{uuid4().hex[:8]}@example.com"
tokens = await register_and_login(e2e_client, email)
fake_session_id = str(uuid4())
response = await e2e_client.delete(
f"/api/v1/sessions/{fake_session_id}",
headers={"Authorization": f"Bearer {tokens['access_token']}"},
)
assert response.status_code == 404
async def test_cannot_revoke_other_user_session(self, e2e_client):
"""Users cannot revoke other users' sessions."""
user1_email = f"e2e-user1-{uuid4().hex[:8]}@example.com"
user2_email = f"e2e-user2-{uuid4().hex[:8]}@example.com"
tokens1 = await register_and_login(e2e_client, user1_email)
tokens2 = await register_and_login(e2e_client, user2_email)
# Get user1's session ID
sessions_resp = await e2e_client.get(
"/api/v1/sessions/me",
headers={"Authorization": f"Bearer {tokens1['access_token']}"},
)
user1_session_id = sessions_resp.json()["sessions"][0]["id"]
# User2 tries to revoke user1's session
response = await e2e_client.delete(
f"/api/v1/sessions/{user1_session_id}",
headers={"Authorization": f"Bearer {tokens2['access_token']}"},
)
assert response.status_code == 403
class TestSessionCleanupWorkflows:
"""Test session cleanup workflows."""
async def test_cleanup_expired_sessions(self, e2e_client):
"""Users can cleanup their expired sessions."""
email = f"e2e-{uuid4().hex[:8]}@example.com"
tokens = await register_and_login(e2e_client, email)
response = await e2e_client.delete(
"/api/v1/sessions/me/expired",
headers={"Authorization": f"Bearer {tokens['access_token']}"},
)
assert response.status_code == 200
data = response.json()
assert data["success"] is True
assert "Cleaned up" in data["message"]
async def test_cleanup_requires_auth(self, e2e_client):
"""Session cleanup requires authentication."""
response = await e2e_client.delete("/api/v1/sessions/me/expired")
assert response.status_code == 401
class TestLogoutWorkflows:
"""Test logout workflows."""
async def test_logout_invalidates_session(self, e2e_client):
"""Logout should invalidate the session."""
email = f"e2e-{uuid4().hex[:8]}@example.com"
tokens = await register_and_login(e2e_client, email)
# Logout
logout_resp = await e2e_client.post(
"/api/v1/auth/logout",
headers={"Authorization": f"Bearer {tokens['access_token']}"},
json={"refresh_token": tokens["refresh_token"]},
)
assert logout_resp.status_code == 200
# Refresh token should no longer work
refresh_resp = await e2e_client.post(
"/api/v1/auth/refresh",
json={"refresh_token": tokens["refresh_token"]},
)
# May be 401 or 400 depending on implementation
assert refresh_resp.status_code in [400, 401]
async def test_logout_all_revokes_all_sessions(self, e2e_client):
"""Logout all should revoke all sessions."""
email = f"e2e-{uuid4().hex[:8]}@example.com"
password = "SecurePassword123!"
# Register
await e2e_client.post(
"/api/v1/auth/register",
json={
"email": email,
"password": password,
"first_name": "Test",
"last_name": "User",
},
)
# Create multiple sessions
tokens1 = (
await e2e_client.post(
"/api/v1/auth/login",
json={"email": email, "password": password},
)
).json()
tokens2 = (
await e2e_client.post(
"/api/v1/auth/login",
json={"email": email, "password": password},
)
).json()
# Logout all
logout_resp = await e2e_client.post(
"/api/v1/auth/logout-all",
headers={"Authorization": f"Bearer {tokens1['access_token']}"},
)
assert logout_resp.status_code == 200
# Second token's refresh should no longer work
refresh_resp = await e2e_client.post(
"/api/v1/auth/refresh",
json={"refresh_token": tokens2["refresh_token"]},
)
assert refresh_resp.status_code in [400, 401]

View File

@@ -0,0 +1,351 @@
"""
User management E2E workflow tests with real PostgreSQL.
These tests validate complete user management workflows including:
- Profile viewing and updates
- Password changes
- User settings management
Usage:
make test-e2e # Run all E2E tests
"""
from uuid import uuid4
import pytest
pytestmark = [
pytest.mark.e2e,
pytest.mark.postgres,
pytest.mark.asyncio,
]
async def register_and_login(client, email: str, password: str = "SecurePassword123!"): # noqa: S107
"""Helper to register a user and get tokens."""
await client.post(
"/api/v1/auth/register",
json={
"email": email,
"password": password,
"first_name": "Test",
"last_name": "User",
},
)
login_resp = await client.post(
"/api/v1/auth/login",
json={"email": email, "password": password},
)
return login_resp.json()
class TestUserProfileWorkflows:
"""Test user profile management workflows."""
async def test_get_own_profile(self, e2e_client):
"""Users can view their own profile."""
email = f"e2e-{uuid4().hex[:8]}@example.com"
tokens = await register_and_login(e2e_client, email)
response = await e2e_client.get(
"/api/v1/users/me",
headers={"Authorization": f"Bearer {tokens['access_token']}"},
)
assert response.status_code == 200
data = response.json()
assert data["email"] == email
assert data["first_name"] == "Test"
assert data["last_name"] == "User"
assert "id" in data
assert "is_active" in data
async def test_update_own_profile(self, e2e_client):
"""Users can update their own profile."""
email = f"e2e-{uuid4().hex[:8]}@example.com"
tokens = await register_and_login(e2e_client, email)
response = await e2e_client.patch(
"/api/v1/users/me",
headers={"Authorization": f"Bearer {tokens['access_token']}"},
json={
"first_name": "Updated",
"last_name": "Name",
},
)
assert response.status_code == 200
data = response.json()
assert data["first_name"] == "Updated"
assert data["last_name"] == "Name"
# Verify changes persisted
verify_resp = await e2e_client.get(
"/api/v1/users/me",
headers={"Authorization": f"Bearer {tokens['access_token']}"},
)
assert verify_resp.json()["first_name"] == "Updated"
async def test_profile_requires_auth(self, e2e_client):
"""Profile endpoints require authentication."""
response = await e2e_client.get("/api/v1/users/me")
assert response.status_code == 401
async def test_get_user_by_id_own_profile(self, e2e_client):
"""Users can get their own profile by ID."""
email = f"e2e-{uuid4().hex[:8]}@example.com"
tokens = await register_and_login(e2e_client, email)
# Get user ID from /me endpoint
me_resp = await e2e_client.get(
"/api/v1/users/me",
headers={"Authorization": f"Bearer {tokens['access_token']}"},
)
user_id = me_resp.json()["id"]
# Get by ID
response = await e2e_client.get(
f"/api/v1/users/{user_id}",
headers={"Authorization": f"Bearer {tokens['access_token']}"},
)
assert response.status_code == 200
assert response.json()["id"] == user_id
async def test_cannot_get_other_user_profile(self, e2e_client):
"""Regular users cannot view other users' profiles."""
# Create two users
user1_email = f"e2e-user1-{uuid4().hex[:8]}@example.com"
user2_email = f"e2e-user2-{uuid4().hex[:8]}@example.com"
tokens1 = await register_and_login(e2e_client, user1_email)
tokens2 = await register_and_login(e2e_client, user2_email)
# Get user1's ID
me_resp = await e2e_client.get(
"/api/v1/users/me",
headers={"Authorization": f"Bearer {tokens1['access_token']}"},
)
user1_id = me_resp.json()["id"]
# User2 tries to access user1's profile
response = await e2e_client.get(
f"/api/v1/users/{user1_id}",
headers={"Authorization": f"Bearer {tokens2['access_token']}"},
)
assert response.status_code == 403
class TestPasswordChangeWorkflows:
"""Test password change workflows."""
async def test_change_password_success(self, e2e_client):
"""Users can change their password with correct current password."""
email = f"e2e-{uuid4().hex[:8]}@example.com"
old_password = "OldPassword123!"
new_password = "NewPassword456!"
tokens = await register_and_login(e2e_client, email, old_password)
response = await e2e_client.patch(
"/api/v1/users/me/password",
headers={"Authorization": f"Bearer {tokens['access_token']}"},
json={
"current_password": old_password,
"new_password": new_password,
},
)
assert response.status_code == 200
data = response.json()
assert data["success"] is True
# Verify new password works
login_resp = await e2e_client.post(
"/api/v1/auth/login",
json={"email": email, "password": new_password},
)
assert login_resp.status_code == 200
async def test_change_password_wrong_current(self, e2e_client):
"""Password change fails with wrong current password."""
email = f"e2e-{uuid4().hex[:8]}@example.com"
tokens = await register_and_login(e2e_client, email)
response = await e2e_client.patch(
"/api/v1/users/me/password",
headers={"Authorization": f"Bearer {tokens['access_token']}"},
json={
"current_password": "WrongPassword123!",
"new_password": "NewPassword456!",
},
)
assert response.status_code == 403
async def test_change_password_weak_new_password(self, e2e_client):
"""Password change fails with weak new password."""
email = f"e2e-{uuid4().hex[:8]}@example.com"
password = "SecurePassword123!"
tokens = await register_and_login(e2e_client, email, password)
response = await e2e_client.patch(
"/api/v1/users/me/password",
headers={"Authorization": f"Bearer {tokens['access_token']}"},
json={
"current_password": password,
"new_password": "weak", # Too weak
},
)
assert response.status_code == 422 # Validation error
async def test_old_password_invalid_after_change(self, e2e_client):
"""Old password no longer works after password change."""
email = f"e2e-{uuid4().hex[:8]}@example.com"
old_password = "OldPassword123!"
new_password = "NewPassword456!"
tokens = await register_and_login(e2e_client, email, old_password)
# Change password
await e2e_client.patch(
"/api/v1/users/me/password",
headers={"Authorization": f"Bearer {tokens['access_token']}"},
json={
"current_password": old_password,
"new_password": new_password,
},
)
# Old password should fail
login_resp = await e2e_client.post(
"/api/v1/auth/login",
json={"email": email, "password": old_password},
)
assert login_resp.status_code == 401
class TestUserUpdateWorkflows:
"""Test user update edge cases."""
async def test_cannot_elevate_own_privileges(self, e2e_client):
"""Users cannot make themselves superusers."""
email = f"e2e-{uuid4().hex[:8]}@example.com"
tokens = await register_and_login(e2e_client, email)
# Try to make self superuser - should be silently ignored or rejected
response = await e2e_client.patch(
"/api/v1/users/me",
headers={"Authorization": f"Bearer {tokens['access_token']}"},
json={"is_superuser": True},
)
# The request may succeed but is_superuser should not change
if response.status_code == 200:
data = response.json()
assert data.get("is_superuser") is False
else:
# Or it may be rejected outright
assert response.status_code in [400, 403, 422]
async def test_cannot_update_other_user_profile(self, e2e_client):
"""Regular users cannot update other users' profiles."""
user1_email = f"e2e-user1-{uuid4().hex[:8]}@example.com"
user2_email = f"e2e-user2-{uuid4().hex[:8]}@example.com"
tokens1 = await register_and_login(e2e_client, user1_email)
tokens2 = await register_and_login(e2e_client, user2_email)
# Get user1's ID
me_resp = await e2e_client.get(
"/api/v1/users/me",
headers={"Authorization": f"Bearer {tokens1['access_token']}"},
)
user1_id = me_resp.json()["id"]
# User2 tries to update user1
response = await e2e_client.patch(
f"/api/v1/users/{user1_id}",
headers={"Authorization": f"Bearer {tokens2['access_token']}"},
json={"first_name": "Hacked"},
)
assert response.status_code == 403
class TestAdminUserListWorkflows:
"""Test admin user list workflows via /users endpoint."""
async def test_superuser_can_list_all_users(self, e2e_client, e2e_superuser):
"""Superuser can list all users via /users endpoint."""
response = await e2e_client.get(
"/api/v1/users",
headers={
"Authorization": f"Bearer {e2e_superuser['tokens']['access_token']}"
},
)
assert response.status_code == 200
data = response.json()
assert "data" in data
assert "pagination" in data
async def test_regular_user_cannot_list_users(self, e2e_client):
"""Regular users cannot list all users."""
email = f"e2e-{uuid4().hex[:8]}@example.com"
tokens = await register_and_login(e2e_client, email)
response = await e2e_client.get(
"/api/v1/users",
headers={"Authorization": f"Bearer {tokens['access_token']}"},
)
assert response.status_code == 403
class TestDeactivatedUserWorkflows:
"""Test workflows involving deactivated users."""
async def test_deactivated_user_cannot_login(self, e2e_client, e2e_superuser):
"""Deactivated users cannot login."""
# Create user
email = f"deactivate-login-{uuid4().hex[:8]}@example.com"
password = "DeactivatePass123!"
await e2e_client.post(
"/api/v1/auth/register",
json={
"email": email,
"password": password,
"first_name": "Deactivate",
"last_name": "Login",
},
)
# Get user ID
list_resp = await e2e_client.get(
"/api/v1/admin/users",
headers={
"Authorization": f"Bearer {e2e_superuser['tokens']['access_token']}"
},
)
users = list_resp.json()["data"]
target_user = next(u for u in users if u["email"] == email)
# Deactivate user
await e2e_client.post(
f"/api/v1/admin/users/{target_user['id']}/deactivate",
headers={
"Authorization": f"Bearer {e2e_superuser['tokens']['access_token']}"
},
)
# Try to login - should fail
response = await e2e_client.post(
"/api/v1/auth/login",
json={"email": email, "password": password},
)
assert response.status_code in [401, 403]

View File

@@ -0,0 +1,772 @@
# tests/services/test_oauth_provider_service.py
"""
Tests for OAuth Provider Service (Authorization Server mode for MCP).
Covers:
- Authorization code creation and exchange
- Token generation, refresh, and revocation
- PKCE verification
- Token introspection (RFC 7662)
- Consent management
- Error handling
"""
import base64
import hashlib
import secrets
from unittest.mock import patch
from uuid import uuid4
import pytest
import pytest_asyncio
from app.models.oauth_client import OAuthClient
from app.models.user import User
from app.services import oauth_provider_service as service
from app.utils.test_utils import setup_async_test_db, teardown_async_test_db
@pytest_asyncio.fixture(scope="function")
async def db():
"""Fixture provides testing engine and session for each test."""
test_engine, AsyncTestingSessionLocal = await setup_async_test_db()
async with AsyncTestingSessionLocal() as session:
yield session
await teardown_async_test_db(test_engine)
@pytest_asyncio.fixture
async def test_user(db):
"""Create a test user."""
user = User(
id=uuid4(),
email="testuser@example.com",
password_hash="$2b$12$test",
first_name="Test",
last_name="User",
is_active=True,
is_superuser=False,
)
db.add(user)
await db.commit()
await db.refresh(user)
return user
@pytest_asyncio.fixture
async def public_client(db):
"""Create a test public OAuth client."""
client = OAuthClient(
id=uuid4(),
client_id="test_public_client",
client_name="Test Public Client",
client_type="public",
redirect_uris=["http://localhost:3000/callback"],
allowed_scopes=["openid", "profile", "email", "read:users"],
is_active=True,
)
db.add(client)
await db.commit()
await db.refresh(client)
return client
@pytest_asyncio.fixture
async def confidential_client(db):
"""Create a test confidential OAuth client using bcrypt."""
from app.core.auth import get_password_hash
secret = "test_client_secret"
# Use bcrypt for new client secret hashing (security improvement)
secret_hash = get_password_hash(secret)
client = OAuthClient(
id=uuid4(),
client_id="test_confidential_client",
client_name="Test Confidential Client",
client_type="confidential",
client_secret_hash=secret_hash,
redirect_uris=["http://localhost:3000/callback"],
allowed_scopes=["openid", "profile", "email"],
is_active=True,
)
db.add(client)
await db.commit()
await db.refresh(client)
return client, secret
@pytest_asyncio.fixture
async def confidential_client_legacy_hash(db):
"""Create a test confidential OAuth client with legacy SHA-256 hash."""
# This tests backward compatibility with old SHA-256 hashed secrets
secret = "test_legacy_secret"
secret_hash = hashlib.sha256(secret.encode()).hexdigest()
client = OAuthClient(
id=uuid4(),
client_id="test_legacy_client",
client_name="Test Legacy Client",
client_type="confidential",
client_secret_hash=secret_hash,
redirect_uris=["http://localhost:3000/callback"],
allowed_scopes=["openid", "profile"],
is_active=True,
)
db.add(client)
await db.commit()
await db.refresh(client)
return client, secret
class TestHelperFunctions:
"""Tests for helper functions."""
def test_generate_code_length(self):
"""Test authorization code generation has proper length."""
code = service.generate_code()
assert len(code) > 64 # Base64 encoding of 64 bytes
def test_generate_code_unique(self):
"""Test authorization codes are unique."""
codes = [service.generate_code() for _ in range(100)]
assert len(set(codes)) == 100
def test_generate_token(self):
"""Test token generation."""
token = service.generate_token()
assert len(token) > 32
def test_generate_jti(self):
"""Test JTI generation."""
jti = service.generate_jti()
assert len(jti) > 20
def test_hash_token(self):
"""Test token hashing."""
token = "test_token"
hashed = service.hash_token(token)
assert len(hashed) == 64 # SHA-256 hex digest
def test_hash_token_deterministic(self):
"""Test same token produces same hash."""
token = "test_token"
hash1 = service.hash_token(token)
hash2 = service.hash_token(token)
assert hash1 == hash2
def test_parse_scope(self):
"""Test scope parsing."""
assert service.parse_scope("openid profile email") == [
"openid",
"profile",
"email",
]
assert service.parse_scope("") == []
assert service.parse_scope(" openid profile ") == ["openid", "profile"]
def test_join_scope(self):
"""Test scope joining."""
# Result is sorted and deduplicated
result = service.join_scope(["profile", "openid", "profile"])
assert "openid" in result
assert "profile" in result
class TestPKCEVerification:
"""Tests for PKCE verification."""
def test_verify_pkce_s256_valid(self):
"""Test PKCE verification with S256 method."""
# Generate code_verifier
code_verifier = secrets.token_urlsafe(64)
# Generate code_challenge using S256
digest = hashlib.sha256(code_verifier.encode()).digest()
code_challenge = base64.urlsafe_b64encode(digest).rstrip(b"=").decode()
assert service.verify_pkce(code_verifier, code_challenge, "S256") is True
def test_verify_pkce_s256_invalid(self):
"""Test PKCE verification fails with wrong verifier."""
code_verifier = secrets.token_urlsafe(64)
wrong_verifier = secrets.token_urlsafe(64)
digest = hashlib.sha256(code_verifier.encode()).digest()
code_challenge = base64.urlsafe_b64encode(digest).rstrip(b"=").decode()
assert service.verify_pkce(wrong_verifier, code_challenge, "S256") is False
def test_verify_pkce_plain_rejected(self):
"""Test PKCE verification rejects 'plain' method for security."""
# SECURITY: 'plain' method provides no security benefit and must be rejected
# per RFC 7636 Section 4.3 - only S256 is allowed
code_verifier = "test_verifier"
assert service.verify_pkce(code_verifier, code_verifier, "plain") is False
def test_verify_pkce_unknown_method(self):
"""Test PKCE verification with unknown method returns False."""
assert service.verify_pkce("verifier", "challenge", "unknown") is False
class TestClientValidation:
"""Tests for client validation."""
@pytest.mark.asyncio
async def test_get_client_success(self, db, public_client):
"""Test getting a valid client."""
client = await service.get_client(db, public_client.client_id)
assert client is not None
assert client.client_id == public_client.client_id
@pytest.mark.asyncio
async def test_get_client_not_found(self, db):
"""Test getting a non-existent client."""
client = await service.get_client(db, "nonexistent")
assert client is None
@pytest.mark.asyncio
async def test_get_client_inactive(self, db, public_client):
"""Test getting an inactive client returns None."""
public_client.is_active = False
await db.commit()
client = await service.get_client(db, public_client.client_id)
assert client is None
@pytest.mark.asyncio
async def test_validate_client_public(self, db, public_client):
"""Test validating a public client."""
client = await service.validate_client(db, public_client.client_id)
assert client.client_id == public_client.client_id
@pytest.mark.asyncio
async def test_validate_client_confidential_with_secret(
self, db, confidential_client
):
"""Test validating a confidential client with correct secret."""
client, secret = confidential_client
validated = await service.validate_client(db, client.client_id, secret)
assert validated.client_id == client.client_id
@pytest.mark.asyncio
async def test_validate_client_confidential_wrong_secret(
self, db, confidential_client
):
"""Test validating a confidential client with wrong secret."""
client, _ = confidential_client
with pytest.raises(service.InvalidClientError, match="Invalid client secret"):
await service.validate_client(db, client.client_id, "wrong_secret")
@pytest.mark.asyncio
async def test_validate_client_confidential_no_secret(
self, db, confidential_client
):
"""Test validating a confidential client without secret."""
client, _ = confidential_client
with pytest.raises(service.InvalidClientError, match="Client secret required"):
await service.validate_client(db, client.client_id)
@pytest.mark.asyncio
async def test_validate_client_legacy_sha256_hash(
self, db, confidential_client_legacy_hash
):
"""Test validating a client with legacy SHA-256 hash (backward compatibility)."""
client, secret = confidential_client_legacy_hash
validated = await service.validate_client(db, client.client_id, secret)
assert validated.client_id == client.client_id
@pytest.mark.asyncio
async def test_validate_client_legacy_sha256_wrong_secret(
self, db, confidential_client_legacy_hash
):
"""Test legacy SHA-256 client rejects wrong secret."""
client, _ = confidential_client_legacy_hash
with pytest.raises(service.InvalidClientError, match="Invalid client secret"):
await service.validate_client(db, client.client_id, "wrong_secret")
def test_validate_redirect_uri_success(self, public_client):
"""Test validating a registered redirect URI."""
# Should not raise
service.validate_redirect_uri(public_client, "http://localhost:3000/callback")
def test_validate_redirect_uri_invalid(self, public_client):
"""Test validating an unregistered redirect URI."""
with pytest.raises(service.InvalidRequestError, match="Invalid redirect_uri"):
service.validate_redirect_uri(public_client, "http://evil.com/callback")
def test_validate_redirect_uri_no_uris(self, public_client):
"""Test validating when client has no URIs."""
public_client.redirect_uris = []
with pytest.raises(service.InvalidRequestError, match="no registered"):
service.validate_redirect_uri(public_client, "http://localhost:3000")
class TestScopeValidation:
"""Tests for scope validation."""
def test_validate_scopes_all_valid(self, public_client):
"""Test validating all valid scopes."""
scopes = service.validate_scopes(public_client, ["openid", "profile"])
assert "openid" in scopes
assert "profile" in scopes
def test_validate_scopes_partial_valid(self, public_client):
"""Test validating with some invalid scopes - filters them out."""
scopes = service.validate_scopes(public_client, ["openid", "invalid_scope"])
assert "openid" in scopes
assert "invalid_scope" not in scopes
def test_validate_scopes_empty_uses_all_allowed(self, public_client):
"""Test empty scope request uses all allowed scopes."""
scopes = service.validate_scopes(public_client, [])
assert set(scopes) == set(public_client.allowed_scopes)
def test_validate_scopes_none_valid(self, public_client):
"""Test validating with no valid scopes raises error."""
with pytest.raises(service.InvalidScopeError):
service.validate_scopes(public_client, ["invalid1", "invalid2"])
class TestAuthorizationCode:
"""Tests for authorization code creation and exchange."""
@pytest.mark.asyncio
async def test_create_authorization_code_public_with_pkce(
self, db, public_client, test_user
):
"""Test creating authorization code for public client with PKCE."""
code_verifier = secrets.token_urlsafe(64)
digest = hashlib.sha256(code_verifier.encode()).digest()
code_challenge = base64.urlsafe_b64encode(digest).rstrip(b"=").decode()
code = await service.create_authorization_code(
db=db,
client=public_client,
user=test_user,
redirect_uri="http://localhost:3000/callback",
scope="openid profile",
code_challenge=code_challenge,
code_challenge_method="S256",
)
assert code is not None
assert len(code) > 64
@pytest.mark.asyncio
async def test_create_authorization_code_public_without_pkce_fails(
self, db, public_client, test_user
):
"""Test creating authorization code for public client without PKCE fails."""
with pytest.raises(service.InvalidRequestError, match="PKCE"):
await service.create_authorization_code(
db=db,
client=public_client,
user=test_user,
redirect_uri="http://localhost:3000/callback",
scope="openid",
)
@pytest.mark.asyncio
async def test_exchange_authorization_code_success(
self, db, public_client, test_user
):
"""Test exchanging valid authorization code for tokens."""
# Create PKCE challenge
code_verifier = secrets.token_urlsafe(64)
digest = hashlib.sha256(code_verifier.encode()).digest()
code_challenge = base64.urlsafe_b64encode(digest).rstrip(b"=").decode()
# Create auth code
code = await service.create_authorization_code(
db=db,
client=public_client,
user=test_user,
redirect_uri="http://localhost:3000/callback",
scope="openid profile",
code_challenge=code_challenge,
code_challenge_method="S256",
)
# Exchange code
with patch("app.services.oauth_provider_service.settings") as mock_settings:
mock_settings.OAUTH_ISSUER = "http://localhost:8000"
mock_settings.SECRET_KEY = "test_secret_key_for_jwt_signing_123456"
mock_settings.ALGORITHM = "HS256"
result = await service.exchange_authorization_code(
db=db,
code=code,
client_id=public_client.client_id,
redirect_uri="http://localhost:3000/callback",
code_verifier=code_verifier,
)
assert "access_token" in result
assert "refresh_token" in result
assert result["token_type"] == "Bearer"
assert "expires_in" in result
@pytest.mark.asyncio
async def test_exchange_authorization_code_invalid_code(self, db, public_client):
"""Test exchanging invalid code fails."""
with pytest.raises(service.InvalidGrantError, match="Invalid authorization"):
await service.exchange_authorization_code(
db=db,
code="invalid_code",
client_id=public_client.client_id,
redirect_uri="http://localhost:3000/callback",
)
@pytest.mark.asyncio
async def test_exchange_authorization_code_wrong_redirect_uri(
self, db, public_client, test_user
):
"""Test exchanging code with wrong redirect_uri fails."""
code_verifier = secrets.token_urlsafe(64)
digest = hashlib.sha256(code_verifier.encode()).digest()
code_challenge = base64.urlsafe_b64encode(digest).rstrip(b"=").decode()
code = await service.create_authorization_code(
db=db,
client=public_client,
user=test_user,
redirect_uri="http://localhost:3000/callback",
scope="openid",
code_challenge=code_challenge,
code_challenge_method="S256",
)
with pytest.raises(service.InvalidGrantError, match="redirect_uri mismatch"):
await service.exchange_authorization_code(
db=db,
code=code,
client_id=public_client.client_id,
redirect_uri="http://different.com/callback",
code_verifier=code_verifier,
)
@pytest.mark.asyncio
async def test_exchange_authorization_code_invalid_pkce(
self, db, public_client, test_user
):
"""Test exchanging code with invalid PKCE verifier fails."""
code_verifier = secrets.token_urlsafe(64)
digest = hashlib.sha256(code_verifier.encode()).digest()
code_challenge = base64.urlsafe_b64encode(digest).rstrip(b"=").decode()
code = await service.create_authorization_code(
db=db,
client=public_client,
user=test_user,
redirect_uri="http://localhost:3000/callback",
scope="openid",
code_challenge=code_challenge,
code_challenge_method="S256",
)
with pytest.raises(service.InvalidGrantError, match="Invalid code_verifier"):
await service.exchange_authorization_code(
db=db,
code=code,
client_id=public_client.client_id,
redirect_uri="http://localhost:3000/callback",
code_verifier="wrong_verifier",
)
class TestTokenRefresh:
"""Tests for token refresh."""
@pytest.mark.asyncio
async def test_refresh_tokens_success(self, db, public_client, test_user):
"""Test refreshing tokens successfully."""
# Create initial tokens
with patch("app.services.oauth_provider_service.settings") as mock_settings:
mock_settings.OAUTH_ISSUER = "http://localhost:8000"
mock_settings.SECRET_KEY = "test_secret_key_for_jwt_signing_123456"
mock_settings.ALGORITHM = "HS256"
result = await service.create_tokens(
db=db,
client=public_client,
user=test_user,
scope="openid profile",
)
refresh_token = result["refresh_token"]
# Refresh the tokens
new_result = await service.refresh_tokens(
db=db,
refresh_token=refresh_token,
client_id=public_client.client_id,
)
assert "access_token" in new_result
assert "refresh_token" in new_result
assert new_result["refresh_token"] != refresh_token # Token rotation
@pytest.mark.asyncio
async def test_refresh_tokens_invalid_token(self, db, public_client):
"""Test refreshing with invalid token fails."""
with pytest.raises(service.InvalidGrantError, match="Invalid refresh token"):
await service.refresh_tokens(
db=db,
refresh_token="invalid_token",
client_id=public_client.client_id,
)
@pytest.mark.asyncio
async def test_refresh_tokens_scope_reduction(self, db, public_client, test_user):
"""Test refreshing with reduced scope."""
with patch("app.services.oauth_provider_service.settings") as mock_settings:
mock_settings.OAUTH_ISSUER = "http://localhost:8000"
mock_settings.SECRET_KEY = "test_secret_key_for_jwt_signing_123456"
mock_settings.ALGORITHM = "HS256"
result = await service.create_tokens(
db=db,
client=public_client,
user=test_user,
scope="openid profile email",
)
new_result = await service.refresh_tokens(
db=db,
refresh_token=result["refresh_token"],
client_id=public_client.client_id,
scope="openid", # Reduced scope
)
assert "openid" in new_result["scope"]
assert "profile" not in new_result["scope"]
@pytest.mark.asyncio
async def test_refresh_tokens_scope_expansion_fails(
self, db, public_client, test_user
):
"""Test refreshing with expanded scope fails."""
with patch("app.services.oauth_provider_service.settings") as mock_settings:
mock_settings.OAUTH_ISSUER = "http://localhost:8000"
mock_settings.SECRET_KEY = "test_secret_key_for_jwt_signing_123456"
mock_settings.ALGORITHM = "HS256"
result = await service.create_tokens(
db=db,
client=public_client,
user=test_user,
scope="openid",
)
with pytest.raises(service.InvalidScopeError, match="Cannot expand scope"):
await service.refresh_tokens(
db=db,
refresh_token=result["refresh_token"],
client_id=public_client.client_id,
scope="openid profile", # Expanded scope
)
class TestTokenRevocation:
"""Tests for token revocation."""
@pytest.mark.asyncio
async def test_revoke_refresh_token(self, db, public_client, test_user):
"""Test revoking a refresh token."""
with patch("app.services.oauth_provider_service.settings") as mock_settings:
mock_settings.OAUTH_ISSUER = "http://localhost:8000"
mock_settings.SECRET_KEY = "test_secret_key_for_jwt_signing_123456"
mock_settings.ALGORITHM = "HS256"
result = await service.create_tokens(
db=db,
client=public_client,
user=test_user,
scope="openid",
)
# Revoke the token
revoked = await service.revoke_token(
db=db,
token=result["refresh_token"],
token_type_hint="refresh_token",
)
assert revoked is True
# Try to use revoked token
with pytest.raises(service.InvalidGrantError, match="revoked"):
await service.refresh_tokens(
db=db,
refresh_token=result["refresh_token"],
client_id=public_client.client_id,
)
@pytest.mark.asyncio
async def test_revoke_all_user_tokens(self, db, public_client, test_user):
"""Test revoking all tokens for a user."""
with patch("app.services.oauth_provider_service.settings") as mock_settings:
mock_settings.OAUTH_ISSUER = "http://localhost:8000"
mock_settings.SECRET_KEY = "test_secret_key_for_jwt_signing_123456"
mock_settings.ALGORITHM = "HS256"
# Create multiple tokens (we don't need to capture results)
await service.create_tokens(
db=db,
client=public_client,
user=test_user,
scope="openid",
)
await service.create_tokens(
db=db,
client=public_client,
user=test_user,
scope="profile",
)
# Revoke all
count = await service.revoke_all_user_tokens(db, test_user.id)
assert count == 2
class TestTokenIntrospection:
"""Tests for token introspection (RFC 7662)."""
@pytest.mark.asyncio
async def test_introspect_valid_access_token(self, db, public_client, test_user):
"""Test introspecting a valid access token."""
with patch("app.services.oauth_provider_service.settings") as mock_settings:
mock_settings.OAUTH_ISSUER = "http://localhost:8000"
mock_settings.SECRET_KEY = "test_secret_key_for_jwt_signing_123456"
mock_settings.ALGORITHM = "HS256"
result = await service.create_tokens(
db=db,
client=public_client,
user=test_user,
scope="openid profile",
)
introspection = await service.introspect_token(
db=db,
token=result["access_token"],
)
assert introspection["active"] is True
assert introspection["client_id"] == public_client.client_id
assert introspection["sub"] == str(test_user.id)
@pytest.mark.asyncio
async def test_introspect_invalid_token(self, db):
"""Test introspecting an invalid token."""
introspection = await service.introspect_token(
db=db,
token="invalid_token",
)
assert introspection["active"] is False
class TestConsentManagement:
"""Tests for consent management."""
@pytest.mark.asyncio
async def test_grant_consent(self, db, public_client, test_user):
"""Test granting consent."""
consent = await service.grant_consent(
db=db,
user_id=test_user.id,
client_id=public_client.client_id,
scopes=["openid", "profile"],
)
assert consent is not None
assert "openid" in consent.granted_scopes
assert "profile" in consent.granted_scopes
@pytest.mark.asyncio
async def test_check_consent_granted(self, db, public_client, test_user):
"""Test checking granted consent."""
await service.grant_consent(
db=db,
user_id=test_user.id,
client_id=public_client.client_id,
scopes=["openid", "profile"],
)
has_consent = await service.check_consent(
db=db,
user_id=test_user.id,
client_id=public_client.client_id,
requested_scopes=["openid"],
)
assert has_consent is True
@pytest.mark.asyncio
async def test_check_consent_not_granted(self, db, public_client, test_user):
"""Test checking consent that hasn't been granted."""
has_consent = await service.check_consent(
db=db,
user_id=test_user.id,
client_id=public_client.client_id,
requested_scopes=["openid"],
)
assert has_consent is False
@pytest.mark.asyncio
async def test_revoke_consent(self, db, public_client, test_user):
"""Test revoking consent."""
await service.grant_consent(
db=db,
user_id=test_user.id,
client_id=public_client.client_id,
scopes=["openid"],
)
revoked = await service.revoke_consent(
db=db,
user_id=test_user.id,
client_id=public_client.client_id,
)
assert revoked is True
# Check consent is gone
has_consent = await service.check_consent(
db=db,
user_id=test_user.id,
client_id=public_client.client_id,
requested_scopes=["openid"],
)
assert has_consent is False
class TestOAuthErrors:
"""Tests for OAuth error classes."""
def test_invalid_client_error(self):
"""Test InvalidClientError."""
error = service.InvalidClientError("Test description")
assert error.error == "invalid_client"
assert error.error_description == "Test description"
def test_invalid_grant_error(self):
"""Test InvalidGrantError."""
error = service.InvalidGrantError("Test description")
assert error.error == "invalid_grant"
assert error.error_description == "Test description"
def test_invalid_request_error(self):
"""Test InvalidRequestError."""
error = service.InvalidRequestError("Test description")
assert error.error == "invalid_request"
assert error.error_description == "Test description"
def test_invalid_scope_error(self):
"""Test InvalidScopeError."""
error = service.InvalidScopeError("Test description")
assert error.error == "invalid_scope"
assert error.error_description == "Test description"
def test_access_denied_error(self):
"""Test AccessDeniedError."""
error = service.AccessDeniedError("Test description")
assert error.error == "access_denied"
assert error.error_description == "Test description"

View File

@@ -451,6 +451,7 @@ class TestHandleCallbackComplete:
state="valid_state_login",
provider="google",
code_verifier="test_verifier",
redirect_uri="http://localhost:3000/callback",
expires_at=datetime.now(UTC) + timedelta(minutes=10),
)
await oauth_state.create_state(session, obj_in=state_data)
@@ -533,6 +534,7 @@ class TestHandleCallbackComplete:
state_data = OAuthStateCreate(
state="valid_state_inactive",
provider="google",
redirect_uri="http://localhost:3000/callback",
expires_at=datetime.now(UTC) + timedelta(minutes=10),
)
await oauth_state.create_state(session, obj_in=state_data)
@@ -583,6 +585,7 @@ class TestHandleCallbackComplete:
state="valid_state_linking",
provider="github",
user_id=async_test_user.id, # User is logged in
redirect_uri="http://localhost:3000/callback",
expires_at=datetime.now(UTC) + timedelta(minutes=10),
)
await oauth_state.create_state(session, obj_in=state_data)
@@ -648,6 +651,7 @@ class TestHandleCallbackComplete:
state="valid_state_bad_user",
provider="google",
user_id=uuid4(), # Non-existent user
redirect_uri="http://localhost:3000/callback",
expires_at=datetime.now(UTC) + timedelta(minutes=10),
)
await oauth_state.create_state(session, obj_in=state_data)
@@ -707,6 +711,7 @@ class TestHandleCallbackComplete:
state="valid_state_already_linked",
provider="google",
user_id=async_test_user.id,
redirect_uri="http://localhost:3000/callback",
expires_at=datetime.now(UTC) + timedelta(minutes=10),
)
await oauth_state.create_state(session, obj_in=state_data)
@@ -769,6 +774,7 @@ class TestHandleCallbackComplete:
state_data = OAuthStateCreate(
state="valid_state_autolink",
provider="google",
redirect_uri="http://localhost:3000/callback",
expires_at=datetime.now(UTC) + timedelta(minutes=10),
)
await oauth_state.create_state(session, obj_in=state_data)
@@ -832,6 +838,7 @@ class TestHandleCallbackComplete:
state_data = OAuthStateCreate(
state="valid_state_new_user",
provider="google",
redirect_uri="http://localhost:3000/callback",
expires_at=datetime.now(UTC) + timedelta(minutes=10),
)
await oauth_state.create_state(session, obj_in=state_data)
@@ -904,6 +911,7 @@ class TestHandleCallbackComplete:
state_data = OAuthStateCreate(
state="valid_state_no_email",
provider="github",
redirect_uri="http://localhost:3000/callback",
expires_at=datetime.now(UTC) + timedelta(minutes=10),
)
await oauth_state.create_state(session, obj_in=state_data)
@@ -961,6 +969,7 @@ class TestHandleCallbackComplete:
state_data = OAuthStateCreate(
state="valid_state_token_fail",
provider="google",
redirect_uri="http://localhost:3000/callback",
expires_at=datetime.now(UTC) + timedelta(minutes=10),
)
await oauth_state.create_state(session, obj_in=state_data)
@@ -1004,6 +1013,7 @@ class TestHandleCallbackComplete:
state_data = OAuthStateCreate(
state="valid_state_userinfo_fail",
provider="google",
redirect_uri="http://localhost:3000/callback",
expires_at=datetime.now(UTC) + timedelta(minutes=10),
)
await oauth_state.create_state(session, obj_in=state_data)
@@ -1047,6 +1057,7 @@ class TestHandleCallbackComplete:
state_data = OAuthStateCreate(
state="valid_state_no_token",
provider="google",
redirect_uri="http://localhost:3000/callback",
expires_at=datetime.now(UTC) + timedelta(minutes=10),
)
await oauth_state.create_state(session, obj_in=state_data)
@@ -1090,6 +1101,7 @@ class TestHandleCallbackComplete:
state_data = OAuthStateCreate(
state="valid_state_no_user_id",
provider="google",
redirect_uri="http://localhost:3000/callback",
expires_at=datetime.now(UTC) + timedelta(minutes=10),
)
await oauth_state.create_state(session, obj_in=state_data)

View File

@@ -153,6 +153,7 @@
"authFailed": "Authentication Failed",
"providerError": "The authentication provider returned an error",
"missingParams": "Missing authentication parameters",
"stateMismatch": "Invalid OAuth state. Please try again.",
"unexpectedError": "An unexpected error occurred during authentication",
"backToLogin": "Back to Login"
}

View File

@@ -153,6 +153,7 @@
"authFailed": "Autenticazione Fallita",
"providerError": "Il provider di autenticazione ha restituito un errore",
"missingParams": "Parametri di autenticazione mancanti",
"stateMismatch": "Stato OAuth non valido. Riprova.",
"unexpectedError": "Si è verificato un errore durante l'autenticazione",
"backToLogin": "Torna al Login"
}

View File

@@ -21,6 +21,24 @@ import { Loader2 } from 'lucide-react';
import { useOAuthCallback } from '@/lib/api/hooks/useOAuth';
import config from '@/config/app.config';
/**
* SECURITY: Constant-time string comparison to prevent timing attacks.
* JavaScript's === operator may short-circuit, potentially leaking information.
* While timing attacks on frontend state are less critical (state is in URL),
* this provides defense-in-depth.
*/
function constantTimeCompare(a: string, b: string): boolean {
if (a.length !== b.length) {
return false;
}
let result = 0;
for (let i = 0; i < a.length; i++) {
result |= a.charCodeAt(i) ^ b.charCodeAt(i);
}
return result === 0;
}
export default function OAuthCallbackPage() {
const params = useParams();
const searchParams = useSearchParams();
@@ -53,6 +71,19 @@ export default function OAuthCallbackPage() {
return;
}
// SECURITY: Validate state parameter against stored value (CSRF protection)
// This prevents cross-site request forgery attacks
// Use constant-time comparison for defense-in-depth
const storedState = sessionStorage.getItem('oauth_state');
if (!storedState || !constantTimeCompare(storedState, state)) {
// Clean up stored state on mismatch
sessionStorage.removeItem('oauth_state');
sessionStorage.removeItem('oauth_mode');
sessionStorage.removeItem('oauth_provider');
setError(t('stateMismatch') || 'Invalid OAuth state. Please try again.');
return;
}
hasProcessed.current = true;
// Process the OAuth callback

View File

@@ -0,0 +1,325 @@
/**
* OAuth Consent Page
* Displays authorization consent form for OAuth provider mode (MCP integration).
*
* Users are redirected here when an external application (MCP client) requests
* access to their account. They can approve or deny the requested permissions.
*/
'use client';
import { useState, useEffect } from 'react';
import { useSearchParams } from 'next/navigation';
import { useRouter } from '@/lib/i18n/routing';
import { useTranslations } from 'next-intl';
import { Button } from '@/components/ui/button';
import {
Card,
CardContent,
CardDescription,
CardFooter,
CardHeader,
CardTitle,
} from '@/components/ui/card';
import { Alert, AlertDescription } from '@/components/ui/alert';
import { Checkbox } from '@/components/ui/checkbox';
import { Label } from '@/components/ui/label';
import { Loader2, Shield, AlertTriangle, ExternalLink, CheckCircle2 } from 'lucide-react';
import { useAuth } from '@/lib/auth/AuthContext';
import config from '@/config/app.config';
// Scope descriptions for display
const SCOPE_INFO: Record<string, { name: string; description: string; icon: string }> = {
openid: {
name: 'OpenID Connect',
description: 'Verify your identity',
icon: 'user',
},
profile: {
name: 'Profile',
description: 'Access your name and basic profile information',
icon: 'user-circle',
},
email: {
name: 'Email',
description: 'Access your email address',
icon: 'mail',
},
'read:users': {
name: 'Read Users',
description: 'View user information',
icon: 'users',
},
'write:users': {
name: 'Write Users',
description: 'Modify user information',
icon: 'user-edit',
},
'read:organizations': {
name: 'Read Organizations',
description: 'View organization information',
icon: 'building',
},
'write:organizations': {
name: 'Write Organizations',
description: 'Modify organization information',
icon: 'building-edit',
},
admin: {
name: 'Admin Access',
description: 'Full administrative access',
icon: 'shield',
},
};
interface ConsentParams {
clientId: string;
clientName: string;
redirectUri: string;
scope: string;
state: string;
codeChallenge: string;
codeChallengeMethod: string;
nonce: string;
}
export default function OAuthConsentPage() {
const searchParams = useSearchParams();
const router = useRouter();
// Note: t is available for future i18n use
const _t = useTranslations('auth.oauth');
void _t; // Suppress unused warning - ready for i18n
const { isAuthenticated, isLoading: authLoading } = useAuth();
const [isSubmitting, setIsSubmitting] = useState(false);
const [error, setError] = useState<string | null>(null);
const [selectedScopes, setSelectedScopes] = useState<Set<string>>(new Set());
const [params, setParams] = useState<ConsentParams | null>(null);
// Parse URL parameters
useEffect(() => {
const clientId = searchParams.get('client_id') || '';
const clientName = searchParams.get('client_name') || 'Application';
const redirectUri = searchParams.get('redirect_uri') || '';
const scope = searchParams.get('scope') || '';
const state = searchParams.get('state') || '';
const codeChallenge = searchParams.get('code_challenge') || '';
const codeChallengeMethod = searchParams.get('code_challenge_method') || '';
const nonce = searchParams.get('nonce') || '';
if (!clientId || !redirectUri) {
setError('Invalid authorization request. Missing required parameters.');
return;
}
setParams({
clientId,
clientName,
redirectUri,
scope,
state,
codeChallenge,
codeChallengeMethod,
nonce,
});
// Initialize selected scopes with all requested scopes
if (scope) {
setSelectedScopes(new Set(scope.split(' ')));
}
}, [searchParams]);
// Redirect to login if not authenticated
useEffect(() => {
if (!authLoading && !isAuthenticated) {
const returnUrl = `/auth/consent?${searchParams.toString()}`;
router.push(`${config.routes.login}?return_to=${encodeURIComponent(returnUrl)}`);
}
}, [authLoading, isAuthenticated, router, searchParams]);
const handleScopeToggle = (scope: string) => {
setSelectedScopes((prev) => {
const next = new Set(prev);
if (next.has(scope)) {
next.delete(scope);
} else {
next.add(scope);
}
return next;
});
};
const handleSubmit = async (approved: boolean) => {
if (!params) return;
setIsSubmitting(true);
setError(null);
try {
// Create form data for consent submission
const formData = new FormData();
formData.append('approved', approved.toString());
formData.append('client_id', params.clientId);
formData.append('redirect_uri', params.redirectUri);
formData.append('scope', Array.from(selectedScopes).join(' '));
formData.append('state', params.state);
if (params.codeChallenge) {
formData.append('code_challenge', params.codeChallenge);
}
if (params.codeChallengeMethod) {
formData.append('code_challenge_method', params.codeChallengeMethod);
}
if (params.nonce) {
formData.append('nonce', params.nonce);
}
// Submit consent to backend
const apiUrl = process.env.NEXT_PUBLIC_API_URL || 'http://localhost:8000';
const response = await fetch(`${apiUrl}/api/v1/oauth/provider/authorize/consent`, {
method: 'POST',
body: formData,
credentials: 'include',
});
// The endpoint returns a redirect, so follow it
if (response.redirected) {
window.location.href = response.url;
} else if (!response.ok) {
const data = await response.json();
throw new Error(data.detail || 'Failed to process consent');
}
} catch (err) {
setError(err instanceof Error ? err.message : 'An unexpected error occurred');
setIsSubmitting(false);
}
};
// Show loading state while checking auth
if (authLoading) {
return (
<div className="flex min-h-screen items-center justify-center p-4">
<div className="text-center space-y-4">
<Loader2 className="h-8 w-8 animate-spin mx-auto text-primary" />
<p className="text-muted-foreground">Loading...</p>
</div>
</div>
);
}
// Show error state
if (error && !params) {
return (
<div className="flex min-h-screen items-center justify-center p-4">
<div className="w-full max-w-md space-y-4">
<Alert variant="destructive">
<AlertTriangle className="h-4 w-4" />
<AlertDescription>{error}</AlertDescription>
</Alert>
<div className="flex gap-2 justify-center">
<Button variant="outline" onClick={() => router.push(config.routes.login)}>
Back to Login
</Button>
</div>
</div>
</div>
);
}
if (!params) {
return null;
}
const requestedScopes = params.scope ? params.scope.split(' ') : [];
return (
<div className="flex min-h-screen items-center justify-center p-4">
<Card className="w-full max-w-md">
<CardHeader className="text-center">
<div className="flex justify-center mb-4">
<Shield className="h-12 w-12 text-primary" />
</div>
<CardTitle className="text-xl">Authorization Request</CardTitle>
<CardDescription className="mt-2">
<span className="font-semibold text-foreground">{params.clientName}</span> wants to
access your account
</CardDescription>
</CardHeader>
<CardContent className="space-y-4">
{error && (
<Alert variant="destructive">
<AlertTriangle className="h-4 w-4" />
<AlertDescription>{error}</AlertDescription>
</Alert>
)}
<div className="space-y-3">
<p className="text-sm font-medium">This application will be able to:</p>
<div className="space-y-2 border rounded-lg p-3">
{requestedScopes.map((scope) => {
const info = SCOPE_INFO[scope] || {
name: scope,
description: `Access to ${scope}`,
};
const isSelected = selectedScopes.has(scope);
return (
<div
key={scope}
className="flex items-start space-x-3 py-2 border-b last:border-0"
>
<Checkbox
id={`scope-${scope}`}
checked={isSelected}
onCheckedChange={() => handleScopeToggle(scope)}
disabled={isSubmitting}
/>
<div className="flex-1 space-y-0.5">
<Label
htmlFor={`scope-${scope}`}
className="text-sm font-medium cursor-pointer"
>
{info.name}
</Label>
<p className="text-xs text-muted-foreground">{info.description}</p>
</div>
{isSelected && <CheckCircle2 className="h-4 w-4 text-green-500 mt-0.5" />}
</div>
);
})}
</div>
</div>
<Alert>
<ExternalLink className="h-4 w-4" />
<AlertDescription className="text-xs">
After authorization, you will be redirected to:
<br />
<code className="text-xs break-all bg-muted px-1 py-0.5 rounded">
{params.redirectUri}
</code>
</AlertDescription>
</Alert>
</CardContent>
<CardFooter className="flex gap-3">
<Button
variant="outline"
className="flex-1"
onClick={() => handleSubmit(false)}
disabled={isSubmitting}
>
{isSubmitting ? <Loader2 className="h-4 w-4 animate-spin" /> : 'Deny'}
</Button>
<Button
className="flex-1"
onClick={() => handleSubmit(true)}
disabled={isSubmitting || selectedScopes.size === 0}
>
{isSubmitting ? <Loader2 className="h-4 w-4 animate-spin" /> : 'Authorize'}
</Button>
</CardFooter>
</Card>
</div>
);
}

View File

@@ -56,6 +56,44 @@ export function useOAuthProviders() {
// OAuth Flow Mutations
// ============================================================================
// Allowed OAuth provider domains for security validation
const ALLOWED_OAUTH_DOMAINS = [
'accounts.google.com',
'github.com',
'www.facebook.com', // For future Facebook support
'login.microsoftonline.com', // For future Microsoft support
];
/**
* Validate OAuth authorization URL
* SECURITY: Prevents open redirect attacks by only allowing known OAuth provider domains
*/
function isValidOAuthUrl(url: string): boolean {
try {
const parsed = new URL(url);
// Only allow HTTPS for OAuth (security requirement)
if (parsed.protocol !== 'https:') {
return false;
}
// Check if domain is in allowlist
return ALLOWED_OAUTH_DOMAINS.includes(parsed.hostname);
} catch {
return false;
}
}
/**
* Extract state parameter from OAuth authorization URL
*/
function extractStateFromUrl(url: string): string | null {
try {
const parsed = new URL(url);
return parsed.searchParams.get('state');
} catch {
return null;
}
}
/**
* Start OAuth login/registration flow
* Redirects user to the OAuth provider
@@ -77,12 +115,27 @@ export function useOAuthStart() {
});
if (response.data) {
// Store mode in sessionStorage for callback handling
sessionStorage.setItem('oauth_mode', mode);
sessionStorage.setItem('oauth_provider', provider);
// Response is { [key: string]: unknown }, so cast authorization_url
const authUrl = (response.data as { authorization_url: string }).authorization_url;
// SECURITY: Validate the authorization URL before redirecting
// This prevents open redirect attacks if the backend is compromised
if (!isValidOAuthUrl(authUrl)) {
throw new Error('Invalid OAuth authorization URL');
}
// SECURITY: Extract and store the state parameter for CSRF validation
// The callback page will verify this matches the state in the response
const state = extractStateFromUrl(authUrl);
if (!state) {
throw new Error('Missing state parameter in authorization URL');
}
// Store mode, provider, and state in sessionStorage for callback handling
sessionStorage.setItem('oauth_mode', mode);
sessionStorage.setItem('oauth_provider', provider);
sessionStorage.setItem('oauth_state', state);
// Redirect to OAuth provider
window.location.href = authUrl;
}
@@ -151,14 +204,16 @@ export function useOAuthCallback() {
queryClient.invalidateQueries({ queryKey: ['user'] });
}
// Clean up session storage
// Clean up session storage (including state for security)
sessionStorage.removeItem('oauth_mode');
sessionStorage.removeItem('oauth_provider');
sessionStorage.removeItem('oauth_state');
},
onError: () => {
// Clean up session storage on error too
sessionStorage.removeItem('oauth_mode');
sessionStorage.removeItem('oauth_provider');
sessionStorage.removeItem('oauth_state');
},
});
}
@@ -199,12 +254,25 @@ export function useOAuthLink() {
});
if (response.data) {
// Store mode in sessionStorage for callback handling
sessionStorage.setItem('oauth_mode', 'link');
sessionStorage.setItem('oauth_provider', provider);
// Response is { [key: string]: unknown }, so cast authorization_url
const authUrl = (response.data as { authorization_url: string }).authorization_url;
// SECURITY: Validate the authorization URL before redirecting
if (!isValidOAuthUrl(authUrl)) {
throw new Error('Invalid OAuth authorization URL');
}
// SECURITY: Extract and store the state parameter for CSRF validation
const state = extractStateFromUrl(authUrl);
if (!state) {
throw new Error('Missing state parameter in authorization URL');
}
// Store mode, provider, and state in sessionStorage for callback handling
sessionStorage.setItem('oauth_mode', 'link');
sessionStorage.setItem('oauth_provider', provider);
sessionStorage.setItem('oauth_state', state);
// Redirect to OAuth provider
window.location.href = authUrl;
}