Refactor token handling and introduce token revocation logic
Updated `decode_token` for stricter validation of token claims and explicit error handling. Added utilities for token revocation and verification, improving
This commit is contained in:
@@ -1,3 +1,4 @@
|
||||
from datetime import datetime, timezone
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
import pytest
|
||||
@@ -21,13 +22,22 @@ def mock_user():
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_current_user_success(mock_user):
|
||||
valid_token = jwt.encode({"sub": str(mock_user.id), "type": "access"}, SECRET_KEY, algorithm=ALGORITHM)
|
||||
# Create a valid access token with required claims
|
||||
valid_token = jwt.encode(
|
||||
{"sub": str(mock_user.id), "type": "access", "exp": datetime.now(tz=timezone.utc).timestamp() + 3600},
|
||||
SECRET_KEY,
|
||||
algorithm=ALGORITHM
|
||||
)
|
||||
|
||||
# Mock database session
|
||||
mock_db = AsyncMock()
|
||||
mock_db.get.return_value = mock_user
|
||||
mock_db.get.return_value = mock_user # Ensure `db.get()` returns the mock_user
|
||||
|
||||
# Call the dependency
|
||||
user = await get_current_user(token=valid_token, db=mock_db)
|
||||
assert user == mock_user
|
||||
|
||||
# Assertions
|
||||
assert user == mock_user, "Returned user does not match the mocked user."
|
||||
mock_db.get.assert_called_once_with(User, mock_user.id)
|
||||
|
||||
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import pytest
|
||||
from datetime import timedelta
|
||||
from jose import jwt, JWTError
|
||||
from datetime import timedelta, datetime, timezone
|
||||
from jose import jwt, JWTError, ExpiredSignatureError
|
||||
from app.auth.security import (
|
||||
get_password_hash, verify_password,
|
||||
create_access_token, create_refresh_token,
|
||||
@@ -52,17 +52,59 @@ def test_decode_token_expired():
|
||||
user_id = "123e4567-e89b-12d3-a456-426614174000"
|
||||
token = create_access_token({"sub": user_id}, expires_delta=timedelta(seconds=-1))
|
||||
|
||||
with pytest.raises(JWTError):
|
||||
with pytest.raises(JWTError) as exc_info:
|
||||
decode_token(token)
|
||||
|
||||
assert str(exc_info.value) == "Token expired. Please refresh your token to continue."
|
||||
|
||||
|
||||
def test_decode_token_missing_exp():
|
||||
# Create a token without the `exp` claim
|
||||
payload = {"sub": "123e4567-e89b-12d3-a456-426614174000", "type": "access"}
|
||||
token = jwt.encode(payload, SECRET_KEY, algorithm=ALGORITHM)
|
||||
|
||||
with pytest.raises(JWTError) as exc_info:
|
||||
decode_token(token)
|
||||
|
||||
assert str(exc_info.value) == "Malformed token. Missing required claim."
|
||||
|
||||
|
||||
|
||||
def test_decode_token_missing_sub():
|
||||
# Create a token without the `sub` claim
|
||||
payload = {"exp": datetime.now().timestamp() + 60, "type": "access"}
|
||||
token = jwt.encode(payload, SECRET_KEY, algorithm=ALGORITHM)
|
||||
|
||||
with pytest.raises(JWTError) as exc_info:
|
||||
decode_token(token)
|
||||
|
||||
assert str(exc_info.value) == "Malformed token. Missing required claim."
|
||||
|
||||
|
||||
def test_decode_token_invalid_signature():
|
||||
token = jwt.encode({"some": "data"}, "invalid_key", algorithm=ALGORITHM)
|
||||
with pytest.raises(JWTError):
|
||||
# Use a different secret key for signing
|
||||
token = jwt.encode({"sub": "123", "type": "access"}, "wrong_secret", algorithm=ALGORITHM)
|
||||
|
||||
with pytest.raises(JWTError) as exc_info:
|
||||
decode_token(token)
|
||||
|
||||
assert str(exc_info.value) == "Invalid token."
|
||||
|
||||
|
||||
def test_decode_token_malformed():
|
||||
malformed_token = "malformed.header.payload"
|
||||
with pytest.raises(JWTError):
|
||||
decode_token(malformed_token)
|
||||
|
||||
with pytest.raises(JWTError) as exc_info:
|
||||
decode_token(malformed_token)
|
||||
|
||||
assert str(exc_info.value) == "Invalid token."
|
||||
|
||||
|
||||
def test_decode_token_invalid_type():
|
||||
user_id = "123e4567-e89b-12d3-a456-426614174000"
|
||||
token = create_refresh_token({"sub": user_id}) # Token type is "refresh"
|
||||
|
||||
with pytest.raises(JWTError) as exc_info:
|
||||
decode_token(token, required_type="access") # Expecting an access token
|
||||
|
||||
assert str(exc_info.value) == "Invalid token type: expected 'access', got 'refresh'."
|
||||
|
||||
Reference in New Issue
Block a user