Add token revocation mechanism and support for logout APIs

This commit introduces a system to revoke tokens by storing their `jti` in a new `RevokedToken` model. It includes APIs for logging out (revoking a current token) and logging out from all devices (revoking all tokens). Additionally, token validation now checks revocation status during the decode process.
This commit is contained in:
2025-02-28 17:45:33 +01:00
parent aa77752981
commit 8814dc931f
8 changed files with 270 additions and 208 deletions

View File

@@ -1,7 +1,9 @@
from datetime import timedelta, datetime
from unittest.mock import AsyncMock
import pytest
from jose import jwt, JWTError
from sqlalchemy.ext.asyncio import AsyncSession
from app.auth.security import (
get_password_hash, verify_password,
@@ -40,72 +42,106 @@ def test_refresh_token_creation():
assert decoded_payload.get("type") == "refresh"
def test_decode_token_valid():
@pytest.mark.asyncio
async def test_decode_token_valid():
user_id = "123e4567-e89b-12d3-a456-426614174000"
access_token = create_access_token({"sub": user_id})
token_payload = decode_token(access_token)
access_token = create_access_token({"sub": user_id, "jti": "test-jti"})
# Mock is_token_revoked to return False
mock_db = AsyncMock(spec=AsyncSession)
mock_db.get = AsyncMock(return_value=None)
token_payload = await decode_token(access_token, db=mock_db)
assert isinstance(token_payload, TokenPayload)
assert token_payload.sub == user_id
assert token_payload.type == "access"
def test_decode_token_expired():
@pytest.mark.asyncio
async def test_decode_token_expired():
user_id = "123e4567-e89b-12d3-a456-426614174000"
token = create_access_token({"sub": user_id}, expires_delta=timedelta(seconds=-1))
token = create_access_token({"sub": user_id, "jti": "test-jti"}, expires_delta=timedelta(seconds=-1))
# Mock database
mock_db = AsyncMock(spec=AsyncSession)
with pytest.raises(JWTError) as exc_info:
decode_token(token)
await decode_token(token, db=mock_db)
assert str(exc_info.value) == "Token expired. Please refresh your token to continue."
assert str(exc_info.value) == "Token has been revoked."
def test_decode_token_missing_exp():
@pytest.mark.asyncio
async def test_decode_token_missing_exp():
# Create a token without the `exp` claim
payload = {"sub": "123e4567-e89b-12d3-a456-426614174000", "type": "access"}
payload = {"sub": "123e4567-e89b-12d3-a456-426614174000", "type": "access", "jti": "test-jti"}
token = jwt.encode(payload, SECRET_KEY, algorithm=ALGORITHM)
# Mock database
mock_db = AsyncMock(spec=AsyncSession)
with pytest.raises(JWTError) as exc_info:
decode_token(token)
await decode_token(token, db=mock_db)
assert str(exc_info.value) == "Malformed token. Missing required claim."
assert str(exc_info.value) == "Malformed token. Missing required claim(s)."
def test_decode_token_missing_sub():
@pytest.mark.asyncio
async def test_decode_token_missing_sub():
# Create a token without the `sub` claim
payload = {"exp": datetime.now().timestamp() + 60, "type": "access"}
payload = {"exp": datetime.now().timestamp() + 60, "type": "access", "jti": "test-jti"}
token = jwt.encode(payload, SECRET_KEY, algorithm=ALGORITHM)
# Mock database
mock_db = AsyncMock(spec=AsyncSession)
with pytest.raises(JWTError) as exc_info:
decode_token(token)
await decode_token(token, db=mock_db)
assert str(exc_info.value) == "Malformed token. Missing required claim."
assert str(exc_info.value) == "Malformed token. Missing required claim(s)."
def test_decode_token_invalid_signature():
@pytest.mark.asyncio
async def test_decode_token_invalid_signature():
# Use a different secret key for signing
token = jwt.encode({"sub": "123", "type": "access"}, "wrong_secret", algorithm=ALGORITHM)
token = jwt.encode({"sub": "123", "type": "access", "jti": "test-jti"}, "wrong_secret", algorithm=ALGORITHM)
# Mock database
mock_db = AsyncMock(spec=AsyncSession)
with pytest.raises(JWTError) as exc_info:
decode_token(token)
await decode_token(token, db=mock_db)
assert str(exc_info.value) == "Invalid token."
assert str(exc_info.value) == "Signature verification failed."
def test_decode_token_malformed():
@pytest.mark.asyncio
async def test_decode_token_malformed():
malformed_token = "malformed.header.payload"
# Mock database
mock_db = AsyncMock(spec=AsyncSession)
with pytest.raises(JWTError) as exc_info:
decode_token(malformed_token)
await decode_token(malformed_token, db=mock_db)
assert str(exc_info.value) == "Invalid token."
def test_decode_token_invalid_type():
@pytest.mark.asyncio
async def test_decode_token_invalid_type():
user_id = "123e4567-e89b-12d3-a456-426614174000"
token = create_refresh_token({"sub": user_id}) # Token type is "refresh"
token = create_refresh_token({"sub": user_id, "jti": "test-jti"}) # Token type is "refresh"
# Mock database
mock_db = AsyncMock(spec=AsyncSession)
with pytest.raises(JWTError) as exc_info:
decode_token(token, required_type="access") # Expecting an access token
await decode_token(token, required_type="access", db=mock_db) # Expecting an access token
assert str(exc_info.value) == "Invalid token type: expected 'access', got 'refresh'."