Remove token revocation logic and unused dependencies
Eliminated the `RevokedToken` model and associated logic for managing token revocation. Removed unused files, related tests, and outdated dependencies in authentication modules. Simplified token decoding, user validation, and dependency injection by streamlining the flow and enhancing maintainability.
This commit is contained in:
@@ -1,40 +0,0 @@
|
||||
from fastapi import Depends, HTTPException, status
|
||||
from fastapi.security import OAuth2PasswordBearer
|
||||
from jose import JWTError, jwt
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.core.database import get_db
|
||||
from app.auth.security import decode_token
|
||||
from app.models.user import User
|
||||
|
||||
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="auth/token")
|
||||
|
||||
|
||||
async def get_current_user(
|
||||
token: str = Depends(oauth2_scheme),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
try:
|
||||
payload = await decode_token(token) # Use updated decode_token.
|
||||
user_id: str = payload.sub
|
||||
token_type: str = payload.type
|
||||
|
||||
if user_id is None or token_type != "access":
|
||||
raise HTTPException(status_code=401, detail="Invalid token type.")
|
||||
|
||||
user = await db.get(User, user_id)
|
||||
if user is None:
|
||||
raise HTTPException(status_code=401, detail="User not found.")
|
||||
|
||||
return user
|
||||
except JWTError as e:
|
||||
raise HTTPException(status_code=401, detail=str(e))
|
||||
|
||||
|
||||
|
||||
async def get_current_active_user(
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
if not current_user.is_active:
|
||||
raise HTTPException(status_code=400, detail="Inactive user")
|
||||
return current_user
|
||||
@@ -1,176 +0,0 @@
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import Optional
|
||||
from uuid import uuid4
|
||||
|
||||
from fastapi import Depends
|
||||
from jose import jwt, JWTError, ExpiredSignatureError, JOSEError
|
||||
from passlib.context import CryptContext
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.core.config import settings
|
||||
from app.core.database import get_db
|
||||
from app.schemas.token import TokenPayload, TokenResponse
|
||||
from auth.utils import is_token_revoked
|
||||
|
||||
# Configuration
|
||||
SECRET_KEY = settings.SECRET_KEY
|
||||
ALGORITHM = "HS256"
|
||||
ACCESS_TOKEN_EXPIRE_MINUTES = 30
|
||||
REFRESH_TOKEN_EXPIRE_DAYS = 7
|
||||
|
||||
# Password hashing context
|
||||
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
|
||||
|
||||
|
||||
def verify_password(plain_password: str, hashed_password: str) -> bool:
|
||||
"""Verify a plain password against its hash."""
|
||||
return pwd_context.verify(plain_password, hashed_password)
|
||||
|
||||
|
||||
def get_password_hash(password: str) -> str:
|
||||
"""Generate password hash."""
|
||||
return pwd_context.hash(password)
|
||||
|
||||
|
||||
def create_tokens(user_id: str) -> TokenResponse:
|
||||
"""
|
||||
Create both access and refresh tokens for a user.
|
||||
|
||||
Args:
|
||||
user_id: The user's ID
|
||||
|
||||
Returns:
|
||||
TokenResponse containing both tokens and metadata
|
||||
"""
|
||||
# Add `jti` during token creation
|
||||
access_token = create_access_token({"sub": user_id, "jti": str(uuid4())})
|
||||
refresh_token = create_refresh_token({"sub": user_id, "jti": str(uuid4())})
|
||||
|
||||
return TokenResponse(
|
||||
access_token=access_token,
|
||||
refresh_token=refresh_token,
|
||||
token_type="bearer",
|
||||
expires_in=ACCESS_TOKEN_EXPIRE_MINUTES * 60,
|
||||
user_id=user_id,
|
||||
scope="read write"
|
||||
)
|
||||
|
||||
|
||||
def create_token(
|
||||
data: dict,
|
||||
expires_delta: Optional[timedelta] = None,
|
||||
token_type: str = "access"
|
||||
) -> str:
|
||||
"""Create a JWT token with the specified type and expiration."""
|
||||
to_encode = data.copy()
|
||||
|
||||
if expires_delta:
|
||||
expire = datetime.now(timezone.utc) + expires_delta
|
||||
else:
|
||||
expire = datetime.now(timezone.utc) + (
|
||||
timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES) if token_type == "access"
|
||||
else timedelta(days=REFRESH_TOKEN_EXPIRE_DAYS)
|
||||
)
|
||||
|
||||
to_encode.update({
|
||||
"exp": expire,
|
||||
"type": token_type,
|
||||
"iat": datetime.now(timezone.utc),
|
||||
})
|
||||
if "jti" not in to_encode:
|
||||
to_encode["jti"] = str(uuid4()) # Ensure unique `jti` is always added
|
||||
|
||||
return jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)
|
||||
|
||||
|
||||
def create_access_token(data: dict, expires_delta: Optional[timedelta] = None) -> str:
|
||||
"""Create a new access token."""
|
||||
# Ensure `data` includes `jti` for consistency
|
||||
if "jti" not in data:
|
||||
data["jti"] = str(uuid4())
|
||||
return create_token(data, expires_delta, "access")
|
||||
|
||||
|
||||
def create_refresh_token(data: dict, expires_delta: Optional[timedelta] = None) -> str:
|
||||
"""Create a new refresh token."""
|
||||
# Ensure `data` includes `jti` for consistency
|
||||
if "jti" not in data:
|
||||
data["jti"] = str(uuid4())
|
||||
return create_token(data, expires_delta, "refresh")
|
||||
|
||||
|
||||
async def decode_token(
|
||||
token: str,
|
||||
required_type: str = "access",
|
||||
db: AsyncSession = Depends(get_db)
|
||||
) -> TokenPayload:
|
||||
"""
|
||||
Decode and validate a JWT token, including revocation checks.
|
||||
|
||||
Args:
|
||||
token (str): The JWT token to decode.
|
||||
required_type (str): The expected token type (default: "access").
|
||||
db (AsyncSession): Database session for token revocation checks.
|
||||
|
||||
Returns:
|
||||
TokenPayload: The decoded token data.
|
||||
|
||||
Raises:
|
||||
JWTError: If the token is expired, revoked, malformed, or fails validation.
|
||||
"""
|
||||
try:
|
||||
# Step 1: Decode the JWT token
|
||||
payload = jwt.decode(
|
||||
token,
|
||||
SECRET_KEY,
|
||||
algorithms=[ALGORITHM],
|
||||
options={
|
||||
"verify_exp": True,
|
||||
"verify_iat": True,
|
||||
"require": ["exp", "iat", "sub", "type", "jti"]
|
||||
}
|
||||
)
|
||||
|
||||
except ExpiredSignatureError:
|
||||
raise JWTError("Token has expired. Please refresh your token or login again.")
|
||||
except JWTError as e:
|
||||
if "Signature verification failed" in str(e):
|
||||
raise JWTError("Invalid token signature. The token may have been tampered with or corrupted.")
|
||||
raise JWTError(f"Failed to decode the token: {e}")
|
||||
except JOSEError as e:
|
||||
if "segments" in str(e).lower():
|
||||
raise JWTError("Malformed token. The token format is invalid (e.g., not enough segments).")
|
||||
raise JWTError("Failed to decode the token. Ensure the token is valid and correctly formatted.") from e
|
||||
except Exception as e:
|
||||
# Catch-all for unexpected exceptions during decoding
|
||||
raise JWTError(f"An unexpected error occurred while decoding the token: {e}")
|
||||
|
||||
# Step 2: Validate Required Claims
|
||||
required_claims = ["exp", "sub", "type", "jti"]
|
||||
missing_claims = [claim for claim in required_claims if claim not in payload]
|
||||
if missing_claims:
|
||||
raise JWTError(f"Malformed token. Missing required claims: {', '.join(missing_claims)}.")
|
||||
|
||||
# Step 3: Validate Expiry
|
||||
expiration = datetime.fromtimestamp(payload["exp"])
|
||||
if datetime.now(timezone.utc) > expiration:
|
||||
raise JWTError("Token has expired. Please refresh your token or login again.")
|
||||
|
||||
# Step 4: Validate Token Type
|
||||
token_type = payload.get("type")
|
||||
if token_type != required_type:
|
||||
raise JWTError(f"Invalid token type: expected '{required_type}', got '{token_type}'.")
|
||||
|
||||
# Step 5: Check Revocation
|
||||
jti = payload.get("jti")
|
||||
if await is_token_revoked(jti, db):
|
||||
raise JWTError("Token has been revoked. Please login again to generate a new token.")
|
||||
|
||||
# Step 6: Return Validated Token Payload
|
||||
return TokenPayload(
|
||||
sub=payload["sub"],
|
||||
type=payload["type"],
|
||||
exp=expiration,
|
||||
iat=datetime.fromtimestamp(payload.get("iat", 0)),
|
||||
jti=jti
|
||||
)
|
||||
@@ -1,45 +0,0 @@
|
||||
from datetime import datetime, timezone, timedelta
|
||||
|
||||
from sqlalchemy import delete
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from app.models.token import RevokedToken
|
||||
|
||||
|
||||
async def revoke_token(jti: str, token_type: str, user_id: str, db: AsyncSession):
|
||||
"""Revoke a token by storing its `jti` in the revoked_tokens table."""
|
||||
revoked_token = RevokedToken(jti=jti, token_type=token_type, user_id=user_id)
|
||||
db.add(revoked_token)
|
||||
await db.commit()
|
||||
|
||||
|
||||
async def is_token_revoked(jti: str, db: AsyncSession) -> bool:
|
||||
"""Check whether the token's JTI is in the revoked_tokens table."""
|
||||
from sqlalchemy import select
|
||||
result = await db.execute(select(RevokedToken).where(RevokedToken.jti == jti))
|
||||
revoked = result.scalar_one_or_none()
|
||||
return revoked is not None
|
||||
|
||||
|
||||
async def cleanup_expired_tokens(db: AsyncSession):
|
||||
"""Delete revoked tokens that are past their expiration time."""
|
||||
now = datetime.now(timezone.utc)
|
||||
|
||||
# For access tokens (shorter expiry)
|
||||
expire_before = now - timedelta(days=1) # Keep for 1 day past expiry
|
||||
await db.execute(
|
||||
delete(RevokedToken).where(
|
||||
(RevokedToken.token_type == "access") &
|
||||
(RevokedToken.created_at < expire_before)
|
||||
)
|
||||
)
|
||||
|
||||
# For refresh tokens (longer expiry)
|
||||
expire_before = now - timedelta(days=14) # Keep for 14 days past expiry
|
||||
await db.execute(
|
||||
delete(RevokedToken).where(
|
||||
(RevokedToken.token_type == "refresh") &
|
||||
(RevokedToken.created_at < expire_before)
|
||||
)
|
||||
)
|
||||
|
||||
await db.commit()
|
||||
Reference in New Issue
Block a user