Add comprehensive test coverage for email service, password reset endpoints, and soft delete functionality

- Introduced unit tests for `EmailService` covering `ConsoleEmailBackend` and `SMTPEmailBackend`.
- Added tests for password reset request and confirmation endpoints, including edge cases and error handling.
- Implemented soft delete CRUD tests to validate `deleted_at` behavior and data exclusion in queries.
- Enhanced API tests for email functionality and user management workflows.
This commit is contained in:
Felipe Cardoso
2025-10-30 17:18:25 +01:00
parent 182b12b2d5
commit defa33975f
7 changed files with 2063 additions and 2 deletions

View File

@@ -0,0 +1,348 @@
# tests/api/test_auth_endpoints.py
"""
Tests for authentication endpoints.
"""
import pytest
from unittest.mock import patch, MagicMock
from fastapi import status
from app.models.user import User
from app.schemas.users import UserCreate
# Disable rate limiting for tests
@pytest.fixture(autouse=True)
def disable_rate_limit():
"""Disable rate limiting for all tests in this module."""
with patch('app.api.routes.auth.limiter.enabled', False):
yield
class TestRegisterEndpoint:
"""Tests for POST /auth/register endpoint."""
def test_register_success(self, client, test_db):
"""Test successful user registration."""
response = client.post(
"/api/v1/auth/register",
json={
"email": "newuser@example.com",
"password": "SecurePassword123",
"first_name": "New",
"last_name": "User"
}
)
assert response.status_code == status.HTTP_201_CREATED
data = response.json()
assert data["email"] == "newuser@example.com"
assert data["first_name"] == "New"
assert "password" not in data
def test_register_duplicate_email(self, client, test_user):
"""Test registering with existing email."""
response = client.post(
"/api/v1/auth/register",
json={
"email": test_user.email,
"password": "SecurePassword123",
"first_name": "Duplicate",
"last_name": "User"
}
)
assert response.status_code == status.HTTP_409_CONFLICT
data = response.json()
assert data["success"] is False
def test_register_weak_password(self, client):
"""Test registration with weak password."""
response = client.post(
"/api/v1/auth/register",
json={
"email": "weakpass@example.com",
"password": "weak",
"first_name": "Weak",
"last_name": "Pass"
}
)
assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY
def test_register_unexpected_error(self, client, test_db):
"""Test registration with unexpected error."""
with patch('app.services.auth_service.AuthService.create_user') as mock_create:
mock_create.side_effect = Exception("Unexpected error")
response = client.post(
"/api/v1/auth/register",
json={
"email": "error@example.com",
"password": "SecurePassword123",
"first_name": "Error",
"last_name": "User"
}
)
assert response.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR
class TestLoginEndpoint:
"""Tests for POST /auth/login endpoint."""
def test_login_success(self, client, test_user):
"""Test successful login."""
response = client.post(
"/api/v1/auth/login",
json={
"email": test_user.email,
"password": "TestPassword123"
}
)
assert response.status_code == status.HTTP_200_OK
data = response.json()
assert "access_token" in data
assert "refresh_token" in data
assert data["token_type"] == "bearer"
def test_login_wrong_password(self, client, test_user):
"""Test login with wrong password."""
response = client.post(
"/api/v1/auth/login",
json={
"email": test_user.email,
"password": "WrongPassword123"
}
)
assert response.status_code == status.HTTP_401_UNAUTHORIZED
def test_login_nonexistent_user(self, client):
"""Test login with non-existent email."""
response = client.post(
"/api/v1/auth/login",
json={
"email": "nonexistent@example.com",
"password": "Password123"
}
)
assert response.status_code == status.HTTP_401_UNAUTHORIZED
def test_login_inactive_user(self, client, test_user, test_db):
"""Test login with inactive user."""
test_user.is_active = False
test_db.add(test_user)
test_db.commit()
response = client.post(
"/api/v1/auth/login",
json={
"email": test_user.email,
"password": "TestPassword123"
}
)
assert response.status_code == status.HTTP_401_UNAUTHORIZED
def test_login_unexpected_error(self, client, test_user):
"""Test login with unexpected error."""
with patch('app.services.auth_service.AuthService.authenticate_user') as mock_auth:
mock_auth.side_effect = Exception("Database error")
response = client.post(
"/api/v1/auth/login",
json={
"email": test_user.email,
"password": "TestPassword123"
}
)
assert response.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR
class TestOAuthLoginEndpoint:
"""Tests for POST /auth/login/oauth endpoint."""
def test_oauth_login_success(self, client, test_user):
"""Test successful OAuth login."""
response = client.post(
"/api/v1/auth/login/oauth",
data={
"username": test_user.email,
"password": "TestPassword123"
}
)
assert response.status_code == status.HTTP_200_OK
data = response.json()
assert "access_token" in data
assert "refresh_token" in data
def test_oauth_login_wrong_credentials(self, client, test_user):
"""Test OAuth login with wrong credentials."""
response = client.post(
"/api/v1/auth/login/oauth",
data={
"username": test_user.email,
"password": "WrongPassword"
}
)
assert response.status_code == status.HTTP_401_UNAUTHORIZED
def test_oauth_login_inactive_user(self, client, test_user, test_db):
"""Test OAuth login with inactive user."""
test_user.is_active = False
test_db.add(test_user)
test_db.commit()
response = client.post(
"/api/v1/auth/login/oauth",
data={
"username": test_user.email,
"password": "TestPassword123"
}
)
assert response.status_code == status.HTTP_401_UNAUTHORIZED
def test_oauth_login_unexpected_error(self, client, test_user):
"""Test OAuth login with unexpected error."""
with patch('app.services.auth_service.AuthService.authenticate_user') as mock_auth:
mock_auth.side_effect = Exception("Unexpected error")
response = client.post(
"/api/v1/auth/login/oauth",
data={
"username": test_user.email,
"password": "TestPassword123"
}
)
assert response.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR
class TestRefreshTokenEndpoint:
"""Tests for POST /auth/refresh endpoint."""
def test_refresh_token_success(self, client, test_user):
"""Test successful token refresh."""
# First, login to get a refresh token
login_response = client.post(
"/api/v1/auth/login",
json={
"email": test_user.email,
"password": "TestPassword123"
}
)
refresh_token = login_response.json()["refresh_token"]
# Now refresh the token
response = client.post(
"/api/v1/auth/refresh",
json={"refresh_token": refresh_token}
)
assert response.status_code == status.HTTP_200_OK
data = response.json()
assert "access_token" in data
assert "refresh_token" in data
def test_refresh_token_expired(self, client):
"""Test refresh with expired token."""
from app.core.auth import TokenExpiredError
with patch('app.services.auth_service.AuthService.refresh_tokens') as mock_refresh:
mock_refresh.side_effect = TokenExpiredError("Token expired")
response = client.post(
"/api/v1/auth/refresh",
json={"refresh_token": "some_token"}
)
assert response.status_code == status.HTTP_401_UNAUTHORIZED
def test_refresh_token_invalid(self, client):
"""Test refresh with invalid token."""
response = client.post(
"/api/v1/auth/refresh",
json={"refresh_token": "invalid_token"}
)
assert response.status_code == status.HTTP_401_UNAUTHORIZED
def test_refresh_token_unexpected_error(self, client, test_user):
"""Test refresh with unexpected error."""
# Get a valid refresh token first
login_response = client.post(
"/api/v1/auth/login",
json={
"email": test_user.email,
"password": "TestPassword123"
}
)
refresh_token = login_response.json()["refresh_token"]
with patch('app.services.auth_service.AuthService.refresh_tokens') as mock_refresh:
mock_refresh.side_effect = Exception("Unexpected error")
response = client.post(
"/api/v1/auth/refresh",
json={"refresh_token": refresh_token}
)
assert response.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR
class TestGetCurrentUserEndpoint:
"""Tests for GET /auth/me endpoint."""
def test_get_current_user_success(self, client, test_user):
"""Test getting current user info."""
# First, login to get an access token
login_response = client.post(
"/api/v1/auth/login",
json={
"email": test_user.email,
"password": "TestPassword123"
}
)
access_token = login_response.json()["access_token"]
# Get current user info
response = client.get(
"/api/v1/auth/me",
headers={"Authorization": f"Bearer {access_token}"}
)
assert response.status_code == status.HTTP_200_OK
data = response.json()
assert data["email"] == test_user.email
assert data["first_name"] == test_user.first_name
def test_get_current_user_no_token(self, client):
"""Test getting current user without token."""
response = client.get("/api/v1/auth/me")
assert response.status_code == status.HTTP_401_UNAUTHORIZED
def test_get_current_user_invalid_token(self, client):
"""Test getting current user with invalid token."""
response = client.get(
"/api/v1/auth/me",
headers={"Authorization": "Bearer invalid_token"}
)
assert response.status_code == status.HTTP_401_UNAUTHORIZED
def test_get_current_user_expired_token(self, client):
"""Test getting current user with expired token."""
# Use a clearly invalid/malformed token
response = client.get(
"/api/v1/auth/me",
headers={"Authorization": "Bearer eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.invalid"}
)
assert response.status_code == status.HTTP_401_UNAUTHORIZED

View File

@@ -0,0 +1,377 @@
# tests/api/test_auth_password_reset.py
"""
Tests for password reset endpoints.
"""
import pytest
from unittest.mock import patch, AsyncMock, MagicMock
from fastapi import status
from app.schemas.users import PasswordResetRequest, PasswordResetConfirm
from app.utils.security import create_password_reset_token
# Disable rate limiting for tests
@pytest.fixture(autouse=True)
def disable_rate_limit():
"""Disable rate limiting for all tests in this module."""
with patch('app.api.routes.auth.limiter.enabled', False):
yield
class TestPasswordResetRequest:
"""Tests for POST /auth/password-reset/request endpoint."""
@pytest.mark.asyncio
async def test_password_reset_request_valid_email(self, client, test_user):
"""Test password reset request with valid email."""
with patch('app.api.routes.auth.email_service.send_password_reset_email') as mock_send:
mock_send.return_value = True
response = client.post(
"/api/v1/auth/password-reset/request",
json={"email": test_user.email}
)
assert response.status_code == status.HTTP_200_OK
data = response.json()
assert data["success"] is True
assert "reset link" in data["message"].lower()
# Verify email was sent
mock_send.assert_called_once()
call_args = mock_send.call_args
assert call_args.kwargs["to_email"] == test_user.email
assert call_args.kwargs["user_name"] == test_user.first_name
assert "reset_token" in call_args.kwargs
@pytest.mark.asyncio
async def test_password_reset_request_nonexistent_email(self, client):
"""Test password reset request with non-existent email."""
with patch('app.api.routes.auth.email_service.send_password_reset_email') as mock_send:
response = client.post(
"/api/v1/auth/password-reset/request",
json={"email": "nonexistent@example.com"}
)
# Should still return success to prevent email enumeration
assert response.status_code == status.HTTP_200_OK
data = response.json()
assert data["success"] is True
# Email should not be sent
mock_send.assert_not_called()
@pytest.mark.asyncio
async def test_password_reset_request_inactive_user(self, client, test_db, test_user):
"""Test password reset request with inactive user."""
# Deactivate user
test_user.is_active = False
test_db.add(test_user)
test_db.commit()
with patch('app.api.routes.auth.email_service.send_password_reset_email') as mock_send:
response = client.post(
"/api/v1/auth/password-reset/request",
json={"email": test_user.email}
)
# Should still return success to prevent email enumeration
assert response.status_code == status.HTTP_200_OK
data = response.json()
assert data["success"] is True
# Email should not be sent to inactive user
mock_send.assert_not_called()
@pytest.mark.asyncio
async def test_password_reset_request_invalid_email_format(self, client):
"""Test password reset request with invalid email format."""
response = client.post(
"/api/v1/auth/password-reset/request",
json={"email": "not-an-email"}
)
assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY
@pytest.mark.asyncio
async def test_password_reset_request_missing_email(self, client):
"""Test password reset request without email."""
response = client.post(
"/api/v1/auth/password-reset/request",
json={}
)
assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY
@pytest.mark.asyncio
async def test_password_reset_request_email_service_error(self, client, test_user):
"""Test password reset when email service fails."""
with patch('app.api.routes.auth.email_service.send_password_reset_email') as mock_send:
mock_send.side_effect = Exception("SMTP Error")
response = client.post(
"/api/v1/auth/password-reset/request",
json={"email": test_user.email}
)
# Should still return success even if email fails
assert response.status_code == status.HTTP_200_OK
data = response.json()
assert data["success"] is True
@pytest.mark.asyncio
async def test_password_reset_request_rate_limiting(self, client, test_user):
"""Test that password reset requests are rate limited."""
with patch('app.api.routes.auth.email_service.send_password_reset_email') as mock_send:
mock_send.return_value = True
# Make multiple requests quickly (3/minute limit)
for _ in range(3):
response = client.post(
"/api/v1/auth/password-reset/request",
json={"email": test_user.email}
)
assert response.status_code == status.HTTP_200_OK
class TestPasswordResetConfirm:
"""Tests for POST /auth/password-reset/confirm endpoint."""
def test_password_reset_confirm_valid_token(self, client, test_user, test_db):
"""Test password reset confirmation with valid token."""
# Generate valid token
token = create_password_reset_token(test_user.email)
new_password = "NewSecure123"
response = client.post(
"/api/v1/auth/password-reset/confirm",
json={
"token": token,
"new_password": new_password
}
)
assert response.status_code == status.HTTP_200_OK
data = response.json()
assert data["success"] is True
assert "successfully" in data["message"].lower()
# Verify user can login with new password
test_db.refresh(test_user)
from app.core.auth import verify_password
assert verify_password(new_password, test_user.password_hash) is True
def test_password_reset_confirm_expired_token(self, client, test_user):
"""Test password reset confirmation with expired token."""
import time as time_module
# Create token that expires immediately
token = create_password_reset_token(test_user.email, expires_in=1)
# Wait for token to expire
time_module.sleep(2)
response = client.post(
"/api/v1/auth/password-reset/confirm",
json={
"token": token,
"new_password": "NewSecure123"
}
)
assert response.status_code == status.HTTP_400_BAD_REQUEST
data = response.json()
# Check custom error format
assert data["success"] is False
error_msg = data["errors"][0]["message"].lower() if "errors" in data else ""
assert "invalid" in error_msg or "expired" in error_msg
def test_password_reset_confirm_invalid_token(self, client):
"""Test password reset confirmation with invalid token."""
response = client.post(
"/api/v1/auth/password-reset/confirm",
json={
"token": "invalid_token_xyz",
"new_password": "NewSecure123"
}
)
assert response.status_code == status.HTTP_400_BAD_REQUEST
data = response.json()
assert data["success"] is False
error_msg = data["errors"][0]["message"].lower() if "errors" in data else ""
assert "invalid" in error_msg or "expired" in error_msg
def test_password_reset_confirm_tampered_token(self, client, test_user):
"""Test password reset confirmation with tampered token."""
import base64
import json
# Create valid token and tamper with it
token = create_password_reset_token(test_user.email)
decoded = base64.urlsafe_b64decode(token.encode('utf-8')).decode('utf-8')
token_data = json.loads(decoded)
token_data["payload"]["email"] = "hacker@example.com"
# Re-encode tampered token
tampered = base64.urlsafe_b64encode(json.dumps(token_data).encode('utf-8')).decode('utf-8')
response = client.post(
"/api/v1/auth/password-reset/confirm",
json={
"token": tampered,
"new_password": "NewSecure123"
}
)
assert response.status_code == status.HTTP_400_BAD_REQUEST
def test_password_reset_confirm_nonexistent_user(self, client):
"""Test password reset confirmation for non-existent user."""
# Create token for email that doesn't exist
token = create_password_reset_token("nonexistent@example.com")
response = client.post(
"/api/v1/auth/password-reset/confirm",
json={
"token": token,
"new_password": "NewSecure123"
}
)
assert response.status_code == status.HTTP_404_NOT_FOUND
data = response.json()
assert data["success"] is False
error_msg = data["errors"][0]["message"].lower() if "errors" in data else ""
assert "not found" in error_msg
def test_password_reset_confirm_inactive_user(self, client, test_user, test_db):
"""Test password reset confirmation for inactive user."""
# Deactivate user
test_user.is_active = False
test_db.add(test_user)
test_db.commit()
token = create_password_reset_token(test_user.email)
response = client.post(
"/api/v1/auth/password-reset/confirm",
json={
"token": token,
"new_password": "NewSecure123"
}
)
assert response.status_code == status.HTTP_400_BAD_REQUEST
data = response.json()
assert data["success"] is False
error_msg = data["errors"][0]["message"].lower() if "errors" in data else ""
assert "inactive" in error_msg
def test_password_reset_confirm_weak_password(self, client, test_user):
"""Test password reset confirmation with weak password."""
token = create_password_reset_token(test_user.email)
# Test various weak passwords
weak_passwords = [
"short1", # Too short
"NoDigitsHere", # No digits
"no_uppercase123", # No uppercase
]
for weak_password in weak_passwords:
response = client.post(
"/api/v1/auth/password-reset/confirm",
json={
"token": token,
"new_password": weak_password
}
)
assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY
def test_password_reset_confirm_missing_fields(self, client):
"""Test password reset confirmation with missing fields."""
# Missing token
response = client.post(
"/api/v1/auth/password-reset/confirm",
json={"new_password": "NewSecure123"}
)
assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY
# Missing password
token = create_password_reset_token("test@example.com")
response = client.post(
"/api/v1/auth/password-reset/confirm",
json={"token": token}
)
assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY
def test_password_reset_confirm_database_error(self, client, test_user, test_db):
"""Test password reset confirmation with database error."""
token = create_password_reset_token(test_user.email)
with patch.object(test_db, 'commit') as mock_commit:
mock_commit.side_effect = Exception("Database error")
response = client.post(
"/api/v1/auth/password-reset/confirm",
json={
"token": token,
"new_password": "NewSecure123"
}
)
assert response.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR
data = response.json()
assert data["success"] is False
error_msg = data["errors"][0]["message"].lower() if "errors" in data else ""
assert "error" in error_msg or "resetting" in error_msg
def test_password_reset_full_flow(self, client, test_user, test_db):
"""Test complete password reset flow."""
original_password = test_user.password_hash
new_password = "BrandNew123"
# Step 1: Request password reset
with patch('app.api.routes.auth.email_service.send_password_reset_email') as mock_send:
mock_send.return_value = True
response = client.post(
"/api/v1/auth/password-reset/request",
json={"email": test_user.email}
)
assert response.status_code == status.HTTP_200_OK
# Extract token from mock call
call_args = mock_send.call_args
reset_token = call_args.kwargs["reset_token"]
# Step 2: Confirm password reset
response = client.post(
"/api/v1/auth/password-reset/confirm",
json={
"token": reset_token,
"new_password": new_password
}
)
assert response.status_code == status.HTTP_200_OK
# Step 3: Verify old password doesn't work
test_db.refresh(test_user)
from app.core.auth import verify_password
assert test_user.password_hash != original_password
# Step 4: Verify new password works
response = client.post(
"/api/v1/auth/login",
json={
"email": test_user.email,
"password": new_password
}
)
assert response.status_code == status.HTTP_200_OK
assert "access_token" in response.json()

View File

@@ -3,8 +3,12 @@ import uuid
from datetime import datetime, timezone
import pytest
from fastapi.testclient import TestClient
from app.main import app
from app.core.database import get_db
from app.models.user import User
from app.core.auth import get_password_hash
from app.utils.test_utils import setup_test_db, teardown_test_db, setup_async_test_db, teardown_async_test_db
@@ -64,3 +68,89 @@ def mock_user(db_session):
db_session.add(mock_user)
db_session.commit()
return mock_user
@pytest.fixture(scope="function")
def test_db():
"""
Creates a test database for integration tests.
This creates a fresh database for each test to ensure isolation.
"""
test_engine, TestingSessionLocal = setup_test_db()
# Create a session
with TestingSessionLocal() as session:
yield session
# Clean up
teardown_test_db(test_engine)
@pytest.fixture(scope="function")
def client(test_db):
"""
Create a FastAPI test client with a test database.
This overrides the get_db dependency to use the test database.
"""
def override_get_db():
try:
yield test_db
finally:
pass
app.dependency_overrides[get_db] = override_get_db
with TestClient(app) as test_client:
yield test_client
app.dependency_overrides.clear()
@pytest.fixture
def test_user(test_db):
"""
Create a test user in the database.
Password: TestPassword123
"""
user = User(
id=uuid.uuid4(),
email="testuser@example.com",
password_hash=get_password_hash("TestPassword123"),
first_name="Test",
last_name="User",
phone_number="+1234567890",
is_active=True,
is_superuser=False,
preferences=None,
)
test_db.add(user)
test_db.commit()
test_db.refresh(user)
return user
@pytest.fixture
def test_superuser(test_db):
"""
Create a test superuser in the database.
Password: SuperPassword123
"""
user = User(
id=uuid.uuid4(),
email="superuser@example.com",
password_hash=get_password_hash("SuperPassword123"),
first_name="Super",
last_name="User",
phone_number="+9876543210",
is_active=True,
is_superuser=True,
preferences=None,
)
test_db.add(user)
test_db.commit()
test_db.refresh(user)
return user

View File

@@ -0,0 +1,448 @@
# tests/crud/test_crud_base.py
"""
Tests for CRUD base operations.
"""
import pytest
from uuid import uuid4
from app.models.user import User
from app.crud.user import user as user_crud
from app.schemas.users import UserCreate, UserUpdate
class TestCRUDGet:
"""Tests for CRUD get operations."""
def test_get_by_valid_uuid(self, db_session):
"""Test getting a record by valid UUID."""
user = User(
email="get_uuid@example.com",
password_hash="hash",
first_name="Get",
last_name="UUID",
is_active=True,
is_superuser=False
)
db_session.add(user)
db_session.commit()
db_session.refresh(user)
retrieved = user_crud.get(db_session, id=user.id)
assert retrieved is not None
assert retrieved.id == user.id
assert retrieved.email == user.email
def test_get_by_string_uuid(self, db_session):
"""Test getting a record by UUID string."""
user = User(
email="get_string@example.com",
password_hash="hash",
first_name="Get",
last_name="String",
is_active=True,
is_superuser=False
)
db_session.add(user)
db_session.commit()
db_session.refresh(user)
retrieved = user_crud.get(db_session, id=str(user.id))
assert retrieved is not None
assert retrieved.id == user.id
def test_get_nonexistent(self, db_session):
"""Test getting a non-existent record."""
fake_id = uuid4()
result = user_crud.get(db_session, id=fake_id)
assert result is None
def test_get_invalid_uuid(self, db_session):
"""Test getting with invalid UUID format."""
result = user_crud.get(db_session, id="not-a-uuid")
assert result is None
class TestCRUDGetMulti:
"""Tests for get_multi operations."""
def test_get_multi_basic(self, db_session):
"""Test basic get_multi functionality."""
# Create multiple users
users = [
User(email=f"multi{i}@example.com", password_hash="hash", first_name=f"User{i}",
is_active=True, is_superuser=False)
for i in range(5)
]
db_session.add_all(users)
db_session.commit()
results = user_crud.get_multi(db_session, skip=0, limit=10)
assert len(results) >= 5
def test_get_multi_pagination(self, db_session):
"""Test pagination with get_multi."""
# Create users
users = [
User(email=f"page{i}@example.com", password_hash="hash", first_name=f"Page{i}",
is_active=True, is_superuser=False)
for i in range(10)
]
db_session.add_all(users)
db_session.commit()
# First page
page1 = user_crud.get_multi(db_session, skip=0, limit=3)
assert len(page1) == 3
# Second page
page2 = user_crud.get_multi(db_session, skip=3, limit=3)
assert len(page2) == 3
# Pages should have different users
page1_ids = {u.id for u in page1}
page2_ids = {u.id for u in page2}
assert len(page1_ids.intersection(page2_ids)) == 0
def test_get_multi_negative_skip(self, db_session):
"""Test that negative skip raises ValueError."""
with pytest.raises(ValueError, match="skip must be non-negative"):
user_crud.get_multi(db_session, skip=-1, limit=10)
def test_get_multi_negative_limit(self, db_session):
"""Test that negative limit raises ValueError."""
with pytest.raises(ValueError, match="limit must be non-negative"):
user_crud.get_multi(db_session, skip=0, limit=-1)
def test_get_multi_limit_too_large(self, db_session):
"""Test that limit over 1000 raises ValueError."""
with pytest.raises(ValueError, match="Maximum limit is 1000"):
user_crud.get_multi(db_session, skip=0, limit=1001)
class TestCRUDGetMultiWithTotal:
"""Tests for get_multi_with_total operations."""
def test_get_multi_with_total_basic(self, db_session):
"""Test basic get_multi_with_total functionality."""
# Create users
users = [
User(email=f"total{i}@example.com", password_hash="hash", first_name=f"Total{i}",
is_active=True, is_superuser=False)
for i in range(7)
]
db_session.add_all(users)
db_session.commit()
results, total = user_crud.get_multi_with_total(db_session, skip=0, limit=10)
assert total >= 7
assert len(results) >= 7
def test_get_multi_with_total_pagination(self, db_session):
"""Test pagination returns correct total."""
# Create users
users = [
User(email=f"pagetotal{i}@example.com", password_hash="hash", first_name=f"PageTotal{i}",
is_active=True, is_superuser=False)
for i in range(15)
]
db_session.add_all(users)
db_session.commit()
# First page
page1, total1 = user_crud.get_multi_with_total(db_session, skip=0, limit=5)
assert len(page1) == 5
assert total1 >= 15
# Second page should have same total
page2, total2 = user_crud.get_multi_with_total(db_session, skip=5, limit=5)
assert len(page2) == 5
assert total2 == total1
def test_get_multi_with_total_sorting_asc(self, db_session):
"""Test sorting in ascending order."""
# Create users
users = [
User(email=f"sort{i}@example.com", password_hash="hash", first_name=f"User{chr(90-i)}",
is_active=True, is_superuser=False)
for i in range(5)
]
db_session.add_all(users)
db_session.commit()
results, _ = user_crud.get_multi_with_total(
db_session,
sort_by="first_name",
sort_order="asc"
)
# Check that results are sorted
first_names = [u.first_name for u in results if u.first_name.startswith("User")]
assert first_names == sorted(first_names)
def test_get_multi_with_total_sorting_desc(self, db_session):
"""Test sorting in descending order."""
# Create users
users = [
User(email=f"desc{i}@example.com", password_hash="hash", first_name=f"User{chr(65+i)}",
is_active=True, is_superuser=False)
for i in range(5)
]
db_session.add_all(users)
db_session.commit()
results, _ = user_crud.get_multi_with_total(
db_session,
sort_by="first_name",
sort_order="desc"
)
# Check that results are sorted descending
first_names = [u.first_name for u in results if u.first_name.startswith("User")]
assert first_names == sorted(first_names, reverse=True)
def test_get_multi_with_total_filtering(self, db_session):
"""Test filtering with get_multi_with_total."""
# Create active and inactive users
active_user = User(
email="active_filter@example.com",
password_hash="hash",
first_name="Active",
is_active=True,
is_superuser=False
)
inactive_user = User(
email="inactive_filter@example.com",
password_hash="hash",
first_name="Inactive",
is_active=False,
is_superuser=False
)
db_session.add_all([active_user, inactive_user])
db_session.commit()
# Filter for active users only
results, total = user_crud.get_multi_with_total(
db_session,
filters={"is_active": True}
)
emails = [u.email for u in results]
assert "active_filter@example.com" in emails
assert "inactive_filter@example.com" not in emails
def test_get_multi_with_total_multiple_filters(self, db_session):
"""Test multiple filters."""
# Create users with different combinations
user1 = User(
email="multi1@example.com",
password_hash="hash",
first_name="User1",
is_active=True,
is_superuser=True
)
user2 = User(
email="multi2@example.com",
password_hash="hash",
first_name="User2",
is_active=True,
is_superuser=False
)
user3 = User(
email="multi3@example.com",
password_hash="hash",
first_name="User3",
is_active=False,
is_superuser=True
)
db_session.add_all([user1, user2, user3])
db_session.commit()
# Filter for active superusers
results, _ = user_crud.get_multi_with_total(
db_session,
filters={"is_active": True, "is_superuser": True}
)
emails = [u.email for u in results]
assert "multi1@example.com" in emails
assert "multi2@example.com" not in emails
assert "multi3@example.com" not in emails
def test_get_multi_with_total_nonexistent_sort_field(self, db_session):
"""Test sorting by non-existent field is ignored."""
results, _ = user_crud.get_multi_with_total(
db_session,
sort_by="nonexistent_field",
sort_order="asc"
)
# Should not raise an error, just ignore the invalid sort field
assert results is not None
def test_get_multi_with_total_nonexistent_filter_field(self, db_session):
"""Test filtering by non-existent field is ignored."""
results, _ = user_crud.get_multi_with_total(
db_session,
filters={"nonexistent_field": "value"}
)
# Should not raise an error, just ignore the invalid filter
assert results is not None
def test_get_multi_with_total_none_filter_values(self, db_session):
"""Test that None filter values are ignored."""
user = User(
email="none_filter@example.com",
password_hash="hash",
first_name="None",
is_active=True,
is_superuser=False
)
db_session.add(user)
db_session.commit()
# Pass None as a filter value - should be ignored
results, _ = user_crud.get_multi_with_total(
db_session,
filters={"is_active": None}
)
# Should return all users (not filtered)
assert len(results) >= 1
class TestCRUDCreate:
"""Tests for create operations."""
def test_create_basic(self, db_session):
"""Test basic record creation."""
user_data = UserCreate(
email="create@example.com",
password="Password123",
first_name="Create",
last_name="Test"
)
created = user_crud.create(db_session, obj_in=user_data)
assert created.id is not None
assert created.email == "create@example.com"
assert created.first_name == "Create"
def test_create_duplicate_email(self, db_session):
"""Test that creating duplicate email raises error."""
user_data = UserCreate(
email="duplicate@example.com",
password="Password123",
first_name="First"
)
# Create first user
user_crud.create(db_session, obj_in=user_data)
# Try to create duplicate
with pytest.raises(ValueError, match="already exists"):
user_crud.create(db_session, obj_in=user_data)
class TestCRUDUpdate:
"""Tests for update operations."""
def test_update_basic(self, db_session):
"""Test basic record update."""
user = User(
email="update@example.com",
password_hash="hash",
first_name="Original",
is_active=True,
is_superuser=False
)
db_session.add(user)
db_session.commit()
db_session.refresh(user)
update_data = UserUpdate(first_name="Updated")
updated = user_crud.update(db_session, db_obj=user, obj_in=update_data)
assert updated.first_name == "Updated"
assert updated.email == "update@example.com" # Unchanged
def test_update_with_dict(self, db_session):
"""Test updating with dictionary."""
user = User(
email="updatedict@example.com",
password_hash="hash",
first_name="Original",
is_active=True,
is_superuser=False
)
db_session.add(user)
db_session.commit()
db_session.refresh(user)
update_data = {"first_name": "DictUpdated", "last_name": "DictLast"}
updated = user_crud.update(db_session, db_obj=user, obj_in=update_data)
assert updated.first_name == "DictUpdated"
assert updated.last_name == "DictLast"
def test_update_partial(self, db_session):
"""Test partial update (only some fields)."""
user = User(
email="partial@example.com",
password_hash="hash",
first_name="First",
last_name="Last",
is_active=True,
is_superuser=False
)
db_session.add(user)
db_session.commit()
db_session.refresh(user)
# Only update last_name
update_data = UserUpdate(last_name="NewLast")
updated = user_crud.update(db_session, db_obj=user, obj_in=update_data)
assert updated.first_name == "First" # Unchanged
assert updated.last_name == "NewLast" # Changed
class TestCRUDRemove:
"""Tests for remove (hard delete) operations."""
def test_remove_basic(self, db_session):
"""Test basic record removal."""
user = User(
email="remove@example.com",
password_hash="hash",
first_name="Remove",
is_active=True,
is_superuser=False
)
db_session.add(user)
db_session.commit()
db_session.refresh(user)
user_id = user.id
# Remove the user
removed = user_crud.remove(db_session, id=user_id)
assert removed is not None
assert removed.id == user_id
# User should no longer exist
retrieved = user_crud.get(db_session, id=user_id)
assert retrieved is None
def test_remove_nonexistent(self, db_session):
"""Test removing non-existent record."""
fake_id = uuid4()
result = user_crud.remove(db_session, id=fake_id)
assert result is None
def test_remove_invalid_uuid(self, db_session):
"""Test removing with invalid UUID."""
result = user_crud.remove(db_session, id="not-a-uuid")
assert result is None

View File

@@ -0,0 +1,324 @@
# tests/crud/test_soft_delete.py
"""
Tests for soft delete functionality in CRUD operations.
"""
import pytest
from datetime import datetime, timezone
from app.models.user import User
from app.crud.user import user as user_crud
class TestSoftDelete:
"""Tests for soft delete functionality."""
def test_soft_delete_marks_deleted_at(self, db_session):
"""Test that soft delete sets deleted_at timestamp."""
# Create a user
test_user = User(
email="softdelete@example.com",
password_hash="hashedpassword",
first_name="Soft",
last_name="Delete",
is_active=True,
is_superuser=False
)
db_session.add(test_user)
db_session.commit()
db_session.refresh(test_user)
user_id = test_user.id
assert test_user.deleted_at is None
# Soft delete the user
deleted_user = user_crud.soft_delete(db_session, id=user_id)
assert deleted_user is not None
assert deleted_user.deleted_at is not None
assert isinstance(deleted_user.deleted_at, datetime)
def test_soft_delete_excludes_from_get_multi(self, db_session):
"""Test that soft deleted records are excluded from get_multi."""
# Create two users
user1 = User(
email="user1@example.com",
password_hash="hash1",
first_name="User",
last_name="One",
is_active=True,
is_superuser=False
)
user2 = User(
email="user2@example.com",
password_hash="hash2",
first_name="User",
last_name="Two",
is_active=True,
is_superuser=False
)
db_session.add_all([user1, user2])
db_session.commit()
db_session.refresh(user1)
db_session.refresh(user2)
# Both users should be returned
users, total = user_crud.get_multi_with_total(db_session)
assert total >= 2
user_emails = [u.email for u in users]
assert "user1@example.com" in user_emails
assert "user2@example.com" in user_emails
# Soft delete user1
user_crud.soft_delete(db_session, id=user1.id)
# Only user2 should be returned
users, total = user_crud.get_multi_with_total(db_session)
user_emails = [u.email for u in users]
assert "user1@example.com" not in user_emails
assert "user2@example.com" in user_emails
def test_soft_delete_still_retrievable_by_get(self, db_session):
"""Test that soft deleted records can still be retrieved by get() method."""
# Create a user
user = User(
email="gettest@example.com",
password_hash="hash",
first_name="Get",
last_name="Test",
is_active=True,
is_superuser=False
)
db_session.add(user)
db_session.commit()
db_session.refresh(user)
user_id = user.id
# User should be retrievable
retrieved = user_crud.get(db_session, id=user_id)
assert retrieved is not None
assert retrieved.email == "gettest@example.com"
assert retrieved.deleted_at is None
# Soft delete the user
user_crud.soft_delete(db_session, id=user_id)
# User should still be retrievable by ID (soft delete doesn't prevent direct access)
retrieved = user_crud.get(db_session, id=user_id)
assert retrieved is not None
assert retrieved.deleted_at is not None
def test_soft_delete_nonexistent_record(self, db_session):
"""Test soft deleting a record that doesn't exist."""
import uuid
fake_id = uuid.uuid4()
result = user_crud.soft_delete(db_session, id=fake_id)
assert result is None
def test_restore_sets_deleted_at_to_none(self, db_session):
"""Test that restore clears the deleted_at timestamp."""
# Create and soft delete a user
user = User(
email="restore@example.com",
password_hash="hash",
first_name="Restore",
last_name="Test",
is_active=True,
is_superuser=False
)
db_session.add(user)
db_session.commit()
db_session.refresh(user)
user_id = user.id
# Soft delete
user_crud.soft_delete(db_session, id=user_id)
db_session.refresh(user)
assert user.deleted_at is not None
# Restore
restored_user = user_crud.restore(db_session, id=user_id)
assert restored_user is not None
assert restored_user.deleted_at is None
def test_restore_makes_record_available(self, db_session):
"""Test that restored records appear in queries."""
# Create and soft delete a user
user = User(
email="available@example.com",
password_hash="hash",
first_name="Available",
last_name="Test",
is_active=True,
is_superuser=False
)
db_session.add(user)
db_session.commit()
db_session.refresh(user)
user_id = user.id
user_email = user.email
# Soft delete
user_crud.soft_delete(db_session, id=user_id)
# User should not be in query results
users, _ = user_crud.get_multi_with_total(db_session)
emails = [u.email for u in users]
assert user_email not in emails
# Restore
user_crud.restore(db_session, id=user_id)
# User should now be in query results
users, _ = user_crud.get_multi_with_total(db_session)
emails = [u.email for u in users]
assert user_email in emails
def test_restore_nonexistent_record(self, db_session):
"""Test restoring a record that doesn't exist."""
import uuid
fake_id = uuid.uuid4()
result = user_crud.restore(db_session, id=fake_id)
assert result is None
def test_restore_already_active_record(self, db_session):
"""Test restoring a record that was never deleted returns None."""
# Create a user (not deleted)
user = User(
email="never_deleted@example.com",
password_hash="hash",
first_name="Never",
last_name="Deleted",
is_active=True,
is_superuser=False
)
db_session.add(user)
db_session.commit()
db_session.refresh(user)
user_id = user.id
assert user.deleted_at is None
# Restore should return None (record is not soft-deleted)
restored = user_crud.restore(db_session, id=user_id)
assert restored is None
def test_soft_delete_multiple_times(self, db_session):
"""Test soft deleting the same record multiple times."""
# Create a user
user = User(
email="multiple_delete@example.com",
password_hash="hash",
first_name="Multiple",
last_name="Delete",
is_active=True,
is_superuser=False
)
db_session.add(user)
db_session.commit()
db_session.refresh(user)
user_id = user.id
# First soft delete
first_deleted = user_crud.soft_delete(db_session, id=user_id)
assert first_deleted is not None
first_timestamp = first_deleted.deleted_at
# Restore
user_crud.restore(db_session, id=user_id)
# Second soft delete
second_deleted = user_crud.soft_delete(db_session, id=user_id)
assert second_deleted is not None
second_timestamp = second_deleted.deleted_at
# Timestamps should be different
assert second_timestamp != first_timestamp
assert second_timestamp > first_timestamp
def test_get_multi_with_filters_excludes_deleted(self, db_session):
"""Test that get_multi_with_total with filters excludes deleted records."""
# Create active and inactive users
active_user = User(
email="active_not_deleted@example.com",
password_hash="hash",
first_name="Active",
last_name="NotDeleted",
is_active=True,
is_superuser=False
)
inactive_user = User(
email="inactive_not_deleted@example.com",
password_hash="hash",
first_name="Inactive",
last_name="NotDeleted",
is_active=False,
is_superuser=False
)
deleted_active_user = User(
email="active_deleted@example.com",
password_hash="hash",
first_name="Active",
last_name="Deleted",
is_active=True,
is_superuser=False
)
db_session.add_all([active_user, inactive_user, deleted_active_user])
db_session.commit()
db_session.refresh(deleted_active_user)
# Soft delete one active user
user_crud.soft_delete(db_session, id=deleted_active_user.id)
# Filter for active users - should only return non-deleted active user
users, total = user_crud.get_multi_with_total(
db_session,
filters={"is_active": True}
)
emails = [u.email for u in users]
assert "active_not_deleted@example.com" in emails
assert "active_deleted@example.com" not in emails
assert "inactive_not_deleted@example.com" not in emails
def test_soft_delete_preserves_other_fields(self, db_session):
"""Test that soft delete doesn't modify other fields."""
# Create a user with specific data
user = User(
email="preserve@example.com",
password_hash="original_hash",
first_name="Preserve",
last_name="Fields",
phone_number="+1234567890",
is_active=True,
is_superuser=False,
preferences={"theme": "dark"}
)
db_session.add(user)
db_session.commit()
db_session.refresh(user)
user_id = user.id
original_email = user.email
original_hash = user.password_hash
original_first_name = user.first_name
original_phone = user.phone_number
original_preferences = user.preferences
# Soft delete
deleted = user_crud.soft_delete(db_session, id=user_id)
# All other fields should remain unchanged
assert deleted.email == original_email
assert deleted.password_hash == original_hash
assert deleted.first_name == original_first_name
assert deleted.phone_number == original_phone
assert deleted.preferences == original_preferences
assert deleted.is_active is True # is_active unchanged

View File

@@ -0,0 +1,281 @@
# tests/services/test_email_service.py
"""
Tests for email service functionality.
"""
import pytest
from unittest.mock import patch, AsyncMock, MagicMock
from app.services.email_service import (
EmailService,
ConsoleEmailBackend,
SMTPEmailBackend
)
class TestConsoleEmailBackend:
"""Tests for ConsoleEmailBackend."""
@pytest.mark.asyncio
async def test_send_email_basic(self):
"""Test basic email sending with console backend."""
backend = ConsoleEmailBackend()
result = await backend.send_email(
to=["user@example.com"],
subject="Test Subject",
html_content="<p>Test HTML</p>",
text_content="Test Text"
)
assert result is True
@pytest.mark.asyncio
async def test_send_email_without_text_content(self):
"""Test sending email without plain text version."""
backend = ConsoleEmailBackend()
result = await backend.send_email(
to=["user@example.com"],
subject="Test Subject",
html_content="<p>Test HTML</p>"
)
assert result is True
@pytest.mark.asyncio
async def test_send_email_multiple_recipients(self):
"""Test sending email to multiple recipients."""
backend = ConsoleEmailBackend()
result = await backend.send_email(
to=["user1@example.com", "user2@example.com"],
subject="Test Subject",
html_content="<p>Test HTML</p>"
)
assert result is True
class TestSMTPEmailBackend:
"""Tests for SMTPEmailBackend."""
@pytest.mark.asyncio
async def test_smtp_backend_initialization(self):
"""Test SMTP backend initialization."""
backend = SMTPEmailBackend(
host="smtp.example.com",
port=587,
username="test@example.com",
password="password"
)
assert backend.host == "smtp.example.com"
assert backend.port == 587
assert backend.username == "test@example.com"
assert backend.password == "password"
@pytest.mark.asyncio
async def test_smtp_backend_fallback_to_console(self):
"""Test that SMTP backend falls back to console when not implemented."""
backend = SMTPEmailBackend(
host="smtp.example.com",
port=587,
username="test@example.com",
password="password"
)
# Should fall back to console backend since SMTP is not implemented
result = await backend.send_email(
to=["user@example.com"],
subject="Test Subject",
html_content="<p>Test HTML</p>"
)
assert result is True
class TestEmailService:
"""Tests for EmailService."""
def test_email_service_default_backend(self):
"""Test that EmailService uses ConsoleEmailBackend by default."""
service = EmailService()
assert isinstance(service.backend, ConsoleEmailBackend)
def test_email_service_custom_backend(self):
"""Test EmailService with custom backend."""
custom_backend = ConsoleEmailBackend()
service = EmailService(backend=custom_backend)
assert service.backend is custom_backend
@pytest.mark.asyncio
async def test_send_password_reset_email(self):
"""Test sending password reset email."""
service = EmailService()
result = await service.send_password_reset_email(
to_email="user@example.com",
reset_token="test_token_123",
user_name="John"
)
assert result is True
@pytest.mark.asyncio
async def test_send_password_reset_email_without_name(self):
"""Test sending password reset email without user name."""
service = EmailService()
result = await service.send_password_reset_email(
to_email="user@example.com",
reset_token="test_token_123"
)
assert result is True
@pytest.mark.asyncio
async def test_send_password_reset_email_includes_token_in_url(self):
"""Test that password reset email includes token in URL."""
backend_mock = AsyncMock(spec=ConsoleEmailBackend)
backend_mock.send_email = AsyncMock(return_value=True)
service = EmailService(backend=backend_mock)
token = "test_reset_token_xyz"
await service.send_password_reset_email(
to_email="user@example.com",
reset_token=token
)
# Verify send_email was called
backend_mock.send_email.assert_called_once()
call_args = backend_mock.send_email.call_args
# Check that token is in the HTML content
html_content = call_args.kwargs['html_content']
assert token in html_content
@pytest.mark.asyncio
async def test_send_password_reset_email_error_handling(self):
"""Test error handling in password reset email."""
backend_mock = AsyncMock(spec=ConsoleEmailBackend)
backend_mock.send_email = AsyncMock(side_effect=Exception("SMTP Error"))
service = EmailService(backend=backend_mock)
result = await service.send_password_reset_email(
to_email="user@example.com",
reset_token="test_token"
)
assert result is False
@pytest.mark.asyncio
async def test_send_email_verification(self):
"""Test sending email verification email."""
service = EmailService()
result = await service.send_email_verification(
to_email="user@example.com",
verification_token="verification_token_123",
user_name="Jane"
)
assert result is True
@pytest.mark.asyncio
async def test_send_email_verification_without_name(self):
"""Test sending email verification without user name."""
service = EmailService()
result = await service.send_email_verification(
to_email="user@example.com",
verification_token="verification_token_123"
)
assert result is True
@pytest.mark.asyncio
async def test_send_email_verification_includes_token(self):
"""Test that email verification includes token in URL."""
backend_mock = AsyncMock(spec=ConsoleEmailBackend)
backend_mock.send_email = AsyncMock(return_value=True)
service = EmailService(backend=backend_mock)
token = "test_verification_token_xyz"
await service.send_email_verification(
to_email="user@example.com",
verification_token=token
)
# Verify send_email was called
backend_mock.send_email.assert_called_once()
call_args = backend_mock.send_email.call_args
# Check that token is in the HTML content
html_content = call_args.kwargs['html_content']
assert token in html_content
@pytest.mark.asyncio
async def test_send_email_verification_error_handling(self):
"""Test error handling in email verification."""
backend_mock = AsyncMock(spec=ConsoleEmailBackend)
backend_mock.send_email = AsyncMock(side_effect=Exception("Email Error"))
service = EmailService(backend=backend_mock)
result = await service.send_email_verification(
to_email="user@example.com",
verification_token="test_token"
)
assert result is False
@pytest.mark.asyncio
async def test_password_reset_email_contains_required_elements(self):
"""Test that password reset email has all required elements."""
backend_mock = AsyncMock(spec=ConsoleEmailBackend)
backend_mock.send_email = AsyncMock(return_value=True)
service = EmailService(backend=backend_mock)
await service.send_password_reset_email(
to_email="user@example.com",
reset_token="token123",
user_name="Test User"
)
call_args = backend_mock.send_email.call_args
html_content = call_args.kwargs['html_content']
text_content = call_args.kwargs['text_content']
# Check HTML content
assert "Password Reset" in html_content
assert "token123" in html_content
assert "Test User" in html_content
# Check text content
assert "Password Reset" in text_content or "password reset" in text_content.lower()
assert "token123" in text_content
@pytest.mark.asyncio
async def test_verification_email_contains_required_elements(self):
"""Test that verification email has all required elements."""
backend_mock = AsyncMock(spec=ConsoleEmailBackend)
backend_mock.send_email = AsyncMock(return_value=True)
service = EmailService(backend=backend_mock)
await service.send_email_verification(
to_email="user@example.com",
verification_token="verify123",
user_name="Test User"
)
call_args = backend_mock.send_email.call_args
html_content = call_args.kwargs['html_content']
text_content = call_args.kwargs['text_content']
# Check HTML content
assert "Verify" in html_content
assert "verify123" in html_content
assert "Test User" in html_content
# Check text content
assert "verify" in text_content.lower()
assert "verify123" in text_content

View File

@@ -8,7 +8,14 @@ import json
import pytest
from unittest.mock import patch, MagicMock
from app.utils.security import create_upload_token, verify_upload_token
from app.utils.security import (
create_upload_token,
verify_upload_token,
create_password_reset_token,
verify_password_reset_token,
create_email_verification_token,
verify_email_verification_token
)
class TestCreateUploadToken:
@@ -231,3 +238,189 @@ class TestVerifyUploadToken:
# The signature validation is already tested by test_verify_invalid_signature
# and test_verify_tampered_payload. Testing with different SECRET_KEY
# requires complex mocking that can interfere with other tests.
class TestPasswordResetTokens:
"""Tests for password reset token functions."""
def test_create_password_reset_token(self):
"""Test creating a password reset token."""
email = "user@example.com"
token = create_password_reset_token(email)
assert token is not None
assert isinstance(token, str)
assert len(token) > 0
def test_verify_password_reset_token_valid(self):
"""Test verifying a valid password reset token."""
email = "user@example.com"
token = create_password_reset_token(email)
verified_email = verify_password_reset_token(token)
assert verified_email == email
def test_verify_password_reset_token_expired(self):
"""Test that expired password reset tokens are rejected."""
email = "user@example.com"
# Create token that expires in 1 second
with patch('app.utils.security.time') as mock_time:
mock_time.time = MagicMock(return_value=1000000)
token = create_password_reset_token(email, expires_in=1)
# Fast forward time
mock_time.time.return_value = 1000002
verified_email = verify_password_reset_token(token)
assert verified_email is None
def test_verify_password_reset_token_invalid(self):
"""Test that invalid tokens are rejected."""
assert verify_password_reset_token("invalid_token") is None
assert verify_password_reset_token("") is None
def test_verify_password_reset_token_tampered(self):
"""Test that tampered tokens are rejected."""
email = "user@example.com"
token = create_password_reset_token(email)
# Decode and tamper
decoded = base64.urlsafe_b64decode(token.encode('utf-8')).decode('utf-8')
token_data = json.loads(decoded)
token_data["payload"]["email"] = "hacker@example.com"
# Re-encode
tampered = base64.urlsafe_b64encode(json.dumps(token_data).encode('utf-8')).decode('utf-8')
verified_email = verify_password_reset_token(tampered)
assert verified_email is None
def test_verify_password_reset_token_wrong_purpose(self):
"""Test that email verification tokens can't be used for password reset."""
email = "user@example.com"
# Create an email verification token
token = create_email_verification_token(email)
# Try to verify as password reset token
verified_email = verify_password_reset_token(token)
assert verified_email is None
def test_password_reset_token_custom_expiration(self):
"""Test password reset token with custom expiration."""
email = "user@example.com"
custom_exp = 7200 # 2 hours
with patch('app.utils.security.time') as mock_time:
current_time = 1000000
mock_time.time = MagicMock(return_value=current_time)
token = create_password_reset_token(email, expires_in=custom_exp)
# Decode to check expiration
decoded = base64.urlsafe_b64decode(token.encode('utf-8')).decode('utf-8')
token_data = json.loads(decoded)
assert token_data["payload"]["exp"] == current_time + custom_exp
class TestEmailVerificationTokens:
"""Tests for email verification token functions."""
def test_create_email_verification_token(self):
"""Test creating an email verification token."""
email = "user@example.com"
token = create_email_verification_token(email)
assert token is not None
assert isinstance(token, str)
assert len(token) > 0
def test_verify_email_verification_token_valid(self):
"""Test verifying a valid email verification token."""
email = "user@example.com"
token = create_email_verification_token(email)
verified_email = verify_email_verification_token(token)
assert verified_email == email
def test_verify_email_verification_token_expired(self):
"""Test that expired verification tokens are rejected."""
email = "user@example.com"
with patch('app.utils.security.time') as mock_time:
mock_time.time = MagicMock(return_value=1000000)
token = create_email_verification_token(email, expires_in=1)
# Fast forward time
mock_time.time.return_value = 1000002
verified_email = verify_email_verification_token(token)
assert verified_email is None
def test_verify_email_verification_token_invalid(self):
"""Test that invalid tokens are rejected."""
assert verify_email_verification_token("invalid_token") is None
assert verify_email_verification_token("") is None
def test_verify_email_verification_token_tampered(self):
"""Test that tampered verification tokens are rejected."""
email = "user@example.com"
token = create_email_verification_token(email)
# Decode and tamper
decoded = base64.urlsafe_b64decode(token.encode('utf-8')).decode('utf-8')
token_data = json.loads(decoded)
token_data["payload"]["email"] = "hacker@example.com"
# Re-encode
tampered = base64.urlsafe_b64encode(json.dumps(token_data).encode('utf-8')).decode('utf-8')
verified_email = verify_email_verification_token(tampered)
assert verified_email is None
def test_verify_email_verification_token_wrong_purpose(self):
"""Test that password reset tokens can't be used for email verification."""
email = "user@example.com"
# Create a password reset token
token = create_password_reset_token(email)
# Try to verify as email verification token
verified_email = verify_email_verification_token(token)
assert verified_email is None
def test_email_verification_token_default_expiration(self):
"""Test email verification token with default 24-hour expiration."""
email = "user@example.com"
with patch('app.utils.security.time') as mock_time:
current_time = 1000000
mock_time.time = MagicMock(return_value=current_time)
token = create_email_verification_token(email)
# Decode to check expiration (should be 86400 seconds = 24 hours)
decoded = base64.urlsafe_b64decode(token.encode('utf-8')).decode('utf-8')
token_data = json.loads(decoded)
assert token_data["payload"]["exp"] == current_time + 86400
def test_tokens_are_unique(self):
"""Test that multiple tokens for the same email are unique."""
email = "user@example.com"
token1 = create_password_reset_token(email)
token2 = create_password_reset_token(email)
assert token1 != token2
def test_verification_and_reset_tokens_are_different(self):
"""Test that verification and reset tokens for same email are different."""
email = "user@example.com"
reset_token = create_password_reset_token(email)
verify_token = create_email_verification_token(email)
assert reset_token != verify_token