refactor(tests): replace crud references with repo across repository test files

- Updated import statements and test logic to align with `repositories` naming changes.
- Adjusted documentation and test names for consistency with the updated naming convention.
- Improved test descriptions to reflect the repository-based structure.
This commit is contained in:
2026-03-01 19:22:16 +01:00
parent 07309013d7
commit a3f78dc801
38 changed files with 409 additions and 409 deletions

View File

@@ -1,6 +1,6 @@
# tests/crud/test_base.py
# tests/repositories/test_base.py
"""
Comprehensive tests for CRUDBase class covering all error paths and edge cases.
Comprehensive tests for BaseRepository class covering all error paths and edge cases.
"""
from datetime import UTC
@@ -16,11 +16,11 @@ from app.core.repository_exceptions import (
IntegrityConstraintError,
InvalidInputError,
)
from app.repositories.user import user_repo as user_crud
from app.repositories.user import user_repo as user_repo
from app.schemas.users import UserCreate, UserUpdate
class TestCRUDBaseGet:
class TestRepositoryBaseGet:
"""Tests for get method covering UUID validation and options."""
@pytest.mark.asyncio
@@ -29,7 +29,7 @@ class TestCRUDBaseGet:
_test_engine, SessionLocal = async_test_db
async with SessionLocal() as session:
result = await user_crud.get(session, id="invalid-uuid")
result = await user_repo.get(session, id="invalid-uuid")
assert result is None
@pytest.mark.asyncio
@@ -38,7 +38,7 @@ class TestCRUDBaseGet:
_test_engine, SessionLocal = async_test_db
async with SessionLocal() as session:
result = await user_crud.get(session, id=12345) # int instead of UUID
result = await user_repo.get(session, id=12345) # int instead of UUID
assert result is None
@pytest.mark.asyncio
@@ -48,7 +48,7 @@ class TestCRUDBaseGet:
async with SessionLocal() as session:
# Pass UUID object directly
result = await user_crud.get(session, id=async_test_user.id)
result = await user_repo.get(session, id=async_test_user.id)
assert result is not None
assert result.id == async_test_user.id
@@ -60,7 +60,7 @@ class TestCRUDBaseGet:
async with SessionLocal() as session:
# Test that options parameter is accepted and doesn't error
# We pass an empty list which still tests the code path
result = await user_crud.get(
result = await user_repo.get(
session, id=str(async_test_user.id), options=[]
)
assert result is not None
@@ -74,10 +74,10 @@ class TestCRUDBaseGet:
# Mock execute to raise an exception
with patch.object(session, "execute", side_effect=Exception("DB error")):
with pytest.raises(Exception, match="DB error"):
await user_crud.get(session, id=str(uuid4()))
await user_repo.get(session, id=str(uuid4()))
class TestCRUDBaseGetMulti:
class TestRepositoryBaseGetMulti:
"""Tests for get_multi method covering pagination validation and options."""
@pytest.mark.asyncio
@@ -87,7 +87,7 @@ class TestCRUDBaseGetMulti:
async with SessionLocal() as session:
with pytest.raises(InvalidInputError, match="skip must be non-negative"):
await user_crud.get_multi(session, skip=-1)
await user_repo.get_multi(session, skip=-1)
@pytest.mark.asyncio
async def test_get_multi_negative_limit(self, async_test_db):
@@ -96,7 +96,7 @@ class TestCRUDBaseGetMulti:
async with SessionLocal() as session:
with pytest.raises(InvalidInputError, match="limit must be non-negative"):
await user_crud.get_multi(session, limit=-1)
await user_repo.get_multi(session, limit=-1)
@pytest.mark.asyncio
async def test_get_multi_limit_too_large(self, async_test_db):
@@ -105,7 +105,7 @@ class TestCRUDBaseGetMulti:
async with SessionLocal() as session:
with pytest.raises(InvalidInputError, match="Maximum limit is 1000"):
await user_crud.get_multi(session, limit=1001)
await user_repo.get_multi(session, limit=1001)
@pytest.mark.asyncio
async def test_get_multi_with_options(self, async_test_db, async_test_user):
@@ -114,7 +114,7 @@ class TestCRUDBaseGetMulti:
async with SessionLocal() as session:
# Test that options parameter is accepted
results = await user_crud.get_multi(session, skip=0, limit=10, options=[])
results = await user_repo.get_multi(session, skip=0, limit=10, options=[])
assert isinstance(results, list)
@pytest.mark.asyncio
@@ -125,10 +125,10 @@ class TestCRUDBaseGetMulti:
async with SessionLocal() as session:
with patch.object(session, "execute", side_effect=Exception("DB error")):
with pytest.raises(Exception, match="DB error"):
await user_crud.get_multi(session)
await user_repo.get_multi(session)
class TestCRUDBaseCreate:
class TestRepositoryBaseCreate:
"""Tests for create method covering various error conditions."""
@pytest.mark.asyncio
@@ -146,7 +146,7 @@ class TestCRUDBaseCreate:
)
with pytest.raises(DuplicateEntryError, match="already exists"):
await user_crud.create(session, obj_in=user_data)
await user_repo.create(session, obj_in=user_data)
@pytest.mark.asyncio
async def test_create_integrity_error_non_duplicate(self, async_test_db):
@@ -173,11 +173,11 @@ class TestCRUDBaseCreate:
with pytest.raises(
DuplicateEntryError, match="Database integrity error"
):
await user_crud.create(session, obj_in=user_data)
await user_repo.create(session, obj_in=user_data)
@pytest.mark.asyncio
async def test_create_operational_error(self, async_test_db):
"""Test create with OperationalError (user CRUD catches as generic Exception)."""
"""Test create with OperationalError (user repository catches as generic Exception)."""
_test_engine, SessionLocal = async_test_db
async with SessionLocal() as session:
@@ -195,13 +195,13 @@ class TestCRUDBaseCreate:
last_name="User",
)
# User CRUD catches this as generic Exception and re-raises
# User repository catches this as generic Exception and re-raises
with pytest.raises(OperationalError):
await user_crud.create(session, obj_in=user_data)
await user_repo.create(session, obj_in=user_data)
@pytest.mark.asyncio
async def test_create_data_error(self, async_test_db):
"""Test create with DataError (user CRUD catches as generic Exception)."""
"""Test create with DataError (user repository catches as generic Exception)."""
_test_engine, SessionLocal = async_test_db
async with SessionLocal() as session:
@@ -217,9 +217,9 @@ class TestCRUDBaseCreate:
last_name="User",
)
# User CRUD catches this as generic Exception and re-raises
# User repository catches this as generic Exception and re-raises
with pytest.raises(DataError):
await user_crud.create(session, obj_in=user_data)
await user_repo.create(session, obj_in=user_data)
@pytest.mark.asyncio
async def test_create_unexpected_error(self, async_test_db):
@@ -238,10 +238,10 @@ class TestCRUDBaseCreate:
)
with pytest.raises(RuntimeError, match="Unexpected error"):
await user_crud.create(session, obj_in=user_data)
await user_repo.create(session, obj_in=user_data)
class TestCRUDBaseUpdate:
class TestRepositoryBaseUpdate:
"""Tests for update method covering error conditions."""
@pytest.mark.asyncio
@@ -251,7 +251,7 @@ class TestCRUDBaseUpdate:
# Create another user
async with SessionLocal() as session:
from app.repositories.user import user_repo as user_crud
from app.repositories.user import user_repo as user_repo
user2_data = UserCreate(
email="user2@example.com",
@@ -259,12 +259,12 @@ class TestCRUDBaseUpdate:
first_name="User",
last_name="Two",
)
user2 = await user_crud.create(session, obj_in=user2_data)
user2 = await user_repo.create(session, obj_in=user2_data)
await session.commit()
# Try to update user2 with user1's email
async with SessionLocal() as session:
user2_obj = await user_crud.get(session, id=str(user2.id))
user2_obj = await user_repo.get(session, id=str(user2.id))
with patch.object(
session,
@@ -276,7 +276,7 @@ class TestCRUDBaseUpdate:
update_data = UserUpdate(email=async_test_user.email)
with pytest.raises(DuplicateEntryError, match="already exists"):
await user_crud.update(
await user_repo.update(
session, db_obj=user2_obj, obj_in=update_data
)
@@ -286,10 +286,10 @@ class TestCRUDBaseUpdate:
_test_engine, SessionLocal = async_test_db
async with SessionLocal() as session:
user = await user_crud.get(session, id=str(async_test_user.id))
user = await user_repo.get(session, id=str(async_test_user.id))
# Update with dict (tests lines 164-165)
updated = await user_crud.update(
updated = await user_repo.update(
session, db_obj=user, obj_in={"first_name": "UpdatedName"}
)
assert updated.first_name == "UpdatedName"
@@ -300,7 +300,7 @@ class TestCRUDBaseUpdate:
_test_engine, SessionLocal = async_test_db
async with SessionLocal() as session:
user = await user_crud.get(session, id=str(async_test_user.id))
user = await user_repo.get(session, id=str(async_test_user.id))
with patch.object(
session,
@@ -312,7 +312,7 @@ class TestCRUDBaseUpdate:
with pytest.raises(
IntegrityConstraintError, match="Database integrity error"
):
await user_crud.update(
await user_repo.update(
session, db_obj=user, obj_in={"first_name": "Test"}
)
@@ -322,7 +322,7 @@ class TestCRUDBaseUpdate:
_test_engine, SessionLocal = async_test_db
async with SessionLocal() as session:
user = await user_crud.get(session, id=str(async_test_user.id))
user = await user_repo.get(session, id=str(async_test_user.id))
with patch.object(
session,
@@ -334,7 +334,7 @@ class TestCRUDBaseUpdate:
with pytest.raises(
IntegrityConstraintError, match="Database operation failed"
):
await user_crud.update(
await user_repo.update(
session, db_obj=user, obj_in={"first_name": "Test"}
)
@@ -344,18 +344,18 @@ class TestCRUDBaseUpdate:
_test_engine, SessionLocal = async_test_db
async with SessionLocal() as session:
user = await user_crud.get(session, id=str(async_test_user.id))
user = await user_repo.get(session, id=str(async_test_user.id))
with patch.object(
session, "commit", side_effect=RuntimeError("Unexpected")
):
with pytest.raises(RuntimeError):
await user_crud.update(
await user_repo.update(
session, db_obj=user, obj_in={"first_name": "Test"}
)
class TestCRUDBaseRemove:
class TestRepositoryBaseRemove:
"""Tests for remove method covering UUID validation and error conditions."""
@pytest.mark.asyncio
@@ -364,7 +364,7 @@ class TestCRUDBaseRemove:
_test_engine, SessionLocal = async_test_db
async with SessionLocal() as session:
result = await user_crud.remove(session, id="invalid-uuid")
result = await user_repo.remove(session, id="invalid-uuid")
assert result is None
@pytest.mark.asyncio
@@ -380,13 +380,13 @@ class TestCRUDBaseRemove:
first_name="To",
last_name="Delete",
)
user = await user_crud.create(session, obj_in=user_data)
user = await user_repo.create(session, obj_in=user_data)
user_id = user.id
await session.commit()
# Delete with UUID object
async with SessionLocal() as session:
result = await user_crud.remove(session, id=user_id) # UUID object
result = await user_repo.remove(session, id=user_id) # UUID object
assert result is not None
assert result.id == user_id
@@ -396,7 +396,7 @@ class TestCRUDBaseRemove:
_test_engine, SessionLocal = async_test_db
async with SessionLocal() as session:
result = await user_crud.remove(session, id=str(uuid4()))
result = await user_repo.remove(session, id=str(uuid4()))
assert result is None
@pytest.mark.asyncio
@@ -417,7 +417,7 @@ class TestCRUDBaseRemove:
IntegrityConstraintError,
match="Cannot delete.*referenced by other records",
):
await user_crud.remove(session, id=str(async_test_user.id))
await user_repo.remove(session, id=str(async_test_user.id))
@pytest.mark.asyncio
async def test_remove_unexpected_error(self, async_test_db, async_test_user):
@@ -429,10 +429,10 @@ class TestCRUDBaseRemove:
session, "commit", side_effect=RuntimeError("Unexpected")
):
with pytest.raises(RuntimeError):
await user_crud.remove(session, id=str(async_test_user.id))
await user_repo.remove(session, id=str(async_test_user.id))
class TestCRUDBaseGetMultiWithTotal:
class TestRepositoryBaseGetMultiWithTotal:
"""Tests for get_multi_with_total method covering pagination, filtering, sorting."""
@pytest.mark.asyncio
@@ -441,7 +441,7 @@ class TestCRUDBaseGetMultiWithTotal:
_test_engine, SessionLocal = async_test_db
async with SessionLocal() as session:
items, total = await user_crud.get_multi_with_total(
items, total = await user_repo.get_multi_with_total(
session, skip=0, limit=10
)
assert isinstance(items, list)
@@ -455,7 +455,7 @@ class TestCRUDBaseGetMultiWithTotal:
async with SessionLocal() as session:
with pytest.raises(InvalidInputError, match="skip must be non-negative"):
await user_crud.get_multi_with_total(session, skip=-1)
await user_repo.get_multi_with_total(session, skip=-1)
@pytest.mark.asyncio
async def test_get_multi_with_total_negative_limit(self, async_test_db):
@@ -464,7 +464,7 @@ class TestCRUDBaseGetMultiWithTotal:
async with SessionLocal() as session:
with pytest.raises(InvalidInputError, match="limit must be non-negative"):
await user_crud.get_multi_with_total(session, limit=-1)
await user_repo.get_multi_with_total(session, limit=-1)
@pytest.mark.asyncio
async def test_get_multi_with_total_limit_too_large(self, async_test_db):
@@ -473,7 +473,7 @@ class TestCRUDBaseGetMultiWithTotal:
async with SessionLocal() as session:
with pytest.raises(InvalidInputError, match="Maximum limit is 1000"):
await user_crud.get_multi_with_total(session, limit=1001)
await user_repo.get_multi_with_total(session, limit=1001)
@pytest.mark.asyncio
async def test_get_multi_with_total_with_filters(
@@ -484,7 +484,7 @@ class TestCRUDBaseGetMultiWithTotal:
async with SessionLocal() as session:
filters = {"email": async_test_user.email}
items, total = await user_crud.get_multi_with_total(
items, total = await user_repo.get_multi_with_total(
session, filters=filters
)
assert total == 1
@@ -512,12 +512,12 @@ class TestCRUDBaseGetMultiWithTotal:
first_name="ZZZ",
last_name="User",
)
await user_crud.create(session, obj_in=user_data1)
await user_crud.create(session, obj_in=user_data2)
await user_repo.create(session, obj_in=user_data1)
await user_repo.create(session, obj_in=user_data2)
await session.commit()
async with SessionLocal() as session:
items, total = await user_crud.get_multi_with_total(
items, total = await user_repo.get_multi_with_total(
session, sort_by="email", sort_order="asc"
)
assert total >= 3
@@ -545,12 +545,12 @@ class TestCRUDBaseGetMultiWithTotal:
first_name="CCC",
last_name="User",
)
await user_crud.create(session, obj_in=user_data1)
await user_crud.create(session, obj_in=user_data2)
await user_repo.create(session, obj_in=user_data1)
await user_repo.create(session, obj_in=user_data2)
await session.commit()
async with SessionLocal() as session:
items, _total = await user_crud.get_multi_with_total(
items, _total = await user_repo.get_multi_with_total(
session, sort_by="email", sort_order="desc", limit=1
)
assert len(items) == 1
@@ -570,19 +570,19 @@ class TestCRUDBaseGetMultiWithTotal:
first_name=f"User{i}",
last_name="Test",
)
await user_crud.create(session, obj_in=user_data)
await user_repo.create(session, obj_in=user_data)
await session.commit()
async with SessionLocal() as session:
# Get first page
items1, total = await user_crud.get_multi_with_total(
items1, total = await user_repo.get_multi_with_total(
session, skip=0, limit=2
)
assert len(items1) == 2
assert total >= 3
# Get second page
items2, total2 = await user_crud.get_multi_with_total(
items2, total2 = await user_repo.get_multi_with_total(
session, skip=2, limit=2
)
assert len(items2) >= 1
@@ -594,7 +594,7 @@ class TestCRUDBaseGetMultiWithTotal:
assert ids1.isdisjoint(ids2)
class TestCRUDBaseCount:
class TestRepositoryBaseCount:
"""Tests for count method."""
@pytest.mark.asyncio
@@ -603,7 +603,7 @@ class TestCRUDBaseCount:
_test_engine, SessionLocal = async_test_db
async with SessionLocal() as session:
count = await user_crud.count(session)
count = await user_repo.count(session)
assert isinstance(count, int)
assert count >= 1 # At least the test user
@@ -614,7 +614,7 @@ class TestCRUDBaseCount:
# Create additional users
async with SessionLocal() as session:
initial_count = await user_crud.count(session)
initial_count = await user_repo.count(session)
user_data1 = UserCreate(
email="count1@example.com",
@@ -628,12 +628,12 @@ class TestCRUDBaseCount:
first_name="Count",
last_name="Two",
)
await user_crud.create(session, obj_in=user_data1)
await user_crud.create(session, obj_in=user_data2)
await user_repo.create(session, obj_in=user_data1)
await user_repo.create(session, obj_in=user_data2)
await session.commit()
async with SessionLocal() as session:
new_count = await user_crud.count(session)
new_count = await user_repo.count(session)
assert new_count == initial_count + 2
@pytest.mark.asyncio
@@ -644,10 +644,10 @@ class TestCRUDBaseCount:
async with SessionLocal() as session:
with patch.object(session, "execute", side_effect=Exception("DB error")):
with pytest.raises(Exception, match="DB error"):
await user_crud.count(session)
await user_repo.count(session)
class TestCRUDBaseExists:
class TestRepositoryBaseExists:
"""Tests for exists method."""
@pytest.mark.asyncio
@@ -656,7 +656,7 @@ class TestCRUDBaseExists:
_test_engine, SessionLocal = async_test_db
async with SessionLocal() as session:
result = await user_crud.exists(session, id=str(async_test_user.id))
result = await user_repo.exists(session, id=str(async_test_user.id))
assert result is True
@pytest.mark.asyncio
@@ -665,7 +665,7 @@ class TestCRUDBaseExists:
_test_engine, SessionLocal = async_test_db
async with SessionLocal() as session:
result = await user_crud.exists(session, id=str(uuid4()))
result = await user_repo.exists(session, id=str(uuid4()))
assert result is False
@pytest.mark.asyncio
@@ -674,11 +674,11 @@ class TestCRUDBaseExists:
_test_engine, SessionLocal = async_test_db
async with SessionLocal() as session:
result = await user_crud.exists(session, id="invalid-uuid")
result = await user_repo.exists(session, id="invalid-uuid")
assert result is False
class TestCRUDBaseSoftDelete:
class TestRepositoryBaseSoftDelete:
"""Tests for soft_delete method."""
@pytest.mark.asyncio
@@ -694,13 +694,13 @@ class TestCRUDBaseSoftDelete:
first_name="Soft",
last_name="Delete",
)
user = await user_crud.create(session, obj_in=user_data)
user = await user_repo.create(session, obj_in=user_data)
user_id = user.id
await session.commit()
# Soft delete the user
async with SessionLocal() as session:
deleted = await user_crud.soft_delete(session, id=str(user_id))
deleted = await user_repo.soft_delete(session, id=str(user_id))
assert deleted is not None
assert deleted.deleted_at is not None
@@ -710,7 +710,7 @@ class TestCRUDBaseSoftDelete:
_test_engine, SessionLocal = async_test_db
async with SessionLocal() as session:
result = await user_crud.soft_delete(session, id="invalid-uuid")
result = await user_repo.soft_delete(session, id="invalid-uuid")
assert result is None
@pytest.mark.asyncio
@@ -719,7 +719,7 @@ class TestCRUDBaseSoftDelete:
_test_engine, SessionLocal = async_test_db
async with SessionLocal() as session:
result = await user_crud.soft_delete(session, id=str(uuid4()))
result = await user_repo.soft_delete(session, id=str(uuid4()))
assert result is None
@pytest.mark.asyncio
@@ -735,18 +735,18 @@ class TestCRUDBaseSoftDelete:
first_name="Soft",
last_name="Delete2",
)
user = await user_crud.create(session, obj_in=user_data)
user = await user_repo.create(session, obj_in=user_data)
user_id = user.id
await session.commit()
# Soft delete with UUID object
async with SessionLocal() as session:
deleted = await user_crud.soft_delete(session, id=user_id) # UUID object
deleted = await user_repo.soft_delete(session, id=user_id) # UUID object
assert deleted is not None
assert deleted.deleted_at is not None
class TestCRUDBaseRestore:
class TestRepositoryBaseRestore:
"""Tests for restore method."""
@pytest.mark.asyncio
@@ -762,16 +762,16 @@ class TestCRUDBaseRestore:
first_name="Restore",
last_name="Test",
)
user = await user_crud.create(session, obj_in=user_data)
user = await user_repo.create(session, obj_in=user_data)
user_id = user.id
await session.commit()
async with SessionLocal() as session:
await user_crud.soft_delete(session, id=str(user_id))
await user_repo.soft_delete(session, id=str(user_id))
# Restore the user
async with SessionLocal() as session:
restored = await user_crud.restore(session, id=str(user_id))
restored = await user_repo.restore(session, id=str(user_id))
assert restored is not None
assert restored.deleted_at is None
@@ -781,7 +781,7 @@ class TestCRUDBaseRestore:
_test_engine, SessionLocal = async_test_db
async with SessionLocal() as session:
result = await user_crud.restore(session, id="invalid-uuid")
result = await user_repo.restore(session, id="invalid-uuid")
assert result is None
@pytest.mark.asyncio
@@ -790,7 +790,7 @@ class TestCRUDBaseRestore:
_test_engine, SessionLocal = async_test_db
async with SessionLocal() as session:
result = await user_crud.restore(session, id=str(uuid4()))
result = await user_repo.restore(session, id=str(uuid4()))
assert result is None
@pytest.mark.asyncio
@@ -800,7 +800,7 @@ class TestCRUDBaseRestore:
async with SessionLocal() as session:
# Try to restore a user that's not deleted
result = await user_crud.restore(session, id=str(async_test_user.id))
result = await user_repo.restore(session, id=str(async_test_user.id))
assert result is None
@pytest.mark.asyncio
@@ -816,21 +816,21 @@ class TestCRUDBaseRestore:
first_name="Restore",
last_name="Test2",
)
user = await user_crud.create(session, obj_in=user_data)
user = await user_repo.create(session, obj_in=user_data)
user_id = user.id
await session.commit()
async with SessionLocal() as session:
await user_crud.soft_delete(session, id=str(user_id))
await user_repo.soft_delete(session, id=str(user_id))
# Restore with UUID object
async with SessionLocal() as session:
restored = await user_crud.restore(session, id=user_id) # UUID object
restored = await user_repo.restore(session, id=user_id) # UUID object
assert restored is not None
assert restored.deleted_at is None
class TestCRUDBasePaginationValidation:
class TestRepositoryBasePaginationValidation:
"""Tests for pagination parameter validation (covers lines 254-260)."""
@pytest.mark.asyncio
@@ -840,7 +840,7 @@ class TestCRUDBasePaginationValidation:
async with SessionLocal() as session:
with pytest.raises(InvalidInputError, match="skip must be non-negative"):
await user_crud.get_multi_with_total(session, skip=-1, limit=10)
await user_repo.get_multi_with_total(session, skip=-1, limit=10)
@pytest.mark.asyncio
async def test_get_multi_with_total_negative_limit(self, async_test_db):
@@ -849,7 +849,7 @@ class TestCRUDBasePaginationValidation:
async with SessionLocal() as session:
with pytest.raises(InvalidInputError, match="limit must be non-negative"):
await user_crud.get_multi_with_total(session, skip=0, limit=-1)
await user_repo.get_multi_with_total(session, skip=0, limit=-1)
@pytest.mark.asyncio
async def test_get_multi_with_total_limit_too_large(self, async_test_db):
@@ -858,7 +858,7 @@ class TestCRUDBasePaginationValidation:
async with SessionLocal() as session:
with pytest.raises(InvalidInputError, match="Maximum limit is 1000"):
await user_crud.get_multi_with_total(session, skip=0, limit=1001)
await user_repo.get_multi_with_total(session, skip=0, limit=1001)
@pytest.mark.asyncio
async def test_get_multi_with_total_with_filters(
@@ -868,7 +868,7 @@ class TestCRUDBasePaginationValidation:
_test_engine, SessionLocal = async_test_db
async with SessionLocal() as session:
users, total = await user_crud.get_multi_with_total(
users, total = await user_repo.get_multi_with_total(
session, skip=0, limit=10, filters={"is_active": True}
)
assert isinstance(users, list)
@@ -880,7 +880,7 @@ class TestCRUDBasePaginationValidation:
_test_engine, SessionLocal = async_test_db
async with SessionLocal() as session:
users, _total = await user_crud.get_multi_with_total(
users, _total = await user_repo.get_multi_with_total(
session, skip=0, limit=10, sort_by="created_at", sort_order="desc"
)
assert isinstance(users, list)
@@ -891,13 +891,13 @@ class TestCRUDBasePaginationValidation:
_test_engine, SessionLocal = async_test_db
async with SessionLocal() as session:
users, _total = await user_crud.get_multi_with_total(
users, _total = await user_repo.get_multi_with_total(
session, skip=0, limit=10, sort_by="created_at", sort_order="asc"
)
assert isinstance(users, list)
class TestCRUDBaseModelsWithoutSoftDelete:
class TestRepositoryBaseModelsWithoutSoftDelete:
"""
Test soft_delete and restore on models without deleted_at column.
Covers lines 342-343, 383-384 - error handling for unsupported models.
@@ -912,7 +912,7 @@ class TestCRUDBaseModelsWithoutSoftDelete:
# Create an organization (which doesn't have deleted_at)
from app.models.organization import Organization
from app.repositories.organization import organization_repo as org_crud
from app.repositories.organization import organization_repo as org_repo
async with SessionLocal() as session:
org = Organization(name="Test Org", slug="test-org")
@@ -925,7 +925,7 @@ class TestCRUDBaseModelsWithoutSoftDelete:
with pytest.raises(
InvalidInputError, match="does not have a deleted_at column"
):
await org_crud.soft_delete(session, id=str(org_id))
await org_repo.soft_delete(session, id=str(org_id))
@pytest.mark.asyncio
async def test_restore_model_without_deleted_at(self, async_test_db):
@@ -934,7 +934,7 @@ class TestCRUDBaseModelsWithoutSoftDelete:
# Create an organization (which doesn't have deleted_at)
from app.models.organization import Organization
from app.repositories.organization import organization_repo as org_crud
from app.repositories.organization import organization_repo as org_repo
async with SessionLocal() as session:
org = Organization(name="Restore Test", slug="restore-test")
@@ -947,10 +947,10 @@ class TestCRUDBaseModelsWithoutSoftDelete:
with pytest.raises(
InvalidInputError, match="does not have a deleted_at column"
):
await org_crud.restore(session, id=str(org_id))
await org_repo.restore(session, id=str(org_id))
class TestCRUDBaseEagerLoadingWithRealOptions:
class TestRepositoryBaseEagerLoadingWithRealOptions:
"""
Test eager loading with actual SQLAlchemy load options.
Covers lines 77-78, 119-120 - options loop execution.
@@ -967,7 +967,7 @@ class TestCRUDBaseEagerLoadingWithRealOptions:
# Create a session for the user
from app.models.user_session import UserSession
from app.repositories.session import session_repo as session_crud
from app.repositories.session import session_repo as session_repo
async with SessionLocal() as session:
user_session = UserSession(
@@ -985,7 +985,7 @@ class TestCRUDBaseEagerLoadingWithRealOptions:
# Get session with eager loading of user relationship
async with SessionLocal() as session:
result = await session_crud.get(
result = await session_repo.get(
session,
id=str(session_id),
options=[joinedload(UserSession.user)], # Real option, not empty list
@@ -1006,7 +1006,7 @@ class TestCRUDBaseEagerLoadingWithRealOptions:
# Create multiple sessions for the user
from app.models.user_session import UserSession
from app.repositories.session import session_repo as session_crud
from app.repositories.session import session_repo as session_repo
async with SessionLocal() as session:
for i in range(3):
@@ -1024,7 +1024,7 @@ class TestCRUDBaseEagerLoadingWithRealOptions:
# Get sessions with eager loading
async with SessionLocal() as session:
results = await session_crud.get_multi(
results = await session_repo.get_multi(
session,
skip=0,
limit=10,

View File

@@ -1,6 +1,6 @@
# tests/crud/test_base_db_failures.py
# tests/repositories/test_base_db_failures.py
"""
Comprehensive tests for base CRUD database failure scenarios.
Comprehensive tests for base repository database failure scenarios.
Tests exception handling, rollbacks, and error messages.
"""
@@ -11,16 +11,16 @@ import pytest
from sqlalchemy.exc import DataError, OperationalError
from app.core.repository_exceptions import IntegrityConstraintError
from app.repositories.user import user_repo as user_crud
from app.repositories.user import user_repo as user_repo
from app.schemas.users import UserCreate
class TestBaseCRUDCreateFailures:
"""Test base CRUD create method exception handling."""
class TestBaseRepositoryCreateFailures:
"""Test base repository create method exception handling."""
@pytest.mark.asyncio
async def test_create_operational_error_triggers_rollback(self, async_test_db):
"""Test that OperationalError triggers rollback (User CRUD catches as Exception)."""
"""Test that OperationalError triggers rollback (User repository catches as Exception)."""
_test_engine, SessionLocal = async_test_db
async with SessionLocal() as session:
@@ -41,16 +41,16 @@ class TestBaseCRUDCreateFailures:
last_name="User",
)
# User CRUD catches this as generic Exception and re-raises
# User repository catches this as generic Exception and re-raises
with pytest.raises(OperationalError):
await user_crud.create(session, obj_in=user_data)
await user_repo.create(session, obj_in=user_data)
# Verify rollback was called
mock_rollback.assert_called_once()
@pytest.mark.asyncio
async def test_create_data_error_triggers_rollback(self, async_test_db):
"""Test that DataError triggers rollback (User CRUD catches as Exception)."""
"""Test that DataError triggers rollback (User repository catches as Exception)."""
_test_engine, SessionLocal = async_test_db
async with SessionLocal() as session:
@@ -69,9 +69,9 @@ class TestBaseCRUDCreateFailures:
last_name="User",
)
# User CRUD catches this as generic Exception and re-raises
# User repository catches this as generic Exception and re-raises
with pytest.raises(DataError):
await user_crud.create(session, obj_in=user_data)
await user_repo.create(session, obj_in=user_data)
mock_rollback.assert_called_once()
@@ -97,13 +97,13 @@ class TestBaseCRUDCreateFailures:
)
with pytest.raises(RuntimeError, match="Unexpected database error"):
await user_crud.create(session, obj_in=user_data)
await user_repo.create(session, obj_in=user_data)
mock_rollback.assert_called_once()
class TestBaseCRUDUpdateFailures:
"""Test base CRUD update method exception handling."""
class TestBaseRepositoryUpdateFailures:
"""Test base repository update method exception handling."""
@pytest.mark.asyncio
async def test_update_operational_error(self, async_test_db, async_test_user):
@@ -111,7 +111,7 @@ class TestBaseCRUDUpdateFailures:
_test_engine, SessionLocal = async_test_db
async with SessionLocal() as session:
user = await user_crud.get(session, id=str(async_test_user.id))
user = await user_repo.get(session, id=str(async_test_user.id))
async def mock_commit():
raise OperationalError("Connection timeout", {}, Exception("Timeout"))
@@ -123,7 +123,7 @@ class TestBaseCRUDUpdateFailures:
with pytest.raises(
IntegrityConstraintError, match="Database operation failed"
):
await user_crud.update(
await user_repo.update(
session, db_obj=user, obj_in={"first_name": "Updated"}
)
@@ -135,7 +135,7 @@ class TestBaseCRUDUpdateFailures:
_test_engine, SessionLocal = async_test_db
async with SessionLocal() as session:
user = await user_crud.get(session, id=str(async_test_user.id))
user = await user_repo.get(session, id=str(async_test_user.id))
async def mock_commit():
raise DataError("Invalid data", {}, Exception("Data type mismatch"))
@@ -147,7 +147,7 @@ class TestBaseCRUDUpdateFailures:
with pytest.raises(
IntegrityConstraintError, match="Database operation failed"
):
await user_crud.update(
await user_repo.update(
session, db_obj=user, obj_in={"first_name": "Updated"}
)
@@ -159,7 +159,7 @@ class TestBaseCRUDUpdateFailures:
_test_engine, SessionLocal = async_test_db
async with SessionLocal() as session:
user = await user_crud.get(session, id=str(async_test_user.id))
user = await user_repo.get(session, id=str(async_test_user.id))
async def mock_commit():
raise KeyError("Unexpected error")
@@ -169,15 +169,15 @@ class TestBaseCRUDUpdateFailures:
session, "rollback", new_callable=AsyncMock
) as mock_rollback:
with pytest.raises(KeyError):
await user_crud.update(
await user_repo.update(
session, db_obj=user, obj_in={"first_name": "Updated"}
)
mock_rollback.assert_called_once()
class TestBaseCRUDRemoveFailures:
"""Test base CRUD remove method exception handling."""
class TestBaseRepositoryRemoveFailures:
"""Test base repository remove method exception handling."""
@pytest.mark.asyncio
async def test_remove_unexpected_error_triggers_rollback(
@@ -196,12 +196,12 @@ class TestBaseCRUDRemoveFailures:
session, "rollback", new_callable=AsyncMock
) as mock_rollback:
with pytest.raises(RuntimeError, match="Database write failed"):
await user_crud.remove(session, id=str(async_test_user.id))
await user_repo.remove(session, id=str(async_test_user.id))
mock_rollback.assert_called_once()
class TestBaseCRUDGetMultiWithTotalFailures:
class TestBaseRepositoryGetMultiWithTotalFailures:
"""Test get_multi_with_total exception handling."""
@pytest.mark.asyncio
@@ -217,10 +217,10 @@ class TestBaseCRUDGetMultiWithTotalFailures:
with patch.object(session, "execute", side_effect=mock_execute):
with pytest.raises(OperationalError):
await user_crud.get_multi_with_total(session, skip=0, limit=10)
await user_repo.get_multi_with_total(session, skip=0, limit=10)
class TestBaseCRUDCountFailures:
class TestBaseRepositoryCountFailures:
"""Test count method exception handling."""
@pytest.mark.asyncio
@@ -235,10 +235,10 @@ class TestBaseCRUDCountFailures:
with patch.object(session, "execute", side_effect=mock_execute):
with pytest.raises(OperationalError):
await user_crud.count(session)
await user_repo.count(session)
class TestBaseCRUDSoftDeleteFailures:
class TestBaseRepositorySoftDeleteFailures:
"""Test soft_delete method exception handling."""
@pytest.mark.asyncio
@@ -258,12 +258,12 @@ class TestBaseCRUDSoftDeleteFailures:
session, "rollback", new_callable=AsyncMock
) as mock_rollback:
with pytest.raises(RuntimeError, match="Soft delete failed"):
await user_crud.soft_delete(session, id=str(async_test_user.id))
await user_repo.soft_delete(session, id=str(async_test_user.id))
mock_rollback.assert_called_once()
class TestBaseCRUDRestoreFailures:
class TestBaseRepositoryRestoreFailures:
"""Test restore method exception handling."""
@pytest.mark.asyncio
@@ -279,12 +279,12 @@ class TestBaseCRUDRestoreFailures:
first_name="Restore",
last_name="Test",
)
user = await user_crud.create(session, obj_in=user_data)
user = await user_repo.create(session, obj_in=user_data)
user_id = user.id
await session.commit()
async with SessionLocal() as session:
await user_crud.soft_delete(session, id=str(user_id))
await user_repo.soft_delete(session, id=str(user_id))
# Now test restore failure
async with SessionLocal() as session:
@@ -297,12 +297,12 @@ class TestBaseCRUDRestoreFailures:
session, "rollback", new_callable=AsyncMock
) as mock_rollback:
with pytest.raises(RuntimeError, match="Restore failed"):
await user_crud.restore(session, id=str(user_id))
await user_repo.restore(session, id=str(user_id))
mock_rollback.assert_called_once()
class TestBaseCRUDGetFailures:
class TestBaseRepositoryGetFailures:
"""Test get method exception handling."""
@pytest.mark.asyncio
@@ -317,10 +317,10 @@ class TestBaseCRUDGetFailures:
with patch.object(session, "execute", side_effect=mock_execute):
with pytest.raises(OperationalError):
await user_crud.get(session, id=str(uuid4()))
await user_repo.get(session, id=str(uuid4()))
class TestBaseCRUDGetMultiFailures:
class TestBaseRepositoryGetMultiFailures:
"""Test get_multi method exception handling."""
@pytest.mark.asyncio
@@ -335,4 +335,4 @@ class TestBaseCRUDGetMultiFailures:
with patch.object(session, "execute", side_effect=mock_execute):
with pytest.raises(OperationalError):
await user_crud.get_multi(session, skip=0, limit=10)
await user_repo.get_multi(session, skip=0, limit=10)

View File

@@ -1,6 +1,6 @@
# tests/crud/test_oauth.py
# tests/repositories/test_oauth.py
"""
Comprehensive tests for OAuth CRUD operations.
Comprehensive tests for OAuth repository operations.
"""
from datetime import UTC, datetime, timedelta
@@ -14,8 +14,8 @@ from app.repositories.oauth_state import oauth_state_repo as oauth_state
from app.schemas.oauth import OAuthAccountCreate, OAuthClientCreate, OAuthStateCreate
class TestOAuthAccountCRUD:
"""Tests for OAuth account CRUD operations."""
class TestOAuthAccountRepository:
"""Tests for OAuth account repository operations."""
@pytest.mark.asyncio
async def test_create_account(self, async_test_db, async_test_user):
@@ -269,8 +269,8 @@ class TestOAuthAccountCRUD:
assert updated.refresh_token == "new_refresh_token"
class TestOAuthStateCRUD:
"""Tests for OAuth state CRUD operations."""
class TestOAuthStateRepository:
"""Tests for OAuth state repository operations."""
@pytest.mark.asyncio
async def test_create_state(self, async_test_db):
@@ -376,8 +376,8 @@ class TestOAuthStateCRUD:
assert result is not None
class TestOAuthClientCRUD:
"""Tests for OAuth client CRUD operations (provider mode)."""
class TestOAuthClientRepository:
"""Tests for OAuth client repository operations (provider mode)."""
@pytest.mark.asyncio
async def test_create_public_client(self, async_test_db):

View File

@@ -1,6 +1,6 @@
# tests/crud/test_organization_async.py
# tests/repositories/test_organization_async.py
"""
Comprehensive tests for async organization CRUD operations.
Comprehensive tests for async organization repository operations.
"""
from unittest.mock import AsyncMock, MagicMock, patch
@@ -12,7 +12,7 @@ from sqlalchemy import select
from app.core.repository_exceptions import DuplicateEntryError, IntegrityConstraintError
from app.models.organization import Organization
from app.models.user_organization import OrganizationRole, UserOrganization
from app.repositories.organization import organization_repo as organization_crud
from app.repositories.organization import organization_repo as organization_repo
from app.schemas.organizations import OrganizationCreate
@@ -35,7 +35,7 @@ class TestGetBySlug:
# Get by slug
async with AsyncTestingSessionLocal() as session:
result = await organization_crud.get_by_slug(session, slug="test-org")
result = await organization_repo.get_by_slug(session, slug="test-org")
assert result is not None
assert result.id == org_id
assert result.slug == "test-org"
@@ -46,7 +46,7 @@ class TestGetBySlug:
_test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session:
result = await organization_crud.get_by_slug(session, slug="nonexistent")
result = await organization_repo.get_by_slug(session, slug="nonexistent")
assert result is None
@@ -55,7 +55,7 @@ class TestCreate:
@pytest.mark.asyncio
async def test_create_success(self, async_test_db):
"""Test successfully creating an organization_crud."""
"""Test successfully creating an organization_repo."""
_test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session:
@@ -66,7 +66,7 @@ class TestCreate:
is_active=True,
settings={"key": "value"},
)
result = await organization_crud.create(session, obj_in=org_in)
result = await organization_repo.create(session, obj_in=org_in)
assert result.name == "New Org"
assert result.slug == "new-org"
@@ -89,7 +89,7 @@ class TestCreate:
async with AsyncTestingSessionLocal() as session:
org_in = OrganizationCreate(name="Org 2", slug="duplicate-slug")
with pytest.raises(DuplicateEntryError, match="already exists"):
await organization_crud.create(session, obj_in=org_in)
await organization_repo.create(session, obj_in=org_in)
@pytest.mark.asyncio
async def test_create_without_settings(self, async_test_db):
@@ -98,7 +98,7 @@ class TestCreate:
async with AsyncTestingSessionLocal() as session:
org_in = OrganizationCreate(name="No Settings Org", slug="no-settings")
result = await organization_crud.create(session, obj_in=org_in)
result = await organization_repo.create(session, obj_in=org_in)
assert result.settings == {}
@@ -119,7 +119,7 @@ class TestGetMultiWithFilters:
await session.commit()
async with AsyncTestingSessionLocal() as session:
orgs, total = await organization_crud.get_multi_with_filters(session)
orgs, total = await organization_repo.get_multi_with_filters(session)
assert total == 5
assert len(orgs) == 5
@@ -135,7 +135,7 @@ class TestGetMultiWithFilters:
await session.commit()
async with AsyncTestingSessionLocal() as session:
orgs, total = await organization_crud.get_multi_with_filters(
orgs, total = await organization_repo.get_multi_with_filters(
session, is_active=True
)
assert total == 1
@@ -157,7 +157,7 @@ class TestGetMultiWithFilters:
await session.commit()
async with AsyncTestingSessionLocal() as session:
orgs, total = await organization_crud.get_multi_with_filters(
orgs, total = await organization_repo.get_multi_with_filters(
session, search="tech"
)
assert total == 1
@@ -175,7 +175,7 @@ class TestGetMultiWithFilters:
await session.commit()
async with AsyncTestingSessionLocal() as session:
orgs, total = await organization_crud.get_multi_with_filters(
orgs, total = await organization_repo.get_multi_with_filters(
session, skip=2, limit=3
)
assert total == 10
@@ -193,7 +193,7 @@ class TestGetMultiWithFilters:
await session.commit()
async with AsyncTestingSessionLocal() as session:
orgs, _total = await organization_crud.get_multi_with_filters(
orgs, _total = await organization_repo.get_multi_with_filters(
session, sort_by="name", sort_order="asc"
)
assert orgs[0].name == "A Org"
@@ -205,7 +205,7 @@ class TestGetMemberCount:
@pytest.mark.asyncio
async def test_get_member_count_success(self, async_test_db, async_test_user):
"""Test getting member count for organization_crud."""
"""Test getting member count for organization_repo."""
_test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session:
@@ -225,7 +225,7 @@ class TestGetMemberCount:
org_id = org.id
async with AsyncTestingSessionLocal() as session:
count = await organization_crud.get_member_count(
count = await organization_repo.get_member_count(
session, organization_id=org_id
)
assert count == 1
@@ -242,7 +242,7 @@ class TestGetMemberCount:
org_id = org.id
async with AsyncTestingSessionLocal() as session:
count = await organization_crud.get_member_count(
count = await organization_repo.get_member_count(
session, organization_id=org_id
)
assert count == 0
@@ -253,7 +253,7 @@ class TestAddUser:
@pytest.mark.asyncio
async def test_add_user_success(self, async_test_db, async_test_user):
"""Test successfully adding a user to organization_crud."""
"""Test successfully adding a user to organization_repo."""
_test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session:
@@ -263,7 +263,7 @@ class TestAddUser:
org_id = org.id
async with AsyncTestingSessionLocal() as session:
result = await organization_crud.add_user(
result = await organization_repo.add_user(
session,
organization_id=org_id,
user_id=async_test_user.id,
@@ -297,7 +297,7 @@ class TestAddUser:
async with AsyncTestingSessionLocal() as session:
with pytest.raises(DuplicateEntryError, match="already a member"):
await organization_crud.add_user(
await organization_repo.add_user(
session, organization_id=org_id, user_id=async_test_user.id
)
@@ -322,7 +322,7 @@ class TestAddUser:
org_id = org.id
async with AsyncTestingSessionLocal() as session:
result = await organization_crud.add_user(
result = await organization_repo.add_user(
session,
organization_id=org_id,
user_id=async_test_user.id,
@@ -338,7 +338,7 @@ class TestRemoveUser:
@pytest.mark.asyncio
async def test_remove_user_success(self, async_test_db, async_test_user):
"""Test successfully removing a user from organization_crud."""
"""Test successfully removing a user from organization_repo."""
_test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session:
@@ -357,7 +357,7 @@ class TestRemoveUser:
org_id = org.id
async with AsyncTestingSessionLocal() as session:
result = await organization_crud.remove_user(
result = await organization_repo.remove_user(
session, organization_id=org_id, user_id=async_test_user.id
)
@@ -385,7 +385,7 @@ class TestRemoveUser:
org_id = org.id
async with AsyncTestingSessionLocal() as session:
result = await organization_crud.remove_user(
result = await organization_repo.remove_user(
session, organization_id=org_id, user_id=uuid4()
)
@@ -416,7 +416,7 @@ class TestUpdateUserRole:
org_id = org.id
async with AsyncTestingSessionLocal() as session:
result = await organization_crud.update_user_role(
result = await organization_repo.update_user_role(
session,
organization_id=org_id,
user_id=async_test_user.id,
@@ -439,7 +439,7 @@ class TestUpdateUserRole:
org_id = org.id
async with AsyncTestingSessionLocal() as session:
result = await organization_crud.update_user_role(
result = await organization_repo.update_user_role(
session,
organization_id=org_id,
user_id=uuid4(),
@@ -475,7 +475,7 @@ class TestGetOrganizationMembers:
org_id = org.id
async with AsyncTestingSessionLocal() as session:
members, total = await organization_crud.get_organization_members(
members, total = await organization_repo.get_organization_members(
session, organization_id=org_id
)
@@ -508,7 +508,7 @@ class TestGetOrganizationMembers:
org_id = org.id
async with AsyncTestingSessionLocal() as session:
members, total = await organization_crud.get_organization_members(
members, total = await organization_repo.get_organization_members(
session, organization_id=org_id, skip=0, limit=10
)
@@ -539,7 +539,7 @@ class TestGetUserOrganizations:
await session.commit()
async with AsyncTestingSessionLocal() as session:
orgs = await organization_crud.get_user_organizations(
orgs = await organization_repo.get_user_organizations(
session, user_id=async_test_user.id
)
@@ -575,7 +575,7 @@ class TestGetUserOrganizations:
await session.commit()
async with AsyncTestingSessionLocal() as session:
orgs = await organization_crud.get_user_organizations(
orgs = await organization_repo.get_user_organizations(
session, user_id=async_test_user.id, is_active=True
)
@@ -588,7 +588,7 @@ class TestGetUserRole:
@pytest.mark.asyncio
async def test_get_user_role_in_org_success(self, async_test_db, async_test_user):
"""Test getting user role in organization_crud."""
"""Test getting user role in organization_repo."""
_test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session:
@@ -607,7 +607,7 @@ class TestGetUserRole:
org_id = org.id
async with AsyncTestingSessionLocal() as session:
role = await organization_crud.get_user_role_in_org(
role = await organization_repo.get_user_role_in_org(
session, user_id=async_test_user.id, organization_id=org_id
)
@@ -625,7 +625,7 @@ class TestGetUserRole:
org_id = org.id
async with AsyncTestingSessionLocal() as session:
role = await organization_crud.get_user_role_in_org(
role = await organization_repo.get_user_role_in_org(
session, user_id=uuid4(), organization_id=org_id
)
@@ -656,7 +656,7 @@ class TestIsUserOrgOwner:
org_id = org.id
async with AsyncTestingSessionLocal() as session:
is_owner = await organization_crud.is_user_org_owner(
is_owner = await organization_repo.is_user_org_owner(
session, user_id=async_test_user.id, organization_id=org_id
)
@@ -683,7 +683,7 @@ class TestIsUserOrgOwner:
org_id = org.id
async with AsyncTestingSessionLocal() as session:
is_owner = await organization_crud.is_user_org_owner(
is_owner = await organization_repo.is_user_org_owner(
session, user_id=async_test_user.id, organization_id=org_id
)
@@ -720,7 +720,7 @@ class TestGetMultiWithMemberCounts:
(
orgs_with_counts,
total,
) = await organization_crud.get_multi_with_member_counts(session)
) = await organization_repo.get_multi_with_member_counts(session)
assert total == 2
assert len(orgs_with_counts) == 2
@@ -745,7 +745,7 @@ class TestGetMultiWithMemberCounts:
(
orgs_with_counts,
total,
) = await organization_crud.get_multi_with_member_counts(
) = await organization_repo.get_multi_with_member_counts(
session, is_active=True
)
@@ -767,7 +767,7 @@ class TestGetMultiWithMemberCounts:
(
orgs_with_counts,
total,
) = await organization_crud.get_multi_with_member_counts(
) = await organization_repo.get_multi_with_member_counts(
session, search="tech"
)
@@ -801,7 +801,7 @@ class TestGetUserOrganizationsWithDetails:
async with AsyncTestingSessionLocal() as session:
orgs_with_details = (
await organization_crud.get_user_organizations_with_details(
await organization_repo.get_user_organizations_with_details(
session, user_id=async_test_user.id
)
)
@@ -841,7 +841,7 @@ class TestGetUserOrganizationsWithDetails:
async with AsyncTestingSessionLocal() as session:
orgs_with_details = (
await organization_crud.get_user_organizations_with_details(
await organization_repo.get_user_organizations_with_details(
session, user_id=async_test_user.id, is_active=True
)
)
@@ -874,7 +874,7 @@ class TestIsUserOrgAdmin:
org_id = org.id
async with AsyncTestingSessionLocal() as session:
is_admin = await organization_crud.is_user_org_admin(
is_admin = await organization_repo.is_user_org_admin(
session, user_id=async_test_user.id, organization_id=org_id
)
@@ -901,7 +901,7 @@ class TestIsUserOrgAdmin:
org_id = org.id
async with AsyncTestingSessionLocal() as session:
is_admin = await organization_crud.is_user_org_admin(
is_admin = await organization_repo.is_user_org_admin(
session, user_id=async_test_user.id, organization_id=org_id
)
@@ -928,7 +928,7 @@ class TestIsUserOrgAdmin:
org_id = org.id
async with AsyncTestingSessionLocal() as session:
is_admin = await organization_crud.is_user_org_admin(
is_admin = await organization_repo.is_user_org_admin(
session, user_id=async_test_user.id, organization_id=org_id
)
@@ -937,7 +937,7 @@ class TestIsUserOrgAdmin:
class TestOrganizationExceptionHandlers:
"""
Test exception handlers in organization CRUD methods.
Test exception handlers in organization repository methods.
Uses mocks to trigger database errors and verify proper error handling.
Covers lines: 33-35, 57-62, 114-116, 130-132, 207-209, 258-260, 291-294, 326-329, 385-387, 409-411, 466-468, 491-493
"""
@@ -952,7 +952,7 @@ class TestOrganizationExceptionHandlers:
session, "execute", side_effect=Exception("Database connection lost")
):
with pytest.raises(Exception, match="Database connection lost"):
await organization_crud.get_by_slug(session, slug="test-slug")
await organization_repo.get_by_slug(session, slug="test-slug")
@pytest.mark.asyncio
async def test_create_integrity_error_non_slug(self, async_test_db):
@@ -976,7 +976,7 @@ class TestOrganizationExceptionHandlers:
with pytest.raises(
IntegrityConstraintError, match="Database integrity error"
):
await organization_crud.create(session, obj_in=org_in)
await organization_repo.create(session, obj_in=org_in)
@pytest.mark.asyncio
async def test_create_unexpected_error(self, async_test_db):
@@ -990,7 +990,7 @@ class TestOrganizationExceptionHandlers:
with patch.object(session, "rollback", new_callable=AsyncMock):
org_in = OrganizationCreate(name="Test", slug="test")
with pytest.raises(RuntimeError, match="Unexpected error"):
await organization_crud.create(session, obj_in=org_in)
await organization_repo.create(session, obj_in=org_in)
@pytest.mark.asyncio
async def test_get_multi_with_filters_database_error(self, async_test_db):
@@ -1002,7 +1002,7 @@ class TestOrganizationExceptionHandlers:
session, "execute", side_effect=Exception("Query timeout")
):
with pytest.raises(Exception, match="Query timeout"):
await organization_crud.get_multi_with_filters(session)
await organization_repo.get_multi_with_filters(session)
@pytest.mark.asyncio
async def test_get_member_count_database_error(self, async_test_db):
@@ -1016,7 +1016,7 @@ class TestOrganizationExceptionHandlers:
session, "execute", side_effect=Exception("Count query failed")
):
with pytest.raises(Exception, match="Count query failed"):
await organization_crud.get_member_count(
await organization_repo.get_member_count(
session, organization_id=uuid4()
)
@@ -1030,7 +1030,7 @@ class TestOrganizationExceptionHandlers:
session, "execute", side_effect=Exception("Complex query failed")
):
with pytest.raises(Exception, match="Complex query failed"):
await organization_crud.get_multi_with_member_counts(session)
await organization_repo.get_multi_with_member_counts(session)
@pytest.mark.asyncio
async def test_add_user_integrity_error(self, async_test_db, async_test_user):
@@ -1064,7 +1064,7 @@ class TestOrganizationExceptionHandlers:
IntegrityConstraintError,
match="Failed to add user to organization",
):
await organization_crud.add_user(
await organization_repo.add_user(
session,
organization_id=org_id,
user_id=async_test_user.id,
@@ -1082,7 +1082,7 @@ class TestOrganizationExceptionHandlers:
session, "execute", side_effect=Exception("Delete failed")
):
with pytest.raises(Exception, match="Delete failed"):
await organization_crud.remove_user(
await organization_repo.remove_user(
session, organization_id=uuid4(), user_id=async_test_user.id
)
@@ -1100,7 +1100,7 @@ class TestOrganizationExceptionHandlers:
session, "execute", side_effect=Exception("Update failed")
):
with pytest.raises(Exception, match="Update failed"):
await organization_crud.update_user_role(
await organization_repo.update_user_role(
session,
organization_id=uuid4(),
user_id=async_test_user.id,
@@ -1119,7 +1119,7 @@ class TestOrganizationExceptionHandlers:
session, "execute", side_effect=Exception("Members query failed")
):
with pytest.raises(Exception, match="Members query failed"):
await organization_crud.get_organization_members(
await organization_repo.get_organization_members(
session, organization_id=uuid4()
)
@@ -1135,7 +1135,7 @@ class TestOrganizationExceptionHandlers:
session, "execute", side_effect=Exception("User orgs query failed")
):
with pytest.raises(Exception, match="User orgs query failed"):
await organization_crud.get_user_organizations(
await organization_repo.get_user_organizations(
session, user_id=async_test_user.id
)
@@ -1151,7 +1151,7 @@ class TestOrganizationExceptionHandlers:
session, "execute", side_effect=Exception("Details query failed")
):
with pytest.raises(Exception, match="Details query failed"):
await organization_crud.get_user_organizations_with_details(
await organization_repo.get_user_organizations_with_details(
session, user_id=async_test_user.id
)
@@ -1169,6 +1169,6 @@ class TestOrganizationExceptionHandlers:
session, "execute", side_effect=Exception("Role query failed")
):
with pytest.raises(Exception, match="Role query failed"):
await organization_crud.get_user_role_in_org(
await organization_repo.get_user_role_in_org(
session, user_id=async_test_user.id, organization_id=uuid4()
)

View File

@@ -1,6 +1,6 @@
# tests/crud/test_session_async.py
# tests/repositories/test_session_async.py
"""
Comprehensive tests for async session CRUD operations.
Comprehensive tests for async session repository operations.
"""
from datetime import UTC, datetime, timedelta
@@ -10,7 +10,7 @@ import pytest
from app.core.repository_exceptions import InvalidInputError
from app.models.user_session import UserSession
from app.repositories.session import session_repo as session_crud
from app.repositories.session import session_repo as session_repo
from app.schemas.sessions import SessionCreate
@@ -37,7 +37,7 @@ class TestGetByJti:
await session.commit()
async with AsyncTestingSessionLocal() as session:
result = await session_crud.get_by_jti(session, jti="test_jti_123")
result = await session_repo.get_by_jti(session, jti="test_jti_123")
assert result is not None
assert result.refresh_token_jti == "test_jti_123"
@@ -47,7 +47,7 @@ class TestGetByJti:
_test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session:
result = await session_crud.get_by_jti(session, jti="nonexistent")
result = await session_repo.get_by_jti(session, jti="nonexistent")
assert result is None
@@ -74,7 +74,7 @@ class TestGetActiveByJti:
await session.commit()
async with AsyncTestingSessionLocal() as session:
result = await session_crud.get_active_by_jti(session, jti="active_jti")
result = await session_repo.get_active_by_jti(session, jti="active_jti")
assert result is not None
assert result.is_active is True
@@ -98,7 +98,7 @@ class TestGetActiveByJti:
await session.commit()
async with AsyncTestingSessionLocal() as session:
result = await session_crud.get_active_by_jti(session, jti="inactive_jti")
result = await session_repo.get_active_by_jti(session, jti="inactive_jti")
assert result is None
@@ -135,7 +135,7 @@ class TestGetUserSessions:
await session.commit()
async with AsyncTestingSessionLocal() as session:
results = await session_crud.get_user_sessions(
results = await session_repo.get_user_sessions(
session, user_id=str(async_test_user.id), active_only=True
)
assert len(results) == 1
@@ -162,7 +162,7 @@ class TestGetUserSessions:
await session.commit()
async with AsyncTestingSessionLocal() as session:
results = await session_crud.get_user_sessions(
results = await session_repo.get_user_sessions(
session, user_id=str(async_test_user.id), active_only=False
)
assert len(results) == 3
@@ -173,7 +173,7 @@ class TestCreateSession:
@pytest.mark.asyncio
async def test_create_session_success(self, async_test_db, async_test_user):
"""Test successfully creating a session_crud."""
"""Test successfully creating a session_repo."""
_test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session:
@@ -189,7 +189,7 @@ class TestCreateSession:
location_city="San Francisco",
location_country="USA",
)
result = await session_crud.create_session(session, obj_in=session_data)
result = await session_repo.create_session(session, obj_in=session_data)
assert result.user_id == async_test_user.id
assert result.refresh_token_jti == "new_jti"
@@ -202,7 +202,7 @@ class TestDeactivate:
@pytest.mark.asyncio
async def test_deactivate_success(self, async_test_db, async_test_user):
"""Test successfully deactivating a session_crud."""
"""Test successfully deactivating a session_repo."""
_test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session:
@@ -221,7 +221,7 @@ class TestDeactivate:
session_id = user_session.id
async with AsyncTestingSessionLocal() as session:
result = await session_crud.deactivate(session, session_id=str(session_id))
result = await session_repo.deactivate(session, session_id=str(session_id))
assert result is not None
assert result.is_active is False
@@ -231,7 +231,7 @@ class TestDeactivate:
_test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session:
result = await session_crud.deactivate(session, session_id=str(uuid4()))
result = await session_repo.deactivate(session, session_id=str(uuid4()))
assert result is None
@@ -262,7 +262,7 @@ class TestDeactivateAllUserSessions:
await session.commit()
async with AsyncTestingSessionLocal() as session:
count = await session_crud.deactivate_all_user_sessions(
count = await session_repo.deactivate_all_user_sessions(
session, user_id=str(async_test_user.id)
)
assert count == 2
@@ -292,7 +292,7 @@ class TestUpdateLastUsed:
await session.refresh(user_session)
old_time = user_session.last_used_at
result = await session_crud.update_last_used(session, session=user_session)
result = await session_repo.update_last_used(session, session=user_session)
assert result.last_used_at > old_time
@@ -321,7 +321,7 @@ class TestGetUserSessionCount:
await session.commit()
async with AsyncTestingSessionLocal() as session:
count = await session_crud.get_user_session_count(
count = await session_repo.get_user_session_count(
session, user_id=str(async_test_user.id)
)
assert count == 3
@@ -332,7 +332,7 @@ class TestGetUserSessionCount:
_test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session:
count = await session_crud.get_user_session_count(
count = await session_repo.get_user_session_count(
session, user_id=str(uuid4())
)
assert count == 0
@@ -364,7 +364,7 @@ class TestUpdateRefreshToken:
new_jti = "new_jti_123"
new_expires = datetime.now(UTC) + timedelta(days=14)
result = await session_crud.update_refresh_token(
result = await session_repo.update_refresh_token(
session,
session=user_session,
new_jti=new_jti,
@@ -410,7 +410,7 @@ class TestCleanupExpired:
# Cleanup
async with AsyncTestingSessionLocal() as session:
count = await session_crud.cleanup_expired(session, keep_days=30)
count = await session_repo.cleanup_expired(session, keep_days=30)
assert count == 1
@pytest.mark.asyncio
@@ -436,7 +436,7 @@ class TestCleanupExpired:
# Cleanup
async with AsyncTestingSessionLocal() as session:
count = await session_crud.cleanup_expired(session, keep_days=30)
count = await session_repo.cleanup_expired(session, keep_days=30)
assert count == 0 # Should not delete recent sessions
@pytest.mark.asyncio
@@ -462,7 +462,7 @@ class TestCleanupExpired:
# Cleanup
async with AsyncTestingSessionLocal() as session:
count = await session_crud.cleanup_expired(session, keep_days=30)
count = await session_repo.cleanup_expired(session, keep_days=30)
assert count == 0 # Should not delete active sessions
@@ -493,7 +493,7 @@ class TestCleanupExpiredForUser:
# Cleanup for user
async with AsyncTestingSessionLocal() as session:
count = await session_crud.cleanup_expired_for_user(
count = await session_repo.cleanup_expired_for_user(
session, user_id=str(async_test_user.id)
)
assert count == 1
@@ -505,7 +505,7 @@ class TestCleanupExpiredForUser:
async with AsyncTestingSessionLocal() as session:
with pytest.raises(InvalidInputError, match="Invalid user ID format"):
await session_crud.cleanup_expired_for_user(
await session_repo.cleanup_expired_for_user(
session, user_id="not-a-valid-uuid"
)
@@ -533,7 +533,7 @@ class TestCleanupExpiredForUser:
# Cleanup
async with AsyncTestingSessionLocal() as session:
count = await session_crud.cleanup_expired_for_user(
count = await session_repo.cleanup_expired_for_user(
session, user_id=str(async_test_user.id)
)
assert count == 0 # Should not delete active sessions
@@ -565,7 +565,7 @@ class TestGetUserSessionsWithUser:
# Get with user relationship
async with AsyncTestingSessionLocal() as session:
results = await session_crud.get_user_sessions(
results = await session_repo.get_user_sessions(
session, user_id=str(async_test_user.id), with_user=True
)
assert len(results) >= 1

View File

@@ -1,6 +1,6 @@
# tests/crud/test_session_db_failures.py
# tests/repositories/test_session_db_failures.py
"""
Comprehensive tests for session CRUD database failure scenarios.
Comprehensive tests for session repository database failure scenarios.
"""
from datetime import UTC, datetime, timedelta
@@ -12,11 +12,11 @@ from sqlalchemy.exc import OperationalError
from app.core.repository_exceptions import IntegrityConstraintError
from app.models.user_session import UserSession
from app.repositories.session import session_repo as session_crud
from app.repositories.session import session_repo as session_repo
from app.schemas.sessions import SessionCreate
class TestSessionCRUDGetByJtiFailures:
class TestSessionRepositoryGetByJtiFailures:
"""Test get_by_jti exception handling."""
@pytest.mark.asyncio
@@ -31,10 +31,10 @@ class TestSessionCRUDGetByJtiFailures:
with patch.object(session, "execute", side_effect=mock_execute):
with pytest.raises(OperationalError):
await session_crud.get_by_jti(session, jti="test_jti")
await session_repo.get_by_jti(session, jti="test_jti")
class TestSessionCRUDGetActiveByJtiFailures:
class TestSessionRepositoryGetActiveByJtiFailures:
"""Test get_active_by_jti exception handling."""
@pytest.mark.asyncio
@@ -49,10 +49,10 @@ class TestSessionCRUDGetActiveByJtiFailures:
with patch.object(session, "execute", side_effect=mock_execute):
with pytest.raises(OperationalError):
await session_crud.get_active_by_jti(session, jti="test_jti")
await session_repo.get_active_by_jti(session, jti="test_jti")
class TestSessionCRUDGetUserSessionsFailures:
class TestSessionRepositoryGetUserSessionsFailures:
"""Test get_user_sessions exception handling."""
@pytest.mark.asyncio
@@ -69,12 +69,12 @@ class TestSessionCRUDGetUserSessionsFailures:
with patch.object(session, "execute", side_effect=mock_execute):
with pytest.raises(OperationalError):
await session_crud.get_user_sessions(
await session_repo.get_user_sessions(
session, user_id=str(async_test_user.id)
)
class TestSessionCRUDCreateSessionFailures:
class TestSessionRepositoryCreateSessionFailures:
"""Test create_session exception handling."""
@pytest.mark.asyncio
@@ -106,7 +106,7 @@ class TestSessionCRUDCreateSessionFailures:
with pytest.raises(
IntegrityConstraintError, match="Failed to create session"
):
await session_crud.create_session(session, obj_in=session_data)
await session_repo.create_session(session, obj_in=session_data)
mock_rollback.assert_called_once()
@@ -139,12 +139,12 @@ class TestSessionCRUDCreateSessionFailures:
with pytest.raises(
IntegrityConstraintError, match="Failed to create session"
):
await session_crud.create_session(session, obj_in=session_data)
await session_repo.create_session(session, obj_in=session_data)
mock_rollback.assert_called_once()
class TestSessionCRUDDeactivateFailures:
class TestSessionRepositoryDeactivateFailures:
"""Test deactivate exception handling."""
@pytest.mark.asyncio
@@ -182,14 +182,14 @@ class TestSessionCRUDDeactivateFailures:
session, "rollback", new_callable=AsyncMock
) as mock_rollback:
with pytest.raises(OperationalError):
await session_crud.deactivate(
await session_repo.deactivate(
session, session_id=str(session_id)
)
mock_rollback.assert_called_once()
class TestSessionCRUDDeactivateAllFailures:
class TestSessionRepositoryDeactivateAllFailures:
"""Test deactivate_all_user_sessions exception handling."""
@pytest.mark.asyncio
@@ -209,14 +209,14 @@ class TestSessionCRUDDeactivateAllFailures:
session, "rollback", new_callable=AsyncMock
) as mock_rollback:
with pytest.raises(OperationalError):
await session_crud.deactivate_all_user_sessions(
await session_repo.deactivate_all_user_sessions(
session, user_id=str(async_test_user.id)
)
mock_rollback.assert_called_once()
class TestSessionCRUDUpdateLastUsedFailures:
class TestSessionRepositoryUpdateLastUsedFailures:
"""Test update_last_used exception handling."""
@pytest.mark.asyncio
@@ -259,12 +259,12 @@ class TestSessionCRUDUpdateLastUsedFailures:
session, "rollback", new_callable=AsyncMock
) as mock_rollback:
with pytest.raises(OperationalError):
await session_crud.update_last_used(session, session=sess)
await session_repo.update_last_used(session, session=sess)
mock_rollback.assert_called_once()
class TestSessionCRUDUpdateRefreshTokenFailures:
class TestSessionRepositoryUpdateRefreshTokenFailures:
"""Test update_refresh_token exception handling."""
@pytest.mark.asyncio
@@ -307,7 +307,7 @@ class TestSessionCRUDUpdateRefreshTokenFailures:
session, "rollback", new_callable=AsyncMock
) as mock_rollback:
with pytest.raises(OperationalError):
await session_crud.update_refresh_token(
await session_repo.update_refresh_token(
session,
session=sess,
new_jti=str(uuid4()),
@@ -317,7 +317,7 @@ class TestSessionCRUDUpdateRefreshTokenFailures:
mock_rollback.assert_called_once()
class TestSessionCRUDCleanupExpiredFailures:
class TestSessionRepositoryCleanupExpiredFailures:
"""Test cleanup_expired exception handling."""
@pytest.mark.asyncio
@@ -337,12 +337,12 @@ class TestSessionCRUDCleanupExpiredFailures:
session, "rollback", new_callable=AsyncMock
) as mock_rollback:
with pytest.raises(OperationalError):
await session_crud.cleanup_expired(session, keep_days=30)
await session_repo.cleanup_expired(session, keep_days=30)
mock_rollback.assert_called_once()
class TestSessionCRUDCleanupExpiredForUserFailures:
class TestSessionRepositoryCleanupExpiredForUserFailures:
"""Test cleanup_expired_for_user exception handling."""
@pytest.mark.asyncio
@@ -362,14 +362,14 @@ class TestSessionCRUDCleanupExpiredForUserFailures:
session, "rollback", new_callable=AsyncMock
) as mock_rollback:
with pytest.raises(OperationalError):
await session_crud.cleanup_expired_for_user(
await session_repo.cleanup_expired_for_user(
session, user_id=str(async_test_user.id)
)
mock_rollback.assert_called_once()
class TestSessionCRUDGetUserSessionCountFailures:
class TestSessionRepositoryGetUserSessionCountFailures:
"""Test get_user_session_count exception handling."""
@pytest.mark.asyncio
@@ -386,6 +386,6 @@ class TestSessionCRUDGetUserSessionCountFailures:
with patch.object(session, "execute", side_effect=mock_execute):
with pytest.raises(OperationalError):
await session_crud.get_user_session_count(
await session_repo.get_user_session_count(
session, user_id=str(async_test_user.id)
)

View File

@@ -1,12 +1,12 @@
# tests/crud/test_user_async.py
# tests/repositories/test_user_async.py
"""
Comprehensive tests for async user CRUD operations.
Comprehensive tests for async user repository operations.
"""
import pytest
from app.core.repository_exceptions import DuplicateEntryError, InvalidInputError
from app.repositories.user import user_repo as user_crud
from app.repositories.user import user_repo as user_repo
from app.schemas.users import UserCreate, UserUpdate
@@ -19,7 +19,7 @@ class TestGetByEmail:
_test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session:
result = await user_crud.get_by_email(session, email=async_test_user.email)
result = await user_repo.get_by_email(session, email=async_test_user.email)
assert result is not None
assert result.email == async_test_user.email
assert result.id == async_test_user.id
@@ -30,7 +30,7 @@ class TestGetByEmail:
_test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session:
result = await user_crud.get_by_email(
result = await user_repo.get_by_email(
session, email="nonexistent@example.com"
)
assert result is None
@@ -41,7 +41,7 @@ class TestCreate:
@pytest.mark.asyncio
async def test_create_user_success(self, async_test_db):
"""Test successfully creating a user_crud."""
"""Test successfully creating a user_repo."""
_test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session:
@@ -52,7 +52,7 @@ class TestCreate:
last_name="User",
phone_number="+1234567890",
)
result = await user_crud.create(session, obj_in=user_data)
result = await user_repo.create(session, obj_in=user_data)
assert result.email == "newuser@example.com"
assert result.first_name == "New"
@@ -76,7 +76,7 @@ class TestCreate:
last_name="User",
is_superuser=True,
)
result = await user_crud.create(session, obj_in=user_data)
result = await user_repo.create(session, obj_in=user_data)
assert result.is_superuser is True
assert result.email == "superuser@example.com"
@@ -95,7 +95,7 @@ class TestCreate:
)
with pytest.raises(DuplicateEntryError) as exc_info:
await user_crud.create(session, obj_in=user_data)
await user_repo.create(session, obj_in=user_data)
assert "already exists" in str(exc_info.value).lower()
@@ -110,12 +110,12 @@ class TestUpdate:
async with AsyncTestingSessionLocal() as session:
# Get fresh copy of user
user = await user_crud.get(session, id=str(async_test_user.id))
user = await user_repo.get(session, id=str(async_test_user.id))
update_data = UserUpdate(
first_name="Updated", last_name="Name", phone_number="+9876543210"
)
result = await user_crud.update(session, db_obj=user, obj_in=update_data)
result = await user_repo.update(session, db_obj=user, obj_in=update_data)
assert result.first_name == "Updated"
assert result.last_name == "Name"
@@ -134,16 +134,16 @@ class TestUpdate:
first_name="Pass",
last_name="Test",
)
user = await user_crud.create(session, obj_in=user_data)
user = await user_repo.create(session, obj_in=user_data)
user_id = user.id
old_password_hash = user.password_hash
# Update the password
async with AsyncTestingSessionLocal() as session:
user = await user_crud.get(session, id=str(user_id))
user = await user_repo.get(session, id=str(user_id))
update_data = UserUpdate(password="NewDifferentPassword123!")
result = await user_crud.update(session, db_obj=user, obj_in=update_data)
result = await user_repo.update(session, db_obj=user, obj_in=update_data)
await session.refresh(result)
assert result.password_hash != old_password_hash
@@ -158,10 +158,10 @@ class TestUpdate:
_test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session:
user = await user_crud.get(session, id=str(async_test_user.id))
user = await user_repo.get(session, id=str(async_test_user.id))
update_dict = {"first_name": "DictUpdate"}
result = await user_crud.update(session, db_obj=user, obj_in=update_dict)
result = await user_repo.update(session, db_obj=user, obj_in=update_dict)
assert result.first_name == "DictUpdate"
@@ -175,7 +175,7 @@ class TestGetMultiWithTotal:
_test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session:
users, total = await user_crud.get_multi_with_total(
users, total = await user_repo.get_multi_with_total(
session, skip=0, limit=10
)
assert total >= 1
@@ -196,10 +196,10 @@ class TestGetMultiWithTotal:
first_name=f"User{i}",
last_name="Test",
)
await user_crud.create(session, obj_in=user_data)
await user_repo.create(session, obj_in=user_data)
async with AsyncTestingSessionLocal() as session:
users, _total = await user_crud.get_multi_with_total(
users, _total = await user_repo.get_multi_with_total(
session, skip=0, limit=10, sort_by="email", sort_order="asc"
)
@@ -222,10 +222,10 @@ class TestGetMultiWithTotal:
first_name=f"User{i}",
last_name="Test",
)
await user_crud.create(session, obj_in=user_data)
await user_repo.create(session, obj_in=user_data)
async with AsyncTestingSessionLocal() as session:
users, _total = await user_crud.get_multi_with_total(
users, _total = await user_repo.get_multi_with_total(
session, skip=0, limit=10, sort_by="email", sort_order="desc"
)
@@ -247,7 +247,7 @@ class TestGetMultiWithTotal:
first_name="Active",
last_name="User",
)
await user_crud.create(session, obj_in=active_user)
await user_repo.create(session, obj_in=active_user)
inactive_user = UserCreate(
email="inactive@example.com",
@@ -255,15 +255,15 @@ class TestGetMultiWithTotal:
first_name="Inactive",
last_name="User",
)
created_inactive = await user_crud.create(session, obj_in=inactive_user)
created_inactive = await user_repo.create(session, obj_in=inactive_user)
# Deactivate the user
await user_crud.update(
await user_repo.update(
session, db_obj=created_inactive, obj_in={"is_active": False}
)
async with AsyncTestingSessionLocal() as session:
users, _total = await user_crud.get_multi_with_total(
users, _total = await user_repo.get_multi_with_total(
session, skip=0, limit=100, filters={"is_active": True}
)
@@ -283,10 +283,10 @@ class TestGetMultiWithTotal:
first_name="Searchable",
last_name="UserName",
)
await user_crud.create(session, obj_in=user_data)
await user_repo.create(session, obj_in=user_data)
async with AsyncTestingSessionLocal() as session:
users, total = await user_crud.get_multi_with_total(
users, total = await user_repo.get_multi_with_total(
session, skip=0, limit=100, search="Searchable"
)
@@ -307,16 +307,16 @@ class TestGetMultiWithTotal:
first_name=f"Page{i}",
last_name="User",
)
await user_crud.create(session, obj_in=user_data)
await user_repo.create(session, obj_in=user_data)
async with AsyncTestingSessionLocal() as session:
# Get first page
users_page1, total = await user_crud.get_multi_with_total(
users_page1, total = await user_repo.get_multi_with_total(
session, skip=0, limit=2
)
# Get second page
users_page2, total2 = await user_crud.get_multi_with_total(
users_page2, total2 = await user_repo.get_multi_with_total(
session, skip=2, limit=2
)
@@ -332,7 +332,7 @@ class TestGetMultiWithTotal:
async with AsyncTestingSessionLocal() as session:
with pytest.raises(InvalidInputError) as exc_info:
await user_crud.get_multi_with_total(session, skip=-1, limit=10)
await user_repo.get_multi_with_total(session, skip=-1, limit=10)
assert "skip must be non-negative" in str(exc_info.value)
@@ -343,7 +343,7 @@ class TestGetMultiWithTotal:
async with AsyncTestingSessionLocal() as session:
with pytest.raises(InvalidInputError) as exc_info:
await user_crud.get_multi_with_total(session, skip=0, limit=-1)
await user_repo.get_multi_with_total(session, skip=0, limit=-1)
assert "limit must be non-negative" in str(exc_info.value)
@@ -354,7 +354,7 @@ class TestGetMultiWithTotal:
async with AsyncTestingSessionLocal() as session:
with pytest.raises(InvalidInputError) as exc_info:
await user_crud.get_multi_with_total(session, skip=0, limit=1001)
await user_repo.get_multi_with_total(session, skip=0, limit=1001)
assert "Maximum limit is 1000" in str(exc_info.value)
@@ -377,12 +377,12 @@ class TestBulkUpdateStatus:
first_name=f"Bulk{i}",
last_name="User",
)
user = await user_crud.create(session, obj_in=user_data)
user = await user_repo.create(session, obj_in=user_data)
user_ids.append(user.id)
# Bulk deactivate
async with AsyncTestingSessionLocal() as session:
count = await user_crud.bulk_update_status(
count = await user_repo.bulk_update_status(
session, user_ids=user_ids, is_active=False
)
assert count == 3
@@ -390,7 +390,7 @@ class TestBulkUpdateStatus:
# Verify all are inactive
async with AsyncTestingSessionLocal() as session:
for user_id in user_ids:
user = await user_crud.get(session, id=str(user_id))
user = await user_repo.get(session, id=str(user_id))
assert user.is_active is False
@pytest.mark.asyncio
@@ -399,7 +399,7 @@ class TestBulkUpdateStatus:
_test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session:
count = await user_crud.bulk_update_status(
count = await user_repo.bulk_update_status(
session, user_ids=[], is_active=False
)
assert count == 0
@@ -417,21 +417,21 @@ class TestBulkUpdateStatus:
first_name="Reactivate",
last_name="User",
)
user = await user_crud.create(session, obj_in=user_data)
user = await user_repo.create(session, obj_in=user_data)
# Deactivate
await user_crud.update(session, db_obj=user, obj_in={"is_active": False})
await user_repo.update(session, db_obj=user, obj_in={"is_active": False})
user_id = user.id
# Reactivate
async with AsyncTestingSessionLocal() as session:
count = await user_crud.bulk_update_status(
count = await user_repo.bulk_update_status(
session, user_ids=[user_id], is_active=True
)
assert count == 1
# Verify active
async with AsyncTestingSessionLocal() as session:
user = await user_crud.get(session, id=str(user_id))
user = await user_repo.get(session, id=str(user_id))
assert user.is_active is True
@@ -453,24 +453,24 @@ class TestBulkSoftDelete:
first_name=f"Delete{i}",
last_name="User",
)
user = await user_crud.create(session, obj_in=user_data)
user = await user_repo.create(session, obj_in=user_data)
user_ids.append(user.id)
# Bulk delete
async with AsyncTestingSessionLocal() as session:
count = await user_crud.bulk_soft_delete(session, user_ids=user_ids)
count = await user_repo.bulk_soft_delete(session, user_ids=user_ids)
assert count == 3
# Verify all are soft deleted
async with AsyncTestingSessionLocal() as session:
for user_id in user_ids:
user = await user_crud.get(session, id=str(user_id))
user = await user_repo.get(session, id=str(user_id))
assert user.deleted_at is not None
assert user.is_active is False
@pytest.mark.asyncio
async def test_bulk_soft_delete_with_exclusion(self, async_test_db):
"""Test bulk soft delete with excluded user_crud."""
"""Test bulk soft delete with excluded user_repo."""
_test_engine, AsyncTestingSessionLocal = async_test_db
# Create multiple users
@@ -483,20 +483,20 @@ class TestBulkSoftDelete:
first_name=f"Exclude{i}",
last_name="User",
)
user = await user_crud.create(session, obj_in=user_data)
user = await user_repo.create(session, obj_in=user_data)
user_ids.append(user.id)
# Bulk delete, excluding first user
exclude_id = user_ids[0]
async with AsyncTestingSessionLocal() as session:
count = await user_crud.bulk_soft_delete(
count = await user_repo.bulk_soft_delete(
session, user_ids=user_ids, exclude_user_id=exclude_id
)
assert count == 2 # Only 2 deleted
# Verify excluded user is NOT deleted
async with AsyncTestingSessionLocal() as session:
excluded_user = await user_crud.get(session, id=str(exclude_id))
excluded_user = await user_repo.get(session, id=str(exclude_id))
assert excluded_user.deleted_at is None
@pytest.mark.asyncio
@@ -505,7 +505,7 @@ class TestBulkSoftDelete:
_test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session:
count = await user_crud.bulk_soft_delete(session, user_ids=[])
count = await user_repo.bulk_soft_delete(session, user_ids=[])
assert count == 0
@pytest.mark.asyncio
@@ -521,12 +521,12 @@ class TestBulkSoftDelete:
first_name="Only",
last_name="User",
)
user = await user_crud.create(session, obj_in=user_data)
user = await user_repo.create(session, obj_in=user_data)
user_id = user.id
# Try to delete but exclude
async with AsyncTestingSessionLocal() as session:
count = await user_crud.bulk_soft_delete(
count = await user_repo.bulk_soft_delete(
session, user_ids=[user_id], exclude_user_id=user_id
)
assert count == 0
@@ -544,15 +544,15 @@ class TestBulkSoftDelete:
first_name="PreDeleted",
last_name="User",
)
user = await user_crud.create(session, obj_in=user_data)
user = await user_repo.create(session, obj_in=user_data)
user_id = user.id
# First deletion
await user_crud.bulk_soft_delete(session, user_ids=[user_id])
await user_repo.bulk_soft_delete(session, user_ids=[user_id])
# Try to delete again
async with AsyncTestingSessionLocal() as session:
count = await user_crud.bulk_soft_delete(session, user_ids=[user_id])
count = await user_repo.bulk_soft_delete(session, user_ids=[user_id])
assert count == 0 # Already deleted
@@ -561,16 +561,16 @@ class TestUtilityMethods:
@pytest.mark.asyncio
async def test_is_active_true(self, async_test_db, async_test_user):
"""Test is_active returns True for active user_crud."""
"""Test is_active returns True for active user_repo."""
_test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session:
user = await user_crud.get(session, id=str(async_test_user.id))
assert user_crud.is_active(user) is True
user = await user_repo.get(session, id=str(async_test_user.id))
assert user_repo.is_active(user) is True
@pytest.mark.asyncio
async def test_is_active_false(self, async_test_db):
"""Test is_active returns False for inactive user_crud."""
"""Test is_active returns False for inactive user_repo."""
_test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session:
@@ -580,10 +580,10 @@ class TestUtilityMethods:
first_name="Inactive",
last_name="User",
)
user = await user_crud.create(session, obj_in=user_data)
await user_crud.update(session, db_obj=user, obj_in={"is_active": False})
user = await user_repo.create(session, obj_in=user_data)
await user_repo.update(session, db_obj=user, obj_in={"is_active": False})
assert user_crud.is_active(user) is False
assert user_repo.is_active(user) is False
@pytest.mark.asyncio
async def test_is_superuser_true(self, async_test_db, async_test_superuser):
@@ -591,22 +591,22 @@ class TestUtilityMethods:
_test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session:
user = await user_crud.get(session, id=str(async_test_superuser.id))
assert user_crud.is_superuser(user) is True
user = await user_repo.get(session, id=str(async_test_superuser.id))
assert user_repo.is_superuser(user) is True
@pytest.mark.asyncio
async def test_is_superuser_false(self, async_test_db, async_test_user):
"""Test is_superuser returns False for regular user_crud."""
"""Test is_superuser returns False for regular user_repo."""
_test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session:
user = await user_crud.get(session, id=str(async_test_user.id))
assert user_crud.is_superuser(user) is False
user = await user_repo.get(session, id=str(async_test_user.id))
assert user_repo.is_superuser(user) is False
class TestUserExceptionHandlers:
"""
Test exception handlers in user CRUD methods.
Test exception handlers in user repository methods.
Covers lines: 30-32, 205-208, 257-260
"""
@@ -622,7 +622,7 @@ class TestUserExceptionHandlers:
session, "execute", side_effect=Exception("Database query failed")
):
with pytest.raises(Exception, match="Database query failed"):
await user_crud.get_by_email(session, email="test@example.com")
await user_repo.get_by_email(session, email="test@example.com")
@pytest.mark.asyncio
async def test_bulk_update_status_database_error(
@@ -640,7 +640,7 @@ class TestUserExceptionHandlers:
):
with patch.object(session, "rollback", new_callable=AsyncMock):
with pytest.raises(Exception, match="Bulk update failed"):
await user_crud.bulk_update_status(
await user_repo.bulk_update_status(
session, user_ids=[async_test_user.id], is_active=False
)
@@ -660,6 +660,6 @@ class TestUserExceptionHandlers:
):
with patch.object(session, "rollback", new_callable=AsyncMock):
with pytest.raises(Exception, match="Bulk delete failed"):
await user_crud.bulk_soft_delete(
await user_repo.bulk_soft_delete(
session, user_ids=[async_test_user.id]
)