forked from cardosofelipe/pragma-stack
Add full OAuth provider functionality and enhance flows
- Implemented OAuth 2.0 Authorization Server endpoints per RFCs, including token, introspection, revocation, and metadata discovery. - Added user consent submission, listing, and revocation APIs alongside frontend integration for improved UX. - Enforced stricter OAuth security measures (PKCE, state validation, scopes). - Refactored schemas and services for consistency and expanded coverage of OAuth workflows. - Updated documentation and type definitions for new API behaviors.
This commit is contained in:
0
backend/app/alembic/versions/f8c3d2e1a4b5_add_oauth_provider_models.py
Normal file → Executable file
0
backend/app/alembic/versions/f8c3d2e1a4b5_add_oauth_provider_models.py
Normal file → Executable file
@@ -27,7 +27,11 @@ from slowapi import Limiter
|
||||
from slowapi.util import get_remote_address
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.api.dependencies.auth import get_current_active_user, get_current_superuser
|
||||
from app.api.dependencies.auth import (
|
||||
get_current_active_user,
|
||||
get_current_superuser,
|
||||
get_optional_current_user,
|
||||
)
|
||||
from app.core.config import settings
|
||||
from app.core.database import get_db
|
||||
from app.crud import oauth_client as oauth_client_crud
|
||||
@@ -42,6 +46,8 @@ from app.schemas.oauth import (
|
||||
from app.services import oauth_provider_service as provider_service
|
||||
|
||||
router = APIRouter()
|
||||
# Separate router for RFC 8414 well-known endpoint (registered at root level)
|
||||
wellknown_router = APIRouter()
|
||||
logger = logging.getLogger(__name__)
|
||||
limiter = Limiter(key_func=get_remote_address)
|
||||
|
||||
@@ -60,7 +66,7 @@ def require_provider_enabled():
|
||||
# ============================================================================
|
||||
|
||||
|
||||
@router.get(
|
||||
@wellknown_router.get(
|
||||
"/.well-known/oauth-authorization-server",
|
||||
response_model=OAuthServerMetadata,
|
||||
summary="OAuth Server Metadata",
|
||||
@@ -69,6 +75,8 @@ def require_provider_enabled():
|
||||
|
||||
Returns server metadata including supported endpoints, scopes,
|
||||
and capabilities. MCP clients use this to discover the server.
|
||||
|
||||
Note: This endpoint is at the root level per RFC 8414.
|
||||
""",
|
||||
operation_id="get_oauth_server_metadata",
|
||||
tags=["OAuth Provider"],
|
||||
@@ -153,7 +161,7 @@ async def authorize(
|
||||
nonce: str | None = Query(default=None, description="OpenID Connect nonce"),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
_: None = Depends(require_provider_enabled),
|
||||
current_user: User | None = Depends(get_current_active_user),
|
||||
current_user: User | None = Depends(get_optional_current_user),
|
||||
) -> Any:
|
||||
"""
|
||||
Authorization endpoint - initiates OAuth flow.
|
||||
|
||||
@@ -14,6 +14,7 @@ from slowapi.errors import RateLimitExceeded
|
||||
from slowapi.util import get_remote_address
|
||||
|
||||
from app.api.main import api_router
|
||||
from app.api.routes.oauth_provider import wellknown_router as oauth_wellknown_router
|
||||
from app.core.config import settings
|
||||
from app.core.database import check_database_health
|
||||
from app.core.exceptions import (
|
||||
@@ -324,3 +325,7 @@ async def health_check() -> JSONResponse:
|
||||
|
||||
|
||||
app.include_router(api_router, prefix=settings.API_V1_STR)
|
||||
|
||||
# OAuth 2.0 well-known endpoint at root level per RFC 8414
|
||||
# This allows MCP clients to discover the OAuth server metadata at /.well-known/oauth-authorization-server
|
||||
app.include_router(oauth_wellknown_router)
|
||||
|
||||
0
backend/app/models/oauth_account.py
Normal file → Executable file
0
backend/app/models/oauth_account.py
Normal file → Executable file
0
backend/app/models/oauth_authorization_code.py
Normal file → Executable file
0
backend/app/models/oauth_authorization_code.py
Normal file → Executable file
0
backend/app/models/oauth_client.py
Normal file → Executable file
0
backend/app/models/oauth_client.py
Normal file → Executable file
0
backend/app/models/oauth_provider_token.py
Normal file → Executable file
0
backend/app/models/oauth_provider_token.py
Normal file → Executable file
0
backend/app/models/oauth_state.py
Normal file → Executable file
0
backend/app/models/oauth_state.py
Normal file → Executable file
0
backend/app/services/oauth_provider_service.py
Normal file → Executable file
0
backend/app/services/oauth_provider_service.py
Normal file → Executable file
@@ -291,9 +291,8 @@ class TestOAuthProviderEndpoints:
|
||||
with patch("app.api.routes.oauth_provider.settings") as mock_settings:
|
||||
mock_settings.OAUTH_PROVIDER_ENABLED = False
|
||||
|
||||
response = await client.get(
|
||||
"/api/v1/oauth/.well-known/oauth-authorization-server"
|
||||
)
|
||||
# RFC 8414: well-known endpoint is at root level
|
||||
response = await client.get("/.well-known/oauth-authorization-server")
|
||||
assert response.status_code == 404
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -303,9 +302,8 @@ class TestOAuthProviderEndpoints:
|
||||
mock_settings.OAUTH_PROVIDER_ENABLED = True
|
||||
mock_settings.OAUTH_ISSUER = "https://api.example.com"
|
||||
|
||||
response = await client.get(
|
||||
"/api/v1/oauth/.well-known/oauth-authorization-server"
|
||||
)
|
||||
# RFC 8414: well-known endpoint is at root level
|
||||
response = await client.get("/.well-known/oauth-authorization-server")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["issuer"] == "https://api.example.com"
|
||||
@@ -344,8 +342,10 @@ class TestOAuthProviderEndpoints:
|
||||
assert response.status_code == 404
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_provider_authorize_requires_auth(self, client, async_test_db):
|
||||
"""Test provider authorize requires authentication."""
|
||||
async def test_provider_authorize_public_client_requires_pkce(
|
||||
self, client, async_test_db
|
||||
):
|
||||
"""Test provider authorize requires PKCE for public clients."""
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
# Create a test client
|
||||
@@ -373,9 +373,54 @@ class TestOAuthProviderEndpoints:
|
||||
"client_id": test_client_id,
|
||||
"redirect_uri": "http://localhost:3000/callback",
|
||||
},
|
||||
follow_redirects=False,
|
||||
)
|
||||
# Authorize endpoint requires authentication
|
||||
assert response.status_code == 401
|
||||
# Public client without PKCE gets redirect with error
|
||||
assert response.status_code == 302
|
||||
assert "error=invalid_request" in response.headers.get("location", "")
|
||||
assert "PKCE" in response.headers.get("location", "")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_provider_authorize_redirects_to_login(self, client, async_test_db):
|
||||
"""Test provider authorize redirects unauthenticated users to login."""
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
# Create a test client
|
||||
from app.crud.oauth import oauth_client
|
||||
from app.schemas.oauth import OAuthClientCreate
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
client_data = OAuthClientCreate(
|
||||
client_name="Test App",
|
||||
redirect_uris=["http://localhost:3000/callback"],
|
||||
allowed_scopes=["read:users"],
|
||||
)
|
||||
test_client, _ = await oauth_client.create_client(
|
||||
session, obj_in=client_data
|
||||
)
|
||||
test_client_id = test_client.client_id
|
||||
|
||||
with patch("app.api.routes.oauth_provider.settings") as mock_settings:
|
||||
mock_settings.OAUTH_PROVIDER_ENABLED = True
|
||||
mock_settings.FRONTEND_URL = "http://localhost:3000"
|
||||
|
||||
# Include PKCE parameters for public client
|
||||
response = await client.get(
|
||||
"/api/v1/oauth/provider/authorize",
|
||||
params={
|
||||
"response_type": "code",
|
||||
"client_id": test_client_id,
|
||||
"redirect_uri": "http://localhost:3000/callback",
|
||||
"code_challenge": "E9Melhoa2OwvFrEMTJguCHaoeK1t8URWbuGJSstw-cM",
|
||||
"code_challenge_method": "S256",
|
||||
},
|
||||
follow_redirects=False,
|
||||
)
|
||||
# Unauthenticated users get redirected to login
|
||||
assert response.status_code == 302
|
||||
location = response.headers.get("location", "")
|
||||
assert "/login" in location
|
||||
assert "return_to" in location
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_provider_token_requires_client_id(self, client):
|
||||
|
||||
Reference in New Issue
Block a user