Remove token revocation logic and unused dependencies

Eliminated the `RevokedToken` model and associated logic for managing token revocation. Removed unused files, related tests, and outdated dependencies in authentication modules. Simplified token decoding, user validation, and dependency injection by streamlining the flow and enhancing maintainability.
This commit is contained in:
2025-03-02 11:04:12 +01:00
parent 453016629f
commit cd92cd9780
24 changed files with 954 additions and 781 deletions

View File

@@ -1,40 +0,0 @@
from fastapi import Depends, HTTPException, status
from fastapi.security import OAuth2PasswordBearer
from jose import JWTError, jwt
from sqlalchemy.ext.asyncio import AsyncSession
from app.core.database import get_db
from app.auth.security import decode_token
from app.models.user import User
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="auth/token")
async def get_current_user(
token: str = Depends(oauth2_scheme),
db: AsyncSession = Depends(get_db)
):
try:
payload = await decode_token(token) # Use updated decode_token.
user_id: str = payload.sub
token_type: str = payload.type
if user_id is None or token_type != "access":
raise HTTPException(status_code=401, detail="Invalid token type.")
user = await db.get(User, user_id)
if user is None:
raise HTTPException(status_code=401, detail="User not found.")
return user
except JWTError as e:
raise HTTPException(status_code=401, detail=str(e))
async def get_current_active_user(
current_user: User = Depends(get_current_user),
):
if not current_user.is_active:
raise HTTPException(status_code=400, detail="Inactive user")
return current_user

View File

@@ -1,176 +0,0 @@
from datetime import datetime, timedelta, timezone
from typing import Optional
from uuid import uuid4
from fastapi import Depends
from jose import jwt, JWTError, ExpiredSignatureError, JOSEError
from passlib.context import CryptContext
from sqlalchemy.ext.asyncio import AsyncSession
from app.core.config import settings
from app.core.database import get_db
from app.schemas.token import TokenPayload, TokenResponse
from auth.utils import is_token_revoked
# Configuration
SECRET_KEY = settings.SECRET_KEY
ALGORITHM = "HS256"
ACCESS_TOKEN_EXPIRE_MINUTES = 30
REFRESH_TOKEN_EXPIRE_DAYS = 7
# Password hashing context
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
def verify_password(plain_password: str, hashed_password: str) -> bool:
"""Verify a plain password against its hash."""
return pwd_context.verify(plain_password, hashed_password)
def get_password_hash(password: str) -> str:
"""Generate password hash."""
return pwd_context.hash(password)
def create_tokens(user_id: str) -> TokenResponse:
"""
Create both access and refresh tokens for a user.
Args:
user_id: The user's ID
Returns:
TokenResponse containing both tokens and metadata
"""
# Add `jti` during token creation
access_token = create_access_token({"sub": user_id, "jti": str(uuid4())})
refresh_token = create_refresh_token({"sub": user_id, "jti": str(uuid4())})
return TokenResponse(
access_token=access_token,
refresh_token=refresh_token,
token_type="bearer",
expires_in=ACCESS_TOKEN_EXPIRE_MINUTES * 60,
user_id=user_id,
scope="read write"
)
def create_token(
data: dict,
expires_delta: Optional[timedelta] = None,
token_type: str = "access"
) -> str:
"""Create a JWT token with the specified type and expiration."""
to_encode = data.copy()
if expires_delta:
expire = datetime.now(timezone.utc) + expires_delta
else:
expire = datetime.now(timezone.utc) + (
timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES) if token_type == "access"
else timedelta(days=REFRESH_TOKEN_EXPIRE_DAYS)
)
to_encode.update({
"exp": expire,
"type": token_type,
"iat": datetime.now(timezone.utc),
})
if "jti" not in to_encode:
to_encode["jti"] = str(uuid4()) # Ensure unique `jti` is always added
return jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)
def create_access_token(data: dict, expires_delta: Optional[timedelta] = None) -> str:
"""Create a new access token."""
# Ensure `data` includes `jti` for consistency
if "jti" not in data:
data["jti"] = str(uuid4())
return create_token(data, expires_delta, "access")
def create_refresh_token(data: dict, expires_delta: Optional[timedelta] = None) -> str:
"""Create a new refresh token."""
# Ensure `data` includes `jti` for consistency
if "jti" not in data:
data["jti"] = str(uuid4())
return create_token(data, expires_delta, "refresh")
async def decode_token(
token: str,
required_type: str = "access",
db: AsyncSession = Depends(get_db)
) -> TokenPayload:
"""
Decode and validate a JWT token, including revocation checks.
Args:
token (str): The JWT token to decode.
required_type (str): The expected token type (default: "access").
db (AsyncSession): Database session for token revocation checks.
Returns:
TokenPayload: The decoded token data.
Raises:
JWTError: If the token is expired, revoked, malformed, or fails validation.
"""
try:
# Step 1: Decode the JWT token
payload = jwt.decode(
token,
SECRET_KEY,
algorithms=[ALGORITHM],
options={
"verify_exp": True,
"verify_iat": True,
"require": ["exp", "iat", "sub", "type", "jti"]
}
)
except ExpiredSignatureError:
raise JWTError("Token has expired. Please refresh your token or login again.")
except JWTError as e:
if "Signature verification failed" in str(e):
raise JWTError("Invalid token signature. The token may have been tampered with or corrupted.")
raise JWTError(f"Failed to decode the token: {e}")
except JOSEError as e:
if "segments" in str(e).lower():
raise JWTError("Malformed token. The token format is invalid (e.g., not enough segments).")
raise JWTError("Failed to decode the token. Ensure the token is valid and correctly formatted.") from e
except Exception as e:
# Catch-all for unexpected exceptions during decoding
raise JWTError(f"An unexpected error occurred while decoding the token: {e}")
# Step 2: Validate Required Claims
required_claims = ["exp", "sub", "type", "jti"]
missing_claims = [claim for claim in required_claims if claim not in payload]
if missing_claims:
raise JWTError(f"Malformed token. Missing required claims: {', '.join(missing_claims)}.")
# Step 3: Validate Expiry
expiration = datetime.fromtimestamp(payload["exp"])
if datetime.now(timezone.utc) > expiration:
raise JWTError("Token has expired. Please refresh your token or login again.")
# Step 4: Validate Token Type
token_type = payload.get("type")
if token_type != required_type:
raise JWTError(f"Invalid token type: expected '{required_type}', got '{token_type}'.")
# Step 5: Check Revocation
jti = payload.get("jti")
if await is_token_revoked(jti, db):
raise JWTError("Token has been revoked. Please login again to generate a new token.")
# Step 6: Return Validated Token Payload
return TokenPayload(
sub=payload["sub"],
type=payload["type"],
exp=expiration,
iat=datetime.fromtimestamp(payload.get("iat", 0)),
jti=jti
)

View File

@@ -1,45 +0,0 @@
from datetime import datetime, timezone, timedelta
from sqlalchemy import delete
from sqlalchemy.ext.asyncio import AsyncSession
from app.models.token import RevokedToken
async def revoke_token(jti: str, token_type: str, user_id: str, db: AsyncSession):
"""Revoke a token by storing its `jti` in the revoked_tokens table."""
revoked_token = RevokedToken(jti=jti, token_type=token_type, user_id=user_id)
db.add(revoked_token)
await db.commit()
async def is_token_revoked(jti: str, db: AsyncSession) -> bool:
"""Check whether the token's JTI is in the revoked_tokens table."""
from sqlalchemy import select
result = await db.execute(select(RevokedToken).where(RevokedToken.jti == jti))
revoked = result.scalar_one_or_none()
return revoked is not None
async def cleanup_expired_tokens(db: AsyncSession):
"""Delete revoked tokens that are past their expiration time."""
now = datetime.now(timezone.utc)
# For access tokens (shorter expiry)
expire_before = now - timedelta(days=1) # Keep for 1 day past expiry
await db.execute(
delete(RevokedToken).where(
(RevokedToken.token_type == "access") &
(RevokedToken.created_at < expire_before)
)
)
# For refresh tokens (longer expiry)
expire_before = now - timedelta(days=14) # Keep for 14 days past expiry
await db.execute(
delete(RevokedToken).where(
(RevokedToken.token_type == "refresh") &
(RevokedToken.created_at < expire_before)
)
)
await db.commit()