forked from cardosofelipe/fast-next-template
- Implemented OAuth 2.0 Authorization Server endpoints per RFCs, including token, introspection, revocation, and metadata discovery. - Added user consent submission, listing, and revocation APIs alongside frontend integration for improved UX. - Enforced stricter OAuth security measures (PKCE, state validation, scopes). - Refactored schemas and services for consistency and expanded coverage of OAuth workflows. - Updated documentation and type definitions for new API behaviors.
1070 lines
33 KiB
Python
Executable File
1070 lines
33 KiB
Python
Executable File
"""
|
|
OAuth Provider Service for MCP integration.
|
|
|
|
Implements OAuth 2.0 Authorization Server functionality:
|
|
- Authorization code flow with PKCE
|
|
- Token issuance (JWT access tokens, opaque refresh tokens)
|
|
- Token refresh
|
|
- Token revocation
|
|
- Consent management
|
|
|
|
Security features:
|
|
- PKCE required for public clients (S256)
|
|
- Short-lived authorization codes (10 minutes)
|
|
- JWT access tokens (self-contained, no DB lookup)
|
|
- Secure refresh token storage (hashed)
|
|
- Token rotation on refresh
|
|
- Comprehensive validation
|
|
"""
|
|
|
|
import base64
|
|
import hashlib
|
|
import logging
|
|
import secrets
|
|
from datetime import UTC, datetime, timedelta
|
|
from typing import Any
|
|
from uuid import UUID
|
|
|
|
from jose import jwt
|
|
from sqlalchemy import and_, delete, select
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
|
|
from app.core.config import settings
|
|
from app.models.oauth_authorization_code import OAuthAuthorizationCode
|
|
from app.models.oauth_client import OAuthClient
|
|
from app.models.oauth_provider_token import OAuthConsent, OAuthProviderRefreshToken
|
|
from app.models.user import User
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
# Constants
|
|
AUTHORIZATION_CODE_EXPIRY_MINUTES = 10
|
|
ACCESS_TOKEN_EXPIRY_MINUTES = 60 # 1 hour for MCP clients
|
|
REFRESH_TOKEN_EXPIRY_DAYS = 30
|
|
|
|
|
|
class OAuthProviderError(Exception):
|
|
"""Base exception for OAuth provider errors."""
|
|
|
|
def __init__(
|
|
self,
|
|
error: str,
|
|
error_description: str | None = None,
|
|
error_uri: str | None = None,
|
|
):
|
|
self.error = error
|
|
self.error_description = error_description
|
|
self.error_uri = error_uri
|
|
super().__init__(error_description or error)
|
|
|
|
|
|
class InvalidClientError(OAuthProviderError):
|
|
"""Client authentication failed."""
|
|
|
|
def __init__(self, description: str = "Invalid client credentials"):
|
|
super().__init__("invalid_client", description)
|
|
|
|
|
|
class InvalidGrantError(OAuthProviderError):
|
|
"""Invalid authorization grant."""
|
|
|
|
def __init__(self, description: str = "Invalid grant"):
|
|
super().__init__("invalid_grant", description)
|
|
|
|
|
|
class InvalidRequestError(OAuthProviderError):
|
|
"""Invalid request parameters."""
|
|
|
|
def __init__(self, description: str = "Invalid request"):
|
|
super().__init__("invalid_request", description)
|
|
|
|
|
|
class InvalidScopeError(OAuthProviderError):
|
|
"""Invalid scope requested."""
|
|
|
|
def __init__(self, description: str = "Invalid scope"):
|
|
super().__init__("invalid_scope", description)
|
|
|
|
|
|
class UnauthorizedClientError(OAuthProviderError):
|
|
"""Client not authorized for this grant type."""
|
|
|
|
def __init__(self, description: str = "Unauthorized client"):
|
|
super().__init__("unauthorized_client", description)
|
|
|
|
|
|
class AccessDeniedError(OAuthProviderError):
|
|
"""User denied authorization."""
|
|
|
|
def __init__(self, description: str = "Access denied"):
|
|
super().__init__("access_denied", description)
|
|
|
|
|
|
# ============================================================================
|
|
# Helper Functions
|
|
# ============================================================================
|
|
|
|
|
|
def generate_code() -> str:
|
|
"""Generate a cryptographically secure authorization code."""
|
|
return secrets.token_urlsafe(64)
|
|
|
|
|
|
def generate_token() -> str:
|
|
"""Generate a cryptographically secure token."""
|
|
return secrets.token_urlsafe(48)
|
|
|
|
|
|
def generate_jti() -> str:
|
|
"""Generate a unique JWT ID."""
|
|
return secrets.token_urlsafe(32)
|
|
|
|
|
|
def hash_token(token: str) -> str:
|
|
"""Hash a token using SHA-256."""
|
|
return hashlib.sha256(token.encode()).hexdigest()
|
|
|
|
|
|
def verify_pkce(code_verifier: str, code_challenge: str, method: str) -> bool:
|
|
"""
|
|
Verify PKCE code_verifier against stored code_challenge.
|
|
|
|
SECURITY: Only S256 method is supported. The 'plain' method provides
|
|
no security benefit and is explicitly rejected per RFC 7636 Section 4.3.
|
|
"""
|
|
if method != "S256":
|
|
# SECURITY: Reject any method other than S256
|
|
# 'plain' method provides no security against code interception attacks
|
|
logger.warning(f"PKCE verification rejected for unsupported method: {method}")
|
|
return False
|
|
|
|
# SHA-256 hash, then base64url encode (RFC 7636 Section 4.2)
|
|
digest = hashlib.sha256(code_verifier.encode()).digest()
|
|
computed = base64.urlsafe_b64encode(digest).rstrip(b"=").decode()
|
|
return secrets.compare_digest(computed, code_challenge)
|
|
|
|
|
|
def parse_scope(scope: str) -> list[str]:
|
|
"""Parse space-separated scope string into list."""
|
|
return [s.strip() for s in scope.split() if s.strip()]
|
|
|
|
|
|
def join_scope(scopes: list[str]) -> str:
|
|
"""Join scope list into space-separated string."""
|
|
return " ".join(sorted(set(scopes)))
|
|
|
|
|
|
# ============================================================================
|
|
# Client Validation
|
|
# ============================================================================
|
|
|
|
|
|
async def get_client(db: AsyncSession, client_id: str) -> OAuthClient | None:
|
|
"""Get OAuth client by client_id."""
|
|
result = await db.execute(
|
|
select(OAuthClient).where(
|
|
and_(
|
|
OAuthClient.client_id == client_id,
|
|
OAuthClient.is_active == True, # noqa: E712
|
|
)
|
|
)
|
|
)
|
|
return result.scalar_one_or_none()
|
|
|
|
|
|
async def validate_client(
|
|
db: AsyncSession,
|
|
client_id: str,
|
|
client_secret: str | None = None,
|
|
require_secret: bool = False,
|
|
) -> OAuthClient:
|
|
"""
|
|
Validate OAuth client credentials.
|
|
|
|
Args:
|
|
db: Database session
|
|
client_id: Client identifier
|
|
client_secret: Client secret (required for confidential clients)
|
|
require_secret: Whether to require secret validation
|
|
|
|
Returns:
|
|
Validated OAuthClient
|
|
|
|
Raises:
|
|
InvalidClientError: If client validation fails
|
|
"""
|
|
client = await get_client(db, client_id)
|
|
if not client:
|
|
raise InvalidClientError("Unknown client_id")
|
|
|
|
# Confidential clients must provide valid secret
|
|
if client.client_type == "confidential" or require_secret:
|
|
if not client_secret:
|
|
raise InvalidClientError("Client secret required")
|
|
if not client.client_secret_hash:
|
|
raise InvalidClientError("Client not configured with secret")
|
|
|
|
# SECURITY: Verify secret using bcrypt (not SHA-256)
|
|
# Supports both bcrypt and legacy SHA-256 hashes for migration
|
|
from app.core.auth import verify_password
|
|
|
|
stored_hash = str(client.client_secret_hash)
|
|
|
|
if stored_hash.startswith("$2"):
|
|
# New bcrypt format
|
|
if not verify_password(client_secret, stored_hash):
|
|
raise InvalidClientError("Invalid client secret")
|
|
else:
|
|
# Legacy SHA-256 format
|
|
computed_hash = hashlib.sha256(client_secret.encode()).hexdigest()
|
|
if not secrets.compare_digest(computed_hash, stored_hash):
|
|
raise InvalidClientError("Invalid client secret")
|
|
|
|
return client
|
|
|
|
|
|
def validate_redirect_uri(client: OAuthClient, redirect_uri: str) -> None:
|
|
"""
|
|
Validate redirect_uri against client's registered URIs.
|
|
|
|
Raises:
|
|
InvalidRequestError: If redirect_uri is not registered
|
|
"""
|
|
if not client.redirect_uris:
|
|
raise InvalidRequestError("Client has no registered redirect URIs")
|
|
|
|
if redirect_uri not in client.redirect_uris:
|
|
raise InvalidRequestError("Invalid redirect_uri")
|
|
|
|
|
|
def validate_scopes(client: OAuthClient, requested_scopes: list[str]) -> list[str]:
|
|
"""
|
|
Validate requested scopes against client's allowed scopes.
|
|
|
|
Returns:
|
|
List of valid scopes (intersection of requested and allowed)
|
|
|
|
Raises:
|
|
InvalidScopeError: If no valid scopes
|
|
"""
|
|
allowed = set(client.allowed_scopes or [])
|
|
requested = set(requested_scopes)
|
|
|
|
# If no scopes requested, use all allowed scopes
|
|
if not requested:
|
|
return list(allowed)
|
|
|
|
valid = requested & allowed
|
|
if not valid:
|
|
raise InvalidScopeError(
|
|
"None of the requested scopes are allowed for this client"
|
|
)
|
|
|
|
# Warn if some scopes were filtered out
|
|
invalid = requested - allowed
|
|
if invalid:
|
|
logger.warning(f"Client {client.client_id} requested invalid scopes: {invalid}")
|
|
|
|
return list(valid)
|
|
|
|
|
|
# ============================================================================
|
|
# Authorization Code Flow
|
|
# ============================================================================
|
|
|
|
|
|
async def create_authorization_code(
|
|
db: AsyncSession,
|
|
client: OAuthClient,
|
|
user: User,
|
|
redirect_uri: str,
|
|
scope: str,
|
|
code_challenge: str | None = None,
|
|
code_challenge_method: str | None = None,
|
|
state: str | None = None,
|
|
nonce: str | None = None,
|
|
) -> str:
|
|
"""
|
|
Create an authorization code for the authorization code flow.
|
|
|
|
Args:
|
|
db: Database session
|
|
client: Validated OAuth client
|
|
user: Authenticated user
|
|
redirect_uri: Validated redirect URI
|
|
scope: Granted scopes (space-separated)
|
|
code_challenge: PKCE code challenge
|
|
code_challenge_method: PKCE method (S256)
|
|
state: CSRF state parameter
|
|
nonce: OpenID Connect nonce
|
|
|
|
Returns:
|
|
Authorization code string
|
|
"""
|
|
# Public clients MUST use PKCE
|
|
if client.client_type == "public":
|
|
if not code_challenge or code_challenge_method != "S256":
|
|
raise InvalidRequestError("PKCE with S256 is required for public clients")
|
|
|
|
code = generate_code()
|
|
expires_at = datetime.now(UTC) + timedelta(
|
|
minutes=AUTHORIZATION_CODE_EXPIRY_MINUTES
|
|
)
|
|
|
|
auth_code = OAuthAuthorizationCode(
|
|
code=code,
|
|
client_id=client.client_id,
|
|
user_id=user.id,
|
|
redirect_uri=redirect_uri,
|
|
scope=scope,
|
|
code_challenge=code_challenge,
|
|
code_challenge_method=code_challenge_method,
|
|
state=state,
|
|
nonce=nonce,
|
|
expires_at=expires_at,
|
|
used=False,
|
|
)
|
|
|
|
db.add(auth_code)
|
|
await db.commit()
|
|
|
|
logger.info(
|
|
f"Created authorization code for user {user.id} and client {client.client_id}"
|
|
)
|
|
return code
|
|
|
|
|
|
async def exchange_authorization_code(
|
|
db: AsyncSession,
|
|
code: str,
|
|
client_id: str,
|
|
redirect_uri: str,
|
|
code_verifier: str | None = None,
|
|
client_secret: str | None = None,
|
|
device_info: str | None = None,
|
|
ip_address: str | None = None,
|
|
) -> dict[str, Any]:
|
|
"""
|
|
Exchange authorization code for tokens.
|
|
|
|
Args:
|
|
db: Database session
|
|
code: Authorization code
|
|
client_id: Client identifier
|
|
redirect_uri: Must match the original redirect_uri
|
|
code_verifier: PKCE code verifier
|
|
client_secret: Client secret (for confidential clients)
|
|
device_info: Optional device information
|
|
ip_address: Optional IP address
|
|
|
|
Returns:
|
|
Token response dict with access_token, refresh_token, etc.
|
|
|
|
Raises:
|
|
InvalidGrantError: If code is invalid, expired, or already used
|
|
InvalidClientError: If client validation fails
|
|
"""
|
|
# Atomically mark code as used and fetch it (prevents race condition)
|
|
# RFC 6749 Section 4.1.2: Authorization codes MUST be single-use
|
|
from sqlalchemy import update
|
|
|
|
# First, atomically mark the code as used and get affected count
|
|
update_stmt = (
|
|
update(OAuthAuthorizationCode)
|
|
.where(
|
|
and_(
|
|
OAuthAuthorizationCode.code == code,
|
|
OAuthAuthorizationCode.used == False, # noqa: E712
|
|
)
|
|
)
|
|
.values(used=True)
|
|
.returning(OAuthAuthorizationCode.id)
|
|
)
|
|
result = await db.execute(update_stmt)
|
|
updated_id = result.scalar_one_or_none()
|
|
|
|
if not updated_id:
|
|
# Either code doesn't exist or was already used
|
|
# Check if it exists to provide appropriate error
|
|
check_result = await db.execute(
|
|
select(OAuthAuthorizationCode).where(OAuthAuthorizationCode.code == code)
|
|
)
|
|
existing_code = check_result.scalar_one_or_none()
|
|
|
|
if existing_code and existing_code.used:
|
|
# Code reuse is a security incident - revoke all tokens for this grant
|
|
logger.warning(
|
|
f"Authorization code reuse detected for client {existing_code.client_id}"
|
|
)
|
|
await revoke_tokens_for_user_client(
|
|
db, UUID(str(existing_code.user_id)), str(existing_code.client_id)
|
|
)
|
|
raise InvalidGrantError("Authorization code has already been used")
|
|
else:
|
|
raise InvalidGrantError("Invalid authorization code")
|
|
|
|
# Now fetch the full auth code record
|
|
auth_code_result = await db.execute(
|
|
select(OAuthAuthorizationCode).where(OAuthAuthorizationCode.id == updated_id)
|
|
)
|
|
auth_code = auth_code_result.scalar_one()
|
|
await db.commit()
|
|
|
|
if auth_code.is_expired:
|
|
raise InvalidGrantError("Authorization code has expired")
|
|
|
|
if auth_code.client_id != client_id:
|
|
raise InvalidGrantError("Authorization code was not issued to this client")
|
|
|
|
if auth_code.redirect_uri != redirect_uri:
|
|
raise InvalidGrantError("redirect_uri mismatch")
|
|
|
|
# Validate client - ALWAYS require secret for confidential clients
|
|
client = await get_client(db, client_id)
|
|
if not client:
|
|
raise InvalidClientError("Unknown client_id")
|
|
|
|
# Confidential clients MUST authenticate (RFC 6749 Section 3.2.1)
|
|
if client.client_type == "confidential":
|
|
if not client_secret:
|
|
raise InvalidClientError("Client secret required for confidential clients")
|
|
client = await validate_client(
|
|
db, client_id, client_secret, require_secret=True
|
|
)
|
|
elif client_secret:
|
|
# Public client provided secret - validate it if given
|
|
client = await validate_client(
|
|
db, client_id, client_secret, require_secret=True
|
|
)
|
|
|
|
# Verify PKCE
|
|
if auth_code.code_challenge:
|
|
if not code_verifier:
|
|
raise InvalidGrantError("code_verifier required")
|
|
if not verify_pkce(
|
|
code_verifier,
|
|
str(auth_code.code_challenge),
|
|
str(auth_code.code_challenge_method or "S256"),
|
|
):
|
|
raise InvalidGrantError("Invalid code_verifier")
|
|
elif client.client_type == "public":
|
|
# Public clients without PKCE - this shouldn't happen if we validated on authorize
|
|
raise InvalidGrantError("PKCE required for public clients")
|
|
|
|
# Get user
|
|
user_result = await db.execute(select(User).where(User.id == auth_code.user_id))
|
|
user = user_result.scalar_one_or_none()
|
|
if not user or not user.is_active:
|
|
raise InvalidGrantError("User not found or inactive")
|
|
|
|
# Generate tokens
|
|
return await create_tokens(
|
|
db=db,
|
|
client=client,
|
|
user=user,
|
|
scope=str(auth_code.scope),
|
|
nonce=str(auth_code.nonce) if auth_code.nonce else None,
|
|
device_info=device_info,
|
|
ip_address=ip_address,
|
|
)
|
|
|
|
|
|
# ============================================================================
|
|
# Token Generation
|
|
# ============================================================================
|
|
|
|
|
|
async def create_tokens(
|
|
db: AsyncSession,
|
|
client: OAuthClient,
|
|
user: User,
|
|
scope: str,
|
|
nonce: str | None = None,
|
|
device_info: str | None = None,
|
|
ip_address: str | None = None,
|
|
) -> dict[str, Any]:
|
|
"""
|
|
Create access and refresh tokens.
|
|
|
|
Args:
|
|
db: Database session
|
|
client: OAuth client
|
|
user: User
|
|
scope: Granted scopes
|
|
nonce: OpenID Connect nonce (included in ID token)
|
|
device_info: Optional device information
|
|
ip_address: Optional IP address
|
|
|
|
Returns:
|
|
Token response dict
|
|
"""
|
|
now = datetime.now(UTC)
|
|
jti = generate_jti()
|
|
|
|
# Access token expiry
|
|
access_token_lifetime = int(client.access_token_lifetime or "3600")
|
|
access_expires = now + timedelta(seconds=access_token_lifetime)
|
|
|
|
# Refresh token expiry
|
|
refresh_token_lifetime = int(
|
|
client.refresh_token_lifetime or str(REFRESH_TOKEN_EXPIRY_DAYS * 86400)
|
|
)
|
|
refresh_expires = now + timedelta(seconds=refresh_token_lifetime)
|
|
|
|
# Create JWT access token
|
|
# SECURITY: Include all standard JWT claims per RFC 7519
|
|
access_token_payload = {
|
|
"iss": settings.OAUTH_ISSUER,
|
|
"sub": str(user.id),
|
|
"aud": client.client_id,
|
|
"exp": int(access_expires.timestamp()),
|
|
"iat": int(now.timestamp()),
|
|
"nbf": int(now.timestamp()), # Not Before - token is valid immediately
|
|
"jti": jti,
|
|
"scope": scope,
|
|
"client_id": client.client_id,
|
|
# User info (basic claims)
|
|
"email": user.email,
|
|
"name": f"{user.first_name or ''} {user.last_name or ''}".strip() or user.email,
|
|
}
|
|
|
|
# Add nonce for OpenID Connect
|
|
if nonce:
|
|
access_token_payload["nonce"] = nonce
|
|
|
|
access_token = jwt.encode(
|
|
access_token_payload,
|
|
settings.SECRET_KEY,
|
|
algorithm=settings.ALGORITHM,
|
|
)
|
|
|
|
# Create opaque refresh token
|
|
refresh_token = generate_token()
|
|
refresh_token_hash = hash_token(refresh_token)
|
|
|
|
# Store refresh token in database
|
|
refresh_token_record = OAuthProviderRefreshToken(
|
|
token_hash=refresh_token_hash,
|
|
jti=jti,
|
|
client_id=client.client_id,
|
|
user_id=user.id,
|
|
scope=scope,
|
|
expires_at=refresh_expires,
|
|
device_info=device_info,
|
|
ip_address=ip_address,
|
|
)
|
|
db.add(refresh_token_record)
|
|
await db.commit()
|
|
|
|
logger.info(f"Issued tokens for user {user.id} to client {client.client_id}")
|
|
|
|
return {
|
|
"access_token": access_token,
|
|
"token_type": "Bearer",
|
|
"expires_in": access_token_lifetime,
|
|
"refresh_token": refresh_token,
|
|
"scope": scope,
|
|
}
|
|
|
|
|
|
async def refresh_tokens(
|
|
db: AsyncSession,
|
|
refresh_token: str,
|
|
client_id: str,
|
|
client_secret: str | None = None,
|
|
scope: str | None = None,
|
|
device_info: str | None = None,
|
|
ip_address: str | None = None,
|
|
) -> dict[str, Any]:
|
|
"""
|
|
Refresh access token using refresh token.
|
|
|
|
Implements token rotation - old refresh token is invalidated,
|
|
new refresh token is issued.
|
|
|
|
Args:
|
|
db: Database session
|
|
refresh_token: Refresh token
|
|
client_id: Client identifier
|
|
client_secret: Client secret (for confidential clients)
|
|
scope: Optional reduced scope
|
|
device_info: Optional device information
|
|
ip_address: Optional IP address
|
|
|
|
Returns:
|
|
New token response dict
|
|
|
|
Raises:
|
|
InvalidGrantError: If refresh token is invalid
|
|
"""
|
|
# Find refresh token
|
|
token_hash = hash_token(refresh_token)
|
|
result = await db.execute(
|
|
select(OAuthProviderRefreshToken).where(
|
|
OAuthProviderRefreshToken.token_hash == token_hash
|
|
)
|
|
)
|
|
token_record: OAuthProviderRefreshToken | None = result.scalar_one_or_none()
|
|
|
|
if not token_record:
|
|
raise InvalidGrantError("Invalid refresh token")
|
|
|
|
if token_record.revoked:
|
|
# Token reuse after revocation - security incident
|
|
logger.warning(
|
|
f"Revoked refresh token reuse detected for client {token_record.client_id}"
|
|
)
|
|
raise InvalidGrantError("Refresh token has been revoked")
|
|
|
|
if token_record.is_expired:
|
|
raise InvalidGrantError("Refresh token has expired")
|
|
|
|
if token_record.client_id != client_id:
|
|
raise InvalidGrantError("Refresh token was not issued to this client")
|
|
|
|
# Validate client
|
|
client = await validate_client(
|
|
db,
|
|
client_id,
|
|
client_secret,
|
|
require_secret=(client_secret is not None),
|
|
)
|
|
|
|
# Get user
|
|
user_result = await db.execute(select(User).where(User.id == token_record.user_id))
|
|
user = user_result.scalar_one_or_none()
|
|
if not user or not user.is_active:
|
|
raise InvalidGrantError("User not found or inactive")
|
|
|
|
# Validate scope (can only reduce, not expand)
|
|
token_scope = str(token_record.scope) if token_record.scope else ""
|
|
original_scopes = set(parse_scope(token_scope))
|
|
if scope:
|
|
requested_scopes = set(parse_scope(scope))
|
|
if not requested_scopes.issubset(original_scopes):
|
|
raise InvalidScopeError("Cannot expand scope on refresh")
|
|
final_scope = join_scope(list(requested_scopes))
|
|
else:
|
|
final_scope = token_scope
|
|
|
|
# Revoke old refresh token (token rotation)
|
|
token_record.revoked = True # type: ignore[assignment]
|
|
token_record.last_used_at = datetime.now(UTC) # type: ignore[assignment]
|
|
await db.commit()
|
|
|
|
# Issue new tokens
|
|
device = str(token_record.device_info) if token_record.device_info else None
|
|
ip_addr = str(token_record.ip_address) if token_record.ip_address else None
|
|
return await create_tokens(
|
|
db=db,
|
|
client=client,
|
|
user=user,
|
|
scope=final_scope,
|
|
device_info=device_info or device,
|
|
ip_address=ip_address or ip_addr,
|
|
)
|
|
|
|
|
|
# ============================================================================
|
|
# Token Revocation
|
|
# ============================================================================
|
|
|
|
|
|
async def revoke_token(
|
|
db: AsyncSession,
|
|
token: str,
|
|
token_type_hint: str | None = None,
|
|
client_id: str | None = None,
|
|
client_secret: str | None = None,
|
|
) -> bool:
|
|
"""
|
|
Revoke a token (access or refresh).
|
|
|
|
For refresh tokens: marks as revoked in database
|
|
For access tokens: we can't truly revoke JWTs, but we can revoke
|
|
the associated refresh token to prevent further refreshes
|
|
|
|
Args:
|
|
db: Database session
|
|
token: Token to revoke
|
|
token_type_hint: "access_token" or "refresh_token"
|
|
client_id: Client identifier (for validation)
|
|
client_secret: Client secret (for confidential clients)
|
|
|
|
Returns:
|
|
True if token was revoked, False if not found
|
|
"""
|
|
# Try as refresh token first (more likely)
|
|
if token_type_hint != "access_token":
|
|
token_hash = hash_token(token)
|
|
result = await db.execute(
|
|
select(OAuthProviderRefreshToken).where(
|
|
OAuthProviderRefreshToken.token_hash == token_hash
|
|
)
|
|
)
|
|
refresh_record = result.scalar_one_or_none()
|
|
|
|
if refresh_record:
|
|
# Validate client if provided
|
|
if client_id and refresh_record.client_id != client_id:
|
|
raise InvalidClientError("Token was not issued to this client")
|
|
|
|
refresh_record.revoked = True # type: ignore[assignment]
|
|
await db.commit()
|
|
logger.info(f"Revoked refresh token {refresh_record.jti[:8]}...")
|
|
return True
|
|
|
|
# Try as access token (JWT)
|
|
if token_type_hint != "refresh_token":
|
|
try:
|
|
from jose.exceptions import JWTError
|
|
|
|
payload = jwt.decode(
|
|
token,
|
|
settings.SECRET_KEY,
|
|
algorithms=[settings.ALGORITHM],
|
|
options={
|
|
"verify_exp": False,
|
|
"verify_aud": False,
|
|
}, # Allow expired tokens
|
|
)
|
|
jti = payload.get("jti")
|
|
if jti:
|
|
# Find and revoke the associated refresh token
|
|
result = await db.execute(
|
|
select(OAuthProviderRefreshToken).where(
|
|
OAuthProviderRefreshToken.jti == jti
|
|
)
|
|
)
|
|
refresh_record = result.scalar_one_or_none()
|
|
if refresh_record:
|
|
if client_id and refresh_record.client_id != client_id:
|
|
raise InvalidClientError("Token was not issued to this client")
|
|
refresh_record.revoked = True # type: ignore[assignment]
|
|
await db.commit()
|
|
logger.info(
|
|
f"Revoked refresh token via access token JTI {jti[:8]}..."
|
|
)
|
|
return True
|
|
except (JWTError, Exception): # noqa: S110 - Intentional: invalid JWT not an error
|
|
pass
|
|
|
|
return False
|
|
|
|
|
|
async def revoke_tokens_for_user_client(
|
|
db: AsyncSession,
|
|
user_id: UUID,
|
|
client_id: str,
|
|
) -> int:
|
|
"""
|
|
Revoke all tokens for a specific user-client pair.
|
|
|
|
Used when security incidents are detected (e.g., code reuse).
|
|
|
|
Args:
|
|
db: Database session
|
|
user_id: User identifier
|
|
client_id: Client identifier
|
|
|
|
Returns:
|
|
Number of tokens revoked
|
|
"""
|
|
result = await db.execute(
|
|
select(OAuthProviderRefreshToken).where(
|
|
and_(
|
|
OAuthProviderRefreshToken.user_id == user_id,
|
|
OAuthProviderRefreshToken.client_id == client_id,
|
|
OAuthProviderRefreshToken.revoked == False, # noqa: E712
|
|
)
|
|
)
|
|
)
|
|
tokens = result.scalars().all()
|
|
|
|
count = 0
|
|
for token in tokens:
|
|
token.revoked = True # type: ignore[assignment]
|
|
count += 1
|
|
|
|
if count > 0:
|
|
await db.commit()
|
|
logger.warning(
|
|
f"Revoked {count} tokens for user {user_id} and client {client_id}"
|
|
)
|
|
|
|
return count
|
|
|
|
|
|
async def revoke_all_user_tokens(db: AsyncSession, user_id: UUID) -> int:
|
|
"""
|
|
Revoke all OAuth provider tokens for a user.
|
|
|
|
Used when user changes password or explicitly logs out everywhere.
|
|
|
|
Args:
|
|
db: Database session
|
|
user_id: User identifier
|
|
|
|
Returns:
|
|
Number of tokens revoked
|
|
"""
|
|
result = await db.execute(
|
|
select(OAuthProviderRefreshToken).where(
|
|
and_(
|
|
OAuthProviderRefreshToken.user_id == user_id,
|
|
OAuthProviderRefreshToken.revoked == False, # noqa: E712
|
|
)
|
|
)
|
|
)
|
|
tokens = result.scalars().all()
|
|
|
|
count = 0
|
|
for token in tokens:
|
|
token.revoked = True # type: ignore[assignment]
|
|
count += 1
|
|
|
|
if count > 0:
|
|
await db.commit()
|
|
logger.info(f"Revoked {count} OAuth provider tokens for user {user_id}")
|
|
|
|
return count
|
|
|
|
|
|
# ============================================================================
|
|
# Token Introspection (RFC 7662)
|
|
# ============================================================================
|
|
|
|
|
|
async def introspect_token(
|
|
db: AsyncSession,
|
|
token: str,
|
|
token_type_hint: str | None = None,
|
|
client_id: str | None = None,
|
|
client_secret: str | None = None,
|
|
) -> dict[str, Any]:
|
|
"""
|
|
Introspect a token to determine its validity and metadata.
|
|
|
|
Implements RFC 7662 Token Introspection.
|
|
|
|
Args:
|
|
db: Database session
|
|
token: Token to introspect
|
|
token_type_hint: "access_token" or "refresh_token"
|
|
client_id: Client requesting introspection
|
|
client_secret: Client secret
|
|
|
|
Returns:
|
|
Introspection response dict
|
|
"""
|
|
# Validate client if credentials provided
|
|
if client_id:
|
|
await validate_client(db, client_id, client_secret)
|
|
|
|
# Try as access token (JWT) first
|
|
if token_type_hint != "refresh_token":
|
|
try:
|
|
from jose.exceptions import ExpiredSignatureError, JWTError
|
|
|
|
payload = jwt.decode(
|
|
token,
|
|
settings.SECRET_KEY,
|
|
algorithms=[settings.ALGORITHM],
|
|
options={
|
|
"verify_aud": False
|
|
}, # Don't require audience match for introspection
|
|
)
|
|
|
|
# Check if associated refresh token is revoked
|
|
jti = payload.get("jti")
|
|
if jti:
|
|
result = await db.execute(
|
|
select(OAuthProviderRefreshToken).where(
|
|
OAuthProviderRefreshToken.jti == jti
|
|
)
|
|
)
|
|
refresh_record = result.scalar_one_or_none()
|
|
if refresh_record and refresh_record.revoked:
|
|
return {"active": False}
|
|
|
|
return {
|
|
"active": True,
|
|
"scope": payload.get("scope", ""),
|
|
"client_id": payload.get("client_id"),
|
|
"username": payload.get("email"),
|
|
"token_type": "Bearer",
|
|
"exp": payload.get("exp"),
|
|
"iat": payload.get("iat"),
|
|
"sub": payload.get("sub"),
|
|
"aud": payload.get("aud"),
|
|
"iss": payload.get("iss"),
|
|
}
|
|
except ExpiredSignatureError:
|
|
return {"active": False}
|
|
except (JWTError, Exception): # noqa: S110 - Intentional: invalid JWT falls through to refresh token check
|
|
pass
|
|
|
|
# Try as refresh token
|
|
if token_type_hint != "access_token":
|
|
token_hash = hash_token(token)
|
|
result = await db.execute(
|
|
select(OAuthProviderRefreshToken).where(
|
|
OAuthProviderRefreshToken.token_hash == token_hash
|
|
)
|
|
)
|
|
refresh_record = result.scalar_one_or_none()
|
|
|
|
if refresh_record and refresh_record.is_valid:
|
|
return {
|
|
"active": True,
|
|
"scope": refresh_record.scope,
|
|
"client_id": refresh_record.client_id,
|
|
"token_type": "refresh_token",
|
|
"exp": int(refresh_record.expires_at.timestamp()),
|
|
"iat": int(refresh_record.created_at.timestamp()),
|
|
"sub": str(refresh_record.user_id),
|
|
}
|
|
|
|
return {"active": False}
|
|
|
|
|
|
# ============================================================================
|
|
# Consent Management
|
|
# ============================================================================
|
|
|
|
|
|
async def get_consent(
|
|
db: AsyncSession,
|
|
user_id: UUID,
|
|
client_id: str,
|
|
) -> OAuthConsent | None:
|
|
"""Get existing consent record for user-client pair."""
|
|
result = await db.execute(
|
|
select(OAuthConsent).where(
|
|
and_(
|
|
OAuthConsent.user_id == user_id,
|
|
OAuthConsent.client_id == client_id,
|
|
)
|
|
)
|
|
)
|
|
return result.scalar_one_or_none()
|
|
|
|
|
|
async def check_consent(
|
|
db: AsyncSession,
|
|
user_id: UUID,
|
|
client_id: str,
|
|
requested_scopes: list[str],
|
|
) -> bool:
|
|
"""
|
|
Check if user has already consented to the requested scopes.
|
|
|
|
Returns True if all requested scopes are already granted.
|
|
"""
|
|
consent = await get_consent(db, user_id, client_id)
|
|
if not consent:
|
|
return False
|
|
return consent.has_scopes(requested_scopes)
|
|
|
|
|
|
async def grant_consent(
|
|
db: AsyncSession,
|
|
user_id: UUID,
|
|
client_id: str,
|
|
scopes: list[str],
|
|
) -> OAuthConsent:
|
|
"""
|
|
Grant or update consent for a user-client pair.
|
|
|
|
If consent already exists, updates the granted scopes.
|
|
"""
|
|
consent = await get_consent(db, user_id, client_id)
|
|
|
|
if consent:
|
|
# Merge scopes
|
|
granted = str(consent.granted_scopes) if consent.granted_scopes else ""
|
|
existing = set(parse_scope(granted))
|
|
new_scopes = existing | set(scopes)
|
|
consent.granted_scopes = join_scope(list(new_scopes)) # type: ignore[assignment]
|
|
else:
|
|
consent = OAuthConsent(
|
|
user_id=user_id,
|
|
client_id=client_id,
|
|
granted_scopes=join_scope(scopes),
|
|
)
|
|
db.add(consent)
|
|
|
|
await db.commit()
|
|
await db.refresh(consent)
|
|
return consent
|
|
|
|
|
|
async def revoke_consent(
|
|
db: AsyncSession,
|
|
user_id: UUID,
|
|
client_id: str,
|
|
) -> bool:
|
|
"""
|
|
Revoke consent and all tokens for a user-client pair.
|
|
|
|
Returns True if consent was found and revoked.
|
|
"""
|
|
# Delete consent record
|
|
result = await db.execute(
|
|
delete(OAuthConsent).where(
|
|
and_(
|
|
OAuthConsent.user_id == user_id,
|
|
OAuthConsent.client_id == client_id,
|
|
)
|
|
)
|
|
)
|
|
|
|
# Revoke all tokens
|
|
await revoke_tokens_for_user_client(db, user_id, client_id)
|
|
|
|
await db.commit()
|
|
return result.rowcount > 0 # type: ignore[attr-defined]
|
|
|
|
|
|
# ============================================================================
|
|
# Cleanup
|
|
# ============================================================================
|
|
|
|
|
|
async def cleanup_expired_codes(db: AsyncSession) -> int:
|
|
"""
|
|
Delete expired authorization codes.
|
|
|
|
Should be called periodically (e.g., every hour).
|
|
|
|
Returns:
|
|
Number of codes deleted
|
|
"""
|
|
result = await db.execute(
|
|
delete(OAuthAuthorizationCode).where(
|
|
OAuthAuthorizationCode.expires_at < datetime.now(UTC)
|
|
)
|
|
)
|
|
await db.commit()
|
|
return result.rowcount # type: ignore[attr-defined]
|
|
|
|
|
|
async def cleanup_expired_tokens(db: AsyncSession) -> int:
|
|
"""
|
|
Delete expired and revoked refresh tokens.
|
|
|
|
Should be called periodically (e.g., daily).
|
|
|
|
Returns:
|
|
Number of tokens deleted
|
|
"""
|
|
# Delete tokens that are both expired AND revoked (or just very old)
|
|
cutoff = datetime.now(UTC) - timedelta(days=7)
|
|
result = await db.execute(
|
|
delete(OAuthProviderRefreshToken).where(
|
|
OAuthProviderRefreshToken.expires_at < cutoff
|
|
)
|
|
)
|
|
await db.commit()
|
|
return result.rowcount # type: ignore[attr-defined]
|