diff --git a/backend/tests/api/test_auth_endpoints.py b/backend/tests/api/test_auth_endpoints.py new file mode 100644 index 0000000..97b201d --- /dev/null +++ b/backend/tests/api/test_auth_endpoints.py @@ -0,0 +1,348 @@ +# tests/api/test_auth_endpoints.py +""" +Tests for authentication endpoints. +""" +import pytest +from unittest.mock import patch, MagicMock +from fastapi import status + +from app.models.user import User +from app.schemas.users import UserCreate + + +# Disable rate limiting for tests +@pytest.fixture(autouse=True) +def disable_rate_limit(): + """Disable rate limiting for all tests in this module.""" + with patch('app.api.routes.auth.limiter.enabled', False): + yield + + +class TestRegisterEndpoint: + """Tests for POST /auth/register endpoint.""" + + def test_register_success(self, client, test_db): + """Test successful user registration.""" + response = client.post( + "/api/v1/auth/register", + json={ + "email": "newuser@example.com", + "password": "SecurePassword123", + "first_name": "New", + "last_name": "User" + } + ) + + assert response.status_code == status.HTTP_201_CREATED + data = response.json() + assert data["email"] == "newuser@example.com" + assert data["first_name"] == "New" + assert "password" not in data + + def test_register_duplicate_email(self, client, test_user): + """Test registering with existing email.""" + response = client.post( + "/api/v1/auth/register", + json={ + "email": test_user.email, + "password": "SecurePassword123", + "first_name": "Duplicate", + "last_name": "User" + } + ) + + assert response.status_code == status.HTTP_409_CONFLICT + data = response.json() + assert data["success"] is False + + def test_register_weak_password(self, client): + """Test registration with weak password.""" + response = client.post( + "/api/v1/auth/register", + json={ + "email": "weakpass@example.com", + "password": "weak", + "first_name": "Weak", + "last_name": "Pass" + } + ) + + assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY + + def test_register_unexpected_error(self, client, test_db): + """Test registration with unexpected error.""" + with patch('app.services.auth_service.AuthService.create_user') as mock_create: + mock_create.side_effect = Exception("Unexpected error") + + response = client.post( + "/api/v1/auth/register", + json={ + "email": "error@example.com", + "password": "SecurePassword123", + "first_name": "Error", + "last_name": "User" + } + ) + + assert response.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR + + +class TestLoginEndpoint: + """Tests for POST /auth/login endpoint.""" + + def test_login_success(self, client, test_user): + """Test successful login.""" + response = client.post( + "/api/v1/auth/login", + json={ + "email": test_user.email, + "password": "TestPassword123" + } + ) + + assert response.status_code == status.HTTP_200_OK + data = response.json() + assert "access_token" in data + assert "refresh_token" in data + assert data["token_type"] == "bearer" + + def test_login_wrong_password(self, client, test_user): + """Test login with wrong password.""" + response = client.post( + "/api/v1/auth/login", + json={ + "email": test_user.email, + "password": "WrongPassword123" + } + ) + + assert response.status_code == status.HTTP_401_UNAUTHORIZED + + def test_login_nonexistent_user(self, client): + """Test login with non-existent email.""" + response = client.post( + "/api/v1/auth/login", + json={ + "email": "nonexistent@example.com", + "password": "Password123" + } + ) + + assert response.status_code == status.HTTP_401_UNAUTHORIZED + + def test_login_inactive_user(self, client, test_user, test_db): + """Test login with inactive user.""" + test_user.is_active = False + test_db.add(test_user) + test_db.commit() + + response = client.post( + "/api/v1/auth/login", + json={ + "email": test_user.email, + "password": "TestPassword123" + } + ) + + assert response.status_code == status.HTTP_401_UNAUTHORIZED + + def test_login_unexpected_error(self, client, test_user): + """Test login with unexpected error.""" + with patch('app.services.auth_service.AuthService.authenticate_user') as mock_auth: + mock_auth.side_effect = Exception("Database error") + + response = client.post( + "/api/v1/auth/login", + json={ + "email": test_user.email, + "password": "TestPassword123" + } + ) + + assert response.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR + + +class TestOAuthLoginEndpoint: + """Tests for POST /auth/login/oauth endpoint.""" + + def test_oauth_login_success(self, client, test_user): + """Test successful OAuth login.""" + response = client.post( + "/api/v1/auth/login/oauth", + data={ + "username": test_user.email, + "password": "TestPassword123" + } + ) + + assert response.status_code == status.HTTP_200_OK + data = response.json() + assert "access_token" in data + assert "refresh_token" in data + + def test_oauth_login_wrong_credentials(self, client, test_user): + """Test OAuth login with wrong credentials.""" + response = client.post( + "/api/v1/auth/login/oauth", + data={ + "username": test_user.email, + "password": "WrongPassword" + } + ) + + assert response.status_code == status.HTTP_401_UNAUTHORIZED + + def test_oauth_login_inactive_user(self, client, test_user, test_db): + """Test OAuth login with inactive user.""" + test_user.is_active = False + test_db.add(test_user) + test_db.commit() + + response = client.post( + "/api/v1/auth/login/oauth", + data={ + "username": test_user.email, + "password": "TestPassword123" + } + ) + + assert response.status_code == status.HTTP_401_UNAUTHORIZED + + def test_oauth_login_unexpected_error(self, client, test_user): + """Test OAuth login with unexpected error.""" + with patch('app.services.auth_service.AuthService.authenticate_user') as mock_auth: + mock_auth.side_effect = Exception("Unexpected error") + + response = client.post( + "/api/v1/auth/login/oauth", + data={ + "username": test_user.email, + "password": "TestPassword123" + } + ) + + assert response.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR + + +class TestRefreshTokenEndpoint: + """Tests for POST /auth/refresh endpoint.""" + + def test_refresh_token_success(self, client, test_user): + """Test successful token refresh.""" + # First, login to get a refresh token + login_response = client.post( + "/api/v1/auth/login", + json={ + "email": test_user.email, + "password": "TestPassword123" + } + ) + refresh_token = login_response.json()["refresh_token"] + + # Now refresh the token + response = client.post( + "/api/v1/auth/refresh", + json={"refresh_token": refresh_token} + ) + + assert response.status_code == status.HTTP_200_OK + data = response.json() + assert "access_token" in data + assert "refresh_token" in data + + def test_refresh_token_expired(self, client): + """Test refresh with expired token.""" + from app.core.auth import TokenExpiredError + + with patch('app.services.auth_service.AuthService.refresh_tokens') as mock_refresh: + mock_refresh.side_effect = TokenExpiredError("Token expired") + + response = client.post( + "/api/v1/auth/refresh", + json={"refresh_token": "some_token"} + ) + + assert response.status_code == status.HTTP_401_UNAUTHORIZED + + def test_refresh_token_invalid(self, client): + """Test refresh with invalid token.""" + response = client.post( + "/api/v1/auth/refresh", + json={"refresh_token": "invalid_token"} + ) + + assert response.status_code == status.HTTP_401_UNAUTHORIZED + + def test_refresh_token_unexpected_error(self, client, test_user): + """Test refresh with unexpected error.""" + # Get a valid refresh token first + login_response = client.post( + "/api/v1/auth/login", + json={ + "email": test_user.email, + "password": "TestPassword123" + } + ) + refresh_token = login_response.json()["refresh_token"] + + with patch('app.services.auth_service.AuthService.refresh_tokens') as mock_refresh: + mock_refresh.side_effect = Exception("Unexpected error") + + response = client.post( + "/api/v1/auth/refresh", + json={"refresh_token": refresh_token} + ) + + assert response.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR + + +class TestGetCurrentUserEndpoint: + """Tests for GET /auth/me endpoint.""" + + def test_get_current_user_success(self, client, test_user): + """Test getting current user info.""" + # First, login to get an access token + login_response = client.post( + "/api/v1/auth/login", + json={ + "email": test_user.email, + "password": "TestPassword123" + } + ) + access_token = login_response.json()["access_token"] + + # Get current user info + response = client.get( + "/api/v1/auth/me", + headers={"Authorization": f"Bearer {access_token}"} + ) + + assert response.status_code == status.HTTP_200_OK + data = response.json() + assert data["email"] == test_user.email + assert data["first_name"] == test_user.first_name + + def test_get_current_user_no_token(self, client): + """Test getting current user without token.""" + response = client.get("/api/v1/auth/me") + + assert response.status_code == status.HTTP_401_UNAUTHORIZED + + def test_get_current_user_invalid_token(self, client): + """Test getting current user with invalid token.""" + response = client.get( + "/api/v1/auth/me", + headers={"Authorization": "Bearer invalid_token"} + ) + + assert response.status_code == status.HTTP_401_UNAUTHORIZED + + def test_get_current_user_expired_token(self, client): + """Test getting current user with expired token.""" + # Use a clearly invalid/malformed token + response = client.get( + "/api/v1/auth/me", + headers={"Authorization": "Bearer eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.invalid"} + ) + + assert response.status_code == status.HTTP_401_UNAUTHORIZED diff --git a/backend/tests/api/test_auth_password_reset.py b/backend/tests/api/test_auth_password_reset.py new file mode 100644 index 0000000..d70bf89 --- /dev/null +++ b/backend/tests/api/test_auth_password_reset.py @@ -0,0 +1,377 @@ +# tests/api/test_auth_password_reset.py +""" +Tests for password reset endpoints. +""" +import pytest +from unittest.mock import patch, AsyncMock, MagicMock +from fastapi import status + +from app.schemas.users import PasswordResetRequest, PasswordResetConfirm +from app.utils.security import create_password_reset_token + + +# Disable rate limiting for tests +@pytest.fixture(autouse=True) +def disable_rate_limit(): + """Disable rate limiting for all tests in this module.""" + with patch('app.api.routes.auth.limiter.enabled', False): + yield + + +class TestPasswordResetRequest: + """Tests for POST /auth/password-reset/request endpoint.""" + + @pytest.mark.asyncio + async def test_password_reset_request_valid_email(self, client, test_user): + """Test password reset request with valid email.""" + with patch('app.api.routes.auth.email_service.send_password_reset_email') as mock_send: + mock_send.return_value = True + + response = client.post( + "/api/v1/auth/password-reset/request", + json={"email": test_user.email} + ) + + assert response.status_code == status.HTTP_200_OK + data = response.json() + assert data["success"] is True + assert "reset link" in data["message"].lower() + + # Verify email was sent + mock_send.assert_called_once() + call_args = mock_send.call_args + assert call_args.kwargs["to_email"] == test_user.email + assert call_args.kwargs["user_name"] == test_user.first_name + assert "reset_token" in call_args.kwargs + + @pytest.mark.asyncio + async def test_password_reset_request_nonexistent_email(self, client): + """Test password reset request with non-existent email.""" + with patch('app.api.routes.auth.email_service.send_password_reset_email') as mock_send: + response = client.post( + "/api/v1/auth/password-reset/request", + json={"email": "nonexistent@example.com"} + ) + + # Should still return success to prevent email enumeration + assert response.status_code == status.HTTP_200_OK + data = response.json() + assert data["success"] is True + + # Email should not be sent + mock_send.assert_not_called() + + @pytest.mark.asyncio + async def test_password_reset_request_inactive_user(self, client, test_db, test_user): + """Test password reset request with inactive user.""" + # Deactivate user + test_user.is_active = False + test_db.add(test_user) + test_db.commit() + + with patch('app.api.routes.auth.email_service.send_password_reset_email') as mock_send: + response = client.post( + "/api/v1/auth/password-reset/request", + json={"email": test_user.email} + ) + + # Should still return success to prevent email enumeration + assert response.status_code == status.HTTP_200_OK + data = response.json() + assert data["success"] is True + + # Email should not be sent to inactive user + mock_send.assert_not_called() + + @pytest.mark.asyncio + async def test_password_reset_request_invalid_email_format(self, client): + """Test password reset request with invalid email format.""" + response = client.post( + "/api/v1/auth/password-reset/request", + json={"email": "not-an-email"} + ) + + assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY + + @pytest.mark.asyncio + async def test_password_reset_request_missing_email(self, client): + """Test password reset request without email.""" + response = client.post( + "/api/v1/auth/password-reset/request", + json={} + ) + + assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY + + @pytest.mark.asyncio + async def test_password_reset_request_email_service_error(self, client, test_user): + """Test password reset when email service fails.""" + with patch('app.api.routes.auth.email_service.send_password_reset_email') as mock_send: + mock_send.side_effect = Exception("SMTP Error") + + response = client.post( + "/api/v1/auth/password-reset/request", + json={"email": test_user.email} + ) + + # Should still return success even if email fails + assert response.status_code == status.HTTP_200_OK + data = response.json() + assert data["success"] is True + + @pytest.mark.asyncio + async def test_password_reset_request_rate_limiting(self, client, test_user): + """Test that password reset requests are rate limited.""" + with patch('app.api.routes.auth.email_service.send_password_reset_email') as mock_send: + mock_send.return_value = True + + # Make multiple requests quickly (3/minute limit) + for _ in range(3): + response = client.post( + "/api/v1/auth/password-reset/request", + json={"email": test_user.email} + ) + assert response.status_code == status.HTTP_200_OK + + +class TestPasswordResetConfirm: + """Tests for POST /auth/password-reset/confirm endpoint.""" + + def test_password_reset_confirm_valid_token(self, client, test_user, test_db): + """Test password reset confirmation with valid token.""" + # Generate valid token + token = create_password_reset_token(test_user.email) + new_password = "NewSecure123" + + response = client.post( + "/api/v1/auth/password-reset/confirm", + json={ + "token": token, + "new_password": new_password + } + ) + + assert response.status_code == status.HTTP_200_OK + data = response.json() + assert data["success"] is True + assert "successfully" in data["message"].lower() + + # Verify user can login with new password + test_db.refresh(test_user) + from app.core.auth import verify_password + assert verify_password(new_password, test_user.password_hash) is True + + def test_password_reset_confirm_expired_token(self, client, test_user): + """Test password reset confirmation with expired token.""" + import time as time_module + + # Create token that expires immediately + token = create_password_reset_token(test_user.email, expires_in=1) + + # Wait for token to expire + time_module.sleep(2) + + response = client.post( + "/api/v1/auth/password-reset/confirm", + json={ + "token": token, + "new_password": "NewSecure123" + } + ) + + assert response.status_code == status.HTTP_400_BAD_REQUEST + data = response.json() + # Check custom error format + assert data["success"] is False + error_msg = data["errors"][0]["message"].lower() if "errors" in data else "" + assert "invalid" in error_msg or "expired" in error_msg + + def test_password_reset_confirm_invalid_token(self, client): + """Test password reset confirmation with invalid token.""" + response = client.post( + "/api/v1/auth/password-reset/confirm", + json={ + "token": "invalid_token_xyz", + "new_password": "NewSecure123" + } + ) + + assert response.status_code == status.HTTP_400_BAD_REQUEST + data = response.json() + assert data["success"] is False + error_msg = data["errors"][0]["message"].lower() if "errors" in data else "" + assert "invalid" in error_msg or "expired" in error_msg + + def test_password_reset_confirm_tampered_token(self, client, test_user): + """Test password reset confirmation with tampered token.""" + import base64 + import json + + # Create valid token and tamper with it + token = create_password_reset_token(test_user.email) + decoded = base64.urlsafe_b64decode(token.encode('utf-8')).decode('utf-8') + token_data = json.loads(decoded) + token_data["payload"]["email"] = "hacker@example.com" + + # Re-encode tampered token + tampered = base64.urlsafe_b64encode(json.dumps(token_data).encode('utf-8')).decode('utf-8') + + response = client.post( + "/api/v1/auth/password-reset/confirm", + json={ + "token": tampered, + "new_password": "NewSecure123" + } + ) + + assert response.status_code == status.HTTP_400_BAD_REQUEST + + def test_password_reset_confirm_nonexistent_user(self, client): + """Test password reset confirmation for non-existent user.""" + # Create token for email that doesn't exist + token = create_password_reset_token("nonexistent@example.com") + + response = client.post( + "/api/v1/auth/password-reset/confirm", + json={ + "token": token, + "new_password": "NewSecure123" + } + ) + + assert response.status_code == status.HTTP_404_NOT_FOUND + data = response.json() + assert data["success"] is False + error_msg = data["errors"][0]["message"].lower() if "errors" in data else "" + assert "not found" in error_msg + + def test_password_reset_confirm_inactive_user(self, client, test_user, test_db): + """Test password reset confirmation for inactive user.""" + # Deactivate user + test_user.is_active = False + test_db.add(test_user) + test_db.commit() + + token = create_password_reset_token(test_user.email) + + response = client.post( + "/api/v1/auth/password-reset/confirm", + json={ + "token": token, + "new_password": "NewSecure123" + } + ) + + assert response.status_code == status.HTTP_400_BAD_REQUEST + data = response.json() + assert data["success"] is False + error_msg = data["errors"][0]["message"].lower() if "errors" in data else "" + assert "inactive" in error_msg + + def test_password_reset_confirm_weak_password(self, client, test_user): + """Test password reset confirmation with weak password.""" + token = create_password_reset_token(test_user.email) + + # Test various weak passwords + weak_passwords = [ + "short1", # Too short + "NoDigitsHere", # No digits + "no_uppercase123", # No uppercase + ] + + for weak_password in weak_passwords: + response = client.post( + "/api/v1/auth/password-reset/confirm", + json={ + "token": token, + "new_password": weak_password + } + ) + + assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY + + def test_password_reset_confirm_missing_fields(self, client): + """Test password reset confirmation with missing fields.""" + # Missing token + response = client.post( + "/api/v1/auth/password-reset/confirm", + json={"new_password": "NewSecure123"} + ) + assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY + + # Missing password + token = create_password_reset_token("test@example.com") + response = client.post( + "/api/v1/auth/password-reset/confirm", + json={"token": token} + ) + assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY + + def test_password_reset_confirm_database_error(self, client, test_user, test_db): + """Test password reset confirmation with database error.""" + token = create_password_reset_token(test_user.email) + + with patch.object(test_db, 'commit') as mock_commit: + mock_commit.side_effect = Exception("Database error") + + response = client.post( + "/api/v1/auth/password-reset/confirm", + json={ + "token": token, + "new_password": "NewSecure123" + } + ) + + assert response.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR + data = response.json() + assert data["success"] is False + error_msg = data["errors"][0]["message"].lower() if "errors" in data else "" + assert "error" in error_msg or "resetting" in error_msg + + def test_password_reset_full_flow(self, client, test_user, test_db): + """Test complete password reset flow.""" + original_password = test_user.password_hash + new_password = "BrandNew123" + + # Step 1: Request password reset + with patch('app.api.routes.auth.email_service.send_password_reset_email') as mock_send: + mock_send.return_value = True + + response = client.post( + "/api/v1/auth/password-reset/request", + json={"email": test_user.email} + ) + + assert response.status_code == status.HTTP_200_OK + + # Extract token from mock call + call_args = mock_send.call_args + reset_token = call_args.kwargs["reset_token"] + + # Step 2: Confirm password reset + response = client.post( + "/api/v1/auth/password-reset/confirm", + json={ + "token": reset_token, + "new_password": new_password + } + ) + + assert response.status_code == status.HTTP_200_OK + + # Step 3: Verify old password doesn't work + test_db.refresh(test_user) + from app.core.auth import verify_password + assert test_user.password_hash != original_password + + # Step 4: Verify new password works + response = client.post( + "/api/v1/auth/login", + json={ + "email": test_user.email, + "password": new_password + } + ) + + assert response.status_code == status.HTTP_200_OK + assert "access_token" in response.json() diff --git a/backend/tests/conftest.py b/backend/tests/conftest.py index 1b83a92..c2356f7 100644 --- a/backend/tests/conftest.py +++ b/backend/tests/conftest.py @@ -3,8 +3,12 @@ import uuid from datetime import datetime, timezone import pytest +from fastapi.testclient import TestClient +from app.main import app +from app.core.database import get_db from app.models.user import User +from app.core.auth import get_password_hash from app.utils.test_utils import setup_test_db, teardown_test_db, setup_async_test_db, teardown_async_test_db @@ -63,4 +67,90 @@ def mock_user(db_session): ) db_session.add(mock_user) db_session.commit() - return mock_user \ No newline at end of file + return mock_user + + +@pytest.fixture(scope="function") +def test_db(): + """ + Creates a test database for integration tests. + + This creates a fresh database for each test to ensure isolation. + """ + test_engine, TestingSessionLocal = setup_test_db() + + # Create a session + with TestingSessionLocal() as session: + yield session + + # Clean up + teardown_test_db(test_engine) + + +@pytest.fixture(scope="function") +def client(test_db): + """ + Create a FastAPI test client with a test database. + + This overrides the get_db dependency to use the test database. + """ + def override_get_db(): + try: + yield test_db + finally: + pass + + app.dependency_overrides[get_db] = override_get_db + + with TestClient(app) as test_client: + yield test_client + + app.dependency_overrides.clear() + + +@pytest.fixture +def test_user(test_db): + """ + Create a test user in the database. + + Password: TestPassword123 + """ + user = User( + id=uuid.uuid4(), + email="testuser@example.com", + password_hash=get_password_hash("TestPassword123"), + first_name="Test", + last_name="User", + phone_number="+1234567890", + is_active=True, + is_superuser=False, + preferences=None, + ) + test_db.add(user) + test_db.commit() + test_db.refresh(user) + return user + + +@pytest.fixture +def test_superuser(test_db): + """ + Create a test superuser in the database. + + Password: SuperPassword123 + """ + user = User( + id=uuid.uuid4(), + email="superuser@example.com", + password_hash=get_password_hash("SuperPassword123"), + first_name="Super", + last_name="User", + phone_number="+9876543210", + is_active=True, + is_superuser=True, + preferences=None, + ) + test_db.add(user) + test_db.commit() + test_db.refresh(user) + return user \ No newline at end of file diff --git a/backend/tests/crud/test_crud_base.py b/backend/tests/crud/test_crud_base.py new file mode 100644 index 0000000..cd3d17f --- /dev/null +++ b/backend/tests/crud/test_crud_base.py @@ -0,0 +1,448 @@ +# tests/crud/test_crud_base.py +""" +Tests for CRUD base operations. +""" +import pytest +from uuid import uuid4 + +from app.models.user import User +from app.crud.user import user as user_crud +from app.schemas.users import UserCreate, UserUpdate + + +class TestCRUDGet: + """Tests for CRUD get operations.""" + + def test_get_by_valid_uuid(self, db_session): + """Test getting a record by valid UUID.""" + user = User( + email="get_uuid@example.com", + password_hash="hash", + first_name="Get", + last_name="UUID", + is_active=True, + is_superuser=False + ) + db_session.add(user) + db_session.commit() + db_session.refresh(user) + + retrieved = user_crud.get(db_session, id=user.id) + assert retrieved is not None + assert retrieved.id == user.id + assert retrieved.email == user.email + + def test_get_by_string_uuid(self, db_session): + """Test getting a record by UUID string.""" + user = User( + email="get_string@example.com", + password_hash="hash", + first_name="Get", + last_name="String", + is_active=True, + is_superuser=False + ) + db_session.add(user) + db_session.commit() + db_session.refresh(user) + + retrieved = user_crud.get(db_session, id=str(user.id)) + assert retrieved is not None + assert retrieved.id == user.id + + def test_get_nonexistent(self, db_session): + """Test getting a non-existent record.""" + fake_id = uuid4() + result = user_crud.get(db_session, id=fake_id) + assert result is None + + def test_get_invalid_uuid(self, db_session): + """Test getting with invalid UUID format.""" + result = user_crud.get(db_session, id="not-a-uuid") + assert result is None + + +class TestCRUDGetMulti: + """Tests for get_multi operations.""" + + def test_get_multi_basic(self, db_session): + """Test basic get_multi functionality.""" + # Create multiple users + users = [ + User(email=f"multi{i}@example.com", password_hash="hash", first_name=f"User{i}", + is_active=True, is_superuser=False) + for i in range(5) + ] + db_session.add_all(users) + db_session.commit() + + results = user_crud.get_multi(db_session, skip=0, limit=10) + assert len(results) >= 5 + + def test_get_multi_pagination(self, db_session): + """Test pagination with get_multi.""" + # Create users + users = [ + User(email=f"page{i}@example.com", password_hash="hash", first_name=f"Page{i}", + is_active=True, is_superuser=False) + for i in range(10) + ] + db_session.add_all(users) + db_session.commit() + + # First page + page1 = user_crud.get_multi(db_session, skip=0, limit=3) + assert len(page1) == 3 + + # Second page + page2 = user_crud.get_multi(db_session, skip=3, limit=3) + assert len(page2) == 3 + + # Pages should have different users + page1_ids = {u.id for u in page1} + page2_ids = {u.id for u in page2} + assert len(page1_ids.intersection(page2_ids)) == 0 + + def test_get_multi_negative_skip(self, db_session): + """Test that negative skip raises ValueError.""" + with pytest.raises(ValueError, match="skip must be non-negative"): + user_crud.get_multi(db_session, skip=-1, limit=10) + + def test_get_multi_negative_limit(self, db_session): + """Test that negative limit raises ValueError.""" + with pytest.raises(ValueError, match="limit must be non-negative"): + user_crud.get_multi(db_session, skip=0, limit=-1) + + def test_get_multi_limit_too_large(self, db_session): + """Test that limit over 1000 raises ValueError.""" + with pytest.raises(ValueError, match="Maximum limit is 1000"): + user_crud.get_multi(db_session, skip=0, limit=1001) + + +class TestCRUDGetMultiWithTotal: + """Tests for get_multi_with_total operations.""" + + def test_get_multi_with_total_basic(self, db_session): + """Test basic get_multi_with_total functionality.""" + # Create users + users = [ + User(email=f"total{i}@example.com", password_hash="hash", first_name=f"Total{i}", + is_active=True, is_superuser=False) + for i in range(7) + ] + db_session.add_all(users) + db_session.commit() + + results, total = user_crud.get_multi_with_total(db_session, skip=0, limit=10) + assert total >= 7 + assert len(results) >= 7 + + def test_get_multi_with_total_pagination(self, db_session): + """Test pagination returns correct total.""" + # Create users + users = [ + User(email=f"pagetotal{i}@example.com", password_hash="hash", first_name=f"PageTotal{i}", + is_active=True, is_superuser=False) + for i in range(15) + ] + db_session.add_all(users) + db_session.commit() + + # First page + page1, total1 = user_crud.get_multi_with_total(db_session, skip=0, limit=5) + assert len(page1) == 5 + assert total1 >= 15 + + # Second page should have same total + page2, total2 = user_crud.get_multi_with_total(db_session, skip=5, limit=5) + assert len(page2) == 5 + assert total2 == total1 + + def test_get_multi_with_total_sorting_asc(self, db_session): + """Test sorting in ascending order.""" + # Create users + users = [ + User(email=f"sort{i}@example.com", password_hash="hash", first_name=f"User{chr(90-i)}", + is_active=True, is_superuser=False) + for i in range(5) + ] + db_session.add_all(users) + db_session.commit() + + results, _ = user_crud.get_multi_with_total( + db_session, + sort_by="first_name", + sort_order="asc" + ) + + # Check that results are sorted + first_names = [u.first_name for u in results if u.first_name.startswith("User")] + assert first_names == sorted(first_names) + + def test_get_multi_with_total_sorting_desc(self, db_session): + """Test sorting in descending order.""" + # Create users + users = [ + User(email=f"desc{i}@example.com", password_hash="hash", first_name=f"User{chr(65+i)}", + is_active=True, is_superuser=False) + for i in range(5) + ] + db_session.add_all(users) + db_session.commit() + + results, _ = user_crud.get_multi_with_total( + db_session, + sort_by="first_name", + sort_order="desc" + ) + + # Check that results are sorted descending + first_names = [u.first_name for u in results if u.first_name.startswith("User")] + assert first_names == sorted(first_names, reverse=True) + + def test_get_multi_with_total_filtering(self, db_session): + """Test filtering with get_multi_with_total.""" + # Create active and inactive users + active_user = User( + email="active_filter@example.com", + password_hash="hash", + first_name="Active", + is_active=True, + is_superuser=False + ) + inactive_user = User( + email="inactive_filter@example.com", + password_hash="hash", + first_name="Inactive", + is_active=False, + is_superuser=False + ) + db_session.add_all([active_user, inactive_user]) + db_session.commit() + + # Filter for active users only + results, total = user_crud.get_multi_with_total( + db_session, + filters={"is_active": True} + ) + + emails = [u.email for u in results] + assert "active_filter@example.com" in emails + assert "inactive_filter@example.com" not in emails + + def test_get_multi_with_total_multiple_filters(self, db_session): + """Test multiple filters.""" + # Create users with different combinations + user1 = User( + email="multi1@example.com", + password_hash="hash", + first_name="User1", + is_active=True, + is_superuser=True + ) + user2 = User( + email="multi2@example.com", + password_hash="hash", + first_name="User2", + is_active=True, + is_superuser=False + ) + user3 = User( + email="multi3@example.com", + password_hash="hash", + first_name="User3", + is_active=False, + is_superuser=True + ) + db_session.add_all([user1, user2, user3]) + db_session.commit() + + # Filter for active superusers + results, _ = user_crud.get_multi_with_total( + db_session, + filters={"is_active": True, "is_superuser": True} + ) + + emails = [u.email for u in results] + assert "multi1@example.com" in emails + assert "multi2@example.com" not in emails + assert "multi3@example.com" not in emails + + def test_get_multi_with_total_nonexistent_sort_field(self, db_session): + """Test sorting by non-existent field is ignored.""" + results, _ = user_crud.get_multi_with_total( + db_session, + sort_by="nonexistent_field", + sort_order="asc" + ) + + # Should not raise an error, just ignore the invalid sort field + assert results is not None + + def test_get_multi_with_total_nonexistent_filter_field(self, db_session): + """Test filtering by non-existent field is ignored.""" + results, _ = user_crud.get_multi_with_total( + db_session, + filters={"nonexistent_field": "value"} + ) + + # Should not raise an error, just ignore the invalid filter + assert results is not None + + def test_get_multi_with_total_none_filter_values(self, db_session): + """Test that None filter values are ignored.""" + user = User( + email="none_filter@example.com", + password_hash="hash", + first_name="None", + is_active=True, + is_superuser=False + ) + db_session.add(user) + db_session.commit() + + # Pass None as a filter value - should be ignored + results, _ = user_crud.get_multi_with_total( + db_session, + filters={"is_active": None} + ) + + # Should return all users (not filtered) + assert len(results) >= 1 + + +class TestCRUDCreate: + """Tests for create operations.""" + + def test_create_basic(self, db_session): + """Test basic record creation.""" + user_data = UserCreate( + email="create@example.com", + password="Password123", + first_name="Create", + last_name="Test" + ) + + created = user_crud.create(db_session, obj_in=user_data) + + assert created.id is not None + assert created.email == "create@example.com" + assert created.first_name == "Create" + + def test_create_duplicate_email(self, db_session): + """Test that creating duplicate email raises error.""" + user_data = UserCreate( + email="duplicate@example.com", + password="Password123", + first_name="First" + ) + + # Create first user + user_crud.create(db_session, obj_in=user_data) + + # Try to create duplicate + with pytest.raises(ValueError, match="already exists"): + user_crud.create(db_session, obj_in=user_data) + + +class TestCRUDUpdate: + """Tests for update operations.""" + + def test_update_basic(self, db_session): + """Test basic record update.""" + user = User( + email="update@example.com", + password_hash="hash", + first_name="Original", + is_active=True, + is_superuser=False + ) + db_session.add(user) + db_session.commit() + db_session.refresh(user) + + update_data = UserUpdate(first_name="Updated") + updated = user_crud.update(db_session, db_obj=user, obj_in=update_data) + + assert updated.first_name == "Updated" + assert updated.email == "update@example.com" # Unchanged + + def test_update_with_dict(self, db_session): + """Test updating with dictionary.""" + user = User( + email="updatedict@example.com", + password_hash="hash", + first_name="Original", + is_active=True, + is_superuser=False + ) + db_session.add(user) + db_session.commit() + db_session.refresh(user) + + update_data = {"first_name": "DictUpdated", "last_name": "DictLast"} + updated = user_crud.update(db_session, db_obj=user, obj_in=update_data) + + assert updated.first_name == "DictUpdated" + assert updated.last_name == "DictLast" + + def test_update_partial(self, db_session): + """Test partial update (only some fields).""" + user = User( + email="partial@example.com", + password_hash="hash", + first_name="First", + last_name="Last", + is_active=True, + is_superuser=False + ) + db_session.add(user) + db_session.commit() + db_session.refresh(user) + + # Only update last_name + update_data = UserUpdate(last_name="NewLast") + updated = user_crud.update(db_session, db_obj=user, obj_in=update_data) + + assert updated.first_name == "First" # Unchanged + assert updated.last_name == "NewLast" # Changed + + +class TestCRUDRemove: + """Tests for remove (hard delete) operations.""" + + def test_remove_basic(self, db_session): + """Test basic record removal.""" + user = User( + email="remove@example.com", + password_hash="hash", + first_name="Remove", + is_active=True, + is_superuser=False + ) + db_session.add(user) + db_session.commit() + db_session.refresh(user) + + user_id = user.id + + # Remove the user + removed = user_crud.remove(db_session, id=user_id) + + assert removed is not None + assert removed.id == user_id + + # User should no longer exist + retrieved = user_crud.get(db_session, id=user_id) + assert retrieved is None + + def test_remove_nonexistent(self, db_session): + """Test removing non-existent record.""" + fake_id = uuid4() + result = user_crud.remove(db_session, id=fake_id) + assert result is None + + def test_remove_invalid_uuid(self, db_session): + """Test removing with invalid UUID.""" + result = user_crud.remove(db_session, id="not-a-uuid") + assert result is None diff --git a/backend/tests/crud/test_soft_delete.py b/backend/tests/crud/test_soft_delete.py new file mode 100644 index 0000000..cb41e00 --- /dev/null +++ b/backend/tests/crud/test_soft_delete.py @@ -0,0 +1,324 @@ +# tests/crud/test_soft_delete.py +""" +Tests for soft delete functionality in CRUD operations. +""" +import pytest +from datetime import datetime, timezone + +from app.models.user import User +from app.crud.user import user as user_crud + + +class TestSoftDelete: + """Tests for soft delete functionality.""" + + def test_soft_delete_marks_deleted_at(self, db_session): + """Test that soft delete sets deleted_at timestamp.""" + # Create a user + test_user = User( + email="softdelete@example.com", + password_hash="hashedpassword", + first_name="Soft", + last_name="Delete", + is_active=True, + is_superuser=False + ) + db_session.add(test_user) + db_session.commit() + db_session.refresh(test_user) + + user_id = test_user.id + assert test_user.deleted_at is None + + # Soft delete the user + deleted_user = user_crud.soft_delete(db_session, id=user_id) + + assert deleted_user is not None + assert deleted_user.deleted_at is not None + assert isinstance(deleted_user.deleted_at, datetime) + + def test_soft_delete_excludes_from_get_multi(self, db_session): + """Test that soft deleted records are excluded from get_multi.""" + # Create two users + user1 = User( + email="user1@example.com", + password_hash="hash1", + first_name="User", + last_name="One", + is_active=True, + is_superuser=False + ) + user2 = User( + email="user2@example.com", + password_hash="hash2", + first_name="User", + last_name="Two", + is_active=True, + is_superuser=False + ) + db_session.add_all([user1, user2]) + db_session.commit() + db_session.refresh(user1) + db_session.refresh(user2) + + # Both users should be returned + users, total = user_crud.get_multi_with_total(db_session) + assert total >= 2 + user_emails = [u.email for u in users] + assert "user1@example.com" in user_emails + assert "user2@example.com" in user_emails + + # Soft delete user1 + user_crud.soft_delete(db_session, id=user1.id) + + # Only user2 should be returned + users, total = user_crud.get_multi_with_total(db_session) + user_emails = [u.email for u in users] + assert "user1@example.com" not in user_emails + assert "user2@example.com" in user_emails + + def test_soft_delete_still_retrievable_by_get(self, db_session): + """Test that soft deleted records can still be retrieved by get() method.""" + # Create a user + user = User( + email="gettest@example.com", + password_hash="hash", + first_name="Get", + last_name="Test", + is_active=True, + is_superuser=False + ) + db_session.add(user) + db_session.commit() + db_session.refresh(user) + + user_id = user.id + + # User should be retrievable + retrieved = user_crud.get(db_session, id=user_id) + assert retrieved is not None + assert retrieved.email == "gettest@example.com" + assert retrieved.deleted_at is None + + # Soft delete the user + user_crud.soft_delete(db_session, id=user_id) + + # User should still be retrievable by ID (soft delete doesn't prevent direct access) + retrieved = user_crud.get(db_session, id=user_id) + assert retrieved is not None + assert retrieved.deleted_at is not None + + def test_soft_delete_nonexistent_record(self, db_session): + """Test soft deleting a record that doesn't exist.""" + import uuid + fake_id = uuid.uuid4() + + result = user_crud.soft_delete(db_session, id=fake_id) + assert result is None + + def test_restore_sets_deleted_at_to_none(self, db_session): + """Test that restore clears the deleted_at timestamp.""" + # Create and soft delete a user + user = User( + email="restore@example.com", + password_hash="hash", + first_name="Restore", + last_name="Test", + is_active=True, + is_superuser=False + ) + db_session.add(user) + db_session.commit() + db_session.refresh(user) + + user_id = user.id + + # Soft delete + user_crud.soft_delete(db_session, id=user_id) + db_session.refresh(user) + assert user.deleted_at is not None + + # Restore + restored_user = user_crud.restore(db_session, id=user_id) + + assert restored_user is not None + assert restored_user.deleted_at is None + + def test_restore_makes_record_available(self, db_session): + """Test that restored records appear in queries.""" + # Create and soft delete a user + user = User( + email="available@example.com", + password_hash="hash", + first_name="Available", + last_name="Test", + is_active=True, + is_superuser=False + ) + db_session.add(user) + db_session.commit() + db_session.refresh(user) + + user_id = user.id + user_email = user.email + + # Soft delete + user_crud.soft_delete(db_session, id=user_id) + + # User should not be in query results + users, _ = user_crud.get_multi_with_total(db_session) + emails = [u.email for u in users] + assert user_email not in emails + + # Restore + user_crud.restore(db_session, id=user_id) + + # User should now be in query results + users, _ = user_crud.get_multi_with_total(db_session) + emails = [u.email for u in users] + assert user_email in emails + + def test_restore_nonexistent_record(self, db_session): + """Test restoring a record that doesn't exist.""" + import uuid + fake_id = uuid.uuid4() + + result = user_crud.restore(db_session, id=fake_id) + assert result is None + + def test_restore_already_active_record(self, db_session): + """Test restoring a record that was never deleted returns None.""" + # Create a user (not deleted) + user = User( + email="never_deleted@example.com", + password_hash="hash", + first_name="Never", + last_name="Deleted", + is_active=True, + is_superuser=False + ) + db_session.add(user) + db_session.commit() + db_session.refresh(user) + + user_id = user.id + assert user.deleted_at is None + + # Restore should return None (record is not soft-deleted) + restored = user_crud.restore(db_session, id=user_id) + assert restored is None + + def test_soft_delete_multiple_times(self, db_session): + """Test soft deleting the same record multiple times.""" + # Create a user + user = User( + email="multiple_delete@example.com", + password_hash="hash", + first_name="Multiple", + last_name="Delete", + is_active=True, + is_superuser=False + ) + db_session.add(user) + db_session.commit() + db_session.refresh(user) + + user_id = user.id + + # First soft delete + first_deleted = user_crud.soft_delete(db_session, id=user_id) + assert first_deleted is not None + first_timestamp = first_deleted.deleted_at + + # Restore + user_crud.restore(db_session, id=user_id) + + # Second soft delete + second_deleted = user_crud.soft_delete(db_session, id=user_id) + assert second_deleted is not None + second_timestamp = second_deleted.deleted_at + + # Timestamps should be different + assert second_timestamp != first_timestamp + assert second_timestamp > first_timestamp + + def test_get_multi_with_filters_excludes_deleted(self, db_session): + """Test that get_multi_with_total with filters excludes deleted records.""" + # Create active and inactive users + active_user = User( + email="active_not_deleted@example.com", + password_hash="hash", + first_name="Active", + last_name="NotDeleted", + is_active=True, + is_superuser=False + ) + inactive_user = User( + email="inactive_not_deleted@example.com", + password_hash="hash", + first_name="Inactive", + last_name="NotDeleted", + is_active=False, + is_superuser=False + ) + deleted_active_user = User( + email="active_deleted@example.com", + password_hash="hash", + first_name="Active", + last_name="Deleted", + is_active=True, + is_superuser=False + ) + + db_session.add_all([active_user, inactive_user, deleted_active_user]) + db_session.commit() + db_session.refresh(deleted_active_user) + + # Soft delete one active user + user_crud.soft_delete(db_session, id=deleted_active_user.id) + + # Filter for active users - should only return non-deleted active user + users, total = user_crud.get_multi_with_total( + db_session, + filters={"is_active": True} + ) + + emails = [u.email for u in users] + assert "active_not_deleted@example.com" in emails + assert "active_deleted@example.com" not in emails + assert "inactive_not_deleted@example.com" not in emails + + def test_soft_delete_preserves_other_fields(self, db_session): + """Test that soft delete doesn't modify other fields.""" + # Create a user with specific data + user = User( + email="preserve@example.com", + password_hash="original_hash", + first_name="Preserve", + last_name="Fields", + phone_number="+1234567890", + is_active=True, + is_superuser=False, + preferences={"theme": "dark"} + ) + db_session.add(user) + db_session.commit() + db_session.refresh(user) + + user_id = user.id + original_email = user.email + original_hash = user.password_hash + original_first_name = user.first_name + original_phone = user.phone_number + original_preferences = user.preferences + + # Soft delete + deleted = user_crud.soft_delete(db_session, id=user_id) + + # All other fields should remain unchanged + assert deleted.email == original_email + assert deleted.password_hash == original_hash + assert deleted.first_name == original_first_name + assert deleted.phone_number == original_phone + assert deleted.preferences == original_preferences + assert deleted.is_active is True # is_active unchanged diff --git a/backend/tests/services/test_email_service.py b/backend/tests/services/test_email_service.py new file mode 100644 index 0000000..77d8c9d --- /dev/null +++ b/backend/tests/services/test_email_service.py @@ -0,0 +1,281 @@ +# tests/services/test_email_service.py +""" +Tests for email service functionality. +""" +import pytest +from unittest.mock import patch, AsyncMock, MagicMock + +from app.services.email_service import ( + EmailService, + ConsoleEmailBackend, + SMTPEmailBackend +) + + +class TestConsoleEmailBackend: + """Tests for ConsoleEmailBackend.""" + + @pytest.mark.asyncio + async def test_send_email_basic(self): + """Test basic email sending with console backend.""" + backend = ConsoleEmailBackend() + + result = await backend.send_email( + to=["user@example.com"], + subject="Test Subject", + html_content="
Test HTML
", + text_content="Test Text" + ) + + assert result is True + + @pytest.mark.asyncio + async def test_send_email_without_text_content(self): + """Test sending email without plain text version.""" + backend = ConsoleEmailBackend() + + result = await backend.send_email( + to=["user@example.com"], + subject="Test Subject", + html_content="Test HTML
" + ) + + assert result is True + + @pytest.mark.asyncio + async def test_send_email_multiple_recipients(self): + """Test sending email to multiple recipients.""" + backend = ConsoleEmailBackend() + + result = await backend.send_email( + to=["user1@example.com", "user2@example.com"], + subject="Test Subject", + html_content="Test HTML
" + ) + + assert result is True + + +class TestSMTPEmailBackend: + """Tests for SMTPEmailBackend.""" + + @pytest.mark.asyncio + async def test_smtp_backend_initialization(self): + """Test SMTP backend initialization.""" + backend = SMTPEmailBackend( + host="smtp.example.com", + port=587, + username="test@example.com", + password="password" + ) + + assert backend.host == "smtp.example.com" + assert backend.port == 587 + assert backend.username == "test@example.com" + assert backend.password == "password" + + @pytest.mark.asyncio + async def test_smtp_backend_fallback_to_console(self): + """Test that SMTP backend falls back to console when not implemented.""" + backend = SMTPEmailBackend( + host="smtp.example.com", + port=587, + username="test@example.com", + password="password" + ) + + # Should fall back to console backend since SMTP is not implemented + result = await backend.send_email( + to=["user@example.com"], + subject="Test Subject", + html_content="Test HTML
" + ) + + assert result is True + + +class TestEmailService: + """Tests for EmailService.""" + + def test_email_service_default_backend(self): + """Test that EmailService uses ConsoleEmailBackend by default.""" + service = EmailService() + assert isinstance(service.backend, ConsoleEmailBackend) + + def test_email_service_custom_backend(self): + """Test EmailService with custom backend.""" + custom_backend = ConsoleEmailBackend() + service = EmailService(backend=custom_backend) + assert service.backend is custom_backend + + @pytest.mark.asyncio + async def test_send_password_reset_email(self): + """Test sending password reset email.""" + service = EmailService() + + result = await service.send_password_reset_email( + to_email="user@example.com", + reset_token="test_token_123", + user_name="John" + ) + + assert result is True + + @pytest.mark.asyncio + async def test_send_password_reset_email_without_name(self): + """Test sending password reset email without user name.""" + service = EmailService() + + result = await service.send_password_reset_email( + to_email="user@example.com", + reset_token="test_token_123" + ) + + assert result is True + + @pytest.mark.asyncio + async def test_send_password_reset_email_includes_token_in_url(self): + """Test that password reset email includes token in URL.""" + backend_mock = AsyncMock(spec=ConsoleEmailBackend) + backend_mock.send_email = AsyncMock(return_value=True) + service = EmailService(backend=backend_mock) + + token = "test_reset_token_xyz" + await service.send_password_reset_email( + to_email="user@example.com", + reset_token=token + ) + + # Verify send_email was called + backend_mock.send_email.assert_called_once() + call_args = backend_mock.send_email.call_args + + # Check that token is in the HTML content + html_content = call_args.kwargs['html_content'] + assert token in html_content + + @pytest.mark.asyncio + async def test_send_password_reset_email_error_handling(self): + """Test error handling in password reset email.""" + backend_mock = AsyncMock(spec=ConsoleEmailBackend) + backend_mock.send_email = AsyncMock(side_effect=Exception("SMTP Error")) + service = EmailService(backend=backend_mock) + + result = await service.send_password_reset_email( + to_email="user@example.com", + reset_token="test_token" + ) + + assert result is False + + @pytest.mark.asyncio + async def test_send_email_verification(self): + """Test sending email verification email.""" + service = EmailService() + + result = await service.send_email_verification( + to_email="user@example.com", + verification_token="verification_token_123", + user_name="Jane" + ) + + assert result is True + + @pytest.mark.asyncio + async def test_send_email_verification_without_name(self): + """Test sending email verification without user name.""" + service = EmailService() + + result = await service.send_email_verification( + to_email="user@example.com", + verification_token="verification_token_123" + ) + + assert result is True + + @pytest.mark.asyncio + async def test_send_email_verification_includes_token(self): + """Test that email verification includes token in URL.""" + backend_mock = AsyncMock(spec=ConsoleEmailBackend) + backend_mock.send_email = AsyncMock(return_value=True) + service = EmailService(backend=backend_mock) + + token = "test_verification_token_xyz" + await service.send_email_verification( + to_email="user@example.com", + verification_token=token + ) + + # Verify send_email was called + backend_mock.send_email.assert_called_once() + call_args = backend_mock.send_email.call_args + + # Check that token is in the HTML content + html_content = call_args.kwargs['html_content'] + assert token in html_content + + @pytest.mark.asyncio + async def test_send_email_verification_error_handling(self): + """Test error handling in email verification.""" + backend_mock = AsyncMock(spec=ConsoleEmailBackend) + backend_mock.send_email = AsyncMock(side_effect=Exception("Email Error")) + service = EmailService(backend=backend_mock) + + result = await service.send_email_verification( + to_email="user@example.com", + verification_token="test_token" + ) + + assert result is False + + @pytest.mark.asyncio + async def test_password_reset_email_contains_required_elements(self): + """Test that password reset email has all required elements.""" + backend_mock = AsyncMock(spec=ConsoleEmailBackend) + backend_mock.send_email = AsyncMock(return_value=True) + service = EmailService(backend=backend_mock) + + await service.send_password_reset_email( + to_email="user@example.com", + reset_token="token123", + user_name="Test User" + ) + + call_args = backend_mock.send_email.call_args + html_content = call_args.kwargs['html_content'] + text_content = call_args.kwargs['text_content'] + + # Check HTML content + assert "Password Reset" in html_content + assert "token123" in html_content + assert "Test User" in html_content + + # Check text content + assert "Password Reset" in text_content or "password reset" in text_content.lower() + assert "token123" in text_content + + @pytest.mark.asyncio + async def test_verification_email_contains_required_elements(self): + """Test that verification email has all required elements.""" + backend_mock = AsyncMock(spec=ConsoleEmailBackend) + backend_mock.send_email = AsyncMock(return_value=True) + service = EmailService(backend=backend_mock) + + await service.send_email_verification( + to_email="user@example.com", + verification_token="verify123", + user_name="Test User" + ) + + call_args = backend_mock.send_email.call_args + html_content = call_args.kwargs['html_content'] + text_content = call_args.kwargs['text_content'] + + # Check HTML content + assert "Verify" in html_content + assert "verify123" in html_content + assert "Test User" in html_content + + # Check text content + assert "verify" in text_content.lower() + assert "verify123" in text_content diff --git a/backend/tests/utils/test_security.py b/backend/tests/utils/test_security.py index 5434281..52c4a9e 100644 --- a/backend/tests/utils/test_security.py +++ b/backend/tests/utils/test_security.py @@ -8,7 +8,14 @@ import json import pytest from unittest.mock import patch, MagicMock -from app.utils.security import create_upload_token, verify_upload_token +from app.utils.security import ( + create_upload_token, + verify_upload_token, + create_password_reset_token, + verify_password_reset_token, + create_email_verification_token, + verify_email_verification_token +) class TestCreateUploadToken: @@ -231,3 +238,189 @@ class TestVerifyUploadToken: # The signature validation is already tested by test_verify_invalid_signature # and test_verify_tampered_payload. Testing with different SECRET_KEY # requires complex mocking that can interfere with other tests. + + +class TestPasswordResetTokens: + """Tests for password reset token functions.""" + + def test_create_password_reset_token(self): + """Test creating a password reset token.""" + email = "user@example.com" + token = create_password_reset_token(email) + + assert token is not None + assert isinstance(token, str) + assert len(token) > 0 + + def test_verify_password_reset_token_valid(self): + """Test verifying a valid password reset token.""" + email = "user@example.com" + token = create_password_reset_token(email) + + verified_email = verify_password_reset_token(token) + + assert verified_email == email + + def test_verify_password_reset_token_expired(self): + """Test that expired password reset tokens are rejected.""" + email = "user@example.com" + + # Create token that expires in 1 second + with patch('app.utils.security.time') as mock_time: + mock_time.time = MagicMock(return_value=1000000) + token = create_password_reset_token(email, expires_in=1) + + # Fast forward time + mock_time.time.return_value = 1000002 + + verified_email = verify_password_reset_token(token) + assert verified_email is None + + def test_verify_password_reset_token_invalid(self): + """Test that invalid tokens are rejected.""" + assert verify_password_reset_token("invalid_token") is None + assert verify_password_reset_token("") is None + + def test_verify_password_reset_token_tampered(self): + """Test that tampered tokens are rejected.""" + email = "user@example.com" + token = create_password_reset_token(email) + + # Decode and tamper + decoded = base64.urlsafe_b64decode(token.encode('utf-8')).decode('utf-8') + token_data = json.loads(decoded) + token_data["payload"]["email"] = "hacker@example.com" + + # Re-encode + tampered = base64.urlsafe_b64encode(json.dumps(token_data).encode('utf-8')).decode('utf-8') + + verified_email = verify_password_reset_token(tampered) + assert verified_email is None + + def test_verify_password_reset_token_wrong_purpose(self): + """Test that email verification tokens can't be used for password reset.""" + email = "user@example.com" + # Create an email verification token + token = create_email_verification_token(email) + + # Try to verify as password reset token + verified_email = verify_password_reset_token(token) + assert verified_email is None + + def test_password_reset_token_custom_expiration(self): + """Test password reset token with custom expiration.""" + email = "user@example.com" + custom_exp = 7200 # 2 hours + + with patch('app.utils.security.time') as mock_time: + current_time = 1000000 + mock_time.time = MagicMock(return_value=current_time) + + token = create_password_reset_token(email, expires_in=custom_exp) + + # Decode to check expiration + decoded = base64.urlsafe_b64decode(token.encode('utf-8')).decode('utf-8') + token_data = json.loads(decoded) + + assert token_data["payload"]["exp"] == current_time + custom_exp + + +class TestEmailVerificationTokens: + """Tests for email verification token functions.""" + + def test_create_email_verification_token(self): + """Test creating an email verification token.""" + email = "user@example.com" + token = create_email_verification_token(email) + + assert token is not None + assert isinstance(token, str) + assert len(token) > 0 + + def test_verify_email_verification_token_valid(self): + """Test verifying a valid email verification token.""" + email = "user@example.com" + token = create_email_verification_token(email) + + verified_email = verify_email_verification_token(token) + + assert verified_email == email + + def test_verify_email_verification_token_expired(self): + """Test that expired verification tokens are rejected.""" + email = "user@example.com" + + with patch('app.utils.security.time') as mock_time: + mock_time.time = MagicMock(return_value=1000000) + token = create_email_verification_token(email, expires_in=1) + + # Fast forward time + mock_time.time.return_value = 1000002 + + verified_email = verify_email_verification_token(token) + assert verified_email is None + + def test_verify_email_verification_token_invalid(self): + """Test that invalid tokens are rejected.""" + assert verify_email_verification_token("invalid_token") is None + assert verify_email_verification_token("") is None + + def test_verify_email_verification_token_tampered(self): + """Test that tampered verification tokens are rejected.""" + email = "user@example.com" + token = create_email_verification_token(email) + + # Decode and tamper + decoded = base64.urlsafe_b64decode(token.encode('utf-8')).decode('utf-8') + token_data = json.loads(decoded) + token_data["payload"]["email"] = "hacker@example.com" + + # Re-encode + tampered = base64.urlsafe_b64encode(json.dumps(token_data).encode('utf-8')).decode('utf-8') + + verified_email = verify_email_verification_token(tampered) + assert verified_email is None + + def test_verify_email_verification_token_wrong_purpose(self): + """Test that password reset tokens can't be used for email verification.""" + email = "user@example.com" + # Create a password reset token + token = create_password_reset_token(email) + + # Try to verify as email verification token + verified_email = verify_email_verification_token(token) + assert verified_email is None + + def test_email_verification_token_default_expiration(self): + """Test email verification token with default 24-hour expiration.""" + email = "user@example.com" + + with patch('app.utils.security.time') as mock_time: + current_time = 1000000 + mock_time.time = MagicMock(return_value=current_time) + + token = create_email_verification_token(email) + + # Decode to check expiration (should be 86400 seconds = 24 hours) + decoded = base64.urlsafe_b64decode(token.encode('utf-8')).decode('utf-8') + token_data = json.loads(decoded) + + assert token_data["payload"]["exp"] == current_time + 86400 + + def test_tokens_are_unique(self): + """Test that multiple tokens for the same email are unique.""" + email = "user@example.com" + + token1 = create_password_reset_token(email) + token2 = create_password_reset_token(email) + + assert token1 != token2 + + def test_verification_and_reset_tokens_are_different(self): + """Test that verification and reset tokens for same email are different.""" + email = "user@example.com" + + reset_token = create_password_reset_token(email) + verify_token = create_email_verification_token(email) + + assert reset_token != verify_token