Refactor token handling and introduce token revocation logic

Updated `decode_token` for stricter validation of token claims and explicit error handling. Added utilities for token revocation and verification, improving
This commit is contained in:
2025-02-28 16:57:57 +01:00
parent c3a55b26c7
commit 548880b468
7 changed files with 124 additions and 36 deletions

View File

@@ -1,6 +1,6 @@
import pytest
from datetime import timedelta
from jose import jwt, JWTError
from datetime import timedelta, datetime, timezone
from jose import jwt, JWTError, ExpiredSignatureError
from app.auth.security import (
get_password_hash, verify_password,
create_access_token, create_refresh_token,
@@ -52,17 +52,59 @@ def test_decode_token_expired():
user_id = "123e4567-e89b-12d3-a456-426614174000"
token = create_access_token({"sub": user_id}, expires_delta=timedelta(seconds=-1))
with pytest.raises(JWTError):
with pytest.raises(JWTError) as exc_info:
decode_token(token)
assert str(exc_info.value) == "Token expired. Please refresh your token to continue."
def test_decode_token_missing_exp():
# Create a token without the `exp` claim
payload = {"sub": "123e4567-e89b-12d3-a456-426614174000", "type": "access"}
token = jwt.encode(payload, SECRET_KEY, algorithm=ALGORITHM)
with pytest.raises(JWTError) as exc_info:
decode_token(token)
assert str(exc_info.value) == "Malformed token. Missing required claim."
def test_decode_token_missing_sub():
# Create a token without the `sub` claim
payload = {"exp": datetime.now().timestamp() + 60, "type": "access"}
token = jwt.encode(payload, SECRET_KEY, algorithm=ALGORITHM)
with pytest.raises(JWTError) as exc_info:
decode_token(token)
assert str(exc_info.value) == "Malformed token. Missing required claim."
def test_decode_token_invalid_signature():
token = jwt.encode({"some": "data"}, "invalid_key", algorithm=ALGORITHM)
with pytest.raises(JWTError):
# Use a different secret key for signing
token = jwt.encode({"sub": "123", "type": "access"}, "wrong_secret", algorithm=ALGORITHM)
with pytest.raises(JWTError) as exc_info:
decode_token(token)
assert str(exc_info.value) == "Invalid token."
def test_decode_token_malformed():
malformed_token = "malformed.header.payload"
with pytest.raises(JWTError):
decode_token(malformed_token)
with pytest.raises(JWTError) as exc_info:
decode_token(malformed_token)
assert str(exc_info.value) == "Invalid token."
def test_decode_token_invalid_type():
user_id = "123e4567-e89b-12d3-a456-426614174000"
token = create_refresh_token({"sub": user_id}) # Token type is "refresh"
with pytest.raises(JWTError) as exc_info:
decode_token(token, required_type="access") # Expecting an access token
assert str(exc_info.value) == "Invalid token type: expected 'access', got 'refresh'."