This commit introduces a system to revoke tokens by storing their `jti` in a new `RevokedToken` model. It includes APIs for logging out (revoking a current token) and logging out from all devices (revoking all tokens). Additionally, token validation now checks revocation status during the decode process.
148 lines
4.6 KiB
Python
148 lines
4.6 KiB
Python
from datetime import timedelta, datetime
|
|
from unittest.mock import AsyncMock
|
|
|
|
import pytest
|
|
from jose import jwt, JWTError
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
|
|
from app.auth.security import (
|
|
get_password_hash, verify_password,
|
|
create_access_token, create_refresh_token,
|
|
decode_token, SECRET_KEY, ALGORITHM
|
|
)
|
|
from app.schemas.token import TokenPayload
|
|
|
|
|
|
def test_password_hashing():
|
|
plain_password = "securepassword123"
|
|
hashed_password = get_password_hash(plain_password)
|
|
|
|
# Ensure hashed passwords are not the same
|
|
assert hashed_password != plain_password
|
|
# Test password verification
|
|
assert verify_password(plain_password, hashed_password)
|
|
assert not verify_password("wrongpassword", hashed_password)
|
|
|
|
|
|
def test_access_token_creation():
|
|
user_id = "123e4567-e89b-12d3-a456-426614174000"
|
|
token = create_access_token({"sub": user_id})
|
|
decoded_payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
|
|
|
|
assert decoded_payload.get("sub") == user_id
|
|
assert decoded_payload.get("type") == "access"
|
|
|
|
|
|
def test_refresh_token_creation():
|
|
user_id = "123e4567-e89b-12d3-a456-426614174000"
|
|
token = create_refresh_token({"sub": user_id})
|
|
decoded_payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
|
|
|
|
assert decoded_payload.get("sub") == user_id
|
|
assert decoded_payload.get("type") == "refresh"
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_decode_token_valid():
|
|
user_id = "123e4567-e89b-12d3-a456-426614174000"
|
|
access_token = create_access_token({"sub": user_id, "jti": "test-jti"})
|
|
|
|
# Mock is_token_revoked to return False
|
|
mock_db = AsyncMock(spec=AsyncSession)
|
|
mock_db.get = AsyncMock(return_value=None)
|
|
|
|
token_payload = await decode_token(access_token, db=mock_db)
|
|
|
|
assert isinstance(token_payload, TokenPayload)
|
|
assert token_payload.sub == user_id
|
|
assert token_payload.type == "access"
|
|
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_decode_token_expired():
|
|
user_id = "123e4567-e89b-12d3-a456-426614174000"
|
|
token = create_access_token({"sub": user_id, "jti": "test-jti"}, expires_delta=timedelta(seconds=-1))
|
|
|
|
# Mock database
|
|
mock_db = AsyncMock(spec=AsyncSession)
|
|
|
|
with pytest.raises(JWTError) as exc_info:
|
|
await decode_token(token, db=mock_db)
|
|
|
|
assert str(exc_info.value) == "Token has been revoked."
|
|
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_decode_token_missing_exp():
|
|
# Create a token without the `exp` claim
|
|
payload = {"sub": "123e4567-e89b-12d3-a456-426614174000", "type": "access", "jti": "test-jti"}
|
|
token = jwt.encode(payload, SECRET_KEY, algorithm=ALGORITHM)
|
|
|
|
# Mock database
|
|
mock_db = AsyncMock(spec=AsyncSession)
|
|
|
|
with pytest.raises(JWTError) as exc_info:
|
|
await decode_token(token, db=mock_db)
|
|
|
|
assert str(exc_info.value) == "Malformed token. Missing required claim(s)."
|
|
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_decode_token_missing_sub():
|
|
# Create a token without the `sub` claim
|
|
payload = {"exp": datetime.now().timestamp() + 60, "type": "access", "jti": "test-jti"}
|
|
token = jwt.encode(payload, SECRET_KEY, algorithm=ALGORITHM)
|
|
|
|
# Mock database
|
|
mock_db = AsyncMock(spec=AsyncSession)
|
|
|
|
with pytest.raises(JWTError) as exc_info:
|
|
await decode_token(token, db=mock_db)
|
|
|
|
assert str(exc_info.value) == "Malformed token. Missing required claim(s)."
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_decode_token_invalid_signature():
|
|
# Use a different secret key for signing
|
|
token = jwt.encode({"sub": "123", "type": "access", "jti": "test-jti"}, "wrong_secret", algorithm=ALGORITHM)
|
|
|
|
# Mock database
|
|
mock_db = AsyncMock(spec=AsyncSession)
|
|
|
|
with pytest.raises(JWTError) as exc_info:
|
|
await decode_token(token, db=mock_db)
|
|
|
|
assert str(exc_info.value) == "Signature verification failed."
|
|
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_decode_token_malformed():
|
|
malformed_token = "malformed.header.payload"
|
|
|
|
# Mock database
|
|
mock_db = AsyncMock(spec=AsyncSession)
|
|
|
|
with pytest.raises(JWTError) as exc_info:
|
|
await decode_token(malformed_token, db=mock_db)
|
|
|
|
assert str(exc_info.value) == "Invalid token."
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_decode_token_invalid_type():
|
|
user_id = "123e4567-e89b-12d3-a456-426614174000"
|
|
token = create_refresh_token({"sub": user_id, "jti": "test-jti"}) # Token type is "refresh"
|
|
|
|
# Mock database
|
|
mock_db = AsyncMock(spec=AsyncSession)
|
|
|
|
with pytest.raises(JWTError) as exc_info:
|
|
await decode_token(token, required_type="access", db=mock_db) # Expecting an access token
|
|
|
|
assert str(exc_info.value) == "Invalid token type: expected 'access', got 'refresh'."
|