Add pyproject.toml for consolidated project configuration and replace Black, isort, and Flake8 with Ruff

- Introduced `pyproject.toml` to centralize backend tool configurations (e.g., Ruff, mypy, coverage, pytest).
- Replaced Black, isort, and Flake8 with Ruff for linting, formatting, and import sorting.
- Updated `requirements.txt` to include Ruff and remove replaced tools.
- Added `Makefile` to streamline development workflows with commands for linting, formatting, type-checking, testing, and cleanup.
This commit is contained in:
2025-11-10 11:55:15 +01:00
parent a5c671c133
commit c589b565f0
86 changed files with 4572 additions and 3956 deletions

View File

@@ -2,14 +2,16 @@
"""
Comprehensive tests for CRUDBase class covering all error paths and edge cases.
"""
from datetime import UTC
from unittest.mock import patch
from uuid import uuid4
import pytest
from uuid import uuid4, UUID
from sqlalchemy.exc import IntegrityError, OperationalError, DataError
from sqlalchemy.exc import DataError, IntegrityError, OperationalError
from sqlalchemy.orm import joinedload
from unittest.mock import AsyncMock, patch, MagicMock
from app.crud.user import user as user_crud
from app.models.user import User
from app.schemas.users import UserCreate, UserUpdate
@@ -19,7 +21,7 @@ class TestCRUDBaseGet:
@pytest.mark.asyncio
async def test_get_with_invalid_uuid_string(self, async_test_db):
"""Test get with invalid UUID string returns None."""
test_engine, SessionLocal = async_test_db
_test_engine, SessionLocal = async_test_db
async with SessionLocal() as session:
result = await user_crud.get(session, id="invalid-uuid")
@@ -28,7 +30,7 @@ class TestCRUDBaseGet:
@pytest.mark.asyncio
async def test_get_with_invalid_uuid_type(self, async_test_db):
"""Test get with invalid UUID type returns None."""
test_engine, SessionLocal = async_test_db
_test_engine, SessionLocal = async_test_db
async with SessionLocal() as session:
result = await user_crud.get(session, id=12345) # int instead of UUID
@@ -37,7 +39,7 @@ class TestCRUDBaseGet:
@pytest.mark.asyncio
async def test_get_with_uuid_object(self, async_test_db, async_test_user):
"""Test get with UUID object instead of string."""
test_engine, SessionLocal = async_test_db
_test_engine, SessionLocal = async_test_db
async with SessionLocal() as session:
# Pass UUID object directly
@@ -48,26 +50,24 @@ class TestCRUDBaseGet:
@pytest.mark.asyncio
async def test_get_with_options(self, async_test_db, async_test_user):
"""Test get with eager loading options (tests lines 76-78)."""
test_engine, SessionLocal = async_test_db
_test_engine, SessionLocal = async_test_db
async with SessionLocal() as session:
# Test that options parameter is accepted and doesn't error
# We pass an empty list which still tests the code path
result = await user_crud.get(
session,
id=str(async_test_user.id),
options=[]
session, id=str(async_test_user.id), options=[]
)
assert result is not None
@pytest.mark.asyncio
async def test_get_database_error(self, async_test_db):
"""Test get handles database errors properly."""
test_engine, SessionLocal = async_test_db
_test_engine, SessionLocal = async_test_db
async with SessionLocal() as session:
# Mock execute to raise an exception
with patch.object(session, 'execute', side_effect=Exception("DB error")):
with patch.object(session, "execute", side_effect=Exception("DB error")):
with pytest.raises(Exception, match="DB error"):
await user_crud.get(session, id=str(uuid4()))
@@ -78,7 +78,7 @@ class TestCRUDBaseGetMulti:
@pytest.mark.asyncio
async def test_get_multi_negative_skip(self, async_test_db):
"""Test get_multi with negative skip raises ValueError."""
test_engine, SessionLocal = async_test_db
_test_engine, SessionLocal = async_test_db
async with SessionLocal() as session:
with pytest.raises(ValueError, match="skip must be non-negative"):
@@ -87,7 +87,7 @@ class TestCRUDBaseGetMulti:
@pytest.mark.asyncio
async def test_get_multi_negative_limit(self, async_test_db):
"""Test get_multi with negative limit raises ValueError."""
test_engine, SessionLocal = async_test_db
_test_engine, SessionLocal = async_test_db
async with SessionLocal() as session:
with pytest.raises(ValueError, match="limit must be non-negative"):
@@ -96,7 +96,7 @@ class TestCRUDBaseGetMulti:
@pytest.mark.asyncio
async def test_get_multi_limit_too_large(self, async_test_db):
"""Test get_multi with limit > 1000 raises ValueError."""
test_engine, SessionLocal = async_test_db
_test_engine, SessionLocal = async_test_db
async with SessionLocal() as session:
with pytest.raises(ValueError, match="Maximum limit is 1000"):
@@ -105,25 +105,20 @@ class TestCRUDBaseGetMulti:
@pytest.mark.asyncio
async def test_get_multi_with_options(self, async_test_db, async_test_user):
"""Test get_multi with eager loading options (tests lines 118-120)."""
test_engine, SessionLocal = async_test_db
_test_engine, SessionLocal = async_test_db
async with SessionLocal() as session:
# Test that options parameter is accepted
results = await user_crud.get_multi(
session,
skip=0,
limit=10,
options=[]
)
results = await user_crud.get_multi(session, skip=0, limit=10, options=[])
assert isinstance(results, list)
@pytest.mark.asyncio
async def test_get_multi_database_error(self, async_test_db):
"""Test get_multi handles database errors."""
test_engine, SessionLocal = async_test_db
_test_engine, SessionLocal = async_test_db
async with SessionLocal() as session:
with patch.object(session, 'execute', side_effect=Exception("DB error")):
with patch.object(session, "execute", side_effect=Exception("DB error")):
with pytest.raises(Exception, match="DB error"):
await user_crud.get_multi(session)
@@ -134,7 +129,7 @@ class TestCRUDBaseCreate:
@pytest.mark.asyncio
async def test_create_duplicate_unique_field(self, async_test_db, async_test_user):
"""Test create with duplicate unique field raises ValueError."""
test_engine, SessionLocal = async_test_db
_test_engine, SessionLocal = async_test_db
async with SessionLocal() as session:
# Try to create user with duplicate email
@@ -142,7 +137,7 @@ class TestCRUDBaseCreate:
email=async_test_user.email, # Duplicate!
password="TestPassword123!",
first_name="Test",
last_name="Duplicate"
last_name="Duplicate",
)
with pytest.raises(ValueError, match="already exists"):
@@ -151,22 +146,23 @@ class TestCRUDBaseCreate:
@pytest.mark.asyncio
async def test_create_integrity_error_non_duplicate(self, async_test_db):
"""Test create with non-duplicate IntegrityError."""
test_engine, SessionLocal = async_test_db
_test_engine, SessionLocal = async_test_db
async with SessionLocal() as session:
# Mock commit to raise IntegrityError without "unique" in message
original_commit = session.commit
async def mock_commit():
error = IntegrityError("statement", {}, Exception("foreign key violation"))
error = IntegrityError(
"statement", {}, Exception("foreign key violation")
)
raise error
with patch.object(session, 'commit', side_effect=mock_commit):
with patch.object(session, "commit", side_effect=mock_commit):
user_data = UserCreate(
email="test@example.com",
password="TestPassword123!",
first_name="Test",
last_name="User"
last_name="User",
)
with pytest.raises(ValueError, match="Database integrity error"):
@@ -175,15 +171,21 @@ class TestCRUDBaseCreate:
@pytest.mark.asyncio
async def test_create_operational_error(self, async_test_db):
"""Test create with OperationalError (user CRUD catches as generic Exception)."""
test_engine, SessionLocal = async_test_db
_test_engine, SessionLocal = async_test_db
async with SessionLocal() as session:
with patch.object(session, 'commit', side_effect=OperationalError("statement", {}, Exception("connection lost"))):
with patch.object(
session,
"commit",
side_effect=OperationalError(
"statement", {}, Exception("connection lost")
),
):
user_data = UserCreate(
email="test@example.com",
password="TestPassword123!",
first_name="Test",
last_name="User"
last_name="User",
)
# User CRUD catches this as generic Exception and re-raises
@@ -193,15 +195,19 @@ class TestCRUDBaseCreate:
@pytest.mark.asyncio
async def test_create_data_error(self, async_test_db):
"""Test create with DataError (user CRUD catches as generic Exception)."""
test_engine, SessionLocal = async_test_db
_test_engine, SessionLocal = async_test_db
async with SessionLocal() as session:
with patch.object(session, 'commit', side_effect=DataError("statement", {}, Exception("invalid data"))):
with patch.object(
session,
"commit",
side_effect=DataError("statement", {}, Exception("invalid data")),
):
user_data = UserCreate(
email="test@example.com",
password="TestPassword123!",
first_name="Test",
last_name="User"
last_name="User",
)
# User CRUD catches this as generic Exception and re-raises
@@ -211,15 +217,17 @@ class TestCRUDBaseCreate:
@pytest.mark.asyncio
async def test_create_unexpected_error(self, async_test_db):
"""Test create with unexpected exception."""
test_engine, SessionLocal = async_test_db
_test_engine, SessionLocal = async_test_db
async with SessionLocal() as session:
with patch.object(session, 'commit', side_effect=RuntimeError("Unexpected error")):
with patch.object(
session, "commit", side_effect=RuntimeError("Unexpected error")
):
user_data = UserCreate(
email="test@example.com",
password="TestPassword123!",
first_name="Test",
last_name="User"
last_name="User",
)
with pytest.raises(RuntimeError, match="Unexpected error"):
@@ -232,16 +240,17 @@ class TestCRUDBaseUpdate:
@pytest.mark.asyncio
async def test_update_duplicate_unique_field(self, async_test_db, async_test_user):
"""Test update with duplicate unique field raises ValueError."""
test_engine, SessionLocal = async_test_db
_test_engine, SessionLocal = async_test_db
# Create another user
async with SessionLocal() as session:
from app.crud.user import user as user_crud
user2_data = UserCreate(
email="user2@example.com",
password="TestPassword123!",
first_name="User",
last_name="Two"
last_name="Two",
)
user2 = await user_crud.create(session, obj_in=user2_data)
await session.commit()
@@ -250,63 +259,89 @@ class TestCRUDBaseUpdate:
async with SessionLocal() as session:
user2_obj = await user_crud.get(session, id=str(user2.id))
with patch.object(session, 'commit', side_effect=IntegrityError("statement", {}, Exception("UNIQUE constraint failed"))):
with patch.object(
session,
"commit",
side_effect=IntegrityError(
"statement", {}, Exception("UNIQUE constraint failed")
),
):
update_data = UserUpdate(email=async_test_user.email)
with pytest.raises(ValueError, match="already exists"):
await user_crud.update(session, db_obj=user2_obj, obj_in=update_data)
await user_crud.update(
session, db_obj=user2_obj, obj_in=update_data
)
@pytest.mark.asyncio
async def test_update_with_dict(self, async_test_db, async_test_user):
"""Test update with dict instead of schema."""
test_engine, SessionLocal = async_test_db
_test_engine, SessionLocal = async_test_db
async with SessionLocal() as session:
user = await user_crud.get(session, id=str(async_test_user.id))
# Update with dict (tests lines 164-165)
updated = await user_crud.update(
session,
db_obj=user,
obj_in={"first_name": "UpdatedName"}
session, db_obj=user, obj_in={"first_name": "UpdatedName"}
)
assert updated.first_name == "UpdatedName"
@pytest.mark.asyncio
async def test_update_integrity_error(self, async_test_db, async_test_user):
"""Test update with IntegrityError."""
test_engine, SessionLocal = async_test_db
_test_engine, SessionLocal = async_test_db
async with SessionLocal() as session:
user = await user_crud.get(session, id=str(async_test_user.id))
with patch.object(session, 'commit', side_effect=IntegrityError("statement", {}, Exception("constraint failed"))):
with patch.object(
session,
"commit",
side_effect=IntegrityError(
"statement", {}, Exception("constraint failed")
),
):
with pytest.raises(ValueError, match="Database integrity error"):
await user_crud.update(session, db_obj=user, obj_in={"first_name": "Test"})
await user_crud.update(
session, db_obj=user, obj_in={"first_name": "Test"}
)
@pytest.mark.asyncio
async def test_update_operational_error(self, async_test_db, async_test_user):
"""Test update with OperationalError."""
test_engine, SessionLocal = async_test_db
_test_engine, SessionLocal = async_test_db
async with SessionLocal() as session:
user = await user_crud.get(session, id=str(async_test_user.id))
with patch.object(session, 'commit', side_effect=OperationalError("statement", {}, Exception("connection error"))):
with patch.object(
session,
"commit",
side_effect=OperationalError(
"statement", {}, Exception("connection error")
),
):
with pytest.raises(ValueError, match="Database operation failed"):
await user_crud.update(session, db_obj=user, obj_in={"first_name": "Test"})
await user_crud.update(
session, db_obj=user, obj_in={"first_name": "Test"}
)
@pytest.mark.asyncio
async def test_update_unexpected_error(self, async_test_db, async_test_user):
"""Test update with unexpected error."""
test_engine, SessionLocal = async_test_db
_test_engine, SessionLocal = async_test_db
async with SessionLocal() as session:
user = await user_crud.get(session, id=str(async_test_user.id))
with patch.object(session, 'commit', side_effect=RuntimeError("Unexpected")):
with patch.object(
session, "commit", side_effect=RuntimeError("Unexpected")
):
with pytest.raises(RuntimeError):
await user_crud.update(session, db_obj=user, obj_in={"first_name": "Test"})
await user_crud.update(
session, db_obj=user, obj_in={"first_name": "Test"}
)
class TestCRUDBaseRemove:
@@ -315,7 +350,7 @@ class TestCRUDBaseRemove:
@pytest.mark.asyncio
async def test_remove_invalid_uuid(self, async_test_db):
"""Test remove with invalid UUID returns None."""
test_engine, SessionLocal = async_test_db
_test_engine, SessionLocal = async_test_db
async with SessionLocal() as session:
result = await user_crud.remove(session, id="invalid-uuid")
@@ -324,7 +359,7 @@ class TestCRUDBaseRemove:
@pytest.mark.asyncio
async def test_remove_with_uuid_object(self, async_test_db, async_test_user):
"""Test remove with UUID object."""
test_engine, SessionLocal = async_test_db
_test_engine, SessionLocal = async_test_db
# Create a user to delete
async with SessionLocal() as session:
@@ -332,7 +367,7 @@ class TestCRUDBaseRemove:
email="todelete@example.com",
password="TestPassword123!",
first_name="To",
last_name="Delete"
last_name="Delete",
)
user = await user_crud.create(session, obj_in=user_data)
user_id = user.id
@@ -347,7 +382,7 @@ class TestCRUDBaseRemove:
@pytest.mark.asyncio
async def test_remove_nonexistent(self, async_test_db):
"""Test remove of nonexistent record returns None."""
test_engine, SessionLocal = async_test_db
_test_engine, SessionLocal = async_test_db
async with SessionLocal() as session:
result = await user_crud.remove(session, id=str(uuid4()))
@@ -356,21 +391,31 @@ class TestCRUDBaseRemove:
@pytest.mark.asyncio
async def test_remove_integrity_error(self, async_test_db, async_test_user):
"""Test remove with IntegrityError (foreign key constraint)."""
test_engine, SessionLocal = async_test_db
_test_engine, SessionLocal = async_test_db
async with SessionLocal() as session:
# Mock delete to raise IntegrityError
with patch.object(session, 'commit', side_effect=IntegrityError("statement", {}, Exception("FOREIGN KEY constraint"))):
with pytest.raises(ValueError, match="Cannot delete.*referenced by other records"):
with patch.object(
session,
"commit",
side_effect=IntegrityError(
"statement", {}, Exception("FOREIGN KEY constraint")
),
):
with pytest.raises(
ValueError, match="Cannot delete.*referenced by other records"
):
await user_crud.remove(session, id=str(async_test_user.id))
@pytest.mark.asyncio
async def test_remove_unexpected_error(self, async_test_db, async_test_user):
"""Test remove with unexpected error."""
test_engine, SessionLocal = async_test_db
_test_engine, SessionLocal = async_test_db
async with SessionLocal() as session:
with patch.object(session, 'commit', side_effect=RuntimeError("Unexpected")):
with patch.object(
session, "commit", side_effect=RuntimeError("Unexpected")
):
with pytest.raises(RuntimeError):
await user_crud.remove(session, id=str(async_test_user.id))
@@ -381,10 +426,12 @@ class TestCRUDBaseGetMultiWithTotal:
@pytest.mark.asyncio
async def test_get_multi_with_total_basic(self, async_test_db, async_test_user):
"""Test get_multi_with_total basic functionality."""
test_engine, SessionLocal = async_test_db
_test_engine, SessionLocal = async_test_db
async with SessionLocal() as session:
items, total = await user_crud.get_multi_with_total(session, skip=0, limit=10)
items, total = await user_crud.get_multi_with_total(
session, skip=0, limit=10
)
assert isinstance(items, list)
assert isinstance(total, int)
assert total >= 1 # At least the test user
@@ -392,7 +439,7 @@ class TestCRUDBaseGetMultiWithTotal:
@pytest.mark.asyncio
async def test_get_multi_with_total_negative_skip(self, async_test_db):
"""Test get_multi_with_total with negative skip raises ValueError."""
test_engine, SessionLocal = async_test_db
_test_engine, SessionLocal = async_test_db
async with SessionLocal() as session:
with pytest.raises(ValueError, match="skip must be non-negative"):
@@ -401,7 +448,7 @@ class TestCRUDBaseGetMultiWithTotal:
@pytest.mark.asyncio
async def test_get_multi_with_total_negative_limit(self, async_test_db):
"""Test get_multi_with_total with negative limit raises ValueError."""
test_engine, SessionLocal = async_test_db
_test_engine, SessionLocal = async_test_db
async with SessionLocal() as session:
with pytest.raises(ValueError, match="limit must be non-negative"):
@@ -410,28 +457,34 @@ class TestCRUDBaseGetMultiWithTotal:
@pytest.mark.asyncio
async def test_get_multi_with_total_limit_too_large(self, async_test_db):
"""Test get_multi_with_total with limit > 1000 raises ValueError."""
test_engine, SessionLocal = async_test_db
_test_engine, SessionLocal = async_test_db
async with SessionLocal() as session:
with pytest.raises(ValueError, match="Maximum limit is 1000"):
await user_crud.get_multi_with_total(session, limit=1001)
@pytest.mark.asyncio
async def test_get_multi_with_total_with_filters(self, async_test_db, async_test_user):
async def test_get_multi_with_total_with_filters(
self, async_test_db, async_test_user
):
"""Test get_multi_with_total with filters."""
test_engine, SessionLocal = async_test_db
_test_engine, SessionLocal = async_test_db
async with SessionLocal() as session:
filters = {"email": async_test_user.email}
items, total = await user_crud.get_multi_with_total(session, filters=filters)
items, total = await user_crud.get_multi_with_total(
session, filters=filters
)
assert total == 1
assert len(items) == 1
assert items[0].email == async_test_user.email
@pytest.mark.asyncio
async def test_get_multi_with_total_with_sorting_asc(self, async_test_db, async_test_user):
async def test_get_multi_with_total_with_sorting_asc(
self, async_test_db, async_test_user
):
"""Test get_multi_with_total with ascending sort."""
test_engine, SessionLocal = async_test_db
_test_engine, SessionLocal = async_test_db
# Create additional users
async with SessionLocal() as session:
@@ -439,13 +492,13 @@ class TestCRUDBaseGetMultiWithTotal:
email="aaa@example.com",
password="TestPassword123!",
first_name="AAA",
last_name="User"
last_name="User",
)
user_data2 = UserCreate(
email="zzz@example.com",
password="TestPassword123!",
first_name="ZZZ",
last_name="User"
last_name="User",
)
await user_crud.create(session, obj_in=user_data1)
await user_crud.create(session, obj_in=user_data2)
@@ -460,9 +513,11 @@ class TestCRUDBaseGetMultiWithTotal:
assert items[0].email == "aaa@example.com"
@pytest.mark.asyncio
async def test_get_multi_with_total_with_sorting_desc(self, async_test_db, async_test_user):
async def test_get_multi_with_total_with_sorting_desc(
self, async_test_db, async_test_user
):
"""Test get_multi_with_total with descending sort."""
test_engine, SessionLocal = async_test_db
_test_engine, SessionLocal = async_test_db
# Create additional users
async with SessionLocal() as session:
@@ -470,20 +525,20 @@ class TestCRUDBaseGetMultiWithTotal:
email="bbb@example.com",
password="TestPassword123!",
first_name="BBB",
last_name="User"
last_name="User",
)
user_data2 = UserCreate(
email="ccc@example.com",
password="TestPassword123!",
first_name="CCC",
last_name="User"
last_name="User",
)
await user_crud.create(session, obj_in=user_data1)
await user_crud.create(session, obj_in=user_data2)
await session.commit()
async with SessionLocal() as session:
items, total = await user_crud.get_multi_with_total(
items, _total = await user_crud.get_multi_with_total(
session, sort_by="email", sort_order="desc", limit=1
)
assert len(items) == 1
@@ -492,7 +547,7 @@ class TestCRUDBaseGetMultiWithTotal:
@pytest.mark.asyncio
async def test_get_multi_with_total_with_pagination(self, async_test_db):
"""Test get_multi_with_total pagination works correctly."""
test_engine, SessionLocal = async_test_db
_test_engine, SessionLocal = async_test_db
# Create minimal users for pagination test (3 instead of 5)
async with SessionLocal() as session:
@@ -501,19 +556,23 @@ class TestCRUDBaseGetMultiWithTotal:
email=f"user{i}@example.com",
password="TestPassword123!",
first_name=f"User{i}",
last_name="Test"
last_name="Test",
)
await user_crud.create(session, obj_in=user_data)
await session.commit()
async with SessionLocal() as session:
# Get first page
items1, total = await user_crud.get_multi_with_total(session, skip=0, limit=2)
items1, total = await user_crud.get_multi_with_total(
session, skip=0, limit=2
)
assert len(items1) == 2
assert total >= 3
# Get second page
items2, total2 = await user_crud.get_multi_with_total(session, skip=2, limit=2)
items2, total2 = await user_crud.get_multi_with_total(
session, skip=2, limit=2
)
assert len(items2) >= 1
assert total2 == total
@@ -529,7 +588,7 @@ class TestCRUDBaseCount:
@pytest.mark.asyncio
async def test_count_basic(self, async_test_db, async_test_user):
"""Test count returns correct number."""
test_engine, SessionLocal = async_test_db
_test_engine, SessionLocal = async_test_db
async with SessionLocal() as session:
count = await user_crud.count(session)
@@ -539,7 +598,7 @@ class TestCRUDBaseCount:
@pytest.mark.asyncio
async def test_count_multiple_users(self, async_test_db, async_test_user):
"""Test count with multiple users."""
test_engine, SessionLocal = async_test_db
_test_engine, SessionLocal = async_test_db
# Create additional users
async with SessionLocal() as session:
@@ -549,13 +608,13 @@ class TestCRUDBaseCount:
email="count1@example.com",
password="TestPassword123!",
first_name="Count",
last_name="One"
last_name="One",
)
user_data2 = UserCreate(
email="count2@example.com",
password="TestPassword123!",
first_name="Count",
last_name="Two"
last_name="Two",
)
await user_crud.create(session, obj_in=user_data1)
await user_crud.create(session, obj_in=user_data2)
@@ -568,10 +627,10 @@ class TestCRUDBaseCount:
@pytest.mark.asyncio
async def test_count_database_error(self, async_test_db):
"""Test count handles database errors."""
test_engine, SessionLocal = async_test_db
_test_engine, SessionLocal = async_test_db
async with SessionLocal() as session:
with patch.object(session, 'execute', side_effect=Exception("DB error")):
with patch.object(session, "execute", side_effect=Exception("DB error")):
with pytest.raises(Exception, match="DB error"):
await user_crud.count(session)
@@ -582,7 +641,7 @@ class TestCRUDBaseExists:
@pytest.mark.asyncio
async def test_exists_true(self, async_test_db, async_test_user):
"""Test exists returns True for existing record."""
test_engine, SessionLocal = async_test_db
_test_engine, SessionLocal = async_test_db
async with SessionLocal() as session:
result = await user_crud.exists(session, id=str(async_test_user.id))
@@ -591,7 +650,7 @@ class TestCRUDBaseExists:
@pytest.mark.asyncio
async def test_exists_false(self, async_test_db):
"""Test exists returns False for non-existent record."""
test_engine, SessionLocal = async_test_db
_test_engine, SessionLocal = async_test_db
async with SessionLocal() as session:
result = await user_crud.exists(session, id=str(uuid4()))
@@ -600,7 +659,7 @@ class TestCRUDBaseExists:
@pytest.mark.asyncio
async def test_exists_invalid_uuid(self, async_test_db):
"""Test exists returns False for invalid UUID."""
test_engine, SessionLocal = async_test_db
_test_engine, SessionLocal = async_test_db
async with SessionLocal() as session:
result = await user_crud.exists(session, id="invalid-uuid")
@@ -613,7 +672,7 @@ class TestCRUDBaseSoftDelete:
@pytest.mark.asyncio
async def test_soft_delete_success(self, async_test_db):
"""Test soft delete sets deleted_at timestamp."""
test_engine, SessionLocal = async_test_db
_test_engine, SessionLocal = async_test_db
# Create a user to soft delete
async with SessionLocal() as session:
@@ -621,7 +680,7 @@ class TestCRUDBaseSoftDelete:
email="softdelete@example.com",
password="TestPassword123!",
first_name="Soft",
last_name="Delete"
last_name="Delete",
)
user = await user_crud.create(session, obj_in=user_data)
user_id = user.id
@@ -636,7 +695,7 @@ class TestCRUDBaseSoftDelete:
@pytest.mark.asyncio
async def test_soft_delete_invalid_uuid(self, async_test_db):
"""Test soft delete with invalid UUID returns None."""
test_engine, SessionLocal = async_test_db
_test_engine, SessionLocal = async_test_db
async with SessionLocal() as session:
result = await user_crud.soft_delete(session, id="invalid-uuid")
@@ -645,7 +704,7 @@ class TestCRUDBaseSoftDelete:
@pytest.mark.asyncio
async def test_soft_delete_nonexistent(self, async_test_db):
"""Test soft delete of nonexistent record returns None."""
test_engine, SessionLocal = async_test_db
_test_engine, SessionLocal = async_test_db
async with SessionLocal() as session:
result = await user_crud.soft_delete(session, id=str(uuid4()))
@@ -654,7 +713,7 @@ class TestCRUDBaseSoftDelete:
@pytest.mark.asyncio
async def test_soft_delete_with_uuid_object(self, async_test_db):
"""Test soft delete with UUID object."""
test_engine, SessionLocal = async_test_db
_test_engine, SessionLocal = async_test_db
# Create a user to soft delete
async with SessionLocal() as session:
@@ -662,7 +721,7 @@ class TestCRUDBaseSoftDelete:
email="softdelete2@example.com",
password="TestPassword123!",
first_name="Soft",
last_name="Delete2"
last_name="Delete2",
)
user = await user_crud.create(session, obj_in=user_data)
user_id = user.id
@@ -681,7 +740,7 @@ class TestCRUDBaseRestore:
@pytest.mark.asyncio
async def test_restore_success(self, async_test_db):
"""Test restore clears deleted_at timestamp."""
test_engine, SessionLocal = async_test_db
_test_engine, SessionLocal = async_test_db
# Create and soft delete a user
async with SessionLocal() as session:
@@ -689,7 +748,7 @@ class TestCRUDBaseRestore:
email="restore@example.com",
password="TestPassword123!",
first_name="Restore",
last_name="Test"
last_name="Test",
)
user = await user_crud.create(session, obj_in=user_data)
user_id = user.id
@@ -707,7 +766,7 @@ class TestCRUDBaseRestore:
@pytest.mark.asyncio
async def test_restore_invalid_uuid(self, async_test_db):
"""Test restore with invalid UUID returns None."""
test_engine, SessionLocal = async_test_db
_test_engine, SessionLocal = async_test_db
async with SessionLocal() as session:
result = await user_crud.restore(session, id="invalid-uuid")
@@ -716,7 +775,7 @@ class TestCRUDBaseRestore:
@pytest.mark.asyncio
async def test_restore_nonexistent(self, async_test_db):
"""Test restore of nonexistent record returns None."""
test_engine, SessionLocal = async_test_db
_test_engine, SessionLocal = async_test_db
async with SessionLocal() as session:
result = await user_crud.restore(session, id=str(uuid4()))
@@ -725,7 +784,7 @@ class TestCRUDBaseRestore:
@pytest.mark.asyncio
async def test_restore_not_deleted(self, async_test_db, async_test_user):
"""Test restore of non-deleted record returns None."""
test_engine, SessionLocal = async_test_db
_test_engine, SessionLocal = async_test_db
async with SessionLocal() as session:
# Try to restore a user that's not deleted
@@ -735,7 +794,7 @@ class TestCRUDBaseRestore:
@pytest.mark.asyncio
async def test_restore_with_uuid_object(self, async_test_db):
"""Test restore with UUID object."""
test_engine, SessionLocal = async_test_db
_test_engine, SessionLocal = async_test_db
# Create and soft delete a user
async with SessionLocal() as session:
@@ -743,7 +802,7 @@ class TestCRUDBaseRestore:
email="restore2@example.com",
password="TestPassword123!",
first_name="Restore",
last_name="Test2"
last_name="Test2",
)
user = await user_crud.create(session, obj_in=user_data)
user_id = user.id
@@ -765,7 +824,7 @@ class TestCRUDBasePaginationValidation:
@pytest.mark.asyncio
async def test_get_multi_with_total_negative_skip(self, async_test_db):
"""Test that negative skip raises ValueError."""
test_engine, SessionLocal = async_test_db
_test_engine, SessionLocal = async_test_db
async with SessionLocal() as session:
with pytest.raises(ValueError, match="skip must be non-negative"):
@@ -774,7 +833,7 @@ class TestCRUDBasePaginationValidation:
@pytest.mark.asyncio
async def test_get_multi_with_total_negative_limit(self, async_test_db):
"""Test that negative limit raises ValueError."""
test_engine, SessionLocal = async_test_db
_test_engine, SessionLocal = async_test_db
async with SessionLocal() as session:
with pytest.raises(ValueError, match="limit must be non-negative"):
@@ -783,23 +842,22 @@ class TestCRUDBasePaginationValidation:
@pytest.mark.asyncio
async def test_get_multi_with_total_limit_too_large(self, async_test_db):
"""Test that limit > 1000 raises ValueError."""
test_engine, SessionLocal = async_test_db
_test_engine, SessionLocal = async_test_db
async with SessionLocal() as session:
with pytest.raises(ValueError, match="Maximum limit is 1000"):
await user_crud.get_multi_with_total(session, skip=0, limit=1001)
@pytest.mark.asyncio
async def test_get_multi_with_total_with_filters(self, async_test_db, async_test_user):
async def test_get_multi_with_total_with_filters(
self, async_test_db, async_test_user
):
"""Test pagination with filters (covers lines 270-273)."""
test_engine, SessionLocal = async_test_db
_test_engine, SessionLocal = async_test_db
async with SessionLocal() as session:
users, total = await user_crud.get_multi_with_total(
session,
skip=0,
limit=10,
filters={"is_active": True}
session, skip=0, limit=10, filters={"is_active": True}
)
assert isinstance(users, list)
assert total >= 0
@@ -807,30 +865,22 @@ class TestCRUDBasePaginationValidation:
@pytest.mark.asyncio
async def test_get_multi_with_total_with_sorting_desc(self, async_test_db):
"""Test pagination with descending sort (covers lines 283-284)."""
test_engine, SessionLocal = async_test_db
_test_engine, SessionLocal = async_test_db
async with SessionLocal() as session:
users, total = await user_crud.get_multi_with_total(
session,
skip=0,
limit=10,
sort_by="created_at",
sort_order="desc"
users, _total = await user_crud.get_multi_with_total(
session, skip=0, limit=10, sort_by="created_at", sort_order="desc"
)
assert isinstance(users, list)
@pytest.mark.asyncio
async def test_get_multi_with_total_with_sorting_asc(self, async_test_db):
"""Test pagination with ascending sort (covers lines 285-286)."""
test_engine, SessionLocal = async_test_db
_test_engine, SessionLocal = async_test_db
async with SessionLocal() as session:
users, total = await user_crud.get_multi_with_total(
session,
skip=0,
limit=10,
sort_by="created_at",
sort_order="asc"
users, _total = await user_crud.get_multi_with_total(
session, skip=0, limit=10, sort_by="created_at", sort_order="asc"
)
assert isinstance(users, list)
@@ -842,13 +892,15 @@ class TestCRUDBaseModelsWithoutSoftDelete:
"""
@pytest.mark.asyncio
async def test_soft_delete_model_without_deleted_at(self, async_test_db, async_test_user):
async def test_soft_delete_model_without_deleted_at(
self, async_test_db, async_test_user
):
"""Test soft_delete on Organization model (no deleted_at) raises ValueError (covers lines 342-343)."""
test_engine, SessionLocal = async_test_db
_test_engine, SessionLocal = async_test_db
# Create an organization (which doesn't have deleted_at)
from app.models.organization import Organization
from app.crud.organization import organization as org_crud
from app.models.organization import Organization
async with SessionLocal() as session:
org = Organization(name="Test Org", slug="test-org")
@@ -864,11 +916,11 @@ class TestCRUDBaseModelsWithoutSoftDelete:
@pytest.mark.asyncio
async def test_restore_model_without_deleted_at(self, async_test_db):
"""Test restore on Organization model (no deleted_at) raises ValueError (covers lines 383-384)."""
test_engine, SessionLocal = async_test_db
_test_engine, SessionLocal = async_test_db
# Create an organization (which doesn't have deleted_at)
from app.models.organization import Organization
from app.crud.organization import organization as org_crud
from app.models.organization import Organization
async with SessionLocal() as session:
org = Organization(name="Restore Test", slug="restore-test")
@@ -889,14 +941,17 @@ class TestCRUDBaseEagerLoadingWithRealOptions:
"""
@pytest.mark.asyncio
async def test_get_with_real_eager_loading_options(self, async_test_db, async_test_user):
async def test_get_with_real_eager_loading_options(
self, async_test_db, async_test_user
):
"""Test get() with actual eager loading options (covers lines 77-78)."""
from datetime import datetime, timedelta, timezone
test_engine, SessionLocal = async_test_db
from datetime import datetime, timedelta
_test_engine, SessionLocal = async_test_db
# Create a session for the user
from app.models.user_session import UserSession
from app.crud.session import session as session_crud
from app.models.user_session import UserSession
async with SessionLocal() as session:
user_session = UserSession(
@@ -905,8 +960,8 @@ class TestCRUDBaseEagerLoadingWithRealOptions:
device_id="test-device",
ip_address="192.168.1.1",
user_agent="Test Agent",
last_used_at=datetime.now(timezone.utc),
expires_at=datetime.now(timezone.utc) + timedelta(days=60)
last_used_at=datetime.now(UTC),
expires_at=datetime.now(UTC) + timedelta(days=60),
)
session.add(user_session)
await session.commit()
@@ -917,7 +972,7 @@ class TestCRUDBaseEagerLoadingWithRealOptions:
result = await session_crud.get(
session,
id=str(session_id),
options=[joinedload(UserSession.user)] # Real option, not empty list
options=[joinedload(UserSession.user)], # Real option, not empty list
)
assert result is not None
assert result.id == session_id
@@ -925,14 +980,17 @@ class TestCRUDBaseEagerLoadingWithRealOptions:
assert result.user.email == async_test_user.email
@pytest.mark.asyncio
async def test_get_multi_with_real_eager_loading_options(self, async_test_db, async_test_user):
async def test_get_multi_with_real_eager_loading_options(
self, async_test_db, async_test_user
):
"""Test get_multi() with actual eager loading options (covers lines 119-120)."""
from datetime import datetime, timedelta, timezone
test_engine, SessionLocal = async_test_db
from datetime import datetime, timedelta
_test_engine, SessionLocal = async_test_db
# Create multiple sessions for the user
from app.models.user_session import UserSession
from app.crud.session import session as session_crud
from app.models.user_session import UserSession
async with SessionLocal() as session:
for i in range(3):
@@ -942,8 +1000,8 @@ class TestCRUDBaseEagerLoadingWithRealOptions:
device_id=f"device-{i}",
ip_address=f"192.168.1.{i}",
user_agent=f"Agent {i}",
last_used_at=datetime.now(timezone.utc),
expires_at=datetime.now(timezone.utc) + timedelta(days=60)
last_used_at=datetime.now(UTC),
expires_at=datetime.now(UTC) + timedelta(days=60),
)
session.add(user_session)
await session.commit()
@@ -954,7 +1012,7 @@ class TestCRUDBaseEagerLoadingWithRealOptions:
session,
skip=0,
limit=10,
options=[joinedload(UserSession.user)] # Real option, not empty list
options=[joinedload(UserSession.user)], # Real option, not empty list
)
assert len(results) >= 3
# Verify we can access user without additional queries

View File

@@ -3,13 +3,15 @@
Comprehensive tests for base CRUD database failure scenarios.
Tests exception handling, rollbacks, and error messages.
"""
import pytest
from unittest.mock import AsyncMock, MagicMock, patch
from sqlalchemy.exc import IntegrityError, OperationalError, DataError
from unittest.mock import AsyncMock, patch
from uuid import uuid4
import pytest
from sqlalchemy.exc import DataError, OperationalError
from app.crud.user import user as user_crud
from app.schemas.users import UserCreate, UserUpdate
from app.schemas.users import UserCreate
class TestBaseCRUDCreateFailures:
@@ -18,19 +20,24 @@ class TestBaseCRUDCreateFailures:
@pytest.mark.asyncio
async def test_create_operational_error_triggers_rollback(self, async_test_db):
"""Test that OperationalError triggers rollback (User CRUD catches as Exception)."""
test_engine, SessionLocal = async_test_db
_test_engine, SessionLocal = async_test_db
async with SessionLocal() as session:
async def mock_commit():
raise OperationalError("Connection lost", {}, Exception("DB connection failed"))
with patch.object(session, 'commit', side_effect=mock_commit):
with patch.object(session, 'rollback', new_callable=AsyncMock) as mock_rollback:
async def mock_commit():
raise OperationalError(
"Connection lost", {}, Exception("DB connection failed")
)
with patch.object(session, "commit", side_effect=mock_commit):
with patch.object(
session, "rollback", new_callable=AsyncMock
) as mock_rollback:
user_data = UserCreate(
email="operror@example.com",
password="TestPassword123!",
first_name="Test",
last_name="User"
last_name="User",
)
# User CRUD catches this as generic Exception and re-raises
@@ -43,19 +50,22 @@ class TestBaseCRUDCreateFailures:
@pytest.mark.asyncio
async def test_create_data_error_triggers_rollback(self, async_test_db):
"""Test that DataError triggers rollback (User CRUD catches as Exception)."""
test_engine, SessionLocal = async_test_db
_test_engine, SessionLocal = async_test_db
async with SessionLocal() as session:
async def mock_commit():
raise DataError("Invalid data type", {}, Exception("Data overflow"))
with patch.object(session, 'commit', side_effect=mock_commit):
with patch.object(session, 'rollback', new_callable=AsyncMock) as mock_rollback:
with patch.object(session, "commit", side_effect=mock_commit):
with patch.object(
session, "rollback", new_callable=AsyncMock
) as mock_rollback:
user_data = UserCreate(
email="dataerror@example.com",
password="TestPassword123!",
first_name="Test",
last_name="User"
last_name="User",
)
# User CRUD catches this as generic Exception and re-raises
@@ -67,19 +77,22 @@ class TestBaseCRUDCreateFailures:
@pytest.mark.asyncio
async def test_create_unexpected_exception_triggers_rollback(self, async_test_db):
"""Test that unexpected exceptions trigger rollback and re-raise."""
test_engine, SessionLocal = async_test_db
_test_engine, SessionLocal = async_test_db
async with SessionLocal() as session:
async def mock_commit():
raise RuntimeError("Unexpected database error")
with patch.object(session, 'commit', side_effect=mock_commit):
with patch.object(session, 'rollback', new_callable=AsyncMock) as mock_rollback:
with patch.object(session, "commit", side_effect=mock_commit):
with patch.object(
session, "rollback", new_callable=AsyncMock
) as mock_rollback:
user_data = UserCreate(
email="unexpected@example.com",
password="TestPassword123!",
first_name="Test",
last_name="User"
last_name="User",
)
with pytest.raises(RuntimeError, match="Unexpected database error"):
@@ -94,7 +107,7 @@ class TestBaseCRUDUpdateFailures:
@pytest.mark.asyncio
async def test_update_operational_error(self, async_test_db, async_test_user):
"""Test update with OperationalError."""
test_engine, SessionLocal = async_test_db
_test_engine, SessionLocal = async_test_db
async with SessionLocal() as session:
user = await user_crud.get(session, id=str(async_test_user.id))
@@ -102,17 +115,21 @@ class TestBaseCRUDUpdateFailures:
async def mock_commit():
raise OperationalError("Connection timeout", {}, Exception("Timeout"))
with patch.object(session, 'commit', side_effect=mock_commit):
with patch.object(session, 'rollback', new_callable=AsyncMock) as mock_rollback:
with patch.object(session, "commit", side_effect=mock_commit):
with patch.object(
session, "rollback", new_callable=AsyncMock
) as mock_rollback:
with pytest.raises(ValueError, match="Database operation failed"):
await user_crud.update(session, db_obj=user, obj_in={"first_name": "Updated"})
await user_crud.update(
session, db_obj=user, obj_in={"first_name": "Updated"}
)
mock_rollback.assert_called_once()
@pytest.mark.asyncio
async def test_update_data_error(self, async_test_db, async_test_user):
"""Test update with DataError."""
test_engine, SessionLocal = async_test_db
_test_engine, SessionLocal = async_test_db
async with SessionLocal() as session:
user = await user_crud.get(session, id=str(async_test_user.id))
@@ -120,17 +137,21 @@ class TestBaseCRUDUpdateFailures:
async def mock_commit():
raise DataError("Invalid data", {}, Exception("Data type mismatch"))
with patch.object(session, 'commit', side_effect=mock_commit):
with patch.object(session, 'rollback', new_callable=AsyncMock) as mock_rollback:
with patch.object(session, "commit", side_effect=mock_commit):
with patch.object(
session, "rollback", new_callable=AsyncMock
) as mock_rollback:
with pytest.raises(ValueError, match="Database operation failed"):
await user_crud.update(session, db_obj=user, obj_in={"first_name": "Updated"})
await user_crud.update(
session, db_obj=user, obj_in={"first_name": "Updated"}
)
mock_rollback.assert_called_once()
@pytest.mark.asyncio
async def test_update_unexpected_error(self, async_test_db, async_test_user):
"""Test update with unexpected error."""
test_engine, SessionLocal = async_test_db
_test_engine, SessionLocal = async_test_db
async with SessionLocal() as session:
user = await user_crud.get(session, id=str(async_test_user.id))
@@ -138,10 +159,14 @@ class TestBaseCRUDUpdateFailures:
async def mock_commit():
raise KeyError("Unexpected error")
with patch.object(session, 'commit', side_effect=mock_commit):
with patch.object(session, 'rollback', new_callable=AsyncMock) as mock_rollback:
with patch.object(session, "commit", side_effect=mock_commit):
with patch.object(
session, "rollback", new_callable=AsyncMock
) as mock_rollback:
with pytest.raises(KeyError):
await user_crud.update(session, db_obj=user, obj_in={"first_name": "Updated"})
await user_crud.update(
session, db_obj=user, obj_in={"first_name": "Updated"}
)
mock_rollback.assert_called_once()
@@ -150,16 +175,21 @@ class TestBaseCRUDRemoveFailures:
"""Test base CRUD remove method exception handling."""
@pytest.mark.asyncio
async def test_remove_unexpected_error_triggers_rollback(self, async_test_db, async_test_user):
async def test_remove_unexpected_error_triggers_rollback(
self, async_test_db, async_test_user
):
"""Test that unexpected errors in remove trigger rollback."""
test_engine, SessionLocal = async_test_db
_test_engine, SessionLocal = async_test_db
async with SessionLocal() as session:
async def mock_commit():
raise RuntimeError("Database write failed")
with patch.object(session, 'commit', side_effect=mock_commit):
with patch.object(session, 'rollback', new_callable=AsyncMock) as mock_rollback:
with patch.object(session, "commit", side_effect=mock_commit):
with patch.object(
session, "rollback", new_callable=AsyncMock
) as mock_rollback:
with pytest.raises(RuntimeError, match="Database write failed"):
await user_crud.remove(session, id=str(async_test_user.id))
@@ -172,16 +202,15 @@ class TestBaseCRUDGetMultiWithTotalFailures:
@pytest.mark.asyncio
async def test_get_multi_with_total_database_error(self, async_test_db):
"""Test get_multi_with_total handles database errors."""
test_engine, SessionLocal = async_test_db
_test_engine, SessionLocal = async_test_db
async with SessionLocal() as session:
# Mock execute to raise an error
original_execute = session.execute
async def mock_execute(*args, **kwargs):
raise OperationalError("Query failed", {}, Exception("Database error"))
with patch.object(session, 'execute', side_effect=mock_execute):
with patch.object(session, "execute", side_effect=mock_execute):
with pytest.raises(OperationalError):
await user_crud.get_multi_with_total(session, skip=0, limit=10)
@@ -192,13 +221,14 @@ class TestBaseCRUDCountFailures:
@pytest.mark.asyncio
async def test_count_database_error_propagates(self, async_test_db):
"""Test count propagates database errors."""
test_engine, SessionLocal = async_test_db
_test_engine, SessionLocal = async_test_db
async with SessionLocal() as session:
async def mock_execute(*args, **kwargs):
raise OperationalError("Count failed", {}, Exception("DB error"))
with patch.object(session, 'execute', side_effect=mock_execute):
with patch.object(session, "execute", side_effect=mock_execute):
with pytest.raises(OperationalError):
await user_crud.count(session)
@@ -207,16 +237,21 @@ class TestBaseCRUDSoftDeleteFailures:
"""Test soft_delete method exception handling."""
@pytest.mark.asyncio
async def test_soft_delete_unexpected_error_triggers_rollback(self, async_test_db, async_test_user):
async def test_soft_delete_unexpected_error_triggers_rollback(
self, async_test_db, async_test_user
):
"""Test soft_delete handles unexpected errors with rollback."""
test_engine, SessionLocal = async_test_db
_test_engine, SessionLocal = async_test_db
async with SessionLocal() as session:
async def mock_commit():
raise RuntimeError("Soft delete failed")
with patch.object(session, 'commit', side_effect=mock_commit):
with patch.object(session, 'rollback', new_callable=AsyncMock) as mock_rollback:
with patch.object(session, "commit", side_effect=mock_commit):
with patch.object(
session, "rollback", new_callable=AsyncMock
) as mock_rollback:
with pytest.raises(RuntimeError, match="Soft delete failed"):
await user_crud.soft_delete(session, id=str(async_test_user.id))
@@ -229,7 +264,7 @@ class TestBaseCRUDRestoreFailures:
@pytest.mark.asyncio
async def test_restore_unexpected_error_triggers_rollback(self, async_test_db):
"""Test restore handles unexpected errors with rollback."""
test_engine, SessionLocal = async_test_db
_test_engine, SessionLocal = async_test_db
# First create and soft delete a user
async with SessionLocal() as session:
@@ -237,7 +272,7 @@ class TestBaseCRUDRestoreFailures:
email="restore_test@example.com",
password="TestPassword123!",
first_name="Restore",
last_name="Test"
last_name="Test",
)
user = await user_crud.create(session, obj_in=user_data)
user_id = user.id
@@ -248,11 +283,14 @@ class TestBaseCRUDRestoreFailures:
# Now test restore failure
async with SessionLocal() as session:
async def mock_commit():
raise RuntimeError("Restore failed")
with patch.object(session, 'commit', side_effect=mock_commit):
with patch.object(session, 'rollback', new_callable=AsyncMock) as mock_rollback:
with patch.object(session, "commit", side_effect=mock_commit):
with patch.object(
session, "rollback", new_callable=AsyncMock
) as mock_rollback:
with pytest.raises(RuntimeError, match="Restore failed"):
await user_crud.restore(session, id=str(user_id))
@@ -265,13 +303,14 @@ class TestBaseCRUDGetFailures:
@pytest.mark.asyncio
async def test_get_database_error_propagates(self, async_test_db):
"""Test get propagates database errors."""
test_engine, SessionLocal = async_test_db
_test_engine, SessionLocal = async_test_db
async with SessionLocal() as session:
async def mock_execute(*args, **kwargs):
raise OperationalError("Get failed", {}, Exception("DB error"))
with patch.object(session, 'execute', side_effect=mock_execute):
with patch.object(session, "execute", side_effect=mock_execute):
with pytest.raises(OperationalError):
await user_crud.get(session, id=str(uuid4()))
@@ -282,12 +321,13 @@ class TestBaseCRUDGetMultiFailures:
@pytest.mark.asyncio
async def test_get_multi_database_error_propagates(self, async_test_db):
"""Test get_multi propagates database errors."""
test_engine, SessionLocal = async_test_db
_test_engine, SessionLocal = async_test_db
async with SessionLocal() as session:
async def mock_execute(*args, **kwargs):
raise OperationalError("Query failed", {}, Exception("DB error"))
with patch.object(session, 'execute', side_effect=mock_execute):
with patch.object(session, "execute", side_effect=mock_execute):
with pytest.raises(OperationalError):
await user_crud.get_multi(session, skip=0, limit=10)

File diff suppressed because it is too large Load Diff

View File

@@ -2,10 +2,12 @@
"""
Comprehensive tests for async session CRUD operations.
"""
import pytest
from datetime import datetime, timedelta, timezone
from datetime import UTC, datetime, timedelta
from uuid import uuid4
import pytest
from app.crud.session import session as session_crud
from app.models.user_session import UserSession
from app.schemas.sessions import SessionCreate
@@ -17,7 +19,7 @@ class TestGetByJti:
@pytest.mark.asyncio
async def test_get_by_jti_success(self, async_test_db, async_test_user):
"""Test getting session by JTI."""
test_engine, AsyncTestingSessionLocal = async_test_db
_test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session:
user_session = UserSession(
@@ -27,8 +29,8 @@ class TestGetByJti:
ip_address="192.168.1.1",
user_agent="Mozilla/5.0",
is_active=True,
expires_at=datetime.now(timezone.utc) + timedelta(days=7),
last_used_at=datetime.now(timezone.utc)
expires_at=datetime.now(UTC) + timedelta(days=7),
last_used_at=datetime.now(UTC),
)
session.add(user_session)
await session.commit()
@@ -41,7 +43,7 @@ class TestGetByJti:
@pytest.mark.asyncio
async def test_get_by_jti_not_found(self, async_test_db):
"""Test getting non-existent JTI returns None."""
test_engine, AsyncTestingSessionLocal = async_test_db
_test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session:
result = await session_crud.get_by_jti(session, jti="nonexistent")
@@ -54,7 +56,7 @@ class TestGetActiveByJti:
@pytest.mark.asyncio
async def test_get_active_by_jti_success(self, async_test_db, async_test_user):
"""Test getting active session by JTI."""
test_engine, AsyncTestingSessionLocal = async_test_db
_test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session:
user_session = UserSession(
@@ -64,8 +66,8 @@ class TestGetActiveByJti:
ip_address="192.168.1.1",
user_agent="Mozilla/5.0",
is_active=True,
expires_at=datetime.now(timezone.utc) + timedelta(days=7),
last_used_at=datetime.now(timezone.utc)
expires_at=datetime.now(UTC) + timedelta(days=7),
last_used_at=datetime.now(UTC),
)
session.add(user_session)
await session.commit()
@@ -78,7 +80,7 @@ class TestGetActiveByJti:
@pytest.mark.asyncio
async def test_get_active_by_jti_inactive(self, async_test_db, async_test_user):
"""Test getting inactive session by JTI returns None."""
test_engine, AsyncTestingSessionLocal = async_test_db
_test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session:
user_session = UserSession(
@@ -88,8 +90,8 @@ class TestGetActiveByJti:
ip_address="192.168.1.1",
user_agent="Mozilla/5.0",
is_active=False,
expires_at=datetime.now(timezone.utc) + timedelta(days=7),
last_used_at=datetime.now(timezone.utc)
expires_at=datetime.now(UTC) + timedelta(days=7),
last_used_at=datetime.now(UTC),
)
session.add(user_session)
await session.commit()
@@ -105,7 +107,7 @@ class TestGetUserSessions:
@pytest.mark.asyncio
async def test_get_user_sessions_active_only(self, async_test_db, async_test_user):
"""Test getting only active user sessions."""
test_engine, AsyncTestingSessionLocal = async_test_db
_test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session:
active = UserSession(
@@ -115,8 +117,8 @@ class TestGetUserSessions:
ip_address="192.168.1.1",
user_agent="Mozilla/5.0",
is_active=True,
expires_at=datetime.now(timezone.utc) + timedelta(days=7),
last_used_at=datetime.now(timezone.utc)
expires_at=datetime.now(UTC) + timedelta(days=7),
last_used_at=datetime.now(UTC),
)
inactive = UserSession(
user_id=async_test_user.id,
@@ -125,17 +127,15 @@ class TestGetUserSessions:
ip_address="192.168.1.2",
user_agent="Mozilla/5.0",
is_active=False,
expires_at=datetime.now(timezone.utc) + timedelta(days=7),
last_used_at=datetime.now(timezone.utc)
expires_at=datetime.now(UTC) + timedelta(days=7),
last_used_at=datetime.now(UTC),
)
session.add_all([active, inactive])
await session.commit()
async with AsyncTestingSessionLocal() as session:
results = await session_crud.get_user_sessions(
session,
user_id=str(async_test_user.id),
active_only=True
session, user_id=str(async_test_user.id), active_only=True
)
assert len(results) == 1
assert results[0].is_active is True
@@ -143,7 +143,7 @@ class TestGetUserSessions:
@pytest.mark.asyncio
async def test_get_user_sessions_all(self, async_test_db, async_test_user):
"""Test getting all user sessions."""
test_engine, AsyncTestingSessionLocal = async_test_db
_test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session:
for i in range(3):
@@ -154,17 +154,15 @@ class TestGetUserSessions:
ip_address="192.168.1.1",
user_agent="Mozilla/5.0",
is_active=i % 2 == 0,
expires_at=datetime.now(timezone.utc) + timedelta(days=7),
last_used_at=datetime.now(timezone.utc)
expires_at=datetime.now(UTC) + timedelta(days=7),
last_used_at=datetime.now(UTC),
)
session.add(sess)
await session.commit()
async with AsyncTestingSessionLocal() as session:
results = await session_crud.get_user_sessions(
session,
user_id=str(async_test_user.id),
active_only=False
session, user_id=str(async_test_user.id), active_only=False
)
assert len(results) == 3
@@ -175,7 +173,7 @@ class TestCreateSession:
@pytest.mark.asyncio
async def test_create_session_success(self, async_test_db, async_test_user):
"""Test successfully creating a session_crud."""
test_engine, AsyncTestingSessionLocal = async_test_db
_test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session:
session_data = SessionCreate(
@@ -185,10 +183,10 @@ class TestCreateSession:
device_id="device_123",
ip_address="192.168.1.100",
user_agent="Mozilla/5.0",
last_used_at=datetime.now(timezone.utc),
expires_at=datetime.now(timezone.utc) + timedelta(days=7),
last_used_at=datetime.now(UTC),
expires_at=datetime.now(UTC) + timedelta(days=7),
location_city="San Francisco",
location_country="USA"
location_country="USA",
)
result = await session_crud.create_session(session, obj_in=session_data)
@@ -204,7 +202,7 @@ class TestDeactivate:
@pytest.mark.asyncio
async def test_deactivate_success(self, async_test_db, async_test_user):
"""Test successfully deactivating a session_crud."""
test_engine, AsyncTestingSessionLocal = async_test_db
_test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session:
user_session = UserSession(
@@ -214,8 +212,8 @@ class TestDeactivate:
ip_address="192.168.1.1",
user_agent="Mozilla/5.0",
is_active=True,
expires_at=datetime.now(timezone.utc) + timedelta(days=7),
last_used_at=datetime.now(timezone.utc)
expires_at=datetime.now(UTC) + timedelta(days=7),
last_used_at=datetime.now(UTC),
)
session.add(user_session)
await session.commit()
@@ -229,7 +227,7 @@ class TestDeactivate:
@pytest.mark.asyncio
async def test_deactivate_not_found(self, async_test_db):
"""Test deactivating non-existent session returns None."""
test_engine, AsyncTestingSessionLocal = async_test_db
_test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session:
result = await session_crud.deactivate(session, session_id=str(uuid4()))
@@ -240,9 +238,11 @@ class TestDeactivateAllUserSessions:
"""Tests for deactivate_all_user_sessions method."""
@pytest.mark.asyncio
async def test_deactivate_all_user_sessions_success(self, async_test_db, async_test_user):
async def test_deactivate_all_user_sessions_success(
self, async_test_db, async_test_user
):
"""Test deactivating all user sessions."""
test_engine, AsyncTestingSessionLocal = async_test_db
_test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session:
# Create minimal sessions for test (2 instead of 5)
@@ -254,16 +254,15 @@ class TestDeactivateAllUserSessions:
ip_address="192.168.1.1",
user_agent="Mozilla/5.0",
is_active=True,
expires_at=datetime.now(timezone.utc) + timedelta(days=7),
last_used_at=datetime.now(timezone.utc)
expires_at=datetime.now(UTC) + timedelta(days=7),
last_used_at=datetime.now(UTC),
)
session.add(sess)
await session.commit()
async with AsyncTestingSessionLocal() as session:
count = await session_crud.deactivate_all_user_sessions(
session,
user_id=str(async_test_user.id)
session, user_id=str(async_test_user.id)
)
assert count == 2
@@ -274,7 +273,7 @@ class TestUpdateLastUsed:
@pytest.mark.asyncio
async def test_update_last_used_success(self, async_test_db, async_test_user):
"""Test updating last_used_at timestamp."""
test_engine, AsyncTestingSessionLocal = async_test_db
_test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session:
user_session = UserSession(
@@ -284,8 +283,8 @@ class TestUpdateLastUsed:
ip_address="192.168.1.1",
user_agent="Mozilla/5.0",
is_active=True,
expires_at=datetime.now(timezone.utc) + timedelta(days=7),
last_used_at=datetime.now(timezone.utc) - timedelta(hours=1)
expires_at=datetime.now(UTC) + timedelta(days=7),
last_used_at=datetime.now(UTC) - timedelta(hours=1),
)
session.add(user_session)
await session.commit()
@@ -303,7 +302,7 @@ class TestGetUserSessionCount:
@pytest.mark.asyncio
async def test_get_user_session_count_success(self, async_test_db, async_test_user):
"""Test getting user session count."""
test_engine, AsyncTestingSessionLocal = async_test_db
_test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session:
for i in range(3):
@@ -314,28 +313,26 @@ class TestGetUserSessionCount:
ip_address="192.168.1.1",
user_agent="Mozilla/5.0",
is_active=True,
expires_at=datetime.now(timezone.utc) + timedelta(days=7),
last_used_at=datetime.now(timezone.utc)
expires_at=datetime.now(UTC) + timedelta(days=7),
last_used_at=datetime.now(UTC),
)
session.add(sess)
await session.commit()
async with AsyncTestingSessionLocal() as session:
count = await session_crud.get_user_session_count(
session,
user_id=str(async_test_user.id)
session, user_id=str(async_test_user.id)
)
assert count == 3
@pytest.mark.asyncio
async def test_get_user_session_count_empty(self, async_test_db):
"""Test getting session count for user with no sessions."""
test_engine, AsyncTestingSessionLocal = async_test_db
_test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session:
count = await session_crud.get_user_session_count(
session,
user_id=str(uuid4())
session, user_id=str(uuid4())
)
assert count == 0
@@ -346,7 +343,7 @@ class TestUpdateRefreshToken:
@pytest.mark.asyncio
async def test_update_refresh_token_success(self, async_test_db, async_test_user):
"""Test updating refresh token JTI and expiration."""
test_engine, AsyncTestingSessionLocal = async_test_db
_test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session:
user_session = UserSession(
@@ -356,26 +353,34 @@ class TestUpdateRefreshToken:
ip_address="192.168.1.1",
user_agent="Mozilla/5.0",
is_active=True,
expires_at=datetime.now(timezone.utc) + timedelta(days=7),
last_used_at=datetime.now(timezone.utc) - timedelta(hours=1)
expires_at=datetime.now(UTC) + timedelta(days=7),
last_used_at=datetime.now(UTC) - timedelta(hours=1),
)
session.add(user_session)
await session.commit()
await session.refresh(user_session)
new_jti = "new_jti_123"
new_expires = datetime.now(timezone.utc) + timedelta(days=14)
new_expires = datetime.now(UTC) + timedelta(days=14)
result = await session_crud.update_refresh_token(
session,
session=user_session,
new_jti=new_jti,
new_expires_at=new_expires
new_expires_at=new_expires,
)
assert result.refresh_token_jti == new_jti
# Compare timestamps ignoring timezone info
assert abs((result.expires_at.replace(tzinfo=None) - new_expires.replace(tzinfo=None)).total_seconds()) < 1
assert (
abs(
(
result.expires_at.replace(tzinfo=None)
- new_expires.replace(tzinfo=None)
).total_seconds()
)
< 1
)
class TestCleanupExpired:
@@ -384,7 +389,7 @@ class TestCleanupExpired:
@pytest.mark.asyncio
async def test_cleanup_expired_success(self, async_test_db, async_test_user):
"""Test cleaning up old expired inactive sessions."""
test_engine, AsyncTestingSessionLocal = async_test_db
_test_engine, AsyncTestingSessionLocal = async_test_db
# Create old expired inactive session
async with AsyncTestingSessionLocal() as session:
@@ -395,9 +400,9 @@ class TestCleanupExpired:
ip_address="192.168.1.1",
user_agent="Mozilla/5.0",
is_active=False,
expires_at=datetime.now(timezone.utc) - timedelta(days=5),
last_used_at=datetime.now(timezone.utc) - timedelta(days=35),
created_at=datetime.now(timezone.utc) - timedelta(days=35)
expires_at=datetime.now(UTC) - timedelta(days=5),
last_used_at=datetime.now(UTC) - timedelta(days=35),
created_at=datetime.now(UTC) - timedelta(days=35),
)
session.add(old_session)
await session.commit()
@@ -410,7 +415,7 @@ class TestCleanupExpired:
@pytest.mark.asyncio
async def test_cleanup_expired_keeps_recent(self, async_test_db, async_test_user):
"""Test that cleanup keeps recent expired sessions."""
test_engine, AsyncTestingSessionLocal = async_test_db
_test_engine, AsyncTestingSessionLocal = async_test_db
# Create recent expired inactive session (less than keep_days old)
async with AsyncTestingSessionLocal() as session:
@@ -421,9 +426,9 @@ class TestCleanupExpired:
ip_address="192.168.1.1",
user_agent="Mozilla/5.0",
is_active=False,
expires_at=datetime.now(timezone.utc) - timedelta(hours=1),
last_used_at=datetime.now(timezone.utc) - timedelta(hours=2),
created_at=datetime.now(timezone.utc) - timedelta(days=1)
expires_at=datetime.now(UTC) - timedelta(hours=1),
last_used_at=datetime.now(UTC) - timedelta(hours=2),
created_at=datetime.now(UTC) - timedelta(days=1),
)
session.add(recent_session)
await session.commit()
@@ -436,7 +441,7 @@ class TestCleanupExpired:
@pytest.mark.asyncio
async def test_cleanup_expired_keeps_active(self, async_test_db, async_test_user):
"""Test that cleanup does not delete active sessions."""
test_engine, AsyncTestingSessionLocal = async_test_db
_test_engine, AsyncTestingSessionLocal = async_test_db
# Create old expired but ACTIVE session
async with AsyncTestingSessionLocal() as session:
@@ -447,9 +452,9 @@ class TestCleanupExpired:
ip_address="192.168.1.1",
user_agent="Mozilla/5.0",
is_active=True, # Active
expires_at=datetime.now(timezone.utc) - timedelta(days=5),
last_used_at=datetime.now(timezone.utc) - timedelta(days=35),
created_at=datetime.now(timezone.utc) - timedelta(days=35)
expires_at=datetime.now(UTC) - timedelta(days=5),
last_used_at=datetime.now(UTC) - timedelta(days=35),
created_at=datetime.now(UTC) - timedelta(days=35),
)
session.add(active_session)
await session.commit()
@@ -464,9 +469,11 @@ class TestCleanupExpiredForUser:
"""Tests for cleanup_expired_for_user method."""
@pytest.mark.asyncio
async def test_cleanup_expired_for_user_success(self, async_test_db, async_test_user):
async def test_cleanup_expired_for_user_success(
self, async_test_db, async_test_user
):
"""Test cleaning up expired sessions for specific user."""
test_engine, AsyncTestingSessionLocal = async_test_db
_test_engine, AsyncTestingSessionLocal = async_test_db
# Create expired inactive session for user
async with AsyncTestingSessionLocal() as session:
@@ -477,8 +484,8 @@ class TestCleanupExpiredForUser:
ip_address="192.168.1.1",
user_agent="Mozilla/5.0",
is_active=False,
expires_at=datetime.now(timezone.utc) - timedelta(days=1),
last_used_at=datetime.now(timezone.utc) - timedelta(days=2)
expires_at=datetime.now(UTC) - timedelta(days=1),
last_used_at=datetime.now(UTC) - timedelta(days=2),
)
session.add(expired_session)
await session.commit()
@@ -486,27 +493,27 @@ class TestCleanupExpiredForUser:
# Cleanup for user
async with AsyncTestingSessionLocal() as session:
count = await session_crud.cleanup_expired_for_user(
session,
user_id=str(async_test_user.id)
session, user_id=str(async_test_user.id)
)
assert count == 1
@pytest.mark.asyncio
async def test_cleanup_expired_for_user_invalid_uuid(self, async_test_db):
"""Test cleanup with invalid user UUID."""
test_engine, AsyncTestingSessionLocal = async_test_db
_test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session:
with pytest.raises(ValueError, match="Invalid user ID format"):
await session_crud.cleanup_expired_for_user(
session,
user_id="not-a-valid-uuid"
session, user_id="not-a-valid-uuid"
)
@pytest.mark.asyncio
async def test_cleanup_expired_for_user_keeps_active(self, async_test_db, async_test_user):
async def test_cleanup_expired_for_user_keeps_active(
self, async_test_db, async_test_user
):
"""Test that cleanup for user keeps active sessions."""
test_engine, AsyncTestingSessionLocal = async_test_db
_test_engine, AsyncTestingSessionLocal = async_test_db
# Create expired but active session
async with AsyncTestingSessionLocal() as session:
@@ -517,8 +524,8 @@ class TestCleanupExpiredForUser:
ip_address="192.168.1.1",
user_agent="Mozilla/5.0",
is_active=True, # Active
expires_at=datetime.now(timezone.utc) - timedelta(days=1),
last_used_at=datetime.now(timezone.utc) - timedelta(days=2)
expires_at=datetime.now(UTC) - timedelta(days=1),
last_used_at=datetime.now(UTC) - timedelta(days=2),
)
session.add(active_session)
await session.commit()
@@ -526,8 +533,7 @@ class TestCleanupExpiredForUser:
# Cleanup
async with AsyncTestingSessionLocal() as session:
count = await session_crud.cleanup_expired_for_user(
session,
user_id=str(async_test_user.id)
session, user_id=str(async_test_user.id)
)
assert count == 0 # Should not delete active sessions
@@ -536,9 +542,11 @@ class TestGetUserSessionsWithUser:
"""Tests for get_user_sessions with eager loading."""
@pytest.mark.asyncio
async def test_get_user_sessions_with_user_relationship(self, async_test_db, async_test_user):
async def test_get_user_sessions_with_user_relationship(
self, async_test_db, async_test_user
):
"""Test getting sessions with user relationship loaded."""
test_engine, AsyncTestingSessionLocal = async_test_db
_test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session:
user_session = UserSession(
@@ -548,8 +556,8 @@ class TestGetUserSessionsWithUser:
ip_address="192.168.1.1",
user_agent="Mozilla/5.0",
is_active=True,
expires_at=datetime.now(timezone.utc) + timedelta(days=7),
last_used_at=datetime.now(timezone.utc)
expires_at=datetime.now(UTC) + timedelta(days=7),
last_used_at=datetime.now(UTC),
)
session.add(user_session)
await session.commit()
@@ -557,8 +565,6 @@ class TestGetUserSessionsWithUser:
# Get with user relationship
async with AsyncTestingSessionLocal() as session:
results = await session_crud.get_user_sessions(
session,
user_id=str(async_test_user.id),
with_user=True
session, user_id=str(async_test_user.id), with_user=True
)
assert len(results) >= 1

View File

@@ -2,12 +2,14 @@
"""
Comprehensive tests for session CRUD database failure scenarios.
"""
import pytest
from datetime import UTC, datetime, timedelta
from unittest.mock import AsyncMock, patch
from sqlalchemy.exc import OperationalError, IntegrityError
from datetime import datetime, timedelta, timezone
from uuid import uuid4
import pytest
from sqlalchemy.exc import OperationalError
from app.crud.session import session as session_crud
from app.models.user_session import UserSession
from app.schemas.sessions import SessionCreate
@@ -19,13 +21,14 @@ class TestSessionCRUDGetByJtiFailures:
@pytest.mark.asyncio
async def test_get_by_jti_database_error(self, async_test_db):
"""Test get_by_jti handles database errors."""
test_engine, SessionLocal = async_test_db
_test_engine, SessionLocal = async_test_db
async with SessionLocal() as session:
async def mock_execute(*args, **kwargs):
raise OperationalError("DB connection lost", {}, Exception())
with patch.object(session, 'execute', side_effect=mock_execute):
with patch.object(session, "execute", side_effect=mock_execute):
with pytest.raises(OperationalError):
await session_crud.get_by_jti(session, jti="test_jti")
@@ -36,13 +39,14 @@ class TestSessionCRUDGetActiveByJtiFailures:
@pytest.mark.asyncio
async def test_get_active_by_jti_database_error(self, async_test_db):
"""Test get_active_by_jti handles database errors."""
test_engine, SessionLocal = async_test_db
_test_engine, SessionLocal = async_test_db
async with SessionLocal() as session:
async def mock_execute(*args, **kwargs):
raise OperationalError("Query timeout", {}, Exception())
with patch.object(session, 'execute', side_effect=mock_execute):
with patch.object(session, "execute", side_effect=mock_execute):
with pytest.raises(OperationalError):
await session_crud.get_active_by_jti(session, jti="test_jti")
@@ -51,19 +55,21 @@ class TestSessionCRUDGetUserSessionsFailures:
"""Test get_user_sessions exception handling."""
@pytest.mark.asyncio
async def test_get_user_sessions_database_error(self, async_test_db, async_test_user):
async def test_get_user_sessions_database_error(
self, async_test_db, async_test_user
):
"""Test get_user_sessions handles database errors."""
test_engine, SessionLocal = async_test_db
_test_engine, SessionLocal = async_test_db
async with SessionLocal() as session:
async def mock_execute(*args, **kwargs):
raise OperationalError("Database error", {}, Exception())
with patch.object(session, 'execute', side_effect=mock_execute):
with patch.object(session, "execute", side_effect=mock_execute):
with pytest.raises(OperationalError):
await session_crud.get_user_sessions(
session,
user_id=str(async_test_user.id)
session, user_id=str(async_test_user.id)
)
@@ -71,24 +77,29 @@ class TestSessionCRUDCreateSessionFailures:
"""Test create_session exception handling."""
@pytest.mark.asyncio
async def test_create_session_commit_failure_triggers_rollback(self, async_test_db, async_test_user):
async def test_create_session_commit_failure_triggers_rollback(
self, async_test_db, async_test_user
):
"""Test create_session handles commit failures with rollback."""
test_engine, SessionLocal = async_test_db
_test_engine, SessionLocal = async_test_db
async with SessionLocal() as session:
async def mock_commit():
raise OperationalError("Commit failed", {}, Exception())
with patch.object(session, 'commit', side_effect=mock_commit):
with patch.object(session, 'rollback', new_callable=AsyncMock) as mock_rollback:
with patch.object(session, "commit", side_effect=mock_commit):
with patch.object(
session, "rollback", new_callable=AsyncMock
) as mock_rollback:
session_data = SessionCreate(
user_id=async_test_user.id,
refresh_token_jti=str(uuid4()),
device_name="Test Device",
ip_address="127.0.0.1",
user_agent="Test Agent",
expires_at=datetime.now(timezone.utc) + timedelta(days=7),
last_used_at=datetime.now(timezone.utc)
expires_at=datetime.now(UTC) + timedelta(days=7),
last_used_at=datetime.now(UTC),
)
with pytest.raises(ValueError, match="Failed to create session"):
@@ -97,24 +108,29 @@ class TestSessionCRUDCreateSessionFailures:
mock_rollback.assert_called_once()
@pytest.mark.asyncio
async def test_create_session_unexpected_error_triggers_rollback(self, async_test_db, async_test_user):
async def test_create_session_unexpected_error_triggers_rollback(
self, async_test_db, async_test_user
):
"""Test create_session handles unexpected errors."""
test_engine, SessionLocal = async_test_db
_test_engine, SessionLocal = async_test_db
async with SessionLocal() as session:
async def mock_commit():
raise RuntimeError("Unexpected error")
with patch.object(session, 'commit', side_effect=mock_commit):
with patch.object(session, 'rollback', new_callable=AsyncMock) as mock_rollback:
with patch.object(session, "commit", side_effect=mock_commit):
with patch.object(
session, "rollback", new_callable=AsyncMock
) as mock_rollback:
session_data = SessionCreate(
user_id=async_test_user.id,
refresh_token_jti=str(uuid4()),
device_name="Test Device",
ip_address="127.0.0.1",
user_agent="Test Agent",
expires_at=datetime.now(timezone.utc) + timedelta(days=7),
last_used_at=datetime.now(timezone.utc)
expires_at=datetime.now(UTC) + timedelta(days=7),
last_used_at=datetime.now(UTC),
)
with pytest.raises(ValueError, match="Failed to create session"):
@@ -127,9 +143,11 @@ class TestSessionCRUDDeactivateFailures:
"""Test deactivate exception handling."""
@pytest.mark.asyncio
async def test_deactivate_commit_failure_triggers_rollback(self, async_test_db, async_test_user):
async def test_deactivate_commit_failure_triggers_rollback(
self, async_test_db, async_test_user
):
"""Test deactivate handles commit failures."""
test_engine, SessionLocal = async_test_db
_test_engine, SessionLocal = async_test_db
# Create a session first
async with SessionLocal() as session:
@@ -140,8 +158,8 @@ class TestSessionCRUDDeactivateFailures:
ip_address="127.0.0.1",
user_agent="Test Agent",
is_active=True,
expires_at=datetime.now(timezone.utc) + timedelta(days=7),
last_used_at=datetime.now(timezone.utc)
expires_at=datetime.now(UTC) + timedelta(days=7),
last_used_at=datetime.now(UTC),
)
session.add(user_session)
await session.commit()
@@ -150,13 +168,18 @@ class TestSessionCRUDDeactivateFailures:
# Test deactivate failure
async with SessionLocal() as session:
async def mock_commit():
raise OperationalError("Deactivate failed", {}, Exception())
with patch.object(session, 'commit', side_effect=mock_commit):
with patch.object(session, 'rollback', new_callable=AsyncMock) as mock_rollback:
with patch.object(session, "commit", side_effect=mock_commit):
with patch.object(
session, "rollback", new_callable=AsyncMock
) as mock_rollback:
with pytest.raises(OperationalError):
await session_crud.deactivate(session, session_id=str(session_id))
await session_crud.deactivate(
session, session_id=str(session_id)
)
mock_rollback.assert_called_once()
@@ -165,20 +188,24 @@ class TestSessionCRUDDeactivateAllFailures:
"""Test deactivate_all_user_sessions exception handling."""
@pytest.mark.asyncio
async def test_deactivate_all_commit_failure_triggers_rollback(self, async_test_db, async_test_user):
async def test_deactivate_all_commit_failure_triggers_rollback(
self, async_test_db, async_test_user
):
"""Test deactivate_all handles commit failures."""
test_engine, SessionLocal = async_test_db
_test_engine, SessionLocal = async_test_db
async with SessionLocal() as session:
async def mock_commit():
raise OperationalError("Bulk deactivate failed", {}, Exception())
with patch.object(session, 'commit', side_effect=mock_commit):
with patch.object(session, 'rollback', new_callable=AsyncMock) as mock_rollback:
with patch.object(session, "commit", side_effect=mock_commit):
with patch.object(
session, "rollback", new_callable=AsyncMock
) as mock_rollback:
with pytest.raises(OperationalError):
await session_crud.deactivate_all_user_sessions(
session,
user_id=str(async_test_user.id)
session, user_id=str(async_test_user.id)
)
mock_rollback.assert_called_once()
@@ -188,9 +215,11 @@ class TestSessionCRUDUpdateLastUsedFailures:
"""Test update_last_used exception handling."""
@pytest.mark.asyncio
async def test_update_last_used_commit_failure_triggers_rollback(self, async_test_db, async_test_user):
async def test_update_last_used_commit_failure_triggers_rollback(
self, async_test_db, async_test_user
):
"""Test update_last_used handles commit failures."""
test_engine, SessionLocal = async_test_db
_test_engine, SessionLocal = async_test_db
# Create a session
async with SessionLocal() as session:
@@ -201,8 +230,8 @@ class TestSessionCRUDUpdateLastUsedFailures:
ip_address="127.0.0.1",
user_agent="Test Agent",
is_active=True,
expires_at=datetime.now(timezone.utc) + timedelta(days=7),
last_used_at=datetime.now(timezone.utc) - timedelta(hours=1)
expires_at=datetime.now(UTC) + timedelta(days=7),
last_used_at=datetime.now(UTC) - timedelta(hours=1),
)
session.add(user_session)
await session.commit()
@@ -211,15 +240,19 @@ class TestSessionCRUDUpdateLastUsedFailures:
# Test update failure
async with SessionLocal() as session:
from sqlalchemy import select
from app.models.user_session import UserSession as US
result = await session.execute(select(US).where(US.id == user_session.id))
sess = result.scalar_one()
async def mock_commit():
raise OperationalError("Update failed", {}, Exception())
with patch.object(session, 'commit', side_effect=mock_commit):
with patch.object(session, 'rollback', new_callable=AsyncMock) as mock_rollback:
with patch.object(session, "commit", side_effect=mock_commit):
with patch.object(
session, "rollback", new_callable=AsyncMock
) as mock_rollback:
with pytest.raises(OperationalError):
await session_crud.update_last_used(session, session=sess)
@@ -230,9 +263,11 @@ class TestSessionCRUDUpdateRefreshTokenFailures:
"""Test update_refresh_token exception handling."""
@pytest.mark.asyncio
async def test_update_refresh_token_commit_failure_triggers_rollback(self, async_test_db, async_test_user):
async def test_update_refresh_token_commit_failure_triggers_rollback(
self, async_test_db, async_test_user
):
"""Test update_refresh_token handles commit failures."""
test_engine, SessionLocal = async_test_db
_test_engine, SessionLocal = async_test_db
# Create a session
async with SessionLocal() as session:
@@ -243,8 +278,8 @@ class TestSessionCRUDUpdateRefreshTokenFailures:
ip_address="127.0.0.1",
user_agent="Test Agent",
is_active=True,
expires_at=datetime.now(timezone.utc) + timedelta(days=7),
last_used_at=datetime.now(timezone.utc)
expires_at=datetime.now(UTC) + timedelta(days=7),
last_used_at=datetime.now(UTC),
)
session.add(user_session)
await session.commit()
@@ -253,21 +288,25 @@ class TestSessionCRUDUpdateRefreshTokenFailures:
# Test update failure
async with SessionLocal() as session:
from sqlalchemy import select
from app.models.user_session import UserSession as US
result = await session.execute(select(US).where(US.id == user_session.id))
sess = result.scalar_one()
async def mock_commit():
raise OperationalError("Token update failed", {}, Exception())
with patch.object(session, 'commit', side_effect=mock_commit):
with patch.object(session, 'rollback', new_callable=AsyncMock) as mock_rollback:
with patch.object(session, "commit", side_effect=mock_commit):
with patch.object(
session, "rollback", new_callable=AsyncMock
) as mock_rollback:
with pytest.raises(OperationalError):
await session_crud.update_refresh_token(
session,
session=sess,
new_jti=str(uuid4()),
new_expires_at=datetime.now(timezone.utc) + timedelta(days=14)
new_expires_at=datetime.now(UTC) + timedelta(days=14),
)
mock_rollback.assert_called_once()
@@ -277,16 +316,21 @@ class TestSessionCRUDCleanupExpiredFailures:
"""Test cleanup_expired exception handling."""
@pytest.mark.asyncio
async def test_cleanup_expired_commit_failure_triggers_rollback(self, async_test_db):
async def test_cleanup_expired_commit_failure_triggers_rollback(
self, async_test_db
):
"""Test cleanup_expired handles commit failures."""
test_engine, SessionLocal = async_test_db
_test_engine, SessionLocal = async_test_db
async with SessionLocal() as session:
async def mock_commit():
raise OperationalError("Cleanup failed", {}, Exception())
with patch.object(session, 'commit', side_effect=mock_commit):
with patch.object(session, 'rollback', new_callable=AsyncMock) as mock_rollback:
with patch.object(session, "commit", side_effect=mock_commit):
with patch.object(
session, "rollback", new_callable=AsyncMock
) as mock_rollback:
with pytest.raises(OperationalError):
await session_crud.cleanup_expired(session, keep_days=30)
@@ -297,20 +341,24 @@ class TestSessionCRUDCleanupExpiredForUserFailures:
"""Test cleanup_expired_for_user exception handling."""
@pytest.mark.asyncio
async def test_cleanup_expired_for_user_commit_failure_triggers_rollback(self, async_test_db, async_test_user):
async def test_cleanup_expired_for_user_commit_failure_triggers_rollback(
self, async_test_db, async_test_user
):
"""Test cleanup_expired_for_user handles commit failures."""
test_engine, SessionLocal = async_test_db
_test_engine, SessionLocal = async_test_db
async with SessionLocal() as session:
async def mock_commit():
raise OperationalError("User cleanup failed", {}, Exception())
with patch.object(session, 'commit', side_effect=mock_commit):
with patch.object(session, 'rollback', new_callable=AsyncMock) as mock_rollback:
with patch.object(session, "commit", side_effect=mock_commit):
with patch.object(
session, "rollback", new_callable=AsyncMock
) as mock_rollback:
with pytest.raises(OperationalError):
await session_crud.cleanup_expired_for_user(
session,
user_id=str(async_test_user.id)
session, user_id=str(async_test_user.id)
)
mock_rollback.assert_called_once()
@@ -320,17 +368,19 @@ class TestSessionCRUDGetUserSessionCountFailures:
"""Test get_user_session_count exception handling."""
@pytest.mark.asyncio
async def test_get_user_session_count_database_error(self, async_test_db, async_test_user):
async def test_get_user_session_count_database_error(
self, async_test_db, async_test_user
):
"""Test get_user_session_count handles database errors."""
test_engine, SessionLocal = async_test_db
_test_engine, SessionLocal = async_test_db
async with SessionLocal() as session:
async def mock_execute(*args, **kwargs):
raise OperationalError("Count query failed", {}, Exception())
with patch.object(session, 'execute', side_effect=mock_execute):
with patch.object(session, "execute", side_effect=mock_execute):
with pytest.raises(OperationalError):
await session_crud.get_user_session_count(
session,
user_id=str(async_test_user.id)
session, user_id=str(async_test_user.id)
)

View File

@@ -2,12 +2,10 @@
"""
Comprehensive tests for async user CRUD operations.
"""
import pytest
from datetime import datetime, timezone
from uuid import uuid4
from app.crud.user import user as user_crud
from app.models.user import User
from app.schemas.users import UserCreate, UserUpdate
@@ -17,7 +15,7 @@ class TestGetByEmail:
@pytest.mark.asyncio
async def test_get_by_email_success(self, async_test_db, async_test_user):
"""Test getting user by email."""
test_engine, AsyncTestingSessionLocal = async_test_db
_test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session:
result = await user_crud.get_by_email(session, email=async_test_user.email)
@@ -28,10 +26,12 @@ class TestGetByEmail:
@pytest.mark.asyncio
async def test_get_by_email_not_found(self, async_test_db):
"""Test getting non-existent email returns None."""
test_engine, AsyncTestingSessionLocal = async_test_db
_test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session:
result = await user_crud.get_by_email(session, email="nonexistent@example.com")
result = await user_crud.get_by_email(
session, email="nonexistent@example.com"
)
assert result is None
@@ -41,7 +41,7 @@ class TestCreate:
@pytest.mark.asyncio
async def test_create_user_success(self, async_test_db):
"""Test successfully creating a user_crud."""
test_engine, AsyncTestingSessionLocal = async_test_db
_test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session:
user_data = UserCreate(
@@ -49,7 +49,7 @@ class TestCreate:
password="SecurePass123!",
first_name="New",
last_name="User",
phone_number="+1234567890"
phone_number="+1234567890",
)
result = await user_crud.create(session, obj_in=user_data)
@@ -65,7 +65,7 @@ class TestCreate:
@pytest.mark.asyncio
async def test_create_superuser_success(self, async_test_db):
"""Test creating a superuser."""
test_engine, AsyncTestingSessionLocal = async_test_db
_test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session:
user_data = UserCreate(
@@ -73,7 +73,7 @@ class TestCreate:
password="SuperPass123!",
first_name="Super",
last_name="User",
is_superuser=True
is_superuser=True,
)
result = await user_crud.create(session, obj_in=user_data)
@@ -83,14 +83,14 @@ class TestCreate:
@pytest.mark.asyncio
async def test_create_duplicate_email_fails(self, async_test_db, async_test_user):
"""Test creating user with duplicate email raises ValueError."""
test_engine, AsyncTestingSessionLocal = async_test_db
_test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session:
user_data = UserCreate(
email=async_test_user.email, # Duplicate email
password="AnotherPass123!",
first_name="Duplicate",
last_name="User"
last_name="User",
)
with pytest.raises(ValueError) as exc_info:
@@ -105,16 +105,14 @@ class TestUpdate:
@pytest.mark.asyncio
async def test_update_user_basic_fields(self, async_test_db, async_test_user):
"""Test updating basic user fields."""
test_engine, AsyncTestingSessionLocal = async_test_db
_test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session:
# Get fresh copy of user
user = await user_crud.get(session, id=str(async_test_user.id))
update_data = UserUpdate(
first_name="Updated",
last_name="Name",
phone_number="+9876543210"
first_name="Updated", last_name="Name", phone_number="+9876543210"
)
result = await user_crud.update(session, db_obj=user, obj_in=update_data)
@@ -125,7 +123,7 @@ class TestUpdate:
@pytest.mark.asyncio
async def test_update_user_password(self, async_test_db):
"""Test updating user password."""
test_engine, AsyncTestingSessionLocal = async_test_db
_test_engine, AsyncTestingSessionLocal = async_test_db
# Create a fresh user for this test
async with AsyncTestingSessionLocal() as session:
@@ -133,7 +131,7 @@ class TestUpdate:
email="passwordtest@example.com",
password="OldPassword123!",
first_name="Pass",
last_name="Test"
last_name="Test",
)
user = await user_crud.create(session, obj_in=user_data)
user_id = user.id
@@ -149,12 +147,14 @@ class TestUpdate:
await session.refresh(result)
assert result.password_hash != old_password_hash
assert result.password_hash is not None
assert "NewDifferentPassword123!" not in result.password_hash # Should be hashed
assert (
"NewDifferentPassword123!" not in result.password_hash
) # Should be hashed
@pytest.mark.asyncio
async def test_update_user_with_dict(self, async_test_db, async_test_user):
"""Test updating user with dictionary."""
test_engine, AsyncTestingSessionLocal = async_test_db
_test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session:
user = await user_crud.get(session, id=str(async_test_user.id))
@@ -171,13 +171,11 @@ class TestGetMultiWithTotal:
@pytest.mark.asyncio
async def test_get_multi_with_total_basic(self, async_test_db, async_test_user):
"""Test basic pagination."""
test_engine, AsyncTestingSessionLocal = async_test_db
_test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session:
users, total = await user_crud.get_multi_with_total(
session,
skip=0,
limit=10
session, skip=0, limit=10
)
assert total >= 1
assert len(users) >= 1
@@ -186,7 +184,7 @@ class TestGetMultiWithTotal:
@pytest.mark.asyncio
async def test_get_multi_with_total_sorting_asc(self, async_test_db):
"""Test sorting in ascending order."""
test_engine, AsyncTestingSessionLocal = async_test_db
_test_engine, AsyncTestingSessionLocal = async_test_db
# Create multiple users
async with AsyncTestingSessionLocal() as session:
@@ -195,17 +193,13 @@ class TestGetMultiWithTotal:
email=f"sort{i}@example.com",
password="SecurePass123!",
first_name=f"User{i}",
last_name="Test"
last_name="Test",
)
await user_crud.create(session, obj_in=user_data)
async with AsyncTestingSessionLocal() as session:
users, total = await user_crud.get_multi_with_total(
session,
skip=0,
limit=10,
sort_by="email",
sort_order="asc"
users, _total = await user_crud.get_multi_with_total(
session, skip=0, limit=10, sort_by="email", sort_order="asc"
)
# Check if sorted (at least the test users)
@@ -216,7 +210,7 @@ class TestGetMultiWithTotal:
@pytest.mark.asyncio
async def test_get_multi_with_total_sorting_desc(self, async_test_db):
"""Test sorting in descending order."""
test_engine, AsyncTestingSessionLocal = async_test_db
_test_engine, AsyncTestingSessionLocal = async_test_db
# Create multiple users
async with AsyncTestingSessionLocal() as session:
@@ -225,17 +219,13 @@ class TestGetMultiWithTotal:
email=f"desc{i}@example.com",
password="SecurePass123!",
first_name=f"User{i}",
last_name="Test"
last_name="Test",
)
await user_crud.create(session, obj_in=user_data)
async with AsyncTestingSessionLocal() as session:
users, total = await user_crud.get_multi_with_total(
session,
skip=0,
limit=10,
sort_by="email",
sort_order="desc"
users, _total = await user_crud.get_multi_with_total(
session, skip=0, limit=10, sort_by="email", sort_order="desc"
)
# Check if sorted descending (at least the test users)
@@ -246,7 +236,7 @@ class TestGetMultiWithTotal:
@pytest.mark.asyncio
async def test_get_multi_with_total_filtering(self, async_test_db):
"""Test filtering by field."""
test_engine, AsyncTestingSessionLocal = async_test_db
_test_engine, AsyncTestingSessionLocal = async_test_db
# Create active and inactive users
async with AsyncTestingSessionLocal() as session:
@@ -254,7 +244,7 @@ class TestGetMultiWithTotal:
email="active@example.com",
password="SecurePass123!",
first_name="Active",
last_name="User"
last_name="User",
)
await user_crud.create(session, obj_in=active_user)
@@ -262,23 +252,18 @@ class TestGetMultiWithTotal:
email="inactive@example.com",
password="SecurePass123!",
first_name="Inactive",
last_name="User"
last_name="User",
)
created_inactive = await user_crud.create(session, obj_in=inactive_user)
# Deactivate the user
await user_crud.update(
session,
db_obj=created_inactive,
obj_in={"is_active": False}
session, db_obj=created_inactive, obj_in={"is_active": False}
)
async with AsyncTestingSessionLocal() as session:
users, total = await user_crud.get_multi_with_total(
session,
skip=0,
limit=100,
filters={"is_active": True}
users, _total = await user_crud.get_multi_with_total(
session, skip=0, limit=100, filters={"is_active": True}
)
# All returned users should be active
@@ -287,7 +272,7 @@ class TestGetMultiWithTotal:
@pytest.mark.asyncio
async def test_get_multi_with_total_search(self, async_test_db):
"""Test search functionality."""
test_engine, AsyncTestingSessionLocal = async_test_db
_test_engine, AsyncTestingSessionLocal = async_test_db
# Create user with unique name
async with AsyncTestingSessionLocal() as session:
@@ -295,16 +280,13 @@ class TestGetMultiWithTotal:
email="searchable@example.com",
password="SecurePass123!",
first_name="Searchable",
last_name="UserName"
last_name="UserName",
)
await user_crud.create(session, obj_in=user_data)
async with AsyncTestingSessionLocal() as session:
users, total = await user_crud.get_multi_with_total(
session,
skip=0,
limit=100,
search="Searchable"
session, skip=0, limit=100, search="Searchable"
)
assert total >= 1
@@ -313,7 +295,7 @@ class TestGetMultiWithTotal:
@pytest.mark.asyncio
async def test_get_multi_with_total_pagination(self, async_test_db):
"""Test pagination with skip and limit."""
test_engine, AsyncTestingSessionLocal = async_test_db
_test_engine, AsyncTestingSessionLocal = async_test_db
# Create multiple users
async with AsyncTestingSessionLocal() as session:
@@ -322,23 +304,19 @@ class TestGetMultiWithTotal:
email=f"page{i}@example.com",
password="SecurePass123!",
first_name=f"Page{i}",
last_name="User"
last_name="User",
)
await user_crud.create(session, obj_in=user_data)
async with AsyncTestingSessionLocal() as session:
# Get first page
users_page1, total = await user_crud.get_multi_with_total(
session,
skip=0,
limit=2
session, skip=0, limit=2
)
# Get second page
users_page2, total2 = await user_crud.get_multi_with_total(
session,
skip=2,
limit=2
session, skip=2, limit=2
)
# Total should be same
@@ -349,7 +327,7 @@ class TestGetMultiWithTotal:
@pytest.mark.asyncio
async def test_get_multi_with_total_validation_negative_skip(self, async_test_db):
"""Test validation fails for negative skip."""
test_engine, AsyncTestingSessionLocal = async_test_db
_test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session:
with pytest.raises(ValueError) as exc_info:
@@ -360,7 +338,7 @@ class TestGetMultiWithTotal:
@pytest.mark.asyncio
async def test_get_multi_with_total_validation_negative_limit(self, async_test_db):
"""Test validation fails for negative limit."""
test_engine, AsyncTestingSessionLocal = async_test_db
_test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session:
with pytest.raises(ValueError) as exc_info:
@@ -371,7 +349,7 @@ class TestGetMultiWithTotal:
@pytest.mark.asyncio
async def test_get_multi_with_total_validation_max_limit(self, async_test_db):
"""Test validation fails for limit > 1000."""
test_engine, AsyncTestingSessionLocal = async_test_db
_test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session:
with pytest.raises(ValueError) as exc_info:
@@ -386,7 +364,7 @@ class TestBulkUpdateStatus:
@pytest.mark.asyncio
async def test_bulk_update_status_success(self, async_test_db):
"""Test bulk updating user status."""
test_engine, AsyncTestingSessionLocal = async_test_db
_test_engine, AsyncTestingSessionLocal = async_test_db
# Create multiple users
user_ids = []
@@ -396,7 +374,7 @@ class TestBulkUpdateStatus:
email=f"bulk{i}@example.com",
password="SecurePass123!",
first_name=f"Bulk{i}",
last_name="User"
last_name="User",
)
user = await user_crud.create(session, obj_in=user_data)
user_ids.append(user.id)
@@ -404,9 +382,7 @@ class TestBulkUpdateStatus:
# Bulk deactivate
async with AsyncTestingSessionLocal() as session:
count = await user_crud.bulk_update_status(
session,
user_ids=user_ids,
is_active=False
session, user_ids=user_ids, is_active=False
)
assert count == 3
@@ -419,20 +395,18 @@ class TestBulkUpdateStatus:
@pytest.mark.asyncio
async def test_bulk_update_status_empty_list(self, async_test_db):
"""Test bulk update with empty list returns 0."""
test_engine, AsyncTestingSessionLocal = async_test_db
_test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session:
count = await user_crud.bulk_update_status(
session,
user_ids=[],
is_active=False
session, user_ids=[], is_active=False
)
assert count == 0
@pytest.mark.asyncio
async def test_bulk_update_status_reactivate(self, async_test_db):
"""Test bulk reactivating users."""
test_engine, AsyncTestingSessionLocal = async_test_db
_test_engine, AsyncTestingSessionLocal = async_test_db
# Create inactive user
async with AsyncTestingSessionLocal() as session:
@@ -440,7 +414,7 @@ class TestBulkUpdateStatus:
email="reactivate@example.com",
password="SecurePass123!",
first_name="Reactivate",
last_name="User"
last_name="User",
)
user = await user_crud.create(session, obj_in=user_data)
# Deactivate
@@ -450,9 +424,7 @@ class TestBulkUpdateStatus:
# Reactivate
async with AsyncTestingSessionLocal() as session:
count = await user_crud.bulk_update_status(
session,
user_ids=[user_id],
is_active=True
session, user_ids=[user_id], is_active=True
)
assert count == 1
@@ -468,7 +440,7 @@ class TestBulkSoftDelete:
@pytest.mark.asyncio
async def test_bulk_soft_delete_success(self, async_test_db):
"""Test bulk soft deleting users."""
test_engine, AsyncTestingSessionLocal = async_test_db
_test_engine, AsyncTestingSessionLocal = async_test_db
# Create multiple users
user_ids = []
@@ -478,17 +450,14 @@ class TestBulkSoftDelete:
email=f"delete{i}@example.com",
password="SecurePass123!",
first_name=f"Delete{i}",
last_name="User"
last_name="User",
)
user = await user_crud.create(session, obj_in=user_data)
user_ids.append(user.id)
# Bulk delete
async with AsyncTestingSessionLocal() as session:
count = await user_crud.bulk_soft_delete(
session,
user_ids=user_ids
)
count = await user_crud.bulk_soft_delete(session, user_ids=user_ids)
assert count == 3
# Verify all are soft deleted
@@ -501,7 +470,7 @@ class TestBulkSoftDelete:
@pytest.mark.asyncio
async def test_bulk_soft_delete_with_exclusion(self, async_test_db):
"""Test bulk soft delete with excluded user_crud."""
test_engine, AsyncTestingSessionLocal = async_test_db
_test_engine, AsyncTestingSessionLocal = async_test_db
# Create multiple users
user_ids = []
@@ -511,7 +480,7 @@ class TestBulkSoftDelete:
email=f"exclude{i}@example.com",
password="SecurePass123!",
first_name=f"Exclude{i}",
last_name="User"
last_name="User",
)
user = await user_crud.create(session, obj_in=user_data)
user_ids.append(user.id)
@@ -520,9 +489,7 @@ class TestBulkSoftDelete:
exclude_id = user_ids[0]
async with AsyncTestingSessionLocal() as session:
count = await user_crud.bulk_soft_delete(
session,
user_ids=user_ids,
exclude_user_id=exclude_id
session, user_ids=user_ids, exclude_user_id=exclude_id
)
assert count == 2 # Only 2 deleted
@@ -534,19 +501,16 @@ class TestBulkSoftDelete:
@pytest.mark.asyncio
async def test_bulk_soft_delete_empty_list(self, async_test_db):
"""Test bulk delete with empty list returns 0."""
test_engine, AsyncTestingSessionLocal = async_test_db
_test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session:
count = await user_crud.bulk_soft_delete(
session,
user_ids=[]
)
count = await user_crud.bulk_soft_delete(session, user_ids=[])
assert count == 0
@pytest.mark.asyncio
async def test_bulk_soft_delete_all_excluded(self, async_test_db):
"""Test bulk delete where all users are excluded."""
test_engine, AsyncTestingSessionLocal = async_test_db
_test_engine, AsyncTestingSessionLocal = async_test_db
# Create user
async with AsyncTestingSessionLocal() as session:
@@ -554,7 +518,7 @@ class TestBulkSoftDelete:
email="onlyuser@example.com",
password="SecurePass123!",
first_name="Only",
last_name="User"
last_name="User",
)
user = await user_crud.create(session, obj_in=user_data)
user_id = user.id
@@ -562,16 +526,14 @@ class TestBulkSoftDelete:
# Try to delete but exclude
async with AsyncTestingSessionLocal() as session:
count = await user_crud.bulk_soft_delete(
session,
user_ids=[user_id],
exclude_user_id=user_id
session, user_ids=[user_id], exclude_user_id=user_id
)
assert count == 0
@pytest.mark.asyncio
async def test_bulk_soft_delete_already_deleted(self, async_test_db):
"""Test bulk delete doesn't re-delete already deleted users."""
test_engine, AsyncTestingSessionLocal = async_test_db
_test_engine, AsyncTestingSessionLocal = async_test_db
# Create and delete user
async with AsyncTestingSessionLocal() as session:
@@ -579,7 +541,7 @@ class TestBulkSoftDelete:
email="predeleted@example.com",
password="SecurePass123!",
first_name="PreDeleted",
last_name="User"
last_name="User",
)
user = await user_crud.create(session, obj_in=user_data)
user_id = user.id
@@ -589,10 +551,7 @@ class TestBulkSoftDelete:
# Try to delete again
async with AsyncTestingSessionLocal() as session:
count = await user_crud.bulk_soft_delete(
session,
user_ids=[user_id]
)
count = await user_crud.bulk_soft_delete(session, user_ids=[user_id])
assert count == 0 # Already deleted
@@ -602,7 +561,7 @@ class TestUtilityMethods:
@pytest.mark.asyncio
async def test_is_active_true(self, async_test_db, async_test_user):
"""Test is_active returns True for active user_crud."""
test_engine, AsyncTestingSessionLocal = async_test_db
_test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session:
user = await user_crud.get(session, id=str(async_test_user.id))
@@ -611,14 +570,14 @@ class TestUtilityMethods:
@pytest.mark.asyncio
async def test_is_active_false(self, async_test_db):
"""Test is_active returns False for inactive user_crud."""
test_engine, AsyncTestingSessionLocal = async_test_db
_test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session:
user_data = UserCreate(
email="inactive2@example.com",
password="SecurePass123!",
first_name="Inactive",
last_name="User"
last_name="User",
)
user = await user_crud.create(session, obj_in=user_data)
await user_crud.update(session, db_obj=user, obj_in={"is_active": False})
@@ -628,7 +587,7 @@ class TestUtilityMethods:
@pytest.mark.asyncio
async def test_is_superuser_true(self, async_test_db, async_test_superuser):
"""Test is_superuser returns True for superuser."""
test_engine, AsyncTestingSessionLocal = async_test_db
_test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session:
user = await user_crud.get(session, id=str(async_test_superuser.id))
@@ -637,7 +596,7 @@ class TestUtilityMethods:
@pytest.mark.asyncio
async def test_is_superuser_false(self, async_test_db, async_test_user):
"""Test is_superuser returns False for regular user_crud."""
test_engine, AsyncTestingSessionLocal = async_test_db
_test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session:
user = await user_crud.get(session, id=str(async_test_user.id))
@@ -654,42 +613,52 @@ class TestUserExceptionHandlers:
async def test_get_by_email_database_error(self, async_test_db):
"""Test get_by_email handles database errors (covers lines 30-32)."""
from unittest.mock import patch
test_engine, AsyncTestingSessionLocal = async_test_db
_test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session:
with patch.object(session, 'execute', side_effect=Exception("Database query failed")):
with patch.object(
session, "execute", side_effect=Exception("Database query failed")
):
with pytest.raises(Exception, match="Database query failed"):
await user_crud.get_by_email(session, email="test@example.com")
@pytest.mark.asyncio
async def test_bulk_update_status_database_error(self, async_test_db, async_test_user):
async def test_bulk_update_status_database_error(
self, async_test_db, async_test_user
):
"""Test bulk_update_status handles database errors (covers lines 205-208)."""
from unittest.mock import patch, AsyncMock
test_engine, AsyncTestingSessionLocal = async_test_db
from unittest.mock import AsyncMock, patch
_test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session:
# Mock execute to fail
with patch.object(session, 'execute', side_effect=Exception("Bulk update failed")):
with patch.object(session, 'rollback', new_callable=AsyncMock):
with patch.object(
session, "execute", side_effect=Exception("Bulk update failed")
):
with patch.object(session, "rollback", new_callable=AsyncMock):
with pytest.raises(Exception, match="Bulk update failed"):
await user_crud.bulk_update_status(
session,
user_ids=[async_test_user.id],
is_active=False
session, user_ids=[async_test_user.id], is_active=False
)
@pytest.mark.asyncio
async def test_bulk_soft_delete_database_error(self, async_test_db, async_test_user):
async def test_bulk_soft_delete_database_error(
self, async_test_db, async_test_user
):
"""Test bulk_soft_delete handles database errors (covers lines 257-260)."""
from unittest.mock import patch, AsyncMock
test_engine, AsyncTestingSessionLocal = async_test_db
from unittest.mock import AsyncMock, patch
_test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session:
# Mock execute to fail
with patch.object(session, 'execute', side_effect=Exception("Bulk delete failed")):
with patch.object(session, 'rollback', new_callable=AsyncMock):
with patch.object(
session, "execute", side_effect=Exception("Bulk delete failed")
):
with patch.object(session, "rollback", new_callable=AsyncMock):
with pytest.raises(Exception, match="Bulk delete failed"):
await user_crud.bulk_soft_delete(
session,
user_ids=[async_test_user.id]
session, user_ids=[async_test_user.id]
)