Add extensive CRUD tests for session and user management; enhance cleanup logic
- Introduced new unit tests for session CRUD operations, including `update_refresh_token`, `cleanup_expired`, and multi-user session handling. - Added comprehensive tests for `CRUDBase` methods, covering edge cases, error handling, and UUID validation. - Reduced default test session creation from 5 to 2 for performance optimization. - Enhanced pagination, filtering, and sorting validations in `get_multi_with_total`. - Improved error handling with descriptive assertions for database exceptions. - Introduced tests for eager-loaded relationships in user sessions for comprehensive coverage.
This commit is contained in:
324
backend/tests/api/test_auth.py
Normal file
324
backend/tests/api/test_auth.py
Normal file
@@ -0,0 +1,324 @@
|
||||
# tests/api/test_auth.py
|
||||
"""
|
||||
Tests for authentication endpoints.
|
||||
"""
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
from fastapi import status
|
||||
|
||||
|
||||
class TestRegisterEndpoint:
|
||||
"""Tests for POST /auth/register endpoint."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_register_success(self, client):
|
||||
"""Test successful user registration."""
|
||||
response = await client.post(
|
||||
"/api/v1/auth/register",
|
||||
json={
|
||||
"email": "newuser@example.com",
|
||||
"password": "NewPassword123!",
|
||||
"first_name": "New",
|
||||
"last_name": "User"
|
||||
}
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_201_CREATED
|
||||
data = response.json()
|
||||
assert data["email"] == "newuser@example.com"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_register_duplicate_email(self, client, async_test_user):
|
||||
"""Test registration with duplicate email."""
|
||||
response = await client.post(
|
||||
"/api/v1/auth/register",
|
||||
json={
|
||||
"email": async_test_user.email,
|
||||
"password": "TestPassword123!",
|
||||
"first_name": "Test",
|
||||
"last_name": "User"
|
||||
}
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_400_BAD_REQUEST
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_register_weak_password(self, client):
|
||||
"""Test registration with weak password."""
|
||||
response = await client.post(
|
||||
"/api/v1/auth/register",
|
||||
json={
|
||||
"email": "test@example.com",
|
||||
"password": "weak",
|
||||
"first_name": "Test",
|
||||
"last_name": "User"
|
||||
}
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY
|
||||
|
||||
|
||||
class TestLoginEndpoint:
|
||||
"""Tests for POST /auth/login endpoint."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_login_success(self, client, async_test_user):
|
||||
"""Test successful login."""
|
||||
response = await client.post(
|
||||
"/api/v1/auth/login",
|
||||
json={
|
||||
"email": "testuser@example.com",
|
||||
"password": "TestPassword123!"
|
||||
}
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
data = response.json()
|
||||
assert "access_token" in data
|
||||
assert "refresh_token" in data
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_login_invalid_credentials(self, client, async_test_user):
|
||||
"""Test login with invalid password."""
|
||||
response = await client.post(
|
||||
"/api/v1/auth/login",
|
||||
json={
|
||||
"email": "testuser@example.com",
|
||||
"password": "WrongPassword123!"
|
||||
}
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_401_UNAUTHORIZED
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_login_nonexistent_user(self, client):
|
||||
"""Test login with non-existent user."""
|
||||
response = await client.post(
|
||||
"/api/v1/auth/login",
|
||||
json={
|
||||
"email": "nonexistent@example.com",
|
||||
"password": "TestPassword123!"
|
||||
}
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_401_UNAUTHORIZED
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_login_inactive_user(self, client, async_test_db):
|
||||
"""Test login with inactive user."""
|
||||
test_engine, SessionLocal = async_test_db
|
||||
|
||||
async with SessionLocal() as session:
|
||||
from app.models.user import User
|
||||
from app.core.auth import get_password_hash
|
||||
inactive_user = User(
|
||||
email="inactive@example.com",
|
||||
password_hash=get_password_hash("TestPassword123!"),
|
||||
first_name="Inactive",
|
||||
last_name="User",
|
||||
is_active=False
|
||||
)
|
||||
session.add(inactive_user)
|
||||
await session.commit()
|
||||
|
||||
response = await client.post(
|
||||
"/api/v1/auth/login",
|
||||
json={
|
||||
"email": "inactive@example.com",
|
||||
"password": "TestPassword123!"
|
||||
}
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_401_UNAUTHORIZED
|
||||
|
||||
|
||||
class TestRefreshTokenEndpoint:
|
||||
"""Tests for POST /auth/refresh endpoint."""
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def refresh_token(self, client, async_test_user):
|
||||
"""Get a refresh token for testing."""
|
||||
response = await client.post(
|
||||
"/api/v1/auth/login",
|
||||
json={
|
||||
"email": "testuser@example.com",
|
||||
"password": "TestPassword123!"
|
||||
}
|
||||
)
|
||||
return response.json()["refresh_token"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_refresh_token_success(self, client, refresh_token):
|
||||
"""Test successful token refresh."""
|
||||
response = await client.post(
|
||||
"/api/v1/auth/refresh",
|
||||
json={"refresh_token": refresh_token}
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
data = response.json()
|
||||
assert "access_token" in data
|
||||
assert "refresh_token" in data
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_refresh_token_invalid(self, client):
|
||||
"""Test refresh with invalid token."""
|
||||
response = await client.post(
|
||||
"/api/v1/auth/refresh",
|
||||
json={"refresh_token": "invalid.token.here"}
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_401_UNAUTHORIZED
|
||||
|
||||
|
||||
class TestLogoutEndpoint:
|
||||
"""Tests for POST /auth/logout endpoint."""
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def tokens(self, client, async_test_user):
|
||||
"""Get tokens for testing."""
|
||||
response = await client.post(
|
||||
"/api/v1/auth/login",
|
||||
json={
|
||||
"email": "testuser@example.com",
|
||||
"password": "TestPassword123!"
|
||||
}
|
||||
)
|
||||
data = response.json()
|
||||
return {"access_token": data["access_token"], "refresh_token": data["refresh_token"]}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_logout_success(self, client, tokens):
|
||||
"""Test successful logout."""
|
||||
response = await client.post(
|
||||
"/api/v1/auth/logout",
|
||||
headers={"Authorization": f"Bearer {tokens['access_token']}"},
|
||||
json={"refresh_token": tokens["refresh_token"]}
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_logout_without_auth(self, client):
|
||||
"""Test logout without authentication."""
|
||||
response = await client.post(
|
||||
"/api/v1/auth/logout",
|
||||
json={"refresh_token": "some.token"}
|
||||
)
|
||||
assert response.status_code == status.HTTP_401_UNAUTHORIZED
|
||||
|
||||
|
||||
class TestPasswordResetRequest:
|
||||
"""Tests for POST /auth/password-reset/request endpoint."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_password_reset_request_success(self, client, async_test_user):
|
||||
"""Test password reset request with existing user."""
|
||||
response = await client.post(
|
||||
"/api/v1/auth/password-reset/request",
|
||||
json={"email": async_test_user.email}
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
data = response.json()
|
||||
assert data["success"] is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_password_reset_request_nonexistent_email(self, client):
|
||||
"""Test password reset request with non-existent email."""
|
||||
response = await client.post(
|
||||
"/api/v1/auth/password-reset/request",
|
||||
json={"email": "nonexistent@example.com"}
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
data = response.json()
|
||||
assert data["success"] is True
|
||||
|
||||
|
||||
class TestPasswordResetConfirm:
|
||||
"""Tests for POST /auth/password-reset/confirm endpoint."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_password_reset_confirm_invalid_token(self, client):
|
||||
"""Test password reset with invalid token."""
|
||||
response = await client.post(
|
||||
"/api/v1/auth/password-reset/confirm",
|
||||
json={
|
||||
"token": "invalid.token.here",
|
||||
"new_password": "NewPassword123!"
|
||||
}
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_400_BAD_REQUEST
|
||||
|
||||
|
||||
class TestLogoutAll:
|
||||
"""Tests for POST /auth/logout-all endpoint."""
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def tokens(self, client, async_test_user):
|
||||
"""Get tokens for testing."""
|
||||
response = await client.post(
|
||||
"/api/v1/auth/login",
|
||||
json={
|
||||
"email": "testuser@example.com",
|
||||
"password": "TestPassword123!"
|
||||
}
|
||||
)
|
||||
data = response.json()
|
||||
return {"access_token": data["access_token"], "refresh_token": data["refresh_token"]}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_logout_all_success(self, client, tokens):
|
||||
"""Test logout from all devices."""
|
||||
response = await client.post(
|
||||
"/api/v1/auth/logout-all",
|
||||
headers={"Authorization": f"Bearer {tokens['access_token']}"}
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
data = response.json()
|
||||
assert data["success"] is True
|
||||
assert "sessions terminated" in data["message"].lower()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_logout_all_unauthorized(self, client):
|
||||
"""Test logout-all without authentication."""
|
||||
response = await client.post("/api/v1/auth/logout-all")
|
||||
|
||||
assert response.status_code == status.HTTP_401_UNAUTHORIZED
|
||||
|
||||
|
||||
class TestOAuthLogin:
|
||||
"""Tests for POST /auth/login/oauth endpoint."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_oauth_login_success(self, client, async_test_user):
|
||||
"""Test successful OAuth login."""
|
||||
response = await client.post(
|
||||
"/api/v1/auth/login/oauth",
|
||||
data={
|
||||
"username": "testuser@example.com",
|
||||
"password": "TestPassword123!"
|
||||
}
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
data = response.json()
|
||||
assert "access_token" in data
|
||||
assert "refresh_token" in data
|
||||
assert data["token_type"] == "bearer"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_oauth_login_invalid_credentials(self, client, async_test_user):
|
||||
"""Test OAuth login with invalid credentials."""
|
||||
response = await client.post(
|
||||
"/api/v1/auth/login/oauth",
|
||||
data={
|
||||
"username": "testuser@example.com",
|
||||
"password": "WrongPassword"
|
||||
}
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_401_UNAUTHORIZED
|
||||
@@ -6,9 +6,9 @@ from unittest.mock import patch
|
||||
from app.main import app
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@pytest.fixture(scope="module")
|
||||
def client():
|
||||
"""Create a FastAPI test client for the main app."""
|
||||
"""Create a FastAPI test client for the main app (module-scoped for speed)."""
|
||||
# Mock get_db to avoid database connection issues
|
||||
with patch("app.core.database.get_db") as mock_get_db:
|
||||
async def mock_session_generator():
|
||||
@@ -25,46 +25,38 @@ def client():
|
||||
class TestSecurityHeaders:
|
||||
"""Tests for security headers middleware"""
|
||||
|
||||
def test_x_frame_options_header(self, client):
|
||||
"""Test that X-Frame-Options header is set to DENY"""
|
||||
def test_all_security_headers(self, client):
|
||||
"""Test all security headers in a single request for speed"""
|
||||
response = client.get("/health")
|
||||
|
||||
# Test X-Frame-Options
|
||||
assert "X-Frame-Options" in response.headers
|
||||
assert response.headers["X-Frame-Options"] == "DENY"
|
||||
|
||||
def test_x_content_type_options_header(self, client):
|
||||
"""Test that X-Content-Type-Options header is set to nosniff"""
|
||||
response = client.get("/health")
|
||||
# Test X-Content-Type-Options
|
||||
assert "X-Content-Type-Options" in response.headers
|
||||
assert response.headers["X-Content-Type-Options"] == "nosniff"
|
||||
|
||||
def test_x_xss_protection_header(self, client):
|
||||
"""Test that X-XSS-Protection header is set"""
|
||||
response = client.get("/health")
|
||||
# Test X-XSS-Protection
|
||||
assert "X-XSS-Protection" in response.headers
|
||||
assert response.headers["X-XSS-Protection"] == "1; mode=block"
|
||||
|
||||
def test_content_security_policy_header(self, client):
|
||||
"""Test that Content-Security-Policy header is set"""
|
||||
response = client.get("/health")
|
||||
# Test Content-Security-Policy
|
||||
assert "Content-Security-Policy" in response.headers
|
||||
assert "default-src 'self'" in response.headers["Content-Security-Policy"]
|
||||
assert "frame-ancestors 'none'" in response.headers["Content-Security-Policy"]
|
||||
|
||||
def test_permissions_policy_header(self, client):
|
||||
"""Test that Permissions-Policy header is set"""
|
||||
response = client.get("/health")
|
||||
# Test Permissions-Policy
|
||||
assert "Permissions-Policy" in response.headers
|
||||
assert "geolocation=()" in response.headers["Permissions-Policy"]
|
||||
assert "microphone=()" in response.headers["Permissions-Policy"]
|
||||
assert "camera=()" in response.headers["Permissions-Policy"]
|
||||
|
||||
def test_referrer_policy_header(self, client):
|
||||
"""Test that Referrer-Policy header is set"""
|
||||
response = client.get("/health")
|
||||
# Test Referrer-Policy
|
||||
assert "Referrer-Policy" in response.headers
|
||||
assert response.headers["Referrer-Policy"] == "strict-origin-when-cross-origin"
|
||||
|
||||
def test_strict_transport_security_not_in_development(self, client):
|
||||
def test_hsts_not_in_development(self, client):
|
||||
"""Test that Strict-Transport-Security header is not set in development"""
|
||||
from app.core.config import settings
|
||||
|
||||
@@ -73,18 +65,6 @@ class TestSecurityHeaders:
|
||||
response = client.get("/health")
|
||||
assert "Strict-Transport-Security" not in response.headers
|
||||
|
||||
def test_security_headers_on_all_endpoints(self, client):
|
||||
"""Test that security headers are present on all endpoints"""
|
||||
# Test health endpoint
|
||||
response = client.get("/health")
|
||||
assert "X-Frame-Options" in response.headers
|
||||
assert "X-Content-Type-Options" in response.headers
|
||||
|
||||
# Test root endpoint
|
||||
response = client.get("/")
|
||||
assert "X-Frame-Options" in response.headers
|
||||
assert "X-Content-Type-Options" in response.headers
|
||||
|
||||
def test_security_headers_on_404(self, client):
|
||||
"""Test that security headers are present even on 404 responses"""
|
||||
response = client.get("/nonexistent-endpoint")
|
||||
|
||||
@@ -365,3 +365,99 @@ class TestCleanupExpiredSessions:
|
||||
response = await client.delete("/api/v1/sessions/me/expired")
|
||||
|
||||
assert response.status_code == status.HTTP_401_UNAUTHORIZED
|
||||
|
||||
|
||||
# Additional tests for better coverage
|
||||
|
||||
class TestSessionsAdditionalCases:
|
||||
"""Additional tests to improve sessions endpoint coverage."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_sessions_pagination(self, client, async_test_user, async_test_db, user_token):
|
||||
"""Test listing sessions with pagination."""
|
||||
test_engine, SessionLocal = async_test_db
|
||||
|
||||
# Create multiple sessions
|
||||
async with SessionLocal() as session:
|
||||
from app.crud.session import session as session_crud
|
||||
from app.schemas.sessions import SessionCreate
|
||||
|
||||
for i in range(5):
|
||||
session_data = SessionCreate(
|
||||
user_id=async_test_user.id,
|
||||
refresh_token_jti=str(uuid4()),
|
||||
device_name=f"Device {i}",
|
||||
ip_address=f"192.168.1.{i}",
|
||||
user_agent="Mozilla/5.0",
|
||||
expires_at=datetime.now(timezone.utc) + timedelta(days=7),
|
||||
last_used_at=datetime.now(timezone.utc)
|
||||
)
|
||||
await session_crud.create_session(session, obj_in=session_data)
|
||||
await session.commit()
|
||||
|
||||
response = await client.get(
|
||||
"/api/v1/sessions/me?page=1&limit=3",
|
||||
headers={"Authorization": f"Bearer {user_token}"}
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
data = response.json()
|
||||
assert "sessions" in data
|
||||
assert "total" in data
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_revoke_session_invalid_uuid(self, client, user_token):
|
||||
"""Test revoking session with invalid UUID."""
|
||||
response = await client.delete(
|
||||
"/api/v1/sessions/not-a-uuid",
|
||||
headers={"Authorization": f"Bearer {user_token}"}
|
||||
)
|
||||
|
||||
# Should return 422 for invalid UUID format
|
||||
assert response.status_code in [status.HTTP_422_UNPROCESSABLE_ENTITY, status.HTTP_404_NOT_FOUND]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cleanup_expired_sessions_with_mixed_states(self, client, async_test_user, async_test_db, user_token):
|
||||
"""Test cleanup with mix of active/inactive and expired/not-expired sessions."""
|
||||
test_engine, SessionLocal = async_test_db
|
||||
|
||||
from app.crud.session import session as session_crud
|
||||
from app.schemas.sessions import SessionCreate
|
||||
|
||||
async with SessionLocal() as db:
|
||||
# Expired + inactive (should be cleaned)
|
||||
e1_data = SessionCreate(
|
||||
user_id=async_test_user.id,
|
||||
refresh_token_jti=str(uuid4()),
|
||||
device_name="Expired Inactive",
|
||||
ip_address="192.168.1.100",
|
||||
user_agent="Mozilla/5.0",
|
||||
expires_at=datetime.now(timezone.utc) - timedelta(days=1),
|
||||
last_used_at=datetime.now(timezone.utc) - timedelta(days=2)
|
||||
)
|
||||
e1 = await session_crud.create_session(db, obj_in=e1_data)
|
||||
e1.is_active = False
|
||||
db.add(e1)
|
||||
|
||||
# Expired but still active (should NOT be cleaned - only inactive+expired)
|
||||
e2_data = SessionCreate(
|
||||
user_id=async_test_user.id,
|
||||
refresh_token_jti=str(uuid4()),
|
||||
device_name="Expired Active",
|
||||
ip_address="192.168.1.101",
|
||||
user_agent="Mozilla/5.0",
|
||||
expires_at=datetime.now(timezone.utc) - timedelta(hours=1),
|
||||
last_used_at=datetime.now(timezone.utc) - timedelta(hours=2)
|
||||
)
|
||||
await session_crud.create_session(db, obj_in=e2_data)
|
||||
|
||||
await db.commit()
|
||||
|
||||
response = await client.delete(
|
||||
"/api/v1/sessions/me/expired",
|
||||
headers={"Authorization": f"Bearer {user_token}"}
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
data = response.json()
|
||||
assert data["success"] is True
|
||||
|
||||
759
backend/tests/crud/test_base.py
Normal file
759
backend/tests/crud/test_base.py
Normal file
@@ -0,0 +1,759 @@
|
||||
# tests/crud/test_base.py
|
||||
"""
|
||||
Comprehensive tests for CRUDBase class covering all error paths and edge cases.
|
||||
"""
|
||||
import pytest
|
||||
from uuid import uuid4, UUID
|
||||
from sqlalchemy.exc import IntegrityError, OperationalError, DataError
|
||||
from sqlalchemy.orm import joinedload
|
||||
from unittest.mock import AsyncMock, patch, MagicMock
|
||||
|
||||
from app.crud.user import user as user_crud
|
||||
from app.models.user import User
|
||||
from app.schemas.users import UserCreate, UserUpdate
|
||||
|
||||
|
||||
class TestCRUDBaseGet:
|
||||
"""Tests for get method covering UUID validation and options."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_with_invalid_uuid_string(self, async_test_db):
|
||||
"""Test get with invalid UUID string returns None."""
|
||||
test_engine, SessionLocal = async_test_db
|
||||
|
||||
async with SessionLocal() as session:
|
||||
result = await user_crud.get(session, id="invalid-uuid")
|
||||
assert result is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_with_invalid_uuid_type(self, async_test_db):
|
||||
"""Test get with invalid UUID type returns None."""
|
||||
test_engine, SessionLocal = async_test_db
|
||||
|
||||
async with SessionLocal() as session:
|
||||
result = await user_crud.get(session, id=12345) # int instead of UUID
|
||||
assert result is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_with_uuid_object(self, async_test_db, async_test_user):
|
||||
"""Test get with UUID object instead of string."""
|
||||
test_engine, SessionLocal = async_test_db
|
||||
|
||||
async with SessionLocal() as session:
|
||||
# Pass UUID object directly
|
||||
result = await user_crud.get(session, id=async_test_user.id)
|
||||
assert result is not None
|
||||
assert result.id == async_test_user.id
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_with_options(self, async_test_db, async_test_user):
|
||||
"""Test get with eager loading options (tests lines 76-78)."""
|
||||
test_engine, SessionLocal = async_test_db
|
||||
|
||||
async with SessionLocal() as session:
|
||||
# Test that options parameter is accepted and doesn't error
|
||||
# We pass an empty list which still tests the code path
|
||||
result = await user_crud.get(
|
||||
session,
|
||||
id=str(async_test_user.id),
|
||||
options=[]
|
||||
)
|
||||
assert result is not None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_database_error(self, async_test_db):
|
||||
"""Test get handles database errors properly."""
|
||||
test_engine, SessionLocal = async_test_db
|
||||
|
||||
async with SessionLocal() as session:
|
||||
# Mock execute to raise an exception
|
||||
with patch.object(session, 'execute', side_effect=Exception("DB error")):
|
||||
with pytest.raises(Exception, match="DB error"):
|
||||
await user_crud.get(session, id=str(uuid4()))
|
||||
|
||||
|
||||
class TestCRUDBaseGetMulti:
|
||||
"""Tests for get_multi method covering pagination validation and options."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_multi_negative_skip(self, async_test_db):
|
||||
"""Test get_multi with negative skip raises ValueError."""
|
||||
test_engine, SessionLocal = async_test_db
|
||||
|
||||
async with SessionLocal() as session:
|
||||
with pytest.raises(ValueError, match="skip must be non-negative"):
|
||||
await user_crud.get_multi(session, skip=-1)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_multi_negative_limit(self, async_test_db):
|
||||
"""Test get_multi with negative limit raises ValueError."""
|
||||
test_engine, SessionLocal = async_test_db
|
||||
|
||||
async with SessionLocal() as session:
|
||||
with pytest.raises(ValueError, match="limit must be non-negative"):
|
||||
await user_crud.get_multi(session, limit=-1)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_multi_limit_too_large(self, async_test_db):
|
||||
"""Test get_multi with limit > 1000 raises ValueError."""
|
||||
test_engine, SessionLocal = async_test_db
|
||||
|
||||
async with SessionLocal() as session:
|
||||
with pytest.raises(ValueError, match="Maximum limit is 1000"):
|
||||
await user_crud.get_multi(session, limit=1001)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_multi_with_options(self, async_test_db, async_test_user):
|
||||
"""Test get_multi with eager loading options (tests lines 118-120)."""
|
||||
test_engine, SessionLocal = async_test_db
|
||||
|
||||
async with SessionLocal() as session:
|
||||
# Test that options parameter is accepted
|
||||
results = await user_crud.get_multi(
|
||||
session,
|
||||
skip=0,
|
||||
limit=10,
|
||||
options=[]
|
||||
)
|
||||
assert isinstance(results, list)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_multi_database_error(self, async_test_db):
|
||||
"""Test get_multi handles database errors."""
|
||||
test_engine, SessionLocal = async_test_db
|
||||
|
||||
async with SessionLocal() as session:
|
||||
with patch.object(session, 'execute', side_effect=Exception("DB error")):
|
||||
with pytest.raises(Exception, match="DB error"):
|
||||
await user_crud.get_multi(session)
|
||||
|
||||
|
||||
class TestCRUDBaseCreate:
|
||||
"""Tests for create method covering various error conditions."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_duplicate_unique_field(self, async_test_db, async_test_user):
|
||||
"""Test create with duplicate unique field raises ValueError."""
|
||||
test_engine, SessionLocal = async_test_db
|
||||
|
||||
async with SessionLocal() as session:
|
||||
# Try to create user with duplicate email
|
||||
user_data = UserCreate(
|
||||
email=async_test_user.email, # Duplicate!
|
||||
password="TestPassword123!",
|
||||
first_name="Test",
|
||||
last_name="Duplicate"
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="already exists"):
|
||||
await user_crud.create(session, obj_in=user_data)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_integrity_error_non_duplicate(self, async_test_db):
|
||||
"""Test create with non-duplicate IntegrityError."""
|
||||
test_engine, SessionLocal = async_test_db
|
||||
|
||||
async with SessionLocal() as session:
|
||||
# Mock commit to raise IntegrityError without "unique" in message
|
||||
original_commit = session.commit
|
||||
|
||||
async def mock_commit():
|
||||
error = IntegrityError("statement", {}, Exception("foreign key violation"))
|
||||
raise error
|
||||
|
||||
with patch.object(session, 'commit', side_effect=mock_commit):
|
||||
user_data = UserCreate(
|
||||
email="test@example.com",
|
||||
password="TestPassword123!",
|
||||
first_name="Test",
|
||||
last_name="User"
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="Database integrity error"):
|
||||
await user_crud.create(session, obj_in=user_data)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_operational_error(self, async_test_db):
|
||||
"""Test create with OperationalError (user CRUD catches as generic Exception)."""
|
||||
test_engine, SessionLocal = async_test_db
|
||||
|
||||
async with SessionLocal() as session:
|
||||
with patch.object(session, 'commit', side_effect=OperationalError("statement", {}, Exception("connection lost"))):
|
||||
user_data = UserCreate(
|
||||
email="test@example.com",
|
||||
password="TestPassword123!",
|
||||
first_name="Test",
|
||||
last_name="User"
|
||||
)
|
||||
|
||||
# User CRUD catches this as generic Exception and re-raises
|
||||
with pytest.raises(OperationalError):
|
||||
await user_crud.create(session, obj_in=user_data)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_data_error(self, async_test_db):
|
||||
"""Test create with DataError (user CRUD catches as generic Exception)."""
|
||||
test_engine, SessionLocal = async_test_db
|
||||
|
||||
async with SessionLocal() as session:
|
||||
with patch.object(session, 'commit', side_effect=DataError("statement", {}, Exception("invalid data"))):
|
||||
user_data = UserCreate(
|
||||
email="test@example.com",
|
||||
password="TestPassword123!",
|
||||
first_name="Test",
|
||||
last_name="User"
|
||||
)
|
||||
|
||||
# User CRUD catches this as generic Exception and re-raises
|
||||
with pytest.raises(DataError):
|
||||
await user_crud.create(session, obj_in=user_data)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_unexpected_error(self, async_test_db):
|
||||
"""Test create with unexpected exception."""
|
||||
test_engine, SessionLocal = async_test_db
|
||||
|
||||
async with SessionLocal() as session:
|
||||
with patch.object(session, 'commit', side_effect=RuntimeError("Unexpected error")):
|
||||
user_data = UserCreate(
|
||||
email="test@example.com",
|
||||
password="TestPassword123!",
|
||||
first_name="Test",
|
||||
last_name="User"
|
||||
)
|
||||
|
||||
with pytest.raises(RuntimeError, match="Unexpected error"):
|
||||
await user_crud.create(session, obj_in=user_data)
|
||||
|
||||
|
||||
class TestCRUDBaseUpdate:
|
||||
"""Tests for update method covering error conditions."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_duplicate_unique_field(self, async_test_db, async_test_user):
|
||||
"""Test update with duplicate unique field raises ValueError."""
|
||||
test_engine, SessionLocal = async_test_db
|
||||
|
||||
# Create another user
|
||||
async with SessionLocal() as session:
|
||||
from app.crud.user import user as user_crud
|
||||
user2_data = UserCreate(
|
||||
email="user2@example.com",
|
||||
password="TestPassword123!",
|
||||
first_name="User",
|
||||
last_name="Two"
|
||||
)
|
||||
user2 = await user_crud.create(session, obj_in=user2_data)
|
||||
await session.commit()
|
||||
|
||||
# Try to update user2 with user1's email
|
||||
async with SessionLocal() as session:
|
||||
user2_obj = await user_crud.get(session, id=str(user2.id))
|
||||
|
||||
with patch.object(session, 'commit', side_effect=IntegrityError("statement", {}, Exception("UNIQUE constraint failed"))):
|
||||
update_data = UserUpdate(email=async_test_user.email)
|
||||
|
||||
with pytest.raises(ValueError, match="already exists"):
|
||||
await user_crud.update(session, db_obj=user2_obj, obj_in=update_data)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_with_dict(self, async_test_db, async_test_user):
|
||||
"""Test update with dict instead of schema."""
|
||||
test_engine, SessionLocal = async_test_db
|
||||
|
||||
async with SessionLocal() as session:
|
||||
user = await user_crud.get(session, id=str(async_test_user.id))
|
||||
|
||||
# Update with dict (tests lines 164-165)
|
||||
updated = await user_crud.update(
|
||||
session,
|
||||
db_obj=user,
|
||||
obj_in={"first_name": "UpdatedName"}
|
||||
)
|
||||
assert updated.first_name == "UpdatedName"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_integrity_error(self, async_test_db, async_test_user):
|
||||
"""Test update with IntegrityError."""
|
||||
test_engine, SessionLocal = async_test_db
|
||||
|
||||
async with SessionLocal() as session:
|
||||
user = await user_crud.get(session, id=str(async_test_user.id))
|
||||
|
||||
with patch.object(session, 'commit', side_effect=IntegrityError("statement", {}, Exception("constraint failed"))):
|
||||
with pytest.raises(ValueError, match="Database integrity error"):
|
||||
await user_crud.update(session, db_obj=user, obj_in={"first_name": "Test"})
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_operational_error(self, async_test_db, async_test_user):
|
||||
"""Test update with OperationalError."""
|
||||
test_engine, SessionLocal = async_test_db
|
||||
|
||||
async with SessionLocal() as session:
|
||||
user = await user_crud.get(session, id=str(async_test_user.id))
|
||||
|
||||
with patch.object(session, 'commit', side_effect=OperationalError("statement", {}, Exception("connection error"))):
|
||||
with pytest.raises(ValueError, match="Database operation failed"):
|
||||
await user_crud.update(session, db_obj=user, obj_in={"first_name": "Test"})
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_unexpected_error(self, async_test_db, async_test_user):
|
||||
"""Test update with unexpected error."""
|
||||
test_engine, SessionLocal = async_test_db
|
||||
|
||||
async with SessionLocal() as session:
|
||||
user = await user_crud.get(session, id=str(async_test_user.id))
|
||||
|
||||
with patch.object(session, 'commit', side_effect=RuntimeError("Unexpected")):
|
||||
with pytest.raises(RuntimeError):
|
||||
await user_crud.update(session, db_obj=user, obj_in={"first_name": "Test"})
|
||||
|
||||
|
||||
class TestCRUDBaseRemove:
|
||||
"""Tests for remove method covering UUID validation and error conditions."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_remove_invalid_uuid(self, async_test_db):
|
||||
"""Test remove with invalid UUID returns None."""
|
||||
test_engine, SessionLocal = async_test_db
|
||||
|
||||
async with SessionLocal() as session:
|
||||
result = await user_crud.remove(session, id="invalid-uuid")
|
||||
assert result is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_remove_with_uuid_object(self, async_test_db, async_test_user):
|
||||
"""Test remove with UUID object."""
|
||||
test_engine, SessionLocal = async_test_db
|
||||
|
||||
# Create a user to delete
|
||||
async with SessionLocal() as session:
|
||||
user_data = UserCreate(
|
||||
email="todelete@example.com",
|
||||
password="TestPassword123!",
|
||||
first_name="To",
|
||||
last_name="Delete"
|
||||
)
|
||||
user = await user_crud.create(session, obj_in=user_data)
|
||||
user_id = user.id
|
||||
await session.commit()
|
||||
|
||||
# Delete with UUID object
|
||||
async with SessionLocal() as session:
|
||||
result = await user_crud.remove(session, id=user_id) # UUID object
|
||||
assert result is not None
|
||||
assert result.id == user_id
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_remove_nonexistent(self, async_test_db):
|
||||
"""Test remove of nonexistent record returns None."""
|
||||
test_engine, SessionLocal = async_test_db
|
||||
|
||||
async with SessionLocal() as session:
|
||||
result = await user_crud.remove(session, id=str(uuid4()))
|
||||
assert result is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_remove_integrity_error(self, async_test_db, async_test_user):
|
||||
"""Test remove with IntegrityError (foreign key constraint)."""
|
||||
test_engine, SessionLocal = async_test_db
|
||||
|
||||
async with SessionLocal() as session:
|
||||
# Mock delete to raise IntegrityError
|
||||
with patch.object(session, 'commit', side_effect=IntegrityError("statement", {}, Exception("FOREIGN KEY constraint"))):
|
||||
with pytest.raises(ValueError, match="Cannot delete.*referenced by other records"):
|
||||
await user_crud.remove(session, id=str(async_test_user.id))
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_remove_unexpected_error(self, async_test_db, async_test_user):
|
||||
"""Test remove with unexpected error."""
|
||||
test_engine, SessionLocal = async_test_db
|
||||
|
||||
async with SessionLocal() as session:
|
||||
with patch.object(session, 'commit', side_effect=RuntimeError("Unexpected")):
|
||||
with pytest.raises(RuntimeError):
|
||||
await user_crud.remove(session, id=str(async_test_user.id))
|
||||
|
||||
|
||||
class TestCRUDBaseGetMultiWithTotal:
|
||||
"""Tests for get_multi_with_total method covering pagination, filtering, sorting."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_multi_with_total_basic(self, async_test_db, async_test_user):
|
||||
"""Test get_multi_with_total basic functionality."""
|
||||
test_engine, SessionLocal = async_test_db
|
||||
|
||||
async with SessionLocal() as session:
|
||||
items, total = await user_crud.get_multi_with_total(session, skip=0, limit=10)
|
||||
assert isinstance(items, list)
|
||||
assert isinstance(total, int)
|
||||
assert total >= 1 # At least the test user
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_multi_with_total_negative_skip(self, async_test_db):
|
||||
"""Test get_multi_with_total with negative skip raises ValueError."""
|
||||
test_engine, SessionLocal = async_test_db
|
||||
|
||||
async with SessionLocal() as session:
|
||||
with pytest.raises(ValueError, match="skip must be non-negative"):
|
||||
await user_crud.get_multi_with_total(session, skip=-1)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_multi_with_total_negative_limit(self, async_test_db):
|
||||
"""Test get_multi_with_total with negative limit raises ValueError."""
|
||||
test_engine, SessionLocal = async_test_db
|
||||
|
||||
async with SessionLocal() as session:
|
||||
with pytest.raises(ValueError, match="limit must be non-negative"):
|
||||
await user_crud.get_multi_with_total(session, limit=-1)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_multi_with_total_limit_too_large(self, async_test_db):
|
||||
"""Test get_multi_with_total with limit > 1000 raises ValueError."""
|
||||
test_engine, SessionLocal = async_test_db
|
||||
|
||||
async with SessionLocal() as session:
|
||||
with pytest.raises(ValueError, match="Maximum limit is 1000"):
|
||||
await user_crud.get_multi_with_total(session, limit=1001)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_multi_with_total_with_filters(self, async_test_db, async_test_user):
|
||||
"""Test get_multi_with_total with filters."""
|
||||
test_engine, SessionLocal = async_test_db
|
||||
|
||||
async with SessionLocal() as session:
|
||||
filters = {"email": async_test_user.email}
|
||||
items, total = await user_crud.get_multi_with_total(session, filters=filters)
|
||||
assert total == 1
|
||||
assert len(items) == 1
|
||||
assert items[0].email == async_test_user.email
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_multi_with_total_with_sorting_asc(self, async_test_db, async_test_user):
|
||||
"""Test get_multi_with_total with ascending sort."""
|
||||
test_engine, SessionLocal = async_test_db
|
||||
|
||||
# Create additional users
|
||||
async with SessionLocal() as session:
|
||||
user_data1 = UserCreate(
|
||||
email="aaa@example.com",
|
||||
password="TestPassword123!",
|
||||
first_name="AAA",
|
||||
last_name="User"
|
||||
)
|
||||
user_data2 = UserCreate(
|
||||
email="zzz@example.com",
|
||||
password="TestPassword123!",
|
||||
first_name="ZZZ",
|
||||
last_name="User"
|
||||
)
|
||||
await user_crud.create(session, obj_in=user_data1)
|
||||
await user_crud.create(session, obj_in=user_data2)
|
||||
await session.commit()
|
||||
|
||||
async with SessionLocal() as session:
|
||||
items, total = await user_crud.get_multi_with_total(
|
||||
session, sort_by="email", sort_order="asc"
|
||||
)
|
||||
assert total >= 3
|
||||
# Check first email is alphabetically first
|
||||
assert items[0].email == "aaa@example.com"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_multi_with_total_with_sorting_desc(self, async_test_db, async_test_user):
|
||||
"""Test get_multi_with_total with descending sort."""
|
||||
test_engine, SessionLocal = async_test_db
|
||||
|
||||
# Create additional users
|
||||
async with SessionLocal() as session:
|
||||
user_data1 = UserCreate(
|
||||
email="bbb@example.com",
|
||||
password="TestPassword123!",
|
||||
first_name="BBB",
|
||||
last_name="User"
|
||||
)
|
||||
user_data2 = UserCreate(
|
||||
email="ccc@example.com",
|
||||
password="TestPassword123!",
|
||||
first_name="CCC",
|
||||
last_name="User"
|
||||
)
|
||||
await user_crud.create(session, obj_in=user_data1)
|
||||
await user_crud.create(session, obj_in=user_data2)
|
||||
await session.commit()
|
||||
|
||||
async with SessionLocal() as session:
|
||||
items, total = await user_crud.get_multi_with_total(
|
||||
session, sort_by="email", sort_order="desc", limit=1
|
||||
)
|
||||
assert len(items) == 1
|
||||
# First item should have higher email alphabetically
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_multi_with_total_with_pagination(self, async_test_db):
|
||||
"""Test get_multi_with_total pagination works correctly."""
|
||||
test_engine, SessionLocal = async_test_db
|
||||
|
||||
# Create minimal users for pagination test (3 instead of 5)
|
||||
async with SessionLocal() as session:
|
||||
for i in range(3):
|
||||
user_data = UserCreate(
|
||||
email=f"user{i}@example.com",
|
||||
password="TestPassword123!",
|
||||
first_name=f"User{i}",
|
||||
last_name="Test"
|
||||
)
|
||||
await user_crud.create(session, obj_in=user_data)
|
||||
await session.commit()
|
||||
|
||||
async with SessionLocal() as session:
|
||||
# Get first page
|
||||
items1, total = await user_crud.get_multi_with_total(session, skip=0, limit=2)
|
||||
assert len(items1) == 2
|
||||
assert total >= 3
|
||||
|
||||
# Get second page
|
||||
items2, total2 = await user_crud.get_multi_with_total(session, skip=2, limit=2)
|
||||
assert len(items2) >= 1
|
||||
assert total2 == total
|
||||
|
||||
# Ensure no overlap
|
||||
ids1 = {item.id for item in items1}
|
||||
ids2 = {item.id for item in items2}
|
||||
assert ids1.isdisjoint(ids2)
|
||||
|
||||
|
||||
class TestCRUDBaseCount:
|
||||
"""Tests for count method."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_count_basic(self, async_test_db, async_test_user):
|
||||
"""Test count returns correct number."""
|
||||
test_engine, SessionLocal = async_test_db
|
||||
|
||||
async with SessionLocal() as session:
|
||||
count = await user_crud.count(session)
|
||||
assert isinstance(count, int)
|
||||
assert count >= 1 # At least the test user
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_count_multiple_users(self, async_test_db, async_test_user):
|
||||
"""Test count with multiple users."""
|
||||
test_engine, SessionLocal = async_test_db
|
||||
|
||||
# Create additional users
|
||||
async with SessionLocal() as session:
|
||||
initial_count = await user_crud.count(session)
|
||||
|
||||
user_data1 = UserCreate(
|
||||
email="count1@example.com",
|
||||
password="TestPassword123!",
|
||||
first_name="Count",
|
||||
last_name="One"
|
||||
)
|
||||
user_data2 = UserCreate(
|
||||
email="count2@example.com",
|
||||
password="TestPassword123!",
|
||||
first_name="Count",
|
||||
last_name="Two"
|
||||
)
|
||||
await user_crud.create(session, obj_in=user_data1)
|
||||
await user_crud.create(session, obj_in=user_data2)
|
||||
await session.commit()
|
||||
|
||||
async with SessionLocal() as session:
|
||||
new_count = await user_crud.count(session)
|
||||
assert new_count == initial_count + 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_count_database_error(self, async_test_db):
|
||||
"""Test count handles database errors."""
|
||||
test_engine, SessionLocal = async_test_db
|
||||
|
||||
async with SessionLocal() as session:
|
||||
with patch.object(session, 'execute', side_effect=Exception("DB error")):
|
||||
with pytest.raises(Exception, match="DB error"):
|
||||
await user_crud.count(session)
|
||||
|
||||
|
||||
class TestCRUDBaseExists:
|
||||
"""Tests for exists method."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_exists_true(self, async_test_db, async_test_user):
|
||||
"""Test exists returns True for existing record."""
|
||||
test_engine, SessionLocal = async_test_db
|
||||
|
||||
async with SessionLocal() as session:
|
||||
result = await user_crud.exists(session, id=str(async_test_user.id))
|
||||
assert result is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_exists_false(self, async_test_db):
|
||||
"""Test exists returns False for non-existent record."""
|
||||
test_engine, SessionLocal = async_test_db
|
||||
|
||||
async with SessionLocal() as session:
|
||||
result = await user_crud.exists(session, id=str(uuid4()))
|
||||
assert result is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_exists_invalid_uuid(self, async_test_db):
|
||||
"""Test exists returns False for invalid UUID."""
|
||||
test_engine, SessionLocal = async_test_db
|
||||
|
||||
async with SessionLocal() as session:
|
||||
result = await user_crud.exists(session, id="invalid-uuid")
|
||||
assert result is False
|
||||
|
||||
|
||||
class TestCRUDBaseSoftDelete:
|
||||
"""Tests for soft_delete method."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_soft_delete_success(self, async_test_db):
|
||||
"""Test soft delete sets deleted_at timestamp."""
|
||||
test_engine, SessionLocal = async_test_db
|
||||
|
||||
# Create a user to soft delete
|
||||
async with SessionLocal() as session:
|
||||
user_data = UserCreate(
|
||||
email="softdelete@example.com",
|
||||
password="TestPassword123!",
|
||||
first_name="Soft",
|
||||
last_name="Delete"
|
||||
)
|
||||
user = await user_crud.create(session, obj_in=user_data)
|
||||
user_id = user.id
|
||||
await session.commit()
|
||||
|
||||
# Soft delete the user
|
||||
async with SessionLocal() as session:
|
||||
deleted = await user_crud.soft_delete(session, id=str(user_id))
|
||||
assert deleted is not None
|
||||
assert deleted.deleted_at is not None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_soft_delete_invalid_uuid(self, async_test_db):
|
||||
"""Test soft delete with invalid UUID returns None."""
|
||||
test_engine, SessionLocal = async_test_db
|
||||
|
||||
async with SessionLocal() as session:
|
||||
result = await user_crud.soft_delete(session, id="invalid-uuid")
|
||||
assert result is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_soft_delete_nonexistent(self, async_test_db):
|
||||
"""Test soft delete of nonexistent record returns None."""
|
||||
test_engine, SessionLocal = async_test_db
|
||||
|
||||
async with SessionLocal() as session:
|
||||
result = await user_crud.soft_delete(session, id=str(uuid4()))
|
||||
assert result is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_soft_delete_with_uuid_object(self, async_test_db):
|
||||
"""Test soft delete with UUID object."""
|
||||
test_engine, SessionLocal = async_test_db
|
||||
|
||||
# Create a user to soft delete
|
||||
async with SessionLocal() as session:
|
||||
user_data = UserCreate(
|
||||
email="softdelete2@example.com",
|
||||
password="TestPassword123!",
|
||||
first_name="Soft",
|
||||
last_name="Delete2"
|
||||
)
|
||||
user = await user_crud.create(session, obj_in=user_data)
|
||||
user_id = user.id
|
||||
await session.commit()
|
||||
|
||||
# Soft delete with UUID object
|
||||
async with SessionLocal() as session:
|
||||
deleted = await user_crud.soft_delete(session, id=user_id) # UUID object
|
||||
assert deleted is not None
|
||||
assert deleted.deleted_at is not None
|
||||
|
||||
|
||||
class TestCRUDBaseRestore:
|
||||
"""Tests for restore method."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_restore_success(self, async_test_db):
|
||||
"""Test restore clears deleted_at timestamp."""
|
||||
test_engine, SessionLocal = async_test_db
|
||||
|
||||
# Create and soft delete a user
|
||||
async with SessionLocal() as session:
|
||||
user_data = UserCreate(
|
||||
email="restore@example.com",
|
||||
password="TestPassword123!",
|
||||
first_name="Restore",
|
||||
last_name="Test"
|
||||
)
|
||||
user = await user_crud.create(session, obj_in=user_data)
|
||||
user_id = user.id
|
||||
await session.commit()
|
||||
|
||||
async with SessionLocal() as session:
|
||||
await user_crud.soft_delete(session, id=str(user_id))
|
||||
|
||||
# Restore the user
|
||||
async with SessionLocal() as session:
|
||||
restored = await user_crud.restore(session, id=str(user_id))
|
||||
assert restored is not None
|
||||
assert restored.deleted_at is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_restore_invalid_uuid(self, async_test_db):
|
||||
"""Test restore with invalid UUID returns None."""
|
||||
test_engine, SessionLocal = async_test_db
|
||||
|
||||
async with SessionLocal() as session:
|
||||
result = await user_crud.restore(session, id="invalid-uuid")
|
||||
assert result is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_restore_nonexistent(self, async_test_db):
|
||||
"""Test restore of nonexistent record returns None."""
|
||||
test_engine, SessionLocal = async_test_db
|
||||
|
||||
async with SessionLocal() as session:
|
||||
result = await user_crud.restore(session, id=str(uuid4()))
|
||||
assert result is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_restore_not_deleted(self, async_test_db, async_test_user):
|
||||
"""Test restore of non-deleted record returns None."""
|
||||
test_engine, SessionLocal = async_test_db
|
||||
|
||||
async with SessionLocal() as session:
|
||||
# Try to restore a user that's not deleted
|
||||
result = await user_crud.restore(session, id=str(async_test_user.id))
|
||||
assert result is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_restore_with_uuid_object(self, async_test_db):
|
||||
"""Test restore with UUID object."""
|
||||
test_engine, SessionLocal = async_test_db
|
||||
|
||||
# Create and soft delete a user
|
||||
async with SessionLocal() as session:
|
||||
user_data = UserCreate(
|
||||
email="restore2@example.com",
|
||||
password="TestPassword123!",
|
||||
first_name="Restore",
|
||||
last_name="Test2"
|
||||
)
|
||||
user = await user_crud.create(session, obj_in=user_data)
|
||||
user_id = user.id
|
||||
await session.commit()
|
||||
|
||||
async with SessionLocal() as session:
|
||||
await user_crud.soft_delete(session, id=str(user_id))
|
||||
|
||||
# Restore with UUID object
|
||||
async with SessionLocal() as session:
|
||||
restored = await user_crud.restore(session, id=user_id) # UUID object
|
||||
assert restored is not None
|
||||
assert restored.deleted_at is None
|
||||
@@ -245,7 +245,8 @@ class TestDeactivateAllUserSessions:
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
for i in range(5):
|
||||
# Create minimal sessions for test (2 instead of 5)
|
||||
for i in range(2):
|
||||
sess = UserSession(
|
||||
user_id=async_test_user.id,
|
||||
refresh_token_jti=f"bulk_{i}",
|
||||
@@ -264,7 +265,7 @@ class TestDeactivateAllUserSessions:
|
||||
session,
|
||||
user_id=str(async_test_user.id)
|
||||
)
|
||||
assert count == 5
|
||||
assert count == 2
|
||||
|
||||
|
||||
class TestUpdateLastUsed:
|
||||
@@ -337,3 +338,227 @@ class TestGetUserSessionCount:
|
||||
user_id=str(uuid4())
|
||||
)
|
||||
assert count == 0
|
||||
|
||||
|
||||
class TestUpdateRefreshToken:
|
||||
"""Tests for update_refresh_token method."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_refresh_token_success(self, async_test_db, async_test_user):
|
||||
"""Test updating refresh token JTI and expiration."""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
user_session = UserSession(
|
||||
user_id=async_test_user.id,
|
||||
refresh_token_jti="old_jti",
|
||||
device_name="Test Device",
|
||||
ip_address="192.168.1.1",
|
||||
user_agent="Mozilla/5.0",
|
||||
is_active=True,
|
||||
expires_at=datetime.now(timezone.utc) + timedelta(days=7),
|
||||
last_used_at=datetime.now(timezone.utc) - timedelta(hours=1)
|
||||
)
|
||||
session.add(user_session)
|
||||
await session.commit()
|
||||
await session.refresh(user_session)
|
||||
|
||||
new_jti = "new_jti_123"
|
||||
new_expires = datetime.now(timezone.utc) + timedelta(days=14)
|
||||
|
||||
result = await session_crud.update_refresh_token(
|
||||
session,
|
||||
session=user_session,
|
||||
new_jti=new_jti,
|
||||
new_expires_at=new_expires
|
||||
)
|
||||
|
||||
assert result.refresh_token_jti == new_jti
|
||||
# Compare timestamps ignoring timezone info
|
||||
assert abs((result.expires_at.replace(tzinfo=None) - new_expires.replace(tzinfo=None)).total_seconds()) < 1
|
||||
|
||||
|
||||
class TestCleanupExpired:
|
||||
"""Tests for cleanup_expired method."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cleanup_expired_success(self, async_test_db, async_test_user):
|
||||
"""Test cleaning up old expired inactive sessions."""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
# Create old expired inactive session
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
old_session = UserSession(
|
||||
user_id=async_test_user.id,
|
||||
refresh_token_jti="old_expired",
|
||||
device_name="Old Device",
|
||||
ip_address="192.168.1.1",
|
||||
user_agent="Mozilla/5.0",
|
||||
is_active=False,
|
||||
expires_at=datetime.now(timezone.utc) - timedelta(days=5),
|
||||
last_used_at=datetime.now(timezone.utc) - timedelta(days=35),
|
||||
created_at=datetime.now(timezone.utc) - timedelta(days=35)
|
||||
)
|
||||
session.add(old_session)
|
||||
await session.commit()
|
||||
|
||||
# Cleanup
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
count = await session_crud.cleanup_expired(session, keep_days=30)
|
||||
assert count == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cleanup_expired_keeps_recent(self, async_test_db, async_test_user):
|
||||
"""Test that cleanup keeps recent expired sessions."""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
# Create recent expired inactive session (less than keep_days old)
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
recent_session = UserSession(
|
||||
user_id=async_test_user.id,
|
||||
refresh_token_jti="recent_expired",
|
||||
device_name="Recent Device",
|
||||
ip_address="192.168.1.1",
|
||||
user_agent="Mozilla/5.0",
|
||||
is_active=False,
|
||||
expires_at=datetime.now(timezone.utc) - timedelta(hours=1),
|
||||
last_used_at=datetime.now(timezone.utc) - timedelta(hours=2),
|
||||
created_at=datetime.now(timezone.utc) - timedelta(days=1)
|
||||
)
|
||||
session.add(recent_session)
|
||||
await session.commit()
|
||||
|
||||
# Cleanup
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
count = await session_crud.cleanup_expired(session, keep_days=30)
|
||||
assert count == 0 # Should not delete recent sessions
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cleanup_expired_keeps_active(self, async_test_db, async_test_user):
|
||||
"""Test that cleanup does not delete active sessions."""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
# Create old expired but ACTIVE session
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
active_session = UserSession(
|
||||
user_id=async_test_user.id,
|
||||
refresh_token_jti="active_expired",
|
||||
device_name="Active Device",
|
||||
ip_address="192.168.1.1",
|
||||
user_agent="Mozilla/5.0",
|
||||
is_active=True, # Active
|
||||
expires_at=datetime.now(timezone.utc) - timedelta(days=5),
|
||||
last_used_at=datetime.now(timezone.utc) - timedelta(days=35),
|
||||
created_at=datetime.now(timezone.utc) - timedelta(days=35)
|
||||
)
|
||||
session.add(active_session)
|
||||
await session.commit()
|
||||
|
||||
# Cleanup
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
count = await session_crud.cleanup_expired(session, keep_days=30)
|
||||
assert count == 0 # Should not delete active sessions
|
||||
|
||||
|
||||
class TestCleanupExpiredForUser:
|
||||
"""Tests for cleanup_expired_for_user method."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cleanup_expired_for_user_success(self, async_test_db, async_test_user):
|
||||
"""Test cleaning up expired sessions for specific user."""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
# Create expired inactive session for user
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
expired_session = UserSession(
|
||||
user_id=async_test_user.id,
|
||||
refresh_token_jti="user_expired",
|
||||
device_name="Expired Device",
|
||||
ip_address="192.168.1.1",
|
||||
user_agent="Mozilla/5.0",
|
||||
is_active=False,
|
||||
expires_at=datetime.now(timezone.utc) - timedelta(days=1),
|
||||
last_used_at=datetime.now(timezone.utc) - timedelta(days=2)
|
||||
)
|
||||
session.add(expired_session)
|
||||
await session.commit()
|
||||
|
||||
# Cleanup for user
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
count = await session_crud.cleanup_expired_for_user(
|
||||
session,
|
||||
user_id=str(async_test_user.id)
|
||||
)
|
||||
assert count == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cleanup_expired_for_user_invalid_uuid(self, async_test_db):
|
||||
"""Test cleanup with invalid user UUID."""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
with pytest.raises(ValueError, match="Invalid user ID format"):
|
||||
await session_crud.cleanup_expired_for_user(
|
||||
session,
|
||||
user_id="not-a-valid-uuid"
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cleanup_expired_for_user_keeps_active(self, async_test_db, async_test_user):
|
||||
"""Test that cleanup for user keeps active sessions."""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
# Create expired but active session
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
active_session = UserSession(
|
||||
user_id=async_test_user.id,
|
||||
refresh_token_jti="active_user_expired",
|
||||
device_name="Active Device",
|
||||
ip_address="192.168.1.1",
|
||||
user_agent="Mozilla/5.0",
|
||||
is_active=True, # Active
|
||||
expires_at=datetime.now(timezone.utc) - timedelta(days=1),
|
||||
last_used_at=datetime.now(timezone.utc) - timedelta(days=2)
|
||||
)
|
||||
session.add(active_session)
|
||||
await session.commit()
|
||||
|
||||
# Cleanup
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
count = await session_crud.cleanup_expired_for_user(
|
||||
session,
|
||||
user_id=str(async_test_user.id)
|
||||
)
|
||||
assert count == 0 # Should not delete active sessions
|
||||
|
||||
|
||||
class TestGetUserSessionsWithUser:
|
||||
"""Tests for get_user_sessions with eager loading."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_user_sessions_with_user_relationship(self, async_test_db, async_test_user):
|
||||
"""Test getting sessions with user relationship loaded."""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
user_session = UserSession(
|
||||
user_id=async_test_user.id,
|
||||
refresh_token_jti="with_user",
|
||||
device_name="Test Device",
|
||||
ip_address="192.168.1.1",
|
||||
user_agent="Mozilla/5.0",
|
||||
is_active=True,
|
||||
expires_at=datetime.now(timezone.utc) + timedelta(days=7),
|
||||
last_used_at=datetime.now(timezone.utc)
|
||||
)
|
||||
session.add(user_session)
|
||||
await session.commit()
|
||||
|
||||
# Get with user relationship
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
results = await session_crud.get_user_sessions(
|
||||
session,
|
||||
user_id=str(async_test_user.id),
|
||||
with_user=True
|
||||
)
|
||||
assert len(results) >= 1
|
||||
|
||||
84
backend/tests/test_init_db.py
Normal file
84
backend/tests/test_init_db.py
Normal file
@@ -0,0 +1,84 @@
|
||||
# tests/test_init_db.py
|
||||
"""
|
||||
Tests for database initialization script.
|
||||
"""
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
from app.init_db import init_db
|
||||
from app.core.config import settings
|
||||
|
||||
|
||||
class TestInitDb:
|
||||
"""Tests for init_db functionality."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_init_db_creates_superuser_when_not_exists(self, async_test_db):
|
||||
"""Test that init_db creates a superuser when one doesn't exist."""
|
||||
test_engine, SessionLocal = async_test_db
|
||||
|
||||
# Mock the SessionLocal to use our test database
|
||||
with patch('app.init_db.SessionLocal', SessionLocal):
|
||||
# Mock settings to provide test credentials
|
||||
with patch.object(settings, 'FIRST_SUPERUSER_EMAIL', 'test_admin@example.com'):
|
||||
with patch.object(settings, 'FIRST_SUPERUSER_PASSWORD', 'TestAdmin123!'):
|
||||
# Run init_db
|
||||
user = await init_db()
|
||||
|
||||
# Verify superuser was created
|
||||
assert user is not None
|
||||
assert user.email == 'test_admin@example.com'
|
||||
assert user.is_superuser is True
|
||||
assert user.first_name == 'Admin'
|
||||
assert user.last_name == 'User'
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_init_db_returns_existing_superuser(self, async_test_db, async_test_user):
|
||||
"""Test that init_db returns existing superuser instead of creating duplicate."""
|
||||
test_engine, SessionLocal = async_test_db
|
||||
|
||||
# Mock the SessionLocal to use our test database
|
||||
with patch('app.init_db.SessionLocal', SessionLocal):
|
||||
# Mock settings to match async_test_user's email
|
||||
with patch.object(settings, 'FIRST_SUPERUSER_EMAIL', 'testuser@example.com'):
|
||||
with patch.object(settings, 'FIRST_SUPERUSER_PASSWORD', 'TestPassword123!'):
|
||||
# Run init_db
|
||||
user = await init_db()
|
||||
|
||||
# Verify it returns the existing user
|
||||
assert user is not None
|
||||
assert user.id == async_test_user.id
|
||||
assert user.email == 'testuser@example.com'
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_init_db_uses_default_credentials(self, async_test_db):
|
||||
"""Test that init_db uses default credentials when env vars not set."""
|
||||
test_engine, SessionLocal = async_test_db
|
||||
|
||||
# Mock the SessionLocal to use our test database
|
||||
with patch('app.init_db.SessionLocal', SessionLocal):
|
||||
# Mock settings to have None values (not configured)
|
||||
with patch.object(settings, 'FIRST_SUPERUSER_EMAIL', None):
|
||||
with patch.object(settings, 'FIRST_SUPERUSER_PASSWORD', None):
|
||||
# Run init_db
|
||||
user = await init_db()
|
||||
|
||||
# Verify superuser was created with defaults
|
||||
assert user is not None
|
||||
assert user.email == 'admin@example.com'
|
||||
assert user.is_superuser is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_init_db_handles_database_errors(self, async_test_db):
|
||||
"""Test that init_db handles database errors gracefully."""
|
||||
test_engine, SessionLocal = async_test_db
|
||||
|
||||
# Mock user_crud.get_by_email to raise an exception
|
||||
with patch('app.init_db.user_crud.get_by_email', side_effect=Exception("Database error")):
|
||||
with patch('app.init_db.SessionLocal', SessionLocal):
|
||||
with patch.object(settings, 'FIRST_SUPERUSER_EMAIL', 'test@example.com'):
|
||||
with patch.object(settings, 'FIRST_SUPERUSER_PASSWORD', 'TestPassword123!'):
|
||||
# Run init_db and expect it to raise
|
||||
with pytest.raises(Exception, match="Database error"):
|
||||
await init_db()
|
||||
Reference in New Issue
Block a user