Remove CRUD test modules for unused and deprecated features
- Deleted `test_crud_base.py`, `test_crud_error_paths.py`, and `test_organization_async.py` due to the removal of corresponding deprecated CRUD implementations. - Improved codebase maintainability and reduced test suite noise by eliminating obsolete test files.
This commit is contained in:
@@ -9,8 +9,8 @@ from app.main import app
|
||||
@pytest.fixture
|
||||
def client():
|
||||
"""Create a FastAPI test client for the main app."""
|
||||
# Mock get_async_db to avoid database connection issues
|
||||
with patch("app.core.database_async.get_async_db") as mock_get_db:
|
||||
# Mock get_db to avoid database connection issues
|
||||
with patch("app.core.database.get_db") as mock_get_db:
|
||||
async def mock_session_generator():
|
||||
from unittest.mock import MagicMock, AsyncMock
|
||||
mock_session = MagicMock()
|
||||
|
||||
@@ -12,7 +12,7 @@ from httpx import AsyncClient, ASGITransport
|
||||
os.environ["IS_TEST"] = "True"
|
||||
|
||||
from app.main import app
|
||||
from app.core.database_async import get_async_db
|
||||
from app.core.database import get_db
|
||||
from app.models.user import User
|
||||
from app.core.auth import get_password_hash
|
||||
from app.utils.test_utils import setup_test_db, teardown_test_db, setup_async_test_db, teardown_async_test_db
|
||||
@@ -100,18 +100,18 @@ async def client(async_test_db):
|
||||
"""
|
||||
Create a FastAPI async test client with a test database.
|
||||
|
||||
This overrides the get_async_db dependency to use the test database.
|
||||
This overrides the get_db dependency to use the test database.
|
||||
"""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
async def override_get_async_db():
|
||||
async def override_get_db():
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
try:
|
||||
yield session
|
||||
finally:
|
||||
pass
|
||||
|
||||
app.dependency_overrides[get_async_db] = override_get_async_db
|
||||
app.dependency_overrides[get_db] = override_get_db
|
||||
|
||||
# Use ASGITransport for httpx >= 0.27
|
||||
transport = ASGITransport(app=app)
|
||||
|
||||
@@ -1,448 +0,0 @@
|
||||
# tests/crud/test_crud_base.py
|
||||
"""
|
||||
Tests for CRUD base operations.
|
||||
"""
|
||||
import pytest
|
||||
from uuid import uuid4
|
||||
|
||||
from app.models.user import User
|
||||
from app.crud.user import user as user_crud
|
||||
from app.schemas.users import UserCreate, UserUpdate
|
||||
|
||||
|
||||
class TestCRUDGet:
|
||||
"""Tests for CRUD get operations."""
|
||||
|
||||
def test_get_by_valid_uuid(self, db_session):
|
||||
"""Test getting a record by valid UUID."""
|
||||
user = User(
|
||||
email="get_uuid@example.com",
|
||||
password_hash="hash",
|
||||
first_name="Get",
|
||||
last_name="UUID",
|
||||
is_active=True,
|
||||
is_superuser=False
|
||||
)
|
||||
db_session.add(user)
|
||||
db_session.commit()
|
||||
db_session.refresh(user)
|
||||
|
||||
retrieved = user_crud.get(db_session, id=user.id)
|
||||
assert retrieved is not None
|
||||
assert retrieved.id == user.id
|
||||
assert retrieved.email == user.email
|
||||
|
||||
def test_get_by_string_uuid(self, db_session):
|
||||
"""Test getting a record by UUID string."""
|
||||
user = User(
|
||||
email="get_string@example.com",
|
||||
password_hash="hash",
|
||||
first_name="Get",
|
||||
last_name="String",
|
||||
is_active=True,
|
||||
is_superuser=False
|
||||
)
|
||||
db_session.add(user)
|
||||
db_session.commit()
|
||||
db_session.refresh(user)
|
||||
|
||||
retrieved = user_crud.get(db_session, id=str(user.id))
|
||||
assert retrieved is not None
|
||||
assert retrieved.id == user.id
|
||||
|
||||
def test_get_nonexistent(self, db_session):
|
||||
"""Test getting a non-existent record."""
|
||||
fake_id = uuid4()
|
||||
result = user_crud.get(db_session, id=fake_id)
|
||||
assert result is None
|
||||
|
||||
def test_get_invalid_uuid(self, db_session):
|
||||
"""Test getting with invalid UUID format."""
|
||||
result = user_crud.get(db_session, id="not-a-uuid")
|
||||
assert result is None
|
||||
|
||||
|
||||
class TestCRUDGetMulti:
|
||||
"""Tests for get_multi operations."""
|
||||
|
||||
def test_get_multi_basic(self, db_session):
|
||||
"""Test basic get_multi functionality."""
|
||||
# Create multiple users
|
||||
users = [
|
||||
User(email=f"multi{i}@example.com", password_hash="hash", first_name=f"User{i}",
|
||||
is_active=True, is_superuser=False)
|
||||
for i in range(5)
|
||||
]
|
||||
db_session.add_all(users)
|
||||
db_session.commit()
|
||||
|
||||
results = user_crud.get_multi(db_session, skip=0, limit=10)
|
||||
assert len(results) >= 5
|
||||
|
||||
def test_get_multi_pagination(self, db_session):
|
||||
"""Test pagination with get_multi."""
|
||||
# Create users
|
||||
users = [
|
||||
User(email=f"page{i}@example.com", password_hash="hash", first_name=f"Page{i}",
|
||||
is_active=True, is_superuser=False)
|
||||
for i in range(10)
|
||||
]
|
||||
db_session.add_all(users)
|
||||
db_session.commit()
|
||||
|
||||
# First page
|
||||
page1 = user_crud.get_multi(db_session, skip=0, limit=3)
|
||||
assert len(page1) == 3
|
||||
|
||||
# Second page
|
||||
page2 = user_crud.get_multi(db_session, skip=3, limit=3)
|
||||
assert len(page2) == 3
|
||||
|
||||
# Pages should have different users
|
||||
page1_ids = {u.id for u in page1}
|
||||
page2_ids = {u.id for u in page2}
|
||||
assert len(page1_ids.intersection(page2_ids)) == 0
|
||||
|
||||
def test_get_multi_negative_skip(self, db_session):
|
||||
"""Test that negative skip raises ValueError."""
|
||||
with pytest.raises(ValueError, match="skip must be non-negative"):
|
||||
user_crud.get_multi(db_session, skip=-1, limit=10)
|
||||
|
||||
def test_get_multi_negative_limit(self, db_session):
|
||||
"""Test that negative limit raises ValueError."""
|
||||
with pytest.raises(ValueError, match="limit must be non-negative"):
|
||||
user_crud.get_multi(db_session, skip=0, limit=-1)
|
||||
|
||||
def test_get_multi_limit_too_large(self, db_session):
|
||||
"""Test that limit over 1000 raises ValueError."""
|
||||
with pytest.raises(ValueError, match="Maximum limit is 1000"):
|
||||
user_crud.get_multi(db_session, skip=0, limit=1001)
|
||||
|
||||
|
||||
class TestCRUDGetMultiWithTotal:
|
||||
"""Tests for get_multi_with_total operations."""
|
||||
|
||||
def test_get_multi_with_total_basic(self, db_session):
|
||||
"""Test basic get_multi_with_total functionality."""
|
||||
# Create users
|
||||
users = [
|
||||
User(email=f"total{i}@example.com", password_hash="hash", first_name=f"Total{i}",
|
||||
is_active=True, is_superuser=False)
|
||||
for i in range(7)
|
||||
]
|
||||
db_session.add_all(users)
|
||||
db_session.commit()
|
||||
|
||||
results, total = user_crud.get_multi_with_total(db_session, skip=0, limit=10)
|
||||
assert total >= 7
|
||||
assert len(results) >= 7
|
||||
|
||||
def test_get_multi_with_total_pagination(self, db_session):
|
||||
"""Test pagination returns correct total."""
|
||||
# Create users
|
||||
users = [
|
||||
User(email=f"pagetotal{i}@example.com", password_hash="hash", first_name=f"PageTotal{i}",
|
||||
is_active=True, is_superuser=False)
|
||||
for i in range(15)
|
||||
]
|
||||
db_session.add_all(users)
|
||||
db_session.commit()
|
||||
|
||||
# First page
|
||||
page1, total1 = user_crud.get_multi_with_total(db_session, skip=0, limit=5)
|
||||
assert len(page1) == 5
|
||||
assert total1 >= 15
|
||||
|
||||
# Second page should have same total
|
||||
page2, total2 = user_crud.get_multi_with_total(db_session, skip=5, limit=5)
|
||||
assert len(page2) == 5
|
||||
assert total2 == total1
|
||||
|
||||
def test_get_multi_with_total_sorting_asc(self, db_session):
|
||||
"""Test sorting in ascending order."""
|
||||
# Create users
|
||||
users = [
|
||||
User(email=f"sort{i}@example.com", password_hash="hash", first_name=f"User{chr(90-i)}",
|
||||
is_active=True, is_superuser=False)
|
||||
for i in range(5)
|
||||
]
|
||||
db_session.add_all(users)
|
||||
db_session.commit()
|
||||
|
||||
results, _ = user_crud.get_multi_with_total(
|
||||
db_session,
|
||||
sort_by="first_name",
|
||||
sort_order="asc"
|
||||
)
|
||||
|
||||
# Check that results are sorted
|
||||
first_names = [u.first_name for u in results if u.first_name.startswith("User")]
|
||||
assert first_names == sorted(first_names)
|
||||
|
||||
def test_get_multi_with_total_sorting_desc(self, db_session):
|
||||
"""Test sorting in descending order."""
|
||||
# Create users
|
||||
users = [
|
||||
User(email=f"desc{i}@example.com", password_hash="hash", first_name=f"User{chr(65+i)}",
|
||||
is_active=True, is_superuser=False)
|
||||
for i in range(5)
|
||||
]
|
||||
db_session.add_all(users)
|
||||
db_session.commit()
|
||||
|
||||
results, _ = user_crud.get_multi_with_total(
|
||||
db_session,
|
||||
sort_by="first_name",
|
||||
sort_order="desc"
|
||||
)
|
||||
|
||||
# Check that results are sorted descending
|
||||
first_names = [u.first_name for u in results if u.first_name.startswith("User")]
|
||||
assert first_names == sorted(first_names, reverse=True)
|
||||
|
||||
def test_get_multi_with_total_filtering(self, db_session):
|
||||
"""Test filtering with get_multi_with_total."""
|
||||
# Create active and inactive users
|
||||
active_user = User(
|
||||
email="active_filter@example.com",
|
||||
password_hash="hash",
|
||||
first_name="Active",
|
||||
is_active=True,
|
||||
is_superuser=False
|
||||
)
|
||||
inactive_user = User(
|
||||
email="inactive_filter@example.com",
|
||||
password_hash="hash",
|
||||
first_name="Inactive",
|
||||
is_active=False,
|
||||
is_superuser=False
|
||||
)
|
||||
db_session.add_all([active_user, inactive_user])
|
||||
db_session.commit()
|
||||
|
||||
# Filter for active users only
|
||||
results, total = user_crud.get_multi_with_total(
|
||||
db_session,
|
||||
filters={"is_active": True}
|
||||
)
|
||||
|
||||
emails = [u.email for u in results]
|
||||
assert "active_filter@example.com" in emails
|
||||
assert "inactive_filter@example.com" not in emails
|
||||
|
||||
def test_get_multi_with_total_multiple_filters(self, db_session):
|
||||
"""Test multiple filters."""
|
||||
# Create users with different combinations
|
||||
user1 = User(
|
||||
email="multi1@example.com",
|
||||
password_hash="hash",
|
||||
first_name="User1",
|
||||
is_active=True,
|
||||
is_superuser=True
|
||||
)
|
||||
user2 = User(
|
||||
email="multi2@example.com",
|
||||
password_hash="hash",
|
||||
first_name="User2",
|
||||
is_active=True,
|
||||
is_superuser=False
|
||||
)
|
||||
user3 = User(
|
||||
email="multi3@example.com",
|
||||
password_hash="hash",
|
||||
first_name="User3",
|
||||
is_active=False,
|
||||
is_superuser=True
|
||||
)
|
||||
db_session.add_all([user1, user2, user3])
|
||||
db_session.commit()
|
||||
|
||||
# Filter for active superusers
|
||||
results, _ = user_crud.get_multi_with_total(
|
||||
db_session,
|
||||
filters={"is_active": True, "is_superuser": True}
|
||||
)
|
||||
|
||||
emails = [u.email for u in results]
|
||||
assert "multi1@example.com" in emails
|
||||
assert "multi2@example.com" not in emails
|
||||
assert "multi3@example.com" not in emails
|
||||
|
||||
def test_get_multi_with_total_nonexistent_sort_field(self, db_session):
|
||||
"""Test sorting by non-existent field is ignored."""
|
||||
results, _ = user_crud.get_multi_with_total(
|
||||
db_session,
|
||||
sort_by="nonexistent_field",
|
||||
sort_order="asc"
|
||||
)
|
||||
|
||||
# Should not raise an error, just ignore the invalid sort field
|
||||
assert results is not None
|
||||
|
||||
def test_get_multi_with_total_nonexistent_filter_field(self, db_session):
|
||||
"""Test filtering by non-existent field is ignored."""
|
||||
results, _ = user_crud.get_multi_with_total(
|
||||
db_session,
|
||||
filters={"nonexistent_field": "value"}
|
||||
)
|
||||
|
||||
# Should not raise an error, just ignore the invalid filter
|
||||
assert results is not None
|
||||
|
||||
def test_get_multi_with_total_none_filter_values(self, db_session):
|
||||
"""Test that None filter values are ignored."""
|
||||
user = User(
|
||||
email="none_filter@example.com",
|
||||
password_hash="hash",
|
||||
first_name="None",
|
||||
is_active=True,
|
||||
is_superuser=False
|
||||
)
|
||||
db_session.add(user)
|
||||
db_session.commit()
|
||||
|
||||
# Pass None as a filter value - should be ignored
|
||||
results, _ = user_crud.get_multi_with_total(
|
||||
db_session,
|
||||
filters={"is_active": None}
|
||||
)
|
||||
|
||||
# Should return all users (not filtered)
|
||||
assert len(results) >= 1
|
||||
|
||||
|
||||
class TestCRUDCreate:
|
||||
"""Tests for create operations."""
|
||||
|
||||
def test_create_basic(self, db_session):
|
||||
"""Test basic record creation."""
|
||||
user_data = UserCreate(
|
||||
email="create@example.com",
|
||||
password="Password123!",
|
||||
first_name="Create",
|
||||
last_name="Test"
|
||||
)
|
||||
|
||||
created = user_crud.create(db_session, obj_in=user_data)
|
||||
|
||||
assert created.id is not None
|
||||
assert created.email == "create@example.com"
|
||||
assert created.first_name == "Create"
|
||||
|
||||
def test_create_duplicate_email(self, db_session):
|
||||
"""Test that creating duplicate email raises error."""
|
||||
user_data = UserCreate(
|
||||
email="duplicate@example.com",
|
||||
password="Password123!",
|
||||
first_name="First"
|
||||
)
|
||||
|
||||
# Create first user
|
||||
user_crud.create(db_session, obj_in=user_data)
|
||||
|
||||
# Try to create duplicate
|
||||
with pytest.raises(ValueError, match="already exists"):
|
||||
user_crud.create(db_session, obj_in=user_data)
|
||||
|
||||
|
||||
class TestCRUDUpdate:
|
||||
"""Tests for update operations."""
|
||||
|
||||
def test_update_basic(self, db_session):
|
||||
"""Test basic record update."""
|
||||
user = User(
|
||||
email="update@example.com",
|
||||
password_hash="hash",
|
||||
first_name="Original",
|
||||
is_active=True,
|
||||
is_superuser=False
|
||||
)
|
||||
db_session.add(user)
|
||||
db_session.commit()
|
||||
db_session.refresh(user)
|
||||
|
||||
update_data = UserUpdate(first_name="Updated")
|
||||
updated = user_crud.update(db_session, db_obj=user, obj_in=update_data)
|
||||
|
||||
assert updated.first_name == "Updated"
|
||||
assert updated.email == "update@example.com" # Unchanged
|
||||
|
||||
def test_update_with_dict(self, db_session):
|
||||
"""Test updating with dictionary."""
|
||||
user = User(
|
||||
email="updatedict@example.com",
|
||||
password_hash="hash",
|
||||
first_name="Original",
|
||||
is_active=True,
|
||||
is_superuser=False
|
||||
)
|
||||
db_session.add(user)
|
||||
db_session.commit()
|
||||
db_session.refresh(user)
|
||||
|
||||
update_data = {"first_name": "DictUpdated", "last_name": "DictLast"}
|
||||
updated = user_crud.update(db_session, db_obj=user, obj_in=update_data)
|
||||
|
||||
assert updated.first_name == "DictUpdated"
|
||||
assert updated.last_name == "DictLast"
|
||||
|
||||
def test_update_partial(self, db_session):
|
||||
"""Test partial update (only some fields)."""
|
||||
user = User(
|
||||
email="partial@example.com",
|
||||
password_hash="hash",
|
||||
first_name="First",
|
||||
last_name="Last",
|
||||
is_active=True,
|
||||
is_superuser=False
|
||||
)
|
||||
db_session.add(user)
|
||||
db_session.commit()
|
||||
db_session.refresh(user)
|
||||
|
||||
# Only update last_name
|
||||
update_data = UserUpdate(last_name="NewLast")
|
||||
updated = user_crud.update(db_session, db_obj=user, obj_in=update_data)
|
||||
|
||||
assert updated.first_name == "First" # Unchanged
|
||||
assert updated.last_name == "NewLast" # Changed
|
||||
|
||||
|
||||
class TestCRUDRemove:
|
||||
"""Tests for remove (hard delete) operations."""
|
||||
|
||||
def test_remove_basic(self, db_session):
|
||||
"""Test basic record removal."""
|
||||
user = User(
|
||||
email="remove@example.com",
|
||||
password_hash="hash",
|
||||
first_name="Remove",
|
||||
is_active=True,
|
||||
is_superuser=False
|
||||
)
|
||||
db_session.add(user)
|
||||
db_session.commit()
|
||||
db_session.refresh(user)
|
||||
|
||||
user_id = user.id
|
||||
|
||||
# Remove the user
|
||||
removed = user_crud.remove(db_session, id=user_id)
|
||||
|
||||
assert removed is not None
|
||||
assert removed.id == user_id
|
||||
|
||||
# User should no longer exist
|
||||
retrieved = user_crud.get(db_session, id=user_id)
|
||||
assert retrieved is None
|
||||
|
||||
def test_remove_nonexistent(self, db_session):
|
||||
"""Test removing non-existent record."""
|
||||
fake_id = uuid4()
|
||||
result = user_crud.remove(db_session, id=fake_id)
|
||||
assert result is None
|
||||
|
||||
def test_remove_invalid_uuid(self, db_session):
|
||||
"""Test removing with invalid UUID."""
|
||||
result = user_crud.remove(db_session, id="not-a-uuid")
|
||||
assert result is None
|
||||
@@ -1,295 +0,0 @@
|
||||
# tests/crud/test_crud_error_paths.py
|
||||
"""
|
||||
Tests for CRUD error handling paths to increase coverage.
|
||||
These tests focus on exception handling and edge cases.
|
||||
"""
|
||||
import pytest
|
||||
from unittest.mock import patch, MagicMock
|
||||
from sqlalchemy.exc import IntegrityError, OperationalError
|
||||
|
||||
from app.models.user import User
|
||||
from app.crud.user import user as user_crud
|
||||
from app.schemas.users import UserCreate, UserUpdate
|
||||
|
||||
|
||||
class TestCRUDErrorPaths:
|
||||
"""Tests for error handling in CRUD operations."""
|
||||
|
||||
def test_get_database_error(self, db_session):
|
||||
"""Test get method handles database errors."""
|
||||
import uuid
|
||||
user_id = uuid.uuid4()
|
||||
|
||||
with patch.object(db_session, 'query') as mock_query:
|
||||
mock_query.side_effect = OperationalError("statement", "params", "orig")
|
||||
|
||||
with pytest.raises(OperationalError):
|
||||
user_crud.get(db_session, id=user_id)
|
||||
|
||||
def test_get_multi_database_error(self, db_session):
|
||||
"""Test get_multi handles database errors."""
|
||||
with patch.object(db_session, 'query') as mock_query:
|
||||
mock_query.side_effect = OperationalError("statement", "params", "orig")
|
||||
|
||||
with pytest.raises(OperationalError):
|
||||
user_crud.get_multi(db_session, skip=0, limit=10)
|
||||
|
||||
def test_create_integrity_error_non_unique(self, db_session):
|
||||
"""Test create handles integrity errors for non-unique constraints."""
|
||||
# Create first user
|
||||
user_data = UserCreate(
|
||||
email="unique@example.com",
|
||||
password="Password123!",
|
||||
first_name="First"
|
||||
)
|
||||
user_crud.create(db_session, obj_in=user_data)
|
||||
|
||||
# Try to create duplicate
|
||||
with pytest.raises(ValueError, match="already exists"):
|
||||
user_crud.create(db_session, obj_in=user_data)
|
||||
|
||||
def test_create_generic_integrity_error(self, db_session):
|
||||
"""Test create handles other integrity errors."""
|
||||
user_data = UserCreate(
|
||||
email="integrityerror@example.com",
|
||||
password="Password123!",
|
||||
first_name="Integrity"
|
||||
)
|
||||
|
||||
with patch('app.crud.base.jsonable_encoder') as mock_encoder:
|
||||
mock_encoder.return_value = {"email": "test@example.com"}
|
||||
|
||||
with patch.object(db_session, 'add') as mock_add:
|
||||
# Simulate a non-unique integrity error
|
||||
error = IntegrityError("statement", "params", Exception("check constraint failed"))
|
||||
mock_add.side_effect = error
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
user_crud.create(db_session, obj_in=user_data)
|
||||
|
||||
def test_create_unexpected_error(self, db_session):
|
||||
"""Test create handles unexpected errors."""
|
||||
user_data = UserCreate(
|
||||
email="unexpectederror@example.com",
|
||||
password="Password123!",
|
||||
first_name="Unexpected"
|
||||
)
|
||||
|
||||
with patch.object(db_session, 'commit') as mock_commit:
|
||||
mock_commit.side_effect = Exception("Unexpected database error")
|
||||
|
||||
with pytest.raises(Exception):
|
||||
user_crud.create(db_session, obj_in=user_data)
|
||||
|
||||
def test_update_integrity_error(self, db_session):
|
||||
"""Test update handles integrity errors."""
|
||||
# Create a user
|
||||
user = User(
|
||||
email="updateintegrity@example.com",
|
||||
password_hash="hash",
|
||||
first_name="Update",
|
||||
is_active=True,
|
||||
is_superuser=False
|
||||
)
|
||||
db_session.add(user)
|
||||
db_session.commit()
|
||||
db_session.refresh(user)
|
||||
|
||||
# Create another user with a different email
|
||||
user2 = User(
|
||||
email="another@example.com",
|
||||
password_hash="hash",
|
||||
first_name="Another",
|
||||
is_active=True,
|
||||
is_superuser=False
|
||||
)
|
||||
db_session.add(user2)
|
||||
db_session.commit()
|
||||
|
||||
# Try to update user to have the same email as user2
|
||||
with patch.object(db_session, 'commit') as mock_commit:
|
||||
error = IntegrityError("statement", "params", Exception("UNIQUE constraint failed"))
|
||||
mock_commit.side_effect = error
|
||||
|
||||
update_data = UserUpdate(email="another@example.com")
|
||||
with pytest.raises(ValueError, match="already exists"):
|
||||
user_crud.update(db_session, db_obj=user, obj_in=update_data)
|
||||
|
||||
def test_update_unexpected_error(self, db_session):
|
||||
"""Test update handles unexpected errors."""
|
||||
user = User(
|
||||
email="updateunexpected@example.com",
|
||||
password_hash="hash",
|
||||
first_name="Update",
|
||||
is_active=True,
|
||||
is_superuser=False
|
||||
)
|
||||
db_session.add(user)
|
||||
db_session.commit()
|
||||
db_session.refresh(user)
|
||||
|
||||
with patch.object(db_session, 'commit') as mock_commit:
|
||||
mock_commit.side_effect = Exception("Unexpected database error")
|
||||
|
||||
update_data = UserUpdate(first_name="Error")
|
||||
with pytest.raises(Exception):
|
||||
user_crud.update(db_session, db_obj=user, obj_in=update_data)
|
||||
|
||||
def test_remove_with_relationships(self, db_session):
|
||||
"""Test remove handles cascade deletes."""
|
||||
user = User(
|
||||
email="removerelations@example.com",
|
||||
password_hash="hash",
|
||||
first_name="Remove",
|
||||
is_active=True,
|
||||
is_superuser=False
|
||||
)
|
||||
db_session.add(user)
|
||||
db_session.commit()
|
||||
db_session.refresh(user)
|
||||
|
||||
# Remove should succeed even with potential relationships
|
||||
removed = user_crud.remove(db_session, id=user.id)
|
||||
assert removed is not None
|
||||
assert removed.id == user.id
|
||||
|
||||
def test_soft_delete_database_error(self, db_session):
|
||||
"""Test soft_delete handles database errors."""
|
||||
user = User(
|
||||
email="softdeleteerror@example.com",
|
||||
password_hash="hash",
|
||||
first_name="SoftDelete",
|
||||
is_active=True,
|
||||
is_superuser=False
|
||||
)
|
||||
db_session.add(user)
|
||||
db_session.commit()
|
||||
db_session.refresh(user)
|
||||
|
||||
with patch.object(db_session, 'commit') as mock_commit:
|
||||
mock_commit.side_effect = Exception("Database error")
|
||||
|
||||
with pytest.raises(Exception):
|
||||
user_crud.soft_delete(db_session, id=user.id)
|
||||
|
||||
def test_restore_database_error(self, db_session):
|
||||
"""Test restore handles database errors."""
|
||||
user = User(
|
||||
email="restoreerror@example.com",
|
||||
password_hash="hash",
|
||||
first_name="Restore",
|
||||
is_active=True,
|
||||
is_superuser=False
|
||||
)
|
||||
db_session.add(user)
|
||||
db_session.commit()
|
||||
db_session.refresh(user)
|
||||
|
||||
# First soft delete
|
||||
user_crud.soft_delete(db_session, id=user.id)
|
||||
|
||||
# Then try to restore with error
|
||||
with patch.object(db_session, 'commit') as mock_commit:
|
||||
mock_commit.side_effect = Exception("Database error")
|
||||
|
||||
with pytest.raises(Exception):
|
||||
user_crud.restore(db_session, id=user.id)
|
||||
|
||||
def test_get_multi_with_total_error_recovery(self, db_session):
|
||||
"""Test get_multi_with_total handles errors gracefully."""
|
||||
# Test that it doesn't crash on invalid sort fields
|
||||
users, total = user_crud.get_multi_with_total(
|
||||
db_session,
|
||||
sort_by="nonexistent_field_xyz",
|
||||
sort_order="asc"
|
||||
)
|
||||
# Should still return results, just ignore invalid sort
|
||||
assert isinstance(users, list)
|
||||
assert isinstance(total, int)
|
||||
|
||||
def test_update_with_model_dict(self, db_session):
|
||||
"""Test update works with dict input."""
|
||||
user = User(
|
||||
email="updatedict2@example.com",
|
||||
password_hash="hash",
|
||||
first_name="Original",
|
||||
is_active=True,
|
||||
is_superuser=False
|
||||
)
|
||||
db_session.add(user)
|
||||
db_session.commit()
|
||||
db_session.refresh(user)
|
||||
|
||||
# Update with plain dict
|
||||
update_data = {"first_name": "DictUpdated"}
|
||||
updated = user_crud.update(db_session, db_obj=user, obj_in=update_data)
|
||||
|
||||
assert updated.first_name == "DictUpdated"
|
||||
|
||||
def test_update_preserves_unchanged_fields(self, db_session):
|
||||
"""Test that update doesn't modify unspecified fields."""
|
||||
user = User(
|
||||
email="preserve@example.com",
|
||||
password_hash="original_hash",
|
||||
first_name="Original",
|
||||
last_name="Name",
|
||||
phone_number="+1234567890",
|
||||
is_active=True,
|
||||
is_superuser=False
|
||||
)
|
||||
db_session.add(user)
|
||||
db_session.commit()
|
||||
db_session.refresh(user)
|
||||
|
||||
original_password = user.password_hash
|
||||
original_phone = user.phone_number
|
||||
|
||||
# Only update first_name
|
||||
update_data = UserUpdate(first_name="Updated")
|
||||
updated = user_crud.update(db_session, db_obj=user, obj_in=update_data)
|
||||
|
||||
assert updated.first_name == "Updated"
|
||||
assert updated.password_hash == original_password # Unchanged
|
||||
assert updated.phone_number == original_phone # Unchanged
|
||||
assert updated.last_name == "Name" # Unchanged
|
||||
|
||||
|
||||
class TestCRUDValidation:
|
||||
"""Tests for validation in CRUD operations."""
|
||||
|
||||
def test_get_multi_with_empty_results(self, db_session):
|
||||
"""Test get_multi with no results."""
|
||||
# Query with filters that return no results
|
||||
users, total = user_crud.get_multi_with_total(
|
||||
db_session,
|
||||
filters={"email": "nonexistent@example.com"}
|
||||
)
|
||||
|
||||
assert users == []
|
||||
assert total == 0
|
||||
|
||||
def test_get_multi_with_large_offset(self, db_session):
|
||||
"""Test get_multi with offset larger than total records."""
|
||||
users = user_crud.get_multi(db_session, skip=10000, limit=10)
|
||||
assert users == []
|
||||
|
||||
def test_update_with_no_changes(self, db_session):
|
||||
"""Test update when no fields are changed."""
|
||||
user = User(
|
||||
email="nochanges@example.com",
|
||||
password_hash="hash",
|
||||
first_name="NoChanges",
|
||||
is_active=True,
|
||||
is_superuser=False
|
||||
)
|
||||
db_session.add(user)
|
||||
db_session.commit()
|
||||
db_session.refresh(user)
|
||||
|
||||
# Update with empty dict
|
||||
update_data = {}
|
||||
updated = user_crud.update(db_session, db_obj=user, obj_in=update_data)
|
||||
|
||||
# Should still return the user, unchanged
|
||||
assert updated.id == user.id
|
||||
assert updated.first_name == "NoChanges"
|
||||
@@ -7,7 +7,7 @@ from uuid import uuid4
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.crud.organization_async import organization_async
|
||||
from app.crud.organization import organization as organization_crud
|
||||
from app.models.organization import Organization
|
||||
from app.models.user_organization import UserOrganization, OrganizationRole
|
||||
from app.models.user import User
|
||||
@@ -35,7 +35,7 @@ class TestGetBySlug:
|
||||
|
||||
# Get by slug
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
result = await organization_async.get_by_slug(session, slug="test-org")
|
||||
result = await organization_crud.get_by_slug(session, slug="test-org")
|
||||
assert result is not None
|
||||
assert result.id == org_id
|
||||
assert result.slug == "test-org"
|
||||
@@ -46,7 +46,7 @@ class TestGetBySlug:
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
result = await organization_async.get_by_slug(session, slug="nonexistent")
|
||||
result = await organization_crud.get_by_slug(session, slug="nonexistent")
|
||||
assert result is None
|
||||
|
||||
|
||||
@@ -55,7 +55,7 @@ class TestCreate:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_success(self, async_test_db):
|
||||
"""Test successfully creating an organization."""
|
||||
"""Test successfully creating an organization_crud."""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
@@ -66,7 +66,7 @@ class TestCreate:
|
||||
is_active=True,
|
||||
settings={"key": "value"}
|
||||
)
|
||||
result = await organization_async.create(session, obj_in=org_in)
|
||||
result = await organization_crud.create(session, obj_in=org_in)
|
||||
|
||||
assert result.name == "New Org"
|
||||
assert result.slug == "new-org"
|
||||
@@ -92,7 +92,7 @@ class TestCreate:
|
||||
slug="duplicate-slug"
|
||||
)
|
||||
with pytest.raises(ValueError, match="already exists"):
|
||||
await organization_async.create(session, obj_in=org_in)
|
||||
await organization_crud.create(session, obj_in=org_in)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_without_settings(self, async_test_db):
|
||||
@@ -104,7 +104,7 @@ class TestCreate:
|
||||
name="No Settings Org",
|
||||
slug="no-settings"
|
||||
)
|
||||
result = await organization_async.create(session, obj_in=org_in)
|
||||
result = await organization_crud.create(session, obj_in=org_in)
|
||||
|
||||
assert result.settings == {}
|
||||
|
||||
@@ -125,7 +125,7 @@ class TestGetMultiWithFilters:
|
||||
await session.commit()
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
orgs, total = await organization_async.get_multi_with_filters(session)
|
||||
orgs, total = await organization_crud.get_multi_with_filters(session)
|
||||
assert total == 5
|
||||
assert len(orgs) == 5
|
||||
|
||||
@@ -141,7 +141,7 @@ class TestGetMultiWithFilters:
|
||||
await session.commit()
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
orgs, total = await organization_async.get_multi_with_filters(
|
||||
orgs, total = await organization_crud.get_multi_with_filters(
|
||||
session,
|
||||
is_active=True
|
||||
)
|
||||
@@ -160,7 +160,7 @@ class TestGetMultiWithFilters:
|
||||
await session.commit()
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
orgs, total = await organization_async.get_multi_with_filters(
|
||||
orgs, total = await organization_crud.get_multi_with_filters(
|
||||
session,
|
||||
search="tech"
|
||||
)
|
||||
@@ -179,7 +179,7 @@ class TestGetMultiWithFilters:
|
||||
await session.commit()
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
orgs, total = await organization_async.get_multi_with_filters(
|
||||
orgs, total = await organization_crud.get_multi_with_filters(
|
||||
session,
|
||||
skip=2,
|
||||
limit=3
|
||||
@@ -199,7 +199,7 @@ class TestGetMultiWithFilters:
|
||||
await session.commit()
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
orgs, total = await organization_async.get_multi_with_filters(
|
||||
orgs, total = await organization_crud.get_multi_with_filters(
|
||||
session,
|
||||
sort_by="name",
|
||||
sort_order="asc"
|
||||
@@ -213,7 +213,7 @@ class TestGetMemberCount:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_member_count_success(self, async_test_db, async_test_user):
|
||||
"""Test getting member count for organization."""
|
||||
"""Test getting member count for organization_crud."""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
@@ -233,7 +233,7 @@ class TestGetMemberCount:
|
||||
org_id = org.id
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
count = await organization_async.get_member_count(session, organization_id=org_id)
|
||||
count = await organization_crud.get_member_count(session, organization_id=org_id)
|
||||
assert count == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -248,7 +248,7 @@ class TestGetMemberCount:
|
||||
org_id = org.id
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
count = await organization_async.get_member_count(session, organization_id=org_id)
|
||||
count = await organization_crud.get_member_count(session, organization_id=org_id)
|
||||
assert count == 0
|
||||
|
||||
|
||||
@@ -257,7 +257,7 @@ class TestAddUser:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_add_user_success(self, async_test_db, async_test_user):
|
||||
"""Test successfully adding a user to organization."""
|
||||
"""Test successfully adding a user to organization_crud."""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
@@ -267,7 +267,7 @@ class TestAddUser:
|
||||
org_id = org.id
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
result = await organization_async.add_user(
|
||||
result = await organization_crud.add_user(
|
||||
session,
|
||||
organization_id=org_id,
|
||||
user_id=async_test_user.id,
|
||||
@@ -301,7 +301,7 @@ class TestAddUser:
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
with pytest.raises(ValueError, match="already a member"):
|
||||
await organization_async.add_user(
|
||||
await organization_crud.add_user(
|
||||
session,
|
||||
organization_id=org_id,
|
||||
user_id=async_test_user.id
|
||||
@@ -328,7 +328,7 @@ class TestAddUser:
|
||||
org_id = org.id
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
result = await organization_async.add_user(
|
||||
result = await organization_crud.add_user(
|
||||
session,
|
||||
organization_id=org_id,
|
||||
user_id=async_test_user.id,
|
||||
@@ -344,7 +344,7 @@ class TestRemoveUser:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_remove_user_success(self, async_test_db, async_test_user):
|
||||
"""Test successfully removing a user from organization."""
|
||||
"""Test successfully removing a user from organization_crud."""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
@@ -363,7 +363,7 @@ class TestRemoveUser:
|
||||
org_id = org.id
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
result = await organization_async.remove_user(
|
||||
result = await organization_crud.remove_user(
|
||||
session,
|
||||
organization_id=org_id,
|
||||
user_id=async_test_user.id
|
||||
@@ -393,7 +393,7 @@ class TestRemoveUser:
|
||||
org_id = org.id
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
result = await organization_async.remove_user(
|
||||
result = await organization_crud.remove_user(
|
||||
session,
|
||||
organization_id=org_id,
|
||||
user_id=uuid4()
|
||||
@@ -426,7 +426,7 @@ class TestUpdateUserRole:
|
||||
org_id = org.id
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
result = await organization_async.update_user_role(
|
||||
result = await organization_crud.update_user_role(
|
||||
session,
|
||||
organization_id=org_id,
|
||||
user_id=async_test_user.id,
|
||||
@@ -449,7 +449,7 @@ class TestUpdateUserRole:
|
||||
org_id = org.id
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
result = await organization_async.update_user_role(
|
||||
result = await organization_crud.update_user_role(
|
||||
session,
|
||||
organization_id=org_id,
|
||||
user_id=uuid4(),
|
||||
@@ -483,7 +483,7 @@ class TestGetOrganizationMembers:
|
||||
org_id = org.id
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
members, total = await organization_async.get_organization_members(
|
||||
members, total = await organization_crud.get_organization_members(
|
||||
session,
|
||||
organization_id=org_id
|
||||
)
|
||||
@@ -515,7 +515,7 @@ class TestGetOrganizationMembers:
|
||||
org_id = org.id
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
members, total = await organization_async.get_organization_members(
|
||||
members, total = await organization_crud.get_organization_members(
|
||||
session,
|
||||
organization_id=org_id,
|
||||
skip=0,
|
||||
@@ -549,7 +549,7 @@ class TestGetUserOrganizations:
|
||||
await session.commit()
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
orgs = await organization_async.get_user_organizations(
|
||||
orgs = await organization_crud.get_user_organizations(
|
||||
session,
|
||||
user_id=async_test_user.id
|
||||
)
|
||||
@@ -584,7 +584,7 @@ class TestGetUserOrganizations:
|
||||
await session.commit()
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
orgs = await organization_async.get_user_organizations(
|
||||
orgs = await organization_crud.get_user_organizations(
|
||||
session,
|
||||
user_id=async_test_user.id,
|
||||
is_active=True
|
||||
@@ -599,7 +599,7 @@ class TestGetUserRole:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_user_role_in_org_success(self, async_test_db, async_test_user):
|
||||
"""Test getting user role in organization."""
|
||||
"""Test getting user role in organization_crud."""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
@@ -618,7 +618,7 @@ class TestGetUserRole:
|
||||
org_id = org.id
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
role = await organization_async.get_user_role_in_org(
|
||||
role = await organization_crud.get_user_role_in_org(
|
||||
session,
|
||||
user_id=async_test_user.id,
|
||||
organization_id=org_id
|
||||
@@ -638,7 +638,7 @@ class TestGetUserRole:
|
||||
org_id = org.id
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
role = await organization_async.get_user_role_in_org(
|
||||
role = await organization_crud.get_user_role_in_org(
|
||||
session,
|
||||
user_id=uuid4(),
|
||||
organization_id=org_id
|
||||
@@ -671,7 +671,7 @@ class TestIsUserOrgOwner:
|
||||
org_id = org.id
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
is_owner = await organization_async.is_user_org_owner(
|
||||
is_owner = await organization_crud.is_user_org_owner(
|
||||
session,
|
||||
user_id=async_test_user.id,
|
||||
organization_id=org_id
|
||||
@@ -700,7 +700,7 @@ class TestIsUserOrgOwner:
|
||||
org_id = org.id
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
is_owner = await organization_async.is_user_org_owner(
|
||||
is_owner = await organization_crud.is_user_org_owner(
|
||||
session,
|
||||
user_id=async_test_user.id,
|
||||
organization_id=org_id
|
||||
@@ -734,7 +734,7 @@ class TestGetMultiWithMemberCounts:
|
||||
await session.commit()
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
orgs_with_counts, total = await organization_async.get_multi_with_member_counts(session)
|
||||
orgs_with_counts, total = await organization_crud.get_multi_with_member_counts(session)
|
||||
|
||||
assert total == 2
|
||||
assert len(orgs_with_counts) == 2
|
||||
@@ -754,7 +754,7 @@ class TestGetMultiWithMemberCounts:
|
||||
await session.commit()
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
orgs_with_counts, total = await organization_async.get_multi_with_member_counts(
|
||||
orgs_with_counts, total = await organization_crud.get_multi_with_member_counts(
|
||||
session,
|
||||
is_active=True
|
||||
)
|
||||
@@ -774,7 +774,7 @@ class TestGetMultiWithMemberCounts:
|
||||
await session.commit()
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
orgs_with_counts, total = await organization_async.get_multi_with_member_counts(
|
||||
orgs_with_counts, total = await organization_crud.get_multi_with_member_counts(
|
||||
session,
|
||||
search="tech"
|
||||
)
|
||||
@@ -806,7 +806,7 @@ class TestGetUserOrganizationsWithDetails:
|
||||
await session.commit()
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
orgs_with_details = await organization_async.get_user_organizations_with_details(
|
||||
orgs_with_details = await organization_crud.get_user_organizations_with_details(
|
||||
session,
|
||||
user_id=async_test_user.id
|
||||
)
|
||||
@@ -843,7 +843,7 @@ class TestGetUserOrganizationsWithDetails:
|
||||
await session.commit()
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
orgs_with_details = await organization_async.get_user_organizations_with_details(
|
||||
orgs_with_details = await organization_crud.get_user_organizations_with_details(
|
||||
session,
|
||||
user_id=async_test_user.id,
|
||||
is_active=True
|
||||
@@ -877,7 +877,7 @@ class TestIsUserOrgAdmin:
|
||||
org_id = org.id
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
is_admin = await organization_async.is_user_org_admin(
|
||||
is_admin = await organization_crud.is_user_org_admin(
|
||||
session,
|
||||
user_id=async_test_user.id,
|
||||
organization_id=org_id
|
||||
@@ -906,7 +906,7 @@ class TestIsUserOrgAdmin:
|
||||
org_id = org.id
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
is_admin = await organization_async.is_user_org_admin(
|
||||
is_admin = await organization_crud.is_user_org_admin(
|
||||
session,
|
||||
user_id=async_test_user.id,
|
||||
organization_id=org_id
|
||||
@@ -935,7 +935,7 @@ class TestIsUserOrgAdmin:
|
||||
org_id = org.id
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
is_admin = await organization_async.is_user_org_admin(
|
||||
is_admin = await organization_crud.is_user_org_admin(
|
||||
session,
|
||||
user_id=async_test_user.id,
|
||||
organization_id=org_id
|
||||
@@ -6,7 +6,7 @@ import pytest
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from uuid import uuid4
|
||||
|
||||
from app.crud.session_async import session_async
|
||||
from app.crud.session import session as session_crud
|
||||
from app.models.user_session import UserSession
|
||||
from app.schemas.sessions import SessionCreate
|
||||
|
||||
@@ -34,7 +34,7 @@ class TestGetByJti:
|
||||
await session.commit()
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
result = await session_async.get_by_jti(session, jti="test_jti_123")
|
||||
result = await session_crud.get_by_jti(session, jti="test_jti_123")
|
||||
assert result is not None
|
||||
assert result.refresh_token_jti == "test_jti_123"
|
||||
|
||||
@@ -44,7 +44,7 @@ class TestGetByJti:
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
result = await session_async.get_by_jti(session, jti="nonexistent")
|
||||
result = await session_crud.get_by_jti(session, jti="nonexistent")
|
||||
assert result is None
|
||||
|
||||
|
||||
@@ -71,7 +71,7 @@ class TestGetActiveByJti:
|
||||
await session.commit()
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
result = await session_async.get_active_by_jti(session, jti="active_jti")
|
||||
result = await session_crud.get_active_by_jti(session, jti="active_jti")
|
||||
assert result is not None
|
||||
assert result.is_active is True
|
||||
|
||||
@@ -95,7 +95,7 @@ class TestGetActiveByJti:
|
||||
await session.commit()
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
result = await session_async.get_active_by_jti(session, jti="inactive_jti")
|
||||
result = await session_crud.get_active_by_jti(session, jti="inactive_jti")
|
||||
assert result is None
|
||||
|
||||
|
||||
@@ -132,7 +132,7 @@ class TestGetUserSessions:
|
||||
await session.commit()
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
results = await session_async.get_user_sessions(
|
||||
results = await session_crud.get_user_sessions(
|
||||
session,
|
||||
user_id=str(async_test_user.id),
|
||||
active_only=True
|
||||
@@ -161,7 +161,7 @@ class TestGetUserSessions:
|
||||
await session.commit()
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
results = await session_async.get_user_sessions(
|
||||
results = await session_crud.get_user_sessions(
|
||||
session,
|
||||
user_id=str(async_test_user.id),
|
||||
active_only=False
|
||||
@@ -174,7 +174,7 @@ class TestCreateSession:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_session_success(self, async_test_db, async_test_user):
|
||||
"""Test successfully creating a session."""
|
||||
"""Test successfully creating a session_crud."""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
@@ -190,7 +190,7 @@ class TestCreateSession:
|
||||
location_city="San Francisco",
|
||||
location_country="USA"
|
||||
)
|
||||
result = await session_async.create_session(session, obj_in=session_data)
|
||||
result = await session_crud.create_session(session, obj_in=session_data)
|
||||
|
||||
assert result.user_id == async_test_user.id
|
||||
assert result.refresh_token_jti == "new_jti"
|
||||
@@ -203,7 +203,7 @@ class TestDeactivate:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_deactivate_success(self, async_test_db, async_test_user):
|
||||
"""Test successfully deactivating a session."""
|
||||
"""Test successfully deactivating a session_crud."""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
@@ -222,7 +222,7 @@ class TestDeactivate:
|
||||
session_id = user_session.id
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
result = await session_async.deactivate(session, session_id=str(session_id))
|
||||
result = await session_crud.deactivate(session, session_id=str(session_id))
|
||||
assert result is not None
|
||||
assert result.is_active is False
|
||||
|
||||
@@ -232,7 +232,7 @@ class TestDeactivate:
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
result = await session_async.deactivate(session, session_id=str(uuid4()))
|
||||
result = await session_crud.deactivate(session, session_id=str(uuid4()))
|
||||
assert result is None
|
||||
|
||||
|
||||
@@ -260,7 +260,7 @@ class TestDeactivateAllUserSessions:
|
||||
await session.commit()
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
count = await session_async.deactivate_all_user_sessions(
|
||||
count = await session_crud.deactivate_all_user_sessions(
|
||||
session,
|
||||
user_id=str(async_test_user.id)
|
||||
)
|
||||
@@ -291,7 +291,7 @@ class TestUpdateLastUsed:
|
||||
await session.refresh(user_session)
|
||||
|
||||
old_time = user_session.last_used_at
|
||||
result = await session_async.update_last_used(session, session=user_session)
|
||||
result = await session_crud.update_last_used(session, session=user_session)
|
||||
|
||||
assert result.last_used_at > old_time
|
||||
|
||||
@@ -320,7 +320,7 @@ class TestGetUserSessionCount:
|
||||
await session.commit()
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
count = await session_async.get_user_session_count(
|
||||
count = await session_crud.get_user_session_count(
|
||||
session,
|
||||
user_id=str(async_test_user.id)
|
||||
)
|
||||
@@ -332,7 +332,7 @@ class TestGetUserSessionCount:
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
count = await session_async.get_user_session_count(
|
||||
count = await session_crud.get_user_session_count(
|
||||
session,
|
||||
user_id=str(uuid4())
|
||||
)
|
||||
@@ -1,324 +0,0 @@
|
||||
# tests/crud/test_soft_delete.py
|
||||
"""
|
||||
Tests for soft delete functionality in CRUD operations.
|
||||
"""
|
||||
import pytest
|
||||
from datetime import datetime, timezone
|
||||
|
||||
from app.models.user import User
|
||||
from app.crud.user import user as user_crud
|
||||
|
||||
|
||||
class TestSoftDelete:
|
||||
"""Tests for soft delete functionality."""
|
||||
|
||||
def test_soft_delete_marks_deleted_at(self, db_session):
|
||||
"""Test that soft delete sets deleted_at timestamp."""
|
||||
# Create a user
|
||||
test_user = User(
|
||||
email="softdelete@example.com",
|
||||
password_hash="hashedpassword",
|
||||
first_name="Soft",
|
||||
last_name="Delete",
|
||||
is_active=True,
|
||||
is_superuser=False
|
||||
)
|
||||
db_session.add(test_user)
|
||||
db_session.commit()
|
||||
db_session.refresh(test_user)
|
||||
|
||||
user_id = test_user.id
|
||||
assert test_user.deleted_at is None
|
||||
|
||||
# Soft delete the user
|
||||
deleted_user = user_crud.soft_delete(db_session, id=user_id)
|
||||
|
||||
assert deleted_user is not None
|
||||
assert deleted_user.deleted_at is not None
|
||||
assert isinstance(deleted_user.deleted_at, datetime)
|
||||
|
||||
def test_soft_delete_excludes_from_get_multi(self, db_session):
|
||||
"""Test that soft deleted records are excluded from get_multi."""
|
||||
# Create two users
|
||||
user1 = User(
|
||||
email="user1@example.com",
|
||||
password_hash="hash1",
|
||||
first_name="User",
|
||||
last_name="One",
|
||||
is_active=True,
|
||||
is_superuser=False
|
||||
)
|
||||
user2 = User(
|
||||
email="user2@example.com",
|
||||
password_hash="hash2",
|
||||
first_name="User",
|
||||
last_name="Two",
|
||||
is_active=True,
|
||||
is_superuser=False
|
||||
)
|
||||
db_session.add_all([user1, user2])
|
||||
db_session.commit()
|
||||
db_session.refresh(user1)
|
||||
db_session.refresh(user2)
|
||||
|
||||
# Both users should be returned
|
||||
users, total = user_crud.get_multi_with_total(db_session)
|
||||
assert total >= 2
|
||||
user_emails = [u.email for u in users]
|
||||
assert "user1@example.com" in user_emails
|
||||
assert "user2@example.com" in user_emails
|
||||
|
||||
# Soft delete user1
|
||||
user_crud.soft_delete(db_session, id=user1.id)
|
||||
|
||||
# Only user2 should be returned
|
||||
users, total = user_crud.get_multi_with_total(db_session)
|
||||
user_emails = [u.email for u in users]
|
||||
assert "user1@example.com" not in user_emails
|
||||
assert "user2@example.com" in user_emails
|
||||
|
||||
def test_soft_delete_still_retrievable_by_get(self, db_session):
|
||||
"""Test that soft deleted records can still be retrieved by get() method."""
|
||||
# Create a user
|
||||
user = User(
|
||||
email="gettest@example.com",
|
||||
password_hash="hash",
|
||||
first_name="Get",
|
||||
last_name="Test",
|
||||
is_active=True,
|
||||
is_superuser=False
|
||||
)
|
||||
db_session.add(user)
|
||||
db_session.commit()
|
||||
db_session.refresh(user)
|
||||
|
||||
user_id = user.id
|
||||
|
||||
# User should be retrievable
|
||||
retrieved = user_crud.get(db_session, id=user_id)
|
||||
assert retrieved is not None
|
||||
assert retrieved.email == "gettest@example.com"
|
||||
assert retrieved.deleted_at is None
|
||||
|
||||
# Soft delete the user
|
||||
user_crud.soft_delete(db_session, id=user_id)
|
||||
|
||||
# User should still be retrievable by ID (soft delete doesn't prevent direct access)
|
||||
retrieved = user_crud.get(db_session, id=user_id)
|
||||
assert retrieved is not None
|
||||
assert retrieved.deleted_at is not None
|
||||
|
||||
def test_soft_delete_nonexistent_record(self, db_session):
|
||||
"""Test soft deleting a record that doesn't exist."""
|
||||
import uuid
|
||||
fake_id = uuid.uuid4()
|
||||
|
||||
result = user_crud.soft_delete(db_session, id=fake_id)
|
||||
assert result is None
|
||||
|
||||
def test_restore_sets_deleted_at_to_none(self, db_session):
|
||||
"""Test that restore clears the deleted_at timestamp."""
|
||||
# Create and soft delete a user
|
||||
user = User(
|
||||
email="restore@example.com",
|
||||
password_hash="hash",
|
||||
first_name="Restore",
|
||||
last_name="Test",
|
||||
is_active=True,
|
||||
is_superuser=False
|
||||
)
|
||||
db_session.add(user)
|
||||
db_session.commit()
|
||||
db_session.refresh(user)
|
||||
|
||||
user_id = user.id
|
||||
|
||||
# Soft delete
|
||||
user_crud.soft_delete(db_session, id=user_id)
|
||||
db_session.refresh(user)
|
||||
assert user.deleted_at is not None
|
||||
|
||||
# Restore
|
||||
restored_user = user_crud.restore(db_session, id=user_id)
|
||||
|
||||
assert restored_user is not None
|
||||
assert restored_user.deleted_at is None
|
||||
|
||||
def test_restore_makes_record_available(self, db_session):
|
||||
"""Test that restored records appear in queries."""
|
||||
# Create and soft delete a user
|
||||
user = User(
|
||||
email="available@example.com",
|
||||
password_hash="hash",
|
||||
first_name="Available",
|
||||
last_name="Test",
|
||||
is_active=True,
|
||||
is_superuser=False
|
||||
)
|
||||
db_session.add(user)
|
||||
db_session.commit()
|
||||
db_session.refresh(user)
|
||||
|
||||
user_id = user.id
|
||||
user_email = user.email
|
||||
|
||||
# Soft delete
|
||||
user_crud.soft_delete(db_session, id=user_id)
|
||||
|
||||
# User should not be in query results
|
||||
users, _ = user_crud.get_multi_with_total(db_session)
|
||||
emails = [u.email for u in users]
|
||||
assert user_email not in emails
|
||||
|
||||
# Restore
|
||||
user_crud.restore(db_session, id=user_id)
|
||||
|
||||
# User should now be in query results
|
||||
users, _ = user_crud.get_multi_with_total(db_session)
|
||||
emails = [u.email for u in users]
|
||||
assert user_email in emails
|
||||
|
||||
def test_restore_nonexistent_record(self, db_session):
|
||||
"""Test restoring a record that doesn't exist."""
|
||||
import uuid
|
||||
fake_id = uuid.uuid4()
|
||||
|
||||
result = user_crud.restore(db_session, id=fake_id)
|
||||
assert result is None
|
||||
|
||||
def test_restore_already_active_record(self, db_session):
|
||||
"""Test restoring a record that was never deleted returns None."""
|
||||
# Create a user (not deleted)
|
||||
user = User(
|
||||
email="never_deleted@example.com",
|
||||
password_hash="hash",
|
||||
first_name="Never",
|
||||
last_name="Deleted",
|
||||
is_active=True,
|
||||
is_superuser=False
|
||||
)
|
||||
db_session.add(user)
|
||||
db_session.commit()
|
||||
db_session.refresh(user)
|
||||
|
||||
user_id = user.id
|
||||
assert user.deleted_at is None
|
||||
|
||||
# Restore should return None (record is not soft-deleted)
|
||||
restored = user_crud.restore(db_session, id=user_id)
|
||||
assert restored is None
|
||||
|
||||
def test_soft_delete_multiple_times(self, db_session):
|
||||
"""Test soft deleting the same record multiple times."""
|
||||
# Create a user
|
||||
user = User(
|
||||
email="multiple_delete@example.com",
|
||||
password_hash="hash",
|
||||
first_name="Multiple",
|
||||
last_name="Delete",
|
||||
is_active=True,
|
||||
is_superuser=False
|
||||
)
|
||||
db_session.add(user)
|
||||
db_session.commit()
|
||||
db_session.refresh(user)
|
||||
|
||||
user_id = user.id
|
||||
|
||||
# First soft delete
|
||||
first_deleted = user_crud.soft_delete(db_session, id=user_id)
|
||||
assert first_deleted is not None
|
||||
first_timestamp = first_deleted.deleted_at
|
||||
|
||||
# Restore
|
||||
user_crud.restore(db_session, id=user_id)
|
||||
|
||||
# Second soft delete
|
||||
second_deleted = user_crud.soft_delete(db_session, id=user_id)
|
||||
assert second_deleted is not None
|
||||
second_timestamp = second_deleted.deleted_at
|
||||
|
||||
# Timestamps should be different
|
||||
assert second_timestamp != first_timestamp
|
||||
assert second_timestamp > first_timestamp
|
||||
|
||||
def test_get_multi_with_filters_excludes_deleted(self, db_session):
|
||||
"""Test that get_multi_with_total with filters excludes deleted records."""
|
||||
# Create active and inactive users
|
||||
active_user = User(
|
||||
email="active_not_deleted@example.com",
|
||||
password_hash="hash",
|
||||
first_name="Active",
|
||||
last_name="NotDeleted",
|
||||
is_active=True,
|
||||
is_superuser=False
|
||||
)
|
||||
inactive_user = User(
|
||||
email="inactive_not_deleted@example.com",
|
||||
password_hash="hash",
|
||||
first_name="Inactive",
|
||||
last_name="NotDeleted",
|
||||
is_active=False,
|
||||
is_superuser=False
|
||||
)
|
||||
deleted_active_user = User(
|
||||
email="active_deleted@example.com",
|
||||
password_hash="hash",
|
||||
first_name="Active",
|
||||
last_name="Deleted",
|
||||
is_active=True,
|
||||
is_superuser=False
|
||||
)
|
||||
|
||||
db_session.add_all([active_user, inactive_user, deleted_active_user])
|
||||
db_session.commit()
|
||||
db_session.refresh(deleted_active_user)
|
||||
|
||||
# Soft delete one active user
|
||||
user_crud.soft_delete(db_session, id=deleted_active_user.id)
|
||||
|
||||
# Filter for active users - should only return non-deleted active user
|
||||
users, total = user_crud.get_multi_with_total(
|
||||
db_session,
|
||||
filters={"is_active": True}
|
||||
)
|
||||
|
||||
emails = [u.email for u in users]
|
||||
assert "active_not_deleted@example.com" in emails
|
||||
assert "active_deleted@example.com" not in emails
|
||||
assert "inactive_not_deleted@example.com" not in emails
|
||||
|
||||
def test_soft_delete_preserves_other_fields(self, db_session):
|
||||
"""Test that soft delete doesn't modify other fields."""
|
||||
# Create a user with specific data
|
||||
user = User(
|
||||
email="preserve@example.com",
|
||||
password_hash="original_hash",
|
||||
first_name="Preserve",
|
||||
last_name="Fields",
|
||||
phone_number="+1234567890",
|
||||
is_active=True,
|
||||
is_superuser=False,
|
||||
preferences={"theme": "dark"}
|
||||
)
|
||||
db_session.add(user)
|
||||
db_session.commit()
|
||||
db_session.refresh(user)
|
||||
|
||||
user_id = user.id
|
||||
original_email = user.email
|
||||
original_hash = user.password_hash
|
||||
original_first_name = user.first_name
|
||||
original_phone = user.phone_number
|
||||
original_preferences = user.preferences
|
||||
|
||||
# Soft delete
|
||||
deleted = user_crud.soft_delete(db_session, id=user_id)
|
||||
|
||||
# All other fields should remain unchanged
|
||||
assert deleted.email == original_email
|
||||
assert deleted.password_hash == original_hash
|
||||
assert deleted.first_name == original_first_name
|
||||
assert deleted.phone_number == original_phone
|
||||
assert deleted.preferences == original_preferences
|
||||
assert deleted.is_active is True # is_active unchanged
|
||||
703
backend/tests/crud/test_user.py
Executable file → Normal file
703
backend/tests/crud/test_user.py
Executable file → Normal file
@@ -1,125 +1,644 @@
|
||||
# tests/crud/test_user_async.py
|
||||
"""
|
||||
Comprehensive tests for async user CRUD operations.
|
||||
"""
|
||||
import pytest
|
||||
from datetime import datetime, timezone
|
||||
from uuid import uuid4
|
||||
|
||||
from app.crud.user import user as user_crud
|
||||
from app.models.user import User
|
||||
from app.schemas.users import UserCreate, UserUpdate
|
||||
|
||||
|
||||
def test_create_user(db_session, user_create_data):
|
||||
user_in = UserCreate(**user_create_data)
|
||||
user_obj = user_crud.create(db_session, obj_in=user_in)
|
||||
class TestGetByEmail:
|
||||
"""Tests for get_by_email method."""
|
||||
|
||||
assert user_obj.email == user_create_data["email"]
|
||||
assert user_obj.first_name == user_create_data["first_name"]
|
||||
assert user_obj.last_name == user_create_data["last_name"]
|
||||
assert user_obj.phone_number == user_create_data["phone_number"]
|
||||
assert user_obj.is_superuser == user_create_data["is_superuser"]
|
||||
assert user_obj.password_hash is not None
|
||||
assert user_obj.id is not None
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_by_email_success(self, async_test_db, async_test_user):
|
||||
"""Test getting user by email."""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
result = await user_crud.get_by_email(session, email=async_test_user.email)
|
||||
assert result is not None
|
||||
assert result.email == async_test_user.email
|
||||
assert result.id == async_test_user.id
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_by_email_not_found(self, async_test_db):
|
||||
"""Test getting non-existent email returns None."""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
result = await user_crud.get_by_email(session, email="nonexistent@example.com")
|
||||
assert result is None
|
||||
|
||||
|
||||
def test_get_user(db_session, mock_user):
|
||||
# Using mock_user fixture instead of creating new user
|
||||
stored_user = user_crud.get(db_session, id=mock_user.id)
|
||||
assert stored_user
|
||||
assert stored_user.id == mock_user.id
|
||||
assert stored_user.email == mock_user.email
|
||||
class TestCreate:
|
||||
"""Tests for create method."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_user_success(self, async_test_db):
|
||||
"""Test successfully creating a user_crud."""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
user_data = UserCreate(
|
||||
email="newuser@example.com",
|
||||
password="SecurePass123!",
|
||||
first_name="New",
|
||||
last_name="User",
|
||||
phone_number="+1234567890"
|
||||
)
|
||||
result = await user_crud.create(session, obj_in=user_data)
|
||||
|
||||
assert result.email == "newuser@example.com"
|
||||
assert result.first_name == "New"
|
||||
assert result.last_name == "User"
|
||||
assert result.phone_number == "+1234567890"
|
||||
assert result.is_active is True
|
||||
assert result.is_superuser is False
|
||||
assert result.password_hash is not None
|
||||
assert result.password_hash != "SecurePass123!" # Password should be hashed
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_superuser_success(self, async_test_db):
|
||||
"""Test creating a superuser."""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
user_data = UserCreate(
|
||||
email="superuser@example.com",
|
||||
password="SuperPass123!",
|
||||
first_name="Super",
|
||||
last_name="User",
|
||||
is_superuser=True
|
||||
)
|
||||
result = await user_crud.create(session, obj_in=user_data)
|
||||
|
||||
assert result.is_superuser is True
|
||||
assert result.email == "superuser@example.com"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_duplicate_email_fails(self, async_test_db, async_test_user):
|
||||
"""Test creating user with duplicate email raises ValueError."""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
user_data = UserCreate(
|
||||
email=async_test_user.email, # Duplicate email
|
||||
password="AnotherPass123!",
|
||||
first_name="Duplicate",
|
||||
last_name="User"
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
await user_crud.create(session, obj_in=user_data)
|
||||
|
||||
assert "already exists" in str(exc_info.value).lower()
|
||||
|
||||
|
||||
def test_get_user_by_email(db_session, mock_user):
|
||||
stored_user = user_crud.get_by_email(db_session, email=mock_user.email)
|
||||
assert stored_user
|
||||
assert stored_user.id == mock_user.id
|
||||
assert stored_user.email == mock_user.email
|
||||
class TestUpdate:
|
||||
"""Tests for update method."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_user_basic_fields(self, async_test_db, async_test_user):
|
||||
"""Test updating basic user fields."""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
# Get fresh copy of user
|
||||
user = await user_crud.get(session, id=str(async_test_user.id))
|
||||
|
||||
update_data = UserUpdate(
|
||||
first_name="Updated",
|
||||
last_name="Name",
|
||||
phone_number="+9876543210"
|
||||
)
|
||||
result = await user_crud.update(session, db_obj=user, obj_in=update_data)
|
||||
|
||||
assert result.first_name == "Updated"
|
||||
assert result.last_name == "Name"
|
||||
assert result.phone_number == "+9876543210"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_user_password(self, async_test_db):
|
||||
"""Test updating user password."""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
# Create a fresh user for this test
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
user_data = UserCreate(
|
||||
email="passwordtest@example.com",
|
||||
password="OldPassword123!",
|
||||
first_name="Pass",
|
||||
last_name="Test"
|
||||
)
|
||||
user = await user_crud.create(session, obj_in=user_data)
|
||||
user_id = user.id
|
||||
old_password_hash = user.password_hash
|
||||
|
||||
# Update the password
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
user = await user_crud.get(session, id=str(user_id))
|
||||
|
||||
update_data = UserUpdate(password="NewDifferentPassword123!")
|
||||
result = await user_crud.update(session, db_obj=user, obj_in=update_data)
|
||||
|
||||
await session.refresh(result)
|
||||
assert result.password_hash != old_password_hash
|
||||
assert result.password_hash is not None
|
||||
assert "NewDifferentPassword123!" not in result.password_hash # Should be hashed
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_user_with_dict(self, async_test_db, async_test_user):
|
||||
"""Test updating user with dictionary."""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
user = await user_crud.get(session, id=str(async_test_user.id))
|
||||
|
||||
update_dict = {"first_name": "DictUpdate"}
|
||||
result = await user_crud.update(session, db_obj=user, obj_in=update_dict)
|
||||
|
||||
assert result.first_name == "DictUpdate"
|
||||
|
||||
|
||||
def test_update_user(db_session, mock_user):
|
||||
update_data = UserUpdate(
|
||||
first_name="Updated",
|
||||
last_name="Name",
|
||||
phone_number="+9876543210"
|
||||
)
|
||||
class TestGetMultiWithTotal:
|
||||
"""Tests for get_multi_with_total method."""
|
||||
|
||||
updated_user = user_crud.update(db_session, db_obj=mock_user, obj_in=update_data)
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_multi_with_total_basic(self, async_test_db, async_test_user):
|
||||
"""Test basic pagination."""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
assert updated_user.first_name == "Updated"
|
||||
assert updated_user.last_name == "Name"
|
||||
assert updated_user.phone_number == "+9876543210"
|
||||
assert updated_user.email == mock_user.email
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
users, total = await user_crud.get_multi_with_total(
|
||||
session,
|
||||
skip=0,
|
||||
limit=10
|
||||
)
|
||||
assert total >= 1
|
||||
assert len(users) >= 1
|
||||
assert any(u.id == async_test_user.id for u in users)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_multi_with_total_sorting_asc(self, async_test_db):
|
||||
"""Test sorting in ascending order."""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
# Create multiple users
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
for i in range(3):
|
||||
user_data = UserCreate(
|
||||
email=f"sort{i}@example.com",
|
||||
password="SecurePass123!",
|
||||
first_name=f"User{i}",
|
||||
last_name="Test"
|
||||
)
|
||||
await user_crud.create(session, obj_in=user_data)
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
users, total = await user_crud.get_multi_with_total(
|
||||
session,
|
||||
skip=0,
|
||||
limit=10,
|
||||
sort_by="email",
|
||||
sort_order="asc"
|
||||
)
|
||||
|
||||
# Check if sorted (at least the test users)
|
||||
test_users = [u for u in users if u.email.startswith("sort")]
|
||||
if len(test_users) > 1:
|
||||
assert test_users[0].email < test_users[1].email
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_multi_with_total_sorting_desc(self, async_test_db):
|
||||
"""Test sorting in descending order."""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
# Create multiple users
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
for i in range(3):
|
||||
user_data = UserCreate(
|
||||
email=f"desc{i}@example.com",
|
||||
password="SecurePass123!",
|
||||
first_name=f"User{i}",
|
||||
last_name="Test"
|
||||
)
|
||||
await user_crud.create(session, obj_in=user_data)
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
users, total = await user_crud.get_multi_with_total(
|
||||
session,
|
||||
skip=0,
|
||||
limit=10,
|
||||
sort_by="email",
|
||||
sort_order="desc"
|
||||
)
|
||||
|
||||
# Check if sorted descending (at least the test users)
|
||||
test_users = [u for u in users if u.email.startswith("desc")]
|
||||
if len(test_users) > 1:
|
||||
assert test_users[0].email > test_users[1].email
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_multi_with_total_filtering(self, async_test_db):
|
||||
"""Test filtering by field."""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
# Create active and inactive users
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
active_user = UserCreate(
|
||||
email="active@example.com",
|
||||
password="SecurePass123!",
|
||||
first_name="Active",
|
||||
last_name="User"
|
||||
)
|
||||
await user_crud.create(session, obj_in=active_user)
|
||||
|
||||
inactive_user = UserCreate(
|
||||
email="inactive@example.com",
|
||||
password="SecurePass123!",
|
||||
first_name="Inactive",
|
||||
last_name="User"
|
||||
)
|
||||
created_inactive = await user_crud.create(session, obj_in=inactive_user)
|
||||
|
||||
# Deactivate the user
|
||||
await user_crud.update(
|
||||
session,
|
||||
db_obj=created_inactive,
|
||||
obj_in={"is_active": False}
|
||||
)
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
users, total = await user_crud.get_multi_with_total(
|
||||
session,
|
||||
skip=0,
|
||||
limit=100,
|
||||
filters={"is_active": True}
|
||||
)
|
||||
|
||||
# All returned users should be active
|
||||
assert all(u.is_active for u in users)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_multi_with_total_search(self, async_test_db):
|
||||
"""Test search functionality."""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
# Create user with unique name
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
user_data = UserCreate(
|
||||
email="searchable@example.com",
|
||||
password="SecurePass123!",
|
||||
first_name="Searchable",
|
||||
last_name="UserName"
|
||||
)
|
||||
await user_crud.create(session, obj_in=user_data)
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
users, total = await user_crud.get_multi_with_total(
|
||||
session,
|
||||
skip=0,
|
||||
limit=100,
|
||||
search="Searchable"
|
||||
)
|
||||
|
||||
assert total >= 1
|
||||
assert any(u.first_name == "Searchable" for u in users)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_multi_with_total_pagination(self, async_test_db):
|
||||
"""Test pagination with skip and limit."""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
# Create multiple users
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
for i in range(5):
|
||||
user_data = UserCreate(
|
||||
email=f"page{i}@example.com",
|
||||
password="SecurePass123!",
|
||||
first_name=f"Page{i}",
|
||||
last_name="User"
|
||||
)
|
||||
await user_crud.create(session, obj_in=user_data)
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
# Get first page
|
||||
users_page1, total = await user_crud.get_multi_with_total(
|
||||
session,
|
||||
skip=0,
|
||||
limit=2
|
||||
)
|
||||
|
||||
# Get second page
|
||||
users_page2, total2 = await user_crud.get_multi_with_total(
|
||||
session,
|
||||
skip=2,
|
||||
limit=2
|
||||
)
|
||||
|
||||
# Total should be same
|
||||
assert total == total2
|
||||
# Different users on different pages
|
||||
assert users_page1[0].id != users_page2[0].id
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_multi_with_total_validation_negative_skip(self, async_test_db):
|
||||
"""Test validation fails for negative skip."""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
await user_crud.get_multi_with_total(session, skip=-1, limit=10)
|
||||
|
||||
assert "skip must be non-negative" in str(exc_info.value)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_multi_with_total_validation_negative_limit(self, async_test_db):
|
||||
"""Test validation fails for negative limit."""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
await user_crud.get_multi_with_total(session, skip=0, limit=-1)
|
||||
|
||||
assert "limit must be non-negative" in str(exc_info.value)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_multi_with_total_validation_max_limit(self, async_test_db):
|
||||
"""Test validation fails for limit > 1000."""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
await user_crud.get_multi_with_total(session, skip=0, limit=1001)
|
||||
|
||||
assert "Maximum limit is 1000" in str(exc_info.value)
|
||||
|
||||
|
||||
def test_delete_user(db_session, mock_user):
|
||||
user_crud.remove(db_session, id=mock_user.id)
|
||||
deleted_user = user_crud.get(db_session, id=mock_user.id)
|
||||
assert deleted_user is None
|
||||
class TestBulkUpdateStatus:
|
||||
"""Tests for bulk_update_status method."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_bulk_update_status_success(self, async_test_db):
|
||||
"""Test bulk updating user status."""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
# Create multiple users
|
||||
user_ids = []
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
for i in range(3):
|
||||
user_data = UserCreate(
|
||||
email=f"bulk{i}@example.com",
|
||||
password="SecurePass123!",
|
||||
first_name=f"Bulk{i}",
|
||||
last_name="User"
|
||||
)
|
||||
user = await user_crud.create(session, obj_in=user_data)
|
||||
user_ids.append(user.id)
|
||||
|
||||
# Bulk deactivate
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
count = await user_crud.bulk_update_status(
|
||||
session,
|
||||
user_ids=user_ids,
|
||||
is_active=False
|
||||
)
|
||||
assert count == 3
|
||||
|
||||
# Verify all are inactive
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
for user_id in user_ids:
|
||||
user = await user_crud.get(session, id=str(user_id))
|
||||
assert user.is_active is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_bulk_update_status_empty_list(self, async_test_db):
|
||||
"""Test bulk update with empty list returns 0."""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
count = await user_crud.bulk_update_status(
|
||||
session,
|
||||
user_ids=[],
|
||||
is_active=False
|
||||
)
|
||||
assert count == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_bulk_update_status_reactivate(self, async_test_db):
|
||||
"""Test bulk reactivating users."""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
# Create inactive user
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
user_data = UserCreate(
|
||||
email="reactivate@example.com",
|
||||
password="SecurePass123!",
|
||||
first_name="Reactivate",
|
||||
last_name="User"
|
||||
)
|
||||
user = await user_crud.create(session, obj_in=user_data)
|
||||
# Deactivate
|
||||
await user_crud.update(session, db_obj=user, obj_in={"is_active": False})
|
||||
user_id = user.id
|
||||
|
||||
# Reactivate
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
count = await user_crud.bulk_update_status(
|
||||
session,
|
||||
user_ids=[user_id],
|
||||
is_active=True
|
||||
)
|
||||
assert count == 1
|
||||
|
||||
# Verify active
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
user = await user_crud.get(session, id=str(user_id))
|
||||
assert user.is_active is True
|
||||
|
||||
|
||||
def test_get_multi_users(db_session, mock_user, user_create_data):
|
||||
# Create additional users (mock_user is already in db)
|
||||
users_data = [
|
||||
{**user_create_data, "email": f"test{i}@example.com"}
|
||||
for i in range(2) # Creating 2 more users + mock_user = 3 total
|
||||
]
|
||||
class TestBulkSoftDelete:
|
||||
"""Tests for bulk_soft_delete method."""
|
||||
|
||||
for user_data in users_data:
|
||||
user_in = UserCreate(**user_data)
|
||||
user_crud.create(db_session, obj_in=user_in)
|
||||
@pytest.mark.asyncio
|
||||
async def test_bulk_soft_delete_success(self, async_test_db):
|
||||
"""Test bulk soft deleting users."""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
users = user_crud.get_multi(db_session, skip=0, limit=10)
|
||||
assert len(users) == 3
|
||||
assert all(isinstance(user, User) for user in users)
|
||||
# Create multiple users
|
||||
user_ids = []
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
for i in range(3):
|
||||
user_data = UserCreate(
|
||||
email=f"delete{i}@example.com",
|
||||
password="SecurePass123!",
|
||||
first_name=f"Delete{i}",
|
||||
last_name="User"
|
||||
)
|
||||
user = await user_crud.create(session, obj_in=user_data)
|
||||
user_ids.append(user.id)
|
||||
|
||||
# Bulk delete
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
count = await user_crud.bulk_soft_delete(
|
||||
session,
|
||||
user_ids=user_ids
|
||||
)
|
||||
assert count == 3
|
||||
|
||||
# Verify all are soft deleted
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
for user_id in user_ids:
|
||||
user = await user_crud.get(session, id=str(user_id))
|
||||
assert user.deleted_at is not None
|
||||
assert user.is_active is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_bulk_soft_delete_with_exclusion(self, async_test_db):
|
||||
"""Test bulk soft delete with excluded user_crud."""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
# Create multiple users
|
||||
user_ids = []
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
for i in range(3):
|
||||
user_data = UserCreate(
|
||||
email=f"exclude{i}@example.com",
|
||||
password="SecurePass123!",
|
||||
first_name=f"Exclude{i}",
|
||||
last_name="User"
|
||||
)
|
||||
user = await user_crud.create(session, obj_in=user_data)
|
||||
user_ids.append(user.id)
|
||||
|
||||
# Bulk delete, excluding first user
|
||||
exclude_id = user_ids[0]
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
count = await user_crud.bulk_soft_delete(
|
||||
session,
|
||||
user_ids=user_ids,
|
||||
exclude_user_id=exclude_id
|
||||
)
|
||||
assert count == 2 # Only 2 deleted
|
||||
|
||||
# Verify excluded user is NOT deleted
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
excluded_user = await user_crud.get(session, id=str(exclude_id))
|
||||
assert excluded_user.deleted_at is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_bulk_soft_delete_empty_list(self, async_test_db):
|
||||
"""Test bulk delete with empty list returns 0."""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
count = await user_crud.bulk_soft_delete(
|
||||
session,
|
||||
user_ids=[]
|
||||
)
|
||||
assert count == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_bulk_soft_delete_all_excluded(self, async_test_db):
|
||||
"""Test bulk delete where all users are excluded."""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
# Create user
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
user_data = UserCreate(
|
||||
email="onlyuser@example.com",
|
||||
password="SecurePass123!",
|
||||
first_name="Only",
|
||||
last_name="User"
|
||||
)
|
||||
user = await user_crud.create(session, obj_in=user_data)
|
||||
user_id = user.id
|
||||
|
||||
# Try to delete but exclude
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
count = await user_crud.bulk_soft_delete(
|
||||
session,
|
||||
user_ids=[user_id],
|
||||
exclude_user_id=user_id
|
||||
)
|
||||
assert count == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_bulk_soft_delete_already_deleted(self, async_test_db):
|
||||
"""Test bulk delete doesn't re-delete already deleted users."""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
# Create and delete user
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
user_data = UserCreate(
|
||||
email="predeleted@example.com",
|
||||
password="SecurePass123!",
|
||||
first_name="PreDeleted",
|
||||
last_name="User"
|
||||
)
|
||||
user = await user_crud.create(session, obj_in=user_data)
|
||||
user_id = user.id
|
||||
|
||||
# First deletion
|
||||
await user_crud.bulk_soft_delete(session, user_ids=[user_id])
|
||||
|
||||
# Try to delete again
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
count = await user_crud.bulk_soft_delete(
|
||||
session,
|
||||
user_ids=[user_id]
|
||||
)
|
||||
assert count == 0 # Already deleted
|
||||
|
||||
|
||||
def test_is_active(db_session, mock_user):
|
||||
assert user_crud.is_active(mock_user) is True
|
||||
class TestUtilityMethods:
|
||||
"""Tests for utility methods."""
|
||||
|
||||
# Test deactivating user
|
||||
update_data = UserUpdate(is_active=False)
|
||||
deactivated_user = user_crud.update(db_session, db_obj=mock_user, obj_in=update_data)
|
||||
assert user_crud.is_active(deactivated_user) is False
|
||||
@pytest.mark.asyncio
|
||||
async def test_is_active_true(self, async_test_db, async_test_user):
|
||||
"""Test is_active returns True for active user_crud."""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
user = await user_crud.get(session, id=str(async_test_user.id))
|
||||
assert user_crud.is_active(user) is True
|
||||
|
||||
def test_is_superuser(db_session, mock_user, user_create_data):
|
||||
# mock_user is regular user
|
||||
assert user_crud.is_superuser(mock_user) is False
|
||||
@pytest.mark.asyncio
|
||||
async def test_is_active_false(self, async_test_db):
|
||||
"""Test is_active returns False for inactive user_crud."""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
# Create superuser
|
||||
super_user_data = {**user_create_data, "email": "super@example.com", "is_superuser": True}
|
||||
super_user_in = UserCreate(**super_user_data)
|
||||
super_user = user_crud.create(db_session, obj_in=super_user_in)
|
||||
assert user_crud.is_superuser(super_user) is True
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
user_data = UserCreate(
|
||||
email="inactive2@example.com",
|
||||
password="SecurePass123!",
|
||||
first_name="Inactive",
|
||||
last_name="User"
|
||||
)
|
||||
user = await user_crud.create(session, obj_in=user_data)
|
||||
await user_crud.update(session, db_obj=user, obj_in={"is_active": False})
|
||||
|
||||
assert user_crud.is_active(user) is False
|
||||
|
||||
# Additional test cases
|
||||
def test_create_duplicate_email(db_session, mock_user):
|
||||
user_data = UserCreate(
|
||||
email=mock_user.email, # Try to create user with existing email
|
||||
password="TestPassword123!",
|
||||
first_name="Test",
|
||||
last_name="User"
|
||||
)
|
||||
with pytest.raises(Exception): # Should raise an integrity error
|
||||
user_crud.create(db_session, obj_in=user_data)
|
||||
@pytest.mark.asyncio
|
||||
async def test_is_superuser_true(self, async_test_db, async_test_superuser):
|
||||
"""Test is_superuser returns True for superuser."""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
user = await user_crud.get(session, id=str(async_test_superuser.id))
|
||||
assert user_crud.is_superuser(user) is True
|
||||
|
||||
def test_update_user_preferences(db_session, mock_user):
|
||||
preferences = {"theme": "dark", "notifications": True}
|
||||
update_data = UserUpdate(preferences=preferences)
|
||||
@pytest.mark.asyncio
|
||||
async def test_is_superuser_false(self, async_test_db, async_test_user):
|
||||
"""Test is_superuser returns False for regular user_crud."""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
updated_user = user_crud.update(db_session, db_obj=mock_user, obj_in=update_data)
|
||||
assert updated_user.preferences == preferences
|
||||
|
||||
|
||||
def test_get_multi_users_pagination(db_session, user_create_data):
|
||||
# Create 5 users
|
||||
for i in range(5):
|
||||
user_in = UserCreate(**{**user_create_data, "email": f"test{i}@example.com"})
|
||||
user_crud.create(db_session, obj_in=user_in)
|
||||
|
||||
# Test pagination
|
||||
first_page = user_crud.get_multi(db_session, skip=0, limit=2)
|
||||
second_page = user_crud.get_multi(db_session, skip=2, limit=2)
|
||||
|
||||
assert len(first_page) == 2
|
||||
assert len(second_page) == 2
|
||||
assert first_page[0].id != second_page[0].id
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
user = await user_crud.get(session, id=str(async_test_user.id))
|
||||
assert user_crud.is_superuser(user) is False
|
||||
|
||||
@@ -1,644 +0,0 @@
|
||||
# tests/crud/test_user_async.py
|
||||
"""
|
||||
Comprehensive tests for async user CRUD operations.
|
||||
"""
|
||||
import pytest
|
||||
from datetime import datetime, timezone
|
||||
from uuid import uuid4
|
||||
|
||||
from app.crud.user_async import user_async
|
||||
from app.models.user import User
|
||||
from app.schemas.users import UserCreate, UserUpdate
|
||||
|
||||
|
||||
class TestGetByEmail:
|
||||
"""Tests for get_by_email method."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_by_email_success(self, async_test_db, async_test_user):
|
||||
"""Test getting user by email."""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
result = await user_async.get_by_email(session, email=async_test_user.email)
|
||||
assert result is not None
|
||||
assert result.email == async_test_user.email
|
||||
assert result.id == async_test_user.id
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_by_email_not_found(self, async_test_db):
|
||||
"""Test getting non-existent email returns None."""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
result = await user_async.get_by_email(session, email="nonexistent@example.com")
|
||||
assert result is None
|
||||
|
||||
|
||||
class TestCreate:
|
||||
"""Tests for create method."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_user_success(self, async_test_db):
|
||||
"""Test successfully creating a user."""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
user_data = UserCreate(
|
||||
email="newuser@example.com",
|
||||
password="SecurePass123!",
|
||||
first_name="New",
|
||||
last_name="User",
|
||||
phone_number="+1234567890"
|
||||
)
|
||||
result = await user_async.create(session, obj_in=user_data)
|
||||
|
||||
assert result.email == "newuser@example.com"
|
||||
assert result.first_name == "New"
|
||||
assert result.last_name == "User"
|
||||
assert result.phone_number == "+1234567890"
|
||||
assert result.is_active is True
|
||||
assert result.is_superuser is False
|
||||
assert result.password_hash is not None
|
||||
assert result.password_hash != "SecurePass123!" # Password should be hashed
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_superuser_success(self, async_test_db):
|
||||
"""Test creating a superuser."""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
user_data = UserCreate(
|
||||
email="superuser@example.com",
|
||||
password="SuperPass123!",
|
||||
first_name="Super",
|
||||
last_name="User",
|
||||
is_superuser=True
|
||||
)
|
||||
result = await user_async.create(session, obj_in=user_data)
|
||||
|
||||
assert result.is_superuser is True
|
||||
assert result.email == "superuser@example.com"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_duplicate_email_fails(self, async_test_db, async_test_user):
|
||||
"""Test creating user with duplicate email raises ValueError."""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
user_data = UserCreate(
|
||||
email=async_test_user.email, # Duplicate email
|
||||
password="AnotherPass123!",
|
||||
first_name="Duplicate",
|
||||
last_name="User"
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
await user_async.create(session, obj_in=user_data)
|
||||
|
||||
assert "already exists" in str(exc_info.value).lower()
|
||||
|
||||
|
||||
class TestUpdate:
|
||||
"""Tests for update method."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_user_basic_fields(self, async_test_db, async_test_user):
|
||||
"""Test updating basic user fields."""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
# Get fresh copy of user
|
||||
user = await user_async.get(session, id=str(async_test_user.id))
|
||||
|
||||
update_data = UserUpdate(
|
||||
first_name="Updated",
|
||||
last_name="Name",
|
||||
phone_number="+9876543210"
|
||||
)
|
||||
result = await user_async.update(session, db_obj=user, obj_in=update_data)
|
||||
|
||||
assert result.first_name == "Updated"
|
||||
assert result.last_name == "Name"
|
||||
assert result.phone_number == "+9876543210"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_user_password(self, async_test_db):
|
||||
"""Test updating user password."""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
# Create a fresh user for this test
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
user_data = UserCreate(
|
||||
email="passwordtest@example.com",
|
||||
password="OldPassword123!",
|
||||
first_name="Pass",
|
||||
last_name="Test"
|
||||
)
|
||||
user = await user_async.create(session, obj_in=user_data)
|
||||
user_id = user.id
|
||||
old_password_hash = user.password_hash
|
||||
|
||||
# Update the password
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
user = await user_async.get(session, id=str(user_id))
|
||||
|
||||
update_data = UserUpdate(password="NewDifferentPassword123!")
|
||||
result = await user_async.update(session, db_obj=user, obj_in=update_data)
|
||||
|
||||
await session.refresh(result)
|
||||
assert result.password_hash != old_password_hash
|
||||
assert result.password_hash is not None
|
||||
assert "NewDifferentPassword123!" not in result.password_hash # Should be hashed
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_user_with_dict(self, async_test_db, async_test_user):
|
||||
"""Test updating user with dictionary."""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
user = await user_async.get(session, id=str(async_test_user.id))
|
||||
|
||||
update_dict = {"first_name": "DictUpdate"}
|
||||
result = await user_async.update(session, db_obj=user, obj_in=update_dict)
|
||||
|
||||
assert result.first_name == "DictUpdate"
|
||||
|
||||
|
||||
class TestGetMultiWithTotal:
|
||||
"""Tests for get_multi_with_total method."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_multi_with_total_basic(self, async_test_db, async_test_user):
|
||||
"""Test basic pagination."""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
users, total = await user_async.get_multi_with_total(
|
||||
session,
|
||||
skip=0,
|
||||
limit=10
|
||||
)
|
||||
assert total >= 1
|
||||
assert len(users) >= 1
|
||||
assert any(u.id == async_test_user.id for u in users)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_multi_with_total_sorting_asc(self, async_test_db):
|
||||
"""Test sorting in ascending order."""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
# Create multiple users
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
for i in range(3):
|
||||
user_data = UserCreate(
|
||||
email=f"sort{i}@example.com",
|
||||
password="SecurePass123!",
|
||||
first_name=f"User{i}",
|
||||
last_name="Test"
|
||||
)
|
||||
await user_async.create(session, obj_in=user_data)
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
users, total = await user_async.get_multi_with_total(
|
||||
session,
|
||||
skip=0,
|
||||
limit=10,
|
||||
sort_by="email",
|
||||
sort_order="asc"
|
||||
)
|
||||
|
||||
# Check if sorted (at least the test users)
|
||||
test_users = [u for u in users if u.email.startswith("sort")]
|
||||
if len(test_users) > 1:
|
||||
assert test_users[0].email < test_users[1].email
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_multi_with_total_sorting_desc(self, async_test_db):
|
||||
"""Test sorting in descending order."""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
# Create multiple users
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
for i in range(3):
|
||||
user_data = UserCreate(
|
||||
email=f"desc{i}@example.com",
|
||||
password="SecurePass123!",
|
||||
first_name=f"User{i}",
|
||||
last_name="Test"
|
||||
)
|
||||
await user_async.create(session, obj_in=user_data)
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
users, total = await user_async.get_multi_with_total(
|
||||
session,
|
||||
skip=0,
|
||||
limit=10,
|
||||
sort_by="email",
|
||||
sort_order="desc"
|
||||
)
|
||||
|
||||
# Check if sorted descending (at least the test users)
|
||||
test_users = [u for u in users if u.email.startswith("desc")]
|
||||
if len(test_users) > 1:
|
||||
assert test_users[0].email > test_users[1].email
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_multi_with_total_filtering(self, async_test_db):
|
||||
"""Test filtering by field."""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
# Create active and inactive users
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
active_user = UserCreate(
|
||||
email="active@example.com",
|
||||
password="SecurePass123!",
|
||||
first_name="Active",
|
||||
last_name="User"
|
||||
)
|
||||
await user_async.create(session, obj_in=active_user)
|
||||
|
||||
inactive_user = UserCreate(
|
||||
email="inactive@example.com",
|
||||
password="SecurePass123!",
|
||||
first_name="Inactive",
|
||||
last_name="User"
|
||||
)
|
||||
created_inactive = await user_async.create(session, obj_in=inactive_user)
|
||||
|
||||
# Deactivate the user
|
||||
await user_async.update(
|
||||
session,
|
||||
db_obj=created_inactive,
|
||||
obj_in={"is_active": False}
|
||||
)
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
users, total = await user_async.get_multi_with_total(
|
||||
session,
|
||||
skip=0,
|
||||
limit=100,
|
||||
filters={"is_active": True}
|
||||
)
|
||||
|
||||
# All returned users should be active
|
||||
assert all(u.is_active for u in users)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_multi_with_total_search(self, async_test_db):
|
||||
"""Test search functionality."""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
# Create user with unique name
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
user_data = UserCreate(
|
||||
email="searchable@example.com",
|
||||
password="SecurePass123!",
|
||||
first_name="Searchable",
|
||||
last_name="UserName"
|
||||
)
|
||||
await user_async.create(session, obj_in=user_data)
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
users, total = await user_async.get_multi_with_total(
|
||||
session,
|
||||
skip=0,
|
||||
limit=100,
|
||||
search="Searchable"
|
||||
)
|
||||
|
||||
assert total >= 1
|
||||
assert any(u.first_name == "Searchable" for u in users)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_multi_with_total_pagination(self, async_test_db):
|
||||
"""Test pagination with skip and limit."""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
# Create multiple users
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
for i in range(5):
|
||||
user_data = UserCreate(
|
||||
email=f"page{i}@example.com",
|
||||
password="SecurePass123!",
|
||||
first_name=f"Page{i}",
|
||||
last_name="User"
|
||||
)
|
||||
await user_async.create(session, obj_in=user_data)
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
# Get first page
|
||||
users_page1, total = await user_async.get_multi_with_total(
|
||||
session,
|
||||
skip=0,
|
||||
limit=2
|
||||
)
|
||||
|
||||
# Get second page
|
||||
users_page2, total2 = await user_async.get_multi_with_total(
|
||||
session,
|
||||
skip=2,
|
||||
limit=2
|
||||
)
|
||||
|
||||
# Total should be same
|
||||
assert total == total2
|
||||
# Different users on different pages
|
||||
assert users_page1[0].id != users_page2[0].id
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_multi_with_total_validation_negative_skip(self, async_test_db):
|
||||
"""Test validation fails for negative skip."""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
await user_async.get_multi_with_total(session, skip=-1, limit=10)
|
||||
|
||||
assert "skip must be non-negative" in str(exc_info.value)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_multi_with_total_validation_negative_limit(self, async_test_db):
|
||||
"""Test validation fails for negative limit."""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
await user_async.get_multi_with_total(session, skip=0, limit=-1)
|
||||
|
||||
assert "limit must be non-negative" in str(exc_info.value)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_multi_with_total_validation_max_limit(self, async_test_db):
|
||||
"""Test validation fails for limit > 1000."""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
await user_async.get_multi_with_total(session, skip=0, limit=1001)
|
||||
|
||||
assert "Maximum limit is 1000" in str(exc_info.value)
|
||||
|
||||
|
||||
class TestBulkUpdateStatus:
|
||||
"""Tests for bulk_update_status method."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_bulk_update_status_success(self, async_test_db):
|
||||
"""Test bulk updating user status."""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
# Create multiple users
|
||||
user_ids = []
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
for i in range(3):
|
||||
user_data = UserCreate(
|
||||
email=f"bulk{i}@example.com",
|
||||
password="SecurePass123!",
|
||||
first_name=f"Bulk{i}",
|
||||
last_name="User"
|
||||
)
|
||||
user = await user_async.create(session, obj_in=user_data)
|
||||
user_ids.append(user.id)
|
||||
|
||||
# Bulk deactivate
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
count = await user_async.bulk_update_status(
|
||||
session,
|
||||
user_ids=user_ids,
|
||||
is_active=False
|
||||
)
|
||||
assert count == 3
|
||||
|
||||
# Verify all are inactive
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
for user_id in user_ids:
|
||||
user = await user_async.get(session, id=str(user_id))
|
||||
assert user.is_active is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_bulk_update_status_empty_list(self, async_test_db):
|
||||
"""Test bulk update with empty list returns 0."""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
count = await user_async.bulk_update_status(
|
||||
session,
|
||||
user_ids=[],
|
||||
is_active=False
|
||||
)
|
||||
assert count == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_bulk_update_status_reactivate(self, async_test_db):
|
||||
"""Test bulk reactivating users."""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
# Create inactive user
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
user_data = UserCreate(
|
||||
email="reactivate@example.com",
|
||||
password="SecurePass123!",
|
||||
first_name="Reactivate",
|
||||
last_name="User"
|
||||
)
|
||||
user = await user_async.create(session, obj_in=user_data)
|
||||
# Deactivate
|
||||
await user_async.update(session, db_obj=user, obj_in={"is_active": False})
|
||||
user_id = user.id
|
||||
|
||||
# Reactivate
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
count = await user_async.bulk_update_status(
|
||||
session,
|
||||
user_ids=[user_id],
|
||||
is_active=True
|
||||
)
|
||||
assert count == 1
|
||||
|
||||
# Verify active
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
user = await user_async.get(session, id=str(user_id))
|
||||
assert user.is_active is True
|
||||
|
||||
|
||||
class TestBulkSoftDelete:
|
||||
"""Tests for bulk_soft_delete method."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_bulk_soft_delete_success(self, async_test_db):
|
||||
"""Test bulk soft deleting users."""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
# Create multiple users
|
||||
user_ids = []
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
for i in range(3):
|
||||
user_data = UserCreate(
|
||||
email=f"delete{i}@example.com",
|
||||
password="SecurePass123!",
|
||||
first_name=f"Delete{i}",
|
||||
last_name="User"
|
||||
)
|
||||
user = await user_async.create(session, obj_in=user_data)
|
||||
user_ids.append(user.id)
|
||||
|
||||
# Bulk delete
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
count = await user_async.bulk_soft_delete(
|
||||
session,
|
||||
user_ids=user_ids
|
||||
)
|
||||
assert count == 3
|
||||
|
||||
# Verify all are soft deleted
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
for user_id in user_ids:
|
||||
user = await user_async.get(session, id=str(user_id))
|
||||
assert user.deleted_at is not None
|
||||
assert user.is_active is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_bulk_soft_delete_with_exclusion(self, async_test_db):
|
||||
"""Test bulk soft delete with excluded user."""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
# Create multiple users
|
||||
user_ids = []
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
for i in range(3):
|
||||
user_data = UserCreate(
|
||||
email=f"exclude{i}@example.com",
|
||||
password="SecurePass123!",
|
||||
first_name=f"Exclude{i}",
|
||||
last_name="User"
|
||||
)
|
||||
user = await user_async.create(session, obj_in=user_data)
|
||||
user_ids.append(user.id)
|
||||
|
||||
# Bulk delete, excluding first user
|
||||
exclude_id = user_ids[0]
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
count = await user_async.bulk_soft_delete(
|
||||
session,
|
||||
user_ids=user_ids,
|
||||
exclude_user_id=exclude_id
|
||||
)
|
||||
assert count == 2 # Only 2 deleted
|
||||
|
||||
# Verify excluded user is NOT deleted
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
excluded_user = await user_async.get(session, id=str(exclude_id))
|
||||
assert excluded_user.deleted_at is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_bulk_soft_delete_empty_list(self, async_test_db):
|
||||
"""Test bulk delete with empty list returns 0."""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
count = await user_async.bulk_soft_delete(
|
||||
session,
|
||||
user_ids=[]
|
||||
)
|
||||
assert count == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_bulk_soft_delete_all_excluded(self, async_test_db):
|
||||
"""Test bulk delete where all users are excluded."""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
# Create user
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
user_data = UserCreate(
|
||||
email="onlyuser@example.com",
|
||||
password="SecurePass123!",
|
||||
first_name="Only",
|
||||
last_name="User"
|
||||
)
|
||||
user = await user_async.create(session, obj_in=user_data)
|
||||
user_id = user.id
|
||||
|
||||
# Try to delete but exclude
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
count = await user_async.bulk_soft_delete(
|
||||
session,
|
||||
user_ids=[user_id],
|
||||
exclude_user_id=user_id
|
||||
)
|
||||
assert count == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_bulk_soft_delete_already_deleted(self, async_test_db):
|
||||
"""Test bulk delete doesn't re-delete already deleted users."""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
# Create and delete user
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
user_data = UserCreate(
|
||||
email="predeleted@example.com",
|
||||
password="SecurePass123!",
|
||||
first_name="PreDeleted",
|
||||
last_name="User"
|
||||
)
|
||||
user = await user_async.create(session, obj_in=user_data)
|
||||
user_id = user.id
|
||||
|
||||
# First deletion
|
||||
await user_async.bulk_soft_delete(session, user_ids=[user_id])
|
||||
|
||||
# Try to delete again
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
count = await user_async.bulk_soft_delete(
|
||||
session,
|
||||
user_ids=[user_id]
|
||||
)
|
||||
assert count == 0 # Already deleted
|
||||
|
||||
|
||||
class TestUtilityMethods:
|
||||
"""Tests for utility methods."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_is_active_true(self, async_test_db, async_test_user):
|
||||
"""Test is_active returns True for active user."""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
user = await user_async.get(session, id=str(async_test_user.id))
|
||||
assert user_async.is_active(user) is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_is_active_false(self, async_test_db):
|
||||
"""Test is_active returns False for inactive user."""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
user_data = UserCreate(
|
||||
email="inactive2@example.com",
|
||||
password="SecurePass123!",
|
||||
first_name="Inactive",
|
||||
last_name="User"
|
||||
)
|
||||
user = await user_async.create(session, obj_in=user_data)
|
||||
await user_async.update(session, db_obj=user, obj_in={"is_active": False})
|
||||
|
||||
assert user_async.is_active(user) is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_is_superuser_true(self, async_test_db, async_test_superuser):
|
||||
"""Test is_superuser returns True for superuser."""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
user = await user_async.get(session, id=str(async_test_superuser.id))
|
||||
assert user_async.is_superuser(user) is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_is_superuser_false(self, async_test_db, async_test_user):
|
||||
"""Test is_superuser returns False for regular user."""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
user = await user_async.get(session, id=str(async_test_user.id))
|
||||
assert user_async.is_superuser(user) is False
|
||||
@@ -64,8 +64,8 @@ class TestCleanupExpiredSessions:
|
||||
session.add_all([active_session, old_expired_session, recent_expired_session])
|
||||
await session.commit()
|
||||
|
||||
# Mock AsyncSessionLocal to return our test session
|
||||
with patch('app.services.session_cleanup.AsyncSessionLocal', return_value=AsyncTestingSessionLocal()):
|
||||
# Mock SessionLocal to return our test session
|
||||
with patch('app.services.session_cleanup.SessionLocal', return_value=AsyncTestingSessionLocal()):
|
||||
from app.services.session_cleanup import cleanup_expired_sessions
|
||||
deleted_count = await cleanup_expired_sessions(keep_days=30)
|
||||
|
||||
@@ -102,7 +102,7 @@ class TestCleanupExpiredSessions:
|
||||
session.add(active)
|
||||
await session.commit()
|
||||
|
||||
with patch('app.services.session_cleanup.AsyncSessionLocal', return_value=AsyncTestingSessionLocal()):
|
||||
with patch('app.services.session_cleanup.SessionLocal', return_value=AsyncTestingSessionLocal()):
|
||||
from app.services.session_cleanup import cleanup_expired_sessions
|
||||
deleted_count = await cleanup_expired_sessions(keep_days=30)
|
||||
|
||||
@@ -113,7 +113,7 @@ class TestCleanupExpiredSessions:
|
||||
"""Test cleanup with no sessions in database."""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
with patch('app.services.session_cleanup.AsyncSessionLocal', return_value=AsyncTestingSessionLocal()):
|
||||
with patch('app.services.session_cleanup.SessionLocal', return_value=AsyncTestingSessionLocal()):
|
||||
from app.services.session_cleanup import cleanup_expired_sessions
|
||||
deleted_count = await cleanup_expired_sessions(keep_days=30)
|
||||
|
||||
@@ -139,7 +139,7 @@ class TestCleanupExpiredSessions:
|
||||
session.add(today_expired)
|
||||
await session.commit()
|
||||
|
||||
with patch('app.services.session_cleanup.AsyncSessionLocal', return_value=AsyncTestingSessionLocal()):
|
||||
with patch('app.services.session_cleanup.SessionLocal', return_value=AsyncTestingSessionLocal()):
|
||||
from app.services.session_cleanup import cleanup_expired_sessions
|
||||
deleted_count = await cleanup_expired_sessions(keep_days=0)
|
||||
|
||||
@@ -169,7 +169,7 @@ class TestCleanupExpiredSessions:
|
||||
session.add_all(sessions_to_add)
|
||||
await session.commit()
|
||||
|
||||
with patch('app.services.session_cleanup.AsyncSessionLocal', return_value=AsyncTestingSessionLocal()):
|
||||
with patch('app.services.session_cleanup.SessionLocal', return_value=AsyncTestingSessionLocal()):
|
||||
from app.services.session_cleanup import cleanup_expired_sessions
|
||||
deleted_count = await cleanup_expired_sessions(keep_days=30)
|
||||
|
||||
@@ -181,7 +181,7 @@ class TestCleanupExpiredSessions:
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
# Mock session_crud.cleanup_expired to raise error
|
||||
with patch('app.services.session_cleanup.AsyncSessionLocal', return_value=AsyncTestingSessionLocal()):
|
||||
with patch('app.services.session_cleanup.SessionLocal', return_value=AsyncTestingSessionLocal()):
|
||||
with patch('app.services.session_cleanup.session_crud.cleanup_expired') as mock_cleanup:
|
||||
mock_cleanup.side_effect = Exception("Database connection lost")
|
||||
|
||||
@@ -247,7 +247,7 @@ class TestGetSessionStatistics:
|
||||
|
||||
await session.commit()
|
||||
|
||||
with patch('app.services.session_cleanup.AsyncSessionLocal', return_value=AsyncTestingSessionLocal()):
|
||||
with patch('app.services.session_cleanup.SessionLocal', return_value=AsyncTestingSessionLocal()):
|
||||
from app.services.session_cleanup import get_session_statistics
|
||||
stats = await get_session_statistics()
|
||||
|
||||
@@ -261,7 +261,7 @@ class TestGetSessionStatistics:
|
||||
"""Test getting statistics with no sessions."""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
with patch('app.services.session_cleanup.AsyncSessionLocal', return_value=AsyncTestingSessionLocal()):
|
||||
with patch('app.services.session_cleanup.SessionLocal', return_value=AsyncTestingSessionLocal()):
|
||||
from app.services.session_cleanup import get_session_statistics
|
||||
stats = await get_session_statistics()
|
||||
|
||||
@@ -283,7 +283,7 @@ class TestGetSessionStatistics:
|
||||
async def mock_session_local():
|
||||
yield mock_session
|
||||
|
||||
with patch('app.services.session_cleanup.AsyncSessionLocal', return_value=mock_session_local()):
|
||||
with patch('app.services.session_cleanup.SessionLocal', return_value=mock_session_local()):
|
||||
from app.services.session_cleanup import get_session_statistics
|
||||
stats = await get_session_statistics()
|
||||
|
||||
@@ -317,7 +317,7 @@ class TestConcurrentCleanup:
|
||||
|
||||
# Run two cleanups concurrently
|
||||
# Use side_effect to return fresh session instances for each call
|
||||
with patch('app.services.session_cleanup.AsyncSessionLocal', side_effect=lambda: AsyncTestingSessionLocal()):
|
||||
with patch('app.services.session_cleanup.SessionLocal', side_effect=lambda: AsyncTestingSessionLocal()):
|
||||
from app.services.session_cleanup import cleanup_expired_sessions
|
||||
results = await asyncio.gather(
|
||||
cleanup_expired_sessions(keep_days=30),
|
||||
|
||||
@@ -1,223 +0,0 @@
|
||||
# tests/test_init_db.py
|
||||
import pytest
|
||||
from unittest.mock import patch, MagicMock
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.init_db import init_db
|
||||
from app.models.user import User
|
||||
from app.schemas.users import UserCreate
|
||||
|
||||
|
||||
class TestInitDB:
|
||||
"""Tests for database initialization"""
|
||||
|
||||
def test_init_db_creates_superuser_when_not_exists(self, db_session, monkeypatch):
|
||||
"""Test that init_db creates superuser when it doesn't exist"""
|
||||
# Set environment variables
|
||||
monkeypatch.setenv("FIRST_SUPERUSER_EMAIL", "admin@test.com")
|
||||
monkeypatch.setenv("FIRST_SUPERUSER_PASSWORD", "TestPassword123!")
|
||||
|
||||
# Reload settings to pick up environment variables
|
||||
from app.core import config
|
||||
import importlib
|
||||
importlib.reload(config)
|
||||
from app.core.config import settings
|
||||
|
||||
# Mock user_crud to return None (user doesn't exist)
|
||||
with patch('app.init_db.user_crud') as mock_crud:
|
||||
mock_crud.get_by_email.return_value = None
|
||||
|
||||
# Create a mock user to return from create
|
||||
from datetime import datetime, timezone
|
||||
import uuid
|
||||
mock_user = User(
|
||||
id=uuid.uuid4(),
|
||||
email="admin@test.com",
|
||||
password_hash="hashed",
|
||||
first_name="Admin",
|
||||
last_name="User",
|
||||
is_active=True,
|
||||
is_superuser=True,
|
||||
created_at=datetime.now(timezone.utc),
|
||||
updated_at=datetime.now(timezone.utc)
|
||||
)
|
||||
mock_crud.create.return_value = mock_user
|
||||
|
||||
# Call init_db
|
||||
user = init_db(db_session)
|
||||
|
||||
# Verify user was created
|
||||
assert user is not None
|
||||
assert user.email == "admin@test.com"
|
||||
assert user.is_superuser is True
|
||||
mock_crud.create.assert_called_once()
|
||||
|
||||
def test_init_db_returns_existing_superuser(self, db_session, monkeypatch):
|
||||
"""Test that init_db returns existing superuser without creating new one"""
|
||||
# Set environment variables
|
||||
monkeypatch.setenv("FIRST_SUPERUSER_EMAIL", "existing@test.com")
|
||||
monkeypatch.setenv("FIRST_SUPERUSER_PASSWORD", "TestPassword123!")
|
||||
|
||||
# Reload settings
|
||||
from app.core import config
|
||||
import importlib
|
||||
importlib.reload(config)
|
||||
|
||||
# Mock user_crud to return existing user
|
||||
with patch('app.init_db.user_crud') as mock_crud:
|
||||
from datetime import datetime, timezone
|
||||
import uuid
|
||||
existing_user = User(
|
||||
id=uuid.uuid4(),
|
||||
email="existing@test.com",
|
||||
password_hash="hashed",
|
||||
first_name="Existing",
|
||||
last_name="User",
|
||||
is_active=True,
|
||||
is_superuser=True,
|
||||
created_at=datetime.now(timezone.utc),
|
||||
updated_at=datetime.now(timezone.utc)
|
||||
)
|
||||
mock_crud.get_by_email.return_value = existing_user
|
||||
|
||||
# Call init_db
|
||||
user = init_db(db_session)
|
||||
|
||||
# Verify existing user was returned
|
||||
assert user is not None
|
||||
assert user.email == "existing@test.com"
|
||||
# create should NOT be called
|
||||
mock_crud.create.assert_not_called()
|
||||
|
||||
def test_init_db_uses_defaults_when_env_not_set(self, db_session):
|
||||
"""Test that init_db uses default credentials when env vars not set"""
|
||||
# Mock settings to return None for superuser credentials
|
||||
with patch('app.init_db.settings') as mock_settings:
|
||||
mock_settings.FIRST_SUPERUSER_EMAIL = None
|
||||
mock_settings.FIRST_SUPERUSER_PASSWORD = None
|
||||
|
||||
# Mock user_crud
|
||||
with patch('app.init_db.user_crud') as mock_crud:
|
||||
mock_crud.get_by_email.return_value = None
|
||||
|
||||
from datetime import datetime, timezone
|
||||
import uuid
|
||||
mock_user = User(
|
||||
id=uuid.uuid4(),
|
||||
email="admin@example.com",
|
||||
password_hash="hashed",
|
||||
first_name="Admin",
|
||||
last_name="User",
|
||||
is_active=True,
|
||||
is_superuser=True,
|
||||
created_at=datetime.now(timezone.utc),
|
||||
updated_at=datetime.now(timezone.utc)
|
||||
)
|
||||
mock_crud.create.return_value = mock_user
|
||||
|
||||
# Call init_db
|
||||
with patch('app.init_db.logger') as mock_logger:
|
||||
user = init_db(db_session)
|
||||
|
||||
# Verify default email was used
|
||||
mock_crud.get_by_email.assert_called_with(db_session, email="admin@example.com")
|
||||
# Verify warning was logged since credentials not set
|
||||
assert mock_logger.warning.called
|
||||
|
||||
def test_init_db_handles_creation_error(self, db_session, monkeypatch):
|
||||
"""Test that init_db handles errors during user creation"""
|
||||
# Set environment variables
|
||||
monkeypatch.setenv("FIRST_SUPERUSER_EMAIL", "admin@test.com")
|
||||
monkeypatch.setenv("FIRST_SUPERUSER_PASSWORD", "TestPassword123!")
|
||||
|
||||
# Reload settings
|
||||
from app.core import config
|
||||
import importlib
|
||||
importlib.reload(config)
|
||||
|
||||
# Mock user_crud to raise an exception
|
||||
with patch('app.init_db.user_crud') as mock_crud:
|
||||
mock_crud.get_by_email.return_value = None
|
||||
mock_crud.create.side_effect = Exception("Database error")
|
||||
|
||||
# Call init_db and expect exception
|
||||
with pytest.raises(Exception) as exc_info:
|
||||
init_db(db_session)
|
||||
|
||||
assert "Database error" in str(exc_info.value)
|
||||
|
||||
def test_init_db_logs_superuser_creation(self, db_session, monkeypatch):
|
||||
"""Test that init_db logs appropriate messages"""
|
||||
# Set environment variables
|
||||
monkeypatch.setenv("FIRST_SUPERUSER_EMAIL", "admin@test.com")
|
||||
monkeypatch.setenv("FIRST_SUPERUSER_PASSWORD", "TestPassword123!")
|
||||
|
||||
# Reload settings
|
||||
from app.core import config
|
||||
import importlib
|
||||
importlib.reload(config)
|
||||
|
||||
# Mock user_crud
|
||||
with patch('app.init_db.user_crud') as mock_crud:
|
||||
mock_crud.get_by_email.return_value = None
|
||||
|
||||
from datetime import datetime, timezone
|
||||
import uuid
|
||||
mock_user = User(
|
||||
id=uuid.uuid4(),
|
||||
email="admin@test.com",
|
||||
password_hash="hashed",
|
||||
first_name="Admin",
|
||||
last_name="User",
|
||||
is_active=True,
|
||||
is_superuser=True,
|
||||
created_at=datetime.now(timezone.utc),
|
||||
updated_at=datetime.now(timezone.utc)
|
||||
)
|
||||
mock_crud.create.return_value = mock_user
|
||||
|
||||
# Call init_db with logger mock
|
||||
with patch('app.init_db.logger') as mock_logger:
|
||||
user = init_db(db_session)
|
||||
|
||||
# Verify info log was called
|
||||
assert mock_logger.info.called
|
||||
info_call_args = str(mock_logger.info.call_args)
|
||||
assert "Created first superuser" in info_call_args
|
||||
|
||||
def test_init_db_logs_existing_user(self, db_session, monkeypatch):
|
||||
"""Test that init_db logs when user already exists"""
|
||||
# Set environment variables
|
||||
monkeypatch.setenv("FIRST_SUPERUSER_EMAIL", "existing@test.com")
|
||||
monkeypatch.setenv("FIRST_SUPERUSER_PASSWORD", "TestPassword123!")
|
||||
|
||||
# Reload settings
|
||||
from app.core import config
|
||||
import importlib
|
||||
importlib.reload(config)
|
||||
|
||||
# Mock user_crud to return existing user
|
||||
with patch('app.init_db.user_crud') as mock_crud:
|
||||
from datetime import datetime, timezone
|
||||
import uuid
|
||||
existing_user = User(
|
||||
id=uuid.uuid4(),
|
||||
email="existing@test.com",
|
||||
password_hash="hashed",
|
||||
first_name="Existing",
|
||||
last_name="User",
|
||||
is_active=True,
|
||||
is_superuser=True,
|
||||
created_at=datetime.now(timezone.utc),
|
||||
updated_at=datetime.now(timezone.utc)
|
||||
)
|
||||
mock_crud.get_by_email.return_value = existing_user
|
||||
|
||||
# Call init_db with logger mock
|
||||
with patch('app.init_db.logger') as mock_logger:
|
||||
user = init_db(db_session)
|
||||
|
||||
# Verify info log was called
|
||||
assert mock_logger.info.called
|
||||
info_call_args = str(mock_logger.info.call_args)
|
||||
assert "already exists" in info_call_args.lower()
|
||||
Reference in New Issue
Block a user