refactor(backend): enforce route→service→repo layered architecture

- introduce custom repository exception hierarchy (DuplicateEntryError,
  IntegrityConstraintError, InvalidInputError) replacing raw ValueError
- eliminate all direct repository imports and raw SQL from route layer
- add UserService, SessionService, OrganizationService to service layer
- add get_stats/get_org_distribution service methods replacing admin inline SQL
- fix timing side-channel in authenticate_user via dummy bcrypt check
- replace SHA-256 client secret fallback with explicit InvalidClientError
- replace assert with InvalidGrantError in authorization code exchange
- replace N+1 token revocation loops with bulk UPDATE statements
- rename oauth account token fields (drop misleading 'encrypted' suffix)
- add Alembic migration 0003 for token field column rename
- add 45 new service/repository tests; 975 passing, 94% coverage
This commit is contained in:
2026-02-27 09:32:57 +01:00
parent 0646c96b19
commit 98b455fdc3
62 changed files with 2933 additions and 1728 deletions

View File

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,334 @@
# tests/crud/test_base_db_failures.py
"""
Comprehensive tests for base CRUD database failure scenarios.
Tests exception handling, rollbacks, and error messages.
"""
from unittest.mock import AsyncMock, patch
from uuid import uuid4
import pytest
from sqlalchemy.exc import DataError, OperationalError
from app.core.repository_exceptions import IntegrityConstraintError
from app.repositories.user import user_repo as user_crud
from app.schemas.users import UserCreate
class TestBaseCRUDCreateFailures:
"""Test base CRUD create method exception handling."""
@pytest.mark.asyncio
async def test_create_operational_error_triggers_rollback(self, async_test_db):
"""Test that OperationalError triggers rollback (User CRUD catches as Exception)."""
_test_engine, SessionLocal = async_test_db
async with SessionLocal() as session:
async def mock_commit():
raise OperationalError(
"Connection lost", {}, Exception("DB connection failed")
)
with patch.object(session, "commit", side_effect=mock_commit):
with patch.object(
session, "rollback", new_callable=AsyncMock
) as mock_rollback:
user_data = UserCreate(
email="operror@example.com",
password="TestPassword123!",
first_name="Test",
last_name="User",
)
# User CRUD catches this as generic Exception and re-raises
with pytest.raises(OperationalError):
await user_crud.create(session, obj_in=user_data)
# Verify rollback was called
mock_rollback.assert_called_once()
@pytest.mark.asyncio
async def test_create_data_error_triggers_rollback(self, async_test_db):
"""Test that DataError triggers rollback (User CRUD catches as Exception)."""
_test_engine, SessionLocal = async_test_db
async with SessionLocal() as session:
async def mock_commit():
raise DataError("Invalid data type", {}, Exception("Data overflow"))
with patch.object(session, "commit", side_effect=mock_commit):
with patch.object(
session, "rollback", new_callable=AsyncMock
) as mock_rollback:
user_data = UserCreate(
email="dataerror@example.com",
password="TestPassword123!",
first_name="Test",
last_name="User",
)
# User CRUD catches this as generic Exception and re-raises
with pytest.raises(DataError):
await user_crud.create(session, obj_in=user_data)
mock_rollback.assert_called_once()
@pytest.mark.asyncio
async def test_create_unexpected_exception_triggers_rollback(self, async_test_db):
"""Test that unexpected exceptions trigger rollback and re-raise."""
_test_engine, SessionLocal = async_test_db
async with SessionLocal() as session:
async def mock_commit():
raise RuntimeError("Unexpected database error")
with patch.object(session, "commit", side_effect=mock_commit):
with patch.object(
session, "rollback", new_callable=AsyncMock
) as mock_rollback:
user_data = UserCreate(
email="unexpected@example.com",
password="TestPassword123!",
first_name="Test",
last_name="User",
)
with pytest.raises(RuntimeError, match="Unexpected database error"):
await user_crud.create(session, obj_in=user_data)
mock_rollback.assert_called_once()
class TestBaseCRUDUpdateFailures:
"""Test base CRUD update method exception handling."""
@pytest.mark.asyncio
async def test_update_operational_error(self, async_test_db, async_test_user):
"""Test update with OperationalError."""
_test_engine, SessionLocal = async_test_db
async with SessionLocal() as session:
user = await user_crud.get(session, id=str(async_test_user.id))
async def mock_commit():
raise OperationalError("Connection timeout", {}, Exception("Timeout"))
with patch.object(session, "commit", side_effect=mock_commit):
with patch.object(
session, "rollback", new_callable=AsyncMock
) as mock_rollback:
with pytest.raises(IntegrityConstraintError, match="Database operation failed"):
await user_crud.update(
session, db_obj=user, obj_in={"first_name": "Updated"}
)
mock_rollback.assert_called_once()
@pytest.mark.asyncio
async def test_update_data_error(self, async_test_db, async_test_user):
"""Test update with DataError."""
_test_engine, SessionLocal = async_test_db
async with SessionLocal() as session:
user = await user_crud.get(session, id=str(async_test_user.id))
async def mock_commit():
raise DataError("Invalid data", {}, Exception("Data type mismatch"))
with patch.object(session, "commit", side_effect=mock_commit):
with patch.object(
session, "rollback", new_callable=AsyncMock
) as mock_rollback:
with pytest.raises(IntegrityConstraintError, match="Database operation failed"):
await user_crud.update(
session, db_obj=user, obj_in={"first_name": "Updated"}
)
mock_rollback.assert_called_once()
@pytest.mark.asyncio
async def test_update_unexpected_error(self, async_test_db, async_test_user):
"""Test update with unexpected error."""
_test_engine, SessionLocal = async_test_db
async with SessionLocal() as session:
user = await user_crud.get(session, id=str(async_test_user.id))
async def mock_commit():
raise KeyError("Unexpected error")
with patch.object(session, "commit", side_effect=mock_commit):
with patch.object(
session, "rollback", new_callable=AsyncMock
) as mock_rollback:
with pytest.raises(KeyError):
await user_crud.update(
session, db_obj=user, obj_in={"first_name": "Updated"}
)
mock_rollback.assert_called_once()
class TestBaseCRUDRemoveFailures:
"""Test base CRUD remove method exception handling."""
@pytest.mark.asyncio
async def test_remove_unexpected_error_triggers_rollback(
self, async_test_db, async_test_user
):
"""Test that unexpected errors in remove trigger rollback."""
_test_engine, SessionLocal = async_test_db
async with SessionLocal() as session:
async def mock_commit():
raise RuntimeError("Database write failed")
with patch.object(session, "commit", side_effect=mock_commit):
with patch.object(
session, "rollback", new_callable=AsyncMock
) as mock_rollback:
with pytest.raises(RuntimeError, match="Database write failed"):
await user_crud.remove(session, id=str(async_test_user.id))
mock_rollback.assert_called_once()
class TestBaseCRUDGetMultiWithTotalFailures:
"""Test get_multi_with_total exception handling."""
@pytest.mark.asyncio
async def test_get_multi_with_total_database_error(self, async_test_db):
"""Test get_multi_with_total handles database errors."""
_test_engine, SessionLocal = async_test_db
async with SessionLocal() as session:
# Mock execute to raise an error
async def mock_execute(*args, **kwargs):
raise OperationalError("Query failed", {}, Exception("Database error"))
with patch.object(session, "execute", side_effect=mock_execute):
with pytest.raises(OperationalError):
await user_crud.get_multi_with_total(session, skip=0, limit=10)
class TestBaseCRUDCountFailures:
"""Test count method exception handling."""
@pytest.mark.asyncio
async def test_count_database_error_propagates(self, async_test_db):
"""Test count propagates database errors."""
_test_engine, SessionLocal = async_test_db
async with SessionLocal() as session:
async def mock_execute(*args, **kwargs):
raise OperationalError("Count failed", {}, Exception("DB error"))
with patch.object(session, "execute", side_effect=mock_execute):
with pytest.raises(OperationalError):
await user_crud.count(session)
class TestBaseCRUDSoftDeleteFailures:
"""Test soft_delete method exception handling."""
@pytest.mark.asyncio
async def test_soft_delete_unexpected_error_triggers_rollback(
self, async_test_db, async_test_user
):
"""Test soft_delete handles unexpected errors with rollback."""
_test_engine, SessionLocal = async_test_db
async with SessionLocal() as session:
async def mock_commit():
raise RuntimeError("Soft delete failed")
with patch.object(session, "commit", side_effect=mock_commit):
with patch.object(
session, "rollback", new_callable=AsyncMock
) as mock_rollback:
with pytest.raises(RuntimeError, match="Soft delete failed"):
await user_crud.soft_delete(session, id=str(async_test_user.id))
mock_rollback.assert_called_once()
class TestBaseCRUDRestoreFailures:
"""Test restore method exception handling."""
@pytest.mark.asyncio
async def test_restore_unexpected_error_triggers_rollback(self, async_test_db):
"""Test restore handles unexpected errors with rollback."""
_test_engine, SessionLocal = async_test_db
# First create and soft delete a user
async with SessionLocal() as session:
user_data = UserCreate(
email="restore_test@example.com",
password="TestPassword123!",
first_name="Restore",
last_name="Test",
)
user = await user_crud.create(session, obj_in=user_data)
user_id = user.id
await session.commit()
async with SessionLocal() as session:
await user_crud.soft_delete(session, id=str(user_id))
# Now test restore failure
async with SessionLocal() as session:
async def mock_commit():
raise RuntimeError("Restore failed")
with patch.object(session, "commit", side_effect=mock_commit):
with patch.object(
session, "rollback", new_callable=AsyncMock
) as mock_rollback:
with pytest.raises(RuntimeError, match="Restore failed"):
await user_crud.restore(session, id=str(user_id))
mock_rollback.assert_called_once()
class TestBaseCRUDGetFailures:
"""Test get method exception handling."""
@pytest.mark.asyncio
async def test_get_database_error_propagates(self, async_test_db):
"""Test get propagates database errors."""
_test_engine, SessionLocal = async_test_db
async with SessionLocal() as session:
async def mock_execute(*args, **kwargs):
raise OperationalError("Get failed", {}, Exception("DB error"))
with patch.object(session, "execute", side_effect=mock_execute):
with pytest.raises(OperationalError):
await user_crud.get(session, id=str(uuid4()))
class TestBaseCRUDGetMultiFailures:
"""Test get_multi method exception handling."""
@pytest.mark.asyncio
async def test_get_multi_database_error_propagates(self, async_test_db):
"""Test get_multi propagates database errors."""
_test_engine, SessionLocal = async_test_db
async with SessionLocal() as session:
async def mock_execute(*args, **kwargs):
raise OperationalError("Query failed", {}, Exception("DB error"))
with patch.object(session, "execute", side_effect=mock_execute):
with pytest.raises(OperationalError):
await user_crud.get_multi(session, skip=0, limit=10)

View File

@@ -0,0 +1,603 @@
# tests/crud/test_oauth.py
"""
Comprehensive tests for OAuth CRUD operations.
"""
from datetime import UTC, datetime, timedelta
import pytest
from app.core.repository_exceptions import DuplicateEntryError
from app.repositories.oauth_account import oauth_account_repo as oauth_account
from app.repositories.oauth_client import oauth_client_repo as oauth_client
from app.repositories.oauth_state import oauth_state_repo as oauth_state
from app.schemas.oauth import OAuthAccountCreate, OAuthClientCreate, OAuthStateCreate
class TestOAuthAccountCRUD:
"""Tests for OAuth account CRUD operations."""
@pytest.mark.asyncio
async def test_create_account(self, async_test_db, async_test_user):
"""Test creating an OAuth account link."""
_test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session:
account_data = OAuthAccountCreate(
user_id=async_test_user.id,
provider="google",
provider_user_id="google_123456",
provider_email="user@gmail.com",
)
account = await oauth_account.create_account(session, obj_in=account_data)
assert account is not None
assert account.provider == "google"
assert account.provider_user_id == "google_123456"
assert account.user_id == async_test_user.id
@pytest.mark.asyncio
async def test_create_account_same_provider_twice_fails(
self, async_test_db, async_test_user
):
"""Test creating same OAuth account for same user twice raises error."""
_test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session:
account_data = OAuthAccountCreate(
user_id=async_test_user.id,
provider="google",
provider_user_id="google_dup_123",
provider_email="user@gmail.com",
)
await oauth_account.create_account(session, obj_in=account_data)
# Try to create same account again (same provider + provider_user_id)
async with AsyncTestingSessionLocal() as session:
account_data2 = OAuthAccountCreate(
user_id=async_test_user.id, # Same user
provider="google",
provider_user_id="google_dup_123", # Same provider_user_id
provider_email="user@gmail.com",
)
# SQLite returns different error message than PostgreSQL
with pytest.raises(
DuplicateEntryError, match="(already linked|UNIQUE constraint failed|Failed to create)"
):
await oauth_account.create_account(session, obj_in=account_data2)
@pytest.mark.asyncio
async def test_get_by_provider_id(self, async_test_db, async_test_user):
"""Test getting OAuth account by provider and provider user ID."""
_test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session:
account_data = OAuthAccountCreate(
user_id=async_test_user.id,
provider="github",
provider_user_id="github_789",
provider_email="user@github.com",
)
await oauth_account.create_account(session, obj_in=account_data)
async with AsyncTestingSessionLocal() as session:
result = await oauth_account.get_by_provider_id(
session,
provider="github",
provider_user_id="github_789",
)
assert result is not None
assert result.provider == "github"
assert result.user is not None # Eager loaded
@pytest.mark.asyncio
async def test_get_by_provider_id_not_found(self, async_test_db):
"""Test getting non-existent OAuth account returns None."""
_test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session:
result = await oauth_account.get_by_provider_id(
session,
provider="google",
provider_user_id="nonexistent",
)
assert result is None
@pytest.mark.asyncio
async def test_get_user_accounts(self, async_test_db, async_test_user):
"""Test getting all OAuth accounts for a user."""
_test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session:
# Create two accounts for the same user
for provider in ["google", "github"]:
account_data = OAuthAccountCreate(
user_id=async_test_user.id,
provider=provider,
provider_user_id=f"{provider}_user_123",
provider_email=f"user@{provider}.com",
)
await oauth_account.create_account(session, obj_in=account_data)
async with AsyncTestingSessionLocal() as session:
accounts = await oauth_account.get_user_accounts(
session, user_id=async_test_user.id
)
assert len(accounts) == 2
providers = {a.provider for a in accounts}
assert providers == {"google", "github"}
@pytest.mark.asyncio
async def test_get_user_account_by_provider(self, async_test_db, async_test_user):
"""Test getting specific OAuth account for user and provider."""
_test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session:
account_data = OAuthAccountCreate(
user_id=async_test_user.id,
provider="google",
provider_user_id="google_specific",
)
await oauth_account.create_account(session, obj_in=account_data)
async with AsyncTestingSessionLocal() as session:
result = await oauth_account.get_user_account_by_provider(
session,
user_id=async_test_user.id,
provider="google",
)
assert result is not None
assert result.provider == "google"
# Test not found
result2 = await oauth_account.get_user_account_by_provider(
session,
user_id=async_test_user.id,
provider="github", # Not linked
)
assert result2 is None
@pytest.mark.asyncio
async def test_delete_account(self, async_test_db, async_test_user):
"""Test deleting an OAuth account link."""
_test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session:
account_data = OAuthAccountCreate(
user_id=async_test_user.id,
provider="google",
provider_user_id="google_to_delete",
)
await oauth_account.create_account(session, obj_in=account_data)
async with AsyncTestingSessionLocal() as session:
deleted = await oauth_account.delete_account(
session,
user_id=async_test_user.id,
provider="google",
)
assert deleted is True
# Verify deletion
async with AsyncTestingSessionLocal() as session:
result = await oauth_account.get_user_account_by_provider(
session,
user_id=async_test_user.id,
provider="google",
)
assert result is None
@pytest.mark.asyncio
async def test_delete_account_not_found(self, async_test_db, async_test_user):
"""Test deleting non-existent account returns False."""
_test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session:
deleted = await oauth_account.delete_account(
session,
user_id=async_test_user.id,
provider="nonexistent",
)
assert deleted is False
@pytest.mark.asyncio
async def test_get_by_provider_email(self, async_test_db, async_test_user):
"""Test getting OAuth account by provider and email."""
_test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session:
account_data = OAuthAccountCreate(
user_id=async_test_user.id,
provider="google",
provider_user_id="google_email_test",
provider_email="unique@gmail.com",
)
await oauth_account.create_account(session, obj_in=account_data)
async with AsyncTestingSessionLocal() as session:
result = await oauth_account.get_by_provider_email(
session,
provider="google",
email="unique@gmail.com",
)
assert result is not None
assert result.provider_email == "unique@gmail.com"
# Test not found
result2 = await oauth_account.get_by_provider_email(
session,
provider="google",
email="nonexistent@gmail.com",
)
assert result2 is None
@pytest.mark.asyncio
async def test_update_tokens(self, async_test_db, async_test_user):
"""Test updating OAuth tokens."""
from datetime import UTC, datetime, timedelta
_test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session:
account_data = OAuthAccountCreate(
user_id=async_test_user.id,
provider="google",
provider_user_id="google_token_test",
)
account = await oauth_account.create_account(session, obj_in=account_data)
async with AsyncTestingSessionLocal() as session:
# Get the account first
account = await oauth_account.get_by_provider_id(
session, provider="google", provider_user_id="google_token_test"
)
assert account is not None
# Update tokens
new_expires = datetime.now(UTC) + timedelta(hours=1)
updated = await oauth_account.update_tokens(
session,
account=account,
access_token="new_access_token",
refresh_token="new_refresh_token",
token_expires_at=new_expires,
)
assert updated.access_token == "new_access_token"
assert updated.refresh_token == "new_refresh_token"
class TestOAuthStateCRUD:
"""Tests for OAuth state CRUD operations."""
@pytest.mark.asyncio
async def test_create_state(self, async_test_db):
"""Test creating OAuth state."""
_test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session:
state_data = OAuthStateCreate(
state="random_state_123",
code_verifier="pkce_verifier",
nonce="oidc_nonce",
provider="google",
redirect_uri="http://localhost:3000/callback",
expires_at=datetime.now(UTC) + timedelta(minutes=10),
)
state = await oauth_state.create_state(session, obj_in=state_data)
assert state is not None
assert state.state == "random_state_123"
assert state.code_verifier == "pkce_verifier"
assert state.provider == "google"
@pytest.mark.asyncio
async def test_get_and_consume_state(self, async_test_db):
"""Test getting and consuming OAuth state."""
_test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session:
state_data = OAuthStateCreate(
state="consume_state_123",
provider="github",
expires_at=datetime.now(UTC) + timedelta(minutes=10),
)
await oauth_state.create_state(session, obj_in=state_data)
# Consume the state
async with AsyncTestingSessionLocal() as session:
result = await oauth_state.get_and_consume_state(
session, state="consume_state_123"
)
assert result is not None
assert result.provider == "github"
# Try to consume again - should be None (already consumed)
async with AsyncTestingSessionLocal() as session:
result2 = await oauth_state.get_and_consume_state(
session, state="consume_state_123"
)
assert result2 is None
@pytest.mark.asyncio
async def test_get_and_consume_expired_state(self, async_test_db):
"""Test consuming expired state returns None."""
_test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session:
# Create expired state
state_data = OAuthStateCreate(
state="expired_state_123",
provider="google",
expires_at=datetime.now(UTC) - timedelta(minutes=1), # Already expired
)
await oauth_state.create_state(session, obj_in=state_data)
async with AsyncTestingSessionLocal() as session:
result = await oauth_state.get_and_consume_state(
session, state="expired_state_123"
)
assert result is None
@pytest.mark.asyncio
async def test_cleanup_expired_states(self, async_test_db):
"""Test cleaning up expired OAuth states."""
_test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session:
# Create expired state
expired_state = OAuthStateCreate(
state="cleanup_expired",
provider="google",
expires_at=datetime.now(UTC) - timedelta(minutes=5),
)
await oauth_state.create_state(session, obj_in=expired_state)
# Create valid state
valid_state = OAuthStateCreate(
state="cleanup_valid",
provider="google",
expires_at=datetime.now(UTC) + timedelta(minutes=10),
)
await oauth_state.create_state(session, obj_in=valid_state)
# Cleanup
async with AsyncTestingSessionLocal() as session:
count = await oauth_state.cleanup_expired(session)
assert count == 1
# Verify only expired was deleted
async with AsyncTestingSessionLocal() as session:
result = await oauth_state.get_and_consume_state(
session, state="cleanup_valid"
)
assert result is not None
class TestOAuthClientCRUD:
"""Tests for OAuth client CRUD operations (provider mode)."""
@pytest.mark.asyncio
async def test_create_public_client(self, async_test_db):
"""Test creating a public OAuth client."""
_test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session:
client_data = OAuthClientCreate(
client_name="Test MCP App",
client_description="A test application",
redirect_uris=["http://localhost:3000/callback"],
allowed_scopes=["read:users"],
client_type="public",
)
client, secret = await oauth_client.create_client(
session, obj_in=client_data
)
assert client is not None
assert client.client_name == "Test MCP App"
assert client.client_type == "public"
assert secret is None # Public clients don't have secrets
@pytest.mark.asyncio
async def test_create_confidential_client(self, async_test_db):
"""Test creating a confidential OAuth client."""
_test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session:
client_data = OAuthClientCreate(
client_name="Confidential App",
redirect_uris=["http://localhost:8080/callback"],
allowed_scopes=["read:users", "write:users"],
client_type="confidential",
)
client, secret = await oauth_client.create_client(
session, obj_in=client_data
)
assert client is not None
assert client.client_type == "confidential"
assert secret is not None # Confidential clients have secrets
assert len(secret) > 20 # Should be a reasonably long secret
@pytest.mark.asyncio
async def test_get_by_client_id(self, async_test_db):
"""Test getting OAuth client by client_id."""
_test_engine, AsyncTestingSessionLocal = async_test_db
created_client_id = None
async with AsyncTestingSessionLocal() as session:
client_data = OAuthClientCreate(
client_name="Lookup Test",
redirect_uris=["http://localhost:3000/callback"],
allowed_scopes=["read:users"],
)
client, _ = await oauth_client.create_client(session, obj_in=client_data)
created_client_id = client.client_id
async with AsyncTestingSessionLocal() as session:
result = await oauth_client.get_by_client_id(
session, client_id=created_client_id
)
assert result is not None
assert result.client_name == "Lookup Test"
@pytest.mark.asyncio
async def test_get_inactive_client_not_found(self, async_test_db):
"""Test getting inactive OAuth client returns None."""
_test_engine, AsyncTestingSessionLocal = async_test_db
created_client_id = None
async with AsyncTestingSessionLocal() as session:
client_data = OAuthClientCreate(
client_name="Inactive Client",
redirect_uris=["http://localhost:3000/callback"],
allowed_scopes=["read:users"],
)
client, _ = await oauth_client.create_client(session, obj_in=client_data)
created_client_id = client.client_id
# Deactivate
await oauth_client.deactivate_client(session, client_id=created_client_id)
async with AsyncTestingSessionLocal() as session:
result = await oauth_client.get_by_client_id(
session, client_id=created_client_id
)
assert result is None # Inactive clients not returned
@pytest.mark.asyncio
async def test_validate_redirect_uri(self, async_test_db):
"""Test redirect URI validation."""
_test_engine, AsyncTestingSessionLocal = async_test_db
created_client_id = None
async with AsyncTestingSessionLocal() as session:
client_data = OAuthClientCreate(
client_name="URI Test",
redirect_uris=[
"http://localhost:3000/callback",
"http://localhost:8080/oauth",
],
allowed_scopes=["read:users"],
)
client, _ = await oauth_client.create_client(session, obj_in=client_data)
created_client_id = client.client_id
async with AsyncTestingSessionLocal() as session:
# Valid URI
valid = await oauth_client.validate_redirect_uri(
session,
client_id=created_client_id,
redirect_uri="http://localhost:3000/callback",
)
assert valid is True
# Invalid URI
invalid = await oauth_client.validate_redirect_uri(
session,
client_id=created_client_id,
redirect_uri="http://evil.com/callback",
)
assert invalid is False
@pytest.mark.asyncio
async def test_verify_client_secret(self, async_test_db):
"""Test client secret verification."""
_test_engine, AsyncTestingSessionLocal = async_test_db
created_client_id = None
created_secret = None
async with AsyncTestingSessionLocal() as session:
client_data = OAuthClientCreate(
client_name="Secret Test",
redirect_uris=["http://localhost:3000/callback"],
allowed_scopes=["read:users"],
client_type="confidential",
)
client, secret = await oauth_client.create_client(
session, obj_in=client_data
)
created_client_id = client.client_id
created_secret = secret
async with AsyncTestingSessionLocal() as session:
# Valid secret
valid = await oauth_client.verify_client_secret(
session,
client_id=created_client_id,
client_secret=created_secret,
)
assert valid is True
# Invalid secret
invalid = await oauth_client.verify_client_secret(
session,
client_id=created_client_id,
client_secret="wrong_secret",
)
assert invalid is False
@pytest.mark.asyncio
async def test_deactivate_nonexistent_client(self, async_test_db):
"""Test deactivating non-existent client returns None."""
_test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session:
result = await oauth_client.deactivate_client(
session, client_id="nonexistent_client_id"
)
assert result is None
@pytest.mark.asyncio
async def test_validate_redirect_uri_nonexistent_client(self, async_test_db):
"""Test validate_redirect_uri returns False for non-existent client."""
_test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session:
valid = await oauth_client.validate_redirect_uri(
session,
client_id="nonexistent_client_id",
redirect_uri="http://localhost:3000/callback",
)
assert valid is False
@pytest.mark.asyncio
async def test_verify_secret_nonexistent_client(self, async_test_db):
"""Test verify_client_secret returns False for non-existent client."""
_test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session:
valid = await oauth_client.verify_client_secret(
session,
client_id="nonexistent_client_id",
client_secret="any_secret",
)
assert valid is False
@pytest.mark.asyncio
async def test_verify_secret_public_client(self, async_test_db):
"""Test verify_client_secret returns False for public client (no secret)."""
_test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session:
client_data = OAuthClientCreate(
client_name="Public Client",
redirect_uris=["http://localhost:3000/callback"],
allowed_scopes=["read:users"],
client_type="public", # Public client - no secret
)
client, secret = await oauth_client.create_client(
session, obj_in=client_data
)
assert secret is None
async with AsyncTestingSessionLocal() as session:
# Public clients don't have secrets, so verification should fail
valid = await oauth_client.verify_client_secret(
session,
client_id=client.client_id,
client_secret="any_secret",
)
assert valid is False

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,571 @@
# tests/crud/test_session_async.py
"""
Comprehensive tests for async session CRUD operations.
"""
from datetime import UTC, datetime, timedelta
from uuid import uuid4
import pytest
from app.core.repository_exceptions import InvalidInputError
from app.repositories.session import session_repo as session_crud
from app.models.user_session import UserSession
from app.schemas.sessions import SessionCreate
class TestGetByJti:
"""Tests for get_by_jti method."""
@pytest.mark.asyncio
async def test_get_by_jti_success(self, async_test_db, async_test_user):
"""Test getting session by JTI."""
_test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session:
user_session = UserSession(
user_id=async_test_user.id,
refresh_token_jti="test_jti_123",
device_name="Test Device",
ip_address="192.168.1.1",
user_agent="Mozilla/5.0",
is_active=True,
expires_at=datetime.now(UTC) + timedelta(days=7),
last_used_at=datetime.now(UTC),
)
session.add(user_session)
await session.commit()
async with AsyncTestingSessionLocal() as session:
result = await session_crud.get_by_jti(session, jti="test_jti_123")
assert result is not None
assert result.refresh_token_jti == "test_jti_123"
@pytest.mark.asyncio
async def test_get_by_jti_not_found(self, async_test_db):
"""Test getting non-existent JTI returns None."""
_test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session:
result = await session_crud.get_by_jti(session, jti="nonexistent")
assert result is None
class TestGetActiveByJti:
"""Tests for get_active_by_jti method."""
@pytest.mark.asyncio
async def test_get_active_by_jti_success(self, async_test_db, async_test_user):
"""Test getting active session by JTI."""
_test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session:
user_session = UserSession(
user_id=async_test_user.id,
refresh_token_jti="active_jti",
device_name="Test Device",
ip_address="192.168.1.1",
user_agent="Mozilla/5.0",
is_active=True,
expires_at=datetime.now(UTC) + timedelta(days=7),
last_used_at=datetime.now(UTC),
)
session.add(user_session)
await session.commit()
async with AsyncTestingSessionLocal() as session:
result = await session_crud.get_active_by_jti(session, jti="active_jti")
assert result is not None
assert result.is_active is True
@pytest.mark.asyncio
async def test_get_active_by_jti_inactive(self, async_test_db, async_test_user):
"""Test getting inactive session by JTI returns None."""
_test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session:
user_session = UserSession(
user_id=async_test_user.id,
refresh_token_jti="inactive_jti",
device_name="Test Device",
ip_address="192.168.1.1",
user_agent="Mozilla/5.0",
is_active=False,
expires_at=datetime.now(UTC) + timedelta(days=7),
last_used_at=datetime.now(UTC),
)
session.add(user_session)
await session.commit()
async with AsyncTestingSessionLocal() as session:
result = await session_crud.get_active_by_jti(session, jti="inactive_jti")
assert result is None
class TestGetUserSessions:
"""Tests for get_user_sessions method."""
@pytest.mark.asyncio
async def test_get_user_sessions_active_only(self, async_test_db, async_test_user):
"""Test getting only active user sessions."""
_test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session:
active = UserSession(
user_id=async_test_user.id,
refresh_token_jti="active",
device_name="Active Device",
ip_address="192.168.1.1",
user_agent="Mozilla/5.0",
is_active=True,
expires_at=datetime.now(UTC) + timedelta(days=7),
last_used_at=datetime.now(UTC),
)
inactive = UserSession(
user_id=async_test_user.id,
refresh_token_jti="inactive",
device_name="Inactive Device",
ip_address="192.168.1.2",
user_agent="Mozilla/5.0",
is_active=False,
expires_at=datetime.now(UTC) + timedelta(days=7),
last_used_at=datetime.now(UTC),
)
session.add_all([active, inactive])
await session.commit()
async with AsyncTestingSessionLocal() as session:
results = await session_crud.get_user_sessions(
session, user_id=str(async_test_user.id), active_only=True
)
assert len(results) == 1
assert results[0].is_active is True
@pytest.mark.asyncio
async def test_get_user_sessions_all(self, async_test_db, async_test_user):
"""Test getting all user sessions."""
_test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session:
for i in range(3):
sess = UserSession(
user_id=async_test_user.id,
refresh_token_jti=f"session_{i}",
device_name=f"Device {i}",
ip_address="192.168.1.1",
user_agent="Mozilla/5.0",
is_active=i % 2 == 0,
expires_at=datetime.now(UTC) + timedelta(days=7),
last_used_at=datetime.now(UTC),
)
session.add(sess)
await session.commit()
async with AsyncTestingSessionLocal() as session:
results = await session_crud.get_user_sessions(
session, user_id=str(async_test_user.id), active_only=False
)
assert len(results) == 3
class TestCreateSession:
"""Tests for create_session method."""
@pytest.mark.asyncio
async def test_create_session_success(self, async_test_db, async_test_user):
"""Test successfully creating a session_crud."""
_test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session:
session_data = SessionCreate(
user_id=async_test_user.id,
refresh_token_jti="new_jti",
device_name="New Device",
device_id="device_123",
ip_address="192.168.1.100",
user_agent="Mozilla/5.0",
last_used_at=datetime.now(UTC),
expires_at=datetime.now(UTC) + timedelta(days=7),
location_city="San Francisco",
location_country="USA",
)
result = await session_crud.create_session(session, obj_in=session_data)
assert result.user_id == async_test_user.id
assert result.refresh_token_jti == "new_jti"
assert result.is_active is True
assert result.location_city == "San Francisco"
class TestDeactivate:
"""Tests for deactivate method."""
@pytest.mark.asyncio
async def test_deactivate_success(self, async_test_db, async_test_user):
"""Test successfully deactivating a session_crud."""
_test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session:
user_session = UserSession(
user_id=async_test_user.id,
refresh_token_jti="to_deactivate",
device_name="Test Device",
ip_address="192.168.1.1",
user_agent="Mozilla/5.0",
is_active=True,
expires_at=datetime.now(UTC) + timedelta(days=7),
last_used_at=datetime.now(UTC),
)
session.add(user_session)
await session.commit()
session_id = user_session.id
async with AsyncTestingSessionLocal() as session:
result = await session_crud.deactivate(session, session_id=str(session_id))
assert result is not None
assert result.is_active is False
@pytest.mark.asyncio
async def test_deactivate_not_found(self, async_test_db):
"""Test deactivating non-existent session returns None."""
_test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session:
result = await session_crud.deactivate(session, session_id=str(uuid4()))
assert result is None
class TestDeactivateAllUserSessions:
"""Tests for deactivate_all_user_sessions method."""
@pytest.mark.asyncio
async def test_deactivate_all_user_sessions_success(
self, async_test_db, async_test_user
):
"""Test deactivating all user sessions."""
_test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session:
# Create minimal sessions for test (2 instead of 5)
for i in range(2):
sess = UserSession(
user_id=async_test_user.id,
refresh_token_jti=f"bulk_{i}",
device_name=f"Device {i}",
ip_address="192.168.1.1",
user_agent="Mozilla/5.0",
is_active=True,
expires_at=datetime.now(UTC) + timedelta(days=7),
last_used_at=datetime.now(UTC),
)
session.add(sess)
await session.commit()
async with AsyncTestingSessionLocal() as session:
count = await session_crud.deactivate_all_user_sessions(
session, user_id=str(async_test_user.id)
)
assert count == 2
class TestUpdateLastUsed:
"""Tests for update_last_used method."""
@pytest.mark.asyncio
async def test_update_last_used_success(self, async_test_db, async_test_user):
"""Test updating last_used_at timestamp."""
_test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session:
user_session = UserSession(
user_id=async_test_user.id,
refresh_token_jti="update_test",
device_name="Test Device",
ip_address="192.168.1.1",
user_agent="Mozilla/5.0",
is_active=True,
expires_at=datetime.now(UTC) + timedelta(days=7),
last_used_at=datetime.now(UTC) - timedelta(hours=1),
)
session.add(user_session)
await session.commit()
await session.refresh(user_session)
old_time = user_session.last_used_at
result = await session_crud.update_last_used(session, session=user_session)
assert result.last_used_at > old_time
class TestGetUserSessionCount:
"""Tests for get_user_session_count method."""
@pytest.mark.asyncio
async def test_get_user_session_count_success(self, async_test_db, async_test_user):
"""Test getting user session count."""
_test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session:
for i in range(3):
sess = UserSession(
user_id=async_test_user.id,
refresh_token_jti=f"count_{i}",
device_name=f"Device {i}",
ip_address="192.168.1.1",
user_agent="Mozilla/5.0",
is_active=True,
expires_at=datetime.now(UTC) + timedelta(days=7),
last_used_at=datetime.now(UTC),
)
session.add(sess)
await session.commit()
async with AsyncTestingSessionLocal() as session:
count = await session_crud.get_user_session_count(
session, user_id=str(async_test_user.id)
)
assert count == 3
@pytest.mark.asyncio
async def test_get_user_session_count_empty(self, async_test_db):
"""Test getting session count for user with no sessions."""
_test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session:
count = await session_crud.get_user_session_count(
session, user_id=str(uuid4())
)
assert count == 0
class TestUpdateRefreshToken:
"""Tests for update_refresh_token method."""
@pytest.mark.asyncio
async def test_update_refresh_token_success(self, async_test_db, async_test_user):
"""Test updating refresh token JTI and expiration."""
_test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session:
user_session = UserSession(
user_id=async_test_user.id,
refresh_token_jti="old_jti",
device_name="Test Device",
ip_address="192.168.1.1",
user_agent="Mozilla/5.0",
is_active=True,
expires_at=datetime.now(UTC) + timedelta(days=7),
last_used_at=datetime.now(UTC) - timedelta(hours=1),
)
session.add(user_session)
await session.commit()
await session.refresh(user_session)
new_jti = "new_jti_123"
new_expires = datetime.now(UTC) + timedelta(days=14)
result = await session_crud.update_refresh_token(
session,
session=user_session,
new_jti=new_jti,
new_expires_at=new_expires,
)
assert result.refresh_token_jti == new_jti
# Compare timestamps ignoring timezone info
assert (
abs(
(
result.expires_at.replace(tzinfo=None)
- new_expires.replace(tzinfo=None)
).total_seconds()
)
< 1
)
class TestCleanupExpired:
"""Tests for cleanup_expired method."""
@pytest.mark.asyncio
async def test_cleanup_expired_success(self, async_test_db, async_test_user):
"""Test cleaning up old expired inactive sessions."""
_test_engine, AsyncTestingSessionLocal = async_test_db
# Create old expired inactive session
async with AsyncTestingSessionLocal() as session:
old_session = UserSession(
user_id=async_test_user.id,
refresh_token_jti="old_expired",
device_name="Old Device",
ip_address="192.168.1.1",
user_agent="Mozilla/5.0",
is_active=False,
expires_at=datetime.now(UTC) - timedelta(days=5),
last_used_at=datetime.now(UTC) - timedelta(days=35),
created_at=datetime.now(UTC) - timedelta(days=35),
)
session.add(old_session)
await session.commit()
# Cleanup
async with AsyncTestingSessionLocal() as session:
count = await session_crud.cleanup_expired(session, keep_days=30)
assert count == 1
@pytest.mark.asyncio
async def test_cleanup_expired_keeps_recent(self, async_test_db, async_test_user):
"""Test that cleanup keeps recent expired sessions."""
_test_engine, AsyncTestingSessionLocal = async_test_db
# Create recent expired inactive session (less than keep_days old)
async with AsyncTestingSessionLocal() as session:
recent_session = UserSession(
user_id=async_test_user.id,
refresh_token_jti="recent_expired",
device_name="Recent Device",
ip_address="192.168.1.1",
user_agent="Mozilla/5.0",
is_active=False,
expires_at=datetime.now(UTC) - timedelta(hours=1),
last_used_at=datetime.now(UTC) - timedelta(hours=2),
created_at=datetime.now(UTC) - timedelta(days=1),
)
session.add(recent_session)
await session.commit()
# Cleanup
async with AsyncTestingSessionLocal() as session:
count = await session_crud.cleanup_expired(session, keep_days=30)
assert count == 0 # Should not delete recent sessions
@pytest.mark.asyncio
async def test_cleanup_expired_keeps_active(self, async_test_db, async_test_user):
"""Test that cleanup does not delete active sessions."""
_test_engine, AsyncTestingSessionLocal = async_test_db
# Create old expired but ACTIVE session
async with AsyncTestingSessionLocal() as session:
active_session = UserSession(
user_id=async_test_user.id,
refresh_token_jti="active_expired",
device_name="Active Device",
ip_address="192.168.1.1",
user_agent="Mozilla/5.0",
is_active=True, # Active
expires_at=datetime.now(UTC) - timedelta(days=5),
last_used_at=datetime.now(UTC) - timedelta(days=35),
created_at=datetime.now(UTC) - timedelta(days=35),
)
session.add(active_session)
await session.commit()
# Cleanup
async with AsyncTestingSessionLocal() as session:
count = await session_crud.cleanup_expired(session, keep_days=30)
assert count == 0 # Should not delete active sessions
class TestCleanupExpiredForUser:
"""Tests for cleanup_expired_for_user method."""
@pytest.mark.asyncio
async def test_cleanup_expired_for_user_success(
self, async_test_db, async_test_user
):
"""Test cleaning up expired sessions for specific user."""
_test_engine, AsyncTestingSessionLocal = async_test_db
# Create expired inactive session for user
async with AsyncTestingSessionLocal() as session:
expired_session = UserSession(
user_id=async_test_user.id,
refresh_token_jti="user_expired",
device_name="Expired Device",
ip_address="192.168.1.1",
user_agent="Mozilla/5.0",
is_active=False,
expires_at=datetime.now(UTC) - timedelta(days=1),
last_used_at=datetime.now(UTC) - timedelta(days=2),
)
session.add(expired_session)
await session.commit()
# Cleanup for user
async with AsyncTestingSessionLocal() as session:
count = await session_crud.cleanup_expired_for_user(
session, user_id=str(async_test_user.id)
)
assert count == 1
@pytest.mark.asyncio
async def test_cleanup_expired_for_user_invalid_uuid(self, async_test_db):
"""Test cleanup with invalid user UUID."""
_test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session:
with pytest.raises(InvalidInputError, match="Invalid user ID format"):
await session_crud.cleanup_expired_for_user(
session, user_id="not-a-valid-uuid"
)
@pytest.mark.asyncio
async def test_cleanup_expired_for_user_keeps_active(
self, async_test_db, async_test_user
):
"""Test that cleanup for user keeps active sessions."""
_test_engine, AsyncTestingSessionLocal = async_test_db
# Create expired but active session
async with AsyncTestingSessionLocal() as session:
active_session = UserSession(
user_id=async_test_user.id,
refresh_token_jti="active_user_expired",
device_name="Active Device",
ip_address="192.168.1.1",
user_agent="Mozilla/5.0",
is_active=True, # Active
expires_at=datetime.now(UTC) - timedelta(days=1),
last_used_at=datetime.now(UTC) - timedelta(days=2),
)
session.add(active_session)
await session.commit()
# Cleanup
async with AsyncTestingSessionLocal() as session:
count = await session_crud.cleanup_expired_for_user(
session, user_id=str(async_test_user.id)
)
assert count == 0 # Should not delete active sessions
class TestGetUserSessionsWithUser:
"""Tests for get_user_sessions with eager loading."""
@pytest.mark.asyncio
async def test_get_user_sessions_with_user_relationship(
self, async_test_db, async_test_user
):
"""Test getting sessions with user relationship loaded."""
_test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session:
user_session = UserSession(
user_id=async_test_user.id,
refresh_token_jti="with_user",
device_name="Test Device",
ip_address="192.168.1.1",
user_agent="Mozilla/5.0",
is_active=True,
expires_at=datetime.now(UTC) + timedelta(days=7),
last_used_at=datetime.now(UTC),
)
session.add(user_session)
await session.commit()
# Get with user relationship
async with AsyncTestingSessionLocal() as session:
results = await session_crud.get_user_sessions(
session, user_id=str(async_test_user.id), with_user=True
)
assert len(results) >= 1

View File

@@ -0,0 +1,387 @@
# tests/crud/test_session_db_failures.py
"""
Comprehensive tests for session CRUD database failure scenarios.
"""
from datetime import UTC, datetime, timedelta
from unittest.mock import AsyncMock, patch
from uuid import uuid4
import pytest
from sqlalchemy.exc import OperationalError
from app.core.repository_exceptions import IntegrityConstraintError
from app.repositories.session import session_repo as session_crud
from app.models.user_session import UserSession
from app.schemas.sessions import SessionCreate
class TestSessionCRUDGetByJtiFailures:
"""Test get_by_jti exception handling."""
@pytest.mark.asyncio
async def test_get_by_jti_database_error(self, async_test_db):
"""Test get_by_jti handles database errors."""
_test_engine, SessionLocal = async_test_db
async with SessionLocal() as session:
async def mock_execute(*args, **kwargs):
raise OperationalError("DB connection lost", {}, Exception())
with patch.object(session, "execute", side_effect=mock_execute):
with pytest.raises(OperationalError):
await session_crud.get_by_jti(session, jti="test_jti")
class TestSessionCRUDGetActiveByJtiFailures:
"""Test get_active_by_jti exception handling."""
@pytest.mark.asyncio
async def test_get_active_by_jti_database_error(self, async_test_db):
"""Test get_active_by_jti handles database errors."""
_test_engine, SessionLocal = async_test_db
async with SessionLocal() as session:
async def mock_execute(*args, **kwargs):
raise OperationalError("Query timeout", {}, Exception())
with patch.object(session, "execute", side_effect=mock_execute):
with pytest.raises(OperationalError):
await session_crud.get_active_by_jti(session, jti="test_jti")
class TestSessionCRUDGetUserSessionsFailures:
"""Test get_user_sessions exception handling."""
@pytest.mark.asyncio
async def test_get_user_sessions_database_error(
self, async_test_db, async_test_user
):
"""Test get_user_sessions handles database errors."""
_test_engine, SessionLocal = async_test_db
async with SessionLocal() as session:
async def mock_execute(*args, **kwargs):
raise OperationalError("Database error", {}, Exception())
with patch.object(session, "execute", side_effect=mock_execute):
with pytest.raises(OperationalError):
await session_crud.get_user_sessions(
session, user_id=str(async_test_user.id)
)
class TestSessionCRUDCreateSessionFailures:
"""Test create_session exception handling."""
@pytest.mark.asyncio
async def test_create_session_commit_failure_triggers_rollback(
self, async_test_db, async_test_user
):
"""Test create_session handles commit failures with rollback."""
_test_engine, SessionLocal = async_test_db
async with SessionLocal() as session:
async def mock_commit():
raise OperationalError("Commit failed", {}, Exception())
with patch.object(session, "commit", side_effect=mock_commit):
with patch.object(
session, "rollback", new_callable=AsyncMock
) as mock_rollback:
session_data = SessionCreate(
user_id=async_test_user.id,
refresh_token_jti=str(uuid4()),
device_name="Test Device",
ip_address="127.0.0.1",
user_agent="Test Agent",
expires_at=datetime.now(UTC) + timedelta(days=7),
last_used_at=datetime.now(UTC),
)
with pytest.raises(IntegrityConstraintError, match="Failed to create session"):
await session_crud.create_session(session, obj_in=session_data)
mock_rollback.assert_called_once()
@pytest.mark.asyncio
async def test_create_session_unexpected_error_triggers_rollback(
self, async_test_db, async_test_user
):
"""Test create_session handles unexpected errors."""
_test_engine, SessionLocal = async_test_db
async with SessionLocal() as session:
async def mock_commit():
raise RuntimeError("Unexpected error")
with patch.object(session, "commit", side_effect=mock_commit):
with patch.object(
session, "rollback", new_callable=AsyncMock
) as mock_rollback:
session_data = SessionCreate(
user_id=async_test_user.id,
refresh_token_jti=str(uuid4()),
device_name="Test Device",
ip_address="127.0.0.1",
user_agent="Test Agent",
expires_at=datetime.now(UTC) + timedelta(days=7),
last_used_at=datetime.now(UTC),
)
with pytest.raises(IntegrityConstraintError, match="Failed to create session"):
await session_crud.create_session(session, obj_in=session_data)
mock_rollback.assert_called_once()
class TestSessionCRUDDeactivateFailures:
"""Test deactivate exception handling."""
@pytest.mark.asyncio
async def test_deactivate_commit_failure_triggers_rollback(
self, async_test_db, async_test_user
):
"""Test deactivate handles commit failures."""
_test_engine, SessionLocal = async_test_db
# Create a session first
async with SessionLocal() as session:
user_session = UserSession(
user_id=async_test_user.id,
refresh_token_jti=str(uuid4()),
device_name="Test Device",
ip_address="127.0.0.1",
user_agent="Test Agent",
is_active=True,
expires_at=datetime.now(UTC) + timedelta(days=7),
last_used_at=datetime.now(UTC),
)
session.add(user_session)
await session.commit()
await session.refresh(user_session)
session_id = user_session.id
# Test deactivate failure
async with SessionLocal() as session:
async def mock_commit():
raise OperationalError("Deactivate failed", {}, Exception())
with patch.object(session, "commit", side_effect=mock_commit):
with patch.object(
session, "rollback", new_callable=AsyncMock
) as mock_rollback:
with pytest.raises(OperationalError):
await session_crud.deactivate(
session, session_id=str(session_id)
)
mock_rollback.assert_called_once()
class TestSessionCRUDDeactivateAllFailures:
"""Test deactivate_all_user_sessions exception handling."""
@pytest.mark.asyncio
async def test_deactivate_all_commit_failure_triggers_rollback(
self, async_test_db, async_test_user
):
"""Test deactivate_all handles commit failures."""
_test_engine, SessionLocal = async_test_db
async with SessionLocal() as session:
async def mock_commit():
raise OperationalError("Bulk deactivate failed", {}, Exception())
with patch.object(session, "commit", side_effect=mock_commit):
with patch.object(
session, "rollback", new_callable=AsyncMock
) as mock_rollback:
with pytest.raises(OperationalError):
await session_crud.deactivate_all_user_sessions(
session, user_id=str(async_test_user.id)
)
mock_rollback.assert_called_once()
class TestSessionCRUDUpdateLastUsedFailures:
"""Test update_last_used exception handling."""
@pytest.mark.asyncio
async def test_update_last_used_commit_failure_triggers_rollback(
self, async_test_db, async_test_user
):
"""Test update_last_used handles commit failures."""
_test_engine, SessionLocal = async_test_db
# Create a session
async with SessionLocal() as session:
user_session = UserSession(
user_id=async_test_user.id,
refresh_token_jti=str(uuid4()),
device_name="Test Device",
ip_address="127.0.0.1",
user_agent="Test Agent",
is_active=True,
expires_at=datetime.now(UTC) + timedelta(days=7),
last_used_at=datetime.now(UTC) - timedelta(hours=1),
)
session.add(user_session)
await session.commit()
await session.refresh(user_session)
# Test update failure
async with SessionLocal() as session:
from sqlalchemy import select
from app.models.user_session import UserSession as US
result = await session.execute(select(US).where(US.id == user_session.id))
sess = result.scalar_one()
async def mock_commit():
raise OperationalError("Update failed", {}, Exception())
with patch.object(session, "commit", side_effect=mock_commit):
with patch.object(
session, "rollback", new_callable=AsyncMock
) as mock_rollback:
with pytest.raises(OperationalError):
await session_crud.update_last_used(session, session=sess)
mock_rollback.assert_called_once()
class TestSessionCRUDUpdateRefreshTokenFailures:
"""Test update_refresh_token exception handling."""
@pytest.mark.asyncio
async def test_update_refresh_token_commit_failure_triggers_rollback(
self, async_test_db, async_test_user
):
"""Test update_refresh_token handles commit failures."""
_test_engine, SessionLocal = async_test_db
# Create a session
async with SessionLocal() as session:
user_session = UserSession(
user_id=async_test_user.id,
refresh_token_jti=str(uuid4()),
device_name="Test Device",
ip_address="127.0.0.1",
user_agent="Test Agent",
is_active=True,
expires_at=datetime.now(UTC) + timedelta(days=7),
last_used_at=datetime.now(UTC),
)
session.add(user_session)
await session.commit()
await session.refresh(user_session)
# Test update failure
async with SessionLocal() as session:
from sqlalchemy import select
from app.models.user_session import UserSession as US
result = await session.execute(select(US).where(US.id == user_session.id))
sess = result.scalar_one()
async def mock_commit():
raise OperationalError("Token update failed", {}, Exception())
with patch.object(session, "commit", side_effect=mock_commit):
with patch.object(
session, "rollback", new_callable=AsyncMock
) as mock_rollback:
with pytest.raises(OperationalError):
await session_crud.update_refresh_token(
session,
session=sess,
new_jti=str(uuid4()),
new_expires_at=datetime.now(UTC) + timedelta(days=14),
)
mock_rollback.assert_called_once()
class TestSessionCRUDCleanupExpiredFailures:
"""Test cleanup_expired exception handling."""
@pytest.mark.asyncio
async def test_cleanup_expired_commit_failure_triggers_rollback(
self, async_test_db
):
"""Test cleanup_expired handles commit failures."""
_test_engine, SessionLocal = async_test_db
async with SessionLocal() as session:
async def mock_commit():
raise OperationalError("Cleanup failed", {}, Exception())
with patch.object(session, "commit", side_effect=mock_commit):
with patch.object(
session, "rollback", new_callable=AsyncMock
) as mock_rollback:
with pytest.raises(OperationalError):
await session_crud.cleanup_expired(session, keep_days=30)
mock_rollback.assert_called_once()
class TestSessionCRUDCleanupExpiredForUserFailures:
"""Test cleanup_expired_for_user exception handling."""
@pytest.mark.asyncio
async def test_cleanup_expired_for_user_commit_failure_triggers_rollback(
self, async_test_db, async_test_user
):
"""Test cleanup_expired_for_user handles commit failures."""
_test_engine, SessionLocal = async_test_db
async with SessionLocal() as session:
async def mock_commit():
raise OperationalError("User cleanup failed", {}, Exception())
with patch.object(session, "commit", side_effect=mock_commit):
with patch.object(
session, "rollback", new_callable=AsyncMock
) as mock_rollback:
with pytest.raises(OperationalError):
await session_crud.cleanup_expired_for_user(
session, user_id=str(async_test_user.id)
)
mock_rollback.assert_called_once()
class TestSessionCRUDGetUserSessionCountFailures:
"""Test get_user_session_count exception handling."""
@pytest.mark.asyncio
async def test_get_user_session_count_database_error(
self, async_test_db, async_test_user
):
"""Test get_user_session_count handles database errors."""
_test_engine, SessionLocal = async_test_db
async with SessionLocal() as session:
async def mock_execute(*args, **kwargs):
raise OperationalError("Count query failed", {}, Exception())
with patch.object(session, "execute", side_effect=mock_execute):
with pytest.raises(OperationalError):
await session_crud.get_user_session_count(
session, user_id=str(async_test_user.id)
)

View File

@@ -0,0 +1,665 @@
# tests/crud/test_user_async.py
"""
Comprehensive tests for async user CRUD operations.
"""
import pytest
from app.core.repository_exceptions import DuplicateEntryError, InvalidInputError
from app.repositories.user import user_repo as user_crud
from app.schemas.users import UserCreate, UserUpdate
class TestGetByEmail:
"""Tests for get_by_email method."""
@pytest.mark.asyncio
async def test_get_by_email_success(self, async_test_db, async_test_user):
"""Test getting user by email."""
_test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session:
result = await user_crud.get_by_email(session, email=async_test_user.email)
assert result is not None
assert result.email == async_test_user.email
assert result.id == async_test_user.id
@pytest.mark.asyncio
async def test_get_by_email_not_found(self, async_test_db):
"""Test getting non-existent email returns None."""
_test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session:
result = await user_crud.get_by_email(
session, email="nonexistent@example.com"
)
assert result is None
class TestCreate:
"""Tests for create method."""
@pytest.mark.asyncio
async def test_create_user_success(self, async_test_db):
"""Test successfully creating a user_crud."""
_test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session:
user_data = UserCreate(
email="newuser@example.com",
password="SecurePass123!",
first_name="New",
last_name="User",
phone_number="+1234567890",
)
result = await user_crud.create(session, obj_in=user_data)
assert result.email == "newuser@example.com"
assert result.first_name == "New"
assert result.last_name == "User"
assert result.phone_number == "+1234567890"
assert result.is_active is True
assert result.is_superuser is False
assert result.password_hash is not None
assert result.password_hash != "SecurePass123!" # Password should be hashed
@pytest.mark.asyncio
async def test_create_superuser_success(self, async_test_db):
"""Test creating a superuser."""
_test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session:
user_data = UserCreate(
email="superuser@example.com",
password="SuperPass123!",
first_name="Super",
last_name="User",
is_superuser=True,
)
result = await user_crud.create(session, obj_in=user_data)
assert result.is_superuser is True
assert result.email == "superuser@example.com"
@pytest.mark.asyncio
async def test_create_duplicate_email_fails(self, async_test_db, async_test_user):
"""Test creating user with duplicate email raises ValueError."""
_test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session:
user_data = UserCreate(
email=async_test_user.email, # Duplicate email
password="AnotherPass123!",
first_name="Duplicate",
last_name="User",
)
with pytest.raises(DuplicateEntryError) as exc_info:
await user_crud.create(session, obj_in=user_data)
assert "already exists" in str(exc_info.value).lower()
class TestUpdate:
"""Tests for update method."""
@pytest.mark.asyncio
async def test_update_user_basic_fields(self, async_test_db, async_test_user):
"""Test updating basic user fields."""
_test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session:
# Get fresh copy of user
user = await user_crud.get(session, id=str(async_test_user.id))
update_data = UserUpdate(
first_name="Updated", last_name="Name", phone_number="+9876543210"
)
result = await user_crud.update(session, db_obj=user, obj_in=update_data)
assert result.first_name == "Updated"
assert result.last_name == "Name"
assert result.phone_number == "+9876543210"
@pytest.mark.asyncio
async def test_update_user_password(self, async_test_db):
"""Test updating user password."""
_test_engine, AsyncTestingSessionLocal = async_test_db
# Create a fresh user for this test
async with AsyncTestingSessionLocal() as session:
user_data = UserCreate(
email="passwordtest@example.com",
password="OldPassword123!",
first_name="Pass",
last_name="Test",
)
user = await user_crud.create(session, obj_in=user_data)
user_id = user.id
old_password_hash = user.password_hash
# Update the password
async with AsyncTestingSessionLocal() as session:
user = await user_crud.get(session, id=str(user_id))
update_data = UserUpdate(password="NewDifferentPassword123!")
result = await user_crud.update(session, db_obj=user, obj_in=update_data)
await session.refresh(result)
assert result.password_hash != old_password_hash
assert result.password_hash is not None
assert (
"NewDifferentPassword123!" not in result.password_hash
) # Should be hashed
@pytest.mark.asyncio
async def test_update_user_with_dict(self, async_test_db, async_test_user):
"""Test updating user with dictionary."""
_test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session:
user = await user_crud.get(session, id=str(async_test_user.id))
update_dict = {"first_name": "DictUpdate"}
result = await user_crud.update(session, db_obj=user, obj_in=update_dict)
assert result.first_name == "DictUpdate"
class TestGetMultiWithTotal:
"""Tests for get_multi_with_total method."""
@pytest.mark.asyncio
async def test_get_multi_with_total_basic(self, async_test_db, async_test_user):
"""Test basic pagination."""
_test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session:
users, total = await user_crud.get_multi_with_total(
session, skip=0, limit=10
)
assert total >= 1
assert len(users) >= 1
assert any(u.id == async_test_user.id for u in users)
@pytest.mark.asyncio
async def test_get_multi_with_total_sorting_asc(self, async_test_db):
"""Test sorting in ascending order."""
_test_engine, AsyncTestingSessionLocal = async_test_db
# Create multiple users
async with AsyncTestingSessionLocal() as session:
for i in range(3):
user_data = UserCreate(
email=f"sort{i}@example.com",
password="SecurePass123!",
first_name=f"User{i}",
last_name="Test",
)
await user_crud.create(session, obj_in=user_data)
async with AsyncTestingSessionLocal() as session:
users, _total = await user_crud.get_multi_with_total(
session, skip=0, limit=10, sort_by="email", sort_order="asc"
)
# Check if sorted (at least the test users)
test_users = [u for u in users if u.email.startswith("sort")]
if len(test_users) > 1:
assert test_users[0].email < test_users[1].email
@pytest.mark.asyncio
async def test_get_multi_with_total_sorting_desc(self, async_test_db):
"""Test sorting in descending order."""
_test_engine, AsyncTestingSessionLocal = async_test_db
# Create multiple users
async with AsyncTestingSessionLocal() as session:
for i in range(3):
user_data = UserCreate(
email=f"desc{i}@example.com",
password="SecurePass123!",
first_name=f"User{i}",
last_name="Test",
)
await user_crud.create(session, obj_in=user_data)
async with AsyncTestingSessionLocal() as session:
users, _total = await user_crud.get_multi_with_total(
session, skip=0, limit=10, sort_by="email", sort_order="desc"
)
# Check if sorted descending (at least the test users)
test_users = [u for u in users if u.email.startswith("desc")]
if len(test_users) > 1:
assert test_users[0].email > test_users[1].email
@pytest.mark.asyncio
async def test_get_multi_with_total_filtering(self, async_test_db):
"""Test filtering by field."""
_test_engine, AsyncTestingSessionLocal = async_test_db
# Create active and inactive users
async with AsyncTestingSessionLocal() as session:
active_user = UserCreate(
email="active@example.com",
password="SecurePass123!",
first_name="Active",
last_name="User",
)
await user_crud.create(session, obj_in=active_user)
inactive_user = UserCreate(
email="inactive@example.com",
password="SecurePass123!",
first_name="Inactive",
last_name="User",
)
created_inactive = await user_crud.create(session, obj_in=inactive_user)
# Deactivate the user
await user_crud.update(
session, db_obj=created_inactive, obj_in={"is_active": False}
)
async with AsyncTestingSessionLocal() as session:
users, _total = await user_crud.get_multi_with_total(
session, skip=0, limit=100, filters={"is_active": True}
)
# All returned users should be active
assert all(u.is_active for u in users)
@pytest.mark.asyncio
async def test_get_multi_with_total_search(self, async_test_db):
"""Test search functionality."""
_test_engine, AsyncTestingSessionLocal = async_test_db
# Create user with unique name
async with AsyncTestingSessionLocal() as session:
user_data = UserCreate(
email="searchable@example.com",
password="SecurePass123!",
first_name="Searchable",
last_name="UserName",
)
await user_crud.create(session, obj_in=user_data)
async with AsyncTestingSessionLocal() as session:
users, total = await user_crud.get_multi_with_total(
session, skip=0, limit=100, search="Searchable"
)
assert total >= 1
assert any(u.first_name == "Searchable" for u in users)
@pytest.mark.asyncio
async def test_get_multi_with_total_pagination(self, async_test_db):
"""Test pagination with skip and limit."""
_test_engine, AsyncTestingSessionLocal = async_test_db
# Create multiple users
async with AsyncTestingSessionLocal() as session:
for i in range(5):
user_data = UserCreate(
email=f"page{i}@example.com",
password="SecurePass123!",
first_name=f"Page{i}",
last_name="User",
)
await user_crud.create(session, obj_in=user_data)
async with AsyncTestingSessionLocal() as session:
# Get first page
users_page1, total = await user_crud.get_multi_with_total(
session, skip=0, limit=2
)
# Get second page
users_page2, total2 = await user_crud.get_multi_with_total(
session, skip=2, limit=2
)
# Total should be same
assert total == total2
# Different users on different pages
assert users_page1[0].id != users_page2[0].id
@pytest.mark.asyncio
async def test_get_multi_with_total_validation_negative_skip(self, async_test_db):
"""Test validation fails for negative skip."""
_test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session:
with pytest.raises(InvalidInputError) as exc_info:
await user_crud.get_multi_with_total(session, skip=-1, limit=10)
assert "skip must be non-negative" in str(exc_info.value)
@pytest.mark.asyncio
async def test_get_multi_with_total_validation_negative_limit(self, async_test_db):
"""Test validation fails for negative limit."""
_test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session:
with pytest.raises(InvalidInputError) as exc_info:
await user_crud.get_multi_with_total(session, skip=0, limit=-1)
assert "limit must be non-negative" in str(exc_info.value)
@pytest.mark.asyncio
async def test_get_multi_with_total_validation_max_limit(self, async_test_db):
"""Test validation fails for limit > 1000."""
_test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session:
with pytest.raises(InvalidInputError) as exc_info:
await user_crud.get_multi_with_total(session, skip=0, limit=1001)
assert "Maximum limit is 1000" in str(exc_info.value)
class TestBulkUpdateStatus:
"""Tests for bulk_update_status method."""
@pytest.mark.asyncio
async def test_bulk_update_status_success(self, async_test_db):
"""Test bulk updating user status."""
_test_engine, AsyncTestingSessionLocal = async_test_db
# Create multiple users
user_ids = []
async with AsyncTestingSessionLocal() as session:
for i in range(3):
user_data = UserCreate(
email=f"bulk{i}@example.com",
password="SecurePass123!",
first_name=f"Bulk{i}",
last_name="User",
)
user = await user_crud.create(session, obj_in=user_data)
user_ids.append(user.id)
# Bulk deactivate
async with AsyncTestingSessionLocal() as session:
count = await user_crud.bulk_update_status(
session, user_ids=user_ids, is_active=False
)
assert count == 3
# Verify all are inactive
async with AsyncTestingSessionLocal() as session:
for user_id in user_ids:
user = await user_crud.get(session, id=str(user_id))
assert user.is_active is False
@pytest.mark.asyncio
async def test_bulk_update_status_empty_list(self, async_test_db):
"""Test bulk update with empty list returns 0."""
_test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session:
count = await user_crud.bulk_update_status(
session, user_ids=[], is_active=False
)
assert count == 0
@pytest.mark.asyncio
async def test_bulk_update_status_reactivate(self, async_test_db):
"""Test bulk reactivating users."""
_test_engine, AsyncTestingSessionLocal = async_test_db
# Create inactive user
async with AsyncTestingSessionLocal() as session:
user_data = UserCreate(
email="reactivate@example.com",
password="SecurePass123!",
first_name="Reactivate",
last_name="User",
)
user = await user_crud.create(session, obj_in=user_data)
# Deactivate
await user_crud.update(session, db_obj=user, obj_in={"is_active": False})
user_id = user.id
# Reactivate
async with AsyncTestingSessionLocal() as session:
count = await user_crud.bulk_update_status(
session, user_ids=[user_id], is_active=True
)
assert count == 1
# Verify active
async with AsyncTestingSessionLocal() as session:
user = await user_crud.get(session, id=str(user_id))
assert user.is_active is True
class TestBulkSoftDelete:
"""Tests for bulk_soft_delete method."""
@pytest.mark.asyncio
async def test_bulk_soft_delete_success(self, async_test_db):
"""Test bulk soft deleting users."""
_test_engine, AsyncTestingSessionLocal = async_test_db
# Create multiple users
user_ids = []
async with AsyncTestingSessionLocal() as session:
for i in range(3):
user_data = UserCreate(
email=f"delete{i}@example.com",
password="SecurePass123!",
first_name=f"Delete{i}",
last_name="User",
)
user = await user_crud.create(session, obj_in=user_data)
user_ids.append(user.id)
# Bulk delete
async with AsyncTestingSessionLocal() as session:
count = await user_crud.bulk_soft_delete(session, user_ids=user_ids)
assert count == 3
# Verify all are soft deleted
async with AsyncTestingSessionLocal() as session:
for user_id in user_ids:
user = await user_crud.get(session, id=str(user_id))
assert user.deleted_at is not None
assert user.is_active is False
@pytest.mark.asyncio
async def test_bulk_soft_delete_with_exclusion(self, async_test_db):
"""Test bulk soft delete with excluded user_crud."""
_test_engine, AsyncTestingSessionLocal = async_test_db
# Create multiple users
user_ids = []
async with AsyncTestingSessionLocal() as session:
for i in range(3):
user_data = UserCreate(
email=f"exclude{i}@example.com",
password="SecurePass123!",
first_name=f"Exclude{i}",
last_name="User",
)
user = await user_crud.create(session, obj_in=user_data)
user_ids.append(user.id)
# Bulk delete, excluding first user
exclude_id = user_ids[0]
async with AsyncTestingSessionLocal() as session:
count = await user_crud.bulk_soft_delete(
session, user_ids=user_ids, exclude_user_id=exclude_id
)
assert count == 2 # Only 2 deleted
# Verify excluded user is NOT deleted
async with AsyncTestingSessionLocal() as session:
excluded_user = await user_crud.get(session, id=str(exclude_id))
assert excluded_user.deleted_at is None
@pytest.mark.asyncio
async def test_bulk_soft_delete_empty_list(self, async_test_db):
"""Test bulk delete with empty list returns 0."""
_test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session:
count = await user_crud.bulk_soft_delete(session, user_ids=[])
assert count == 0
@pytest.mark.asyncio
async def test_bulk_soft_delete_all_excluded(self, async_test_db):
"""Test bulk delete where all users are excluded."""
_test_engine, AsyncTestingSessionLocal = async_test_db
# Create user
async with AsyncTestingSessionLocal() as session:
user_data = UserCreate(
email="onlyuser@example.com",
password="SecurePass123!",
first_name="Only",
last_name="User",
)
user = await user_crud.create(session, obj_in=user_data)
user_id = user.id
# Try to delete but exclude
async with AsyncTestingSessionLocal() as session:
count = await user_crud.bulk_soft_delete(
session, user_ids=[user_id], exclude_user_id=user_id
)
assert count == 0
@pytest.mark.asyncio
async def test_bulk_soft_delete_already_deleted(self, async_test_db):
"""Test bulk delete doesn't re-delete already deleted users."""
_test_engine, AsyncTestingSessionLocal = async_test_db
# Create and delete user
async with AsyncTestingSessionLocal() as session:
user_data = UserCreate(
email="predeleted@example.com",
password="SecurePass123!",
first_name="PreDeleted",
last_name="User",
)
user = await user_crud.create(session, obj_in=user_data)
user_id = user.id
# First deletion
await user_crud.bulk_soft_delete(session, user_ids=[user_id])
# Try to delete again
async with AsyncTestingSessionLocal() as session:
count = await user_crud.bulk_soft_delete(session, user_ids=[user_id])
assert count == 0 # Already deleted
class TestUtilityMethods:
"""Tests for utility methods."""
@pytest.mark.asyncio
async def test_is_active_true(self, async_test_db, async_test_user):
"""Test is_active returns True for active user_crud."""
_test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session:
user = await user_crud.get(session, id=str(async_test_user.id))
assert user_crud.is_active(user) is True
@pytest.mark.asyncio
async def test_is_active_false(self, async_test_db):
"""Test is_active returns False for inactive user_crud."""
_test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session:
user_data = UserCreate(
email="inactive2@example.com",
password="SecurePass123!",
first_name="Inactive",
last_name="User",
)
user = await user_crud.create(session, obj_in=user_data)
await user_crud.update(session, db_obj=user, obj_in={"is_active": False})
assert user_crud.is_active(user) is False
@pytest.mark.asyncio
async def test_is_superuser_true(self, async_test_db, async_test_superuser):
"""Test is_superuser returns True for superuser."""
_test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session:
user = await user_crud.get(session, id=str(async_test_superuser.id))
assert user_crud.is_superuser(user) is True
@pytest.mark.asyncio
async def test_is_superuser_false(self, async_test_db, async_test_user):
"""Test is_superuser returns False for regular user_crud."""
_test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session:
user = await user_crud.get(session, id=str(async_test_user.id))
assert user_crud.is_superuser(user) is False
class TestUserExceptionHandlers:
"""
Test exception handlers in user CRUD methods.
Covers lines: 30-32, 205-208, 257-260
"""
@pytest.mark.asyncio
async def test_get_by_email_database_error(self, async_test_db):
"""Test get_by_email handles database errors (covers lines 30-32)."""
from unittest.mock import patch
_test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session:
with patch.object(
session, "execute", side_effect=Exception("Database query failed")
):
with pytest.raises(Exception, match="Database query failed"):
await user_crud.get_by_email(session, email="test@example.com")
@pytest.mark.asyncio
async def test_bulk_update_status_database_error(
self, async_test_db, async_test_user
):
"""Test bulk_update_status handles database errors (covers lines 205-208)."""
from unittest.mock import AsyncMock, patch
_test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session:
# Mock execute to fail
with patch.object(
session, "execute", side_effect=Exception("Bulk update failed")
):
with patch.object(session, "rollback", new_callable=AsyncMock):
with pytest.raises(Exception, match="Bulk update failed"):
await user_crud.bulk_update_status(
session, user_ids=[async_test_user.id], is_active=False
)
@pytest.mark.asyncio
async def test_bulk_soft_delete_database_error(
self, async_test_db, async_test_user
):
"""Test bulk_soft_delete handles database errors (covers lines 257-260)."""
from unittest.mock import AsyncMock, patch
_test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session:
# Mock execute to fail
with patch.object(
session, "execute", side_effect=Exception("Bulk delete failed")
):
with patch.object(session, "rollback", new_callable=AsyncMock):
with pytest.raises(Exception, match="Bulk delete failed"):
await user_crud.bulk_soft_delete(
session, user_ids=[async_test_user.id]
)