- Introduced `pyproject.toml` to centralize backend tool configurations (e.g., Ruff, mypy, coverage, pytest). - Replaced Black, isort, and Flake8 with Ruff for linting, formatting, and import sorting. - Updated `requirements.txt` to include Ruff and remove replaced tools. - Added `Makefile` to streamline development workflows with commands for linting, formatting, type-checking, testing, and cleanup.
235 lines
8.0 KiB
Python
Executable File
235 lines
8.0 KiB
Python
Executable File
# tests/core/test_auth.py
|
|
import uuid
|
|
from datetime import UTC, datetime, timedelta
|
|
|
|
import pytest
|
|
from jose import jwt
|
|
|
|
from app.core.auth import (
|
|
TokenExpiredError,
|
|
TokenInvalidError,
|
|
TokenMissingClaimError,
|
|
create_access_token,
|
|
create_refresh_token,
|
|
decode_token,
|
|
get_password_hash,
|
|
get_token_data,
|
|
verify_password,
|
|
)
|
|
from app.core.config import settings
|
|
|
|
|
|
class TestPasswordHandling:
|
|
"""Tests for password hashing and verification functions"""
|
|
|
|
def test_password_hash_different_from_password(self):
|
|
"""Test that a password hash is different from the original password"""
|
|
password = "TestPassword123!"
|
|
hashed = get_password_hash(password)
|
|
assert hashed != password
|
|
|
|
def test_verify_correct_password(self):
|
|
"""Test that verify_password returns True for the correct password"""
|
|
password = "TestPassword123!"
|
|
hashed = get_password_hash(password)
|
|
assert verify_password(password, hashed) is True
|
|
|
|
def test_verify_incorrect_password(self):
|
|
"""Test that verify_password returns False for an incorrect password"""
|
|
password = "TestPassword123!"
|
|
wrong_password = "WrongPassword123!"
|
|
hashed = get_password_hash(password)
|
|
assert verify_password(wrong_password, hashed) is False
|
|
|
|
def test_same_password_different_hash(self):
|
|
"""Test that the same password gets a different hash each time"""
|
|
password = "TestPassword123!"
|
|
hash1 = get_password_hash(password)
|
|
hash2 = get_password_hash(password)
|
|
assert hash1 != hash2
|
|
|
|
|
|
class TestTokenCreation:
|
|
"""Tests for token creation functions"""
|
|
|
|
def test_create_access_token(self):
|
|
"""Test that an access token is created with the correct claims"""
|
|
user_id = str(uuid.uuid4())
|
|
custom_claims = {
|
|
"email": "test@example.com",
|
|
"first_name": "Test",
|
|
"is_superuser": True,
|
|
}
|
|
token = create_access_token(subject=user_id, claims=custom_claims)
|
|
|
|
# Decode token to verify claims
|
|
payload = jwt.decode(
|
|
token, settings.SECRET_KEY, algorithms=[settings.ALGORITHM]
|
|
)
|
|
|
|
# Check standard claims
|
|
assert payload["sub"] == user_id
|
|
assert "jti" in payload
|
|
assert "exp" in payload
|
|
assert "iat" in payload
|
|
assert payload["type"] == "access"
|
|
|
|
# Check custom claims
|
|
for key, value in custom_claims.items():
|
|
assert payload[key] == value
|
|
|
|
def test_create_refresh_token(self):
|
|
"""Test that a refresh token is created with the correct claims"""
|
|
user_id = str(uuid.uuid4())
|
|
token = create_refresh_token(subject=user_id)
|
|
|
|
# Decode token to verify claims
|
|
payload = jwt.decode(
|
|
token, settings.SECRET_KEY, algorithms=[settings.ALGORITHM]
|
|
)
|
|
|
|
# Check standard claims
|
|
assert payload["sub"] == user_id
|
|
assert "jti" in payload
|
|
assert "exp" in payload
|
|
assert "iat" in payload
|
|
assert payload["type"] == "refresh"
|
|
|
|
def test_token_expiration(self):
|
|
"""Test that tokens have the correct expiration time"""
|
|
user_id = str(uuid.uuid4())
|
|
expires = timedelta(minutes=5)
|
|
|
|
# Create token with specific expiration
|
|
token = create_access_token(subject=user_id, expires_delta=expires)
|
|
|
|
# Decode token
|
|
payload = jwt.decode(
|
|
token, settings.SECRET_KEY, algorithms=[settings.ALGORITHM]
|
|
)
|
|
|
|
# Get actual expiration time from token
|
|
expiration = datetime.fromtimestamp(payload["exp"], tz=UTC)
|
|
|
|
# Calculate expected expiration (approximately)
|
|
now = datetime.now(UTC)
|
|
expected_expiration = now + expires
|
|
|
|
# Difference should be small (less than 1 second)
|
|
difference = abs((expiration - expected_expiration).total_seconds())
|
|
assert difference < 1
|
|
|
|
|
|
class TestTokenDecoding:
|
|
"""Tests for token decoding and validation functions"""
|
|
|
|
def test_decode_valid_token(self):
|
|
"""Test that a valid token can be decoded"""
|
|
user_id = str(uuid.uuid4())
|
|
token = create_access_token(subject=user_id)
|
|
|
|
# Decode token
|
|
payload = decode_token(token)
|
|
|
|
# Check that the subject matches
|
|
assert payload.sub == user_id
|
|
|
|
def test_decode_expired_token(self):
|
|
"""Test that an expired token raises TokenExpiredError"""
|
|
user_id = str(uuid.uuid4())
|
|
|
|
# Create a token that's already expired by directly manipulating the payload
|
|
now = datetime.now(UTC)
|
|
expired_time = now - timedelta(hours=1) # 1 hour in the past
|
|
|
|
# Create the expired token manually
|
|
payload = {
|
|
"sub": user_id,
|
|
"exp": int(expired_time.timestamp()), # Set expiration in the past
|
|
"iat": int(now.timestamp()),
|
|
"jti": str(uuid.uuid4()),
|
|
"type": "access",
|
|
}
|
|
|
|
expired_token = jwt.encode(
|
|
payload, settings.SECRET_KEY, algorithm=settings.ALGORITHM
|
|
)
|
|
|
|
# Attempting to decode should raise TokenExpiredError
|
|
with pytest.raises(TokenExpiredError):
|
|
decode_token(expired_token)
|
|
|
|
def test_decode_invalid_token(self):
|
|
"""Test that an invalid token raises TokenInvalidError"""
|
|
invalid_token = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiJpbnZhbGlkIn0.invalid-signature"
|
|
|
|
with pytest.raises(TokenInvalidError):
|
|
decode_token(invalid_token)
|
|
|
|
def test_decode_token_with_missing_sub(self):
|
|
"""Test that a token without 'sub' claim raises TokenMissingClaimError"""
|
|
# Create a token without a subject
|
|
now = datetime.now(UTC)
|
|
payload = {
|
|
"exp": int((now + timedelta(minutes=30)).timestamp()),
|
|
"iat": int(now.timestamp()),
|
|
"jti": str(uuid.uuid4()),
|
|
"type": "access",
|
|
# No 'sub' claim
|
|
}
|
|
|
|
token = jwt.encode(payload, settings.SECRET_KEY, algorithm=settings.ALGORITHM)
|
|
|
|
with pytest.raises(TokenMissingClaimError):
|
|
decode_token(token)
|
|
|
|
def test_decode_token_with_wrong_type(self):
|
|
"""Test that verifying a token with wrong type raises TokenInvalidError"""
|
|
user_id = str(uuid.uuid4())
|
|
token = create_access_token(subject=user_id)
|
|
|
|
# Try to verify it as a refresh token
|
|
with pytest.raises(TokenInvalidError):
|
|
decode_token(token, verify_type="refresh")
|
|
|
|
def test_decode_with_invalid_payload(self):
|
|
"""Test that a token with invalid payload structure raises TokenInvalidError"""
|
|
# Create a token with an invalid payload structure - missing 'sub' which is required
|
|
# but including 'exp' to avoid the expiration check
|
|
now = datetime.now(UTC)
|
|
payload = {
|
|
# Missing "sub" field which is required
|
|
"exp": int((now + timedelta(minutes=30)).timestamp()),
|
|
"iat": int(now.timestamp()),
|
|
"jti": str(uuid.uuid4()),
|
|
"invalid_field": "test",
|
|
}
|
|
|
|
token = jwt.encode(payload, settings.SECRET_KEY, algorithm=settings.ALGORITHM)
|
|
|
|
# Should raise TokenMissingClaimError due to missing 'sub'
|
|
with pytest.raises(TokenMissingClaimError):
|
|
decode_token(token)
|
|
|
|
# Create another token with invalid type for required field
|
|
payload = {
|
|
"sub": 123, # sub should be a string, not an integer
|
|
"exp": int((now + timedelta(minutes=30)).timestamp()),
|
|
}
|
|
|
|
token = jwt.encode(payload, settings.SECRET_KEY, algorithm=settings.ALGORITHM)
|
|
|
|
# Should raise TokenInvalidError due to ValidationError
|
|
with pytest.raises(TokenInvalidError):
|
|
decode_token(token)
|
|
|
|
def test_get_token_data(self):
|
|
"""Test extracting TokenData from a token"""
|
|
user_id = uuid.uuid4()
|
|
token = create_access_token(subject=str(user_id), claims={"is_superuser": True})
|
|
|
|
token_data = get_token_data(token)
|
|
|
|
assert token_data.user_id == user_id
|
|
assert token_data.is_superuser is True
|