Remove token revocation logic and unused dependencies
Eliminated the `RevokedToken` model and associated logic for managing token revocation. Removed unused files, related tests, and outdated dependencies in authentication modules. Simplified token decoding, user validation, and dependency injection by streamlining the flow and enhancing maintainability.
This commit is contained in:
@@ -1,85 +0,0 @@
|
||||
from datetime import datetime, timezone
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
import pytest
|
||||
from fastapi import HTTPException
|
||||
from jose import jwt
|
||||
|
||||
from app.auth.dependencies import get_current_user, get_current_active_user
|
||||
from app.auth.security import SECRET_KEY, ALGORITHM
|
||||
from app.models.user import User
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_user():
|
||||
return User(
|
||||
id="123e4567-e89b-12d3-a456-426614174000",
|
||||
email="test@example.com",
|
||||
password_hash="hashedpassword",
|
||||
is_active=True
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_current_user_success(mock_user):
|
||||
# Create a valid access token with required claims
|
||||
valid_token = jwt.encode(
|
||||
{"sub": str(mock_user.id), "type": "access", "exp": datetime.now(tz=timezone.utc).timestamp() + 3600},
|
||||
SECRET_KEY,
|
||||
algorithm=ALGORITHM
|
||||
)
|
||||
|
||||
# Mock database session
|
||||
mock_db = AsyncMock()
|
||||
mock_db.get.return_value = mock_user # Ensure `db.get()` returns the mock_user
|
||||
|
||||
# Call the dependency
|
||||
user = await get_current_user(token=valid_token, db=mock_db)
|
||||
|
||||
# Assertions
|
||||
assert user == mock_user, "Returned user does not match the mocked user."
|
||||
mock_db.get.assert_called_once_with(User, mock_user.id)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_current_user_invalid_token():
|
||||
invalid_token = "invalid.token.payload"
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await get_current_user(token=invalid_token, db=AsyncMock())
|
||||
|
||||
assert exc_info.value.status_code == 401
|
||||
assert exc_info.value.detail == "Could not validate credentials"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_current_user_wrong_token_type():
|
||||
token = jwt.encode({"sub": "123", "type": "refresh"}, SECRET_KEY, algorithm=ALGORITHM)
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await get_current_user(token=token, db=AsyncMock())
|
||||
|
||||
assert exc_info.value.status_code == 401
|
||||
assert exc_info.value.detail == "Could not validate credentials"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_current_active_user_success(mock_user):
|
||||
result = await get_current_active_user(mock_user)
|
||||
assert result == mock_user
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_current_active_user_inactive():
|
||||
inactive_user = User(
|
||||
id="123e4567-e89b-12d3-a456-426614174000",
|
||||
email="inactive@example.com",
|
||||
password_hash="hashedpassword",
|
||||
is_active=False
|
||||
)
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await get_current_active_user(inactive_user)
|
||||
|
||||
assert exc_info.value.status_code == 400
|
||||
assert exc_info.value.detail == "Inactive user"
|
||||
@@ -1,147 +0,0 @@
|
||||
from datetime import timedelta, datetime
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
import pytest
|
||||
from jose import jwt, JWTError
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.auth.security import (
|
||||
get_password_hash, verify_password,
|
||||
create_access_token, create_refresh_token,
|
||||
decode_token, SECRET_KEY, ALGORITHM
|
||||
)
|
||||
from app.schemas.token import TokenPayload
|
||||
|
||||
|
||||
def test_password_hashing():
|
||||
plain_password = "securepassword123"
|
||||
hashed_password = get_password_hash(plain_password)
|
||||
|
||||
# Ensure hashed passwords are not the same
|
||||
assert hashed_password != plain_password
|
||||
# Test password verification
|
||||
assert verify_password(plain_password, hashed_password)
|
||||
assert not verify_password("wrongpassword", hashed_password)
|
||||
|
||||
|
||||
def test_access_token_creation():
|
||||
user_id = "123e4567-e89b-12d3-a456-426614174000"
|
||||
token = create_access_token({"sub": user_id})
|
||||
decoded_payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
|
||||
|
||||
assert decoded_payload.get("sub") == user_id
|
||||
assert decoded_payload.get("type") == "access"
|
||||
|
||||
|
||||
def test_refresh_token_creation():
|
||||
user_id = "123e4567-e89b-12d3-a456-426614174000"
|
||||
token = create_refresh_token({"sub": user_id})
|
||||
decoded_payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
|
||||
|
||||
assert decoded_payload.get("sub") == user_id
|
||||
assert decoded_payload.get("type") == "refresh"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_decode_token_valid():
|
||||
user_id = "123e4567-e89b-12d3-a456-426614174000"
|
||||
access_token = create_access_token({"sub": user_id, "jti": "test-jti"})
|
||||
|
||||
# Mock is_token_revoked to return False
|
||||
mock_db = AsyncMock(spec=AsyncSession)
|
||||
mock_db.get = AsyncMock(return_value=None)
|
||||
|
||||
token_payload = await decode_token(access_token, db=mock_db)
|
||||
|
||||
assert isinstance(token_payload, TokenPayload)
|
||||
assert token_payload.sub == user_id
|
||||
assert token_payload.type == "access"
|
||||
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_decode_token_expired():
|
||||
user_id = "123e4567-e89b-12d3-a456-426614174000"
|
||||
token = create_access_token({"sub": user_id, "jti": "test-jti"}, expires_delta=timedelta(seconds=-1))
|
||||
|
||||
# Mock database
|
||||
mock_db = AsyncMock(spec=AsyncSession)
|
||||
|
||||
with pytest.raises(JWTError) as exc_info:
|
||||
await decode_token(token, db=mock_db)
|
||||
|
||||
assert str(exc_info.value) == "Token has been revoked."
|
||||
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_decode_token_missing_exp():
|
||||
# Create a token without the `exp` claim
|
||||
payload = {"sub": "123e4567-e89b-12d3-a456-426614174000", "type": "access", "jti": "test-jti"}
|
||||
token = jwt.encode(payload, SECRET_KEY, algorithm=ALGORITHM)
|
||||
|
||||
# Mock database
|
||||
mock_db = AsyncMock(spec=AsyncSession)
|
||||
|
||||
with pytest.raises(JWTError) as exc_info:
|
||||
await decode_token(token, db=mock_db)
|
||||
|
||||
assert str(exc_info.value) == "Malformed token. Missing required claim(s)."
|
||||
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_decode_token_missing_sub():
|
||||
# Create a token without the `sub` claim
|
||||
payload = {"exp": datetime.now().timestamp() + 60, "type": "access", "jti": "test-jti"}
|
||||
token = jwt.encode(payload, SECRET_KEY, algorithm=ALGORITHM)
|
||||
|
||||
# Mock database
|
||||
mock_db = AsyncMock(spec=AsyncSession)
|
||||
|
||||
with pytest.raises(JWTError) as exc_info:
|
||||
await decode_token(token, db=mock_db)
|
||||
|
||||
assert str(exc_info.value) == "Malformed token. Missing required claim(s)."
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_decode_token_invalid_signature():
|
||||
# Use a different secret key for signing
|
||||
token = jwt.encode({"sub": "123", "type": "access", "jti": "test-jti"}, "wrong_secret", algorithm=ALGORITHM)
|
||||
|
||||
# Mock database
|
||||
mock_db = AsyncMock(spec=AsyncSession)
|
||||
|
||||
with pytest.raises(JWTError) as exc_info:
|
||||
await decode_token(token, db=mock_db)
|
||||
|
||||
assert str(exc_info.value) == "Signature verification failed."
|
||||
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_decode_token_malformed():
|
||||
malformed_token = "malformed.header.payload"
|
||||
|
||||
# Mock database
|
||||
mock_db = AsyncMock(spec=AsyncSession)
|
||||
|
||||
with pytest.raises(JWTError) as exc_info:
|
||||
await decode_token(malformed_token, db=mock_db)
|
||||
|
||||
assert str(exc_info.value) == "Invalid token."
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_decode_token_invalid_type():
|
||||
user_id = "123e4567-e89b-12d3-a456-426614174000"
|
||||
token = create_refresh_token({"sub": user_id, "jti": "test-jti"}) # Token type is "refresh"
|
||||
|
||||
# Mock database
|
||||
mock_db = AsyncMock(spec=AsyncSession)
|
||||
|
||||
with pytest.raises(JWTError) as exc_info:
|
||||
await decode_token(token, required_type="access", db=mock_db) # Expecting an access token
|
||||
|
||||
assert str(exc_info.value) == "Invalid token type: expected 'access', got 'refresh'."
|
||||
@@ -9,7 +9,7 @@ from app.models import Event, GiftItem, GiftStatus, GiftPriority, GiftCategory,
|
||||
EventTheme, Guest, GuestStatus, ActivityType, ActivityLog, EmailTemplate, TemplateType, NotificationLog, \
|
||||
NotificationType, NotificationStatus
|
||||
from app.models.user import User
|
||||
from app.utils.test_utils import setup_test_db, teardown_test_db
|
||||
from app.utils.test_utils import setup_test_db, teardown_test_db, setup_async_test_db, teardown_async_test_db
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
@@ -30,6 +30,15 @@ def db_session():
|
||||
teardown_test_db(test_engine)
|
||||
|
||||
|
||||
@pytest.fixture(scope="function") # Define a fixture
|
||||
async def async_test_db():
|
||||
"""Fixture provides new testing engine and session for each test run to improve isolation."""
|
||||
|
||||
test_engine, AsyncTestingSessionLocal = await setup_async_test_db()
|
||||
yield test_engine, AsyncTestingSessionLocal
|
||||
await teardown_async_test_db(test_engine)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_user(db_session):
|
||||
"""Fixture to create and return a mock User instance."""
|
||||
@@ -72,7 +81,6 @@ def event_fixture(db_session, mock_user):
|
||||
return event
|
||||
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def gift_item_fixture(db_session, mock_user):
|
||||
"""
|
||||
|
||||
0
backend/tests/core/__init__.py
Normal file
0
backend/tests/core/__init__.py
Normal file
260
backend/tests/core/test_auth.py
Normal file
260
backend/tests/core/test_auth.py
Normal file
@@ -0,0 +1,260 @@
|
||||
# tests/core/test_auth.py
|
||||
import uuid
|
||||
import pytest
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from jose import jwt
|
||||
from pydantic import ValidationError
|
||||
|
||||
from app.core.auth import (
|
||||
verify_password,
|
||||
get_password_hash,
|
||||
create_access_token,
|
||||
create_refresh_token,
|
||||
decode_token,
|
||||
get_token_data,
|
||||
TokenExpiredError,
|
||||
TokenInvalidError,
|
||||
TokenMissingClaimError
|
||||
)
|
||||
from app.core.config import settings
|
||||
|
||||
|
||||
class TestPasswordHandling:
|
||||
"""Tests for password hashing and verification functions"""
|
||||
|
||||
def test_password_hash_different_from_password(self):
|
||||
"""Test that a password hash is different from the original password"""
|
||||
password = "TestPassword123"
|
||||
hashed = get_password_hash(password)
|
||||
assert hashed != password
|
||||
|
||||
def test_verify_correct_password(self):
|
||||
"""Test that verify_password returns True for the correct password"""
|
||||
password = "TestPassword123"
|
||||
hashed = get_password_hash(password)
|
||||
assert verify_password(password, hashed) is True
|
||||
|
||||
def test_verify_incorrect_password(self):
|
||||
"""Test that verify_password returns False for an incorrect password"""
|
||||
password = "TestPassword123"
|
||||
wrong_password = "WrongPassword123"
|
||||
hashed = get_password_hash(password)
|
||||
assert verify_password(wrong_password, hashed) is False
|
||||
|
||||
def test_same_password_different_hash(self):
|
||||
"""Test that the same password gets a different hash each time"""
|
||||
password = "TestPassword123"
|
||||
hash1 = get_password_hash(password)
|
||||
hash2 = get_password_hash(password)
|
||||
assert hash1 != hash2
|
||||
|
||||
|
||||
class TestTokenCreation:
|
||||
"""Tests for token creation functions"""
|
||||
|
||||
def test_create_access_token(self):
|
||||
"""Test that an access token is created with the correct claims"""
|
||||
user_id = str(uuid.uuid4())
|
||||
custom_claims = {
|
||||
"email": "test@example.com",
|
||||
"first_name": "Test",
|
||||
"is_superuser": True
|
||||
}
|
||||
token = create_access_token(subject=user_id, claims=custom_claims)
|
||||
|
||||
# Decode token to verify claims
|
||||
payload = jwt.decode(
|
||||
token,
|
||||
settings.SECRET_KEY,
|
||||
algorithms=[settings.ALGORITHM]
|
||||
)
|
||||
|
||||
# Check standard claims
|
||||
assert payload["sub"] == user_id
|
||||
assert "jti" in payload
|
||||
assert "exp" in payload
|
||||
assert "iat" in payload
|
||||
assert payload["type"] == "access"
|
||||
|
||||
# Check custom claims
|
||||
for key, value in custom_claims.items():
|
||||
assert payload[key] == value
|
||||
|
||||
def test_create_refresh_token(self):
|
||||
"""Test that a refresh token is created with the correct claims"""
|
||||
user_id = str(uuid.uuid4())
|
||||
token = create_refresh_token(subject=user_id)
|
||||
|
||||
# Decode token to verify claims
|
||||
payload = jwt.decode(
|
||||
token,
|
||||
settings.SECRET_KEY,
|
||||
algorithms=[settings.ALGORITHM]
|
||||
)
|
||||
|
||||
# Check standard claims
|
||||
assert payload["sub"] == user_id
|
||||
assert "jti" in payload
|
||||
assert "exp" in payload
|
||||
assert "iat" in payload
|
||||
assert payload["type"] == "refresh"
|
||||
|
||||
def test_token_expiration(self):
|
||||
"""Test that tokens have the correct expiration time"""
|
||||
user_id = str(uuid.uuid4())
|
||||
expires = timedelta(minutes=5)
|
||||
|
||||
# Create token with specific expiration
|
||||
token = create_access_token(
|
||||
subject=user_id,
|
||||
expires_delta=expires
|
||||
)
|
||||
|
||||
# Decode token
|
||||
payload = jwt.decode(
|
||||
token,
|
||||
settings.SECRET_KEY,
|
||||
algorithms=[settings.ALGORITHM]
|
||||
)
|
||||
|
||||
# Get actual expiration time from token
|
||||
expiration = datetime.fromtimestamp(payload["exp"], tz=timezone.utc)
|
||||
|
||||
# Calculate expected expiration (approximately)
|
||||
now = datetime.now(timezone.utc)
|
||||
expected_expiration = now + expires
|
||||
|
||||
# Difference should be small (less than 1 second)
|
||||
difference = abs((expiration - expected_expiration).total_seconds())
|
||||
assert difference < 1
|
||||
|
||||
|
||||
class TestTokenDecoding:
|
||||
"""Tests for token decoding and validation functions"""
|
||||
|
||||
def test_decode_valid_token(self):
|
||||
"""Test that a valid token can be decoded"""
|
||||
user_id = str(uuid.uuid4())
|
||||
token = create_access_token(subject=user_id)
|
||||
|
||||
# Decode token
|
||||
payload = decode_token(token)
|
||||
|
||||
# Check that the subject matches
|
||||
assert payload.sub == user_id
|
||||
|
||||
def test_decode_expired_token(self):
|
||||
"""Test that an expired token raises TokenExpiredError"""
|
||||
user_id = str(uuid.uuid4())
|
||||
|
||||
# Create a token that's already expired by directly manipulating the payload
|
||||
now = datetime.now(timezone.utc)
|
||||
expired_time = now - timedelta(hours=1) # 1 hour in the past
|
||||
|
||||
# Create the expired token manually
|
||||
payload = {
|
||||
"sub": user_id,
|
||||
"exp": int(expired_time.timestamp()), # Set expiration in the past
|
||||
"iat": int(now.timestamp()),
|
||||
"jti": str(uuid.uuid4()),
|
||||
"type": "access"
|
||||
}
|
||||
|
||||
expired_token = jwt.encode(
|
||||
payload,
|
||||
settings.SECRET_KEY,
|
||||
algorithm=settings.ALGORITHM
|
||||
)
|
||||
|
||||
# Attempting to decode should raise TokenExpiredError
|
||||
with pytest.raises(TokenExpiredError):
|
||||
decode_token(expired_token)
|
||||
|
||||
def test_decode_invalid_token(self):
|
||||
"""Test that an invalid token raises TokenInvalidError"""
|
||||
invalid_token = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiJpbnZhbGlkIn0.invalid-signature"
|
||||
|
||||
with pytest.raises(TokenInvalidError):
|
||||
decode_token(invalid_token)
|
||||
|
||||
def test_decode_token_with_missing_sub(self):
|
||||
"""Test that a token without 'sub' claim raises TokenMissingClaimError"""
|
||||
# Create a token without a subject
|
||||
now = datetime.now(timezone.utc)
|
||||
payload = {
|
||||
"exp": int((now + timedelta(minutes=30)).timestamp()),
|
||||
"iat": int(now.timestamp()),
|
||||
"jti": str(uuid.uuid4()),
|
||||
"type": "access"
|
||||
# No 'sub' claim
|
||||
}
|
||||
|
||||
token = jwt.encode(
|
||||
payload,
|
||||
settings.SECRET_KEY,
|
||||
algorithm=settings.ALGORITHM
|
||||
)
|
||||
|
||||
with pytest.raises(TokenMissingClaimError):
|
||||
decode_token(token)
|
||||
|
||||
def test_decode_token_with_wrong_type(self):
|
||||
"""Test that verifying a token with wrong type raises TokenInvalidError"""
|
||||
user_id = str(uuid.uuid4())
|
||||
token = create_access_token(subject=user_id)
|
||||
|
||||
# Try to verify it as a refresh token
|
||||
with pytest.raises(TokenInvalidError):
|
||||
decode_token(token, verify_type="refresh")
|
||||
|
||||
def test_decode_with_invalid_payload(self):
|
||||
"""Test that a token with invalid payload structure raises TokenInvalidError"""
|
||||
# Create a token with an invalid payload structure - missing 'sub' which is required
|
||||
# but including 'exp' to avoid the expiration check
|
||||
now = datetime.now(timezone.utc)
|
||||
payload = {
|
||||
# Missing "sub" field which is required
|
||||
"exp": int((now + timedelta(minutes=30)).timestamp()),
|
||||
"iat": int(now.timestamp()),
|
||||
"jti": str(uuid.uuid4()),
|
||||
"invalid_field": "test"
|
||||
}
|
||||
|
||||
token = jwt.encode(
|
||||
payload,
|
||||
settings.SECRET_KEY,
|
||||
algorithm=settings.ALGORITHM
|
||||
)
|
||||
|
||||
# Should raise TokenMissingClaimError due to missing 'sub'
|
||||
with pytest.raises(TokenMissingClaimError):
|
||||
decode_token(token)
|
||||
|
||||
# Create another token with invalid type for required field
|
||||
payload = {
|
||||
"sub": 123, # sub should be a string, not an integer
|
||||
"exp": int((now + timedelta(minutes=30)).timestamp()),
|
||||
}
|
||||
|
||||
token = jwt.encode(
|
||||
payload,
|
||||
settings.SECRET_KEY,
|
||||
algorithm=settings.ALGORITHM
|
||||
)
|
||||
|
||||
# Should raise TokenInvalidError due to ValidationError
|
||||
with pytest.raises(TokenInvalidError):
|
||||
decode_token(token)
|
||||
|
||||
def test_get_token_data(self):
|
||||
"""Test extracting TokenData from a token"""
|
||||
user_id = uuid.uuid4()
|
||||
token = create_access_token(
|
||||
subject=str(user_id),
|
||||
claims={"is_superuser": True}
|
||||
)
|
||||
|
||||
token_data = get_token_data(token)
|
||||
|
||||
assert token_data.user_id == user_id
|
||||
assert token_data.is_superuser is True
|
||||
0
backend/tests/services/__init__.py
Normal file
0
backend/tests/services/__init__.py
Normal file
Reference in New Issue
Block a user