Refactor token handling and introduce token revocation logic

Updated `decode_token` for stricter validation of token claims and explicit error handling. Added utilities for token revocation and verification, improving
This commit is contained in:
2025-02-28 16:57:57 +01:00
parent c3a55b26c7
commit 548880b468
7 changed files with 124 additions and 36 deletions

View File

@@ -4,7 +4,7 @@ from jose import JWTError, jwt
from sqlalchemy.ext.asyncio import AsyncSession
from app.core.database import get_db
from auth.security import SECRET_KEY, ALGORITHM
from app.auth.security import decode_token
from app.models.user import User
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="auth/token")
@@ -14,28 +14,22 @@ async def get_current_user(
token: str = Depends(oauth2_scheme),
db: AsyncSession = Depends(get_db)
):
credentials_exception = HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Could not validate credentials",
headers={"WWW-Authenticate": "Bearer"},
)
try:
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
user_id: str = payload.get("sub")
token_type: str = payload.get("type")
payload = decode_token(token) # Use updated decode_token.
user_id: str = payload.sub
token_type: str = payload.type
if user_id is None or token_type != "access":
raise credentials_exception
raise HTTPException(status_code=401, detail="Invalid token type.")
except JWTError:
raise credentials_exception
user = await db.get(User, user_id)
if user is None:
raise HTTPException(status_code=401, detail="User not found.")
user = await db.get(User, user_id)
if user is None:
raise credentials_exception
return user
except JWTError as e:
raise HTTPException(status_code=401, detail=str(e))
return user
async def get_current_active_user(

View File

@@ -2,7 +2,8 @@ from datetime import datetime, timedelta
from typing import Optional, Tuple
from uuid import uuid4
from jose import jwt, JWTError
from black import timezone
from jose import jwt, ExpiredSignatureError, JWTError
from passlib.context import CryptContext
from app.core.config import settings
from app.schemas.token import TokenPayload, TokenResponse
@@ -75,30 +76,54 @@ def create_token(
return jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)
def decode_token(token: str) -> TokenPayload:
def decode_token(token: str, required_type: str = "access") -> TokenPayload:
"""
Decode and validate a JWT token.
Decode and validate a JWT token with explicit edge-case handling.
Args:
token: The JWT token to decode
token: The JWT token to decode.
required_type: The expected token type (default: "access").
Returns:
TokenPayload containing the decoded data
TokenPayload containing the decoded data.
Raises:
JWTError: If token is invalid or expired
JWTError: If the token is expired, invalid, or malformed.
"""
try:
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
# Explicitly validate required claims (`exp`, `sub`, `type`)
if "exp" not in payload or "sub" not in payload or "type" not in payload:
raise KeyError("Missing required claim.")
# Verify token expiration (`exp`)
if datetime.now() > datetime.fromtimestamp(payload["exp"]):
raise ExpiredSignatureError("Token has expired.")
# Verify token type (`type`)
if payload["type"] != required_type:
expected_type = required_type
actual_type = payload["type"]
raise JWTError(f"Invalid token type: expected '{expected_type}', got '{actual_type}'.")
# Create TokenPayload object from token claims
return TokenPayload(
sub=payload["sub"],
type=payload["type"],
exp=datetime.fromtimestamp(payload["exp"]),
iat=datetime.fromtimestamp(payload["iat"]),
iat=datetime.fromtimestamp(payload.get("iat", 0)),
jti=payload.get("jti")
)
except KeyError as e:
raise JWTError("Malformed token. Missing required claim.") from e
except ExpiredSignatureError as e:
raise JWTError("Token expired. Please refresh your token to continue.") from e
except JWTError as e:
raise JWTError(f"Invalid token: {str(e)}")
raise JWTError(str(e)) from e
def create_access_token(data: dict, expires_delta: Optional[timedelta] = None) -> str:
@@ -108,4 +133,4 @@ def create_access_token(data: dict, expires_delta: Optional[timedelta] = None) -
def create_refresh_token(data: dict, expires_delta: Optional[timedelta] = None) -> str:
"""Create a new refresh token."""
return create_token(data, expires_delta, "refresh")
return create_token(data, expires_delta, "refresh")

16
backend/app/auth/utlis.py Normal file
View File

@@ -0,0 +1,16 @@
# auth/utils.py
from sqlalchemy.ext.asyncio import AsyncSession
from app.models.token import RevokedToken
async def revoke_token(jti: str, token_type: str, user_id: str, db: AsyncSession):
"""Revoke a token by adding its `jti` to the database."""
revoked_token = RevokedToken(jti=jti, token_type=token_type, user_id=user_id)
db.add(revoked_token)
await db.commit()
async def is_token_revoked(jti: str, db: AsyncSession):
"""Check if a token with the given `jti` is revoked."""
result = await db.get(RevokedToken, jti)
return result is not None

View File