Remove token revocation logic and unused dependencies

Eliminated the `RevokedToken` model and associated logic for managing token revocation. Removed unused files, related tests, and outdated dependencies in authentication modules. Simplified token decoding, user validation, and dependency injection by streamlining the flow and enhancing maintainability.
This commit is contained in:
2025-03-02 11:04:12 +01:00
parent 453016629f
commit cd92cd9780
24 changed files with 954 additions and 781 deletions

View File

@@ -1,85 +0,0 @@
from datetime import datetime, timezone
from unittest.mock import AsyncMock
import pytest
from fastapi import HTTPException
from jose import jwt
from app.auth.dependencies import get_current_user, get_current_active_user
from app.auth.security import SECRET_KEY, ALGORITHM
from app.models.user import User
@pytest.fixture
def mock_user():
return User(
id="123e4567-e89b-12d3-a456-426614174000",
email="test@example.com",
password_hash="hashedpassword",
is_active=True
)
@pytest.mark.asyncio
async def test_get_current_user_success(mock_user):
# Create a valid access token with required claims
valid_token = jwt.encode(
{"sub": str(mock_user.id), "type": "access", "exp": datetime.now(tz=timezone.utc).timestamp() + 3600},
SECRET_KEY,
algorithm=ALGORITHM
)
# Mock database session
mock_db = AsyncMock()
mock_db.get.return_value = mock_user # Ensure `db.get()` returns the mock_user
# Call the dependency
user = await get_current_user(token=valid_token, db=mock_db)
# Assertions
assert user == mock_user, "Returned user does not match the mocked user."
mock_db.get.assert_called_once_with(User, mock_user.id)
@pytest.mark.asyncio
async def test_get_current_user_invalid_token():
invalid_token = "invalid.token.payload"
with pytest.raises(HTTPException) as exc_info:
await get_current_user(token=invalid_token, db=AsyncMock())
assert exc_info.value.status_code == 401
assert exc_info.value.detail == "Could not validate credentials"
@pytest.mark.asyncio
async def test_get_current_user_wrong_token_type():
token = jwt.encode({"sub": "123", "type": "refresh"}, SECRET_KEY, algorithm=ALGORITHM)
with pytest.raises(HTTPException) as exc_info:
await get_current_user(token=token, db=AsyncMock())
assert exc_info.value.status_code == 401
assert exc_info.value.detail == "Could not validate credentials"
@pytest.mark.asyncio
async def test_get_current_active_user_success(mock_user):
result = await get_current_active_user(mock_user)
assert result == mock_user
@pytest.mark.asyncio
async def test_get_current_active_user_inactive():
inactive_user = User(
id="123e4567-e89b-12d3-a456-426614174000",
email="inactive@example.com",
password_hash="hashedpassword",
is_active=False
)
with pytest.raises(HTTPException) as exc_info:
await get_current_active_user(inactive_user)
assert exc_info.value.status_code == 400
assert exc_info.value.detail == "Inactive user"

View File

@@ -1,147 +0,0 @@
from datetime import timedelta, datetime
from unittest.mock import AsyncMock
import pytest
from jose import jwt, JWTError
from sqlalchemy.ext.asyncio import AsyncSession
from app.auth.security import (
get_password_hash, verify_password,
create_access_token, create_refresh_token,
decode_token, SECRET_KEY, ALGORITHM
)
from app.schemas.token import TokenPayload
def test_password_hashing():
plain_password = "securepassword123"
hashed_password = get_password_hash(plain_password)
# Ensure hashed passwords are not the same
assert hashed_password != plain_password
# Test password verification
assert verify_password(plain_password, hashed_password)
assert not verify_password("wrongpassword", hashed_password)
def test_access_token_creation():
user_id = "123e4567-e89b-12d3-a456-426614174000"
token = create_access_token({"sub": user_id})
decoded_payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
assert decoded_payload.get("sub") == user_id
assert decoded_payload.get("type") == "access"
def test_refresh_token_creation():
user_id = "123e4567-e89b-12d3-a456-426614174000"
token = create_refresh_token({"sub": user_id})
decoded_payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
assert decoded_payload.get("sub") == user_id
assert decoded_payload.get("type") == "refresh"
@pytest.mark.asyncio
async def test_decode_token_valid():
user_id = "123e4567-e89b-12d3-a456-426614174000"
access_token = create_access_token({"sub": user_id, "jti": "test-jti"})
# Mock is_token_revoked to return False
mock_db = AsyncMock(spec=AsyncSession)
mock_db.get = AsyncMock(return_value=None)
token_payload = await decode_token(access_token, db=mock_db)
assert isinstance(token_payload, TokenPayload)
assert token_payload.sub == user_id
assert token_payload.type == "access"
@pytest.mark.asyncio
async def test_decode_token_expired():
user_id = "123e4567-e89b-12d3-a456-426614174000"
token = create_access_token({"sub": user_id, "jti": "test-jti"}, expires_delta=timedelta(seconds=-1))
# Mock database
mock_db = AsyncMock(spec=AsyncSession)
with pytest.raises(JWTError) as exc_info:
await decode_token(token, db=mock_db)
assert str(exc_info.value) == "Token has been revoked."
@pytest.mark.asyncio
async def test_decode_token_missing_exp():
# Create a token without the `exp` claim
payload = {"sub": "123e4567-e89b-12d3-a456-426614174000", "type": "access", "jti": "test-jti"}
token = jwt.encode(payload, SECRET_KEY, algorithm=ALGORITHM)
# Mock database
mock_db = AsyncMock(spec=AsyncSession)
with pytest.raises(JWTError) as exc_info:
await decode_token(token, db=mock_db)
assert str(exc_info.value) == "Malformed token. Missing required claim(s)."
@pytest.mark.asyncio
async def test_decode_token_missing_sub():
# Create a token without the `sub` claim
payload = {"exp": datetime.now().timestamp() + 60, "type": "access", "jti": "test-jti"}
token = jwt.encode(payload, SECRET_KEY, algorithm=ALGORITHM)
# Mock database
mock_db = AsyncMock(spec=AsyncSession)
with pytest.raises(JWTError) as exc_info:
await decode_token(token, db=mock_db)
assert str(exc_info.value) == "Malformed token. Missing required claim(s)."
@pytest.mark.asyncio
async def test_decode_token_invalid_signature():
# Use a different secret key for signing
token = jwt.encode({"sub": "123", "type": "access", "jti": "test-jti"}, "wrong_secret", algorithm=ALGORITHM)
# Mock database
mock_db = AsyncMock(spec=AsyncSession)
with pytest.raises(JWTError) as exc_info:
await decode_token(token, db=mock_db)
assert str(exc_info.value) == "Signature verification failed."
@pytest.mark.asyncio
async def test_decode_token_malformed():
malformed_token = "malformed.header.payload"
# Mock database
mock_db = AsyncMock(spec=AsyncSession)
with pytest.raises(JWTError) as exc_info:
await decode_token(malformed_token, db=mock_db)
assert str(exc_info.value) == "Invalid token."
@pytest.mark.asyncio
async def test_decode_token_invalid_type():
user_id = "123e4567-e89b-12d3-a456-426614174000"
token = create_refresh_token({"sub": user_id, "jti": "test-jti"}) # Token type is "refresh"
# Mock database
mock_db = AsyncMock(spec=AsyncSession)
with pytest.raises(JWTError) as exc_info:
await decode_token(token, required_type="access", db=mock_db) # Expecting an access token
assert str(exc_info.value) == "Invalid token type: expected 'access', got 'refresh'."

View File

@@ -9,7 +9,7 @@ from app.models import Event, GiftItem, GiftStatus, GiftPriority, GiftCategory,
EventTheme, Guest, GuestStatus, ActivityType, ActivityLog, EmailTemplate, TemplateType, NotificationLog, \
NotificationType, NotificationStatus
from app.models.user import User
from app.utils.test_utils import setup_test_db, teardown_test_db
from app.utils.test_utils import setup_test_db, teardown_test_db, setup_async_test_db, teardown_async_test_db
@pytest.fixture(scope="function")
@@ -30,6 +30,15 @@ def db_session():
teardown_test_db(test_engine)
@pytest.fixture(scope="function") # Define a fixture
async def async_test_db():
"""Fixture provides new testing engine and session for each test run to improve isolation."""
test_engine, AsyncTestingSessionLocal = await setup_async_test_db()
yield test_engine, AsyncTestingSessionLocal
await teardown_async_test_db(test_engine)
@pytest.fixture
def mock_user(db_session):
"""Fixture to create and return a mock User instance."""
@@ -72,7 +81,6 @@ def event_fixture(db_session, mock_user):
return event
@pytest.fixture
def gift_item_fixture(db_session, mock_user):
"""

View File

View File

@@ -0,0 +1,260 @@
# tests/core/test_auth.py
import uuid
import pytest
from datetime import datetime, timedelta, timezone
from jose import jwt
from pydantic import ValidationError
from app.core.auth import (
verify_password,
get_password_hash,
create_access_token,
create_refresh_token,
decode_token,
get_token_data,
TokenExpiredError,
TokenInvalidError,
TokenMissingClaimError
)
from app.core.config import settings
class TestPasswordHandling:
"""Tests for password hashing and verification functions"""
def test_password_hash_different_from_password(self):
"""Test that a password hash is different from the original password"""
password = "TestPassword123"
hashed = get_password_hash(password)
assert hashed != password
def test_verify_correct_password(self):
"""Test that verify_password returns True for the correct password"""
password = "TestPassword123"
hashed = get_password_hash(password)
assert verify_password(password, hashed) is True
def test_verify_incorrect_password(self):
"""Test that verify_password returns False for an incorrect password"""
password = "TestPassword123"
wrong_password = "WrongPassword123"
hashed = get_password_hash(password)
assert verify_password(wrong_password, hashed) is False
def test_same_password_different_hash(self):
"""Test that the same password gets a different hash each time"""
password = "TestPassword123"
hash1 = get_password_hash(password)
hash2 = get_password_hash(password)
assert hash1 != hash2
class TestTokenCreation:
"""Tests for token creation functions"""
def test_create_access_token(self):
"""Test that an access token is created with the correct claims"""
user_id = str(uuid.uuid4())
custom_claims = {
"email": "test@example.com",
"first_name": "Test",
"is_superuser": True
}
token = create_access_token(subject=user_id, claims=custom_claims)
# Decode token to verify claims
payload = jwt.decode(
token,
settings.SECRET_KEY,
algorithms=[settings.ALGORITHM]
)
# Check standard claims
assert payload["sub"] == user_id
assert "jti" in payload
assert "exp" in payload
assert "iat" in payload
assert payload["type"] == "access"
# Check custom claims
for key, value in custom_claims.items():
assert payload[key] == value
def test_create_refresh_token(self):
"""Test that a refresh token is created with the correct claims"""
user_id = str(uuid.uuid4())
token = create_refresh_token(subject=user_id)
# Decode token to verify claims
payload = jwt.decode(
token,
settings.SECRET_KEY,
algorithms=[settings.ALGORITHM]
)
# Check standard claims
assert payload["sub"] == user_id
assert "jti" in payload
assert "exp" in payload
assert "iat" in payload
assert payload["type"] == "refresh"
def test_token_expiration(self):
"""Test that tokens have the correct expiration time"""
user_id = str(uuid.uuid4())
expires = timedelta(minutes=5)
# Create token with specific expiration
token = create_access_token(
subject=user_id,
expires_delta=expires
)
# Decode token
payload = jwt.decode(
token,
settings.SECRET_KEY,
algorithms=[settings.ALGORITHM]
)
# Get actual expiration time from token
expiration = datetime.fromtimestamp(payload["exp"], tz=timezone.utc)
# Calculate expected expiration (approximately)
now = datetime.now(timezone.utc)
expected_expiration = now + expires
# Difference should be small (less than 1 second)
difference = abs((expiration - expected_expiration).total_seconds())
assert difference < 1
class TestTokenDecoding:
"""Tests for token decoding and validation functions"""
def test_decode_valid_token(self):
"""Test that a valid token can be decoded"""
user_id = str(uuid.uuid4())
token = create_access_token(subject=user_id)
# Decode token
payload = decode_token(token)
# Check that the subject matches
assert payload.sub == user_id
def test_decode_expired_token(self):
"""Test that an expired token raises TokenExpiredError"""
user_id = str(uuid.uuid4())
# Create a token that's already expired by directly manipulating the payload
now = datetime.now(timezone.utc)
expired_time = now - timedelta(hours=1) # 1 hour in the past
# Create the expired token manually
payload = {
"sub": user_id,
"exp": int(expired_time.timestamp()), # Set expiration in the past
"iat": int(now.timestamp()),
"jti": str(uuid.uuid4()),
"type": "access"
}
expired_token = jwt.encode(
payload,
settings.SECRET_KEY,
algorithm=settings.ALGORITHM
)
# Attempting to decode should raise TokenExpiredError
with pytest.raises(TokenExpiredError):
decode_token(expired_token)
def test_decode_invalid_token(self):
"""Test that an invalid token raises TokenInvalidError"""
invalid_token = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiJpbnZhbGlkIn0.invalid-signature"
with pytest.raises(TokenInvalidError):
decode_token(invalid_token)
def test_decode_token_with_missing_sub(self):
"""Test that a token without 'sub' claim raises TokenMissingClaimError"""
# Create a token without a subject
now = datetime.now(timezone.utc)
payload = {
"exp": int((now + timedelta(minutes=30)).timestamp()),
"iat": int(now.timestamp()),
"jti": str(uuid.uuid4()),
"type": "access"
# No 'sub' claim
}
token = jwt.encode(
payload,
settings.SECRET_KEY,
algorithm=settings.ALGORITHM
)
with pytest.raises(TokenMissingClaimError):
decode_token(token)
def test_decode_token_with_wrong_type(self):
"""Test that verifying a token with wrong type raises TokenInvalidError"""
user_id = str(uuid.uuid4())
token = create_access_token(subject=user_id)
# Try to verify it as a refresh token
with pytest.raises(TokenInvalidError):
decode_token(token, verify_type="refresh")
def test_decode_with_invalid_payload(self):
"""Test that a token with invalid payload structure raises TokenInvalidError"""
# Create a token with an invalid payload structure - missing 'sub' which is required
# but including 'exp' to avoid the expiration check
now = datetime.now(timezone.utc)
payload = {
# Missing "sub" field which is required
"exp": int((now + timedelta(minutes=30)).timestamp()),
"iat": int(now.timestamp()),
"jti": str(uuid.uuid4()),
"invalid_field": "test"
}
token = jwt.encode(
payload,
settings.SECRET_KEY,
algorithm=settings.ALGORITHM
)
# Should raise TokenMissingClaimError due to missing 'sub'
with pytest.raises(TokenMissingClaimError):
decode_token(token)
# Create another token with invalid type for required field
payload = {
"sub": 123, # sub should be a string, not an integer
"exp": int((now + timedelta(minutes=30)).timestamp()),
}
token = jwt.encode(
payload,
settings.SECRET_KEY,
algorithm=settings.ALGORITHM
)
# Should raise TokenInvalidError due to ValidationError
with pytest.raises(TokenInvalidError):
decode_token(token)
def test_get_token_data(self):
"""Test extracting TokenData from a token"""
user_id = uuid.uuid4()
token = create_access_token(
subject=str(user_id),
claims={"is_superuser": True}
)
token_data = get_token_data(token)
assert token_data.user_id == user_id
assert token_data.is_superuser is True

View File