Refactor token handling and introduce token revocation logic

Updated `decode_token` for stricter validation of token claims and explicit error handling. Added utilities for token revocation and verification, improving
This commit is contained in:
2025-02-28 16:57:57 +01:00
parent c3a55b26c7
commit 548880b468
7 changed files with 124 additions and 36 deletions

View File

@@ -4,7 +4,7 @@ from jose import JWTError, jwt
from sqlalchemy.ext.asyncio import AsyncSession
from app.core.database import get_db
from auth.security import SECRET_KEY, ALGORITHM
from app.auth.security import decode_token
from app.models.user import User
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="auth/token")
@@ -14,28 +14,22 @@ async def get_current_user(
token: str = Depends(oauth2_scheme),
db: AsyncSession = Depends(get_db)
):
credentials_exception = HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Could not validate credentials",
headers={"WWW-Authenticate": "Bearer"},
)
try:
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
user_id: str = payload.get("sub")
token_type: str = payload.get("type")
payload = decode_token(token) # Use updated decode_token.
user_id: str = payload.sub
token_type: str = payload.type
if user_id is None or token_type != "access":
raise credentials_exception
raise HTTPException(status_code=401, detail="Invalid token type.")
except JWTError:
raise credentials_exception
user = await db.get(User, user_id)
if user is None:
raise HTTPException(status_code=401, detail="User not found.")
user = await db.get(User, user_id)
if user is None:
raise credentials_exception
return user
except JWTError as e:
raise HTTPException(status_code=401, detail=str(e))
return user
async def get_current_active_user(

View File

@@ -2,7 +2,8 @@ from datetime import datetime, timedelta
from typing import Optional, Tuple
from uuid import uuid4
from jose import jwt, JWTError
from black import timezone
from jose import jwt, ExpiredSignatureError, JWTError
from passlib.context import CryptContext
from app.core.config import settings
from app.schemas.token import TokenPayload, TokenResponse
@@ -75,30 +76,54 @@ def create_token(
return jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)
def decode_token(token: str) -> TokenPayload:
def decode_token(token: str, required_type: str = "access") -> TokenPayload:
"""
Decode and validate a JWT token.
Decode and validate a JWT token with explicit edge-case handling.
Args:
token: The JWT token to decode
token: The JWT token to decode.
required_type: The expected token type (default: "access").
Returns:
TokenPayload containing the decoded data
TokenPayload containing the decoded data.
Raises:
JWTError: If token is invalid or expired
JWTError: If the token is expired, invalid, or malformed.
"""
try:
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
# Explicitly validate required claims (`exp`, `sub`, `type`)
if "exp" not in payload or "sub" not in payload or "type" not in payload:
raise KeyError("Missing required claim.")
# Verify token expiration (`exp`)
if datetime.now() > datetime.fromtimestamp(payload["exp"]):
raise ExpiredSignatureError("Token has expired.")
# Verify token type (`type`)
if payload["type"] != required_type:
expected_type = required_type
actual_type = payload["type"]
raise JWTError(f"Invalid token type: expected '{expected_type}', got '{actual_type}'.")
# Create TokenPayload object from token claims
return TokenPayload(
sub=payload["sub"],
type=payload["type"],
exp=datetime.fromtimestamp(payload["exp"]),
iat=datetime.fromtimestamp(payload["iat"]),
iat=datetime.fromtimestamp(payload.get("iat", 0)),
jti=payload.get("jti")
)
except KeyError as e:
raise JWTError("Malformed token. Missing required claim.") from e
except ExpiredSignatureError as e:
raise JWTError("Token expired. Please refresh your token to continue.") from e
except JWTError as e:
raise JWTError(f"Invalid token: {str(e)}")
raise JWTError(str(e)) from e
def create_access_token(data: dict, expires_delta: Optional[timedelta] = None) -> str:
@@ -108,4 +133,4 @@ def create_access_token(data: dict, expires_delta: Optional[timedelta] = None) -
def create_refresh_token(data: dict, expires_delta: Optional[timedelta] = None) -> str:
"""Create a new refresh token."""
return create_token(data, expires_delta, "refresh")
return create_token(data, expires_delta, "refresh")

16
backend/app/auth/utlis.py Normal file
View File

@@ -0,0 +1,16 @@
# auth/utils.py
from sqlalchemy.ext.asyncio import AsyncSession
from app.models.token import RevokedToken
async def revoke_token(jti: str, token_type: str, user_id: str, db: AsyncSession):
"""Revoke a token by adding its `jti` to the database."""
revoked_token = RevokedToken(jti=jti, token_type=token_type, user_id=user_id)
db.add(revoked_token)
await db.commit()
async def is_token_revoked(jti: str, db: AsyncSession):
"""Check if a token with the given `jti` is revoked."""
result = await db.get(RevokedToken, jti)
return result is not None

View File

View File

@@ -7,3 +7,4 @@ addopts = --disable-warnings
markers =
sqlite: marks tests that should run on SQLite (mocked).
postgres: marks tests that require a real PostgreSQL database.
asyncio_default_fixture_loop_scope = function

View File

@@ -1,3 +1,4 @@
from datetime import datetime, timezone
from unittest.mock import AsyncMock
import pytest
@@ -21,13 +22,22 @@ def mock_user():
@pytest.mark.asyncio
async def test_get_current_user_success(mock_user):
valid_token = jwt.encode({"sub": str(mock_user.id), "type": "access"}, SECRET_KEY, algorithm=ALGORITHM)
# Create a valid access token with required claims
valid_token = jwt.encode(
{"sub": str(mock_user.id), "type": "access", "exp": datetime.now(tz=timezone.utc).timestamp() + 3600},
SECRET_KEY,
algorithm=ALGORITHM
)
# Mock database session
mock_db = AsyncMock()
mock_db.get.return_value = mock_user
mock_db.get.return_value = mock_user # Ensure `db.get()` returns the mock_user
# Call the dependency
user = await get_current_user(token=valid_token, db=mock_db)
assert user == mock_user
# Assertions
assert user == mock_user, "Returned user does not match the mocked user."
mock_db.get.assert_called_once_with(User, mock_user.id)

View File

@@ -1,6 +1,6 @@
import pytest
from datetime import timedelta
from jose import jwt, JWTError
from datetime import timedelta, datetime, timezone
from jose import jwt, JWTError, ExpiredSignatureError
from app.auth.security import (
get_password_hash, verify_password,
create_access_token, create_refresh_token,
@@ -52,17 +52,59 @@ def test_decode_token_expired():
user_id = "123e4567-e89b-12d3-a456-426614174000"
token = create_access_token({"sub": user_id}, expires_delta=timedelta(seconds=-1))
with pytest.raises(JWTError):
with pytest.raises(JWTError) as exc_info:
decode_token(token)
assert str(exc_info.value) == "Token expired. Please refresh your token to continue."
def test_decode_token_missing_exp():
# Create a token without the `exp` claim
payload = {"sub": "123e4567-e89b-12d3-a456-426614174000", "type": "access"}
token = jwt.encode(payload, SECRET_KEY, algorithm=ALGORITHM)
with pytest.raises(JWTError) as exc_info:
decode_token(token)
assert str(exc_info.value) == "Malformed token. Missing required claim."
def test_decode_token_missing_sub():
# Create a token without the `sub` claim
payload = {"exp": datetime.now().timestamp() + 60, "type": "access"}
token = jwt.encode(payload, SECRET_KEY, algorithm=ALGORITHM)
with pytest.raises(JWTError) as exc_info:
decode_token(token)
assert str(exc_info.value) == "Malformed token. Missing required claim."
def test_decode_token_invalid_signature():
token = jwt.encode({"some": "data"}, "invalid_key", algorithm=ALGORITHM)
with pytest.raises(JWTError):
# Use a different secret key for signing
token = jwt.encode({"sub": "123", "type": "access"}, "wrong_secret", algorithm=ALGORITHM)
with pytest.raises(JWTError) as exc_info:
decode_token(token)
assert str(exc_info.value) == "Invalid token."
def test_decode_token_malformed():
malformed_token = "malformed.header.payload"
with pytest.raises(JWTError):
decode_token(malformed_token)
with pytest.raises(JWTError) as exc_info:
decode_token(malformed_token)
assert str(exc_info.value) == "Invalid token."
def test_decode_token_invalid_type():
user_id = "123e4567-e89b-12d3-a456-426614174000"
token = create_refresh_token({"sub": user_id}) # Token type is "refresh"
with pytest.raises(JWTError) as exc_info:
decode_token(token, required_type="access") # Expecting an access token
assert str(exc_info.value) == "Invalid token type: expected 'access', got 'refresh'."