refactor(backend): enforce route→service→repo layered architecture
- introduce custom repository exception hierarchy (DuplicateEntryError, IntegrityConstraintError, InvalidInputError) replacing raw ValueError - eliminate all direct repository imports and raw SQL from route layer - add UserService, SessionService, OrganizationService to service layer - add get_stats/get_org_distribution service methods replacing admin inline SQL - fix timing side-channel in authenticate_user via dummy bcrypt check - replace SHA-256 client secret fallback with explicit InvalidClientError - replace assert with InvalidGrantError in authorization code exchange - replace N+1 token revocation loops with bulk UPDATE statements - rename oauth account token fields (drop misleading 'encrypted' suffix) - add Alembic migration 0003 for token field column rename - add 45 new service/repository tests; 975 passing, 94% coverage
This commit is contained in:
0
backend/tests/repositories/__init__.py
Executable file
0
backend/tests/repositories/__init__.py
Executable file
1026
backend/tests/repositories/test_base.py
Normal file
1026
backend/tests/repositories/test_base.py
Normal file
File diff suppressed because it is too large
Load Diff
334
backend/tests/repositories/test_base_db_failures.py
Normal file
334
backend/tests/repositories/test_base_db_failures.py
Normal file
@@ -0,0 +1,334 @@
|
||||
# tests/crud/test_base_db_failures.py
|
||||
"""
|
||||
Comprehensive tests for base CRUD database failure scenarios.
|
||||
Tests exception handling, rollbacks, and error messages.
|
||||
"""
|
||||
|
||||
from unittest.mock import AsyncMock, patch
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
from sqlalchemy.exc import DataError, OperationalError
|
||||
|
||||
from app.core.repository_exceptions import IntegrityConstraintError
|
||||
from app.repositories.user import user_repo as user_crud
|
||||
from app.schemas.users import UserCreate
|
||||
|
||||
|
||||
class TestBaseCRUDCreateFailures:
|
||||
"""Test base CRUD create method exception handling."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_operational_error_triggers_rollback(self, async_test_db):
|
||||
"""Test that OperationalError triggers rollback (User CRUD catches as Exception)."""
|
||||
_test_engine, SessionLocal = async_test_db
|
||||
|
||||
async with SessionLocal() as session:
|
||||
|
||||
async def mock_commit():
|
||||
raise OperationalError(
|
||||
"Connection lost", {}, Exception("DB connection failed")
|
||||
)
|
||||
|
||||
with patch.object(session, "commit", side_effect=mock_commit):
|
||||
with patch.object(
|
||||
session, "rollback", new_callable=AsyncMock
|
||||
) as mock_rollback:
|
||||
user_data = UserCreate(
|
||||
email="operror@example.com",
|
||||
password="TestPassword123!",
|
||||
first_name="Test",
|
||||
last_name="User",
|
||||
)
|
||||
|
||||
# User CRUD catches this as generic Exception and re-raises
|
||||
with pytest.raises(OperationalError):
|
||||
await user_crud.create(session, obj_in=user_data)
|
||||
|
||||
# Verify rollback was called
|
||||
mock_rollback.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_data_error_triggers_rollback(self, async_test_db):
|
||||
"""Test that DataError triggers rollback (User CRUD catches as Exception)."""
|
||||
_test_engine, SessionLocal = async_test_db
|
||||
|
||||
async with SessionLocal() as session:
|
||||
|
||||
async def mock_commit():
|
||||
raise DataError("Invalid data type", {}, Exception("Data overflow"))
|
||||
|
||||
with patch.object(session, "commit", side_effect=mock_commit):
|
||||
with patch.object(
|
||||
session, "rollback", new_callable=AsyncMock
|
||||
) as mock_rollback:
|
||||
user_data = UserCreate(
|
||||
email="dataerror@example.com",
|
||||
password="TestPassword123!",
|
||||
first_name="Test",
|
||||
last_name="User",
|
||||
)
|
||||
|
||||
# User CRUD catches this as generic Exception and re-raises
|
||||
with pytest.raises(DataError):
|
||||
await user_crud.create(session, obj_in=user_data)
|
||||
|
||||
mock_rollback.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_unexpected_exception_triggers_rollback(self, async_test_db):
|
||||
"""Test that unexpected exceptions trigger rollback and re-raise."""
|
||||
_test_engine, SessionLocal = async_test_db
|
||||
|
||||
async with SessionLocal() as session:
|
||||
|
||||
async def mock_commit():
|
||||
raise RuntimeError("Unexpected database error")
|
||||
|
||||
with patch.object(session, "commit", side_effect=mock_commit):
|
||||
with patch.object(
|
||||
session, "rollback", new_callable=AsyncMock
|
||||
) as mock_rollback:
|
||||
user_data = UserCreate(
|
||||
email="unexpected@example.com",
|
||||
password="TestPassword123!",
|
||||
first_name="Test",
|
||||
last_name="User",
|
||||
)
|
||||
|
||||
with pytest.raises(RuntimeError, match="Unexpected database error"):
|
||||
await user_crud.create(session, obj_in=user_data)
|
||||
|
||||
mock_rollback.assert_called_once()
|
||||
|
||||
|
||||
class TestBaseCRUDUpdateFailures:
|
||||
"""Test base CRUD update method exception handling."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_operational_error(self, async_test_db, async_test_user):
|
||||
"""Test update with OperationalError."""
|
||||
_test_engine, SessionLocal = async_test_db
|
||||
|
||||
async with SessionLocal() as session:
|
||||
user = await user_crud.get(session, id=str(async_test_user.id))
|
||||
|
||||
async def mock_commit():
|
||||
raise OperationalError("Connection timeout", {}, Exception("Timeout"))
|
||||
|
||||
with patch.object(session, "commit", side_effect=mock_commit):
|
||||
with patch.object(
|
||||
session, "rollback", new_callable=AsyncMock
|
||||
) as mock_rollback:
|
||||
with pytest.raises(IntegrityConstraintError, match="Database operation failed"):
|
||||
await user_crud.update(
|
||||
session, db_obj=user, obj_in={"first_name": "Updated"}
|
||||
)
|
||||
|
||||
mock_rollback.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_data_error(self, async_test_db, async_test_user):
|
||||
"""Test update with DataError."""
|
||||
_test_engine, SessionLocal = async_test_db
|
||||
|
||||
async with SessionLocal() as session:
|
||||
user = await user_crud.get(session, id=str(async_test_user.id))
|
||||
|
||||
async def mock_commit():
|
||||
raise DataError("Invalid data", {}, Exception("Data type mismatch"))
|
||||
|
||||
with patch.object(session, "commit", side_effect=mock_commit):
|
||||
with patch.object(
|
||||
session, "rollback", new_callable=AsyncMock
|
||||
) as mock_rollback:
|
||||
with pytest.raises(IntegrityConstraintError, match="Database operation failed"):
|
||||
await user_crud.update(
|
||||
session, db_obj=user, obj_in={"first_name": "Updated"}
|
||||
)
|
||||
|
||||
mock_rollback.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_unexpected_error(self, async_test_db, async_test_user):
|
||||
"""Test update with unexpected error."""
|
||||
_test_engine, SessionLocal = async_test_db
|
||||
|
||||
async with SessionLocal() as session:
|
||||
user = await user_crud.get(session, id=str(async_test_user.id))
|
||||
|
||||
async def mock_commit():
|
||||
raise KeyError("Unexpected error")
|
||||
|
||||
with patch.object(session, "commit", side_effect=mock_commit):
|
||||
with patch.object(
|
||||
session, "rollback", new_callable=AsyncMock
|
||||
) as mock_rollback:
|
||||
with pytest.raises(KeyError):
|
||||
await user_crud.update(
|
||||
session, db_obj=user, obj_in={"first_name": "Updated"}
|
||||
)
|
||||
|
||||
mock_rollback.assert_called_once()
|
||||
|
||||
|
||||
class TestBaseCRUDRemoveFailures:
|
||||
"""Test base CRUD remove method exception handling."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_remove_unexpected_error_triggers_rollback(
|
||||
self, async_test_db, async_test_user
|
||||
):
|
||||
"""Test that unexpected errors in remove trigger rollback."""
|
||||
_test_engine, SessionLocal = async_test_db
|
||||
|
||||
async with SessionLocal() as session:
|
||||
|
||||
async def mock_commit():
|
||||
raise RuntimeError("Database write failed")
|
||||
|
||||
with patch.object(session, "commit", side_effect=mock_commit):
|
||||
with patch.object(
|
||||
session, "rollback", new_callable=AsyncMock
|
||||
) as mock_rollback:
|
||||
with pytest.raises(RuntimeError, match="Database write failed"):
|
||||
await user_crud.remove(session, id=str(async_test_user.id))
|
||||
|
||||
mock_rollback.assert_called_once()
|
||||
|
||||
|
||||
class TestBaseCRUDGetMultiWithTotalFailures:
|
||||
"""Test get_multi_with_total exception handling."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_multi_with_total_database_error(self, async_test_db):
|
||||
"""Test get_multi_with_total handles database errors."""
|
||||
_test_engine, SessionLocal = async_test_db
|
||||
|
||||
async with SessionLocal() as session:
|
||||
# Mock execute to raise an error
|
||||
|
||||
async def mock_execute(*args, **kwargs):
|
||||
raise OperationalError("Query failed", {}, Exception("Database error"))
|
||||
|
||||
with patch.object(session, "execute", side_effect=mock_execute):
|
||||
with pytest.raises(OperationalError):
|
||||
await user_crud.get_multi_with_total(session, skip=0, limit=10)
|
||||
|
||||
|
||||
class TestBaseCRUDCountFailures:
|
||||
"""Test count method exception handling."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_count_database_error_propagates(self, async_test_db):
|
||||
"""Test count propagates database errors."""
|
||||
_test_engine, SessionLocal = async_test_db
|
||||
|
||||
async with SessionLocal() as session:
|
||||
|
||||
async def mock_execute(*args, **kwargs):
|
||||
raise OperationalError("Count failed", {}, Exception("DB error"))
|
||||
|
||||
with patch.object(session, "execute", side_effect=mock_execute):
|
||||
with pytest.raises(OperationalError):
|
||||
await user_crud.count(session)
|
||||
|
||||
|
||||
class TestBaseCRUDSoftDeleteFailures:
|
||||
"""Test soft_delete method exception handling."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_soft_delete_unexpected_error_triggers_rollback(
|
||||
self, async_test_db, async_test_user
|
||||
):
|
||||
"""Test soft_delete handles unexpected errors with rollback."""
|
||||
_test_engine, SessionLocal = async_test_db
|
||||
|
||||
async with SessionLocal() as session:
|
||||
|
||||
async def mock_commit():
|
||||
raise RuntimeError("Soft delete failed")
|
||||
|
||||
with patch.object(session, "commit", side_effect=mock_commit):
|
||||
with patch.object(
|
||||
session, "rollback", new_callable=AsyncMock
|
||||
) as mock_rollback:
|
||||
with pytest.raises(RuntimeError, match="Soft delete failed"):
|
||||
await user_crud.soft_delete(session, id=str(async_test_user.id))
|
||||
|
||||
mock_rollback.assert_called_once()
|
||||
|
||||
|
||||
class TestBaseCRUDRestoreFailures:
|
||||
"""Test restore method exception handling."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_restore_unexpected_error_triggers_rollback(self, async_test_db):
|
||||
"""Test restore handles unexpected errors with rollback."""
|
||||
_test_engine, SessionLocal = async_test_db
|
||||
|
||||
# First create and soft delete a user
|
||||
async with SessionLocal() as session:
|
||||
user_data = UserCreate(
|
||||
email="restore_test@example.com",
|
||||
password="TestPassword123!",
|
||||
first_name="Restore",
|
||||
last_name="Test",
|
||||
)
|
||||
user = await user_crud.create(session, obj_in=user_data)
|
||||
user_id = user.id
|
||||
await session.commit()
|
||||
|
||||
async with SessionLocal() as session:
|
||||
await user_crud.soft_delete(session, id=str(user_id))
|
||||
|
||||
# Now test restore failure
|
||||
async with SessionLocal() as session:
|
||||
|
||||
async def mock_commit():
|
||||
raise RuntimeError("Restore failed")
|
||||
|
||||
with patch.object(session, "commit", side_effect=mock_commit):
|
||||
with patch.object(
|
||||
session, "rollback", new_callable=AsyncMock
|
||||
) as mock_rollback:
|
||||
with pytest.raises(RuntimeError, match="Restore failed"):
|
||||
await user_crud.restore(session, id=str(user_id))
|
||||
|
||||
mock_rollback.assert_called_once()
|
||||
|
||||
|
||||
class TestBaseCRUDGetFailures:
|
||||
"""Test get method exception handling."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_database_error_propagates(self, async_test_db):
|
||||
"""Test get propagates database errors."""
|
||||
_test_engine, SessionLocal = async_test_db
|
||||
|
||||
async with SessionLocal() as session:
|
||||
|
||||
async def mock_execute(*args, **kwargs):
|
||||
raise OperationalError("Get failed", {}, Exception("DB error"))
|
||||
|
||||
with patch.object(session, "execute", side_effect=mock_execute):
|
||||
with pytest.raises(OperationalError):
|
||||
await user_crud.get(session, id=str(uuid4()))
|
||||
|
||||
|
||||
class TestBaseCRUDGetMultiFailures:
|
||||
"""Test get_multi method exception handling."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_multi_database_error_propagates(self, async_test_db):
|
||||
"""Test get_multi propagates database errors."""
|
||||
_test_engine, SessionLocal = async_test_db
|
||||
|
||||
async with SessionLocal() as session:
|
||||
|
||||
async def mock_execute(*args, **kwargs):
|
||||
raise OperationalError("Query failed", {}, Exception("DB error"))
|
||||
|
||||
with patch.object(session, "execute", side_effect=mock_execute):
|
||||
with pytest.raises(OperationalError):
|
||||
await user_crud.get_multi(session, skip=0, limit=10)
|
||||
603
backend/tests/repositories/test_oauth.py
Normal file
603
backend/tests/repositories/test_oauth.py
Normal file
@@ -0,0 +1,603 @@
|
||||
# tests/crud/test_oauth.py
|
||||
"""
|
||||
Comprehensive tests for OAuth CRUD operations.
|
||||
"""
|
||||
|
||||
from datetime import UTC, datetime, timedelta
|
||||
|
||||
import pytest
|
||||
|
||||
from app.core.repository_exceptions import DuplicateEntryError
|
||||
from app.repositories.oauth_account import oauth_account_repo as oauth_account
|
||||
from app.repositories.oauth_client import oauth_client_repo as oauth_client
|
||||
from app.repositories.oauth_state import oauth_state_repo as oauth_state
|
||||
from app.schemas.oauth import OAuthAccountCreate, OAuthClientCreate, OAuthStateCreate
|
||||
|
||||
|
||||
class TestOAuthAccountCRUD:
|
||||
"""Tests for OAuth account CRUD operations."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_account(self, async_test_db, async_test_user):
|
||||
"""Test creating an OAuth account link."""
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
account_data = OAuthAccountCreate(
|
||||
user_id=async_test_user.id,
|
||||
provider="google",
|
||||
provider_user_id="google_123456",
|
||||
provider_email="user@gmail.com",
|
||||
)
|
||||
account = await oauth_account.create_account(session, obj_in=account_data)
|
||||
|
||||
assert account is not None
|
||||
assert account.provider == "google"
|
||||
assert account.provider_user_id == "google_123456"
|
||||
assert account.user_id == async_test_user.id
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_account_same_provider_twice_fails(
|
||||
self, async_test_db, async_test_user
|
||||
):
|
||||
"""Test creating same OAuth account for same user twice raises error."""
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
account_data = OAuthAccountCreate(
|
||||
user_id=async_test_user.id,
|
||||
provider="google",
|
||||
provider_user_id="google_dup_123",
|
||||
provider_email="user@gmail.com",
|
||||
)
|
||||
await oauth_account.create_account(session, obj_in=account_data)
|
||||
|
||||
# Try to create same account again (same provider + provider_user_id)
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
account_data2 = OAuthAccountCreate(
|
||||
user_id=async_test_user.id, # Same user
|
||||
provider="google",
|
||||
provider_user_id="google_dup_123", # Same provider_user_id
|
||||
provider_email="user@gmail.com",
|
||||
)
|
||||
|
||||
# SQLite returns different error message than PostgreSQL
|
||||
with pytest.raises(
|
||||
DuplicateEntryError, match="(already linked|UNIQUE constraint failed|Failed to create)"
|
||||
):
|
||||
await oauth_account.create_account(session, obj_in=account_data2)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_by_provider_id(self, async_test_db, async_test_user):
|
||||
"""Test getting OAuth account by provider and provider user ID."""
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
account_data = OAuthAccountCreate(
|
||||
user_id=async_test_user.id,
|
||||
provider="github",
|
||||
provider_user_id="github_789",
|
||||
provider_email="user@github.com",
|
||||
)
|
||||
await oauth_account.create_account(session, obj_in=account_data)
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
result = await oauth_account.get_by_provider_id(
|
||||
session,
|
||||
provider="github",
|
||||
provider_user_id="github_789",
|
||||
)
|
||||
assert result is not None
|
||||
assert result.provider == "github"
|
||||
assert result.user is not None # Eager loaded
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_by_provider_id_not_found(self, async_test_db):
|
||||
"""Test getting non-existent OAuth account returns None."""
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
result = await oauth_account.get_by_provider_id(
|
||||
session,
|
||||
provider="google",
|
||||
provider_user_id="nonexistent",
|
||||
)
|
||||
assert result is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_user_accounts(self, async_test_db, async_test_user):
|
||||
"""Test getting all OAuth accounts for a user."""
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
# Create two accounts for the same user
|
||||
for provider in ["google", "github"]:
|
||||
account_data = OAuthAccountCreate(
|
||||
user_id=async_test_user.id,
|
||||
provider=provider,
|
||||
provider_user_id=f"{provider}_user_123",
|
||||
provider_email=f"user@{provider}.com",
|
||||
)
|
||||
await oauth_account.create_account(session, obj_in=account_data)
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
accounts = await oauth_account.get_user_accounts(
|
||||
session, user_id=async_test_user.id
|
||||
)
|
||||
assert len(accounts) == 2
|
||||
providers = {a.provider for a in accounts}
|
||||
assert providers == {"google", "github"}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_user_account_by_provider(self, async_test_db, async_test_user):
|
||||
"""Test getting specific OAuth account for user and provider."""
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
account_data = OAuthAccountCreate(
|
||||
user_id=async_test_user.id,
|
||||
provider="google",
|
||||
provider_user_id="google_specific",
|
||||
)
|
||||
await oauth_account.create_account(session, obj_in=account_data)
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
result = await oauth_account.get_user_account_by_provider(
|
||||
session,
|
||||
user_id=async_test_user.id,
|
||||
provider="google",
|
||||
)
|
||||
assert result is not None
|
||||
assert result.provider == "google"
|
||||
|
||||
# Test not found
|
||||
result2 = await oauth_account.get_user_account_by_provider(
|
||||
session,
|
||||
user_id=async_test_user.id,
|
||||
provider="github", # Not linked
|
||||
)
|
||||
assert result2 is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_account(self, async_test_db, async_test_user):
|
||||
"""Test deleting an OAuth account link."""
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
account_data = OAuthAccountCreate(
|
||||
user_id=async_test_user.id,
|
||||
provider="google",
|
||||
provider_user_id="google_to_delete",
|
||||
)
|
||||
await oauth_account.create_account(session, obj_in=account_data)
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
deleted = await oauth_account.delete_account(
|
||||
session,
|
||||
user_id=async_test_user.id,
|
||||
provider="google",
|
||||
)
|
||||
assert deleted is True
|
||||
|
||||
# Verify deletion
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
result = await oauth_account.get_user_account_by_provider(
|
||||
session,
|
||||
user_id=async_test_user.id,
|
||||
provider="google",
|
||||
)
|
||||
assert result is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_account_not_found(self, async_test_db, async_test_user):
|
||||
"""Test deleting non-existent account returns False."""
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
deleted = await oauth_account.delete_account(
|
||||
session,
|
||||
user_id=async_test_user.id,
|
||||
provider="nonexistent",
|
||||
)
|
||||
assert deleted is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_by_provider_email(self, async_test_db, async_test_user):
|
||||
"""Test getting OAuth account by provider and email."""
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
account_data = OAuthAccountCreate(
|
||||
user_id=async_test_user.id,
|
||||
provider="google",
|
||||
provider_user_id="google_email_test",
|
||||
provider_email="unique@gmail.com",
|
||||
)
|
||||
await oauth_account.create_account(session, obj_in=account_data)
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
result = await oauth_account.get_by_provider_email(
|
||||
session,
|
||||
provider="google",
|
||||
email="unique@gmail.com",
|
||||
)
|
||||
assert result is not None
|
||||
assert result.provider_email == "unique@gmail.com"
|
||||
|
||||
# Test not found
|
||||
result2 = await oauth_account.get_by_provider_email(
|
||||
session,
|
||||
provider="google",
|
||||
email="nonexistent@gmail.com",
|
||||
)
|
||||
assert result2 is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_tokens(self, async_test_db, async_test_user):
|
||||
"""Test updating OAuth tokens."""
|
||||
from datetime import UTC, datetime, timedelta
|
||||
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
account_data = OAuthAccountCreate(
|
||||
user_id=async_test_user.id,
|
||||
provider="google",
|
||||
provider_user_id="google_token_test",
|
||||
)
|
||||
account = await oauth_account.create_account(session, obj_in=account_data)
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
# Get the account first
|
||||
account = await oauth_account.get_by_provider_id(
|
||||
session, provider="google", provider_user_id="google_token_test"
|
||||
)
|
||||
assert account is not None
|
||||
|
||||
# Update tokens
|
||||
new_expires = datetime.now(UTC) + timedelta(hours=1)
|
||||
updated = await oauth_account.update_tokens(
|
||||
session,
|
||||
account=account,
|
||||
access_token="new_access_token",
|
||||
refresh_token="new_refresh_token",
|
||||
token_expires_at=new_expires,
|
||||
)
|
||||
|
||||
assert updated.access_token == "new_access_token"
|
||||
assert updated.refresh_token == "new_refresh_token"
|
||||
|
||||
|
||||
class TestOAuthStateCRUD:
|
||||
"""Tests for OAuth state CRUD operations."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_state(self, async_test_db):
|
||||
"""Test creating OAuth state."""
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
state_data = OAuthStateCreate(
|
||||
state="random_state_123",
|
||||
code_verifier="pkce_verifier",
|
||||
nonce="oidc_nonce",
|
||||
provider="google",
|
||||
redirect_uri="http://localhost:3000/callback",
|
||||
expires_at=datetime.now(UTC) + timedelta(minutes=10),
|
||||
)
|
||||
state = await oauth_state.create_state(session, obj_in=state_data)
|
||||
|
||||
assert state is not None
|
||||
assert state.state == "random_state_123"
|
||||
assert state.code_verifier == "pkce_verifier"
|
||||
assert state.provider == "google"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_and_consume_state(self, async_test_db):
|
||||
"""Test getting and consuming OAuth state."""
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
state_data = OAuthStateCreate(
|
||||
state="consume_state_123",
|
||||
provider="github",
|
||||
expires_at=datetime.now(UTC) + timedelta(minutes=10),
|
||||
)
|
||||
await oauth_state.create_state(session, obj_in=state_data)
|
||||
|
||||
# Consume the state
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
result = await oauth_state.get_and_consume_state(
|
||||
session, state="consume_state_123"
|
||||
)
|
||||
assert result is not None
|
||||
assert result.provider == "github"
|
||||
|
||||
# Try to consume again - should be None (already consumed)
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
result2 = await oauth_state.get_and_consume_state(
|
||||
session, state="consume_state_123"
|
||||
)
|
||||
assert result2 is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_and_consume_expired_state(self, async_test_db):
|
||||
"""Test consuming expired state returns None."""
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
# Create expired state
|
||||
state_data = OAuthStateCreate(
|
||||
state="expired_state_123",
|
||||
provider="google",
|
||||
expires_at=datetime.now(UTC) - timedelta(minutes=1), # Already expired
|
||||
)
|
||||
await oauth_state.create_state(session, obj_in=state_data)
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
result = await oauth_state.get_and_consume_state(
|
||||
session, state="expired_state_123"
|
||||
)
|
||||
assert result is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cleanup_expired_states(self, async_test_db):
|
||||
"""Test cleaning up expired OAuth states."""
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
# Create expired state
|
||||
expired_state = OAuthStateCreate(
|
||||
state="cleanup_expired",
|
||||
provider="google",
|
||||
expires_at=datetime.now(UTC) - timedelta(minutes=5),
|
||||
)
|
||||
await oauth_state.create_state(session, obj_in=expired_state)
|
||||
|
||||
# Create valid state
|
||||
valid_state = OAuthStateCreate(
|
||||
state="cleanup_valid",
|
||||
provider="google",
|
||||
expires_at=datetime.now(UTC) + timedelta(minutes=10),
|
||||
)
|
||||
await oauth_state.create_state(session, obj_in=valid_state)
|
||||
|
||||
# Cleanup
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
count = await oauth_state.cleanup_expired(session)
|
||||
assert count == 1
|
||||
|
||||
# Verify only expired was deleted
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
result = await oauth_state.get_and_consume_state(
|
||||
session, state="cleanup_valid"
|
||||
)
|
||||
assert result is not None
|
||||
|
||||
|
||||
class TestOAuthClientCRUD:
|
||||
"""Tests for OAuth client CRUD operations (provider mode)."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_public_client(self, async_test_db):
|
||||
"""Test creating a public OAuth client."""
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
client_data = OAuthClientCreate(
|
||||
client_name="Test MCP App",
|
||||
client_description="A test application",
|
||||
redirect_uris=["http://localhost:3000/callback"],
|
||||
allowed_scopes=["read:users"],
|
||||
client_type="public",
|
||||
)
|
||||
client, secret = await oauth_client.create_client(
|
||||
session, obj_in=client_data
|
||||
)
|
||||
|
||||
assert client is not None
|
||||
assert client.client_name == "Test MCP App"
|
||||
assert client.client_type == "public"
|
||||
assert secret is None # Public clients don't have secrets
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_confidential_client(self, async_test_db):
|
||||
"""Test creating a confidential OAuth client."""
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
client_data = OAuthClientCreate(
|
||||
client_name="Confidential App",
|
||||
redirect_uris=["http://localhost:8080/callback"],
|
||||
allowed_scopes=["read:users", "write:users"],
|
||||
client_type="confidential",
|
||||
)
|
||||
client, secret = await oauth_client.create_client(
|
||||
session, obj_in=client_data
|
||||
)
|
||||
|
||||
assert client is not None
|
||||
assert client.client_type == "confidential"
|
||||
assert secret is not None # Confidential clients have secrets
|
||||
assert len(secret) > 20 # Should be a reasonably long secret
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_by_client_id(self, async_test_db):
|
||||
"""Test getting OAuth client by client_id."""
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
created_client_id = None
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
client_data = OAuthClientCreate(
|
||||
client_name="Lookup Test",
|
||||
redirect_uris=["http://localhost:3000/callback"],
|
||||
allowed_scopes=["read:users"],
|
||||
)
|
||||
client, _ = await oauth_client.create_client(session, obj_in=client_data)
|
||||
created_client_id = client.client_id
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
result = await oauth_client.get_by_client_id(
|
||||
session, client_id=created_client_id
|
||||
)
|
||||
assert result is not None
|
||||
assert result.client_name == "Lookup Test"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_inactive_client_not_found(self, async_test_db):
|
||||
"""Test getting inactive OAuth client returns None."""
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
created_client_id = None
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
client_data = OAuthClientCreate(
|
||||
client_name="Inactive Client",
|
||||
redirect_uris=["http://localhost:3000/callback"],
|
||||
allowed_scopes=["read:users"],
|
||||
)
|
||||
client, _ = await oauth_client.create_client(session, obj_in=client_data)
|
||||
created_client_id = client.client_id
|
||||
|
||||
# Deactivate
|
||||
await oauth_client.deactivate_client(session, client_id=created_client_id)
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
result = await oauth_client.get_by_client_id(
|
||||
session, client_id=created_client_id
|
||||
)
|
||||
assert result is None # Inactive clients not returned
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_validate_redirect_uri(self, async_test_db):
|
||||
"""Test redirect URI validation."""
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
created_client_id = None
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
client_data = OAuthClientCreate(
|
||||
client_name="URI Test",
|
||||
redirect_uris=[
|
||||
"http://localhost:3000/callback",
|
||||
"http://localhost:8080/oauth",
|
||||
],
|
||||
allowed_scopes=["read:users"],
|
||||
)
|
||||
client, _ = await oauth_client.create_client(session, obj_in=client_data)
|
||||
created_client_id = client.client_id
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
# Valid URI
|
||||
valid = await oauth_client.validate_redirect_uri(
|
||||
session,
|
||||
client_id=created_client_id,
|
||||
redirect_uri="http://localhost:3000/callback",
|
||||
)
|
||||
assert valid is True
|
||||
|
||||
# Invalid URI
|
||||
invalid = await oauth_client.validate_redirect_uri(
|
||||
session,
|
||||
client_id=created_client_id,
|
||||
redirect_uri="http://evil.com/callback",
|
||||
)
|
||||
assert invalid is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_verify_client_secret(self, async_test_db):
|
||||
"""Test client secret verification."""
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
created_client_id = None
|
||||
created_secret = None
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
client_data = OAuthClientCreate(
|
||||
client_name="Secret Test",
|
||||
redirect_uris=["http://localhost:3000/callback"],
|
||||
allowed_scopes=["read:users"],
|
||||
client_type="confidential",
|
||||
)
|
||||
client, secret = await oauth_client.create_client(
|
||||
session, obj_in=client_data
|
||||
)
|
||||
created_client_id = client.client_id
|
||||
created_secret = secret
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
# Valid secret
|
||||
valid = await oauth_client.verify_client_secret(
|
||||
session,
|
||||
client_id=created_client_id,
|
||||
client_secret=created_secret,
|
||||
)
|
||||
assert valid is True
|
||||
|
||||
# Invalid secret
|
||||
invalid = await oauth_client.verify_client_secret(
|
||||
session,
|
||||
client_id=created_client_id,
|
||||
client_secret="wrong_secret",
|
||||
)
|
||||
assert invalid is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_deactivate_nonexistent_client(self, async_test_db):
|
||||
"""Test deactivating non-existent client returns None."""
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
result = await oauth_client.deactivate_client(
|
||||
session, client_id="nonexistent_client_id"
|
||||
)
|
||||
assert result is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_validate_redirect_uri_nonexistent_client(self, async_test_db):
|
||||
"""Test validate_redirect_uri returns False for non-existent client."""
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
valid = await oauth_client.validate_redirect_uri(
|
||||
session,
|
||||
client_id="nonexistent_client_id",
|
||||
redirect_uri="http://localhost:3000/callback",
|
||||
)
|
||||
assert valid is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_verify_secret_nonexistent_client(self, async_test_db):
|
||||
"""Test verify_client_secret returns False for non-existent client."""
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
valid = await oauth_client.verify_client_secret(
|
||||
session,
|
||||
client_id="nonexistent_client_id",
|
||||
client_secret="any_secret",
|
||||
)
|
||||
assert valid is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_verify_secret_public_client(self, async_test_db):
|
||||
"""Test verify_client_secret returns False for public client (no secret)."""
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
client_data = OAuthClientCreate(
|
||||
client_name="Public Client",
|
||||
redirect_uris=["http://localhost:3000/callback"],
|
||||
allowed_scopes=["read:users"],
|
||||
client_type="public", # Public client - no secret
|
||||
)
|
||||
client, secret = await oauth_client.create_client(
|
||||
session, obj_in=client_data
|
||||
)
|
||||
assert secret is None
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
# Public clients don't have secrets, so verification should fail
|
||||
valid = await oauth_client.verify_client_secret(
|
||||
session,
|
||||
client_id=client.client_id,
|
||||
client_secret="any_secret",
|
||||
)
|
||||
assert valid is False
|
||||
1171
backend/tests/repositories/test_organization.py
Normal file
1171
backend/tests/repositories/test_organization.py
Normal file
File diff suppressed because it is too large
Load Diff
571
backend/tests/repositories/test_session.py
Normal file
571
backend/tests/repositories/test_session.py
Normal file
@@ -0,0 +1,571 @@
|
||||
# tests/crud/test_session_async.py
|
||||
"""
|
||||
Comprehensive tests for async session CRUD operations.
|
||||
"""
|
||||
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
|
||||
from app.core.repository_exceptions import InvalidInputError
|
||||
from app.repositories.session import session_repo as session_crud
|
||||
from app.models.user_session import UserSession
|
||||
from app.schemas.sessions import SessionCreate
|
||||
|
||||
|
||||
class TestGetByJti:
|
||||
"""Tests for get_by_jti method."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_by_jti_success(self, async_test_db, async_test_user):
|
||||
"""Test getting session by JTI."""
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
user_session = UserSession(
|
||||
user_id=async_test_user.id,
|
||||
refresh_token_jti="test_jti_123",
|
||||
device_name="Test Device",
|
||||
ip_address="192.168.1.1",
|
||||
user_agent="Mozilla/5.0",
|
||||
is_active=True,
|
||||
expires_at=datetime.now(UTC) + timedelta(days=7),
|
||||
last_used_at=datetime.now(UTC),
|
||||
)
|
||||
session.add(user_session)
|
||||
await session.commit()
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
result = await session_crud.get_by_jti(session, jti="test_jti_123")
|
||||
assert result is not None
|
||||
assert result.refresh_token_jti == "test_jti_123"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_by_jti_not_found(self, async_test_db):
|
||||
"""Test getting non-existent JTI returns None."""
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
result = await session_crud.get_by_jti(session, jti="nonexistent")
|
||||
assert result is None
|
||||
|
||||
|
||||
class TestGetActiveByJti:
|
||||
"""Tests for get_active_by_jti method."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_active_by_jti_success(self, async_test_db, async_test_user):
|
||||
"""Test getting active session by JTI."""
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
user_session = UserSession(
|
||||
user_id=async_test_user.id,
|
||||
refresh_token_jti="active_jti",
|
||||
device_name="Test Device",
|
||||
ip_address="192.168.1.1",
|
||||
user_agent="Mozilla/5.0",
|
||||
is_active=True,
|
||||
expires_at=datetime.now(UTC) + timedelta(days=7),
|
||||
last_used_at=datetime.now(UTC),
|
||||
)
|
||||
session.add(user_session)
|
||||
await session.commit()
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
result = await session_crud.get_active_by_jti(session, jti="active_jti")
|
||||
assert result is not None
|
||||
assert result.is_active is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_active_by_jti_inactive(self, async_test_db, async_test_user):
|
||||
"""Test getting inactive session by JTI returns None."""
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
user_session = UserSession(
|
||||
user_id=async_test_user.id,
|
||||
refresh_token_jti="inactive_jti",
|
||||
device_name="Test Device",
|
||||
ip_address="192.168.1.1",
|
||||
user_agent="Mozilla/5.0",
|
||||
is_active=False,
|
||||
expires_at=datetime.now(UTC) + timedelta(days=7),
|
||||
last_used_at=datetime.now(UTC),
|
||||
)
|
||||
session.add(user_session)
|
||||
await session.commit()
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
result = await session_crud.get_active_by_jti(session, jti="inactive_jti")
|
||||
assert result is None
|
||||
|
||||
|
||||
class TestGetUserSessions:
|
||||
"""Tests for get_user_sessions method."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_user_sessions_active_only(self, async_test_db, async_test_user):
|
||||
"""Test getting only active user sessions."""
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
active = UserSession(
|
||||
user_id=async_test_user.id,
|
||||
refresh_token_jti="active",
|
||||
device_name="Active Device",
|
||||
ip_address="192.168.1.1",
|
||||
user_agent="Mozilla/5.0",
|
||||
is_active=True,
|
||||
expires_at=datetime.now(UTC) + timedelta(days=7),
|
||||
last_used_at=datetime.now(UTC),
|
||||
)
|
||||
inactive = UserSession(
|
||||
user_id=async_test_user.id,
|
||||
refresh_token_jti="inactive",
|
||||
device_name="Inactive Device",
|
||||
ip_address="192.168.1.2",
|
||||
user_agent="Mozilla/5.0",
|
||||
is_active=False,
|
||||
expires_at=datetime.now(UTC) + timedelta(days=7),
|
||||
last_used_at=datetime.now(UTC),
|
||||
)
|
||||
session.add_all([active, inactive])
|
||||
await session.commit()
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
results = await session_crud.get_user_sessions(
|
||||
session, user_id=str(async_test_user.id), active_only=True
|
||||
)
|
||||
assert len(results) == 1
|
||||
assert results[0].is_active is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_user_sessions_all(self, async_test_db, async_test_user):
|
||||
"""Test getting all user sessions."""
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
for i in range(3):
|
||||
sess = UserSession(
|
||||
user_id=async_test_user.id,
|
||||
refresh_token_jti=f"session_{i}",
|
||||
device_name=f"Device {i}",
|
||||
ip_address="192.168.1.1",
|
||||
user_agent="Mozilla/5.0",
|
||||
is_active=i % 2 == 0,
|
||||
expires_at=datetime.now(UTC) + timedelta(days=7),
|
||||
last_used_at=datetime.now(UTC),
|
||||
)
|
||||
session.add(sess)
|
||||
await session.commit()
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
results = await session_crud.get_user_sessions(
|
||||
session, user_id=str(async_test_user.id), active_only=False
|
||||
)
|
||||
assert len(results) == 3
|
||||
|
||||
|
||||
class TestCreateSession:
|
||||
"""Tests for create_session method."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_session_success(self, async_test_db, async_test_user):
|
||||
"""Test successfully creating a session_crud."""
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
session_data = SessionCreate(
|
||||
user_id=async_test_user.id,
|
||||
refresh_token_jti="new_jti",
|
||||
device_name="New Device",
|
||||
device_id="device_123",
|
||||
ip_address="192.168.1.100",
|
||||
user_agent="Mozilla/5.0",
|
||||
last_used_at=datetime.now(UTC),
|
||||
expires_at=datetime.now(UTC) + timedelta(days=7),
|
||||
location_city="San Francisco",
|
||||
location_country="USA",
|
||||
)
|
||||
result = await session_crud.create_session(session, obj_in=session_data)
|
||||
|
||||
assert result.user_id == async_test_user.id
|
||||
assert result.refresh_token_jti == "new_jti"
|
||||
assert result.is_active is True
|
||||
assert result.location_city == "San Francisco"
|
||||
|
||||
|
||||
class TestDeactivate:
|
||||
"""Tests for deactivate method."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_deactivate_success(self, async_test_db, async_test_user):
|
||||
"""Test successfully deactivating a session_crud."""
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
user_session = UserSession(
|
||||
user_id=async_test_user.id,
|
||||
refresh_token_jti="to_deactivate",
|
||||
device_name="Test Device",
|
||||
ip_address="192.168.1.1",
|
||||
user_agent="Mozilla/5.0",
|
||||
is_active=True,
|
||||
expires_at=datetime.now(UTC) + timedelta(days=7),
|
||||
last_used_at=datetime.now(UTC),
|
||||
)
|
||||
session.add(user_session)
|
||||
await session.commit()
|
||||
session_id = user_session.id
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
result = await session_crud.deactivate(session, session_id=str(session_id))
|
||||
assert result is not None
|
||||
assert result.is_active is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_deactivate_not_found(self, async_test_db):
|
||||
"""Test deactivating non-existent session returns None."""
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
result = await session_crud.deactivate(session, session_id=str(uuid4()))
|
||||
assert result is None
|
||||
|
||||
|
||||
class TestDeactivateAllUserSessions:
|
||||
"""Tests for deactivate_all_user_sessions method."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_deactivate_all_user_sessions_success(
|
||||
self, async_test_db, async_test_user
|
||||
):
|
||||
"""Test deactivating all user sessions."""
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
# Create minimal sessions for test (2 instead of 5)
|
||||
for i in range(2):
|
||||
sess = UserSession(
|
||||
user_id=async_test_user.id,
|
||||
refresh_token_jti=f"bulk_{i}",
|
||||
device_name=f"Device {i}",
|
||||
ip_address="192.168.1.1",
|
||||
user_agent="Mozilla/5.0",
|
||||
is_active=True,
|
||||
expires_at=datetime.now(UTC) + timedelta(days=7),
|
||||
last_used_at=datetime.now(UTC),
|
||||
)
|
||||
session.add(sess)
|
||||
await session.commit()
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
count = await session_crud.deactivate_all_user_sessions(
|
||||
session, user_id=str(async_test_user.id)
|
||||
)
|
||||
assert count == 2
|
||||
|
||||
|
||||
class TestUpdateLastUsed:
|
||||
"""Tests for update_last_used method."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_last_used_success(self, async_test_db, async_test_user):
|
||||
"""Test updating last_used_at timestamp."""
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
user_session = UserSession(
|
||||
user_id=async_test_user.id,
|
||||
refresh_token_jti="update_test",
|
||||
device_name="Test Device",
|
||||
ip_address="192.168.1.1",
|
||||
user_agent="Mozilla/5.0",
|
||||
is_active=True,
|
||||
expires_at=datetime.now(UTC) + timedelta(days=7),
|
||||
last_used_at=datetime.now(UTC) - timedelta(hours=1),
|
||||
)
|
||||
session.add(user_session)
|
||||
await session.commit()
|
||||
await session.refresh(user_session)
|
||||
|
||||
old_time = user_session.last_used_at
|
||||
result = await session_crud.update_last_used(session, session=user_session)
|
||||
|
||||
assert result.last_used_at > old_time
|
||||
|
||||
|
||||
class TestGetUserSessionCount:
|
||||
"""Tests for get_user_session_count method."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_user_session_count_success(self, async_test_db, async_test_user):
|
||||
"""Test getting user session count."""
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
for i in range(3):
|
||||
sess = UserSession(
|
||||
user_id=async_test_user.id,
|
||||
refresh_token_jti=f"count_{i}",
|
||||
device_name=f"Device {i}",
|
||||
ip_address="192.168.1.1",
|
||||
user_agent="Mozilla/5.0",
|
||||
is_active=True,
|
||||
expires_at=datetime.now(UTC) + timedelta(days=7),
|
||||
last_used_at=datetime.now(UTC),
|
||||
)
|
||||
session.add(sess)
|
||||
await session.commit()
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
count = await session_crud.get_user_session_count(
|
||||
session, user_id=str(async_test_user.id)
|
||||
)
|
||||
assert count == 3
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_user_session_count_empty(self, async_test_db):
|
||||
"""Test getting session count for user with no sessions."""
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
count = await session_crud.get_user_session_count(
|
||||
session, user_id=str(uuid4())
|
||||
)
|
||||
assert count == 0
|
||||
|
||||
|
||||
class TestUpdateRefreshToken:
|
||||
"""Tests for update_refresh_token method."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_refresh_token_success(self, async_test_db, async_test_user):
|
||||
"""Test updating refresh token JTI and expiration."""
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
user_session = UserSession(
|
||||
user_id=async_test_user.id,
|
||||
refresh_token_jti="old_jti",
|
||||
device_name="Test Device",
|
||||
ip_address="192.168.1.1",
|
||||
user_agent="Mozilla/5.0",
|
||||
is_active=True,
|
||||
expires_at=datetime.now(UTC) + timedelta(days=7),
|
||||
last_used_at=datetime.now(UTC) - timedelta(hours=1),
|
||||
)
|
||||
session.add(user_session)
|
||||
await session.commit()
|
||||
await session.refresh(user_session)
|
||||
|
||||
new_jti = "new_jti_123"
|
||||
new_expires = datetime.now(UTC) + timedelta(days=14)
|
||||
|
||||
result = await session_crud.update_refresh_token(
|
||||
session,
|
||||
session=user_session,
|
||||
new_jti=new_jti,
|
||||
new_expires_at=new_expires,
|
||||
)
|
||||
|
||||
assert result.refresh_token_jti == new_jti
|
||||
# Compare timestamps ignoring timezone info
|
||||
assert (
|
||||
abs(
|
||||
(
|
||||
result.expires_at.replace(tzinfo=None)
|
||||
- new_expires.replace(tzinfo=None)
|
||||
).total_seconds()
|
||||
)
|
||||
< 1
|
||||
)
|
||||
|
||||
|
||||
class TestCleanupExpired:
|
||||
"""Tests for cleanup_expired method."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cleanup_expired_success(self, async_test_db, async_test_user):
|
||||
"""Test cleaning up old expired inactive sessions."""
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
# Create old expired inactive session
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
old_session = UserSession(
|
||||
user_id=async_test_user.id,
|
||||
refresh_token_jti="old_expired",
|
||||
device_name="Old Device",
|
||||
ip_address="192.168.1.1",
|
||||
user_agent="Mozilla/5.0",
|
||||
is_active=False,
|
||||
expires_at=datetime.now(UTC) - timedelta(days=5),
|
||||
last_used_at=datetime.now(UTC) - timedelta(days=35),
|
||||
created_at=datetime.now(UTC) - timedelta(days=35),
|
||||
)
|
||||
session.add(old_session)
|
||||
await session.commit()
|
||||
|
||||
# Cleanup
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
count = await session_crud.cleanup_expired(session, keep_days=30)
|
||||
assert count == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cleanup_expired_keeps_recent(self, async_test_db, async_test_user):
|
||||
"""Test that cleanup keeps recent expired sessions."""
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
# Create recent expired inactive session (less than keep_days old)
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
recent_session = UserSession(
|
||||
user_id=async_test_user.id,
|
||||
refresh_token_jti="recent_expired",
|
||||
device_name="Recent Device",
|
||||
ip_address="192.168.1.1",
|
||||
user_agent="Mozilla/5.0",
|
||||
is_active=False,
|
||||
expires_at=datetime.now(UTC) - timedelta(hours=1),
|
||||
last_used_at=datetime.now(UTC) - timedelta(hours=2),
|
||||
created_at=datetime.now(UTC) - timedelta(days=1),
|
||||
)
|
||||
session.add(recent_session)
|
||||
await session.commit()
|
||||
|
||||
# Cleanup
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
count = await session_crud.cleanup_expired(session, keep_days=30)
|
||||
assert count == 0 # Should not delete recent sessions
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cleanup_expired_keeps_active(self, async_test_db, async_test_user):
|
||||
"""Test that cleanup does not delete active sessions."""
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
# Create old expired but ACTIVE session
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
active_session = UserSession(
|
||||
user_id=async_test_user.id,
|
||||
refresh_token_jti="active_expired",
|
||||
device_name="Active Device",
|
||||
ip_address="192.168.1.1",
|
||||
user_agent="Mozilla/5.0",
|
||||
is_active=True, # Active
|
||||
expires_at=datetime.now(UTC) - timedelta(days=5),
|
||||
last_used_at=datetime.now(UTC) - timedelta(days=35),
|
||||
created_at=datetime.now(UTC) - timedelta(days=35),
|
||||
)
|
||||
session.add(active_session)
|
||||
await session.commit()
|
||||
|
||||
# Cleanup
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
count = await session_crud.cleanup_expired(session, keep_days=30)
|
||||
assert count == 0 # Should not delete active sessions
|
||||
|
||||
|
||||
class TestCleanupExpiredForUser:
|
||||
"""Tests for cleanup_expired_for_user method."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cleanup_expired_for_user_success(
|
||||
self, async_test_db, async_test_user
|
||||
):
|
||||
"""Test cleaning up expired sessions for specific user."""
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
# Create expired inactive session for user
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
expired_session = UserSession(
|
||||
user_id=async_test_user.id,
|
||||
refresh_token_jti="user_expired",
|
||||
device_name="Expired Device",
|
||||
ip_address="192.168.1.1",
|
||||
user_agent="Mozilla/5.0",
|
||||
is_active=False,
|
||||
expires_at=datetime.now(UTC) - timedelta(days=1),
|
||||
last_used_at=datetime.now(UTC) - timedelta(days=2),
|
||||
)
|
||||
session.add(expired_session)
|
||||
await session.commit()
|
||||
|
||||
# Cleanup for user
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
count = await session_crud.cleanup_expired_for_user(
|
||||
session, user_id=str(async_test_user.id)
|
||||
)
|
||||
assert count == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cleanup_expired_for_user_invalid_uuid(self, async_test_db):
|
||||
"""Test cleanup with invalid user UUID."""
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
with pytest.raises(InvalidInputError, match="Invalid user ID format"):
|
||||
await session_crud.cleanup_expired_for_user(
|
||||
session, user_id="not-a-valid-uuid"
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cleanup_expired_for_user_keeps_active(
|
||||
self, async_test_db, async_test_user
|
||||
):
|
||||
"""Test that cleanup for user keeps active sessions."""
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
# Create expired but active session
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
active_session = UserSession(
|
||||
user_id=async_test_user.id,
|
||||
refresh_token_jti="active_user_expired",
|
||||
device_name="Active Device",
|
||||
ip_address="192.168.1.1",
|
||||
user_agent="Mozilla/5.0",
|
||||
is_active=True, # Active
|
||||
expires_at=datetime.now(UTC) - timedelta(days=1),
|
||||
last_used_at=datetime.now(UTC) - timedelta(days=2),
|
||||
)
|
||||
session.add(active_session)
|
||||
await session.commit()
|
||||
|
||||
# Cleanup
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
count = await session_crud.cleanup_expired_for_user(
|
||||
session, user_id=str(async_test_user.id)
|
||||
)
|
||||
assert count == 0 # Should not delete active sessions
|
||||
|
||||
|
||||
class TestGetUserSessionsWithUser:
|
||||
"""Tests for get_user_sessions with eager loading."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_user_sessions_with_user_relationship(
|
||||
self, async_test_db, async_test_user
|
||||
):
|
||||
"""Test getting sessions with user relationship loaded."""
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
user_session = UserSession(
|
||||
user_id=async_test_user.id,
|
||||
refresh_token_jti="with_user",
|
||||
device_name="Test Device",
|
||||
ip_address="192.168.1.1",
|
||||
user_agent="Mozilla/5.0",
|
||||
is_active=True,
|
||||
expires_at=datetime.now(UTC) + timedelta(days=7),
|
||||
last_used_at=datetime.now(UTC),
|
||||
)
|
||||
session.add(user_session)
|
||||
await session.commit()
|
||||
|
||||
# Get with user relationship
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
results = await session_crud.get_user_sessions(
|
||||
session, user_id=str(async_test_user.id), with_user=True
|
||||
)
|
||||
assert len(results) >= 1
|
||||
387
backend/tests/repositories/test_session_db_failures.py
Normal file
387
backend/tests/repositories/test_session_db_failures.py
Normal file
@@ -0,0 +1,387 @@
|
||||
# tests/crud/test_session_db_failures.py
|
||||
"""
|
||||
Comprehensive tests for session CRUD database failure scenarios.
|
||||
"""
|
||||
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from unittest.mock import AsyncMock, patch
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
from sqlalchemy.exc import OperationalError
|
||||
|
||||
from app.core.repository_exceptions import IntegrityConstraintError
|
||||
from app.repositories.session import session_repo as session_crud
|
||||
from app.models.user_session import UserSession
|
||||
from app.schemas.sessions import SessionCreate
|
||||
|
||||
|
||||
class TestSessionCRUDGetByJtiFailures:
|
||||
"""Test get_by_jti exception handling."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_by_jti_database_error(self, async_test_db):
|
||||
"""Test get_by_jti handles database errors."""
|
||||
_test_engine, SessionLocal = async_test_db
|
||||
|
||||
async with SessionLocal() as session:
|
||||
|
||||
async def mock_execute(*args, **kwargs):
|
||||
raise OperationalError("DB connection lost", {}, Exception())
|
||||
|
||||
with patch.object(session, "execute", side_effect=mock_execute):
|
||||
with pytest.raises(OperationalError):
|
||||
await session_crud.get_by_jti(session, jti="test_jti")
|
||||
|
||||
|
||||
class TestSessionCRUDGetActiveByJtiFailures:
|
||||
"""Test get_active_by_jti exception handling."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_active_by_jti_database_error(self, async_test_db):
|
||||
"""Test get_active_by_jti handles database errors."""
|
||||
_test_engine, SessionLocal = async_test_db
|
||||
|
||||
async with SessionLocal() as session:
|
||||
|
||||
async def mock_execute(*args, **kwargs):
|
||||
raise OperationalError("Query timeout", {}, Exception())
|
||||
|
||||
with patch.object(session, "execute", side_effect=mock_execute):
|
||||
with pytest.raises(OperationalError):
|
||||
await session_crud.get_active_by_jti(session, jti="test_jti")
|
||||
|
||||
|
||||
class TestSessionCRUDGetUserSessionsFailures:
|
||||
"""Test get_user_sessions exception handling."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_user_sessions_database_error(
|
||||
self, async_test_db, async_test_user
|
||||
):
|
||||
"""Test get_user_sessions handles database errors."""
|
||||
_test_engine, SessionLocal = async_test_db
|
||||
|
||||
async with SessionLocal() as session:
|
||||
|
||||
async def mock_execute(*args, **kwargs):
|
||||
raise OperationalError("Database error", {}, Exception())
|
||||
|
||||
with patch.object(session, "execute", side_effect=mock_execute):
|
||||
with pytest.raises(OperationalError):
|
||||
await session_crud.get_user_sessions(
|
||||
session, user_id=str(async_test_user.id)
|
||||
)
|
||||
|
||||
|
||||
class TestSessionCRUDCreateSessionFailures:
|
||||
"""Test create_session exception handling."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_session_commit_failure_triggers_rollback(
|
||||
self, async_test_db, async_test_user
|
||||
):
|
||||
"""Test create_session handles commit failures with rollback."""
|
||||
_test_engine, SessionLocal = async_test_db
|
||||
|
||||
async with SessionLocal() as session:
|
||||
|
||||
async def mock_commit():
|
||||
raise OperationalError("Commit failed", {}, Exception())
|
||||
|
||||
with patch.object(session, "commit", side_effect=mock_commit):
|
||||
with patch.object(
|
||||
session, "rollback", new_callable=AsyncMock
|
||||
) as mock_rollback:
|
||||
session_data = SessionCreate(
|
||||
user_id=async_test_user.id,
|
||||
refresh_token_jti=str(uuid4()),
|
||||
device_name="Test Device",
|
||||
ip_address="127.0.0.1",
|
||||
user_agent="Test Agent",
|
||||
expires_at=datetime.now(UTC) + timedelta(days=7),
|
||||
last_used_at=datetime.now(UTC),
|
||||
)
|
||||
|
||||
with pytest.raises(IntegrityConstraintError, match="Failed to create session"):
|
||||
await session_crud.create_session(session, obj_in=session_data)
|
||||
|
||||
mock_rollback.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_session_unexpected_error_triggers_rollback(
|
||||
self, async_test_db, async_test_user
|
||||
):
|
||||
"""Test create_session handles unexpected errors."""
|
||||
_test_engine, SessionLocal = async_test_db
|
||||
|
||||
async with SessionLocal() as session:
|
||||
|
||||
async def mock_commit():
|
||||
raise RuntimeError("Unexpected error")
|
||||
|
||||
with patch.object(session, "commit", side_effect=mock_commit):
|
||||
with patch.object(
|
||||
session, "rollback", new_callable=AsyncMock
|
||||
) as mock_rollback:
|
||||
session_data = SessionCreate(
|
||||
user_id=async_test_user.id,
|
||||
refresh_token_jti=str(uuid4()),
|
||||
device_name="Test Device",
|
||||
ip_address="127.0.0.1",
|
||||
user_agent="Test Agent",
|
||||
expires_at=datetime.now(UTC) + timedelta(days=7),
|
||||
last_used_at=datetime.now(UTC),
|
||||
)
|
||||
|
||||
with pytest.raises(IntegrityConstraintError, match="Failed to create session"):
|
||||
await session_crud.create_session(session, obj_in=session_data)
|
||||
|
||||
mock_rollback.assert_called_once()
|
||||
|
||||
|
||||
class TestSessionCRUDDeactivateFailures:
|
||||
"""Test deactivate exception handling."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_deactivate_commit_failure_triggers_rollback(
|
||||
self, async_test_db, async_test_user
|
||||
):
|
||||
"""Test deactivate handles commit failures."""
|
||||
_test_engine, SessionLocal = async_test_db
|
||||
|
||||
# Create a session first
|
||||
async with SessionLocal() as session:
|
||||
user_session = UserSession(
|
||||
user_id=async_test_user.id,
|
||||
refresh_token_jti=str(uuid4()),
|
||||
device_name="Test Device",
|
||||
ip_address="127.0.0.1",
|
||||
user_agent="Test Agent",
|
||||
is_active=True,
|
||||
expires_at=datetime.now(UTC) + timedelta(days=7),
|
||||
last_used_at=datetime.now(UTC),
|
||||
)
|
||||
session.add(user_session)
|
||||
await session.commit()
|
||||
await session.refresh(user_session)
|
||||
session_id = user_session.id
|
||||
|
||||
# Test deactivate failure
|
||||
async with SessionLocal() as session:
|
||||
|
||||
async def mock_commit():
|
||||
raise OperationalError("Deactivate failed", {}, Exception())
|
||||
|
||||
with patch.object(session, "commit", side_effect=mock_commit):
|
||||
with patch.object(
|
||||
session, "rollback", new_callable=AsyncMock
|
||||
) as mock_rollback:
|
||||
with pytest.raises(OperationalError):
|
||||
await session_crud.deactivate(
|
||||
session, session_id=str(session_id)
|
||||
)
|
||||
|
||||
mock_rollback.assert_called_once()
|
||||
|
||||
|
||||
class TestSessionCRUDDeactivateAllFailures:
|
||||
"""Test deactivate_all_user_sessions exception handling."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_deactivate_all_commit_failure_triggers_rollback(
|
||||
self, async_test_db, async_test_user
|
||||
):
|
||||
"""Test deactivate_all handles commit failures."""
|
||||
_test_engine, SessionLocal = async_test_db
|
||||
|
||||
async with SessionLocal() as session:
|
||||
|
||||
async def mock_commit():
|
||||
raise OperationalError("Bulk deactivate failed", {}, Exception())
|
||||
|
||||
with patch.object(session, "commit", side_effect=mock_commit):
|
||||
with patch.object(
|
||||
session, "rollback", new_callable=AsyncMock
|
||||
) as mock_rollback:
|
||||
with pytest.raises(OperationalError):
|
||||
await session_crud.deactivate_all_user_sessions(
|
||||
session, user_id=str(async_test_user.id)
|
||||
)
|
||||
|
||||
mock_rollback.assert_called_once()
|
||||
|
||||
|
||||
class TestSessionCRUDUpdateLastUsedFailures:
|
||||
"""Test update_last_used exception handling."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_last_used_commit_failure_triggers_rollback(
|
||||
self, async_test_db, async_test_user
|
||||
):
|
||||
"""Test update_last_used handles commit failures."""
|
||||
_test_engine, SessionLocal = async_test_db
|
||||
|
||||
# Create a session
|
||||
async with SessionLocal() as session:
|
||||
user_session = UserSession(
|
||||
user_id=async_test_user.id,
|
||||
refresh_token_jti=str(uuid4()),
|
||||
device_name="Test Device",
|
||||
ip_address="127.0.0.1",
|
||||
user_agent="Test Agent",
|
||||
is_active=True,
|
||||
expires_at=datetime.now(UTC) + timedelta(days=7),
|
||||
last_used_at=datetime.now(UTC) - timedelta(hours=1),
|
||||
)
|
||||
session.add(user_session)
|
||||
await session.commit()
|
||||
await session.refresh(user_session)
|
||||
|
||||
# Test update failure
|
||||
async with SessionLocal() as session:
|
||||
from sqlalchemy import select
|
||||
|
||||
from app.models.user_session import UserSession as US
|
||||
|
||||
result = await session.execute(select(US).where(US.id == user_session.id))
|
||||
sess = result.scalar_one()
|
||||
|
||||
async def mock_commit():
|
||||
raise OperationalError("Update failed", {}, Exception())
|
||||
|
||||
with patch.object(session, "commit", side_effect=mock_commit):
|
||||
with patch.object(
|
||||
session, "rollback", new_callable=AsyncMock
|
||||
) as mock_rollback:
|
||||
with pytest.raises(OperationalError):
|
||||
await session_crud.update_last_used(session, session=sess)
|
||||
|
||||
mock_rollback.assert_called_once()
|
||||
|
||||
|
||||
class TestSessionCRUDUpdateRefreshTokenFailures:
|
||||
"""Test update_refresh_token exception handling."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_refresh_token_commit_failure_triggers_rollback(
|
||||
self, async_test_db, async_test_user
|
||||
):
|
||||
"""Test update_refresh_token handles commit failures."""
|
||||
_test_engine, SessionLocal = async_test_db
|
||||
|
||||
# Create a session
|
||||
async with SessionLocal() as session:
|
||||
user_session = UserSession(
|
||||
user_id=async_test_user.id,
|
||||
refresh_token_jti=str(uuid4()),
|
||||
device_name="Test Device",
|
||||
ip_address="127.0.0.1",
|
||||
user_agent="Test Agent",
|
||||
is_active=True,
|
||||
expires_at=datetime.now(UTC) + timedelta(days=7),
|
||||
last_used_at=datetime.now(UTC),
|
||||
)
|
||||
session.add(user_session)
|
||||
await session.commit()
|
||||
await session.refresh(user_session)
|
||||
|
||||
# Test update failure
|
||||
async with SessionLocal() as session:
|
||||
from sqlalchemy import select
|
||||
|
||||
from app.models.user_session import UserSession as US
|
||||
|
||||
result = await session.execute(select(US).where(US.id == user_session.id))
|
||||
sess = result.scalar_one()
|
||||
|
||||
async def mock_commit():
|
||||
raise OperationalError("Token update failed", {}, Exception())
|
||||
|
||||
with patch.object(session, "commit", side_effect=mock_commit):
|
||||
with patch.object(
|
||||
session, "rollback", new_callable=AsyncMock
|
||||
) as mock_rollback:
|
||||
with pytest.raises(OperationalError):
|
||||
await session_crud.update_refresh_token(
|
||||
session,
|
||||
session=sess,
|
||||
new_jti=str(uuid4()),
|
||||
new_expires_at=datetime.now(UTC) + timedelta(days=14),
|
||||
)
|
||||
|
||||
mock_rollback.assert_called_once()
|
||||
|
||||
|
||||
class TestSessionCRUDCleanupExpiredFailures:
|
||||
"""Test cleanup_expired exception handling."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cleanup_expired_commit_failure_triggers_rollback(
|
||||
self, async_test_db
|
||||
):
|
||||
"""Test cleanup_expired handles commit failures."""
|
||||
_test_engine, SessionLocal = async_test_db
|
||||
|
||||
async with SessionLocal() as session:
|
||||
|
||||
async def mock_commit():
|
||||
raise OperationalError("Cleanup failed", {}, Exception())
|
||||
|
||||
with patch.object(session, "commit", side_effect=mock_commit):
|
||||
with patch.object(
|
||||
session, "rollback", new_callable=AsyncMock
|
||||
) as mock_rollback:
|
||||
with pytest.raises(OperationalError):
|
||||
await session_crud.cleanup_expired(session, keep_days=30)
|
||||
|
||||
mock_rollback.assert_called_once()
|
||||
|
||||
|
||||
class TestSessionCRUDCleanupExpiredForUserFailures:
|
||||
"""Test cleanup_expired_for_user exception handling."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cleanup_expired_for_user_commit_failure_triggers_rollback(
|
||||
self, async_test_db, async_test_user
|
||||
):
|
||||
"""Test cleanup_expired_for_user handles commit failures."""
|
||||
_test_engine, SessionLocal = async_test_db
|
||||
|
||||
async with SessionLocal() as session:
|
||||
|
||||
async def mock_commit():
|
||||
raise OperationalError("User cleanup failed", {}, Exception())
|
||||
|
||||
with patch.object(session, "commit", side_effect=mock_commit):
|
||||
with patch.object(
|
||||
session, "rollback", new_callable=AsyncMock
|
||||
) as mock_rollback:
|
||||
with pytest.raises(OperationalError):
|
||||
await session_crud.cleanup_expired_for_user(
|
||||
session, user_id=str(async_test_user.id)
|
||||
)
|
||||
|
||||
mock_rollback.assert_called_once()
|
||||
|
||||
|
||||
class TestSessionCRUDGetUserSessionCountFailures:
|
||||
"""Test get_user_session_count exception handling."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_user_session_count_database_error(
|
||||
self, async_test_db, async_test_user
|
||||
):
|
||||
"""Test get_user_session_count handles database errors."""
|
||||
_test_engine, SessionLocal = async_test_db
|
||||
|
||||
async with SessionLocal() as session:
|
||||
|
||||
async def mock_execute(*args, **kwargs):
|
||||
raise OperationalError("Count query failed", {}, Exception())
|
||||
|
||||
with patch.object(session, "execute", side_effect=mock_execute):
|
||||
with pytest.raises(OperationalError):
|
||||
await session_crud.get_user_session_count(
|
||||
session, user_id=str(async_test_user.id)
|
||||
)
|
||||
665
backend/tests/repositories/test_user.py
Normal file
665
backend/tests/repositories/test_user.py
Normal file
@@ -0,0 +1,665 @@
|
||||
# tests/crud/test_user_async.py
|
||||
"""
|
||||
Comprehensive tests for async user CRUD operations.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
|
||||
from app.core.repository_exceptions import DuplicateEntryError, InvalidInputError
|
||||
from app.repositories.user import user_repo as user_crud
|
||||
from app.schemas.users import UserCreate, UserUpdate
|
||||
|
||||
|
||||
class TestGetByEmail:
|
||||
"""Tests for get_by_email method."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_by_email_success(self, async_test_db, async_test_user):
|
||||
"""Test getting user by email."""
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
result = await user_crud.get_by_email(session, email=async_test_user.email)
|
||||
assert result is not None
|
||||
assert result.email == async_test_user.email
|
||||
assert result.id == async_test_user.id
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_by_email_not_found(self, async_test_db):
|
||||
"""Test getting non-existent email returns None."""
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
result = await user_crud.get_by_email(
|
||||
session, email="nonexistent@example.com"
|
||||
)
|
||||
assert result is None
|
||||
|
||||
|
||||
class TestCreate:
|
||||
"""Tests for create method."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_user_success(self, async_test_db):
|
||||
"""Test successfully creating a user_crud."""
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
user_data = UserCreate(
|
||||
email="newuser@example.com",
|
||||
password="SecurePass123!",
|
||||
first_name="New",
|
||||
last_name="User",
|
||||
phone_number="+1234567890",
|
||||
)
|
||||
result = await user_crud.create(session, obj_in=user_data)
|
||||
|
||||
assert result.email == "newuser@example.com"
|
||||
assert result.first_name == "New"
|
||||
assert result.last_name == "User"
|
||||
assert result.phone_number == "+1234567890"
|
||||
assert result.is_active is True
|
||||
assert result.is_superuser is False
|
||||
assert result.password_hash is not None
|
||||
assert result.password_hash != "SecurePass123!" # Password should be hashed
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_superuser_success(self, async_test_db):
|
||||
"""Test creating a superuser."""
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
user_data = UserCreate(
|
||||
email="superuser@example.com",
|
||||
password="SuperPass123!",
|
||||
first_name="Super",
|
||||
last_name="User",
|
||||
is_superuser=True,
|
||||
)
|
||||
result = await user_crud.create(session, obj_in=user_data)
|
||||
|
||||
assert result.is_superuser is True
|
||||
assert result.email == "superuser@example.com"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_duplicate_email_fails(self, async_test_db, async_test_user):
|
||||
"""Test creating user with duplicate email raises ValueError."""
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
user_data = UserCreate(
|
||||
email=async_test_user.email, # Duplicate email
|
||||
password="AnotherPass123!",
|
||||
first_name="Duplicate",
|
||||
last_name="User",
|
||||
)
|
||||
|
||||
with pytest.raises(DuplicateEntryError) as exc_info:
|
||||
await user_crud.create(session, obj_in=user_data)
|
||||
|
||||
assert "already exists" in str(exc_info.value).lower()
|
||||
|
||||
|
||||
class TestUpdate:
|
||||
"""Tests for update method."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_user_basic_fields(self, async_test_db, async_test_user):
|
||||
"""Test updating basic user fields."""
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
# Get fresh copy of user
|
||||
user = await user_crud.get(session, id=str(async_test_user.id))
|
||||
|
||||
update_data = UserUpdate(
|
||||
first_name="Updated", last_name="Name", phone_number="+9876543210"
|
||||
)
|
||||
result = await user_crud.update(session, db_obj=user, obj_in=update_data)
|
||||
|
||||
assert result.first_name == "Updated"
|
||||
assert result.last_name == "Name"
|
||||
assert result.phone_number == "+9876543210"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_user_password(self, async_test_db):
|
||||
"""Test updating user password."""
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
# Create a fresh user for this test
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
user_data = UserCreate(
|
||||
email="passwordtest@example.com",
|
||||
password="OldPassword123!",
|
||||
first_name="Pass",
|
||||
last_name="Test",
|
||||
)
|
||||
user = await user_crud.create(session, obj_in=user_data)
|
||||
user_id = user.id
|
||||
old_password_hash = user.password_hash
|
||||
|
||||
# Update the password
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
user = await user_crud.get(session, id=str(user_id))
|
||||
|
||||
update_data = UserUpdate(password="NewDifferentPassword123!")
|
||||
result = await user_crud.update(session, db_obj=user, obj_in=update_data)
|
||||
|
||||
await session.refresh(result)
|
||||
assert result.password_hash != old_password_hash
|
||||
assert result.password_hash is not None
|
||||
assert (
|
||||
"NewDifferentPassword123!" not in result.password_hash
|
||||
) # Should be hashed
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_user_with_dict(self, async_test_db, async_test_user):
|
||||
"""Test updating user with dictionary."""
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
user = await user_crud.get(session, id=str(async_test_user.id))
|
||||
|
||||
update_dict = {"first_name": "DictUpdate"}
|
||||
result = await user_crud.update(session, db_obj=user, obj_in=update_dict)
|
||||
|
||||
assert result.first_name == "DictUpdate"
|
||||
|
||||
|
||||
class TestGetMultiWithTotal:
|
||||
"""Tests for get_multi_with_total method."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_multi_with_total_basic(self, async_test_db, async_test_user):
|
||||
"""Test basic pagination."""
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
users, total = await user_crud.get_multi_with_total(
|
||||
session, skip=0, limit=10
|
||||
)
|
||||
assert total >= 1
|
||||
assert len(users) >= 1
|
||||
assert any(u.id == async_test_user.id for u in users)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_multi_with_total_sorting_asc(self, async_test_db):
|
||||
"""Test sorting in ascending order."""
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
# Create multiple users
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
for i in range(3):
|
||||
user_data = UserCreate(
|
||||
email=f"sort{i}@example.com",
|
||||
password="SecurePass123!",
|
||||
first_name=f"User{i}",
|
||||
last_name="Test",
|
||||
)
|
||||
await user_crud.create(session, obj_in=user_data)
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
users, _total = await user_crud.get_multi_with_total(
|
||||
session, skip=0, limit=10, sort_by="email", sort_order="asc"
|
||||
)
|
||||
|
||||
# Check if sorted (at least the test users)
|
||||
test_users = [u for u in users if u.email.startswith("sort")]
|
||||
if len(test_users) > 1:
|
||||
assert test_users[0].email < test_users[1].email
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_multi_with_total_sorting_desc(self, async_test_db):
|
||||
"""Test sorting in descending order."""
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
# Create multiple users
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
for i in range(3):
|
||||
user_data = UserCreate(
|
||||
email=f"desc{i}@example.com",
|
||||
password="SecurePass123!",
|
||||
first_name=f"User{i}",
|
||||
last_name="Test",
|
||||
)
|
||||
await user_crud.create(session, obj_in=user_data)
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
users, _total = await user_crud.get_multi_with_total(
|
||||
session, skip=0, limit=10, sort_by="email", sort_order="desc"
|
||||
)
|
||||
|
||||
# Check if sorted descending (at least the test users)
|
||||
test_users = [u for u in users if u.email.startswith("desc")]
|
||||
if len(test_users) > 1:
|
||||
assert test_users[0].email > test_users[1].email
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_multi_with_total_filtering(self, async_test_db):
|
||||
"""Test filtering by field."""
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
# Create active and inactive users
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
active_user = UserCreate(
|
||||
email="active@example.com",
|
||||
password="SecurePass123!",
|
||||
first_name="Active",
|
||||
last_name="User",
|
||||
)
|
||||
await user_crud.create(session, obj_in=active_user)
|
||||
|
||||
inactive_user = UserCreate(
|
||||
email="inactive@example.com",
|
||||
password="SecurePass123!",
|
||||
first_name="Inactive",
|
||||
last_name="User",
|
||||
)
|
||||
created_inactive = await user_crud.create(session, obj_in=inactive_user)
|
||||
|
||||
# Deactivate the user
|
||||
await user_crud.update(
|
||||
session, db_obj=created_inactive, obj_in={"is_active": False}
|
||||
)
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
users, _total = await user_crud.get_multi_with_total(
|
||||
session, skip=0, limit=100, filters={"is_active": True}
|
||||
)
|
||||
|
||||
# All returned users should be active
|
||||
assert all(u.is_active for u in users)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_multi_with_total_search(self, async_test_db):
|
||||
"""Test search functionality."""
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
# Create user with unique name
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
user_data = UserCreate(
|
||||
email="searchable@example.com",
|
||||
password="SecurePass123!",
|
||||
first_name="Searchable",
|
||||
last_name="UserName",
|
||||
)
|
||||
await user_crud.create(session, obj_in=user_data)
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
users, total = await user_crud.get_multi_with_total(
|
||||
session, skip=0, limit=100, search="Searchable"
|
||||
)
|
||||
|
||||
assert total >= 1
|
||||
assert any(u.first_name == "Searchable" for u in users)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_multi_with_total_pagination(self, async_test_db):
|
||||
"""Test pagination with skip and limit."""
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
# Create multiple users
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
for i in range(5):
|
||||
user_data = UserCreate(
|
||||
email=f"page{i}@example.com",
|
||||
password="SecurePass123!",
|
||||
first_name=f"Page{i}",
|
||||
last_name="User",
|
||||
)
|
||||
await user_crud.create(session, obj_in=user_data)
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
# Get first page
|
||||
users_page1, total = await user_crud.get_multi_with_total(
|
||||
session, skip=0, limit=2
|
||||
)
|
||||
|
||||
# Get second page
|
||||
users_page2, total2 = await user_crud.get_multi_with_total(
|
||||
session, skip=2, limit=2
|
||||
)
|
||||
|
||||
# Total should be same
|
||||
assert total == total2
|
||||
# Different users on different pages
|
||||
assert users_page1[0].id != users_page2[0].id
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_multi_with_total_validation_negative_skip(self, async_test_db):
|
||||
"""Test validation fails for negative skip."""
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
with pytest.raises(InvalidInputError) as exc_info:
|
||||
await user_crud.get_multi_with_total(session, skip=-1, limit=10)
|
||||
|
||||
assert "skip must be non-negative" in str(exc_info.value)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_multi_with_total_validation_negative_limit(self, async_test_db):
|
||||
"""Test validation fails for negative limit."""
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
with pytest.raises(InvalidInputError) as exc_info:
|
||||
await user_crud.get_multi_with_total(session, skip=0, limit=-1)
|
||||
|
||||
assert "limit must be non-negative" in str(exc_info.value)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_multi_with_total_validation_max_limit(self, async_test_db):
|
||||
"""Test validation fails for limit > 1000."""
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
with pytest.raises(InvalidInputError) as exc_info:
|
||||
await user_crud.get_multi_with_total(session, skip=0, limit=1001)
|
||||
|
||||
assert "Maximum limit is 1000" in str(exc_info.value)
|
||||
|
||||
|
||||
class TestBulkUpdateStatus:
|
||||
"""Tests for bulk_update_status method."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_bulk_update_status_success(self, async_test_db):
|
||||
"""Test bulk updating user status."""
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
# Create multiple users
|
||||
user_ids = []
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
for i in range(3):
|
||||
user_data = UserCreate(
|
||||
email=f"bulk{i}@example.com",
|
||||
password="SecurePass123!",
|
||||
first_name=f"Bulk{i}",
|
||||
last_name="User",
|
||||
)
|
||||
user = await user_crud.create(session, obj_in=user_data)
|
||||
user_ids.append(user.id)
|
||||
|
||||
# Bulk deactivate
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
count = await user_crud.bulk_update_status(
|
||||
session, user_ids=user_ids, is_active=False
|
||||
)
|
||||
assert count == 3
|
||||
|
||||
# Verify all are inactive
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
for user_id in user_ids:
|
||||
user = await user_crud.get(session, id=str(user_id))
|
||||
assert user.is_active is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_bulk_update_status_empty_list(self, async_test_db):
|
||||
"""Test bulk update with empty list returns 0."""
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
count = await user_crud.bulk_update_status(
|
||||
session, user_ids=[], is_active=False
|
||||
)
|
||||
assert count == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_bulk_update_status_reactivate(self, async_test_db):
|
||||
"""Test bulk reactivating users."""
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
# Create inactive user
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
user_data = UserCreate(
|
||||
email="reactivate@example.com",
|
||||
password="SecurePass123!",
|
||||
first_name="Reactivate",
|
||||
last_name="User",
|
||||
)
|
||||
user = await user_crud.create(session, obj_in=user_data)
|
||||
# Deactivate
|
||||
await user_crud.update(session, db_obj=user, obj_in={"is_active": False})
|
||||
user_id = user.id
|
||||
|
||||
# Reactivate
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
count = await user_crud.bulk_update_status(
|
||||
session, user_ids=[user_id], is_active=True
|
||||
)
|
||||
assert count == 1
|
||||
|
||||
# Verify active
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
user = await user_crud.get(session, id=str(user_id))
|
||||
assert user.is_active is True
|
||||
|
||||
|
||||
class TestBulkSoftDelete:
|
||||
"""Tests for bulk_soft_delete method."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_bulk_soft_delete_success(self, async_test_db):
|
||||
"""Test bulk soft deleting users."""
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
# Create multiple users
|
||||
user_ids = []
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
for i in range(3):
|
||||
user_data = UserCreate(
|
||||
email=f"delete{i}@example.com",
|
||||
password="SecurePass123!",
|
||||
first_name=f"Delete{i}",
|
||||
last_name="User",
|
||||
)
|
||||
user = await user_crud.create(session, obj_in=user_data)
|
||||
user_ids.append(user.id)
|
||||
|
||||
# Bulk delete
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
count = await user_crud.bulk_soft_delete(session, user_ids=user_ids)
|
||||
assert count == 3
|
||||
|
||||
# Verify all are soft deleted
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
for user_id in user_ids:
|
||||
user = await user_crud.get(session, id=str(user_id))
|
||||
assert user.deleted_at is not None
|
||||
assert user.is_active is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_bulk_soft_delete_with_exclusion(self, async_test_db):
|
||||
"""Test bulk soft delete with excluded user_crud."""
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
# Create multiple users
|
||||
user_ids = []
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
for i in range(3):
|
||||
user_data = UserCreate(
|
||||
email=f"exclude{i}@example.com",
|
||||
password="SecurePass123!",
|
||||
first_name=f"Exclude{i}",
|
||||
last_name="User",
|
||||
)
|
||||
user = await user_crud.create(session, obj_in=user_data)
|
||||
user_ids.append(user.id)
|
||||
|
||||
# Bulk delete, excluding first user
|
||||
exclude_id = user_ids[0]
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
count = await user_crud.bulk_soft_delete(
|
||||
session, user_ids=user_ids, exclude_user_id=exclude_id
|
||||
)
|
||||
assert count == 2 # Only 2 deleted
|
||||
|
||||
# Verify excluded user is NOT deleted
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
excluded_user = await user_crud.get(session, id=str(exclude_id))
|
||||
assert excluded_user.deleted_at is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_bulk_soft_delete_empty_list(self, async_test_db):
|
||||
"""Test bulk delete with empty list returns 0."""
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
count = await user_crud.bulk_soft_delete(session, user_ids=[])
|
||||
assert count == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_bulk_soft_delete_all_excluded(self, async_test_db):
|
||||
"""Test bulk delete where all users are excluded."""
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
# Create user
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
user_data = UserCreate(
|
||||
email="onlyuser@example.com",
|
||||
password="SecurePass123!",
|
||||
first_name="Only",
|
||||
last_name="User",
|
||||
)
|
||||
user = await user_crud.create(session, obj_in=user_data)
|
||||
user_id = user.id
|
||||
|
||||
# Try to delete but exclude
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
count = await user_crud.bulk_soft_delete(
|
||||
session, user_ids=[user_id], exclude_user_id=user_id
|
||||
)
|
||||
assert count == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_bulk_soft_delete_already_deleted(self, async_test_db):
|
||||
"""Test bulk delete doesn't re-delete already deleted users."""
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
# Create and delete user
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
user_data = UserCreate(
|
||||
email="predeleted@example.com",
|
||||
password="SecurePass123!",
|
||||
first_name="PreDeleted",
|
||||
last_name="User",
|
||||
)
|
||||
user = await user_crud.create(session, obj_in=user_data)
|
||||
user_id = user.id
|
||||
|
||||
# First deletion
|
||||
await user_crud.bulk_soft_delete(session, user_ids=[user_id])
|
||||
|
||||
# Try to delete again
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
count = await user_crud.bulk_soft_delete(session, user_ids=[user_id])
|
||||
assert count == 0 # Already deleted
|
||||
|
||||
|
||||
class TestUtilityMethods:
|
||||
"""Tests for utility methods."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_is_active_true(self, async_test_db, async_test_user):
|
||||
"""Test is_active returns True for active user_crud."""
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
user = await user_crud.get(session, id=str(async_test_user.id))
|
||||
assert user_crud.is_active(user) is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_is_active_false(self, async_test_db):
|
||||
"""Test is_active returns False for inactive user_crud."""
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
user_data = UserCreate(
|
||||
email="inactive2@example.com",
|
||||
password="SecurePass123!",
|
||||
first_name="Inactive",
|
||||
last_name="User",
|
||||
)
|
||||
user = await user_crud.create(session, obj_in=user_data)
|
||||
await user_crud.update(session, db_obj=user, obj_in={"is_active": False})
|
||||
|
||||
assert user_crud.is_active(user) is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_is_superuser_true(self, async_test_db, async_test_superuser):
|
||||
"""Test is_superuser returns True for superuser."""
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
user = await user_crud.get(session, id=str(async_test_superuser.id))
|
||||
assert user_crud.is_superuser(user) is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_is_superuser_false(self, async_test_db, async_test_user):
|
||||
"""Test is_superuser returns False for regular user_crud."""
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
user = await user_crud.get(session, id=str(async_test_user.id))
|
||||
assert user_crud.is_superuser(user) is False
|
||||
|
||||
|
||||
class TestUserExceptionHandlers:
|
||||
"""
|
||||
Test exception handlers in user CRUD methods.
|
||||
Covers lines: 30-32, 205-208, 257-260
|
||||
"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_by_email_database_error(self, async_test_db):
|
||||
"""Test get_by_email handles database errors (covers lines 30-32)."""
|
||||
from unittest.mock import patch
|
||||
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
with patch.object(
|
||||
session, "execute", side_effect=Exception("Database query failed")
|
||||
):
|
||||
with pytest.raises(Exception, match="Database query failed"):
|
||||
await user_crud.get_by_email(session, email="test@example.com")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_bulk_update_status_database_error(
|
||||
self, async_test_db, async_test_user
|
||||
):
|
||||
"""Test bulk_update_status handles database errors (covers lines 205-208)."""
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
# Mock execute to fail
|
||||
with patch.object(
|
||||
session, "execute", side_effect=Exception("Bulk update failed")
|
||||
):
|
||||
with patch.object(session, "rollback", new_callable=AsyncMock):
|
||||
with pytest.raises(Exception, match="Bulk update failed"):
|
||||
await user_crud.bulk_update_status(
|
||||
session, user_ids=[async_test_user.id], is_active=False
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_bulk_soft_delete_database_error(
|
||||
self, async_test_db, async_test_user
|
||||
):
|
||||
"""Test bulk_soft_delete handles database errors (covers lines 257-260)."""
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
# Mock execute to fail
|
||||
with patch.object(
|
||||
session, "execute", side_effect=Exception("Bulk delete failed")
|
||||
):
|
||||
with patch.object(session, "rollback", new_callable=AsyncMock):
|
||||
with pytest.raises(Exception, match="Bulk delete failed"):
|
||||
await user_crud.bulk_soft_delete(
|
||||
session, user_ids=[async_test_user.id]
|
||||
)
|
||||
Reference in New Issue
Block a user