refactor(tests): replace crud references with repo across repository test files

- Updated import statements and test logic to align with `repositories` naming changes.
- Adjusted documentation and test names for consistency with the updated naming convention.
- Improved test descriptions to reflect the repository-based structure.
This commit is contained in:
2026-03-01 19:22:16 +01:00
parent 07309013d7
commit a3f78dc801
38 changed files with 409 additions and 409 deletions

View File

@@ -37,7 +37,7 @@ Default superuser (change in production):
│ ├── app/ │ ├── app/
│ │ ├── api/ # API routes (auth, users, organizations, admin) │ │ ├── api/ # API routes (auth, users, organizations, admin)
│ │ ├── core/ # Core functionality (auth, config, database) │ │ ├── core/ # Core functionality (auth, config, database)
│ │ ├── crud/ # Database CRUD operations │ │ ├── repositories/ # Repository pattern (database operations)
│ │ ├── models/ # SQLAlchemy ORM models │ │ ├── models/ # SQLAlchemy ORM models
│ │ ├── schemas/ # Pydantic request/response schemas │ │ ├── schemas/ # Pydantic request/response schemas
│ │ ├── services/ # Business logic layer │ │ ├── services/ # Business logic layer
@@ -113,7 +113,7 @@ OAUTH_ISSUER=https://api.yourdomain.com # JWT issuer URL (must be HTTPS in
### Database Pattern ### Database Pattern
- **Async SQLAlchemy 2.0** with PostgreSQL - **Async SQLAlchemy 2.0** with PostgreSQL
- **Connection pooling**: 20 base connections, 50 max overflow - **Connection pooling**: 20 base connections, 50 max overflow
- **CRUD base class**: `crud/base.py` with common operations - **Repository base class**: `repositories/base.py` with common operations
- **Migrations**: Alembic with helper script `migrate.py` - **Migrations**: Alembic with helper script `migrate.py`
- `python migrate.py auto "message"` - Generate and apply - `python migrate.py auto "message"` - Generate and apply
- `python migrate.py list` - View history - `python migrate.py list` - View history
@@ -222,7 +222,7 @@ NEXT_PUBLIC_API_URL=http://localhost:8000/api/v1
### Adding a New API Endpoint ### Adding a New API Endpoint
1. **Define schema** in `backend/app/schemas/` 1. **Define schema** in `backend/app/schemas/`
2. **Create CRUD operations** in `backend/app/crud/` 2. **Create repository** in `backend/app/repositories/`
3. **Implement route** in `backend/app/api/routes/` 3. **Implement route** in `backend/app/api/routes/`
4. **Register router** in `backend/app/api/main.py` 4. **Register router** in `backend/app/api/main.py`
5. **Write tests** in `backend/tests/api/` 5. **Write tests** in `backend/tests/api/`
@@ -289,7 +289,7 @@ docker-compose exec backend python -c "from app.init_db import init_db; import a
- Authentication system (JWT with refresh tokens, OAuth/social login) - Authentication system (JWT with refresh tokens, OAuth/social login)
- **OAuth Provider Mode (MCP-ready)**: Full OAuth 2.0 Authorization Server - **OAuth Provider Mode (MCP-ready)**: Full OAuth 2.0 Authorization Server
- Session management (device tracking, revocation) - Session management (device tracking, revocation)
- User management (CRUD, password change) - User management (full lifecycle, password change)
- Organization system (multi-tenant with RBAC) - Organization system (multi-tenant with RBAC)
- Admin panel (user/org management, bulk operations) - Admin panel (user/org management, bulk operations)
- **Internationalization (i18n)** with English and Italian - **Internationalization (i18n)** with English and Italian

View File

@@ -148,7 +148,7 @@ async def mock_commit():
with patch.object(session, 'commit', side_effect=mock_commit): with patch.object(session, 'commit', side_effect=mock_commit):
with patch.object(session, 'rollback', new_callable=AsyncMock) as mock_rollback: with patch.object(session, 'rollback', new_callable=AsyncMock) as mock_rollback:
with pytest.raises(OperationalError): with pytest.raises(OperationalError):
await crud_method(session, obj_in=data) await repo_method(session, obj_in=data)
mock_rollback.assert_called_once() mock_rollback.assert_called_once()
``` ```
@@ -171,7 +171,7 @@ with patch.object(session, 'commit', side_effect=mock_commit):
### Common Workflows Guidance ### Common Workflows Guidance
**When Adding a New Feature:** **When Adding a New Feature:**
1. Start with backend schema and CRUD 1. Start with backend schema and repository
2. Implement API route with proper authorization 2. Implement API route with proper authorization
3. Write backend tests (aim for >90% coverage) 3. Write backend tests (aim for >90% coverage)
4. Generate frontend API client: `bun run generate:api` 4. Generate frontend API client: `bun run generate:api`
@@ -224,7 +224,7 @@ with patch.object(session, 'commit', side_effect=mock_commit):
No Claude Code Skills installed yet. To create one, invoke the built-in "skill-creator" skill. No Claude Code Skills installed yet. To create one, invoke the built-in "skill-creator" skill.
**Potential skill ideas for this project:** **Potential skill ideas for this project:**
- API endpoint generator workflow (schema → CRUD → route → tests → frontend client) - API endpoint generator workflow (schema → repository → route → tests → frontend client)
- Component generator with design system compliance - Component generator with design system compliance
- Database migration troubleshooting helper - Database migration troubleshooting helper
- Test coverage analyzer and improvement suggester - Test coverage analyzer and improvement suggester

View File

@@ -204,7 +204,7 @@ export function UserProfile({ userId }: UserProfileProps) {
### Key Patterns ### Key Patterns
- **Backend**: Use CRUD pattern, keep routes thin, business logic in services - **Backend**: Use repository pattern, keep routes thin, business logic in services
- **Frontend**: Use React Query for server state, Zustand for client state - **Frontend**: Use React Query for server state, Zustand for client state
- **Both**: Handle errors gracefully, log appropriately, write tests - **Both**: Handle errors gracefully, log appropriately, write tests

View File

@@ -58,7 +58,7 @@ Full OAuth 2.0 Authorization Server for Model Context Protocol (MCP) and third-p
- User can belong to multiple organizations - User can belong to multiple organizations
### 🛠️ **Admin Panel** ### 🛠️ **Admin Panel**
- Complete user management (CRUD, activate/deactivate, bulk operations) - Complete user management (full lifecycle, activate/deactivate, bulk operations)
- Organization management (create, edit, delete, member management) - Organization management (create, edit, delete, member management)
- Session monitoring across all users - Session monitoring across all users
- Real-time statistics dashboard - Real-time statistics dashboard
@@ -322,7 +322,7 @@ Visit http://localhost:3000 to see your app!
│ ├── app/ │ ├── app/
│ │ ├── api/ # API routes and dependencies │ │ ├── api/ # API routes and dependencies
│ │ ├── core/ # Core functionality (auth, config, database) │ │ ├── core/ # Core functionality (auth, config, database)
│ │ ├── crud/ # Database operations │ │ ├── repositories/ # Repository pattern (database operations)
│ │ ├── models/ # SQLAlchemy models │ │ ├── models/ # SQLAlchemy models
│ │ ├── schemas/ # Pydantic schemas │ │ ├── schemas/ # Pydantic schemas
│ │ ├── services/ # Business logic │ │ ├── services/ # Business logic
@@ -377,7 +377,7 @@ open htmlcov/index.html
``` ```
**Test types:** **Test types:**
- **Unit tests**: CRUD operations, utilities, business logic - **Unit tests**: Repository operations, utilities, business logic
- **Integration tests**: API endpoints with database - **Integration tests**: API endpoints with database
- **Security tests**: JWT algorithm attacks, session hijacking, privilege escalation - **Security tests**: JWT algorithm attacks, session hijacking, privilege escalation
- **Error handling tests**: Database failures, validation errors - **Error handling tests**: Database failures, validation errors
@@ -542,7 +542,7 @@ docker-compose down
### ✅ Completed ### ✅ Completed
- [x] Authentication system (JWT, refresh tokens, session management, OAuth) - [x] Authentication system (JWT, refresh tokens, session management, OAuth)
- [x] User management (CRUD, profile, password change) - [x] User management (full lifecycle, profile, password change)
- [x] Organization system with RBAC (Owner, Admin, Member) - [x] Organization system with RBAC (Owner, Admin, Member)
- [x] Admin panel (users, organizations, sessions, statistics) - [x] Admin panel (users, organizations, sessions, statistics)
- [x] **Internationalization (i18n)** with next-intl (English + Italian) - [x] **Internationalization (i18n)** with next-intl (English + Italian)

View File

@@ -11,7 +11,7 @@ omit =
app/utils/auth_test_utils.py app/utils/auth_test_utils.py
# Async implementations not yet in use # Async implementations not yet in use
app/crud/base_async.py app/repositories/base_async.py
app/core/database_async.py app/core/database_async.py
# CLI scripts - run manually, not tested # CLI scripts - run manually, not tested
@@ -23,7 +23,7 @@ omit =
app/api/routes/__init__.py app/api/routes/__init__.py
app/api/dependencies/__init__.py app/api/dependencies/__init__.py
app/core/__init__.py app/core/__init__.py
app/crud/__init__.py app/repositories/__init__.py
app/models/__init__.py app/models/__init__.py
app/schemas/__init__.py app/schemas/__init__.py
app/services/__init__.py app/services/__init__.py

View File

@@ -264,7 +264,7 @@ app/
│ ├── database.py # Database engine setup │ ├── database.py # Database engine setup
│ ├── auth.py # JWT token handling │ ├── auth.py # JWT token handling
│ └── exceptions.py # Custom exceptions │ └── exceptions.py # Custom exceptions
├── crud/ # Database operations ├── repositories/ # Repository pattern (database operations)
├── models/ # SQLAlchemy ORM models ├── models/ # SQLAlchemy ORM models
├── schemas/ # Pydantic request/response schemas ├── schemas/ # Pydantic request/response schemas
├── services/ # Business logic layer ├── services/ # Business logic layer
@@ -462,7 +462,7 @@ See [docs/FEATURE_EXAMPLE.md](docs/FEATURE_EXAMPLE.md) for step-by-step guide.
Quick overview: Quick overview:
1. Create Pydantic schemas in `app/schemas/` 1. Create Pydantic schemas in `app/schemas/`
2. Create CRUD operations in `app/crud/` 2. Create repository in `app/repositories/`
3. Create route in `app/api/routes/` 3. Create route in `app/api/routes/`
4. Register router in `app/api/main.py` 4. Register router in `app/api/main.py`
5. Write tests in `tests/api/` 5. Write tests in `tests/api/`

View File

@@ -1,5 +1,5 @@
""" """
User management endpoints for CRUD operations. User management endpoints for database operations.
""" """
import logging import logging

View File

@@ -128,8 +128,8 @@ async def async_transaction_scope() -> AsyncGenerator[AsyncSession, None]:
Usage: Usage:
async with async_transaction_scope() as db: async with async_transaction_scope() as db:
user = await user_crud.create(db, obj_in=user_create) user = await user_repo.create(db, obj_in=user_create)
profile = await profile_crud.create(db, obj_in=profile_create) profile = await profile_repo.create(db, obj_in=profile_create)
# Both operations committed together # Both operations committed together
""" """
async with SessionLocal() as session: async with SessionLocal() as session:

View File

@@ -19,7 +19,7 @@ from app.core.database import SessionLocal, engine
from app.models.organization import Organization from app.models.organization import Organization
from app.models.user import User from app.models.user import User
from app.models.user_organization import UserOrganization from app.models.user_organization import UserOrganization
from app.repositories.user import user_repo as user_crud from app.repositories.user import user_repo as user_repo
from app.schemas.users import UserCreate from app.schemas.users import UserCreate
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -51,7 +51,7 @@ async def init_db() -> User | None:
async with SessionLocal() as session: async with SessionLocal() as session:
try: try:
# Check if superuser already exists # Check if superuser already exists
existing_user = await user_crud.get_by_email(session, email=superuser_email) existing_user = await user_repo.get_by_email(session, email=superuser_email)
if existing_user: if existing_user:
logger.info("Superuser already exists: %s", existing_user.email) logger.info("Superuser already exists: %s", existing_user.email)
@@ -66,7 +66,7 @@ async def init_db() -> User | None:
is_superuser=True, is_superuser=True,
) )
user = await user_crud.create(session, obj_in=user_in) user = await user_repo.create(session, obj_in=user_in)
await session.commit() await session.commit()
await session.refresh(user) await session.refresh(user)
@@ -136,7 +136,7 @@ async def load_demo_data(session):
# Create Users # Create Users
for user_data in data.get("users", []): for user_data in data.get("users", []):
existing_user = await user_crud.get_by_email( existing_user = await user_repo.get_by_email(
session, email=user_data["email"] session, email=user_data["email"]
) )
if not existing_user: if not existing_user:
@@ -149,7 +149,7 @@ async def load_demo_data(session):
is_superuser=user_data["is_superuser"], is_superuser=user_data["is_superuser"],
is_active=user_data.get("is_active", True), is_active=user_data.get("is_active", True),
) )
user = await user_crud.create(session, obj_in=user_in) user = await user_repo.create(session, obj_in=user_in)
# Randomize created_at for demo data (last 30 days) # Randomize created_at for demo data (last 30 days)
# This makes the charts look more realistic # This makes the charts look more realistic

View File

@@ -1,6 +1,6 @@
# app/repositories/base.py # app/repositories/base.py
""" """
Base repository class for async CRUD operations using SQLAlchemy 2.0 async patterns. Base repository class for async database operations using SQLAlchemy 2.0 async patterns.
Provides reusable create, read, update, and delete operations for all models. Provides reusable create, read, update, and delete operations for all models.
""" """

View File

@@ -1,5 +1,5 @@
# app/repositories/oauth_account.py # app/repositories/oauth_account.py
"""Repository for OAuthAccount model async CRUD operations.""" """Repository for OAuthAccount model async database operations."""
import logging import logging
from datetime import datetime from datetime import datetime

View File

@@ -1,5 +1,5 @@
# app/repositories/oauth_client.py # app/repositories/oauth_client.py
"""Repository for OAuthClient model async CRUD operations.""" """Repository for OAuthClient model async database operations."""
import logging import logging
import secrets import secrets

View File

@@ -1,5 +1,5 @@
# app/repositories/oauth_state.py # app/repositories/oauth_state.py
"""Repository for OAuthState model async CRUD operations.""" """Repository for OAuthState model async database operations."""
import logging import logging
from datetime import UTC, datetime from datetime import UTC, datetime

View File

@@ -1,5 +1,5 @@
# app/repositories/organization.py # app/repositories/organization.py
"""Repository for Organization model async CRUD operations using SQLAlchemy 2.0 patterns.""" """Repository for Organization model async database operations using SQLAlchemy 2.0 patterns."""
import logging import logging
from typing import Any from typing import Any

View File

@@ -1,5 +1,5 @@
# app/repositories/session.py # app/repositories/session.py
"""Repository for UserSession model async CRUD operations using SQLAlchemy 2.0 patterns.""" """Repository for UserSession model async database operations using SQLAlchemy 2.0 patterns."""
import logging import logging
import uuid import uuid

View File

@@ -1,5 +1,5 @@
# app/repositories/user.py # app/repositories/user.py
"""Repository for User model async CRUD operations using SQLAlchemy 2.0 patterns.""" """Repository for User model async database operations using SQLAlchemy 2.0 patterns."""
import logging import logging
from datetime import UTC, datetime from datetime import UTC, datetime

View File

@@ -8,7 +8,7 @@ import logging
from datetime import UTC, datetime from datetime import UTC, datetime
from app.core.database import SessionLocal from app.core.database import SessionLocal
from app.repositories.session import session_repo as session_crud from app.repositories.session import session_repo as session_repo
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -32,8 +32,8 @@ async def cleanup_expired_sessions(keep_days: int = 30) -> int:
async with SessionLocal() as db: async with SessionLocal() as db:
try: try:
# Use CRUD method to cleanup # Use repository method to cleanup
count = await session_crud.cleanup_expired(db, keep_days=keep_days) count = await session_repo.cleanup_expired(db, keep_days=keep_days)
logger.info("Session cleanup complete: %s sessions deleted", count) logger.info("Session cleanup complete: %s sessions deleted", count)

View File

@@ -214,7 +214,7 @@ if not user:
### Error Handling Pattern ### Error Handling Pattern
Always follow this pattern in CRUD operations (Async version): Always follow this pattern in repository operations (Async version):
```python ```python
from sqlalchemy.exc import IntegrityError, OperationalError, DataError from sqlalchemy.exc import IntegrityError, OperationalError, DataError
@@ -427,7 +427,7 @@ backend/app/alembic/versions/
## Database Operations ## Database Operations
### Async CRUD Pattern ### Async Repository Pattern
**IMPORTANT**: This application uses **async SQLAlchemy** with modern patterns for better performance and testability. **IMPORTANT**: This application uses **async SQLAlchemy** with modern patterns for better performance and testability.
@@ -567,7 +567,7 @@ async def create_user(
**Key Points:** **Key Points:**
- Route functions must be `async def` - Route functions must be `async def`
- Database parameter is `AsyncSession` - Database parameter is `AsyncSession`
- Always `await` CRUD operations - Always `await` repository operations
#### In Services (Multiple Operations) #### In Services (Multiple Operations)

View File

@@ -334,14 +334,14 @@ def login(request: Request, credentials: OAuth2PasswordRequestForm):
# ❌ WRONG - Returns password hash! # ❌ WRONG - Returns password hash!
@router.get("/users/{user_id}") @router.get("/users/{user_id}")
def get_user(user_id: UUID, db: Session = Depends(get_db)) -> User: def get_user(user_id: UUID, db: Session = Depends(get_db)) -> User:
return user_crud.get(db, id=user_id) # Returns ORM model with ALL fields! return user_repo.get(db, id=user_id) # Returns ORM model with ALL fields!
``` ```
```python ```python
# ✅ CORRECT - Use response schema # ✅ CORRECT - Use response schema
@router.get("/users/{user_id}", response_model=UserResponse) @router.get("/users/{user_id}", response_model=UserResponse)
def get_user(user_id: UUID, db: Session = Depends(get_db)): def get_user(user_id: UUID, db: Session = Depends(get_db)):
user = user_crud.get(db, id=user_id) user = user_repo.get(db, id=user_id)
if not user: if not user:
raise HTTPException(status_code=404, detail="User not found") raise HTTPException(status_code=404, detail="User not found")
return user # Pydantic filters to only UserResponse fields return user # Pydantic filters to only UserResponse fields
@@ -506,8 +506,8 @@ def revoke_session(
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
db: Session = Depends(get_db) db: Session = Depends(get_db)
): ):
session = session_crud.get(db, id=session_id) session = session_repo.get(db, id=session_id)
session_crud.deactivate(db, session_id=session_id) session_repo.deactivate(db, session_id=session_id)
# BUG: User can revoke ANYONE'S session! # BUG: User can revoke ANYONE'S session!
return {"message": "Session revoked"} return {"message": "Session revoked"}
``` ```
@@ -520,7 +520,7 @@ def revoke_session(
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
db: Session = Depends(get_db) db: Session = Depends(get_db)
): ):
session = session_crud.get(db, id=session_id) session = session_repo.get(db, id=session_id)
if not session: if not session:
raise NotFoundError("Session not found") raise NotFoundError("Session not found")
@@ -529,7 +529,7 @@ def revoke_session(
if session.user_id != current_user.id: if session.user_id != current_user.id:
raise AuthorizationError("You can only revoke your own sessions") raise AuthorizationError("You can only revoke your own sessions")
session_crud.deactivate(db, session_id=session_id) session_repo.deactivate(db, session_id=session_id)
return {"message": "Session revoked"} return {"message": "Session revoked"}
``` ```

View File

@@ -99,7 +99,7 @@ backend/tests/
│ └── test_database_workflows.py # PostgreSQL workflow tests │ └── test_database_workflows.py # PostgreSQL workflow tests
├── api/ # Integration tests (SQLite, fast) ├── api/ # Integration tests (SQLite, fast)
├── crud/ # Unit tests ├── repositories/ # Repository unit tests
└── conftest.py # Standard fixtures └── conftest.py # Standard fixtures
``` ```

View File

@@ -13,7 +13,7 @@ import pytest
from httpx import AsyncClient from httpx import AsyncClient
from app.models.user import User from app.models.user import User
from app.repositories.session import session_repo as session_crud from app.repositories.session import session_repo as session_repo
class TestRevokedSessionSecurity: class TestRevokedSessionSecurity:
@@ -117,7 +117,7 @@ class TestRevokedSessionSecurity:
async with SessionLocal() as session: async with SessionLocal() as session:
# Find and delete the session # Find and delete the session
db_session = await session_crud.get_by_jti(session, jti=jti) db_session = await session_repo.get_by_jti(session, jti=jti)
if db_session: if db_session:
await session.delete(db_session) await session.delete(db_session)
await session.commit() await session.commit()

View File

@@ -13,7 +13,7 @@ from httpx import AsyncClient
from app.models.organization import Organization from app.models.organization import Organization
from app.models.user import User from app.models.user import User
from app.repositories.user import user_repo as user_crud from app.repositories.user import user_repo as user_repo
class TestInactiveUserBlocking: class TestInactiveUserBlocking:
@@ -50,7 +50,7 @@ class TestInactiveUserBlocking:
# Step 2: Admin deactivates the user # Step 2: Admin deactivates the user
async with SessionLocal() as session: async with SessionLocal() as session:
user = await user_crud.get(session, id=async_test_user.id) user = await user_repo.get(session, id=async_test_user.id)
user.is_active = False user.is_active = False
await session.commit() await session.commit()
@@ -80,7 +80,7 @@ class TestInactiveUserBlocking:
# Deactivate user # Deactivate user
async with SessionLocal() as session: async with SessionLocal() as session:
user = await user_crud.get(session, id=async_test_user.id) user = await user_repo.get(session, id=async_test_user.id)
user.is_active = False user.is_active = False
await session.commit() await session.commit()

View File

@@ -39,7 +39,7 @@ async def async_test_user2(async_test_db):
_test_engine, SessionLocal = async_test_db _test_engine, SessionLocal = async_test_db
async with SessionLocal() as session: async with SessionLocal() as session:
from app.repositories.user import user_repo as user_crud from app.repositories.user import user_repo as user_repo
from app.schemas.users import UserCreate from app.schemas.users import UserCreate
user_data = UserCreate( user_data = UserCreate(
@@ -48,7 +48,7 @@ async def async_test_user2(async_test_db):
first_name="Test", first_name="Test",
last_name="User2", last_name="User2",
) )
user = await user_crud.create(session, obj_in=user_data) user = await user_repo.create(session, obj_in=user_data)
await session.commit() await session.commit()
await session.refresh(user) await session.refresh(user)
return user return user
@@ -191,9 +191,9 @@ class TestRevokeSession:
# Verify session is deactivated # Verify session is deactivated
async with SessionLocal() as session: async with SessionLocal() as session:
from app.repositories.session import session_repo as session_crud from app.repositories.session import session_repo as session_repo
revoked_session = await session_crud.get(session, id=str(session_id)) revoked_session = await session_repo.get(session, id=str(session_id))
assert revoked_session.is_active is False assert revoked_session.is_active is False
@pytest.mark.asyncio @pytest.mark.asyncio
@@ -267,8 +267,8 @@ class TestCleanupExpiredSessions:
"""Test successfully cleaning up expired sessions.""" """Test successfully cleaning up expired sessions."""
_test_engine, SessionLocal = async_test_db _test_engine, SessionLocal = async_test_db
# Create expired and active sessions using CRUD to avoid greenlet issues # Create expired and active sessions using repository to avoid greenlet issues
from app.repositories.session import session_repo as session_crud from app.repositories.session import session_repo as session_repo
from app.schemas.sessions import SessionCreate from app.schemas.sessions import SessionCreate
async with SessionLocal() as db: async with SessionLocal() as db:
@@ -282,7 +282,7 @@ class TestCleanupExpiredSessions:
expires_at=datetime.now(UTC) - timedelta(days=1), expires_at=datetime.now(UTC) - timedelta(days=1),
last_used_at=datetime.now(UTC) - timedelta(days=2), last_used_at=datetime.now(UTC) - timedelta(days=2),
) )
e1 = await session_crud.create_session(db, obj_in=e1_data) e1 = await session_repo.create_session(db, obj_in=e1_data)
e1.is_active = False e1.is_active = False
db.add(e1) db.add(e1)
@@ -296,7 +296,7 @@ class TestCleanupExpiredSessions:
expires_at=datetime.now(UTC) - timedelta(hours=1), expires_at=datetime.now(UTC) - timedelta(hours=1),
last_used_at=datetime.now(UTC) - timedelta(hours=2), last_used_at=datetime.now(UTC) - timedelta(hours=2),
) )
e2 = await session_crud.create_session(db, obj_in=e2_data) e2 = await session_repo.create_session(db, obj_in=e2_data)
e2.is_active = False e2.is_active = False
db.add(e2) db.add(e2)
@@ -310,7 +310,7 @@ class TestCleanupExpiredSessions:
expires_at=datetime.now(UTC) + timedelta(days=7), expires_at=datetime.now(UTC) + timedelta(days=7),
last_used_at=datetime.now(UTC), last_used_at=datetime.now(UTC),
) )
await session_crud.create_session(db, obj_in=a1_data) await session_repo.create_session(db, obj_in=a1_data)
await db.commit() await db.commit()
# Cleanup expired sessions # Cleanup expired sessions
@@ -333,8 +333,8 @@ class TestCleanupExpiredSessions:
"""Test cleanup when no sessions are expired.""" """Test cleanup when no sessions are expired."""
_test_engine, SessionLocal = async_test_db _test_engine, SessionLocal = async_test_db
# Create only active sessions using CRUD # Create only active sessions using repository
from app.repositories.session import session_repo as session_crud from app.repositories.session import session_repo as session_repo
from app.schemas.sessions import SessionCreate from app.schemas.sessions import SessionCreate
async with SessionLocal() as db: async with SessionLocal() as db:
@@ -347,7 +347,7 @@ class TestCleanupExpiredSessions:
expires_at=datetime.now(UTC) + timedelta(days=7), expires_at=datetime.now(UTC) + timedelta(days=7),
last_used_at=datetime.now(UTC), last_used_at=datetime.now(UTC),
) )
await session_crud.create_session(db, obj_in=a1_data) await session_repo.create_session(db, obj_in=a1_data)
await db.commit() await db.commit()
response = await client.delete( response = await client.delete(
@@ -384,7 +384,7 @@ class TestSessionsAdditionalCases:
# Create multiple sessions # Create multiple sessions
async with SessionLocal() as session: async with SessionLocal() as session:
from app.repositories.session import session_repo as session_crud from app.repositories.session import session_repo as session_repo
from app.schemas.sessions import SessionCreate from app.schemas.sessions import SessionCreate
for i in range(5): for i in range(5):
@@ -397,7 +397,7 @@ class TestSessionsAdditionalCases:
expires_at=datetime.now(UTC) + timedelta(days=7), expires_at=datetime.now(UTC) + timedelta(days=7),
last_used_at=datetime.now(UTC), last_used_at=datetime.now(UTC),
) )
await session_crud.create_session(session, obj_in=session_data) await session_repo.create_session(session, obj_in=session_data)
await session.commit() await session.commit()
response = await client.get( response = await client.get(
@@ -431,7 +431,7 @@ class TestSessionsAdditionalCases:
"""Test cleanup with mix of active/inactive and expired/not-expired sessions.""" """Test cleanup with mix of active/inactive and expired/not-expired sessions."""
_test_engine, SessionLocal = async_test_db _test_engine, SessionLocal = async_test_db
from app.repositories.session import session_repo as session_crud from app.repositories.session import session_repo as session_repo
from app.schemas.sessions import SessionCreate from app.schemas.sessions import SessionCreate
async with SessionLocal() as db: async with SessionLocal() as db:
@@ -445,7 +445,7 @@ class TestSessionsAdditionalCases:
expires_at=datetime.now(UTC) - timedelta(days=1), expires_at=datetime.now(UTC) - timedelta(days=1),
last_used_at=datetime.now(UTC) - timedelta(days=2), last_used_at=datetime.now(UTC) - timedelta(days=2),
) )
e1 = await session_crud.create_session(db, obj_in=e1_data) e1 = await session_repo.create_session(db, obj_in=e1_data)
e1.is_active = False e1.is_active = False
db.add(e1) db.add(e1)
@@ -459,7 +459,7 @@ class TestSessionsAdditionalCases:
expires_at=datetime.now(UTC) - timedelta(hours=1), expires_at=datetime.now(UTC) - timedelta(hours=1),
last_used_at=datetime.now(UTC) - timedelta(hours=2), last_used_at=datetime.now(UTC) - timedelta(hours=2),
) )
await session_crud.create_session(db, obj_in=e2_data) await session_repo.create_session(db, obj_in=e2_data)
await db.commit() await db.commit()
@@ -530,7 +530,7 @@ class TestSessionExceptionHandlers:
from app.repositories import session as session_module from app.repositories import session as session_module
# First create a session to revoke # First create a session to revoke
from app.repositories.session import session_repo as session_crud from app.repositories.session import session_repo as session_repo
from app.schemas.sessions import SessionCreate from app.schemas.sessions import SessionCreate
_test_engine, AsyncTestingSessionLocal = async_test_db _test_engine, AsyncTestingSessionLocal = async_test_db
@@ -545,7 +545,7 @@ class TestSessionExceptionHandlers:
last_used_at=datetime.now(UTC), last_used_at=datetime.now(UTC),
expires_at=datetime.now(UTC) + timedelta(days=60), expires_at=datetime.now(UTC) + timedelta(days=60),
) )
user_session = await session_crud.create_session(db, obj_in=session_in) user_session = await session_repo.create_session(db, obj_in=session_in)
session_id = user_session.id session_id = user_session.id
# Mock the deactivate method to raise an exception # Mock the deactivate method to raise an exception

View File

@@ -157,7 +157,7 @@ class TestListUsers:
response = await client.get("/api/v1/users") response = await client.get("/api/v1/users")
assert response.status_code == status.HTTP_401_UNAUTHORIZED assert response.status_code == status.HTTP_401_UNAUTHORIZED
# Note: Removed test_list_users_unexpected_error because mocking at CRUD level # Note: Removed test_list_users_unexpected_error because mocking at repository level
# causes the exception to be raised before FastAPI can handle it properly # causes the exception to be raised before FastAPI can handle it properly

View File

@@ -46,7 +46,7 @@ async def login_user(client, email: str, password: str = "SecurePassword123!"):
async def create_superuser(e2e_db_session, email: str, password: str): async def create_superuser(e2e_db_session, email: str, password: str):
"""Create a superuser directly in the database.""" """Create a superuser directly in the database."""
from app.repositories.user import user_repo as user_crud from app.repositories.user import user_repo as user_repo
from app.schemas.users import UserCreate from app.schemas.users import UserCreate
user_in = UserCreate( user_in = UserCreate(
@@ -56,7 +56,7 @@ async def create_superuser(e2e_db_session, email: str, password: str):
last_name="User", last_name="User",
is_superuser=True, is_superuser=True,
) )
user = await user_crud.create(e2e_db_session, obj_in=user_in) user = await user_repo.create(e2e_db_session, obj_in=user_in)
return user return user

View File

@@ -46,7 +46,7 @@ async def register_and_login(client, email: str, password: str = "SecurePassword
async def create_superuser_and_login(client, db_session): async def create_superuser_and_login(client, db_session):
"""Helper to create a superuser directly in DB and login.""" """Helper to create a superuser directly in DB and login."""
from app.repositories.user import user_repo as user_crud from app.repositories.user import user_repo as user_repo
from app.schemas.users import UserCreate from app.schemas.users import UserCreate
email = f"admin-{uuid4().hex[:8]}@example.com" email = f"admin-{uuid4().hex[:8]}@example.com"
@@ -60,7 +60,7 @@ async def create_superuser_and_login(client, db_session):
last_name="User", last_name="User",
is_superuser=True, is_superuser=True,
) )
await user_crud.create(db_session, obj_in=user_in) await user_repo.create(db_session, obj_in=user_in)
# Login # Login
login_resp = await client.post( login_resp = await client.post(

View File

@@ -1,6 +1,6 @@
# tests/crud/test_base.py # tests/repositories/test_base.py
""" """
Comprehensive tests for CRUDBase class covering all error paths and edge cases. Comprehensive tests for BaseRepository class covering all error paths and edge cases.
""" """
from datetime import UTC from datetime import UTC
@@ -16,11 +16,11 @@ from app.core.repository_exceptions import (
IntegrityConstraintError, IntegrityConstraintError,
InvalidInputError, InvalidInputError,
) )
from app.repositories.user import user_repo as user_crud from app.repositories.user import user_repo as user_repo
from app.schemas.users import UserCreate, UserUpdate from app.schemas.users import UserCreate, UserUpdate
class TestCRUDBaseGet: class TestRepositoryBaseGet:
"""Tests for get method covering UUID validation and options.""" """Tests for get method covering UUID validation and options."""
@pytest.mark.asyncio @pytest.mark.asyncio
@@ -29,7 +29,7 @@ class TestCRUDBaseGet:
_test_engine, SessionLocal = async_test_db _test_engine, SessionLocal = async_test_db
async with SessionLocal() as session: async with SessionLocal() as session:
result = await user_crud.get(session, id="invalid-uuid") result = await user_repo.get(session, id="invalid-uuid")
assert result is None assert result is None
@pytest.mark.asyncio @pytest.mark.asyncio
@@ -38,7 +38,7 @@ class TestCRUDBaseGet:
_test_engine, SessionLocal = async_test_db _test_engine, SessionLocal = async_test_db
async with SessionLocal() as session: async with SessionLocal() as session:
result = await user_crud.get(session, id=12345) # int instead of UUID result = await user_repo.get(session, id=12345) # int instead of UUID
assert result is None assert result is None
@pytest.mark.asyncio @pytest.mark.asyncio
@@ -48,7 +48,7 @@ class TestCRUDBaseGet:
async with SessionLocal() as session: async with SessionLocal() as session:
# Pass UUID object directly # Pass UUID object directly
result = await user_crud.get(session, id=async_test_user.id) result = await user_repo.get(session, id=async_test_user.id)
assert result is not None assert result is not None
assert result.id == async_test_user.id assert result.id == async_test_user.id
@@ -60,7 +60,7 @@ class TestCRUDBaseGet:
async with SessionLocal() as session: async with SessionLocal() as session:
# Test that options parameter is accepted and doesn't error # Test that options parameter is accepted and doesn't error
# We pass an empty list which still tests the code path # We pass an empty list which still tests the code path
result = await user_crud.get( result = await user_repo.get(
session, id=str(async_test_user.id), options=[] session, id=str(async_test_user.id), options=[]
) )
assert result is not None assert result is not None
@@ -74,10 +74,10 @@ class TestCRUDBaseGet:
# Mock execute to raise an exception # Mock execute to raise an exception
with patch.object(session, "execute", side_effect=Exception("DB error")): with patch.object(session, "execute", side_effect=Exception("DB error")):
with pytest.raises(Exception, match="DB error"): with pytest.raises(Exception, match="DB error"):
await user_crud.get(session, id=str(uuid4())) await user_repo.get(session, id=str(uuid4()))
class TestCRUDBaseGetMulti: class TestRepositoryBaseGetMulti:
"""Tests for get_multi method covering pagination validation and options.""" """Tests for get_multi method covering pagination validation and options."""
@pytest.mark.asyncio @pytest.mark.asyncio
@@ -87,7 +87,7 @@ class TestCRUDBaseGetMulti:
async with SessionLocal() as session: async with SessionLocal() as session:
with pytest.raises(InvalidInputError, match="skip must be non-negative"): with pytest.raises(InvalidInputError, match="skip must be non-negative"):
await user_crud.get_multi(session, skip=-1) await user_repo.get_multi(session, skip=-1)
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_get_multi_negative_limit(self, async_test_db): async def test_get_multi_negative_limit(self, async_test_db):
@@ -96,7 +96,7 @@ class TestCRUDBaseGetMulti:
async with SessionLocal() as session: async with SessionLocal() as session:
with pytest.raises(InvalidInputError, match="limit must be non-negative"): with pytest.raises(InvalidInputError, match="limit must be non-negative"):
await user_crud.get_multi(session, limit=-1) await user_repo.get_multi(session, limit=-1)
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_get_multi_limit_too_large(self, async_test_db): async def test_get_multi_limit_too_large(self, async_test_db):
@@ -105,7 +105,7 @@ class TestCRUDBaseGetMulti:
async with SessionLocal() as session: async with SessionLocal() as session:
with pytest.raises(InvalidInputError, match="Maximum limit is 1000"): with pytest.raises(InvalidInputError, match="Maximum limit is 1000"):
await user_crud.get_multi(session, limit=1001) await user_repo.get_multi(session, limit=1001)
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_get_multi_with_options(self, async_test_db, async_test_user): async def test_get_multi_with_options(self, async_test_db, async_test_user):
@@ -114,7 +114,7 @@ class TestCRUDBaseGetMulti:
async with SessionLocal() as session: async with SessionLocal() as session:
# Test that options parameter is accepted # Test that options parameter is accepted
results = await user_crud.get_multi(session, skip=0, limit=10, options=[]) results = await user_repo.get_multi(session, skip=0, limit=10, options=[])
assert isinstance(results, list) assert isinstance(results, list)
@pytest.mark.asyncio @pytest.mark.asyncio
@@ -125,10 +125,10 @@ class TestCRUDBaseGetMulti:
async with SessionLocal() as session: async with SessionLocal() as session:
with patch.object(session, "execute", side_effect=Exception("DB error")): with patch.object(session, "execute", side_effect=Exception("DB error")):
with pytest.raises(Exception, match="DB error"): with pytest.raises(Exception, match="DB error"):
await user_crud.get_multi(session) await user_repo.get_multi(session)
class TestCRUDBaseCreate: class TestRepositoryBaseCreate:
"""Tests for create method covering various error conditions.""" """Tests for create method covering various error conditions."""
@pytest.mark.asyncio @pytest.mark.asyncio
@@ -146,7 +146,7 @@ class TestCRUDBaseCreate:
) )
with pytest.raises(DuplicateEntryError, match="already exists"): with pytest.raises(DuplicateEntryError, match="already exists"):
await user_crud.create(session, obj_in=user_data) await user_repo.create(session, obj_in=user_data)
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_create_integrity_error_non_duplicate(self, async_test_db): async def test_create_integrity_error_non_duplicate(self, async_test_db):
@@ -173,11 +173,11 @@ class TestCRUDBaseCreate:
with pytest.raises( with pytest.raises(
DuplicateEntryError, match="Database integrity error" DuplicateEntryError, match="Database integrity error"
): ):
await user_crud.create(session, obj_in=user_data) await user_repo.create(session, obj_in=user_data)
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_create_operational_error(self, async_test_db): async def test_create_operational_error(self, async_test_db):
"""Test create with OperationalError (user CRUD catches as generic Exception).""" """Test create with OperationalError (user repository catches as generic Exception)."""
_test_engine, SessionLocal = async_test_db _test_engine, SessionLocal = async_test_db
async with SessionLocal() as session: async with SessionLocal() as session:
@@ -195,13 +195,13 @@ class TestCRUDBaseCreate:
last_name="User", last_name="User",
) )
# User CRUD catches this as generic Exception and re-raises # User repository catches this as generic Exception and re-raises
with pytest.raises(OperationalError): with pytest.raises(OperationalError):
await user_crud.create(session, obj_in=user_data) await user_repo.create(session, obj_in=user_data)
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_create_data_error(self, async_test_db): async def test_create_data_error(self, async_test_db):
"""Test create with DataError (user CRUD catches as generic Exception).""" """Test create with DataError (user repository catches as generic Exception)."""
_test_engine, SessionLocal = async_test_db _test_engine, SessionLocal = async_test_db
async with SessionLocal() as session: async with SessionLocal() as session:
@@ -217,9 +217,9 @@ class TestCRUDBaseCreate:
last_name="User", last_name="User",
) )
# User CRUD catches this as generic Exception and re-raises # User repository catches this as generic Exception and re-raises
with pytest.raises(DataError): with pytest.raises(DataError):
await user_crud.create(session, obj_in=user_data) await user_repo.create(session, obj_in=user_data)
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_create_unexpected_error(self, async_test_db): async def test_create_unexpected_error(self, async_test_db):
@@ -238,10 +238,10 @@ class TestCRUDBaseCreate:
) )
with pytest.raises(RuntimeError, match="Unexpected error"): with pytest.raises(RuntimeError, match="Unexpected error"):
await user_crud.create(session, obj_in=user_data) await user_repo.create(session, obj_in=user_data)
class TestCRUDBaseUpdate: class TestRepositoryBaseUpdate:
"""Tests for update method covering error conditions.""" """Tests for update method covering error conditions."""
@pytest.mark.asyncio @pytest.mark.asyncio
@@ -251,7 +251,7 @@ class TestCRUDBaseUpdate:
# Create another user # Create another user
async with SessionLocal() as session: async with SessionLocal() as session:
from app.repositories.user import user_repo as user_crud from app.repositories.user import user_repo as user_repo
user2_data = UserCreate( user2_data = UserCreate(
email="user2@example.com", email="user2@example.com",
@@ -259,12 +259,12 @@ class TestCRUDBaseUpdate:
first_name="User", first_name="User",
last_name="Two", last_name="Two",
) )
user2 = await user_crud.create(session, obj_in=user2_data) user2 = await user_repo.create(session, obj_in=user2_data)
await session.commit() await session.commit()
# Try to update user2 with user1's email # Try to update user2 with user1's email
async with SessionLocal() as session: async with SessionLocal() as session:
user2_obj = await user_crud.get(session, id=str(user2.id)) user2_obj = await user_repo.get(session, id=str(user2.id))
with patch.object( with patch.object(
session, session,
@@ -276,7 +276,7 @@ class TestCRUDBaseUpdate:
update_data = UserUpdate(email=async_test_user.email) update_data = UserUpdate(email=async_test_user.email)
with pytest.raises(DuplicateEntryError, match="already exists"): with pytest.raises(DuplicateEntryError, match="already exists"):
await user_crud.update( await user_repo.update(
session, db_obj=user2_obj, obj_in=update_data session, db_obj=user2_obj, obj_in=update_data
) )
@@ -286,10 +286,10 @@ class TestCRUDBaseUpdate:
_test_engine, SessionLocal = async_test_db _test_engine, SessionLocal = async_test_db
async with SessionLocal() as session: async with SessionLocal() as session:
user = await user_crud.get(session, id=str(async_test_user.id)) user = await user_repo.get(session, id=str(async_test_user.id))
# Update with dict (tests lines 164-165) # Update with dict (tests lines 164-165)
updated = await user_crud.update( updated = await user_repo.update(
session, db_obj=user, obj_in={"first_name": "UpdatedName"} session, db_obj=user, obj_in={"first_name": "UpdatedName"}
) )
assert updated.first_name == "UpdatedName" assert updated.first_name == "UpdatedName"
@@ -300,7 +300,7 @@ class TestCRUDBaseUpdate:
_test_engine, SessionLocal = async_test_db _test_engine, SessionLocal = async_test_db
async with SessionLocal() as session: async with SessionLocal() as session:
user = await user_crud.get(session, id=str(async_test_user.id)) user = await user_repo.get(session, id=str(async_test_user.id))
with patch.object( with patch.object(
session, session,
@@ -312,7 +312,7 @@ class TestCRUDBaseUpdate:
with pytest.raises( with pytest.raises(
IntegrityConstraintError, match="Database integrity error" IntegrityConstraintError, match="Database integrity error"
): ):
await user_crud.update( await user_repo.update(
session, db_obj=user, obj_in={"first_name": "Test"} session, db_obj=user, obj_in={"first_name": "Test"}
) )
@@ -322,7 +322,7 @@ class TestCRUDBaseUpdate:
_test_engine, SessionLocal = async_test_db _test_engine, SessionLocal = async_test_db
async with SessionLocal() as session: async with SessionLocal() as session:
user = await user_crud.get(session, id=str(async_test_user.id)) user = await user_repo.get(session, id=str(async_test_user.id))
with patch.object( with patch.object(
session, session,
@@ -334,7 +334,7 @@ class TestCRUDBaseUpdate:
with pytest.raises( with pytest.raises(
IntegrityConstraintError, match="Database operation failed" IntegrityConstraintError, match="Database operation failed"
): ):
await user_crud.update( await user_repo.update(
session, db_obj=user, obj_in={"first_name": "Test"} session, db_obj=user, obj_in={"first_name": "Test"}
) )
@@ -344,18 +344,18 @@ class TestCRUDBaseUpdate:
_test_engine, SessionLocal = async_test_db _test_engine, SessionLocal = async_test_db
async with SessionLocal() as session: async with SessionLocal() as session:
user = await user_crud.get(session, id=str(async_test_user.id)) user = await user_repo.get(session, id=str(async_test_user.id))
with patch.object( with patch.object(
session, "commit", side_effect=RuntimeError("Unexpected") session, "commit", side_effect=RuntimeError("Unexpected")
): ):
with pytest.raises(RuntimeError): with pytest.raises(RuntimeError):
await user_crud.update( await user_repo.update(
session, db_obj=user, obj_in={"first_name": "Test"} session, db_obj=user, obj_in={"first_name": "Test"}
) )
class TestCRUDBaseRemove: class TestRepositoryBaseRemove:
"""Tests for remove method covering UUID validation and error conditions.""" """Tests for remove method covering UUID validation and error conditions."""
@pytest.mark.asyncio @pytest.mark.asyncio
@@ -364,7 +364,7 @@ class TestCRUDBaseRemove:
_test_engine, SessionLocal = async_test_db _test_engine, SessionLocal = async_test_db
async with SessionLocal() as session: async with SessionLocal() as session:
result = await user_crud.remove(session, id="invalid-uuid") result = await user_repo.remove(session, id="invalid-uuid")
assert result is None assert result is None
@pytest.mark.asyncio @pytest.mark.asyncio
@@ -380,13 +380,13 @@ class TestCRUDBaseRemove:
first_name="To", first_name="To",
last_name="Delete", last_name="Delete",
) )
user = await user_crud.create(session, obj_in=user_data) user = await user_repo.create(session, obj_in=user_data)
user_id = user.id user_id = user.id
await session.commit() await session.commit()
# Delete with UUID object # Delete with UUID object
async with SessionLocal() as session: async with SessionLocal() as session:
result = await user_crud.remove(session, id=user_id) # UUID object result = await user_repo.remove(session, id=user_id) # UUID object
assert result is not None assert result is not None
assert result.id == user_id assert result.id == user_id
@@ -396,7 +396,7 @@ class TestCRUDBaseRemove:
_test_engine, SessionLocal = async_test_db _test_engine, SessionLocal = async_test_db
async with SessionLocal() as session: async with SessionLocal() as session:
result = await user_crud.remove(session, id=str(uuid4())) result = await user_repo.remove(session, id=str(uuid4()))
assert result is None assert result is None
@pytest.mark.asyncio @pytest.mark.asyncio
@@ -417,7 +417,7 @@ class TestCRUDBaseRemove:
IntegrityConstraintError, IntegrityConstraintError,
match="Cannot delete.*referenced by other records", match="Cannot delete.*referenced by other records",
): ):
await user_crud.remove(session, id=str(async_test_user.id)) await user_repo.remove(session, id=str(async_test_user.id))
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_remove_unexpected_error(self, async_test_db, async_test_user): async def test_remove_unexpected_error(self, async_test_db, async_test_user):
@@ -429,10 +429,10 @@ class TestCRUDBaseRemove:
session, "commit", side_effect=RuntimeError("Unexpected") session, "commit", side_effect=RuntimeError("Unexpected")
): ):
with pytest.raises(RuntimeError): with pytest.raises(RuntimeError):
await user_crud.remove(session, id=str(async_test_user.id)) await user_repo.remove(session, id=str(async_test_user.id))
class TestCRUDBaseGetMultiWithTotal: class TestRepositoryBaseGetMultiWithTotal:
"""Tests for get_multi_with_total method covering pagination, filtering, sorting.""" """Tests for get_multi_with_total method covering pagination, filtering, sorting."""
@pytest.mark.asyncio @pytest.mark.asyncio
@@ -441,7 +441,7 @@ class TestCRUDBaseGetMultiWithTotal:
_test_engine, SessionLocal = async_test_db _test_engine, SessionLocal = async_test_db
async with SessionLocal() as session: async with SessionLocal() as session:
items, total = await user_crud.get_multi_with_total( items, total = await user_repo.get_multi_with_total(
session, skip=0, limit=10 session, skip=0, limit=10
) )
assert isinstance(items, list) assert isinstance(items, list)
@@ -455,7 +455,7 @@ class TestCRUDBaseGetMultiWithTotal:
async with SessionLocal() as session: async with SessionLocal() as session:
with pytest.raises(InvalidInputError, match="skip must be non-negative"): with pytest.raises(InvalidInputError, match="skip must be non-negative"):
await user_crud.get_multi_with_total(session, skip=-1) await user_repo.get_multi_with_total(session, skip=-1)
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_get_multi_with_total_negative_limit(self, async_test_db): async def test_get_multi_with_total_negative_limit(self, async_test_db):
@@ -464,7 +464,7 @@ class TestCRUDBaseGetMultiWithTotal:
async with SessionLocal() as session: async with SessionLocal() as session:
with pytest.raises(InvalidInputError, match="limit must be non-negative"): with pytest.raises(InvalidInputError, match="limit must be non-negative"):
await user_crud.get_multi_with_total(session, limit=-1) await user_repo.get_multi_with_total(session, limit=-1)
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_get_multi_with_total_limit_too_large(self, async_test_db): async def test_get_multi_with_total_limit_too_large(self, async_test_db):
@@ -473,7 +473,7 @@ class TestCRUDBaseGetMultiWithTotal:
async with SessionLocal() as session: async with SessionLocal() as session:
with pytest.raises(InvalidInputError, match="Maximum limit is 1000"): with pytest.raises(InvalidInputError, match="Maximum limit is 1000"):
await user_crud.get_multi_with_total(session, limit=1001) await user_repo.get_multi_with_total(session, limit=1001)
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_get_multi_with_total_with_filters( async def test_get_multi_with_total_with_filters(
@@ -484,7 +484,7 @@ class TestCRUDBaseGetMultiWithTotal:
async with SessionLocal() as session: async with SessionLocal() as session:
filters = {"email": async_test_user.email} filters = {"email": async_test_user.email}
items, total = await user_crud.get_multi_with_total( items, total = await user_repo.get_multi_with_total(
session, filters=filters session, filters=filters
) )
assert total == 1 assert total == 1
@@ -512,12 +512,12 @@ class TestCRUDBaseGetMultiWithTotal:
first_name="ZZZ", first_name="ZZZ",
last_name="User", last_name="User",
) )
await user_crud.create(session, obj_in=user_data1) await user_repo.create(session, obj_in=user_data1)
await user_crud.create(session, obj_in=user_data2) await user_repo.create(session, obj_in=user_data2)
await session.commit() await session.commit()
async with SessionLocal() as session: async with SessionLocal() as session:
items, total = await user_crud.get_multi_with_total( items, total = await user_repo.get_multi_with_total(
session, sort_by="email", sort_order="asc" session, sort_by="email", sort_order="asc"
) )
assert total >= 3 assert total >= 3
@@ -545,12 +545,12 @@ class TestCRUDBaseGetMultiWithTotal:
first_name="CCC", first_name="CCC",
last_name="User", last_name="User",
) )
await user_crud.create(session, obj_in=user_data1) await user_repo.create(session, obj_in=user_data1)
await user_crud.create(session, obj_in=user_data2) await user_repo.create(session, obj_in=user_data2)
await session.commit() await session.commit()
async with SessionLocal() as session: async with SessionLocal() as session:
items, _total = await user_crud.get_multi_with_total( items, _total = await user_repo.get_multi_with_total(
session, sort_by="email", sort_order="desc", limit=1 session, sort_by="email", sort_order="desc", limit=1
) )
assert len(items) == 1 assert len(items) == 1
@@ -570,19 +570,19 @@ class TestCRUDBaseGetMultiWithTotal:
first_name=f"User{i}", first_name=f"User{i}",
last_name="Test", last_name="Test",
) )
await user_crud.create(session, obj_in=user_data) await user_repo.create(session, obj_in=user_data)
await session.commit() await session.commit()
async with SessionLocal() as session: async with SessionLocal() as session:
# Get first page # Get first page
items1, total = await user_crud.get_multi_with_total( items1, total = await user_repo.get_multi_with_total(
session, skip=0, limit=2 session, skip=0, limit=2
) )
assert len(items1) == 2 assert len(items1) == 2
assert total >= 3 assert total >= 3
# Get second page # Get second page
items2, total2 = await user_crud.get_multi_with_total( items2, total2 = await user_repo.get_multi_with_total(
session, skip=2, limit=2 session, skip=2, limit=2
) )
assert len(items2) >= 1 assert len(items2) >= 1
@@ -594,7 +594,7 @@ class TestCRUDBaseGetMultiWithTotal:
assert ids1.isdisjoint(ids2) assert ids1.isdisjoint(ids2)
class TestCRUDBaseCount: class TestRepositoryBaseCount:
"""Tests for count method.""" """Tests for count method."""
@pytest.mark.asyncio @pytest.mark.asyncio
@@ -603,7 +603,7 @@ class TestCRUDBaseCount:
_test_engine, SessionLocal = async_test_db _test_engine, SessionLocal = async_test_db
async with SessionLocal() as session: async with SessionLocal() as session:
count = await user_crud.count(session) count = await user_repo.count(session)
assert isinstance(count, int) assert isinstance(count, int)
assert count >= 1 # At least the test user assert count >= 1 # At least the test user
@@ -614,7 +614,7 @@ class TestCRUDBaseCount:
# Create additional users # Create additional users
async with SessionLocal() as session: async with SessionLocal() as session:
initial_count = await user_crud.count(session) initial_count = await user_repo.count(session)
user_data1 = UserCreate( user_data1 = UserCreate(
email="count1@example.com", email="count1@example.com",
@@ -628,12 +628,12 @@ class TestCRUDBaseCount:
first_name="Count", first_name="Count",
last_name="Two", last_name="Two",
) )
await user_crud.create(session, obj_in=user_data1) await user_repo.create(session, obj_in=user_data1)
await user_crud.create(session, obj_in=user_data2) await user_repo.create(session, obj_in=user_data2)
await session.commit() await session.commit()
async with SessionLocal() as session: async with SessionLocal() as session:
new_count = await user_crud.count(session) new_count = await user_repo.count(session)
assert new_count == initial_count + 2 assert new_count == initial_count + 2
@pytest.mark.asyncio @pytest.mark.asyncio
@@ -644,10 +644,10 @@ class TestCRUDBaseCount:
async with SessionLocal() as session: async with SessionLocal() as session:
with patch.object(session, "execute", side_effect=Exception("DB error")): with patch.object(session, "execute", side_effect=Exception("DB error")):
with pytest.raises(Exception, match="DB error"): with pytest.raises(Exception, match="DB error"):
await user_crud.count(session) await user_repo.count(session)
class TestCRUDBaseExists: class TestRepositoryBaseExists:
"""Tests for exists method.""" """Tests for exists method."""
@pytest.mark.asyncio @pytest.mark.asyncio
@@ -656,7 +656,7 @@ class TestCRUDBaseExists:
_test_engine, SessionLocal = async_test_db _test_engine, SessionLocal = async_test_db
async with SessionLocal() as session: async with SessionLocal() as session:
result = await user_crud.exists(session, id=str(async_test_user.id)) result = await user_repo.exists(session, id=str(async_test_user.id))
assert result is True assert result is True
@pytest.mark.asyncio @pytest.mark.asyncio
@@ -665,7 +665,7 @@ class TestCRUDBaseExists:
_test_engine, SessionLocal = async_test_db _test_engine, SessionLocal = async_test_db
async with SessionLocal() as session: async with SessionLocal() as session:
result = await user_crud.exists(session, id=str(uuid4())) result = await user_repo.exists(session, id=str(uuid4()))
assert result is False assert result is False
@pytest.mark.asyncio @pytest.mark.asyncio
@@ -674,11 +674,11 @@ class TestCRUDBaseExists:
_test_engine, SessionLocal = async_test_db _test_engine, SessionLocal = async_test_db
async with SessionLocal() as session: async with SessionLocal() as session:
result = await user_crud.exists(session, id="invalid-uuid") result = await user_repo.exists(session, id="invalid-uuid")
assert result is False assert result is False
class TestCRUDBaseSoftDelete: class TestRepositoryBaseSoftDelete:
"""Tests for soft_delete method.""" """Tests for soft_delete method."""
@pytest.mark.asyncio @pytest.mark.asyncio
@@ -694,13 +694,13 @@ class TestCRUDBaseSoftDelete:
first_name="Soft", first_name="Soft",
last_name="Delete", last_name="Delete",
) )
user = await user_crud.create(session, obj_in=user_data) user = await user_repo.create(session, obj_in=user_data)
user_id = user.id user_id = user.id
await session.commit() await session.commit()
# Soft delete the user # Soft delete the user
async with SessionLocal() as session: async with SessionLocal() as session:
deleted = await user_crud.soft_delete(session, id=str(user_id)) deleted = await user_repo.soft_delete(session, id=str(user_id))
assert deleted is not None assert deleted is not None
assert deleted.deleted_at is not None assert deleted.deleted_at is not None
@@ -710,7 +710,7 @@ class TestCRUDBaseSoftDelete:
_test_engine, SessionLocal = async_test_db _test_engine, SessionLocal = async_test_db
async with SessionLocal() as session: async with SessionLocal() as session:
result = await user_crud.soft_delete(session, id="invalid-uuid") result = await user_repo.soft_delete(session, id="invalid-uuid")
assert result is None assert result is None
@pytest.mark.asyncio @pytest.mark.asyncio
@@ -719,7 +719,7 @@ class TestCRUDBaseSoftDelete:
_test_engine, SessionLocal = async_test_db _test_engine, SessionLocal = async_test_db
async with SessionLocal() as session: async with SessionLocal() as session:
result = await user_crud.soft_delete(session, id=str(uuid4())) result = await user_repo.soft_delete(session, id=str(uuid4()))
assert result is None assert result is None
@pytest.mark.asyncio @pytest.mark.asyncio
@@ -735,18 +735,18 @@ class TestCRUDBaseSoftDelete:
first_name="Soft", first_name="Soft",
last_name="Delete2", last_name="Delete2",
) )
user = await user_crud.create(session, obj_in=user_data) user = await user_repo.create(session, obj_in=user_data)
user_id = user.id user_id = user.id
await session.commit() await session.commit()
# Soft delete with UUID object # Soft delete with UUID object
async with SessionLocal() as session: async with SessionLocal() as session:
deleted = await user_crud.soft_delete(session, id=user_id) # UUID object deleted = await user_repo.soft_delete(session, id=user_id) # UUID object
assert deleted is not None assert deleted is not None
assert deleted.deleted_at is not None assert deleted.deleted_at is not None
class TestCRUDBaseRestore: class TestRepositoryBaseRestore:
"""Tests for restore method.""" """Tests for restore method."""
@pytest.mark.asyncio @pytest.mark.asyncio
@@ -762,16 +762,16 @@ class TestCRUDBaseRestore:
first_name="Restore", first_name="Restore",
last_name="Test", last_name="Test",
) )
user = await user_crud.create(session, obj_in=user_data) user = await user_repo.create(session, obj_in=user_data)
user_id = user.id user_id = user.id
await session.commit() await session.commit()
async with SessionLocal() as session: async with SessionLocal() as session:
await user_crud.soft_delete(session, id=str(user_id)) await user_repo.soft_delete(session, id=str(user_id))
# Restore the user # Restore the user
async with SessionLocal() as session: async with SessionLocal() as session:
restored = await user_crud.restore(session, id=str(user_id)) restored = await user_repo.restore(session, id=str(user_id))
assert restored is not None assert restored is not None
assert restored.deleted_at is None assert restored.deleted_at is None
@@ -781,7 +781,7 @@ class TestCRUDBaseRestore:
_test_engine, SessionLocal = async_test_db _test_engine, SessionLocal = async_test_db
async with SessionLocal() as session: async with SessionLocal() as session:
result = await user_crud.restore(session, id="invalid-uuid") result = await user_repo.restore(session, id="invalid-uuid")
assert result is None assert result is None
@pytest.mark.asyncio @pytest.mark.asyncio
@@ -790,7 +790,7 @@ class TestCRUDBaseRestore:
_test_engine, SessionLocal = async_test_db _test_engine, SessionLocal = async_test_db
async with SessionLocal() as session: async with SessionLocal() as session:
result = await user_crud.restore(session, id=str(uuid4())) result = await user_repo.restore(session, id=str(uuid4()))
assert result is None assert result is None
@pytest.mark.asyncio @pytest.mark.asyncio
@@ -800,7 +800,7 @@ class TestCRUDBaseRestore:
async with SessionLocal() as session: async with SessionLocal() as session:
# Try to restore a user that's not deleted # Try to restore a user that's not deleted
result = await user_crud.restore(session, id=str(async_test_user.id)) result = await user_repo.restore(session, id=str(async_test_user.id))
assert result is None assert result is None
@pytest.mark.asyncio @pytest.mark.asyncio
@@ -816,21 +816,21 @@ class TestCRUDBaseRestore:
first_name="Restore", first_name="Restore",
last_name="Test2", last_name="Test2",
) )
user = await user_crud.create(session, obj_in=user_data) user = await user_repo.create(session, obj_in=user_data)
user_id = user.id user_id = user.id
await session.commit() await session.commit()
async with SessionLocal() as session: async with SessionLocal() as session:
await user_crud.soft_delete(session, id=str(user_id)) await user_repo.soft_delete(session, id=str(user_id))
# Restore with UUID object # Restore with UUID object
async with SessionLocal() as session: async with SessionLocal() as session:
restored = await user_crud.restore(session, id=user_id) # UUID object restored = await user_repo.restore(session, id=user_id) # UUID object
assert restored is not None assert restored is not None
assert restored.deleted_at is None assert restored.deleted_at is None
class TestCRUDBasePaginationValidation: class TestRepositoryBasePaginationValidation:
"""Tests for pagination parameter validation (covers lines 254-260).""" """Tests for pagination parameter validation (covers lines 254-260)."""
@pytest.mark.asyncio @pytest.mark.asyncio
@@ -840,7 +840,7 @@ class TestCRUDBasePaginationValidation:
async with SessionLocal() as session: async with SessionLocal() as session:
with pytest.raises(InvalidInputError, match="skip must be non-negative"): with pytest.raises(InvalidInputError, match="skip must be non-negative"):
await user_crud.get_multi_with_total(session, skip=-1, limit=10) await user_repo.get_multi_with_total(session, skip=-1, limit=10)
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_get_multi_with_total_negative_limit(self, async_test_db): async def test_get_multi_with_total_negative_limit(self, async_test_db):
@@ -849,7 +849,7 @@ class TestCRUDBasePaginationValidation:
async with SessionLocal() as session: async with SessionLocal() as session:
with pytest.raises(InvalidInputError, match="limit must be non-negative"): with pytest.raises(InvalidInputError, match="limit must be non-negative"):
await user_crud.get_multi_with_total(session, skip=0, limit=-1) await user_repo.get_multi_with_total(session, skip=0, limit=-1)
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_get_multi_with_total_limit_too_large(self, async_test_db): async def test_get_multi_with_total_limit_too_large(self, async_test_db):
@@ -858,7 +858,7 @@ class TestCRUDBasePaginationValidation:
async with SessionLocal() as session: async with SessionLocal() as session:
with pytest.raises(InvalidInputError, match="Maximum limit is 1000"): with pytest.raises(InvalidInputError, match="Maximum limit is 1000"):
await user_crud.get_multi_with_total(session, skip=0, limit=1001) await user_repo.get_multi_with_total(session, skip=0, limit=1001)
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_get_multi_with_total_with_filters( async def test_get_multi_with_total_with_filters(
@@ -868,7 +868,7 @@ class TestCRUDBasePaginationValidation:
_test_engine, SessionLocal = async_test_db _test_engine, SessionLocal = async_test_db
async with SessionLocal() as session: async with SessionLocal() as session:
users, total = await user_crud.get_multi_with_total( users, total = await user_repo.get_multi_with_total(
session, skip=0, limit=10, filters={"is_active": True} session, skip=0, limit=10, filters={"is_active": True}
) )
assert isinstance(users, list) assert isinstance(users, list)
@@ -880,7 +880,7 @@ class TestCRUDBasePaginationValidation:
_test_engine, SessionLocal = async_test_db _test_engine, SessionLocal = async_test_db
async with SessionLocal() as session: async with SessionLocal() as session:
users, _total = await user_crud.get_multi_with_total( users, _total = await user_repo.get_multi_with_total(
session, skip=0, limit=10, sort_by="created_at", sort_order="desc" session, skip=0, limit=10, sort_by="created_at", sort_order="desc"
) )
assert isinstance(users, list) assert isinstance(users, list)
@@ -891,13 +891,13 @@ class TestCRUDBasePaginationValidation:
_test_engine, SessionLocal = async_test_db _test_engine, SessionLocal = async_test_db
async with SessionLocal() as session: async with SessionLocal() as session:
users, _total = await user_crud.get_multi_with_total( users, _total = await user_repo.get_multi_with_total(
session, skip=0, limit=10, sort_by="created_at", sort_order="asc" session, skip=0, limit=10, sort_by="created_at", sort_order="asc"
) )
assert isinstance(users, list) assert isinstance(users, list)
class TestCRUDBaseModelsWithoutSoftDelete: class TestRepositoryBaseModelsWithoutSoftDelete:
""" """
Test soft_delete and restore on models without deleted_at column. Test soft_delete and restore on models without deleted_at column.
Covers lines 342-343, 383-384 - error handling for unsupported models. Covers lines 342-343, 383-384 - error handling for unsupported models.
@@ -912,7 +912,7 @@ class TestCRUDBaseModelsWithoutSoftDelete:
# Create an organization (which doesn't have deleted_at) # Create an organization (which doesn't have deleted_at)
from app.models.organization import Organization from app.models.organization import Organization
from app.repositories.organization import organization_repo as org_crud from app.repositories.organization import organization_repo as org_repo
async with SessionLocal() as session: async with SessionLocal() as session:
org = Organization(name="Test Org", slug="test-org") org = Organization(name="Test Org", slug="test-org")
@@ -925,7 +925,7 @@ class TestCRUDBaseModelsWithoutSoftDelete:
with pytest.raises( with pytest.raises(
InvalidInputError, match="does not have a deleted_at column" InvalidInputError, match="does not have a deleted_at column"
): ):
await org_crud.soft_delete(session, id=str(org_id)) await org_repo.soft_delete(session, id=str(org_id))
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_restore_model_without_deleted_at(self, async_test_db): async def test_restore_model_without_deleted_at(self, async_test_db):
@@ -934,7 +934,7 @@ class TestCRUDBaseModelsWithoutSoftDelete:
# Create an organization (which doesn't have deleted_at) # Create an organization (which doesn't have deleted_at)
from app.models.organization import Organization from app.models.organization import Organization
from app.repositories.organization import organization_repo as org_crud from app.repositories.organization import organization_repo as org_repo
async with SessionLocal() as session: async with SessionLocal() as session:
org = Organization(name="Restore Test", slug="restore-test") org = Organization(name="Restore Test", slug="restore-test")
@@ -947,10 +947,10 @@ class TestCRUDBaseModelsWithoutSoftDelete:
with pytest.raises( with pytest.raises(
InvalidInputError, match="does not have a deleted_at column" InvalidInputError, match="does not have a deleted_at column"
): ):
await org_crud.restore(session, id=str(org_id)) await org_repo.restore(session, id=str(org_id))
class TestCRUDBaseEagerLoadingWithRealOptions: class TestRepositoryBaseEagerLoadingWithRealOptions:
""" """
Test eager loading with actual SQLAlchemy load options. Test eager loading with actual SQLAlchemy load options.
Covers lines 77-78, 119-120 - options loop execution. Covers lines 77-78, 119-120 - options loop execution.
@@ -967,7 +967,7 @@ class TestCRUDBaseEagerLoadingWithRealOptions:
# Create a session for the user # Create a session for the user
from app.models.user_session import UserSession from app.models.user_session import UserSession
from app.repositories.session import session_repo as session_crud from app.repositories.session import session_repo as session_repo
async with SessionLocal() as session: async with SessionLocal() as session:
user_session = UserSession( user_session = UserSession(
@@ -985,7 +985,7 @@ class TestCRUDBaseEagerLoadingWithRealOptions:
# Get session with eager loading of user relationship # Get session with eager loading of user relationship
async with SessionLocal() as session: async with SessionLocal() as session:
result = await session_crud.get( result = await session_repo.get(
session, session,
id=str(session_id), id=str(session_id),
options=[joinedload(UserSession.user)], # Real option, not empty list options=[joinedload(UserSession.user)], # Real option, not empty list
@@ -1006,7 +1006,7 @@ class TestCRUDBaseEagerLoadingWithRealOptions:
# Create multiple sessions for the user # Create multiple sessions for the user
from app.models.user_session import UserSession from app.models.user_session import UserSession
from app.repositories.session import session_repo as session_crud from app.repositories.session import session_repo as session_repo
async with SessionLocal() as session: async with SessionLocal() as session:
for i in range(3): for i in range(3):
@@ -1024,7 +1024,7 @@ class TestCRUDBaseEagerLoadingWithRealOptions:
# Get sessions with eager loading # Get sessions with eager loading
async with SessionLocal() as session: async with SessionLocal() as session:
results = await session_crud.get_multi( results = await session_repo.get_multi(
session, session,
skip=0, skip=0,
limit=10, limit=10,

View File

@@ -1,6 +1,6 @@
# tests/crud/test_base_db_failures.py # tests/repositories/test_base_db_failures.py
""" """
Comprehensive tests for base CRUD database failure scenarios. Comprehensive tests for base repository database failure scenarios.
Tests exception handling, rollbacks, and error messages. Tests exception handling, rollbacks, and error messages.
""" """
@@ -11,16 +11,16 @@ import pytest
from sqlalchemy.exc import DataError, OperationalError from sqlalchemy.exc import DataError, OperationalError
from app.core.repository_exceptions import IntegrityConstraintError from app.core.repository_exceptions import IntegrityConstraintError
from app.repositories.user import user_repo as user_crud from app.repositories.user import user_repo as user_repo
from app.schemas.users import UserCreate from app.schemas.users import UserCreate
class TestBaseCRUDCreateFailures: class TestBaseRepositoryCreateFailures:
"""Test base CRUD create method exception handling.""" """Test base repository create method exception handling."""
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_create_operational_error_triggers_rollback(self, async_test_db): async def test_create_operational_error_triggers_rollback(self, async_test_db):
"""Test that OperationalError triggers rollback (User CRUD catches as Exception).""" """Test that OperationalError triggers rollback (User repository catches as Exception)."""
_test_engine, SessionLocal = async_test_db _test_engine, SessionLocal = async_test_db
async with SessionLocal() as session: async with SessionLocal() as session:
@@ -41,16 +41,16 @@ class TestBaseCRUDCreateFailures:
last_name="User", last_name="User",
) )
# User CRUD catches this as generic Exception and re-raises # User repository catches this as generic Exception and re-raises
with pytest.raises(OperationalError): with pytest.raises(OperationalError):
await user_crud.create(session, obj_in=user_data) await user_repo.create(session, obj_in=user_data)
# Verify rollback was called # Verify rollback was called
mock_rollback.assert_called_once() mock_rollback.assert_called_once()
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_create_data_error_triggers_rollback(self, async_test_db): async def test_create_data_error_triggers_rollback(self, async_test_db):
"""Test that DataError triggers rollback (User CRUD catches as Exception).""" """Test that DataError triggers rollback (User repository catches as Exception)."""
_test_engine, SessionLocal = async_test_db _test_engine, SessionLocal = async_test_db
async with SessionLocal() as session: async with SessionLocal() as session:
@@ -69,9 +69,9 @@ class TestBaseCRUDCreateFailures:
last_name="User", last_name="User",
) )
# User CRUD catches this as generic Exception and re-raises # User repository catches this as generic Exception and re-raises
with pytest.raises(DataError): with pytest.raises(DataError):
await user_crud.create(session, obj_in=user_data) await user_repo.create(session, obj_in=user_data)
mock_rollback.assert_called_once() mock_rollback.assert_called_once()
@@ -97,13 +97,13 @@ class TestBaseCRUDCreateFailures:
) )
with pytest.raises(RuntimeError, match="Unexpected database error"): with pytest.raises(RuntimeError, match="Unexpected database error"):
await user_crud.create(session, obj_in=user_data) await user_repo.create(session, obj_in=user_data)
mock_rollback.assert_called_once() mock_rollback.assert_called_once()
class TestBaseCRUDUpdateFailures: class TestBaseRepositoryUpdateFailures:
"""Test base CRUD update method exception handling.""" """Test base repository update method exception handling."""
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_update_operational_error(self, async_test_db, async_test_user): async def test_update_operational_error(self, async_test_db, async_test_user):
@@ -111,7 +111,7 @@ class TestBaseCRUDUpdateFailures:
_test_engine, SessionLocal = async_test_db _test_engine, SessionLocal = async_test_db
async with SessionLocal() as session: async with SessionLocal() as session:
user = await user_crud.get(session, id=str(async_test_user.id)) user = await user_repo.get(session, id=str(async_test_user.id))
async def mock_commit(): async def mock_commit():
raise OperationalError("Connection timeout", {}, Exception("Timeout")) raise OperationalError("Connection timeout", {}, Exception("Timeout"))
@@ -123,7 +123,7 @@ class TestBaseCRUDUpdateFailures:
with pytest.raises( with pytest.raises(
IntegrityConstraintError, match="Database operation failed" IntegrityConstraintError, match="Database operation failed"
): ):
await user_crud.update( await user_repo.update(
session, db_obj=user, obj_in={"first_name": "Updated"} session, db_obj=user, obj_in={"first_name": "Updated"}
) )
@@ -135,7 +135,7 @@ class TestBaseCRUDUpdateFailures:
_test_engine, SessionLocal = async_test_db _test_engine, SessionLocal = async_test_db
async with SessionLocal() as session: async with SessionLocal() as session:
user = await user_crud.get(session, id=str(async_test_user.id)) user = await user_repo.get(session, id=str(async_test_user.id))
async def mock_commit(): async def mock_commit():
raise DataError("Invalid data", {}, Exception("Data type mismatch")) raise DataError("Invalid data", {}, Exception("Data type mismatch"))
@@ -147,7 +147,7 @@ class TestBaseCRUDUpdateFailures:
with pytest.raises( with pytest.raises(
IntegrityConstraintError, match="Database operation failed" IntegrityConstraintError, match="Database operation failed"
): ):
await user_crud.update( await user_repo.update(
session, db_obj=user, obj_in={"first_name": "Updated"} session, db_obj=user, obj_in={"first_name": "Updated"}
) )
@@ -159,7 +159,7 @@ class TestBaseCRUDUpdateFailures:
_test_engine, SessionLocal = async_test_db _test_engine, SessionLocal = async_test_db
async with SessionLocal() as session: async with SessionLocal() as session:
user = await user_crud.get(session, id=str(async_test_user.id)) user = await user_repo.get(session, id=str(async_test_user.id))
async def mock_commit(): async def mock_commit():
raise KeyError("Unexpected error") raise KeyError("Unexpected error")
@@ -169,15 +169,15 @@ class TestBaseCRUDUpdateFailures:
session, "rollback", new_callable=AsyncMock session, "rollback", new_callable=AsyncMock
) as mock_rollback: ) as mock_rollback:
with pytest.raises(KeyError): with pytest.raises(KeyError):
await user_crud.update( await user_repo.update(
session, db_obj=user, obj_in={"first_name": "Updated"} session, db_obj=user, obj_in={"first_name": "Updated"}
) )
mock_rollback.assert_called_once() mock_rollback.assert_called_once()
class TestBaseCRUDRemoveFailures: class TestBaseRepositoryRemoveFailures:
"""Test base CRUD remove method exception handling.""" """Test base repository remove method exception handling."""
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_remove_unexpected_error_triggers_rollback( async def test_remove_unexpected_error_triggers_rollback(
@@ -196,12 +196,12 @@ class TestBaseCRUDRemoveFailures:
session, "rollback", new_callable=AsyncMock session, "rollback", new_callable=AsyncMock
) as mock_rollback: ) as mock_rollback:
with pytest.raises(RuntimeError, match="Database write failed"): with pytest.raises(RuntimeError, match="Database write failed"):
await user_crud.remove(session, id=str(async_test_user.id)) await user_repo.remove(session, id=str(async_test_user.id))
mock_rollback.assert_called_once() mock_rollback.assert_called_once()
class TestBaseCRUDGetMultiWithTotalFailures: class TestBaseRepositoryGetMultiWithTotalFailures:
"""Test get_multi_with_total exception handling.""" """Test get_multi_with_total exception handling."""
@pytest.mark.asyncio @pytest.mark.asyncio
@@ -217,10 +217,10 @@ class TestBaseCRUDGetMultiWithTotalFailures:
with patch.object(session, "execute", side_effect=mock_execute): with patch.object(session, "execute", side_effect=mock_execute):
with pytest.raises(OperationalError): with pytest.raises(OperationalError):
await user_crud.get_multi_with_total(session, skip=0, limit=10) await user_repo.get_multi_with_total(session, skip=0, limit=10)
class TestBaseCRUDCountFailures: class TestBaseRepositoryCountFailures:
"""Test count method exception handling.""" """Test count method exception handling."""
@pytest.mark.asyncio @pytest.mark.asyncio
@@ -235,10 +235,10 @@ class TestBaseCRUDCountFailures:
with patch.object(session, "execute", side_effect=mock_execute): with patch.object(session, "execute", side_effect=mock_execute):
with pytest.raises(OperationalError): with pytest.raises(OperationalError):
await user_crud.count(session) await user_repo.count(session)
class TestBaseCRUDSoftDeleteFailures: class TestBaseRepositorySoftDeleteFailures:
"""Test soft_delete method exception handling.""" """Test soft_delete method exception handling."""
@pytest.mark.asyncio @pytest.mark.asyncio
@@ -258,12 +258,12 @@ class TestBaseCRUDSoftDeleteFailures:
session, "rollback", new_callable=AsyncMock session, "rollback", new_callable=AsyncMock
) as mock_rollback: ) as mock_rollback:
with pytest.raises(RuntimeError, match="Soft delete failed"): with pytest.raises(RuntimeError, match="Soft delete failed"):
await user_crud.soft_delete(session, id=str(async_test_user.id)) await user_repo.soft_delete(session, id=str(async_test_user.id))
mock_rollback.assert_called_once() mock_rollback.assert_called_once()
class TestBaseCRUDRestoreFailures: class TestBaseRepositoryRestoreFailures:
"""Test restore method exception handling.""" """Test restore method exception handling."""
@pytest.mark.asyncio @pytest.mark.asyncio
@@ -279,12 +279,12 @@ class TestBaseCRUDRestoreFailures:
first_name="Restore", first_name="Restore",
last_name="Test", last_name="Test",
) )
user = await user_crud.create(session, obj_in=user_data) user = await user_repo.create(session, obj_in=user_data)
user_id = user.id user_id = user.id
await session.commit() await session.commit()
async with SessionLocal() as session: async with SessionLocal() as session:
await user_crud.soft_delete(session, id=str(user_id)) await user_repo.soft_delete(session, id=str(user_id))
# Now test restore failure # Now test restore failure
async with SessionLocal() as session: async with SessionLocal() as session:
@@ -297,12 +297,12 @@ class TestBaseCRUDRestoreFailures:
session, "rollback", new_callable=AsyncMock session, "rollback", new_callable=AsyncMock
) as mock_rollback: ) as mock_rollback:
with pytest.raises(RuntimeError, match="Restore failed"): with pytest.raises(RuntimeError, match="Restore failed"):
await user_crud.restore(session, id=str(user_id)) await user_repo.restore(session, id=str(user_id))
mock_rollback.assert_called_once() mock_rollback.assert_called_once()
class TestBaseCRUDGetFailures: class TestBaseRepositoryGetFailures:
"""Test get method exception handling.""" """Test get method exception handling."""
@pytest.mark.asyncio @pytest.mark.asyncio
@@ -317,10 +317,10 @@ class TestBaseCRUDGetFailures:
with patch.object(session, "execute", side_effect=mock_execute): with patch.object(session, "execute", side_effect=mock_execute):
with pytest.raises(OperationalError): with pytest.raises(OperationalError):
await user_crud.get(session, id=str(uuid4())) await user_repo.get(session, id=str(uuid4()))
class TestBaseCRUDGetMultiFailures: class TestBaseRepositoryGetMultiFailures:
"""Test get_multi method exception handling.""" """Test get_multi method exception handling."""
@pytest.mark.asyncio @pytest.mark.asyncio
@@ -335,4 +335,4 @@ class TestBaseCRUDGetMultiFailures:
with patch.object(session, "execute", side_effect=mock_execute): with patch.object(session, "execute", side_effect=mock_execute):
with pytest.raises(OperationalError): with pytest.raises(OperationalError):
await user_crud.get_multi(session, skip=0, limit=10) await user_repo.get_multi(session, skip=0, limit=10)

View File

@@ -1,6 +1,6 @@
# tests/crud/test_oauth.py # tests/repositories/test_oauth.py
""" """
Comprehensive tests for OAuth CRUD operations. Comprehensive tests for OAuth repository operations.
""" """
from datetime import UTC, datetime, timedelta from datetime import UTC, datetime, timedelta
@@ -14,8 +14,8 @@ from app.repositories.oauth_state import oauth_state_repo as oauth_state
from app.schemas.oauth import OAuthAccountCreate, OAuthClientCreate, OAuthStateCreate from app.schemas.oauth import OAuthAccountCreate, OAuthClientCreate, OAuthStateCreate
class TestOAuthAccountCRUD: class TestOAuthAccountRepository:
"""Tests for OAuth account CRUD operations.""" """Tests for OAuth account repository operations."""
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_create_account(self, async_test_db, async_test_user): async def test_create_account(self, async_test_db, async_test_user):
@@ -269,8 +269,8 @@ class TestOAuthAccountCRUD:
assert updated.refresh_token == "new_refresh_token" assert updated.refresh_token == "new_refresh_token"
class TestOAuthStateCRUD: class TestOAuthStateRepository:
"""Tests for OAuth state CRUD operations.""" """Tests for OAuth state repository operations."""
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_create_state(self, async_test_db): async def test_create_state(self, async_test_db):
@@ -376,8 +376,8 @@ class TestOAuthStateCRUD:
assert result is not None assert result is not None
class TestOAuthClientCRUD: class TestOAuthClientRepository:
"""Tests for OAuth client CRUD operations (provider mode).""" """Tests for OAuth client repository operations (provider mode)."""
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_create_public_client(self, async_test_db): async def test_create_public_client(self, async_test_db):

View File

@@ -1,6 +1,6 @@
# tests/crud/test_organization_async.py # tests/repositories/test_organization_async.py
""" """
Comprehensive tests for async organization CRUD operations. Comprehensive tests for async organization repository operations.
""" """
from unittest.mock import AsyncMock, MagicMock, patch from unittest.mock import AsyncMock, MagicMock, patch
@@ -12,7 +12,7 @@ from sqlalchemy import select
from app.core.repository_exceptions import DuplicateEntryError, IntegrityConstraintError from app.core.repository_exceptions import DuplicateEntryError, IntegrityConstraintError
from app.models.organization import Organization from app.models.organization import Organization
from app.models.user_organization import OrganizationRole, UserOrganization from app.models.user_organization import OrganizationRole, UserOrganization
from app.repositories.organization import organization_repo as organization_crud from app.repositories.organization import organization_repo as organization_repo
from app.schemas.organizations import OrganizationCreate from app.schemas.organizations import OrganizationCreate
@@ -35,7 +35,7 @@ class TestGetBySlug:
# Get by slug # Get by slug
async with AsyncTestingSessionLocal() as session: async with AsyncTestingSessionLocal() as session:
result = await organization_crud.get_by_slug(session, slug="test-org") result = await organization_repo.get_by_slug(session, slug="test-org")
assert result is not None assert result is not None
assert result.id == org_id assert result.id == org_id
assert result.slug == "test-org" assert result.slug == "test-org"
@@ -46,7 +46,7 @@ class TestGetBySlug:
_test_engine, AsyncTestingSessionLocal = async_test_db _test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session: async with AsyncTestingSessionLocal() as session:
result = await organization_crud.get_by_slug(session, slug="nonexistent") result = await organization_repo.get_by_slug(session, slug="nonexistent")
assert result is None assert result is None
@@ -55,7 +55,7 @@ class TestCreate:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_create_success(self, async_test_db): async def test_create_success(self, async_test_db):
"""Test successfully creating an organization_crud.""" """Test successfully creating an organization_repo."""
_test_engine, AsyncTestingSessionLocal = async_test_db _test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session: async with AsyncTestingSessionLocal() as session:
@@ -66,7 +66,7 @@ class TestCreate:
is_active=True, is_active=True,
settings={"key": "value"}, settings={"key": "value"},
) )
result = await organization_crud.create(session, obj_in=org_in) result = await organization_repo.create(session, obj_in=org_in)
assert result.name == "New Org" assert result.name == "New Org"
assert result.slug == "new-org" assert result.slug == "new-org"
@@ -89,7 +89,7 @@ class TestCreate:
async with AsyncTestingSessionLocal() as session: async with AsyncTestingSessionLocal() as session:
org_in = OrganizationCreate(name="Org 2", slug="duplicate-slug") org_in = OrganizationCreate(name="Org 2", slug="duplicate-slug")
with pytest.raises(DuplicateEntryError, match="already exists"): with pytest.raises(DuplicateEntryError, match="already exists"):
await organization_crud.create(session, obj_in=org_in) await organization_repo.create(session, obj_in=org_in)
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_create_without_settings(self, async_test_db): async def test_create_without_settings(self, async_test_db):
@@ -98,7 +98,7 @@ class TestCreate:
async with AsyncTestingSessionLocal() as session: async with AsyncTestingSessionLocal() as session:
org_in = OrganizationCreate(name="No Settings Org", slug="no-settings") org_in = OrganizationCreate(name="No Settings Org", slug="no-settings")
result = await organization_crud.create(session, obj_in=org_in) result = await organization_repo.create(session, obj_in=org_in)
assert result.settings == {} assert result.settings == {}
@@ -119,7 +119,7 @@ class TestGetMultiWithFilters:
await session.commit() await session.commit()
async with AsyncTestingSessionLocal() as session: async with AsyncTestingSessionLocal() as session:
orgs, total = await organization_crud.get_multi_with_filters(session) orgs, total = await organization_repo.get_multi_with_filters(session)
assert total == 5 assert total == 5
assert len(orgs) == 5 assert len(orgs) == 5
@@ -135,7 +135,7 @@ class TestGetMultiWithFilters:
await session.commit() await session.commit()
async with AsyncTestingSessionLocal() as session: async with AsyncTestingSessionLocal() as session:
orgs, total = await organization_crud.get_multi_with_filters( orgs, total = await organization_repo.get_multi_with_filters(
session, is_active=True session, is_active=True
) )
assert total == 1 assert total == 1
@@ -157,7 +157,7 @@ class TestGetMultiWithFilters:
await session.commit() await session.commit()
async with AsyncTestingSessionLocal() as session: async with AsyncTestingSessionLocal() as session:
orgs, total = await organization_crud.get_multi_with_filters( orgs, total = await organization_repo.get_multi_with_filters(
session, search="tech" session, search="tech"
) )
assert total == 1 assert total == 1
@@ -175,7 +175,7 @@ class TestGetMultiWithFilters:
await session.commit() await session.commit()
async with AsyncTestingSessionLocal() as session: async with AsyncTestingSessionLocal() as session:
orgs, total = await organization_crud.get_multi_with_filters( orgs, total = await organization_repo.get_multi_with_filters(
session, skip=2, limit=3 session, skip=2, limit=3
) )
assert total == 10 assert total == 10
@@ -193,7 +193,7 @@ class TestGetMultiWithFilters:
await session.commit() await session.commit()
async with AsyncTestingSessionLocal() as session: async with AsyncTestingSessionLocal() as session:
orgs, _total = await organization_crud.get_multi_with_filters( orgs, _total = await organization_repo.get_multi_with_filters(
session, sort_by="name", sort_order="asc" session, sort_by="name", sort_order="asc"
) )
assert orgs[0].name == "A Org" assert orgs[0].name == "A Org"
@@ -205,7 +205,7 @@ class TestGetMemberCount:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_get_member_count_success(self, async_test_db, async_test_user): async def test_get_member_count_success(self, async_test_db, async_test_user):
"""Test getting member count for organization_crud.""" """Test getting member count for organization_repo."""
_test_engine, AsyncTestingSessionLocal = async_test_db _test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session: async with AsyncTestingSessionLocal() as session:
@@ -225,7 +225,7 @@ class TestGetMemberCount:
org_id = org.id org_id = org.id
async with AsyncTestingSessionLocal() as session: async with AsyncTestingSessionLocal() as session:
count = await organization_crud.get_member_count( count = await organization_repo.get_member_count(
session, organization_id=org_id session, organization_id=org_id
) )
assert count == 1 assert count == 1
@@ -242,7 +242,7 @@ class TestGetMemberCount:
org_id = org.id org_id = org.id
async with AsyncTestingSessionLocal() as session: async with AsyncTestingSessionLocal() as session:
count = await organization_crud.get_member_count( count = await organization_repo.get_member_count(
session, organization_id=org_id session, organization_id=org_id
) )
assert count == 0 assert count == 0
@@ -253,7 +253,7 @@ class TestAddUser:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_add_user_success(self, async_test_db, async_test_user): async def test_add_user_success(self, async_test_db, async_test_user):
"""Test successfully adding a user to organization_crud.""" """Test successfully adding a user to organization_repo."""
_test_engine, AsyncTestingSessionLocal = async_test_db _test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session: async with AsyncTestingSessionLocal() as session:
@@ -263,7 +263,7 @@ class TestAddUser:
org_id = org.id org_id = org.id
async with AsyncTestingSessionLocal() as session: async with AsyncTestingSessionLocal() as session:
result = await organization_crud.add_user( result = await organization_repo.add_user(
session, session,
organization_id=org_id, organization_id=org_id,
user_id=async_test_user.id, user_id=async_test_user.id,
@@ -297,7 +297,7 @@ class TestAddUser:
async with AsyncTestingSessionLocal() as session: async with AsyncTestingSessionLocal() as session:
with pytest.raises(DuplicateEntryError, match="already a member"): with pytest.raises(DuplicateEntryError, match="already a member"):
await organization_crud.add_user( await organization_repo.add_user(
session, organization_id=org_id, user_id=async_test_user.id session, organization_id=org_id, user_id=async_test_user.id
) )
@@ -322,7 +322,7 @@ class TestAddUser:
org_id = org.id org_id = org.id
async with AsyncTestingSessionLocal() as session: async with AsyncTestingSessionLocal() as session:
result = await organization_crud.add_user( result = await organization_repo.add_user(
session, session,
organization_id=org_id, organization_id=org_id,
user_id=async_test_user.id, user_id=async_test_user.id,
@@ -338,7 +338,7 @@ class TestRemoveUser:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_remove_user_success(self, async_test_db, async_test_user): async def test_remove_user_success(self, async_test_db, async_test_user):
"""Test successfully removing a user from organization_crud.""" """Test successfully removing a user from organization_repo."""
_test_engine, AsyncTestingSessionLocal = async_test_db _test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session: async with AsyncTestingSessionLocal() as session:
@@ -357,7 +357,7 @@ class TestRemoveUser:
org_id = org.id org_id = org.id
async with AsyncTestingSessionLocal() as session: async with AsyncTestingSessionLocal() as session:
result = await organization_crud.remove_user( result = await organization_repo.remove_user(
session, organization_id=org_id, user_id=async_test_user.id session, organization_id=org_id, user_id=async_test_user.id
) )
@@ -385,7 +385,7 @@ class TestRemoveUser:
org_id = org.id org_id = org.id
async with AsyncTestingSessionLocal() as session: async with AsyncTestingSessionLocal() as session:
result = await organization_crud.remove_user( result = await organization_repo.remove_user(
session, organization_id=org_id, user_id=uuid4() session, organization_id=org_id, user_id=uuid4()
) )
@@ -416,7 +416,7 @@ class TestUpdateUserRole:
org_id = org.id org_id = org.id
async with AsyncTestingSessionLocal() as session: async with AsyncTestingSessionLocal() as session:
result = await organization_crud.update_user_role( result = await organization_repo.update_user_role(
session, session,
organization_id=org_id, organization_id=org_id,
user_id=async_test_user.id, user_id=async_test_user.id,
@@ -439,7 +439,7 @@ class TestUpdateUserRole:
org_id = org.id org_id = org.id
async with AsyncTestingSessionLocal() as session: async with AsyncTestingSessionLocal() as session:
result = await organization_crud.update_user_role( result = await organization_repo.update_user_role(
session, session,
organization_id=org_id, organization_id=org_id,
user_id=uuid4(), user_id=uuid4(),
@@ -475,7 +475,7 @@ class TestGetOrganizationMembers:
org_id = org.id org_id = org.id
async with AsyncTestingSessionLocal() as session: async with AsyncTestingSessionLocal() as session:
members, total = await organization_crud.get_organization_members( members, total = await organization_repo.get_organization_members(
session, organization_id=org_id session, organization_id=org_id
) )
@@ -508,7 +508,7 @@ class TestGetOrganizationMembers:
org_id = org.id org_id = org.id
async with AsyncTestingSessionLocal() as session: async with AsyncTestingSessionLocal() as session:
members, total = await organization_crud.get_organization_members( members, total = await organization_repo.get_organization_members(
session, organization_id=org_id, skip=0, limit=10 session, organization_id=org_id, skip=0, limit=10
) )
@@ -539,7 +539,7 @@ class TestGetUserOrganizations:
await session.commit() await session.commit()
async with AsyncTestingSessionLocal() as session: async with AsyncTestingSessionLocal() as session:
orgs = await organization_crud.get_user_organizations( orgs = await organization_repo.get_user_organizations(
session, user_id=async_test_user.id session, user_id=async_test_user.id
) )
@@ -575,7 +575,7 @@ class TestGetUserOrganizations:
await session.commit() await session.commit()
async with AsyncTestingSessionLocal() as session: async with AsyncTestingSessionLocal() as session:
orgs = await organization_crud.get_user_organizations( orgs = await organization_repo.get_user_organizations(
session, user_id=async_test_user.id, is_active=True session, user_id=async_test_user.id, is_active=True
) )
@@ -588,7 +588,7 @@ class TestGetUserRole:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_get_user_role_in_org_success(self, async_test_db, async_test_user): async def test_get_user_role_in_org_success(self, async_test_db, async_test_user):
"""Test getting user role in organization_crud.""" """Test getting user role in organization_repo."""
_test_engine, AsyncTestingSessionLocal = async_test_db _test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session: async with AsyncTestingSessionLocal() as session:
@@ -607,7 +607,7 @@ class TestGetUserRole:
org_id = org.id org_id = org.id
async with AsyncTestingSessionLocal() as session: async with AsyncTestingSessionLocal() as session:
role = await organization_crud.get_user_role_in_org( role = await organization_repo.get_user_role_in_org(
session, user_id=async_test_user.id, organization_id=org_id session, user_id=async_test_user.id, organization_id=org_id
) )
@@ -625,7 +625,7 @@ class TestGetUserRole:
org_id = org.id org_id = org.id
async with AsyncTestingSessionLocal() as session: async with AsyncTestingSessionLocal() as session:
role = await organization_crud.get_user_role_in_org( role = await organization_repo.get_user_role_in_org(
session, user_id=uuid4(), organization_id=org_id session, user_id=uuid4(), organization_id=org_id
) )
@@ -656,7 +656,7 @@ class TestIsUserOrgOwner:
org_id = org.id org_id = org.id
async with AsyncTestingSessionLocal() as session: async with AsyncTestingSessionLocal() as session:
is_owner = await organization_crud.is_user_org_owner( is_owner = await organization_repo.is_user_org_owner(
session, user_id=async_test_user.id, organization_id=org_id session, user_id=async_test_user.id, organization_id=org_id
) )
@@ -683,7 +683,7 @@ class TestIsUserOrgOwner:
org_id = org.id org_id = org.id
async with AsyncTestingSessionLocal() as session: async with AsyncTestingSessionLocal() as session:
is_owner = await organization_crud.is_user_org_owner( is_owner = await organization_repo.is_user_org_owner(
session, user_id=async_test_user.id, organization_id=org_id session, user_id=async_test_user.id, organization_id=org_id
) )
@@ -720,7 +720,7 @@ class TestGetMultiWithMemberCounts:
( (
orgs_with_counts, orgs_with_counts,
total, total,
) = await organization_crud.get_multi_with_member_counts(session) ) = await organization_repo.get_multi_with_member_counts(session)
assert total == 2 assert total == 2
assert len(orgs_with_counts) == 2 assert len(orgs_with_counts) == 2
@@ -745,7 +745,7 @@ class TestGetMultiWithMemberCounts:
( (
orgs_with_counts, orgs_with_counts,
total, total,
) = await organization_crud.get_multi_with_member_counts( ) = await organization_repo.get_multi_with_member_counts(
session, is_active=True session, is_active=True
) )
@@ -767,7 +767,7 @@ class TestGetMultiWithMemberCounts:
( (
orgs_with_counts, orgs_with_counts,
total, total,
) = await organization_crud.get_multi_with_member_counts( ) = await organization_repo.get_multi_with_member_counts(
session, search="tech" session, search="tech"
) )
@@ -801,7 +801,7 @@ class TestGetUserOrganizationsWithDetails:
async with AsyncTestingSessionLocal() as session: async with AsyncTestingSessionLocal() as session:
orgs_with_details = ( orgs_with_details = (
await organization_crud.get_user_organizations_with_details( await organization_repo.get_user_organizations_with_details(
session, user_id=async_test_user.id session, user_id=async_test_user.id
) )
) )
@@ -841,7 +841,7 @@ class TestGetUserOrganizationsWithDetails:
async with AsyncTestingSessionLocal() as session: async with AsyncTestingSessionLocal() as session:
orgs_with_details = ( orgs_with_details = (
await organization_crud.get_user_organizations_with_details( await organization_repo.get_user_organizations_with_details(
session, user_id=async_test_user.id, is_active=True session, user_id=async_test_user.id, is_active=True
) )
) )
@@ -874,7 +874,7 @@ class TestIsUserOrgAdmin:
org_id = org.id org_id = org.id
async with AsyncTestingSessionLocal() as session: async with AsyncTestingSessionLocal() as session:
is_admin = await organization_crud.is_user_org_admin( is_admin = await organization_repo.is_user_org_admin(
session, user_id=async_test_user.id, organization_id=org_id session, user_id=async_test_user.id, organization_id=org_id
) )
@@ -901,7 +901,7 @@ class TestIsUserOrgAdmin:
org_id = org.id org_id = org.id
async with AsyncTestingSessionLocal() as session: async with AsyncTestingSessionLocal() as session:
is_admin = await organization_crud.is_user_org_admin( is_admin = await organization_repo.is_user_org_admin(
session, user_id=async_test_user.id, organization_id=org_id session, user_id=async_test_user.id, organization_id=org_id
) )
@@ -928,7 +928,7 @@ class TestIsUserOrgAdmin:
org_id = org.id org_id = org.id
async with AsyncTestingSessionLocal() as session: async with AsyncTestingSessionLocal() as session:
is_admin = await organization_crud.is_user_org_admin( is_admin = await organization_repo.is_user_org_admin(
session, user_id=async_test_user.id, organization_id=org_id session, user_id=async_test_user.id, organization_id=org_id
) )
@@ -937,7 +937,7 @@ class TestIsUserOrgAdmin:
class TestOrganizationExceptionHandlers: class TestOrganizationExceptionHandlers:
""" """
Test exception handlers in organization CRUD methods. Test exception handlers in organization repository methods.
Uses mocks to trigger database errors and verify proper error handling. Uses mocks to trigger database errors and verify proper error handling.
Covers lines: 33-35, 57-62, 114-116, 130-132, 207-209, 258-260, 291-294, 326-329, 385-387, 409-411, 466-468, 491-493 Covers lines: 33-35, 57-62, 114-116, 130-132, 207-209, 258-260, 291-294, 326-329, 385-387, 409-411, 466-468, 491-493
""" """
@@ -952,7 +952,7 @@ class TestOrganizationExceptionHandlers:
session, "execute", side_effect=Exception("Database connection lost") session, "execute", side_effect=Exception("Database connection lost")
): ):
with pytest.raises(Exception, match="Database connection lost"): with pytest.raises(Exception, match="Database connection lost"):
await organization_crud.get_by_slug(session, slug="test-slug") await organization_repo.get_by_slug(session, slug="test-slug")
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_create_integrity_error_non_slug(self, async_test_db): async def test_create_integrity_error_non_slug(self, async_test_db):
@@ -976,7 +976,7 @@ class TestOrganizationExceptionHandlers:
with pytest.raises( with pytest.raises(
IntegrityConstraintError, match="Database integrity error" IntegrityConstraintError, match="Database integrity error"
): ):
await organization_crud.create(session, obj_in=org_in) await organization_repo.create(session, obj_in=org_in)
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_create_unexpected_error(self, async_test_db): async def test_create_unexpected_error(self, async_test_db):
@@ -990,7 +990,7 @@ class TestOrganizationExceptionHandlers:
with patch.object(session, "rollback", new_callable=AsyncMock): with patch.object(session, "rollback", new_callable=AsyncMock):
org_in = OrganizationCreate(name="Test", slug="test") org_in = OrganizationCreate(name="Test", slug="test")
with pytest.raises(RuntimeError, match="Unexpected error"): with pytest.raises(RuntimeError, match="Unexpected error"):
await organization_crud.create(session, obj_in=org_in) await organization_repo.create(session, obj_in=org_in)
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_get_multi_with_filters_database_error(self, async_test_db): async def test_get_multi_with_filters_database_error(self, async_test_db):
@@ -1002,7 +1002,7 @@ class TestOrganizationExceptionHandlers:
session, "execute", side_effect=Exception("Query timeout") session, "execute", side_effect=Exception("Query timeout")
): ):
with pytest.raises(Exception, match="Query timeout"): with pytest.raises(Exception, match="Query timeout"):
await organization_crud.get_multi_with_filters(session) await organization_repo.get_multi_with_filters(session)
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_get_member_count_database_error(self, async_test_db): async def test_get_member_count_database_error(self, async_test_db):
@@ -1016,7 +1016,7 @@ class TestOrganizationExceptionHandlers:
session, "execute", side_effect=Exception("Count query failed") session, "execute", side_effect=Exception("Count query failed")
): ):
with pytest.raises(Exception, match="Count query failed"): with pytest.raises(Exception, match="Count query failed"):
await organization_crud.get_member_count( await organization_repo.get_member_count(
session, organization_id=uuid4() session, organization_id=uuid4()
) )
@@ -1030,7 +1030,7 @@ class TestOrganizationExceptionHandlers:
session, "execute", side_effect=Exception("Complex query failed") session, "execute", side_effect=Exception("Complex query failed")
): ):
with pytest.raises(Exception, match="Complex query failed"): with pytest.raises(Exception, match="Complex query failed"):
await organization_crud.get_multi_with_member_counts(session) await organization_repo.get_multi_with_member_counts(session)
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_add_user_integrity_error(self, async_test_db, async_test_user): async def test_add_user_integrity_error(self, async_test_db, async_test_user):
@@ -1064,7 +1064,7 @@ class TestOrganizationExceptionHandlers:
IntegrityConstraintError, IntegrityConstraintError,
match="Failed to add user to organization", match="Failed to add user to organization",
): ):
await organization_crud.add_user( await organization_repo.add_user(
session, session,
organization_id=org_id, organization_id=org_id,
user_id=async_test_user.id, user_id=async_test_user.id,
@@ -1082,7 +1082,7 @@ class TestOrganizationExceptionHandlers:
session, "execute", side_effect=Exception("Delete failed") session, "execute", side_effect=Exception("Delete failed")
): ):
with pytest.raises(Exception, match="Delete failed"): with pytest.raises(Exception, match="Delete failed"):
await organization_crud.remove_user( await organization_repo.remove_user(
session, organization_id=uuid4(), user_id=async_test_user.id session, organization_id=uuid4(), user_id=async_test_user.id
) )
@@ -1100,7 +1100,7 @@ class TestOrganizationExceptionHandlers:
session, "execute", side_effect=Exception("Update failed") session, "execute", side_effect=Exception("Update failed")
): ):
with pytest.raises(Exception, match="Update failed"): with pytest.raises(Exception, match="Update failed"):
await organization_crud.update_user_role( await organization_repo.update_user_role(
session, session,
organization_id=uuid4(), organization_id=uuid4(),
user_id=async_test_user.id, user_id=async_test_user.id,
@@ -1119,7 +1119,7 @@ class TestOrganizationExceptionHandlers:
session, "execute", side_effect=Exception("Members query failed") session, "execute", side_effect=Exception("Members query failed")
): ):
with pytest.raises(Exception, match="Members query failed"): with pytest.raises(Exception, match="Members query failed"):
await organization_crud.get_organization_members( await organization_repo.get_organization_members(
session, organization_id=uuid4() session, organization_id=uuid4()
) )
@@ -1135,7 +1135,7 @@ class TestOrganizationExceptionHandlers:
session, "execute", side_effect=Exception("User orgs query failed") session, "execute", side_effect=Exception("User orgs query failed")
): ):
with pytest.raises(Exception, match="User orgs query failed"): with pytest.raises(Exception, match="User orgs query failed"):
await organization_crud.get_user_organizations( await organization_repo.get_user_organizations(
session, user_id=async_test_user.id session, user_id=async_test_user.id
) )
@@ -1151,7 +1151,7 @@ class TestOrganizationExceptionHandlers:
session, "execute", side_effect=Exception("Details query failed") session, "execute", side_effect=Exception("Details query failed")
): ):
with pytest.raises(Exception, match="Details query failed"): with pytest.raises(Exception, match="Details query failed"):
await organization_crud.get_user_organizations_with_details( await organization_repo.get_user_organizations_with_details(
session, user_id=async_test_user.id session, user_id=async_test_user.id
) )
@@ -1169,6 +1169,6 @@ class TestOrganizationExceptionHandlers:
session, "execute", side_effect=Exception("Role query failed") session, "execute", side_effect=Exception("Role query failed")
): ):
with pytest.raises(Exception, match="Role query failed"): with pytest.raises(Exception, match="Role query failed"):
await organization_crud.get_user_role_in_org( await organization_repo.get_user_role_in_org(
session, user_id=async_test_user.id, organization_id=uuid4() session, user_id=async_test_user.id, organization_id=uuid4()
) )

View File

@@ -1,6 +1,6 @@
# tests/crud/test_session_async.py # tests/repositories/test_session_async.py
""" """
Comprehensive tests for async session CRUD operations. Comprehensive tests for async session repository operations.
""" """
from datetime import UTC, datetime, timedelta from datetime import UTC, datetime, timedelta
@@ -10,7 +10,7 @@ import pytest
from app.core.repository_exceptions import InvalidInputError from app.core.repository_exceptions import InvalidInputError
from app.models.user_session import UserSession from app.models.user_session import UserSession
from app.repositories.session import session_repo as session_crud from app.repositories.session import session_repo as session_repo
from app.schemas.sessions import SessionCreate from app.schemas.sessions import SessionCreate
@@ -37,7 +37,7 @@ class TestGetByJti:
await session.commit() await session.commit()
async with AsyncTestingSessionLocal() as session: async with AsyncTestingSessionLocal() as session:
result = await session_crud.get_by_jti(session, jti="test_jti_123") result = await session_repo.get_by_jti(session, jti="test_jti_123")
assert result is not None assert result is not None
assert result.refresh_token_jti == "test_jti_123" assert result.refresh_token_jti == "test_jti_123"
@@ -47,7 +47,7 @@ class TestGetByJti:
_test_engine, AsyncTestingSessionLocal = async_test_db _test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session: async with AsyncTestingSessionLocal() as session:
result = await session_crud.get_by_jti(session, jti="nonexistent") result = await session_repo.get_by_jti(session, jti="nonexistent")
assert result is None assert result is None
@@ -74,7 +74,7 @@ class TestGetActiveByJti:
await session.commit() await session.commit()
async with AsyncTestingSessionLocal() as session: async with AsyncTestingSessionLocal() as session:
result = await session_crud.get_active_by_jti(session, jti="active_jti") result = await session_repo.get_active_by_jti(session, jti="active_jti")
assert result is not None assert result is not None
assert result.is_active is True assert result.is_active is True
@@ -98,7 +98,7 @@ class TestGetActiveByJti:
await session.commit() await session.commit()
async with AsyncTestingSessionLocal() as session: async with AsyncTestingSessionLocal() as session:
result = await session_crud.get_active_by_jti(session, jti="inactive_jti") result = await session_repo.get_active_by_jti(session, jti="inactive_jti")
assert result is None assert result is None
@@ -135,7 +135,7 @@ class TestGetUserSessions:
await session.commit() await session.commit()
async with AsyncTestingSessionLocal() as session: async with AsyncTestingSessionLocal() as session:
results = await session_crud.get_user_sessions( results = await session_repo.get_user_sessions(
session, user_id=str(async_test_user.id), active_only=True session, user_id=str(async_test_user.id), active_only=True
) )
assert len(results) == 1 assert len(results) == 1
@@ -162,7 +162,7 @@ class TestGetUserSessions:
await session.commit() await session.commit()
async with AsyncTestingSessionLocal() as session: async with AsyncTestingSessionLocal() as session:
results = await session_crud.get_user_sessions( results = await session_repo.get_user_sessions(
session, user_id=str(async_test_user.id), active_only=False session, user_id=str(async_test_user.id), active_only=False
) )
assert len(results) == 3 assert len(results) == 3
@@ -173,7 +173,7 @@ class TestCreateSession:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_create_session_success(self, async_test_db, async_test_user): async def test_create_session_success(self, async_test_db, async_test_user):
"""Test successfully creating a session_crud.""" """Test successfully creating a session_repo."""
_test_engine, AsyncTestingSessionLocal = async_test_db _test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session: async with AsyncTestingSessionLocal() as session:
@@ -189,7 +189,7 @@ class TestCreateSession:
location_city="San Francisco", location_city="San Francisco",
location_country="USA", location_country="USA",
) )
result = await session_crud.create_session(session, obj_in=session_data) result = await session_repo.create_session(session, obj_in=session_data)
assert result.user_id == async_test_user.id assert result.user_id == async_test_user.id
assert result.refresh_token_jti == "new_jti" assert result.refresh_token_jti == "new_jti"
@@ -202,7 +202,7 @@ class TestDeactivate:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_deactivate_success(self, async_test_db, async_test_user): async def test_deactivate_success(self, async_test_db, async_test_user):
"""Test successfully deactivating a session_crud.""" """Test successfully deactivating a session_repo."""
_test_engine, AsyncTestingSessionLocal = async_test_db _test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session: async with AsyncTestingSessionLocal() as session:
@@ -221,7 +221,7 @@ class TestDeactivate:
session_id = user_session.id session_id = user_session.id
async with AsyncTestingSessionLocal() as session: async with AsyncTestingSessionLocal() as session:
result = await session_crud.deactivate(session, session_id=str(session_id)) result = await session_repo.deactivate(session, session_id=str(session_id))
assert result is not None assert result is not None
assert result.is_active is False assert result.is_active is False
@@ -231,7 +231,7 @@ class TestDeactivate:
_test_engine, AsyncTestingSessionLocal = async_test_db _test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session: async with AsyncTestingSessionLocal() as session:
result = await session_crud.deactivate(session, session_id=str(uuid4())) result = await session_repo.deactivate(session, session_id=str(uuid4()))
assert result is None assert result is None
@@ -262,7 +262,7 @@ class TestDeactivateAllUserSessions:
await session.commit() await session.commit()
async with AsyncTestingSessionLocal() as session: async with AsyncTestingSessionLocal() as session:
count = await session_crud.deactivate_all_user_sessions( count = await session_repo.deactivate_all_user_sessions(
session, user_id=str(async_test_user.id) session, user_id=str(async_test_user.id)
) )
assert count == 2 assert count == 2
@@ -292,7 +292,7 @@ class TestUpdateLastUsed:
await session.refresh(user_session) await session.refresh(user_session)
old_time = user_session.last_used_at old_time = user_session.last_used_at
result = await session_crud.update_last_used(session, session=user_session) result = await session_repo.update_last_used(session, session=user_session)
assert result.last_used_at > old_time assert result.last_used_at > old_time
@@ -321,7 +321,7 @@ class TestGetUserSessionCount:
await session.commit() await session.commit()
async with AsyncTestingSessionLocal() as session: async with AsyncTestingSessionLocal() as session:
count = await session_crud.get_user_session_count( count = await session_repo.get_user_session_count(
session, user_id=str(async_test_user.id) session, user_id=str(async_test_user.id)
) )
assert count == 3 assert count == 3
@@ -332,7 +332,7 @@ class TestGetUserSessionCount:
_test_engine, AsyncTestingSessionLocal = async_test_db _test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session: async with AsyncTestingSessionLocal() as session:
count = await session_crud.get_user_session_count( count = await session_repo.get_user_session_count(
session, user_id=str(uuid4()) session, user_id=str(uuid4())
) )
assert count == 0 assert count == 0
@@ -364,7 +364,7 @@ class TestUpdateRefreshToken:
new_jti = "new_jti_123" new_jti = "new_jti_123"
new_expires = datetime.now(UTC) + timedelta(days=14) new_expires = datetime.now(UTC) + timedelta(days=14)
result = await session_crud.update_refresh_token( result = await session_repo.update_refresh_token(
session, session,
session=user_session, session=user_session,
new_jti=new_jti, new_jti=new_jti,
@@ -410,7 +410,7 @@ class TestCleanupExpired:
# Cleanup # Cleanup
async with AsyncTestingSessionLocal() as session: async with AsyncTestingSessionLocal() as session:
count = await session_crud.cleanup_expired(session, keep_days=30) count = await session_repo.cleanup_expired(session, keep_days=30)
assert count == 1 assert count == 1
@pytest.mark.asyncio @pytest.mark.asyncio
@@ -436,7 +436,7 @@ class TestCleanupExpired:
# Cleanup # Cleanup
async with AsyncTestingSessionLocal() as session: async with AsyncTestingSessionLocal() as session:
count = await session_crud.cleanup_expired(session, keep_days=30) count = await session_repo.cleanup_expired(session, keep_days=30)
assert count == 0 # Should not delete recent sessions assert count == 0 # Should not delete recent sessions
@pytest.mark.asyncio @pytest.mark.asyncio
@@ -462,7 +462,7 @@ class TestCleanupExpired:
# Cleanup # Cleanup
async with AsyncTestingSessionLocal() as session: async with AsyncTestingSessionLocal() as session:
count = await session_crud.cleanup_expired(session, keep_days=30) count = await session_repo.cleanup_expired(session, keep_days=30)
assert count == 0 # Should not delete active sessions assert count == 0 # Should not delete active sessions
@@ -493,7 +493,7 @@ class TestCleanupExpiredForUser:
# Cleanup for user # Cleanup for user
async with AsyncTestingSessionLocal() as session: async with AsyncTestingSessionLocal() as session:
count = await session_crud.cleanup_expired_for_user( count = await session_repo.cleanup_expired_for_user(
session, user_id=str(async_test_user.id) session, user_id=str(async_test_user.id)
) )
assert count == 1 assert count == 1
@@ -505,7 +505,7 @@ class TestCleanupExpiredForUser:
async with AsyncTestingSessionLocal() as session: async with AsyncTestingSessionLocal() as session:
with pytest.raises(InvalidInputError, match="Invalid user ID format"): with pytest.raises(InvalidInputError, match="Invalid user ID format"):
await session_crud.cleanup_expired_for_user( await session_repo.cleanup_expired_for_user(
session, user_id="not-a-valid-uuid" session, user_id="not-a-valid-uuid"
) )
@@ -533,7 +533,7 @@ class TestCleanupExpiredForUser:
# Cleanup # Cleanup
async with AsyncTestingSessionLocal() as session: async with AsyncTestingSessionLocal() as session:
count = await session_crud.cleanup_expired_for_user( count = await session_repo.cleanup_expired_for_user(
session, user_id=str(async_test_user.id) session, user_id=str(async_test_user.id)
) )
assert count == 0 # Should not delete active sessions assert count == 0 # Should not delete active sessions
@@ -565,7 +565,7 @@ class TestGetUserSessionsWithUser:
# Get with user relationship # Get with user relationship
async with AsyncTestingSessionLocal() as session: async with AsyncTestingSessionLocal() as session:
results = await session_crud.get_user_sessions( results = await session_repo.get_user_sessions(
session, user_id=str(async_test_user.id), with_user=True session, user_id=str(async_test_user.id), with_user=True
) )
assert len(results) >= 1 assert len(results) >= 1

View File

@@ -1,6 +1,6 @@
# tests/crud/test_session_db_failures.py # tests/repositories/test_session_db_failures.py
""" """
Comprehensive tests for session CRUD database failure scenarios. Comprehensive tests for session repository database failure scenarios.
""" """
from datetime import UTC, datetime, timedelta from datetime import UTC, datetime, timedelta
@@ -12,11 +12,11 @@ from sqlalchemy.exc import OperationalError
from app.core.repository_exceptions import IntegrityConstraintError from app.core.repository_exceptions import IntegrityConstraintError
from app.models.user_session import UserSession from app.models.user_session import UserSession
from app.repositories.session import session_repo as session_crud from app.repositories.session import session_repo as session_repo
from app.schemas.sessions import SessionCreate from app.schemas.sessions import SessionCreate
class TestSessionCRUDGetByJtiFailures: class TestSessionRepositoryGetByJtiFailures:
"""Test get_by_jti exception handling.""" """Test get_by_jti exception handling."""
@pytest.mark.asyncio @pytest.mark.asyncio
@@ -31,10 +31,10 @@ class TestSessionCRUDGetByJtiFailures:
with patch.object(session, "execute", side_effect=mock_execute): with patch.object(session, "execute", side_effect=mock_execute):
with pytest.raises(OperationalError): with pytest.raises(OperationalError):
await session_crud.get_by_jti(session, jti="test_jti") await session_repo.get_by_jti(session, jti="test_jti")
class TestSessionCRUDGetActiveByJtiFailures: class TestSessionRepositoryGetActiveByJtiFailures:
"""Test get_active_by_jti exception handling.""" """Test get_active_by_jti exception handling."""
@pytest.mark.asyncio @pytest.mark.asyncio
@@ -49,10 +49,10 @@ class TestSessionCRUDGetActiveByJtiFailures:
with patch.object(session, "execute", side_effect=mock_execute): with patch.object(session, "execute", side_effect=mock_execute):
with pytest.raises(OperationalError): with pytest.raises(OperationalError):
await session_crud.get_active_by_jti(session, jti="test_jti") await session_repo.get_active_by_jti(session, jti="test_jti")
class TestSessionCRUDGetUserSessionsFailures: class TestSessionRepositoryGetUserSessionsFailures:
"""Test get_user_sessions exception handling.""" """Test get_user_sessions exception handling."""
@pytest.mark.asyncio @pytest.mark.asyncio
@@ -69,12 +69,12 @@ class TestSessionCRUDGetUserSessionsFailures:
with patch.object(session, "execute", side_effect=mock_execute): with patch.object(session, "execute", side_effect=mock_execute):
with pytest.raises(OperationalError): with pytest.raises(OperationalError):
await session_crud.get_user_sessions( await session_repo.get_user_sessions(
session, user_id=str(async_test_user.id) session, user_id=str(async_test_user.id)
) )
class TestSessionCRUDCreateSessionFailures: class TestSessionRepositoryCreateSessionFailures:
"""Test create_session exception handling.""" """Test create_session exception handling."""
@pytest.mark.asyncio @pytest.mark.asyncio
@@ -106,7 +106,7 @@ class TestSessionCRUDCreateSessionFailures:
with pytest.raises( with pytest.raises(
IntegrityConstraintError, match="Failed to create session" IntegrityConstraintError, match="Failed to create session"
): ):
await session_crud.create_session(session, obj_in=session_data) await session_repo.create_session(session, obj_in=session_data)
mock_rollback.assert_called_once() mock_rollback.assert_called_once()
@@ -139,12 +139,12 @@ class TestSessionCRUDCreateSessionFailures:
with pytest.raises( with pytest.raises(
IntegrityConstraintError, match="Failed to create session" IntegrityConstraintError, match="Failed to create session"
): ):
await session_crud.create_session(session, obj_in=session_data) await session_repo.create_session(session, obj_in=session_data)
mock_rollback.assert_called_once() mock_rollback.assert_called_once()
class TestSessionCRUDDeactivateFailures: class TestSessionRepositoryDeactivateFailures:
"""Test deactivate exception handling.""" """Test deactivate exception handling."""
@pytest.mark.asyncio @pytest.mark.asyncio
@@ -182,14 +182,14 @@ class TestSessionCRUDDeactivateFailures:
session, "rollback", new_callable=AsyncMock session, "rollback", new_callable=AsyncMock
) as mock_rollback: ) as mock_rollback:
with pytest.raises(OperationalError): with pytest.raises(OperationalError):
await session_crud.deactivate( await session_repo.deactivate(
session, session_id=str(session_id) session, session_id=str(session_id)
) )
mock_rollback.assert_called_once() mock_rollback.assert_called_once()
class TestSessionCRUDDeactivateAllFailures: class TestSessionRepositoryDeactivateAllFailures:
"""Test deactivate_all_user_sessions exception handling.""" """Test deactivate_all_user_sessions exception handling."""
@pytest.mark.asyncio @pytest.mark.asyncio
@@ -209,14 +209,14 @@ class TestSessionCRUDDeactivateAllFailures:
session, "rollback", new_callable=AsyncMock session, "rollback", new_callable=AsyncMock
) as mock_rollback: ) as mock_rollback:
with pytest.raises(OperationalError): with pytest.raises(OperationalError):
await session_crud.deactivate_all_user_sessions( await session_repo.deactivate_all_user_sessions(
session, user_id=str(async_test_user.id) session, user_id=str(async_test_user.id)
) )
mock_rollback.assert_called_once() mock_rollback.assert_called_once()
class TestSessionCRUDUpdateLastUsedFailures: class TestSessionRepositoryUpdateLastUsedFailures:
"""Test update_last_used exception handling.""" """Test update_last_used exception handling."""
@pytest.mark.asyncio @pytest.mark.asyncio
@@ -259,12 +259,12 @@ class TestSessionCRUDUpdateLastUsedFailures:
session, "rollback", new_callable=AsyncMock session, "rollback", new_callable=AsyncMock
) as mock_rollback: ) as mock_rollback:
with pytest.raises(OperationalError): with pytest.raises(OperationalError):
await session_crud.update_last_used(session, session=sess) await session_repo.update_last_used(session, session=sess)
mock_rollback.assert_called_once() mock_rollback.assert_called_once()
class TestSessionCRUDUpdateRefreshTokenFailures: class TestSessionRepositoryUpdateRefreshTokenFailures:
"""Test update_refresh_token exception handling.""" """Test update_refresh_token exception handling."""
@pytest.mark.asyncio @pytest.mark.asyncio
@@ -307,7 +307,7 @@ class TestSessionCRUDUpdateRefreshTokenFailures:
session, "rollback", new_callable=AsyncMock session, "rollback", new_callable=AsyncMock
) as mock_rollback: ) as mock_rollback:
with pytest.raises(OperationalError): with pytest.raises(OperationalError):
await session_crud.update_refresh_token( await session_repo.update_refresh_token(
session, session,
session=sess, session=sess,
new_jti=str(uuid4()), new_jti=str(uuid4()),
@@ -317,7 +317,7 @@ class TestSessionCRUDUpdateRefreshTokenFailures:
mock_rollback.assert_called_once() mock_rollback.assert_called_once()
class TestSessionCRUDCleanupExpiredFailures: class TestSessionRepositoryCleanupExpiredFailures:
"""Test cleanup_expired exception handling.""" """Test cleanup_expired exception handling."""
@pytest.mark.asyncio @pytest.mark.asyncio
@@ -337,12 +337,12 @@ class TestSessionCRUDCleanupExpiredFailures:
session, "rollback", new_callable=AsyncMock session, "rollback", new_callable=AsyncMock
) as mock_rollback: ) as mock_rollback:
with pytest.raises(OperationalError): with pytest.raises(OperationalError):
await session_crud.cleanup_expired(session, keep_days=30) await session_repo.cleanup_expired(session, keep_days=30)
mock_rollback.assert_called_once() mock_rollback.assert_called_once()
class TestSessionCRUDCleanupExpiredForUserFailures: class TestSessionRepositoryCleanupExpiredForUserFailures:
"""Test cleanup_expired_for_user exception handling.""" """Test cleanup_expired_for_user exception handling."""
@pytest.mark.asyncio @pytest.mark.asyncio
@@ -362,14 +362,14 @@ class TestSessionCRUDCleanupExpiredForUserFailures:
session, "rollback", new_callable=AsyncMock session, "rollback", new_callable=AsyncMock
) as mock_rollback: ) as mock_rollback:
with pytest.raises(OperationalError): with pytest.raises(OperationalError):
await session_crud.cleanup_expired_for_user( await session_repo.cleanup_expired_for_user(
session, user_id=str(async_test_user.id) session, user_id=str(async_test_user.id)
) )
mock_rollback.assert_called_once() mock_rollback.assert_called_once()
class TestSessionCRUDGetUserSessionCountFailures: class TestSessionRepositoryGetUserSessionCountFailures:
"""Test get_user_session_count exception handling.""" """Test get_user_session_count exception handling."""
@pytest.mark.asyncio @pytest.mark.asyncio
@@ -386,6 +386,6 @@ class TestSessionCRUDGetUserSessionCountFailures:
with patch.object(session, "execute", side_effect=mock_execute): with patch.object(session, "execute", side_effect=mock_execute):
with pytest.raises(OperationalError): with pytest.raises(OperationalError):
await session_crud.get_user_session_count( await session_repo.get_user_session_count(
session, user_id=str(async_test_user.id) session, user_id=str(async_test_user.id)
) )

View File

@@ -1,12 +1,12 @@
# tests/crud/test_user_async.py # tests/repositories/test_user_async.py
""" """
Comprehensive tests for async user CRUD operations. Comprehensive tests for async user repository operations.
""" """
import pytest import pytest
from app.core.repository_exceptions import DuplicateEntryError, InvalidInputError from app.core.repository_exceptions import DuplicateEntryError, InvalidInputError
from app.repositories.user import user_repo as user_crud from app.repositories.user import user_repo as user_repo
from app.schemas.users import UserCreate, UserUpdate from app.schemas.users import UserCreate, UserUpdate
@@ -19,7 +19,7 @@ class TestGetByEmail:
_test_engine, AsyncTestingSessionLocal = async_test_db _test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session: async with AsyncTestingSessionLocal() as session:
result = await user_crud.get_by_email(session, email=async_test_user.email) result = await user_repo.get_by_email(session, email=async_test_user.email)
assert result is not None assert result is not None
assert result.email == async_test_user.email assert result.email == async_test_user.email
assert result.id == async_test_user.id assert result.id == async_test_user.id
@@ -30,7 +30,7 @@ class TestGetByEmail:
_test_engine, AsyncTestingSessionLocal = async_test_db _test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session: async with AsyncTestingSessionLocal() as session:
result = await user_crud.get_by_email( result = await user_repo.get_by_email(
session, email="nonexistent@example.com" session, email="nonexistent@example.com"
) )
assert result is None assert result is None
@@ -41,7 +41,7 @@ class TestCreate:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_create_user_success(self, async_test_db): async def test_create_user_success(self, async_test_db):
"""Test successfully creating a user_crud.""" """Test successfully creating a user_repo."""
_test_engine, AsyncTestingSessionLocal = async_test_db _test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session: async with AsyncTestingSessionLocal() as session:
@@ -52,7 +52,7 @@ class TestCreate:
last_name="User", last_name="User",
phone_number="+1234567890", phone_number="+1234567890",
) )
result = await user_crud.create(session, obj_in=user_data) result = await user_repo.create(session, obj_in=user_data)
assert result.email == "newuser@example.com" assert result.email == "newuser@example.com"
assert result.first_name == "New" assert result.first_name == "New"
@@ -76,7 +76,7 @@ class TestCreate:
last_name="User", last_name="User",
is_superuser=True, is_superuser=True,
) )
result = await user_crud.create(session, obj_in=user_data) result = await user_repo.create(session, obj_in=user_data)
assert result.is_superuser is True assert result.is_superuser is True
assert result.email == "superuser@example.com" assert result.email == "superuser@example.com"
@@ -95,7 +95,7 @@ class TestCreate:
) )
with pytest.raises(DuplicateEntryError) as exc_info: with pytest.raises(DuplicateEntryError) as exc_info:
await user_crud.create(session, obj_in=user_data) await user_repo.create(session, obj_in=user_data)
assert "already exists" in str(exc_info.value).lower() assert "already exists" in str(exc_info.value).lower()
@@ -110,12 +110,12 @@ class TestUpdate:
async with AsyncTestingSessionLocal() as session: async with AsyncTestingSessionLocal() as session:
# Get fresh copy of user # Get fresh copy of user
user = await user_crud.get(session, id=str(async_test_user.id)) user = await user_repo.get(session, id=str(async_test_user.id))
update_data = UserUpdate( update_data = UserUpdate(
first_name="Updated", last_name="Name", phone_number="+9876543210" first_name="Updated", last_name="Name", phone_number="+9876543210"
) )
result = await user_crud.update(session, db_obj=user, obj_in=update_data) result = await user_repo.update(session, db_obj=user, obj_in=update_data)
assert result.first_name == "Updated" assert result.first_name == "Updated"
assert result.last_name == "Name" assert result.last_name == "Name"
@@ -134,16 +134,16 @@ class TestUpdate:
first_name="Pass", first_name="Pass",
last_name="Test", last_name="Test",
) )
user = await user_crud.create(session, obj_in=user_data) user = await user_repo.create(session, obj_in=user_data)
user_id = user.id user_id = user.id
old_password_hash = user.password_hash old_password_hash = user.password_hash
# Update the password # Update the password
async with AsyncTestingSessionLocal() as session: async with AsyncTestingSessionLocal() as session:
user = await user_crud.get(session, id=str(user_id)) user = await user_repo.get(session, id=str(user_id))
update_data = UserUpdate(password="NewDifferentPassword123!") update_data = UserUpdate(password="NewDifferentPassword123!")
result = await user_crud.update(session, db_obj=user, obj_in=update_data) result = await user_repo.update(session, db_obj=user, obj_in=update_data)
await session.refresh(result) await session.refresh(result)
assert result.password_hash != old_password_hash assert result.password_hash != old_password_hash
@@ -158,10 +158,10 @@ class TestUpdate:
_test_engine, AsyncTestingSessionLocal = async_test_db _test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session: async with AsyncTestingSessionLocal() as session:
user = await user_crud.get(session, id=str(async_test_user.id)) user = await user_repo.get(session, id=str(async_test_user.id))
update_dict = {"first_name": "DictUpdate"} update_dict = {"first_name": "DictUpdate"}
result = await user_crud.update(session, db_obj=user, obj_in=update_dict) result = await user_repo.update(session, db_obj=user, obj_in=update_dict)
assert result.first_name == "DictUpdate" assert result.first_name == "DictUpdate"
@@ -175,7 +175,7 @@ class TestGetMultiWithTotal:
_test_engine, AsyncTestingSessionLocal = async_test_db _test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session: async with AsyncTestingSessionLocal() as session:
users, total = await user_crud.get_multi_with_total( users, total = await user_repo.get_multi_with_total(
session, skip=0, limit=10 session, skip=0, limit=10
) )
assert total >= 1 assert total >= 1
@@ -196,10 +196,10 @@ class TestGetMultiWithTotal:
first_name=f"User{i}", first_name=f"User{i}",
last_name="Test", last_name="Test",
) )
await user_crud.create(session, obj_in=user_data) await user_repo.create(session, obj_in=user_data)
async with AsyncTestingSessionLocal() as session: async with AsyncTestingSessionLocal() as session:
users, _total = await user_crud.get_multi_with_total( users, _total = await user_repo.get_multi_with_total(
session, skip=0, limit=10, sort_by="email", sort_order="asc" session, skip=0, limit=10, sort_by="email", sort_order="asc"
) )
@@ -222,10 +222,10 @@ class TestGetMultiWithTotal:
first_name=f"User{i}", first_name=f"User{i}",
last_name="Test", last_name="Test",
) )
await user_crud.create(session, obj_in=user_data) await user_repo.create(session, obj_in=user_data)
async with AsyncTestingSessionLocal() as session: async with AsyncTestingSessionLocal() as session:
users, _total = await user_crud.get_multi_with_total( users, _total = await user_repo.get_multi_with_total(
session, skip=0, limit=10, sort_by="email", sort_order="desc" session, skip=0, limit=10, sort_by="email", sort_order="desc"
) )
@@ -247,7 +247,7 @@ class TestGetMultiWithTotal:
first_name="Active", first_name="Active",
last_name="User", last_name="User",
) )
await user_crud.create(session, obj_in=active_user) await user_repo.create(session, obj_in=active_user)
inactive_user = UserCreate( inactive_user = UserCreate(
email="inactive@example.com", email="inactive@example.com",
@@ -255,15 +255,15 @@ class TestGetMultiWithTotal:
first_name="Inactive", first_name="Inactive",
last_name="User", last_name="User",
) )
created_inactive = await user_crud.create(session, obj_in=inactive_user) created_inactive = await user_repo.create(session, obj_in=inactive_user)
# Deactivate the user # Deactivate the user
await user_crud.update( await user_repo.update(
session, db_obj=created_inactive, obj_in={"is_active": False} session, db_obj=created_inactive, obj_in={"is_active": False}
) )
async with AsyncTestingSessionLocal() as session: async with AsyncTestingSessionLocal() as session:
users, _total = await user_crud.get_multi_with_total( users, _total = await user_repo.get_multi_with_total(
session, skip=0, limit=100, filters={"is_active": True} session, skip=0, limit=100, filters={"is_active": True}
) )
@@ -283,10 +283,10 @@ class TestGetMultiWithTotal:
first_name="Searchable", first_name="Searchable",
last_name="UserName", last_name="UserName",
) )
await user_crud.create(session, obj_in=user_data) await user_repo.create(session, obj_in=user_data)
async with AsyncTestingSessionLocal() as session: async with AsyncTestingSessionLocal() as session:
users, total = await user_crud.get_multi_with_total( users, total = await user_repo.get_multi_with_total(
session, skip=0, limit=100, search="Searchable" session, skip=0, limit=100, search="Searchable"
) )
@@ -307,16 +307,16 @@ class TestGetMultiWithTotal:
first_name=f"Page{i}", first_name=f"Page{i}",
last_name="User", last_name="User",
) )
await user_crud.create(session, obj_in=user_data) await user_repo.create(session, obj_in=user_data)
async with AsyncTestingSessionLocal() as session: async with AsyncTestingSessionLocal() as session:
# Get first page # Get first page
users_page1, total = await user_crud.get_multi_with_total( users_page1, total = await user_repo.get_multi_with_total(
session, skip=0, limit=2 session, skip=0, limit=2
) )
# Get second page # Get second page
users_page2, total2 = await user_crud.get_multi_with_total( users_page2, total2 = await user_repo.get_multi_with_total(
session, skip=2, limit=2 session, skip=2, limit=2
) )
@@ -332,7 +332,7 @@ class TestGetMultiWithTotal:
async with AsyncTestingSessionLocal() as session: async with AsyncTestingSessionLocal() as session:
with pytest.raises(InvalidInputError) as exc_info: with pytest.raises(InvalidInputError) as exc_info:
await user_crud.get_multi_with_total(session, skip=-1, limit=10) await user_repo.get_multi_with_total(session, skip=-1, limit=10)
assert "skip must be non-negative" in str(exc_info.value) assert "skip must be non-negative" in str(exc_info.value)
@@ -343,7 +343,7 @@ class TestGetMultiWithTotal:
async with AsyncTestingSessionLocal() as session: async with AsyncTestingSessionLocal() as session:
with pytest.raises(InvalidInputError) as exc_info: with pytest.raises(InvalidInputError) as exc_info:
await user_crud.get_multi_with_total(session, skip=0, limit=-1) await user_repo.get_multi_with_total(session, skip=0, limit=-1)
assert "limit must be non-negative" in str(exc_info.value) assert "limit must be non-negative" in str(exc_info.value)
@@ -354,7 +354,7 @@ class TestGetMultiWithTotal:
async with AsyncTestingSessionLocal() as session: async with AsyncTestingSessionLocal() as session:
with pytest.raises(InvalidInputError) as exc_info: with pytest.raises(InvalidInputError) as exc_info:
await user_crud.get_multi_with_total(session, skip=0, limit=1001) await user_repo.get_multi_with_total(session, skip=0, limit=1001)
assert "Maximum limit is 1000" in str(exc_info.value) assert "Maximum limit is 1000" in str(exc_info.value)
@@ -377,12 +377,12 @@ class TestBulkUpdateStatus:
first_name=f"Bulk{i}", first_name=f"Bulk{i}",
last_name="User", last_name="User",
) )
user = await user_crud.create(session, obj_in=user_data) user = await user_repo.create(session, obj_in=user_data)
user_ids.append(user.id) user_ids.append(user.id)
# Bulk deactivate # Bulk deactivate
async with AsyncTestingSessionLocal() as session: async with AsyncTestingSessionLocal() as session:
count = await user_crud.bulk_update_status( count = await user_repo.bulk_update_status(
session, user_ids=user_ids, is_active=False session, user_ids=user_ids, is_active=False
) )
assert count == 3 assert count == 3
@@ -390,7 +390,7 @@ class TestBulkUpdateStatus:
# Verify all are inactive # Verify all are inactive
async with AsyncTestingSessionLocal() as session: async with AsyncTestingSessionLocal() as session:
for user_id in user_ids: for user_id in user_ids:
user = await user_crud.get(session, id=str(user_id)) user = await user_repo.get(session, id=str(user_id))
assert user.is_active is False assert user.is_active is False
@pytest.mark.asyncio @pytest.mark.asyncio
@@ -399,7 +399,7 @@ class TestBulkUpdateStatus:
_test_engine, AsyncTestingSessionLocal = async_test_db _test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session: async with AsyncTestingSessionLocal() as session:
count = await user_crud.bulk_update_status( count = await user_repo.bulk_update_status(
session, user_ids=[], is_active=False session, user_ids=[], is_active=False
) )
assert count == 0 assert count == 0
@@ -417,21 +417,21 @@ class TestBulkUpdateStatus:
first_name="Reactivate", first_name="Reactivate",
last_name="User", last_name="User",
) )
user = await user_crud.create(session, obj_in=user_data) user = await user_repo.create(session, obj_in=user_data)
# Deactivate # Deactivate
await user_crud.update(session, db_obj=user, obj_in={"is_active": False}) await user_repo.update(session, db_obj=user, obj_in={"is_active": False})
user_id = user.id user_id = user.id
# Reactivate # Reactivate
async with AsyncTestingSessionLocal() as session: async with AsyncTestingSessionLocal() as session:
count = await user_crud.bulk_update_status( count = await user_repo.bulk_update_status(
session, user_ids=[user_id], is_active=True session, user_ids=[user_id], is_active=True
) )
assert count == 1 assert count == 1
# Verify active # Verify active
async with AsyncTestingSessionLocal() as session: async with AsyncTestingSessionLocal() as session:
user = await user_crud.get(session, id=str(user_id)) user = await user_repo.get(session, id=str(user_id))
assert user.is_active is True assert user.is_active is True
@@ -453,24 +453,24 @@ class TestBulkSoftDelete:
first_name=f"Delete{i}", first_name=f"Delete{i}",
last_name="User", last_name="User",
) )
user = await user_crud.create(session, obj_in=user_data) user = await user_repo.create(session, obj_in=user_data)
user_ids.append(user.id) user_ids.append(user.id)
# Bulk delete # Bulk delete
async with AsyncTestingSessionLocal() as session: async with AsyncTestingSessionLocal() as session:
count = await user_crud.bulk_soft_delete(session, user_ids=user_ids) count = await user_repo.bulk_soft_delete(session, user_ids=user_ids)
assert count == 3 assert count == 3
# Verify all are soft deleted # Verify all are soft deleted
async with AsyncTestingSessionLocal() as session: async with AsyncTestingSessionLocal() as session:
for user_id in user_ids: for user_id in user_ids:
user = await user_crud.get(session, id=str(user_id)) user = await user_repo.get(session, id=str(user_id))
assert user.deleted_at is not None assert user.deleted_at is not None
assert user.is_active is False assert user.is_active is False
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_bulk_soft_delete_with_exclusion(self, async_test_db): async def test_bulk_soft_delete_with_exclusion(self, async_test_db):
"""Test bulk soft delete with excluded user_crud.""" """Test bulk soft delete with excluded user_repo."""
_test_engine, AsyncTestingSessionLocal = async_test_db _test_engine, AsyncTestingSessionLocal = async_test_db
# Create multiple users # Create multiple users
@@ -483,20 +483,20 @@ class TestBulkSoftDelete:
first_name=f"Exclude{i}", first_name=f"Exclude{i}",
last_name="User", last_name="User",
) )
user = await user_crud.create(session, obj_in=user_data) user = await user_repo.create(session, obj_in=user_data)
user_ids.append(user.id) user_ids.append(user.id)
# Bulk delete, excluding first user # Bulk delete, excluding first user
exclude_id = user_ids[0] exclude_id = user_ids[0]
async with AsyncTestingSessionLocal() as session: async with AsyncTestingSessionLocal() as session:
count = await user_crud.bulk_soft_delete( count = await user_repo.bulk_soft_delete(
session, user_ids=user_ids, exclude_user_id=exclude_id session, user_ids=user_ids, exclude_user_id=exclude_id
) )
assert count == 2 # Only 2 deleted assert count == 2 # Only 2 deleted
# Verify excluded user is NOT deleted # Verify excluded user is NOT deleted
async with AsyncTestingSessionLocal() as session: async with AsyncTestingSessionLocal() as session:
excluded_user = await user_crud.get(session, id=str(exclude_id)) excluded_user = await user_repo.get(session, id=str(exclude_id))
assert excluded_user.deleted_at is None assert excluded_user.deleted_at is None
@pytest.mark.asyncio @pytest.mark.asyncio
@@ -505,7 +505,7 @@ class TestBulkSoftDelete:
_test_engine, AsyncTestingSessionLocal = async_test_db _test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session: async with AsyncTestingSessionLocal() as session:
count = await user_crud.bulk_soft_delete(session, user_ids=[]) count = await user_repo.bulk_soft_delete(session, user_ids=[])
assert count == 0 assert count == 0
@pytest.mark.asyncio @pytest.mark.asyncio
@@ -521,12 +521,12 @@ class TestBulkSoftDelete:
first_name="Only", first_name="Only",
last_name="User", last_name="User",
) )
user = await user_crud.create(session, obj_in=user_data) user = await user_repo.create(session, obj_in=user_data)
user_id = user.id user_id = user.id
# Try to delete but exclude # Try to delete but exclude
async with AsyncTestingSessionLocal() as session: async with AsyncTestingSessionLocal() as session:
count = await user_crud.bulk_soft_delete( count = await user_repo.bulk_soft_delete(
session, user_ids=[user_id], exclude_user_id=user_id session, user_ids=[user_id], exclude_user_id=user_id
) )
assert count == 0 assert count == 0
@@ -544,15 +544,15 @@ class TestBulkSoftDelete:
first_name="PreDeleted", first_name="PreDeleted",
last_name="User", last_name="User",
) )
user = await user_crud.create(session, obj_in=user_data) user = await user_repo.create(session, obj_in=user_data)
user_id = user.id user_id = user.id
# First deletion # First deletion
await user_crud.bulk_soft_delete(session, user_ids=[user_id]) await user_repo.bulk_soft_delete(session, user_ids=[user_id])
# Try to delete again # Try to delete again
async with AsyncTestingSessionLocal() as session: async with AsyncTestingSessionLocal() as session:
count = await user_crud.bulk_soft_delete(session, user_ids=[user_id]) count = await user_repo.bulk_soft_delete(session, user_ids=[user_id])
assert count == 0 # Already deleted assert count == 0 # Already deleted
@@ -561,16 +561,16 @@ class TestUtilityMethods:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_is_active_true(self, async_test_db, async_test_user): async def test_is_active_true(self, async_test_db, async_test_user):
"""Test is_active returns True for active user_crud.""" """Test is_active returns True for active user_repo."""
_test_engine, AsyncTestingSessionLocal = async_test_db _test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session: async with AsyncTestingSessionLocal() as session:
user = await user_crud.get(session, id=str(async_test_user.id)) user = await user_repo.get(session, id=str(async_test_user.id))
assert user_crud.is_active(user) is True assert user_repo.is_active(user) is True
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_is_active_false(self, async_test_db): async def test_is_active_false(self, async_test_db):
"""Test is_active returns False for inactive user_crud.""" """Test is_active returns False for inactive user_repo."""
_test_engine, AsyncTestingSessionLocal = async_test_db _test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session: async with AsyncTestingSessionLocal() as session:
@@ -580,10 +580,10 @@ class TestUtilityMethods:
first_name="Inactive", first_name="Inactive",
last_name="User", last_name="User",
) )
user = await user_crud.create(session, obj_in=user_data) user = await user_repo.create(session, obj_in=user_data)
await user_crud.update(session, db_obj=user, obj_in={"is_active": False}) await user_repo.update(session, db_obj=user, obj_in={"is_active": False})
assert user_crud.is_active(user) is False assert user_repo.is_active(user) is False
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_is_superuser_true(self, async_test_db, async_test_superuser): async def test_is_superuser_true(self, async_test_db, async_test_superuser):
@@ -591,22 +591,22 @@ class TestUtilityMethods:
_test_engine, AsyncTestingSessionLocal = async_test_db _test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session: async with AsyncTestingSessionLocal() as session:
user = await user_crud.get(session, id=str(async_test_superuser.id)) user = await user_repo.get(session, id=str(async_test_superuser.id))
assert user_crud.is_superuser(user) is True assert user_repo.is_superuser(user) is True
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_is_superuser_false(self, async_test_db, async_test_user): async def test_is_superuser_false(self, async_test_db, async_test_user):
"""Test is_superuser returns False for regular user_crud.""" """Test is_superuser returns False for regular user_repo."""
_test_engine, AsyncTestingSessionLocal = async_test_db _test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session: async with AsyncTestingSessionLocal() as session:
user = await user_crud.get(session, id=str(async_test_user.id)) user = await user_repo.get(session, id=str(async_test_user.id))
assert user_crud.is_superuser(user) is False assert user_repo.is_superuser(user) is False
class TestUserExceptionHandlers: class TestUserExceptionHandlers:
""" """
Test exception handlers in user CRUD methods. Test exception handlers in user repository methods.
Covers lines: 30-32, 205-208, 257-260 Covers lines: 30-32, 205-208, 257-260
""" """
@@ -622,7 +622,7 @@ class TestUserExceptionHandlers:
session, "execute", side_effect=Exception("Database query failed") session, "execute", side_effect=Exception("Database query failed")
): ):
with pytest.raises(Exception, match="Database query failed"): with pytest.raises(Exception, match="Database query failed"):
await user_crud.get_by_email(session, email="test@example.com") await user_repo.get_by_email(session, email="test@example.com")
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_bulk_update_status_database_error( async def test_bulk_update_status_database_error(
@@ -640,7 +640,7 @@ class TestUserExceptionHandlers:
): ):
with patch.object(session, "rollback", new_callable=AsyncMock): with patch.object(session, "rollback", new_callable=AsyncMock):
with pytest.raises(Exception, match="Bulk update failed"): with pytest.raises(Exception, match="Bulk update failed"):
await user_crud.bulk_update_status( await user_repo.bulk_update_status(
session, user_ids=[async_test_user.id], is_active=False session, user_ids=[async_test_user.id], is_active=False
) )
@@ -660,6 +660,6 @@ class TestUserExceptionHandlers:
): ):
with patch.object(session, "rollback", new_callable=AsyncMock): with patch.object(session, "rollback", new_callable=AsyncMock):
with pytest.raises(Exception, match="Bulk delete failed"): with pytest.raises(Exception, match="Bulk delete failed"):
await user_crud.bulk_soft_delete( await user_repo.bulk_soft_delete(
session, user_ids=[async_test_user.id] session, user_ids=[async_test_user.id]
) )

View File

@@ -206,13 +206,13 @@ class TestCleanupExpiredSessions:
"""Test cleanup returns 0 on database errors (doesn't crash).""" """Test cleanup returns 0 on database errors (doesn't crash)."""
_test_engine, AsyncTestingSessionLocal = async_test_db _test_engine, AsyncTestingSessionLocal = async_test_db
# Mock session_crud.cleanup_expired to raise error # Mock session_repo.cleanup_expired to raise error
with patch( with patch(
"app.services.session_cleanup.SessionLocal", "app.services.session_cleanup.SessionLocal",
return_value=AsyncTestingSessionLocal(), return_value=AsyncTestingSessionLocal(),
): ):
with patch( with patch(
"app.services.session_cleanup.session_crud.cleanup_expired" "app.services.session_cleanup.session_repo.cleanup_expired"
) as mock_cleanup: ) as mock_cleanup:
mock_cleanup.side_effect = Exception("Database connection lost") mock_cleanup.side_effect = Exception("Database connection lost")

View File

@@ -91,9 +91,9 @@ class TestInitDb:
"""Test that init_db handles database errors gracefully.""" """Test that init_db handles database errors gracefully."""
_test_engine, SessionLocal = async_test_db _test_engine, SessionLocal = async_test_db
# Mock user_crud.get_by_email to raise an exception # Mock user_repo.get_by_email to raise an exception
with patch( with patch(
"app.init_db.user_crud.get_by_email", "app.init_db.user_repo.get_by_email",
side_effect=Exception("Database error"), side_effect=Exception("Database error"),
): ):
with patch("app.init_db.SessionLocal", SessionLocal): with patch("app.init_db.SessionLocal", SessionLocal):

View File

@@ -29,7 +29,7 @@ Production-ready Next.js 16 frontend with TypeScript, authentication, admin pane
### Admin Panel ### Admin Panel
- 👥 **User Administration** - CRUD operations, search, filters - 👥 **User Administration** - Full lifecycle operations, search, filters
- 🏢 **Organization Management** - Multi-tenant support with roles - 🏢 **Organization Management** - Multi-tenant support with roles
- 📊 **Dashboard** - Statistics and quick actions - 📊 **Dashboard** - Statistics and quick actions
- 🔍 **Advanced Filtering** - Status, search, pagination - 🔍 **Advanced Filtering** - Status, search, pagination

View File

@@ -1040,7 +1040,7 @@ export default function AdminDashboardPage() {
These examples demonstrate: These examples demonstrate:
1. **Complete CRUD operations** (User Management) 1. **Complete management operations** (User Management)
2. **Real-time data with polling** (Session Management) 2. **Real-time data with polling** (Session Management)
3. **Data visualization** (Admin Dashboard Charts) 3. **Data visualization** (Admin Dashboard Charts)

View File

@@ -1780,7 +1780,7 @@ The frontend template will be considered complete when:
1. **Functionality:** 1. **Functionality:**
- All specified pages are implemented and functional - All specified pages are implemented and functional
- Authentication flow works end-to-end - Authentication flow works end-to-end
- User and organization CRUD operations work - User and organization management operations work
- API integration is complete and reliable - API integration is complete and reliable
2. **Code Quality:** 2. **Code Quality:**