Add extensive tests for handling CRUD and API error scenarios
- Introduced comprehensive tests for session CRUD error cases, covering exception handling, rollback mechanics, and database failure propagation. - Added robust API error handling tests for admin routes, including user and organization management. - Enhanced test coverage for unexpected errors, edge cases, and validation flows in session and admin operations.
This commit is contained in:
@@ -757,3 +757,79 @@ class TestCRUDBaseRestore:
|
||||
restored = await user_crud.restore(session, id=user_id) # UUID object
|
||||
assert restored is not None
|
||||
assert restored.deleted_at is None
|
||||
|
||||
|
||||
class TestCRUDBasePaginationValidation:
|
||||
"""Tests for pagination parameter validation (covers lines 254-260)."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_multi_with_total_negative_skip(self, async_test_db):
|
||||
"""Test that negative skip raises ValueError."""
|
||||
test_engine, SessionLocal = async_test_db
|
||||
|
||||
async with SessionLocal() as session:
|
||||
with pytest.raises(ValueError, match="skip must be non-negative"):
|
||||
await user_crud.get_multi_with_total(session, skip=-1, limit=10)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_multi_with_total_negative_limit(self, async_test_db):
|
||||
"""Test that negative limit raises ValueError."""
|
||||
test_engine, SessionLocal = async_test_db
|
||||
|
||||
async with SessionLocal() as session:
|
||||
with pytest.raises(ValueError, match="limit must be non-negative"):
|
||||
await user_crud.get_multi_with_total(session, skip=0, limit=-1)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_multi_with_total_limit_too_large(self, async_test_db):
|
||||
"""Test that limit > 1000 raises ValueError."""
|
||||
test_engine, SessionLocal = async_test_db
|
||||
|
||||
async with SessionLocal() as session:
|
||||
with pytest.raises(ValueError, match="Maximum limit is 1000"):
|
||||
await user_crud.get_multi_with_total(session, skip=0, limit=1001)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_multi_with_total_with_filters(self, async_test_db, async_test_user):
|
||||
"""Test pagination with filters (covers lines 270-273)."""
|
||||
test_engine, SessionLocal = async_test_db
|
||||
|
||||
async with SessionLocal() as session:
|
||||
users, total = await user_crud.get_multi_with_total(
|
||||
session,
|
||||
skip=0,
|
||||
limit=10,
|
||||
filters={"is_active": True}
|
||||
)
|
||||
assert isinstance(users, list)
|
||||
assert total >= 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_multi_with_total_with_sorting_desc(self, async_test_db):
|
||||
"""Test pagination with descending sort (covers lines 283-284)."""
|
||||
test_engine, SessionLocal = async_test_db
|
||||
|
||||
async with SessionLocal() as session:
|
||||
users, total = await user_crud.get_multi_with_total(
|
||||
session,
|
||||
skip=0,
|
||||
limit=10,
|
||||
sort_by="created_at",
|
||||
sort_order="desc"
|
||||
)
|
||||
assert isinstance(users, list)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_multi_with_total_with_sorting_asc(self, async_test_db):
|
||||
"""Test pagination with ascending sort (covers lines 285-286)."""
|
||||
test_engine, SessionLocal = async_test_db
|
||||
|
||||
async with SessionLocal() as session:
|
||||
users, total = await user_crud.get_multi_with_total(
|
||||
session,
|
||||
skip=0,
|
||||
limit=10,
|
||||
sort_by="created_at",
|
||||
sort_order="asc"
|
||||
)
|
||||
assert isinstance(users, list)
|
||||
|
||||
293
backend/tests/crud/test_base_db_failures.py
Normal file
293
backend/tests/crud/test_base_db_failures.py
Normal file
@@ -0,0 +1,293 @@
|
||||
# tests/crud/test_base_db_failures.py
|
||||
"""
|
||||
Comprehensive tests for base CRUD database failure scenarios.
|
||||
Tests exception handling, rollbacks, and error messages.
|
||||
"""
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
from sqlalchemy.exc import IntegrityError, OperationalError, DataError
|
||||
from uuid import uuid4
|
||||
|
||||
from app.crud.user import user as user_crud
|
||||
from app.schemas.users import UserCreate, UserUpdate
|
||||
|
||||
|
||||
class TestBaseCRUDCreateFailures:
|
||||
"""Test base CRUD create method exception handling."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_operational_error_triggers_rollback(self, async_test_db):
|
||||
"""Test that OperationalError triggers rollback (User CRUD catches as Exception)."""
|
||||
test_engine, SessionLocal = async_test_db
|
||||
|
||||
async with SessionLocal() as session:
|
||||
async def mock_commit():
|
||||
raise OperationalError("Connection lost", {}, Exception("DB connection failed"))
|
||||
|
||||
with patch.object(session, 'commit', side_effect=mock_commit):
|
||||
with patch.object(session, 'rollback', new_callable=AsyncMock) as mock_rollback:
|
||||
user_data = UserCreate(
|
||||
email="operror@example.com",
|
||||
password="TestPassword123!",
|
||||
first_name="Test",
|
||||
last_name="User"
|
||||
)
|
||||
|
||||
# User CRUD catches this as generic Exception and re-raises
|
||||
with pytest.raises(OperationalError):
|
||||
await user_crud.create(session, obj_in=user_data)
|
||||
|
||||
# Verify rollback was called
|
||||
mock_rollback.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_data_error_triggers_rollback(self, async_test_db):
|
||||
"""Test that DataError triggers rollback (User CRUD catches as Exception)."""
|
||||
test_engine, SessionLocal = async_test_db
|
||||
|
||||
async with SessionLocal() as session:
|
||||
async def mock_commit():
|
||||
raise DataError("Invalid data type", {}, Exception("Data overflow"))
|
||||
|
||||
with patch.object(session, 'commit', side_effect=mock_commit):
|
||||
with patch.object(session, 'rollback', new_callable=AsyncMock) as mock_rollback:
|
||||
user_data = UserCreate(
|
||||
email="dataerror@example.com",
|
||||
password="TestPassword123!",
|
||||
first_name="Test",
|
||||
last_name="User"
|
||||
)
|
||||
|
||||
# User CRUD catches this as generic Exception and re-raises
|
||||
with pytest.raises(DataError):
|
||||
await user_crud.create(session, obj_in=user_data)
|
||||
|
||||
mock_rollback.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_unexpected_exception_triggers_rollback(self, async_test_db):
|
||||
"""Test that unexpected exceptions trigger rollback and re-raise."""
|
||||
test_engine, SessionLocal = async_test_db
|
||||
|
||||
async with SessionLocal() as session:
|
||||
async def mock_commit():
|
||||
raise RuntimeError("Unexpected database error")
|
||||
|
||||
with patch.object(session, 'commit', side_effect=mock_commit):
|
||||
with patch.object(session, 'rollback', new_callable=AsyncMock) as mock_rollback:
|
||||
user_data = UserCreate(
|
||||
email="unexpected@example.com",
|
||||
password="TestPassword123!",
|
||||
first_name="Test",
|
||||
last_name="User"
|
||||
)
|
||||
|
||||
with pytest.raises(RuntimeError, match="Unexpected database error"):
|
||||
await user_crud.create(session, obj_in=user_data)
|
||||
|
||||
mock_rollback.assert_called_once()
|
||||
|
||||
|
||||
class TestBaseCRUDUpdateFailures:
|
||||
"""Test base CRUD update method exception handling."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_operational_error(self, async_test_db, async_test_user):
|
||||
"""Test update with OperationalError."""
|
||||
test_engine, SessionLocal = async_test_db
|
||||
|
||||
async with SessionLocal() as session:
|
||||
user = await user_crud.get(session, id=str(async_test_user.id))
|
||||
|
||||
async def mock_commit():
|
||||
raise OperationalError("Connection timeout", {}, Exception("Timeout"))
|
||||
|
||||
with patch.object(session, 'commit', side_effect=mock_commit):
|
||||
with patch.object(session, 'rollback', new_callable=AsyncMock) as mock_rollback:
|
||||
with pytest.raises(ValueError, match="Database operation failed"):
|
||||
await user_crud.update(session, db_obj=user, obj_in={"first_name": "Updated"})
|
||||
|
||||
mock_rollback.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_data_error(self, async_test_db, async_test_user):
|
||||
"""Test update with DataError."""
|
||||
test_engine, SessionLocal = async_test_db
|
||||
|
||||
async with SessionLocal() as session:
|
||||
user = await user_crud.get(session, id=str(async_test_user.id))
|
||||
|
||||
async def mock_commit():
|
||||
raise DataError("Invalid data", {}, Exception("Data type mismatch"))
|
||||
|
||||
with patch.object(session, 'commit', side_effect=mock_commit):
|
||||
with patch.object(session, 'rollback', new_callable=AsyncMock) as mock_rollback:
|
||||
with pytest.raises(ValueError, match="Database operation failed"):
|
||||
await user_crud.update(session, db_obj=user, obj_in={"first_name": "Updated"})
|
||||
|
||||
mock_rollback.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_unexpected_error(self, async_test_db, async_test_user):
|
||||
"""Test update with unexpected error."""
|
||||
test_engine, SessionLocal = async_test_db
|
||||
|
||||
async with SessionLocal() as session:
|
||||
user = await user_crud.get(session, id=str(async_test_user.id))
|
||||
|
||||
async def mock_commit():
|
||||
raise KeyError("Unexpected error")
|
||||
|
||||
with patch.object(session, 'commit', side_effect=mock_commit):
|
||||
with patch.object(session, 'rollback', new_callable=AsyncMock) as mock_rollback:
|
||||
with pytest.raises(KeyError):
|
||||
await user_crud.update(session, db_obj=user, obj_in={"first_name": "Updated"})
|
||||
|
||||
mock_rollback.assert_called_once()
|
||||
|
||||
|
||||
class TestBaseCRUDRemoveFailures:
|
||||
"""Test base CRUD remove method exception handling."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_remove_unexpected_error_triggers_rollback(self, async_test_db, async_test_user):
|
||||
"""Test that unexpected errors in remove trigger rollback."""
|
||||
test_engine, SessionLocal = async_test_db
|
||||
|
||||
async with SessionLocal() as session:
|
||||
async def mock_commit():
|
||||
raise RuntimeError("Database write failed")
|
||||
|
||||
with patch.object(session, 'commit', side_effect=mock_commit):
|
||||
with patch.object(session, 'rollback', new_callable=AsyncMock) as mock_rollback:
|
||||
with pytest.raises(RuntimeError, match="Database write failed"):
|
||||
await user_crud.remove(session, id=str(async_test_user.id))
|
||||
|
||||
mock_rollback.assert_called_once()
|
||||
|
||||
|
||||
class TestBaseCRUDGetMultiWithTotalFailures:
|
||||
"""Test get_multi_with_total exception handling."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_multi_with_total_database_error(self, async_test_db):
|
||||
"""Test get_multi_with_total handles database errors."""
|
||||
test_engine, SessionLocal = async_test_db
|
||||
|
||||
async with SessionLocal() as session:
|
||||
# Mock execute to raise an error
|
||||
original_execute = session.execute
|
||||
|
||||
async def mock_execute(*args, **kwargs):
|
||||
raise OperationalError("Query failed", {}, Exception("Database error"))
|
||||
|
||||
with patch.object(session, 'execute', side_effect=mock_execute):
|
||||
with pytest.raises(OperationalError):
|
||||
await user_crud.get_multi_with_total(session, skip=0, limit=10)
|
||||
|
||||
|
||||
class TestBaseCRUDCountFailures:
|
||||
"""Test count method exception handling."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_count_database_error_propagates(self, async_test_db):
|
||||
"""Test count propagates database errors."""
|
||||
test_engine, SessionLocal = async_test_db
|
||||
|
||||
async with SessionLocal() as session:
|
||||
async def mock_execute(*args, **kwargs):
|
||||
raise OperationalError("Count failed", {}, Exception("DB error"))
|
||||
|
||||
with patch.object(session, 'execute', side_effect=mock_execute):
|
||||
with pytest.raises(OperationalError):
|
||||
await user_crud.count(session)
|
||||
|
||||
|
||||
class TestBaseCRUDSoftDeleteFailures:
|
||||
"""Test soft_delete method exception handling."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_soft_delete_unexpected_error_triggers_rollback(self, async_test_db, async_test_user):
|
||||
"""Test soft_delete handles unexpected errors with rollback."""
|
||||
test_engine, SessionLocal = async_test_db
|
||||
|
||||
async with SessionLocal() as session:
|
||||
async def mock_commit():
|
||||
raise RuntimeError("Soft delete failed")
|
||||
|
||||
with patch.object(session, 'commit', side_effect=mock_commit):
|
||||
with patch.object(session, 'rollback', new_callable=AsyncMock) as mock_rollback:
|
||||
with pytest.raises(RuntimeError, match="Soft delete failed"):
|
||||
await user_crud.soft_delete(session, id=str(async_test_user.id))
|
||||
|
||||
mock_rollback.assert_called_once()
|
||||
|
||||
|
||||
class TestBaseCRUDRestoreFailures:
|
||||
"""Test restore method exception handling."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_restore_unexpected_error_triggers_rollback(self, async_test_db):
|
||||
"""Test restore handles unexpected errors with rollback."""
|
||||
test_engine, SessionLocal = async_test_db
|
||||
|
||||
# First create and soft delete a user
|
||||
async with SessionLocal() as session:
|
||||
user_data = UserCreate(
|
||||
email="restore_test@example.com",
|
||||
password="TestPassword123!",
|
||||
first_name="Restore",
|
||||
last_name="Test"
|
||||
)
|
||||
user = await user_crud.create(session, obj_in=user_data)
|
||||
user_id = user.id
|
||||
await session.commit()
|
||||
|
||||
async with SessionLocal() as session:
|
||||
await user_crud.soft_delete(session, id=str(user_id))
|
||||
|
||||
# Now test restore failure
|
||||
async with SessionLocal() as session:
|
||||
async def mock_commit():
|
||||
raise RuntimeError("Restore failed")
|
||||
|
||||
with patch.object(session, 'commit', side_effect=mock_commit):
|
||||
with patch.object(session, 'rollback', new_callable=AsyncMock) as mock_rollback:
|
||||
with pytest.raises(RuntimeError, match="Restore failed"):
|
||||
await user_crud.restore(session, id=str(user_id))
|
||||
|
||||
mock_rollback.assert_called_once()
|
||||
|
||||
|
||||
class TestBaseCRUDGetFailures:
|
||||
"""Test get method exception handling."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_database_error_propagates(self, async_test_db):
|
||||
"""Test get propagates database errors."""
|
||||
test_engine, SessionLocal = async_test_db
|
||||
|
||||
async with SessionLocal() as session:
|
||||
async def mock_execute(*args, **kwargs):
|
||||
raise OperationalError("Get failed", {}, Exception("DB error"))
|
||||
|
||||
with patch.object(session, 'execute', side_effect=mock_execute):
|
||||
with pytest.raises(OperationalError):
|
||||
await user_crud.get(session, id=str(uuid4()))
|
||||
|
||||
|
||||
class TestBaseCRUDGetMultiFailures:
|
||||
"""Test get_multi method exception handling."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_multi_database_error_propagates(self, async_test_db):
|
||||
"""Test get_multi propagates database errors."""
|
||||
test_engine, SessionLocal = async_test_db
|
||||
|
||||
async with SessionLocal() as session:
|
||||
async def mock_execute(*args, **kwargs):
|
||||
raise OperationalError("Query failed", {}, Exception("DB error"))
|
||||
|
||||
with patch.object(session, 'execute', side_effect=mock_execute):
|
||||
with pytest.raises(OperationalError):
|
||||
await user_crud.get_multi(session, skip=0, limit=10)
|
||||
336
backend/tests/crud/test_session_db_failures.py
Normal file
336
backend/tests/crud/test_session_db_failures.py
Normal file
@@ -0,0 +1,336 @@
|
||||
# tests/crud/test_session_db_failures.py
|
||||
"""
|
||||
Comprehensive tests for session CRUD database failure scenarios.
|
||||
"""
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock, patch
|
||||
from sqlalchemy.exc import OperationalError, IntegrityError
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from uuid import uuid4
|
||||
|
||||
from app.crud.session import session as session_crud
|
||||
from app.models.user_session import UserSession
|
||||
from app.schemas.sessions import SessionCreate
|
||||
|
||||
|
||||
class TestSessionCRUDGetByJtiFailures:
|
||||
"""Test get_by_jti exception handling."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_by_jti_database_error(self, async_test_db):
|
||||
"""Test get_by_jti handles database errors."""
|
||||
test_engine, SessionLocal = async_test_db
|
||||
|
||||
async with SessionLocal() as session:
|
||||
async def mock_execute(*args, **kwargs):
|
||||
raise OperationalError("DB connection lost", {}, Exception())
|
||||
|
||||
with patch.object(session, 'execute', side_effect=mock_execute):
|
||||
with pytest.raises(OperationalError):
|
||||
await session_crud.get_by_jti(session, jti="test_jti")
|
||||
|
||||
|
||||
class TestSessionCRUDGetActiveByJtiFailures:
|
||||
"""Test get_active_by_jti exception handling."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_active_by_jti_database_error(self, async_test_db):
|
||||
"""Test get_active_by_jti handles database errors."""
|
||||
test_engine, SessionLocal = async_test_db
|
||||
|
||||
async with SessionLocal() as session:
|
||||
async def mock_execute(*args, **kwargs):
|
||||
raise OperationalError("Query timeout", {}, Exception())
|
||||
|
||||
with patch.object(session, 'execute', side_effect=mock_execute):
|
||||
with pytest.raises(OperationalError):
|
||||
await session_crud.get_active_by_jti(session, jti="test_jti")
|
||||
|
||||
|
||||
class TestSessionCRUDGetUserSessionsFailures:
|
||||
"""Test get_user_sessions exception handling."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_user_sessions_database_error(self, async_test_db, async_test_user):
|
||||
"""Test get_user_sessions handles database errors."""
|
||||
test_engine, SessionLocal = async_test_db
|
||||
|
||||
async with SessionLocal() as session:
|
||||
async def mock_execute(*args, **kwargs):
|
||||
raise OperationalError("Database error", {}, Exception())
|
||||
|
||||
with patch.object(session, 'execute', side_effect=mock_execute):
|
||||
with pytest.raises(OperationalError):
|
||||
await session_crud.get_user_sessions(
|
||||
session,
|
||||
user_id=str(async_test_user.id)
|
||||
)
|
||||
|
||||
|
||||
class TestSessionCRUDCreateSessionFailures:
|
||||
"""Test create_session exception handling."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_session_commit_failure_triggers_rollback(self, async_test_db, async_test_user):
|
||||
"""Test create_session handles commit failures with rollback."""
|
||||
test_engine, SessionLocal = async_test_db
|
||||
|
||||
async with SessionLocal() as session:
|
||||
async def mock_commit():
|
||||
raise OperationalError("Commit failed", {}, Exception())
|
||||
|
||||
with patch.object(session, 'commit', side_effect=mock_commit):
|
||||
with patch.object(session, 'rollback', new_callable=AsyncMock) as mock_rollback:
|
||||
session_data = SessionCreate(
|
||||
user_id=async_test_user.id,
|
||||
refresh_token_jti=str(uuid4()),
|
||||
device_name="Test Device",
|
||||
ip_address="127.0.0.1",
|
||||
user_agent="Test Agent",
|
||||
expires_at=datetime.now(timezone.utc) + timedelta(days=7),
|
||||
last_used_at=datetime.now(timezone.utc)
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="Failed to create session"):
|
||||
await session_crud.create_session(session, obj_in=session_data)
|
||||
|
||||
mock_rollback.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_session_unexpected_error_triggers_rollback(self, async_test_db, async_test_user):
|
||||
"""Test create_session handles unexpected errors."""
|
||||
test_engine, SessionLocal = async_test_db
|
||||
|
||||
async with SessionLocal() as session:
|
||||
async def mock_commit():
|
||||
raise RuntimeError("Unexpected error")
|
||||
|
||||
with patch.object(session, 'commit', side_effect=mock_commit):
|
||||
with patch.object(session, 'rollback', new_callable=AsyncMock) as mock_rollback:
|
||||
session_data = SessionCreate(
|
||||
user_id=async_test_user.id,
|
||||
refresh_token_jti=str(uuid4()),
|
||||
device_name="Test Device",
|
||||
ip_address="127.0.0.1",
|
||||
user_agent="Test Agent",
|
||||
expires_at=datetime.now(timezone.utc) + timedelta(days=7),
|
||||
last_used_at=datetime.now(timezone.utc)
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="Failed to create session"):
|
||||
await session_crud.create_session(session, obj_in=session_data)
|
||||
|
||||
mock_rollback.assert_called_once()
|
||||
|
||||
|
||||
class TestSessionCRUDDeactivateFailures:
|
||||
"""Test deactivate exception handling."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_deactivate_commit_failure_triggers_rollback(self, async_test_db, async_test_user):
|
||||
"""Test deactivate handles commit failures."""
|
||||
test_engine, SessionLocal = async_test_db
|
||||
|
||||
# Create a session first
|
||||
async with SessionLocal() as session:
|
||||
user_session = UserSession(
|
||||
user_id=async_test_user.id,
|
||||
refresh_token_jti=str(uuid4()),
|
||||
device_name="Test Device",
|
||||
ip_address="127.0.0.1",
|
||||
user_agent="Test Agent",
|
||||
is_active=True,
|
||||
expires_at=datetime.now(timezone.utc) + timedelta(days=7),
|
||||
last_used_at=datetime.now(timezone.utc)
|
||||
)
|
||||
session.add(user_session)
|
||||
await session.commit()
|
||||
await session.refresh(user_session)
|
||||
session_id = user_session.id
|
||||
|
||||
# Test deactivate failure
|
||||
async with SessionLocal() as session:
|
||||
async def mock_commit():
|
||||
raise OperationalError("Deactivate failed", {}, Exception())
|
||||
|
||||
with patch.object(session, 'commit', side_effect=mock_commit):
|
||||
with patch.object(session, 'rollback', new_callable=AsyncMock) as mock_rollback:
|
||||
with pytest.raises(OperationalError):
|
||||
await session_crud.deactivate(session, session_id=str(session_id))
|
||||
|
||||
mock_rollback.assert_called_once()
|
||||
|
||||
|
||||
class TestSessionCRUDDeactivateAllFailures:
|
||||
"""Test deactivate_all_user_sessions exception handling."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_deactivate_all_commit_failure_triggers_rollback(self, async_test_db, async_test_user):
|
||||
"""Test deactivate_all handles commit failures."""
|
||||
test_engine, SessionLocal = async_test_db
|
||||
|
||||
async with SessionLocal() as session:
|
||||
async def mock_commit():
|
||||
raise OperationalError("Bulk deactivate failed", {}, Exception())
|
||||
|
||||
with patch.object(session, 'commit', side_effect=mock_commit):
|
||||
with patch.object(session, 'rollback', new_callable=AsyncMock) as mock_rollback:
|
||||
with pytest.raises(OperationalError):
|
||||
await session_crud.deactivate_all_user_sessions(
|
||||
session,
|
||||
user_id=str(async_test_user.id)
|
||||
)
|
||||
|
||||
mock_rollback.assert_called_once()
|
||||
|
||||
|
||||
class TestSessionCRUDUpdateLastUsedFailures:
|
||||
"""Test update_last_used exception handling."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_last_used_commit_failure_triggers_rollback(self, async_test_db, async_test_user):
|
||||
"""Test update_last_used handles commit failures."""
|
||||
test_engine, SessionLocal = async_test_db
|
||||
|
||||
# Create a session
|
||||
async with SessionLocal() as session:
|
||||
user_session = UserSession(
|
||||
user_id=async_test_user.id,
|
||||
refresh_token_jti=str(uuid4()),
|
||||
device_name="Test Device",
|
||||
ip_address="127.0.0.1",
|
||||
user_agent="Test Agent",
|
||||
is_active=True,
|
||||
expires_at=datetime.now(timezone.utc) + timedelta(days=7),
|
||||
last_used_at=datetime.now(timezone.utc) - timedelta(hours=1)
|
||||
)
|
||||
session.add(user_session)
|
||||
await session.commit()
|
||||
await session.refresh(user_session)
|
||||
|
||||
# Test update failure
|
||||
async with SessionLocal() as session:
|
||||
from sqlalchemy import select
|
||||
from app.models.user_session import UserSession as US
|
||||
result = await session.execute(select(US).where(US.id == user_session.id))
|
||||
sess = result.scalar_one()
|
||||
|
||||
async def mock_commit():
|
||||
raise OperationalError("Update failed", {}, Exception())
|
||||
|
||||
with patch.object(session, 'commit', side_effect=mock_commit):
|
||||
with patch.object(session, 'rollback', new_callable=AsyncMock) as mock_rollback:
|
||||
with pytest.raises(OperationalError):
|
||||
await session_crud.update_last_used(session, session=sess)
|
||||
|
||||
mock_rollback.assert_called_once()
|
||||
|
||||
|
||||
class TestSessionCRUDUpdateRefreshTokenFailures:
|
||||
"""Test update_refresh_token exception handling."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_refresh_token_commit_failure_triggers_rollback(self, async_test_db, async_test_user):
|
||||
"""Test update_refresh_token handles commit failures."""
|
||||
test_engine, SessionLocal = async_test_db
|
||||
|
||||
# Create a session
|
||||
async with SessionLocal() as session:
|
||||
user_session = UserSession(
|
||||
user_id=async_test_user.id,
|
||||
refresh_token_jti=str(uuid4()),
|
||||
device_name="Test Device",
|
||||
ip_address="127.0.0.1",
|
||||
user_agent="Test Agent",
|
||||
is_active=True,
|
||||
expires_at=datetime.now(timezone.utc) + timedelta(days=7),
|
||||
last_used_at=datetime.now(timezone.utc)
|
||||
)
|
||||
session.add(user_session)
|
||||
await session.commit()
|
||||
await session.refresh(user_session)
|
||||
|
||||
# Test update failure
|
||||
async with SessionLocal() as session:
|
||||
from sqlalchemy import select
|
||||
from app.models.user_session import UserSession as US
|
||||
result = await session.execute(select(US).where(US.id == user_session.id))
|
||||
sess = result.scalar_one()
|
||||
|
||||
async def mock_commit():
|
||||
raise OperationalError("Token update failed", {}, Exception())
|
||||
|
||||
with patch.object(session, 'commit', side_effect=mock_commit):
|
||||
with patch.object(session, 'rollback', new_callable=AsyncMock) as mock_rollback:
|
||||
with pytest.raises(OperationalError):
|
||||
await session_crud.update_refresh_token(
|
||||
session,
|
||||
session=sess,
|
||||
new_jti=str(uuid4()),
|
||||
new_expires_at=datetime.now(timezone.utc) + timedelta(days=14)
|
||||
)
|
||||
|
||||
mock_rollback.assert_called_once()
|
||||
|
||||
|
||||
class TestSessionCRUDCleanupExpiredFailures:
|
||||
"""Test cleanup_expired exception handling."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cleanup_expired_commit_failure_triggers_rollback(self, async_test_db):
|
||||
"""Test cleanup_expired handles commit failures."""
|
||||
test_engine, SessionLocal = async_test_db
|
||||
|
||||
async with SessionLocal() as session:
|
||||
async def mock_commit():
|
||||
raise OperationalError("Cleanup failed", {}, Exception())
|
||||
|
||||
with patch.object(session, 'commit', side_effect=mock_commit):
|
||||
with patch.object(session, 'rollback', new_callable=AsyncMock) as mock_rollback:
|
||||
with pytest.raises(OperationalError):
|
||||
await session_crud.cleanup_expired(session, keep_days=30)
|
||||
|
||||
mock_rollback.assert_called_once()
|
||||
|
||||
|
||||
class TestSessionCRUDCleanupExpiredForUserFailures:
|
||||
"""Test cleanup_expired_for_user exception handling."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cleanup_expired_for_user_commit_failure_triggers_rollback(self, async_test_db, async_test_user):
|
||||
"""Test cleanup_expired_for_user handles commit failures."""
|
||||
test_engine, SessionLocal = async_test_db
|
||||
|
||||
async with SessionLocal() as session:
|
||||
async def mock_commit():
|
||||
raise OperationalError("User cleanup failed", {}, Exception())
|
||||
|
||||
with patch.object(session, 'commit', side_effect=mock_commit):
|
||||
with patch.object(session, 'rollback', new_callable=AsyncMock) as mock_rollback:
|
||||
with pytest.raises(OperationalError):
|
||||
await session_crud.cleanup_expired_for_user(
|
||||
session,
|
||||
user_id=str(async_test_user.id)
|
||||
)
|
||||
|
||||
mock_rollback.assert_called_once()
|
||||
|
||||
|
||||
class TestSessionCRUDGetUserSessionCountFailures:
|
||||
"""Test get_user_session_count exception handling."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_user_session_count_database_error(self, async_test_db, async_test_user):
|
||||
"""Test get_user_session_count handles database errors."""
|
||||
test_engine, SessionLocal = async_test_db
|
||||
|
||||
async with SessionLocal() as session:
|
||||
async def mock_execute(*args, **kwargs):
|
||||
raise OperationalError("Count query failed", {}, Exception())
|
||||
|
||||
with patch.object(session, 'execute', side_effect=mock_execute):
|
||||
with pytest.raises(OperationalError):
|
||||
await session_crud.get_user_session_count(
|
||||
session,
|
||||
user_id=str(async_test_user.id)
|
||||
)
|
||||
Reference in New Issue
Block a user