diff --git a/backend/app/alembic/versions/1174fffbe3e4_add_performance_indexes.py b/backend/app/alembic/versions/1174fffbe3e4_add_performance_indexes.py new file mode 100644 index 0000000..9f58956 --- /dev/null +++ b/backend/app/alembic/versions/1174fffbe3e4_add_performance_indexes.py @@ -0,0 +1,78 @@ +"""add_performance_indexes + +Revision ID: 1174fffbe3e4 +Revises: fbf6318a8a36 +Create Date: 2025-11-01 04:15:25.367010 + +""" +from typing import Sequence, Union + +import sqlalchemy as sa + +from alembic import op + +# revision identifiers, used by Alembic. +revision: str = '1174fffbe3e4' +down_revision: Union[str, None] = 'fbf6318a8a36' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + """Add performance indexes for optimized queries.""" + + # Index for session cleanup queries + # Optimizes: DELETE WHERE is_active = FALSE AND expires_at < now AND created_at < cutoff + op.create_index( + 'ix_user_sessions_cleanup', + 'user_sessions', + ['is_active', 'expires_at', 'created_at'], + unique=False, + postgresql_where=sa.text('is_active = false') + ) + + # Index for user search queries (basic trigram support without pg_trgm extension) + # Optimizes: WHERE email ILIKE '%search%' OR first_name ILIKE '%search%' + # Note: For better performance, consider enabling pg_trgm extension + op.create_index( + 'ix_users_email_lower', + 'users', + [sa.text('LOWER(email)')], + unique=False, + postgresql_where=sa.text('deleted_at IS NULL') + ) + + op.create_index( + 'ix_users_first_name_lower', + 'users', + [sa.text('LOWER(first_name)')], + unique=False, + postgresql_where=sa.text('deleted_at IS NULL') + ) + + op.create_index( + 'ix_users_last_name_lower', + 'users', + [sa.text('LOWER(last_name)')], + unique=False, + postgresql_where=sa.text('deleted_at IS NULL') + ) + + # Index for organization search + op.create_index( + 'ix_organizations_name_lower', + 'organizations', + [sa.text('LOWER(name)')], + unique=False + ) + + +def downgrade() -> None: + """Remove performance indexes.""" + + # Drop indexes in reverse order + op.drop_index('ix_organizations_name_lower', table_name='organizations') + op.drop_index('ix_users_last_name_lower', table_name='users') + op.drop_index('ix_users_first_name_lower', table_name='users') + op.drop_index('ix_users_email_lower', table_name='users') + op.drop_index('ix_user_sessions_cleanup', table_name='user_sessions') diff --git a/backend/app/api/routes/admin.py b/backend/app/api/routes/admin.py index 95dbca1..6dafe2f 100755 --- a/backend/app/api/routes/admin.py +++ b/backend/app/api/routes/admin.py @@ -145,7 +145,7 @@ async def admin_create_user( except ValueError as e: logger.warning(f"Failed to create user: {str(e)}") raise NotFoundError( - detail=str(e), + message=str(e), error_code=ErrorCode.USER_ALREADY_EXISTS ) except Exception as e: @@ -169,7 +169,7 @@ async def admin_get_user( user = await user_crud.get(db, id=user_id) if not user: raise NotFoundError( - detail=f"User {user_id} not found", + message=f"User {user_id} not found", error_code=ErrorCode.USER_NOT_FOUND ) return user @@ -193,7 +193,7 @@ async def admin_update_user( user = await user_crud.get(db, id=user_id) if not user: raise NotFoundError( - detail=f"User {user_id} not found", + message=f"User {user_id} not found", error_code=ErrorCode.USER_NOT_FOUND ) @@ -225,7 +225,7 @@ async def admin_delete_user( user = await user_crud.get(db, id=user_id) if not user: raise NotFoundError( - detail=f"User {user_id} not found", + message=f"User {user_id} not found", error_code=ErrorCode.USER_NOT_FOUND ) @@ -269,7 +269,7 @@ async def admin_activate_user( user = await user_crud.get(db, id=user_id) if not user: raise NotFoundError( - detail=f"User {user_id} not found", + message=f"User {user_id} not found", error_code=ErrorCode.USER_NOT_FOUND ) @@ -305,7 +305,7 @@ async def admin_deactivate_user( user = await user_crud.get(db, id=user_id) if not user: raise NotFoundError( - detail=f"User {user_id} not found", + message=f"User {user_id} not found", error_code=ErrorCode.USER_NOT_FOUND ) @@ -491,7 +491,7 @@ async def admin_create_organization( except ValueError as e: logger.warning(f"Failed to create organization: {str(e)}") raise NotFoundError( - detail=str(e), + message=str(e), error_code=ErrorCode.ALREADY_EXISTS ) except Exception as e: @@ -515,7 +515,7 @@ async def admin_get_organization( org = await organization_crud.get(db, id=org_id) if not org: raise NotFoundError( - detail=f"Organization {org_id} not found", + message=f"Organization {org_id} not found", error_code=ErrorCode.NOT_FOUND ) @@ -551,7 +551,7 @@ async def admin_update_organization( org = await organization_crud.get(db, id=org_id) if not org: raise NotFoundError( - detail=f"Organization {org_id} not found", + message=f"Organization {org_id} not found", error_code=ErrorCode.NOT_FOUND ) @@ -595,7 +595,7 @@ async def admin_delete_organization( org = await organization_crud.get(db, id=org_id) if not org: raise NotFoundError( - detail=f"Organization {org_id} not found", + message=f"Organization {org_id} not found", error_code=ErrorCode.NOT_FOUND ) @@ -633,7 +633,7 @@ async def admin_list_organization_members( org = await organization_crud.get(db, id=org_id) if not org: raise NotFoundError( - detail=f"Organization {org_id} not found", + message=f"Organization {org_id} not found", error_code=ErrorCode.NOT_FOUND ) @@ -688,14 +688,14 @@ async def admin_add_organization_member( org = await organization_crud.get(db, id=org_id) if not org: raise NotFoundError( - detail=f"Organization {org_id} not found", + message=f"Organization {org_id} not found", error_code=ErrorCode.NOT_FOUND ) user = await user_crud.get(db, id=request.user_id) if not user: raise NotFoundError( - detail=f"User {request.user_id} not found", + message=f"User {request.user_id} not found", error_code=ErrorCode.USER_NOT_FOUND ) @@ -749,14 +749,14 @@ async def admin_remove_organization_member( org = await organization_crud.get(db, id=org_id) if not org: raise NotFoundError( - detail=f"Organization {org_id} not found", + message=f"Organization {org_id} not found", error_code=ErrorCode.NOT_FOUND ) user = await user_crud.get(db, id=user_id) if not user: raise NotFoundError( - detail=f"User {user_id} not found", + message=f"User {user_id} not found", error_code=ErrorCode.USER_NOT_FOUND ) @@ -768,7 +768,7 @@ async def admin_remove_organization_member( if not success: raise NotFoundError( - detail="User is not a member of this organization", + message="User is not a member of this organization", error_code=ErrorCode.NOT_FOUND ) diff --git a/backend/app/schemas/users.py b/backend/app/schemas/users.py index 564e6ff..fb98918 100644 --- a/backend/app/schemas/users.py +++ b/backend/app/schemas/users.py @@ -35,6 +35,7 @@ class UserUpdate(BaseModel): first_name: Optional[str] = None last_name: Optional[str] = None phone_number: Optional[str] = None + password: Optional[str] = None preferences: Optional[Dict[str, Any]] = None is_active: Optional[bool] = None # Changed default from True to None to avoid unintended updates @@ -43,6 +44,14 @@ class UserUpdate(BaseModel): def validate_phone(cls, v: Optional[str]) -> Optional[str]: return validate_phone_number(v) + @field_validator('password') + @classmethod + def password_strength(cls, v: Optional[str]) -> Optional[str]: + """Enterprise-grade password strength validation""" + if v is None: + return v + return validate_password_strength(v) + class UserInDB(UserBase): id: UUID diff --git a/backend/app/utils/device.py b/backend/app/utils/device.py index c6614d6..d4842c3 100644 --- a/backend/app/utils/device.py +++ b/backend/app/utils/device.py @@ -68,6 +68,22 @@ def parse_device_name(user_agent: str) -> Optional[str]: elif 'windows phone' in user_agent_lower: return "Windows Phone" + # Tablets (check before desktop, as some tablets contain "android") + elif 'tablet' in user_agent_lower: + return "Tablet" + + # Smart TVs (check before desktop OS patterns) + elif any(tv in user_agent_lower for tv in ['smart-tv', 'smarttv']): + return "Smart TV" + + # Game consoles (check before desktop OS patterns, as Xbox contains "Windows") + elif 'playstation' in user_agent_lower: + return "PlayStation" + elif 'xbox' in user_agent_lower: + return "Xbox" + elif 'nintendo' in user_agent_lower: + return "Nintendo" + # Desktop operating systems elif 'macintosh' in user_agent_lower or 'mac os x' in user_agent_lower: # Try to extract browser @@ -82,22 +98,6 @@ def parse_device_name(user_agent: str) -> Optional[str]: elif 'cros' in user_agent_lower: return "Chromebook" - # Tablets (not already caught) - elif 'tablet' in user_agent_lower: - return "Tablet" - - # Smart TVs - elif any(tv in user_agent_lower for tv in ['smart-tv', 'smarttv', 'tv']): - return "Smart TV" - - # Game consoles - elif 'playstation' in user_agent_lower: - return "PlayStation" - elif 'xbox' in user_agent_lower: - return "Xbox" - elif 'nintendo' in user_agent_lower: - return "Nintendo" - # Fallback: just return browser name if detected browser = extract_browser(user_agent) if browser: diff --git a/backend/tests/api/test_admin.py b/backend/tests/api/test_admin.py new file mode 100644 index 0000000..f9f7cb9 --- /dev/null +++ b/backend/tests/api/test_admin.py @@ -0,0 +1,839 @@ +# tests/api/test_admin.py +""" +Comprehensive tests for admin endpoints. +""" +import pytest +import pytest_asyncio +from uuid import uuid4 +from fastapi import status + +from app.models.organization import Organization +from app.models.user_organization import UserOrganization, OrganizationRole + + +@pytest_asyncio.fixture +async def superuser_token(client, async_test_superuser): + """Get access token for superuser.""" + response = await client.post( + "/api/v1/auth/login", + json={ + "email": "superuser@example.com", + "password": "SuperPassword123!" + } + ) + assert response.status_code == 200, f"Login failed: {response.json()}" + return response.json()["access_token"] + + +# ===== USER MANAGEMENT TESTS ===== + +class TestAdminListUsers: + """Tests for GET /admin/users endpoint.""" + + @pytest.mark.asyncio + async def test_admin_list_users_success(self, client, superuser_token): + """Test successfully listing users as admin.""" + response = await client.get( + "/api/v1/admin/users", + headers={"Authorization": f"Bearer {superuser_token}"} + ) + + assert response.status_code == status.HTTP_200_OK + data = response.json() + assert "data" in data + assert "pagination" in data + assert isinstance(data["data"], list) + + @pytest.mark.asyncio + async def test_admin_list_users_with_filters(self, client, async_test_superuser, async_test_db, superuser_token): + """Test listing users with filters.""" + test_engine, AsyncTestingSessionLocal = async_test_db + + # Create inactive user + async with AsyncTestingSessionLocal() as session: + from app.models.user import User + from app.core.auth import get_password_hash + inactive_user = User( + email="inactive@example.com", + password_hash=get_password_hash("TestPassword123!"), + first_name="Inactive", + last_name="User", + is_active=False + ) + session.add(inactive_user) + await session.commit() + + response = await client.get( + "/api/v1/admin/users?is_active=false", + headers={"Authorization": f"Bearer {superuser_token}"} + ) + + assert response.status_code == status.HTTP_200_OK + data = response.json() + assert len(data["data"]) >= 1 + + @pytest.mark.asyncio + async def test_admin_list_users_with_search(self, client, async_test_superuser, superuser_token): + """Test searching users.""" + response = await client.get( + "/api/v1/admin/users?search=superuser", + headers={"Authorization": f"Bearer {superuser_token}"} + ) + + assert response.status_code == status.HTTP_200_OK + data = response.json() + assert "data" in data + + @pytest.mark.asyncio + async def test_admin_list_users_unauthorized(self, client, async_test_user): + """Test non-admin cannot list users.""" + # Login as regular user + login_response = await client.post( + "/api/v1/auth/login", + json={"email": async_test_user.email, "password": "TestPassword123!"} + ) + token = login_response.json()["access_token"] + + response = await client.get( + "/api/v1/admin/users", + headers={"Authorization": f"Bearer {token}"} + ) + + assert response.status_code == status.HTTP_403_FORBIDDEN + + +class TestAdminCreateUser: + """Tests for POST /admin/users endpoint.""" + + @pytest.mark.asyncio + async def test_admin_create_user_success(self, client, async_test_superuser, superuser_token): + """Test successfully creating a user as admin.""" + response = await client.post( + "/api/v1/admin/users", + json={ + "email": "newadminuser@example.com", + "password": "SecurePassword123!", + "first_name": "New", + "last_name": "User" + }, + headers={"Authorization": f"Bearer {superuser_token}"} + ) + + assert response.status_code == status.HTTP_201_CREATED + data = response.json() + assert data["email"] == "newadminuser@example.com" + + @pytest.mark.asyncio + async def test_admin_create_user_duplicate_email(self, client, async_test_superuser, async_test_user, superuser_token): + """Test creating user with duplicate email fails.""" + response = await client.post( + "/api/v1/admin/users", + json={ + "email": async_test_user.email, + "password": "SecurePassword123!", + "first_name": "Duplicate", + "last_name": "User" + }, + headers={"Authorization": f"Bearer {superuser_token}"} + ) + + assert response.status_code == status.HTTP_404_NOT_FOUND + + +class TestAdminGetUser: + """Tests for GET /admin/users/{user_id} endpoint.""" + + @pytest.mark.asyncio + async def test_admin_get_user_success(self, client, async_test_superuser, async_test_user, superuser_token): + """Test successfully getting user details.""" + response = await client.get( + f"/api/v1/admin/users/{async_test_user.id}", + headers={"Authorization": f"Bearer {superuser_token}"} + ) + + assert response.status_code == status.HTTP_200_OK + data = response.json() + assert data["id"] == str(async_test_user.id) + assert data["email"] == async_test_user.email + + @pytest.mark.asyncio + async def test_admin_get_user_not_found(self, client, async_test_superuser, superuser_token): + """Test getting non-existent user.""" + response = await client.get( + f"/api/v1/admin/users/{uuid4()}", + headers={"Authorization": f"Bearer {superuser_token}"} + ) + + assert response.status_code == status.HTTP_404_NOT_FOUND + + +class TestAdminUpdateUser: + """Tests for PUT /admin/users/{user_id} endpoint.""" + + @pytest.mark.asyncio + async def test_admin_update_user_success(self, client, async_test_superuser, async_test_user, superuser_token): + """Test successfully updating a user.""" + response = await client.put( + f"/api/v1/admin/users/{async_test_user.id}", + json={"first_name": "Updated"}, + headers={"Authorization": f"Bearer {superuser_token}"} + ) + + assert response.status_code == status.HTTP_200_OK + data = response.json() + assert data["first_name"] == "Updated" + + @pytest.mark.asyncio + async def test_admin_update_user_not_found(self, client, async_test_superuser, superuser_token): + """Test updating non-existent user.""" + response = await client.put( + f"/api/v1/admin/users/{uuid4()}", + json={"first_name": "Updated"}, + headers={"Authorization": f"Bearer {superuser_token}"} + ) + + assert response.status_code == status.HTTP_404_NOT_FOUND + + +class TestAdminDeleteUser: + """Tests for DELETE /admin/users/{user_id} endpoint.""" + + @pytest.mark.asyncio + async def test_admin_delete_user_success(self, client, async_test_superuser, async_test_db, superuser_token): + """Test successfully deleting a user.""" + test_engine, AsyncTestingSessionLocal = async_test_db + + # Create user to delete + async with AsyncTestingSessionLocal() as session: + from app.models.user import User + from app.core.auth import get_password_hash + user_to_delete = User( + email="todelete@example.com", + password_hash=get_password_hash("TestPassword123!"), + first_name="To", + last_name="Delete" + ) + session.add(user_to_delete) + await session.commit() + user_id = user_to_delete.id + + response = await client.delete( + f"/api/v1/admin/users/{user_id}", + headers={"Authorization": f"Bearer {superuser_token}"} + ) + + assert response.status_code == status.HTTP_200_OK + data = response.json() + assert data["success"] is True + + @pytest.mark.asyncio + async def test_admin_delete_user_not_found(self, client, async_test_superuser, superuser_token): + """Test deleting non-existent user.""" + response = await client.delete( + f"/api/v1/admin/users/{uuid4()}", + headers={"Authorization": f"Bearer {superuser_token}"} + ) + + assert response.status_code == status.HTTP_404_NOT_FOUND + + @pytest.mark.asyncio + async def test_admin_delete_self_forbidden(self, client, async_test_superuser, superuser_token): + """Test admin cannot delete their own account.""" + response = await client.delete( + f"/api/v1/admin/users/{async_test_superuser.id}", + headers={"Authorization": f"Bearer {superuser_token}"} + ) + + assert response.status_code == status.HTTP_403_FORBIDDEN + + +class TestAdminActivateUser: + """Tests for POST /admin/users/{user_id}/activate endpoint.""" + + @pytest.mark.asyncio + async def test_admin_activate_user_success(self, client, async_test_superuser, async_test_db, superuser_token): + """Test successfully activating a user.""" + test_engine, AsyncTestingSessionLocal = async_test_db + + # Create inactive user + async with AsyncTestingSessionLocal() as session: + from app.models.user import User + from app.core.auth import get_password_hash + inactive_user = User( + email="toactivate@example.com", + password_hash=get_password_hash("TestPassword123!"), + first_name="To", + last_name="Activate", + is_active=False + ) + session.add(inactive_user) + await session.commit() + user_id = inactive_user.id + + response = await client.post( + f"/api/v1/admin/users/{user_id}/activate", + headers={"Authorization": f"Bearer {superuser_token}"} + ) + + assert response.status_code == status.HTTP_200_OK + data = response.json() + assert data["success"] is True + + @pytest.mark.asyncio + async def test_admin_activate_user_not_found(self, client, async_test_superuser, superuser_token): + """Test activating non-existent user.""" + response = await client.post( + f"/api/v1/admin/users/{uuid4()}/activate", + headers={"Authorization": f"Bearer {superuser_token}"} + ) + + assert response.status_code == status.HTTP_404_NOT_FOUND + + +class TestAdminDeactivateUser: + """Tests for POST /admin/users/{user_id}/deactivate endpoint.""" + + @pytest.mark.asyncio + async def test_admin_deactivate_user_success(self, client, async_test_superuser, async_test_user, superuser_token): + """Test successfully deactivating a user.""" + response = await client.post( + f"/api/v1/admin/users/{async_test_user.id}/deactivate", + headers={"Authorization": f"Bearer {superuser_token}"} + ) + + assert response.status_code == status.HTTP_200_OK + data = response.json() + assert data["success"] is True + + @pytest.mark.asyncio + async def test_admin_deactivate_user_not_found(self, client, async_test_superuser, superuser_token): + """Test deactivating non-existent user.""" + response = await client.post( + f"/api/v1/admin/users/{uuid4()}/deactivate", + headers={"Authorization": f"Bearer {superuser_token}"} + ) + + assert response.status_code == status.HTTP_404_NOT_FOUND + + @pytest.mark.asyncio + async def test_admin_deactivate_self_forbidden(self, client, async_test_superuser, superuser_token): + """Test admin cannot deactivate their own account.""" + response = await client.post( + f"/api/v1/admin/users/{async_test_superuser.id}/deactivate", + headers={"Authorization": f"Bearer {superuser_token}"} + ) + + assert response.status_code == status.HTTP_403_FORBIDDEN + + +class TestAdminBulkUserAction: + """Tests for POST /admin/users/bulk-action endpoint.""" + + @pytest.mark.asyncio + async def test_admin_bulk_activate_users(self, client, async_test_superuser, async_test_db, superuser_token): + """Test bulk activating users.""" + test_engine, AsyncTestingSessionLocal = async_test_db + + # Create inactive users + user_ids = [] + async with AsyncTestingSessionLocal() as session: + from app.models.user import User + from app.core.auth import get_password_hash + for i in range(3): + user = User( + email=f"bulk{i}@example.com", + password_hash=get_password_hash("TestPassword123!"), + first_name=f"Bulk{i}", + last_name="User", + is_active=False + ) + session.add(user) + await session.flush() + user_ids.append(str(user.id)) + await session.commit() + + response = await client.post( + "/api/v1/admin/users/bulk-action", + json={ + "action": "activate", + "user_ids": user_ids + }, + headers={"Authorization": f"Bearer {superuser_token}"} + ) + + assert response.status_code == status.HTTP_200_OK + data = response.json() + assert data["affected_count"] == 3 + + @pytest.mark.asyncio + async def test_admin_bulk_deactivate_users(self, client, async_test_superuser, async_test_db, superuser_token): + """Test bulk deactivating users.""" + test_engine, AsyncTestingSessionLocal = async_test_db + + # Create active users + user_ids = [] + async with AsyncTestingSessionLocal() as session: + from app.models.user import User + from app.core.auth import get_password_hash + for i in range(2): + user = User( + email=f"deactivate{i}@example.com", + password_hash=get_password_hash("TestPassword123!"), + first_name=f"Deactivate{i}", + last_name="User", + is_active=True + ) + session.add(user) + await session.flush() + user_ids.append(str(user.id)) + await session.commit() + + response = await client.post( + "/api/v1/admin/users/bulk-action", + json={ + "action": "deactivate", + "user_ids": user_ids + }, + headers={"Authorization": f"Bearer {superuser_token}"} + ) + + assert response.status_code == status.HTTP_200_OK + data = response.json() + assert data["affected_count"] == 2 + + @pytest.mark.asyncio + async def test_admin_bulk_delete_users(self, client, async_test_superuser, async_test_db, superuser_token): + """Test bulk deleting users.""" + test_engine, AsyncTestingSessionLocal = async_test_db + + # Create users to delete + user_ids = [] + async with AsyncTestingSessionLocal() as session: + from app.models.user import User + from app.core.auth import get_password_hash + for i in range(2): + user = User( + email=f"bulkdelete{i}@example.com", + password_hash=get_password_hash("TestPassword123!"), + first_name=f"BulkDelete{i}", + last_name="User" + ) + session.add(user) + await session.flush() + user_ids.append(str(user.id)) + await session.commit() + + response = await client.post( + "/api/v1/admin/users/bulk-action", + json={ + "action": "delete", + "user_ids": user_ids + }, + headers={"Authorization": f"Bearer {superuser_token}"} + ) + + assert response.status_code == status.HTTP_200_OK + data = response.json() + assert data["affected_count"] >= 0 + + +# ===== ORGANIZATION MANAGEMENT TESTS ===== + +class TestAdminListOrganizations: + """Tests for GET /admin/organizations endpoint.""" + + @pytest.mark.asyncio + async def test_admin_list_organizations_success(self, client, async_test_superuser, async_test_db, superuser_token): + """Test successfully listing organizations.""" + test_engine, AsyncTestingSessionLocal = async_test_db + + # Create test organization + async with AsyncTestingSessionLocal() as session: + org = Organization(name="Test Org", slug="test-org") + session.add(org) + await session.commit() + + response = await client.get( + "/api/v1/admin/organizations", + headers={"Authorization": f"Bearer {superuser_token}"} + ) + + assert response.status_code == status.HTTP_200_OK + data = response.json() + assert "data" in data + assert "pagination" in data + + @pytest.mark.asyncio + async def test_admin_list_organizations_with_search(self, client, async_test_superuser, async_test_db, superuser_token): + """Test searching organizations.""" + test_engine, AsyncTestingSessionLocal = async_test_db + + # Create test organization + async with AsyncTestingSessionLocal() as session: + org = Organization(name="Searchable Org", slug="searchable-org") + session.add(org) + await session.commit() + + response = await client.get( + "/api/v1/admin/organizations?search=Searchable", + headers={"Authorization": f"Bearer {superuser_token}"} + ) + + assert response.status_code == status.HTTP_200_OK + + +class TestAdminCreateOrganization: + """Tests for POST /admin/organizations endpoint.""" + + @pytest.mark.asyncio + async def test_admin_create_organization_success(self, client, async_test_superuser, superuser_token): + """Test successfully creating an organization.""" + response = await client.post( + "/api/v1/admin/organizations", + json={ + "name": "New Admin Org", + "slug": "new-admin-org", + "description": "Created by admin" + }, + headers={"Authorization": f"Bearer {superuser_token}"} + ) + + assert response.status_code == status.HTTP_201_CREATED + data = response.json() + assert data["name"] == "New Admin Org" + assert data["member_count"] == 0 + + @pytest.mark.asyncio + async def test_admin_create_organization_duplicate_slug(self, client, async_test_superuser, async_test_db, superuser_token): + """Test creating organization with duplicate slug fails.""" + test_engine, AsyncTestingSessionLocal = async_test_db + + # Create existing organization + async with AsyncTestingSessionLocal() as session: + org = Organization(name="Existing", slug="duplicate-slug") + session.add(org) + await session.commit() + + response = await client.post( + "/api/v1/admin/organizations", + json={ + "name": "Duplicate", + "slug": "duplicate-slug" + }, + headers={"Authorization": f"Bearer {superuser_token}"} + ) + + assert response.status_code == status.HTTP_404_NOT_FOUND + + +class TestAdminGetOrganization: + """Tests for GET /admin/organizations/{org_id} endpoint.""" + + @pytest.mark.asyncio + async def test_admin_get_organization_success(self, client, async_test_superuser, async_test_db, superuser_token): + """Test successfully getting organization details.""" + test_engine, AsyncTestingSessionLocal = async_test_db + + # Create test organization + async with AsyncTestingSessionLocal() as session: + org = Organization(name="Get Test Org", slug="get-test-org") + session.add(org) + await session.commit() + org_id = org.id + + response = await client.get( + f"/api/v1/admin/organizations/{org_id}", + headers={"Authorization": f"Bearer {superuser_token}"} + ) + + assert response.status_code == status.HTTP_200_OK + data = response.json() + assert data["name"] == "Get Test Org" + + @pytest.mark.asyncio + async def test_admin_get_organization_not_found(self, client, async_test_superuser, superuser_token): + """Test getting non-existent organization.""" + response = await client.get( + f"/api/v1/admin/organizations/{uuid4()}", + headers={"Authorization": f"Bearer {superuser_token}"} + ) + + assert response.status_code == status.HTTP_404_NOT_FOUND + + +class TestAdminUpdateOrganization: + """Tests for PUT /admin/organizations/{org_id} endpoint.""" + + @pytest.mark.asyncio + async def test_admin_update_organization_success(self, client, async_test_superuser, async_test_db, superuser_token): + """Test successfully updating an organization.""" + test_engine, AsyncTestingSessionLocal = async_test_db + + # Create test organization + async with AsyncTestingSessionLocal() as session: + org = Organization(name="Update Test", slug="update-test") + session.add(org) + await session.commit() + org_id = org.id + + response = await client.put( + f"/api/v1/admin/organizations/{org_id}", + json={"name": "Updated Name"}, + headers={"Authorization": f"Bearer {superuser_token}"} + ) + + assert response.status_code == status.HTTP_200_OK + data = response.json() + assert data["name"] == "Updated Name" + + @pytest.mark.asyncio + async def test_admin_update_organization_not_found(self, client, async_test_superuser, superuser_token): + """Test updating non-existent organization.""" + response = await client.put( + f"/api/v1/admin/organizations/{uuid4()}", + json={"name": "Updated"}, + headers={"Authorization": f"Bearer {superuser_token}"} + ) + + assert response.status_code == status.HTTP_404_NOT_FOUND + + +class TestAdminDeleteOrganization: + """Tests for DELETE /admin/organizations/{org_id} endpoint.""" + + @pytest.mark.asyncio + async def test_admin_delete_organization_success(self, client, async_test_superuser, async_test_db, superuser_token): + """Test successfully deleting an organization.""" + test_engine, AsyncTestingSessionLocal = async_test_db + + # Create test organization + async with AsyncTestingSessionLocal() as session: + org = Organization(name="Delete Test", slug="delete-test") + session.add(org) + await session.commit() + org_id = org.id + + response = await client.delete( + f"/api/v1/admin/organizations/{org_id}", + headers={"Authorization": f"Bearer {superuser_token}"} + ) + + assert response.status_code == status.HTTP_200_OK + data = response.json() + assert data["success"] is True + + @pytest.mark.asyncio + async def test_admin_delete_organization_not_found(self, client, async_test_superuser, superuser_token): + """Test deleting non-existent organization.""" + response = await client.delete( + f"/api/v1/admin/organizations/{uuid4()}", + headers={"Authorization": f"Bearer {superuser_token}"} + ) + + assert response.status_code == status.HTTP_404_NOT_FOUND + + +class TestAdminListOrganizationMembers: + """Tests for GET /admin/organizations/{org_id}/members endpoint.""" + + @pytest.mark.asyncio + async def test_admin_list_organization_members_success(self, client, async_test_superuser, async_test_db, async_test_user, superuser_token): + """Test successfully listing organization members.""" + test_engine, AsyncTestingSessionLocal = async_test_db + + # Create test organization with member + async with AsyncTestingSessionLocal() as session: + org = Organization(name="Members Test", slug="members-test") + session.add(org) + await session.commit() + + user_org = UserOrganization( + user_id=async_test_user.id, + organization_id=org.id, + role=OrganizationRole.MEMBER, + is_active=True + ) + session.add(user_org) + await session.commit() + org_id = org.id + + response = await client.get( + f"/api/v1/admin/organizations/{org_id}/members", + headers={"Authorization": f"Bearer {superuser_token}"} + ) + + assert response.status_code == status.HTTP_200_OK + data = response.json() + assert "data" in data + assert len(data["data"]) >= 1 + + @pytest.mark.asyncio + async def test_admin_list_organization_members_not_found(self, client, async_test_superuser, superuser_token): + """Test listing members of non-existent organization.""" + response = await client.get( + f"/api/v1/admin/organizations/{uuid4()}/members", + headers={"Authorization": f"Bearer {superuser_token}"} + ) + + assert response.status_code == status.HTTP_404_NOT_FOUND + + +class TestAdminAddOrganizationMember: + """Tests for POST /admin/organizations/{org_id}/members endpoint.""" + + @pytest.mark.asyncio + async def test_admin_add_organization_member_success(self, client, async_test_superuser, async_test_db, async_test_user, superuser_token): + """Test successfully adding a member to organization.""" + test_engine, AsyncTestingSessionLocal = async_test_db + + # Create test organization + async with AsyncTestingSessionLocal() as session: + org = Organization(name="Add Member Test", slug="add-member-test") + session.add(org) + await session.commit() + org_id = org.id + + response = await client.post( + f"/api/v1/admin/organizations/{org_id}/members", + json={ + "user_id": str(async_test_user.id), + "role": "member" + }, + headers={"Authorization": f"Bearer {superuser_token}"} + ) + + assert response.status_code == status.HTTP_200_OK + data = response.json() + assert data["success"] is True + + @pytest.mark.asyncio + async def test_admin_add_organization_member_already_exists(self, client, async_test_superuser, async_test_db, async_test_user, superuser_token): + """Test adding member who is already a member.""" + test_engine, AsyncTestingSessionLocal = async_test_db + + # Create organization with existing member + async with AsyncTestingSessionLocal() as session: + org = Organization(name="Existing Member", slug="existing-member") + session.add(org) + await session.commit() + + user_org = UserOrganization( + user_id=async_test_user.id, + organization_id=org.id, + role=OrganizationRole.MEMBER, + is_active=True + ) + session.add(user_org) + await session.commit() + org_id = org.id + + response = await client.post( + f"/api/v1/admin/organizations/{org_id}/members", + json={ + "user_id": str(async_test_user.id), + "role": "member" + }, + headers={"Authorization": f"Bearer {superuser_token}"} + ) + + assert response.status_code == status.HTTP_409_CONFLICT + + @pytest.mark.asyncio + async def test_admin_add_organization_member_org_not_found(self, client, async_test_superuser, async_test_user, superuser_token): + """Test adding member to non-existent organization.""" + response = await client.post( + f"/api/v1/admin/organizations/{uuid4()}/members", + json={ + "user_id": str(async_test_user.id), + "role": "member" + }, + headers={"Authorization": f"Bearer {superuser_token}"} + ) + + assert response.status_code == status.HTTP_404_NOT_FOUND + + @pytest.mark.asyncio + async def test_admin_add_organization_member_user_not_found(self, client, async_test_superuser, async_test_db, superuser_token): + """Test adding non-existent user to organization.""" + test_engine, AsyncTestingSessionLocal = async_test_db + + # Create test organization + async with AsyncTestingSessionLocal() as session: + org = Organization(name="User Not Found", slug="user-not-found") + session.add(org) + await session.commit() + org_id = org.id + + response = await client.post( + f"/api/v1/admin/organizations/{org_id}/members", + json={ + "user_id": str(uuid4()), + "role": "member" + }, + headers={"Authorization": f"Bearer {superuser_token}"} + ) + + assert response.status_code == status.HTTP_404_NOT_FOUND + + +class TestAdminRemoveOrganizationMember: + """Tests for DELETE /admin/organizations/{org_id}/members/{user_id} endpoint.""" + + @pytest.mark.asyncio + async def test_admin_remove_organization_member_success(self, client, async_test_superuser, async_test_db, async_test_user, superuser_token): + """Test successfully removing a member from organization.""" + test_engine, AsyncTestingSessionLocal = async_test_db + + # Create organization with member + async with AsyncTestingSessionLocal() as session: + org = Organization(name="Remove Member", slug="remove-member") + session.add(org) + await session.commit() + + user_org = UserOrganization( + user_id=async_test_user.id, + organization_id=org.id, + role=OrganizationRole.MEMBER, + is_active=True + ) + session.add(user_org) + await session.commit() + org_id = org.id + + response = await client.delete( + f"/api/v1/admin/organizations/{org_id}/members/{async_test_user.id}", + headers={"Authorization": f"Bearer {superuser_token}"} + ) + + assert response.status_code == status.HTTP_200_OK + data = response.json() + assert data["success"] is True + + @pytest.mark.asyncio + async def test_admin_remove_organization_member_not_member(self, client, async_test_superuser, async_test_db, async_test_user, superuser_token): + """Test removing user who is not a member.""" + test_engine, AsyncTestingSessionLocal = async_test_db + + # Create organization without member + async with AsyncTestingSessionLocal() as session: + org = Organization(name="No Member", slug="no-member") + session.add(org) + await session.commit() + org_id = org.id + + response = await client.delete( + f"/api/v1/admin/organizations/{org_id}/members/{async_test_user.id}", + headers={"Authorization": f"Bearer {superuser_token}"} + ) + + assert response.status_code == status.HTTP_404_NOT_FOUND + + @pytest.mark.asyncio + async def test_admin_remove_organization_member_org_not_found(self, client, async_test_superuser, async_test_user, superuser_token): + """Test removing member from non-existent organization.""" + response = await client.delete( + f"/api/v1/admin/organizations/{uuid4()}/members/{async_test_user.id}", + headers={"Authorization": f"Bearer {superuser_token}"} + ) + + assert response.status_code == status.HTTP_404_NOT_FOUND diff --git a/backend/tests/api/test_auth_endpoints.py b/backend/tests/api/test_auth_endpoints.py index b753b85..833ff06 100755 --- a/backend/tests/api/test_auth_endpoints.py +++ b/backend/tests/api/test_auth_endpoints.py @@ -326,59 +326,3 @@ class TestRefreshTokenEndpoint: ) assert response.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR - - -class TestGetCurrentUserEndpoint: - """Tests for GET /auth/me endpoint.""" - - @pytest.mark.asyncio - async def test_get_current_user_success(self, client, async_test_user): - """Test getting current user info.""" - # First, login to get an access token - login_response = await client.post( - "/api/v1/auth/login", - json={ - "email": async_test_user.email, - "password": "TestPassword123!" - } - ) - access_token = login_response.json()["access_token"] - - # Get current user info - response = await client.get( - "/api/v1/auth/me", - headers={"Authorization": f"Bearer {access_token}"} - ) - - assert response.status_code == status.HTTP_200_OK - data = response.json() - assert data["email"] == async_test_user.email - assert data["first_name"] == async_test_user.first_name - - @pytest.mark.asyncio - async def test_get_current_user_no_token(self, client): - """Test getting current user without token.""" - response = await client.get("/api/v1/auth/me") - - assert response.status_code == status.HTTP_401_UNAUTHORIZED - - @pytest.mark.asyncio - async def test_get_current_user_invalid_token(self, client): - """Test getting current user with invalid token.""" - response = await client.get( - "/api/v1/auth/me", - headers={"Authorization": "Bearer invalid_token"} - ) - - assert response.status_code == status.HTTP_401_UNAUTHORIZED - - @pytest.mark.asyncio - async def test_get_current_user_expired_token(self, client): - """Test getting current user with expired token.""" - # Use a clearly invalid/malformed token - response = await client.get( - "/api/v1/auth/me", - headers={"Authorization": "Bearer eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.invalid"} - ) - - assert response.status_code == status.HTTP_401_UNAUTHORIZED diff --git a/backend/tests/crud/test_organization_async.py b/backend/tests/crud/test_organization_async.py new file mode 100644 index 0000000..f5586fe --- /dev/null +++ b/backend/tests/crud/test_organization_async.py @@ -0,0 +1,944 @@ +# tests/crud/test_organization_async.py +""" +Comprehensive tests for async organization CRUD operations. +""" +import pytest +from uuid import uuid4 +from sqlalchemy import select +from sqlalchemy.ext.asyncio import AsyncSession + +from app.crud.organization_async import organization_async +from app.models.organization import Organization +from app.models.user_organization import UserOrganization, OrganizationRole +from app.models.user import User +from app.schemas.organizations import OrganizationCreate, OrganizationUpdate + + +class TestGetBySlug: + """Tests for get_by_slug method.""" + + @pytest.mark.asyncio + async def test_get_by_slug_success(self, async_test_db): + """Test successfully getting an organization by slug.""" + test_engine, AsyncTestingSessionLocal = async_test_db + + # Create test organization + async with AsyncTestingSessionLocal() as session: + org = Organization( + name="Test Org", + slug="test-org", + description="Test description" + ) + session.add(org) + await session.commit() + org_id = org.id + + # Get by slug + async with AsyncTestingSessionLocal() as session: + result = await organization_async.get_by_slug(session, slug="test-org") + assert result is not None + assert result.id == org_id + assert result.slug == "test-org" + + @pytest.mark.asyncio + async def test_get_by_slug_not_found(self, async_test_db): + """Test getting non-existent organization by slug.""" + test_engine, AsyncTestingSessionLocal = async_test_db + + async with AsyncTestingSessionLocal() as session: + result = await organization_async.get_by_slug(session, slug="nonexistent") + assert result is None + + +class TestCreate: + """Tests for create method.""" + + @pytest.mark.asyncio + async def test_create_success(self, async_test_db): + """Test successfully creating an organization.""" + test_engine, AsyncTestingSessionLocal = async_test_db + + async with AsyncTestingSessionLocal() as session: + org_in = OrganizationCreate( + name="New Org", + slug="new-org", + description="New organization", + is_active=True, + settings={"key": "value"} + ) + result = await organization_async.create(session, obj_in=org_in) + + assert result.name == "New Org" + assert result.slug == "new-org" + assert result.description == "New organization" + assert result.is_active is True + assert result.settings == {"key": "value"} + + @pytest.mark.asyncio + async def test_create_duplicate_slug(self, async_test_db): + """Test creating organization with duplicate slug raises error.""" + test_engine, AsyncTestingSessionLocal = async_test_db + + # Create first org + async with AsyncTestingSessionLocal() as session: + org1 = Organization(name="Org 1", slug="duplicate-slug") + session.add(org1) + await session.commit() + + # Try to create second with same slug + async with AsyncTestingSessionLocal() as session: + org_in = OrganizationCreate( + name="Org 2", + slug="duplicate-slug" + ) + with pytest.raises(ValueError, match="already exists"): + await organization_async.create(session, obj_in=org_in) + + @pytest.mark.asyncio + async def test_create_without_settings(self, async_test_db): + """Test creating organization without settings (defaults to empty dict).""" + test_engine, AsyncTestingSessionLocal = async_test_db + + async with AsyncTestingSessionLocal() as session: + org_in = OrganizationCreate( + name="No Settings Org", + slug="no-settings" + ) + result = await organization_async.create(session, obj_in=org_in) + + assert result.settings == {} + + +class TestGetMultiWithFilters: + """Tests for get_multi_with_filters method.""" + + @pytest.mark.asyncio + async def test_get_multi_with_filters_no_filters(self, async_test_db): + """Test getting organizations without any filters.""" + test_engine, AsyncTestingSessionLocal = async_test_db + + # Create test organizations + async with AsyncTestingSessionLocal() as session: + for i in range(5): + org = Organization(name=f"Org {i}", slug=f"org-{i}") + session.add(org) + await session.commit() + + async with AsyncTestingSessionLocal() as session: + orgs, total = await organization_async.get_multi_with_filters(session) + assert total == 5 + assert len(orgs) == 5 + + @pytest.mark.asyncio + async def test_get_multi_with_filters_is_active(self, async_test_db): + """Test filtering by is_active.""" + test_engine, AsyncTestingSessionLocal = async_test_db + + async with AsyncTestingSessionLocal() as session: + org1 = Organization(name="Active", slug="active", is_active=True) + org2 = Organization(name="Inactive", slug="inactive", is_active=False) + session.add_all([org1, org2]) + await session.commit() + + async with AsyncTestingSessionLocal() as session: + orgs, total = await organization_async.get_multi_with_filters( + session, + is_active=True + ) + assert total == 1 + assert orgs[0].name == "Active" + + @pytest.mark.asyncio + async def test_get_multi_with_filters_search(self, async_test_db): + """Test searching organizations.""" + test_engine, AsyncTestingSessionLocal = async_test_db + + async with AsyncTestingSessionLocal() as session: + org1 = Organization(name="Tech Corp", slug="tech-corp", description="Technology") + org2 = Organization(name="Food Inc", slug="food-inc", description="Restaurant") + session.add_all([org1, org2]) + await session.commit() + + async with AsyncTestingSessionLocal() as session: + orgs, total = await organization_async.get_multi_with_filters( + session, + search="tech" + ) + assert total == 1 + assert orgs[0].name == "Tech Corp" + + @pytest.mark.asyncio + async def test_get_multi_with_filters_pagination(self, async_test_db): + """Test pagination.""" + test_engine, AsyncTestingSessionLocal = async_test_db + + async with AsyncTestingSessionLocal() as session: + for i in range(10): + org = Organization(name=f"Org {i}", slug=f"org-{i}") + session.add(org) + await session.commit() + + async with AsyncTestingSessionLocal() as session: + orgs, total = await organization_async.get_multi_with_filters( + session, + skip=2, + limit=3 + ) + assert total == 10 + assert len(orgs) == 3 + + @pytest.mark.asyncio + async def test_get_multi_with_filters_sorting(self, async_test_db): + """Test sorting.""" + test_engine, AsyncTestingSessionLocal = async_test_db + + async with AsyncTestingSessionLocal() as session: + org1 = Organization(name="B Org", slug="b-org") + org2 = Organization(name="A Org", slug="a-org") + session.add_all([org1, org2]) + await session.commit() + + async with AsyncTestingSessionLocal() as session: + orgs, total = await organization_async.get_multi_with_filters( + session, + sort_by="name", + sort_order="asc" + ) + assert orgs[0].name == "A Org" + assert orgs[1].name == "B Org" + + +class TestGetMemberCount: + """Tests for get_member_count method.""" + + @pytest.mark.asyncio + async def test_get_member_count_success(self, async_test_db, async_test_user): + """Test getting member count for organization.""" + test_engine, AsyncTestingSessionLocal = async_test_db + + async with AsyncTestingSessionLocal() as session: + org = Organization(name="Test Org", slug="test-org") + session.add(org) + await session.commit() + + # Add 1 active member + user_org = UserOrganization( + user_id=async_test_user.id, + organization_id=org.id, + role=OrganizationRole.MEMBER, + is_active=True + ) + session.add(user_org) + await session.commit() + org_id = org.id + + async with AsyncTestingSessionLocal() as session: + count = await organization_async.get_member_count(session, organization_id=org_id) + assert count == 1 + + @pytest.mark.asyncio + async def test_get_member_count_no_members(self, async_test_db): + """Test getting member count for organization with no members.""" + test_engine, AsyncTestingSessionLocal = async_test_db + + async with AsyncTestingSessionLocal() as session: + org = Organization(name="Empty Org", slug="empty-org") + session.add(org) + await session.commit() + org_id = org.id + + async with AsyncTestingSessionLocal() as session: + count = await organization_async.get_member_count(session, organization_id=org_id) + assert count == 0 + + +class TestAddUser: + """Tests for add_user method.""" + + @pytest.mark.asyncio + async def test_add_user_success(self, async_test_db, async_test_user): + """Test successfully adding a user to organization.""" + test_engine, AsyncTestingSessionLocal = async_test_db + + async with AsyncTestingSessionLocal() as session: + org = Organization(name="Test Org", slug="test-org") + session.add(org) + await session.commit() + org_id = org.id + + async with AsyncTestingSessionLocal() as session: + result = await organization_async.add_user( + session, + organization_id=org_id, + user_id=async_test_user.id, + role=OrganizationRole.ADMIN + ) + + assert result.user_id == async_test_user.id + assert result.organization_id == org_id + assert result.role == OrganizationRole.ADMIN + assert result.is_active is True + + @pytest.mark.asyncio + async def test_add_user_already_active_member(self, async_test_db, async_test_user): + """Test adding user who is already an active member raises error.""" + test_engine, AsyncTestingSessionLocal = async_test_db + + async with AsyncTestingSessionLocal() as session: + org = Organization(name="Test Org", slug="test-org") + session.add(org) + await session.commit() + + user_org = UserOrganization( + user_id=async_test_user.id, + organization_id=org.id, + role=OrganizationRole.MEMBER, + is_active=True + ) + session.add(user_org) + await session.commit() + org_id = org.id + + async with AsyncTestingSessionLocal() as session: + with pytest.raises(ValueError, match="already a member"): + await organization_async.add_user( + session, + organization_id=org_id, + user_id=async_test_user.id + ) + + @pytest.mark.asyncio + async def test_add_user_reactivate_inactive(self, async_test_db, async_test_user): + """Test adding user who was previously inactive reactivates them.""" + test_engine, AsyncTestingSessionLocal = async_test_db + + async with AsyncTestingSessionLocal() as session: + org = Organization(name="Test Org", slug="test-org") + session.add(org) + await session.commit() + + user_org = UserOrganization( + user_id=async_test_user.id, + organization_id=org.id, + role=OrganizationRole.MEMBER, + is_active=False + ) + session.add(user_org) + await session.commit() + org_id = org.id + + async with AsyncTestingSessionLocal() as session: + result = await organization_async.add_user( + session, + organization_id=org_id, + user_id=async_test_user.id, + role=OrganizationRole.ADMIN + ) + + assert result.is_active is True + assert result.role == OrganizationRole.ADMIN + + +class TestRemoveUser: + """Tests for remove_user method.""" + + @pytest.mark.asyncio + async def test_remove_user_success(self, async_test_db, async_test_user): + """Test successfully removing a user from organization.""" + test_engine, AsyncTestingSessionLocal = async_test_db + + async with AsyncTestingSessionLocal() as session: + org = Organization(name="Test Org", slug="test-org") + session.add(org) + await session.commit() + + user_org = UserOrganization( + user_id=async_test_user.id, + organization_id=org.id, + role=OrganizationRole.MEMBER, + is_active=True + ) + session.add(user_org) + await session.commit() + org_id = org.id + + async with AsyncTestingSessionLocal() as session: + result = await organization_async.remove_user( + session, + organization_id=org_id, + user_id=async_test_user.id + ) + + assert result is True + + # Verify soft delete + async with AsyncTestingSessionLocal() as session: + stmt = select(UserOrganization).where( + UserOrganization.user_id == async_test_user.id, + UserOrganization.organization_id == org_id + ) + result = await session.execute(stmt) + user_org = result.scalar_one_or_none() + assert user_org.is_active is False + + @pytest.mark.asyncio + async def test_remove_user_not_found(self, async_test_db): + """Test removing non-existent user returns False.""" + test_engine, AsyncTestingSessionLocal = async_test_db + + async with AsyncTestingSessionLocal() as session: + org = Organization(name="Test Org", slug="test-org") + session.add(org) + await session.commit() + org_id = org.id + + async with AsyncTestingSessionLocal() as session: + result = await organization_async.remove_user( + session, + organization_id=org_id, + user_id=uuid4() + ) + + assert result is False + + +class TestUpdateUserRole: + """Tests for update_user_role method.""" + + @pytest.mark.asyncio + async def test_update_user_role_success(self, async_test_db, async_test_user): + """Test successfully updating user role.""" + test_engine, AsyncTestingSessionLocal = async_test_db + + async with AsyncTestingSessionLocal() as session: + org = Organization(name="Test Org", slug="test-org") + session.add(org) + await session.commit() + + user_org = UserOrganization( + user_id=async_test_user.id, + organization_id=org.id, + role=OrganizationRole.MEMBER, + is_active=True + ) + session.add(user_org) + await session.commit() + org_id = org.id + + async with AsyncTestingSessionLocal() as session: + result = await organization_async.update_user_role( + session, + organization_id=org_id, + user_id=async_test_user.id, + role=OrganizationRole.ADMIN, + custom_permissions="custom" + ) + + assert result.role == OrganizationRole.ADMIN + assert result.custom_permissions == "custom" + + @pytest.mark.asyncio + async def test_update_user_role_not_found(self, async_test_db): + """Test updating role for non-existent user returns None.""" + test_engine, AsyncTestingSessionLocal = async_test_db + + async with AsyncTestingSessionLocal() as session: + org = Organization(name="Test Org", slug="test-org") + session.add(org) + await session.commit() + org_id = org.id + + async with AsyncTestingSessionLocal() as session: + result = await organization_async.update_user_role( + session, + organization_id=org_id, + user_id=uuid4(), + role=OrganizationRole.ADMIN + ) + + assert result is None + + +class TestGetOrganizationMembers: + """Tests for get_organization_members method.""" + + @pytest.mark.asyncio + async def test_get_organization_members_success(self, async_test_db, async_test_user): + """Test getting organization members.""" + test_engine, AsyncTestingSessionLocal = async_test_db + + async with AsyncTestingSessionLocal() as session: + org = Organization(name="Test Org", slug="test-org") + session.add(org) + await session.commit() + + user_org = UserOrganization( + user_id=async_test_user.id, + organization_id=org.id, + role=OrganizationRole.ADMIN, + is_active=True + ) + session.add(user_org) + await session.commit() + org_id = org.id + + async with AsyncTestingSessionLocal() as session: + members, total = await organization_async.get_organization_members( + session, + organization_id=org_id + ) + + assert total == 1 + assert len(members) == 1 + assert members[0]["user_id"] == async_test_user.id + assert members[0]["email"] == async_test_user.email + assert members[0]["role"] == OrganizationRole.ADMIN + + @pytest.mark.asyncio + async def test_get_organization_members_with_pagination(self, async_test_db, async_test_user): + """Test getting organization members with pagination.""" + test_engine, AsyncTestingSessionLocal = async_test_db + + async with AsyncTestingSessionLocal() as session: + org = Organization(name="Test Org", slug="test-org") + session.add(org) + await session.commit() + + user_org = UserOrganization( + user_id=async_test_user.id, + organization_id=org.id, + role=OrganizationRole.MEMBER, + is_active=True + ) + session.add(user_org) + await session.commit() + org_id = org.id + + async with AsyncTestingSessionLocal() as session: + members, total = await organization_async.get_organization_members( + session, + organization_id=org_id, + skip=0, + limit=10 + ) + + assert total == 1 + assert len(members) <= 10 + + +class TestGetUserOrganizations: + """Tests for get_user_organizations method.""" + + @pytest.mark.asyncio + async def test_get_user_organizations_success(self, async_test_db, async_test_user): + """Test getting user's organizations.""" + test_engine, AsyncTestingSessionLocal = async_test_db + + async with AsyncTestingSessionLocal() as session: + org = Organization(name="Test Org", slug="test-org") + session.add(org) + await session.commit() + + user_org = UserOrganization( + user_id=async_test_user.id, + organization_id=org.id, + role=OrganizationRole.MEMBER, + is_active=True + ) + session.add(user_org) + await session.commit() + + async with AsyncTestingSessionLocal() as session: + orgs = await organization_async.get_user_organizations( + session, + user_id=async_test_user.id + ) + + assert len(orgs) == 1 + assert orgs[0].name == "Test Org" + + @pytest.mark.asyncio + async def test_get_user_organizations_filter_inactive(self, async_test_db, async_test_user): + """Test filtering inactive organizations.""" + test_engine, AsyncTestingSessionLocal = async_test_db + + async with AsyncTestingSessionLocal() as session: + org1 = Organization(name="Active Org", slug="active-org") + org2 = Organization(name="Inactive Org", slug="inactive-org") + session.add_all([org1, org2]) + await session.commit() + + user_org1 = UserOrganization( + user_id=async_test_user.id, + organization_id=org1.id, + role=OrganizationRole.MEMBER, + is_active=True + ) + user_org2 = UserOrganization( + user_id=async_test_user.id, + organization_id=org2.id, + role=OrganizationRole.MEMBER, + is_active=False + ) + session.add_all([user_org1, user_org2]) + await session.commit() + + async with AsyncTestingSessionLocal() as session: + orgs = await organization_async.get_user_organizations( + session, + user_id=async_test_user.id, + is_active=True + ) + + assert len(orgs) == 1 + assert orgs[0].name == "Active Org" + + +class TestGetUserRole: + """Tests for get_user_role_in_org method.""" + + @pytest.mark.asyncio + async def test_get_user_role_in_org_success(self, async_test_db, async_test_user): + """Test getting user role in organization.""" + test_engine, AsyncTestingSessionLocal = async_test_db + + async with AsyncTestingSessionLocal() as session: + org = Organization(name="Test Org", slug="test-org") + session.add(org) + await session.commit() + + user_org = UserOrganization( + user_id=async_test_user.id, + organization_id=org.id, + role=OrganizationRole.ADMIN, + is_active=True + ) + session.add(user_org) + await session.commit() + org_id = org.id + + async with AsyncTestingSessionLocal() as session: + role = await organization_async.get_user_role_in_org( + session, + user_id=async_test_user.id, + organization_id=org_id + ) + + assert role == OrganizationRole.ADMIN + + @pytest.mark.asyncio + async def test_get_user_role_in_org_not_found(self, async_test_db): + """Test getting role for non-member returns None.""" + test_engine, AsyncTestingSessionLocal = async_test_db + + async with AsyncTestingSessionLocal() as session: + org = Organization(name="Test Org", slug="test-org") + session.add(org) + await session.commit() + org_id = org.id + + async with AsyncTestingSessionLocal() as session: + role = await organization_async.get_user_role_in_org( + session, + user_id=uuid4(), + organization_id=org_id + ) + + assert role is None + + +class TestIsUserOrgOwner: + """Tests for is_user_org_owner method.""" + + @pytest.mark.asyncio + async def test_is_user_org_owner_true(self, async_test_db, async_test_user): + """Test checking if user is owner.""" + test_engine, AsyncTestingSessionLocal = async_test_db + + async with AsyncTestingSessionLocal() as session: + org = Organization(name="Test Org", slug="test-org") + session.add(org) + await session.commit() + + user_org = UserOrganization( + user_id=async_test_user.id, + organization_id=org.id, + role=OrganizationRole.OWNER, + is_active=True + ) + session.add(user_org) + await session.commit() + org_id = org.id + + async with AsyncTestingSessionLocal() as session: + is_owner = await organization_async.is_user_org_owner( + session, + user_id=async_test_user.id, + organization_id=org_id + ) + + assert is_owner is True + + @pytest.mark.asyncio + async def test_is_user_org_owner_false(self, async_test_db, async_test_user): + """Test checking if non-owner user is owner.""" + test_engine, AsyncTestingSessionLocal = async_test_db + + async with AsyncTestingSessionLocal() as session: + org = Organization(name="Test Org", slug="test-org") + session.add(org) + await session.commit() + + user_org = UserOrganization( + user_id=async_test_user.id, + organization_id=org.id, + role=OrganizationRole.MEMBER, + is_active=True + ) + session.add(user_org) + await session.commit() + org_id = org.id + + async with AsyncTestingSessionLocal() as session: + is_owner = await organization_async.is_user_org_owner( + session, + user_id=async_test_user.id, + organization_id=org_id + ) + + assert is_owner is False + + +class TestGetMultiWithMemberCounts: + """Tests for get_multi_with_member_counts method.""" + + @pytest.mark.asyncio + async def test_get_multi_with_member_counts_success(self, async_test_db, async_test_user): + """Test getting organizations with member counts.""" + test_engine, AsyncTestingSessionLocal = async_test_db + + async with AsyncTestingSessionLocal() as session: + org1 = Organization(name="Org 1", slug="org-1") + org2 = Organization(name="Org 2", slug="org-2") + session.add_all([org1, org2]) + await session.commit() + + # Add members to org1 + user_org1 = UserOrganization( + user_id=async_test_user.id, + organization_id=org1.id, + role=OrganizationRole.MEMBER, + is_active=True + ) + session.add(user_org1) + await session.commit() + + async with AsyncTestingSessionLocal() as session: + orgs_with_counts, total = await organization_async.get_multi_with_member_counts(session) + + assert total == 2 + assert len(orgs_with_counts) == 2 + # Verify structure + assert 'organization' in orgs_with_counts[0] + assert 'member_count' in orgs_with_counts[0] + + @pytest.mark.asyncio + async def test_get_multi_with_member_counts_with_filters(self, async_test_db): + """Test getting organizations with member counts and filters.""" + test_engine, AsyncTestingSessionLocal = async_test_db + + async with AsyncTestingSessionLocal() as session: + org1 = Organization(name="Active Org", slug="active-org", is_active=True) + org2 = Organization(name="Inactive Org", slug="inactive-org", is_active=False) + session.add_all([org1, org2]) + await session.commit() + + async with AsyncTestingSessionLocal() as session: + orgs_with_counts, total = await organization_async.get_multi_with_member_counts( + session, + is_active=True + ) + + assert total == 1 + assert orgs_with_counts[0]['organization'].name == "Active Org" + + @pytest.mark.asyncio + async def test_get_multi_with_member_counts_with_search(self, async_test_db): + """Test searching organizations with member counts.""" + test_engine, AsyncTestingSessionLocal = async_test_db + + async with AsyncTestingSessionLocal() as session: + org1 = Organization(name="Tech Corp", slug="tech-corp") + org2 = Organization(name="Food Inc", slug="food-inc") + session.add_all([org1, org2]) + await session.commit() + + async with AsyncTestingSessionLocal() as session: + orgs_with_counts, total = await organization_async.get_multi_with_member_counts( + session, + search="tech" + ) + + assert total == 1 + assert orgs_with_counts[0]['organization'].name == "Tech Corp" + + +class TestGetUserOrganizationsWithDetails: + """Tests for get_user_organizations_with_details method.""" + + @pytest.mark.asyncio + async def test_get_user_organizations_with_details_success(self, async_test_db, async_test_user): + """Test getting user organizations with role and member count.""" + test_engine, AsyncTestingSessionLocal = async_test_db + + async with AsyncTestingSessionLocal() as session: + org = Organization(name="Test Org", slug="test-org") + session.add(org) + await session.commit() + + user_org = UserOrganization( + user_id=async_test_user.id, + organization_id=org.id, + role=OrganizationRole.ADMIN, + is_active=True + ) + session.add(user_org) + await session.commit() + + async with AsyncTestingSessionLocal() as session: + orgs_with_details = await organization_async.get_user_organizations_with_details( + session, + user_id=async_test_user.id + ) + + assert len(orgs_with_details) == 1 + assert orgs_with_details[0]['organization'].name == "Test Org" + assert orgs_with_details[0]['role'] == OrganizationRole.ADMIN + assert 'member_count' in orgs_with_details[0] + + @pytest.mark.asyncio + async def test_get_user_organizations_with_details_filter_inactive(self, async_test_db, async_test_user): + """Test filtering inactive organizations in user details.""" + test_engine, AsyncTestingSessionLocal = async_test_db + + async with AsyncTestingSessionLocal() as session: + org1 = Organization(name="Active Org", slug="active-org") + org2 = Organization(name="Inactive Org", slug="inactive-org") + session.add_all([org1, org2]) + await session.commit() + + user_org1 = UserOrganization( + user_id=async_test_user.id, + organization_id=org1.id, + role=OrganizationRole.MEMBER, + is_active=True + ) + user_org2 = UserOrganization( + user_id=async_test_user.id, + organization_id=org2.id, + role=OrganizationRole.MEMBER, + is_active=False + ) + session.add_all([user_org1, user_org2]) + await session.commit() + + async with AsyncTestingSessionLocal() as session: + orgs_with_details = await organization_async.get_user_organizations_with_details( + session, + user_id=async_test_user.id, + is_active=True + ) + + assert len(orgs_with_details) == 1 + assert orgs_with_details[0]['organization'].name == "Active Org" + + +class TestIsUserOrgAdmin: + """Tests for is_user_org_admin method.""" + + @pytest.mark.asyncio + async def test_is_user_org_admin_owner(self, async_test_db, async_test_user): + """Test checking if owner is admin (should be True).""" + test_engine, AsyncTestingSessionLocal = async_test_db + + async with AsyncTestingSessionLocal() as session: + org = Organization(name="Test Org", slug="test-org") + session.add(org) + await session.commit() + + user_org = UserOrganization( + user_id=async_test_user.id, + organization_id=org.id, + role=OrganizationRole.OWNER, + is_active=True + ) + session.add(user_org) + await session.commit() + org_id = org.id + + async with AsyncTestingSessionLocal() as session: + is_admin = await organization_async.is_user_org_admin( + session, + user_id=async_test_user.id, + organization_id=org_id + ) + + assert is_admin is True + + @pytest.mark.asyncio + async def test_is_user_org_admin_admin_role(self, async_test_db, async_test_user): + """Test checking if admin role is admin.""" + test_engine, AsyncTestingSessionLocal = async_test_db + + async with AsyncTestingSessionLocal() as session: + org = Organization(name="Test Org", slug="test-org") + session.add(org) + await session.commit() + + user_org = UserOrganization( + user_id=async_test_user.id, + organization_id=org.id, + role=OrganizationRole.ADMIN, + is_active=True + ) + session.add(user_org) + await session.commit() + org_id = org.id + + async with AsyncTestingSessionLocal() as session: + is_admin = await organization_async.is_user_org_admin( + session, + user_id=async_test_user.id, + organization_id=org_id + ) + + assert is_admin is True + + @pytest.mark.asyncio + async def test_is_user_org_admin_member_false(self, async_test_db, async_test_user): + """Test checking if regular member is admin.""" + test_engine, AsyncTestingSessionLocal = async_test_db + + async with AsyncTestingSessionLocal() as session: + org = Organization(name="Test Org", slug="test-org") + session.add(org) + await session.commit() + + user_org = UserOrganization( + user_id=async_test_user.id, + organization_id=org.id, + role=OrganizationRole.MEMBER, + is_active=True + ) + session.add(user_org) + await session.commit() + org_id = org.id + + async with AsyncTestingSessionLocal() as session: + is_admin = await organization_async.is_user_org_admin( + session, + user_id=async_test_user.id, + organization_id=org_id + ) + + assert is_admin is False diff --git a/backend/tests/crud/test_session_async.py b/backend/tests/crud/test_session_async.py new file mode 100644 index 0000000..5759379 --- /dev/null +++ b/backend/tests/crud/test_session_async.py @@ -0,0 +1,339 @@ +# tests/crud/test_session_async.py +""" +Comprehensive tests for async session CRUD operations. +""" +import pytest +from datetime import datetime, timedelta, timezone +from uuid import uuid4 + +from app.crud.session_async import session_async +from app.models.user_session import UserSession +from app.schemas.sessions import SessionCreate + + +class TestGetByJti: + """Tests for get_by_jti method.""" + + @pytest.mark.asyncio + async def test_get_by_jti_success(self, async_test_db, async_test_user): + """Test getting session by JTI.""" + test_engine, AsyncTestingSessionLocal = async_test_db + + async with AsyncTestingSessionLocal() as session: + user_session = UserSession( + user_id=async_test_user.id, + refresh_token_jti="test_jti_123", + device_name="Test Device", + ip_address="192.168.1.1", + user_agent="Mozilla/5.0", + is_active=True, + expires_at=datetime.now(timezone.utc) + timedelta(days=7), + last_used_at=datetime.now(timezone.utc) + ) + session.add(user_session) + await session.commit() + + async with AsyncTestingSessionLocal() as session: + result = await session_async.get_by_jti(session, jti="test_jti_123") + assert result is not None + assert result.refresh_token_jti == "test_jti_123" + + @pytest.mark.asyncio + async def test_get_by_jti_not_found(self, async_test_db): + """Test getting non-existent JTI returns None.""" + test_engine, AsyncTestingSessionLocal = async_test_db + + async with AsyncTestingSessionLocal() as session: + result = await session_async.get_by_jti(session, jti="nonexistent") + assert result is None + + +class TestGetActiveByJti: + """Tests for get_active_by_jti method.""" + + @pytest.mark.asyncio + async def test_get_active_by_jti_success(self, async_test_db, async_test_user): + """Test getting active session by JTI.""" + test_engine, AsyncTestingSessionLocal = async_test_db + + async with AsyncTestingSessionLocal() as session: + user_session = UserSession( + user_id=async_test_user.id, + refresh_token_jti="active_jti", + device_name="Test Device", + ip_address="192.168.1.1", + user_agent="Mozilla/5.0", + is_active=True, + expires_at=datetime.now(timezone.utc) + timedelta(days=7), + last_used_at=datetime.now(timezone.utc) + ) + session.add(user_session) + await session.commit() + + async with AsyncTestingSessionLocal() as session: + result = await session_async.get_active_by_jti(session, jti="active_jti") + assert result is not None + assert result.is_active is True + + @pytest.mark.asyncio + async def test_get_active_by_jti_inactive(self, async_test_db, async_test_user): + """Test getting inactive session by JTI returns None.""" + test_engine, AsyncTestingSessionLocal = async_test_db + + async with AsyncTestingSessionLocal() as session: + user_session = UserSession( + user_id=async_test_user.id, + refresh_token_jti="inactive_jti", + device_name="Test Device", + ip_address="192.168.1.1", + user_agent="Mozilla/5.0", + is_active=False, + expires_at=datetime.now(timezone.utc) + timedelta(days=7), + last_used_at=datetime.now(timezone.utc) + ) + session.add(user_session) + await session.commit() + + async with AsyncTestingSessionLocal() as session: + result = await session_async.get_active_by_jti(session, jti="inactive_jti") + assert result is None + + +class TestGetUserSessions: + """Tests for get_user_sessions method.""" + + @pytest.mark.asyncio + async def test_get_user_sessions_active_only(self, async_test_db, async_test_user): + """Test getting only active user sessions.""" + test_engine, AsyncTestingSessionLocal = async_test_db + + async with AsyncTestingSessionLocal() as session: + active = UserSession( + user_id=async_test_user.id, + refresh_token_jti="active", + device_name="Active Device", + ip_address="192.168.1.1", + user_agent="Mozilla/5.0", + is_active=True, + expires_at=datetime.now(timezone.utc) + timedelta(days=7), + last_used_at=datetime.now(timezone.utc) + ) + inactive = UserSession( + user_id=async_test_user.id, + refresh_token_jti="inactive", + device_name="Inactive Device", + ip_address="192.168.1.2", + user_agent="Mozilla/5.0", + is_active=False, + expires_at=datetime.now(timezone.utc) + timedelta(days=7), + last_used_at=datetime.now(timezone.utc) + ) + session.add_all([active, inactive]) + await session.commit() + + async with AsyncTestingSessionLocal() as session: + results = await session_async.get_user_sessions( + session, + user_id=str(async_test_user.id), + active_only=True + ) + assert len(results) == 1 + assert results[0].is_active is True + + @pytest.mark.asyncio + async def test_get_user_sessions_all(self, async_test_db, async_test_user): + """Test getting all user sessions.""" + test_engine, AsyncTestingSessionLocal = async_test_db + + async with AsyncTestingSessionLocal() as session: + for i in range(3): + sess = UserSession( + user_id=async_test_user.id, + refresh_token_jti=f"session_{i}", + device_name=f"Device {i}", + ip_address="192.168.1.1", + user_agent="Mozilla/5.0", + is_active=i % 2 == 0, + expires_at=datetime.now(timezone.utc) + timedelta(days=7), + last_used_at=datetime.now(timezone.utc) + ) + session.add(sess) + await session.commit() + + async with AsyncTestingSessionLocal() as session: + results = await session_async.get_user_sessions( + session, + user_id=str(async_test_user.id), + active_only=False + ) + assert len(results) == 3 + + +class TestCreateSession: + """Tests for create_session method.""" + + @pytest.mark.asyncio + async def test_create_session_success(self, async_test_db, async_test_user): + """Test successfully creating a session.""" + test_engine, AsyncTestingSessionLocal = async_test_db + + async with AsyncTestingSessionLocal() as session: + session_data = SessionCreate( + user_id=async_test_user.id, + refresh_token_jti="new_jti", + device_name="New Device", + device_id="device_123", + ip_address="192.168.1.100", + user_agent="Mozilla/5.0", + last_used_at=datetime.now(timezone.utc), + expires_at=datetime.now(timezone.utc) + timedelta(days=7), + location_city="San Francisco", + location_country="USA" + ) + result = await session_async.create_session(session, obj_in=session_data) + + assert result.user_id == async_test_user.id + assert result.refresh_token_jti == "new_jti" + assert result.is_active is True + assert result.location_city == "San Francisco" + + +class TestDeactivate: + """Tests for deactivate method.""" + + @pytest.mark.asyncio + async def test_deactivate_success(self, async_test_db, async_test_user): + """Test successfully deactivating a session.""" + test_engine, AsyncTestingSessionLocal = async_test_db + + async with AsyncTestingSessionLocal() as session: + user_session = UserSession( + user_id=async_test_user.id, + refresh_token_jti="to_deactivate", + device_name="Test Device", + ip_address="192.168.1.1", + user_agent="Mozilla/5.0", + is_active=True, + expires_at=datetime.now(timezone.utc) + timedelta(days=7), + last_used_at=datetime.now(timezone.utc) + ) + session.add(user_session) + await session.commit() + session_id = user_session.id + + async with AsyncTestingSessionLocal() as session: + result = await session_async.deactivate(session, session_id=str(session_id)) + assert result is not None + assert result.is_active is False + + @pytest.mark.asyncio + async def test_deactivate_not_found(self, async_test_db): + """Test deactivating non-existent session returns None.""" + test_engine, AsyncTestingSessionLocal = async_test_db + + async with AsyncTestingSessionLocal() as session: + result = await session_async.deactivate(session, session_id=str(uuid4())) + assert result is None + + +class TestDeactivateAllUserSessions: + """Tests for deactivate_all_user_sessions method.""" + + @pytest.mark.asyncio + async def test_deactivate_all_user_sessions_success(self, async_test_db, async_test_user): + """Test deactivating all user sessions.""" + test_engine, AsyncTestingSessionLocal = async_test_db + + async with AsyncTestingSessionLocal() as session: + for i in range(5): + sess = UserSession( + user_id=async_test_user.id, + refresh_token_jti=f"bulk_{i}", + device_name=f"Device {i}", + ip_address="192.168.1.1", + user_agent="Mozilla/5.0", + is_active=True, + expires_at=datetime.now(timezone.utc) + timedelta(days=7), + last_used_at=datetime.now(timezone.utc) + ) + session.add(sess) + await session.commit() + + async with AsyncTestingSessionLocal() as session: + count = await session_async.deactivate_all_user_sessions( + session, + user_id=str(async_test_user.id) + ) + assert count == 5 + + +class TestUpdateLastUsed: + """Tests for update_last_used method.""" + + @pytest.mark.asyncio + async def test_update_last_used_success(self, async_test_db, async_test_user): + """Test updating last_used_at timestamp.""" + test_engine, AsyncTestingSessionLocal = async_test_db + + async with AsyncTestingSessionLocal() as session: + user_session = UserSession( + user_id=async_test_user.id, + refresh_token_jti="update_test", + device_name="Test Device", + ip_address="192.168.1.1", + user_agent="Mozilla/5.0", + is_active=True, + expires_at=datetime.now(timezone.utc) + timedelta(days=7), + last_used_at=datetime.now(timezone.utc) - timedelta(hours=1) + ) + session.add(user_session) + await session.commit() + await session.refresh(user_session) + + old_time = user_session.last_used_at + result = await session_async.update_last_used(session, session=user_session) + + assert result.last_used_at > old_time + + +class TestGetUserSessionCount: + """Tests for get_user_session_count method.""" + + @pytest.mark.asyncio + async def test_get_user_session_count_success(self, async_test_db, async_test_user): + """Test getting user session count.""" + test_engine, AsyncTestingSessionLocal = async_test_db + + async with AsyncTestingSessionLocal() as session: + for i in range(3): + sess = UserSession( + user_id=async_test_user.id, + refresh_token_jti=f"count_{i}", + device_name=f"Device {i}", + ip_address="192.168.1.1", + user_agent="Mozilla/5.0", + is_active=True, + expires_at=datetime.now(timezone.utc) + timedelta(days=7), + last_used_at=datetime.now(timezone.utc) + ) + session.add(sess) + await session.commit() + + async with AsyncTestingSessionLocal() as session: + count = await session_async.get_user_session_count( + session, + user_id=str(async_test_user.id) + ) + assert count == 3 + + @pytest.mark.asyncio + async def test_get_user_session_count_empty(self, async_test_db): + """Test getting session count for user with no sessions.""" + test_engine, AsyncTestingSessionLocal = async_test_db + + async with AsyncTestingSessionLocal() as session: + count = await session_async.get_user_session_count( + session, + user_id=str(uuid4()) + ) + assert count == 0 diff --git a/backend/tests/crud/test_user_async.py b/backend/tests/crud/test_user_async.py new file mode 100644 index 0000000..72b5d91 --- /dev/null +++ b/backend/tests/crud/test_user_async.py @@ -0,0 +1,644 @@ +# tests/crud/test_user_async.py +""" +Comprehensive tests for async user CRUD operations. +""" +import pytest +from datetime import datetime, timezone +from uuid import uuid4 + +from app.crud.user_async import user_async +from app.models.user import User +from app.schemas.users import UserCreate, UserUpdate + + +class TestGetByEmail: + """Tests for get_by_email method.""" + + @pytest.mark.asyncio + async def test_get_by_email_success(self, async_test_db, async_test_user): + """Test getting user by email.""" + test_engine, AsyncTestingSessionLocal = async_test_db + + async with AsyncTestingSessionLocal() as session: + result = await user_async.get_by_email(session, email=async_test_user.email) + assert result is not None + assert result.email == async_test_user.email + assert result.id == async_test_user.id + + @pytest.mark.asyncio + async def test_get_by_email_not_found(self, async_test_db): + """Test getting non-existent email returns None.""" + test_engine, AsyncTestingSessionLocal = async_test_db + + async with AsyncTestingSessionLocal() as session: + result = await user_async.get_by_email(session, email="nonexistent@example.com") + assert result is None + + +class TestCreate: + """Tests for create method.""" + + @pytest.mark.asyncio + async def test_create_user_success(self, async_test_db): + """Test successfully creating a user.""" + test_engine, AsyncTestingSessionLocal = async_test_db + + async with AsyncTestingSessionLocal() as session: + user_data = UserCreate( + email="newuser@example.com", + password="SecurePass123!", + first_name="New", + last_name="User", + phone_number="+1234567890" + ) + result = await user_async.create(session, obj_in=user_data) + + assert result.email == "newuser@example.com" + assert result.first_name == "New" + assert result.last_name == "User" + assert result.phone_number == "+1234567890" + assert result.is_active is True + assert result.is_superuser is False + assert result.password_hash is not None + assert result.password_hash != "SecurePass123!" # Password should be hashed + + @pytest.mark.asyncio + async def test_create_superuser_success(self, async_test_db): + """Test creating a superuser.""" + test_engine, AsyncTestingSessionLocal = async_test_db + + async with AsyncTestingSessionLocal() as session: + user_data = UserCreate( + email="superuser@example.com", + password="SuperPass123!", + first_name="Super", + last_name="User", + is_superuser=True + ) + result = await user_async.create(session, obj_in=user_data) + + assert result.is_superuser is True + assert result.email == "superuser@example.com" + + @pytest.mark.asyncio + async def test_create_duplicate_email_fails(self, async_test_db, async_test_user): + """Test creating user with duplicate email raises ValueError.""" + test_engine, AsyncTestingSessionLocal = async_test_db + + async with AsyncTestingSessionLocal() as session: + user_data = UserCreate( + email=async_test_user.email, # Duplicate email + password="AnotherPass123!", + first_name="Duplicate", + last_name="User" + ) + + with pytest.raises(ValueError) as exc_info: + await user_async.create(session, obj_in=user_data) + + assert "already exists" in str(exc_info.value).lower() + + +class TestUpdate: + """Tests for update method.""" + + @pytest.mark.asyncio + async def test_update_user_basic_fields(self, async_test_db, async_test_user): + """Test updating basic user fields.""" + test_engine, AsyncTestingSessionLocal = async_test_db + + async with AsyncTestingSessionLocal() as session: + # Get fresh copy of user + user = await user_async.get(session, id=str(async_test_user.id)) + + update_data = UserUpdate( + first_name="Updated", + last_name="Name", + phone_number="+9876543210" + ) + result = await user_async.update(session, db_obj=user, obj_in=update_data) + + assert result.first_name == "Updated" + assert result.last_name == "Name" + assert result.phone_number == "+9876543210" + + @pytest.mark.asyncio + async def test_update_user_password(self, async_test_db): + """Test updating user password.""" + test_engine, AsyncTestingSessionLocal = async_test_db + + # Create a fresh user for this test + async with AsyncTestingSessionLocal() as session: + user_data = UserCreate( + email="passwordtest@example.com", + password="OldPassword123!", + first_name="Pass", + last_name="Test" + ) + user = await user_async.create(session, obj_in=user_data) + user_id = user.id + old_password_hash = user.password_hash + + # Update the password + async with AsyncTestingSessionLocal() as session: + user = await user_async.get(session, id=str(user_id)) + + update_data = UserUpdate(password="NewDifferentPassword123!") + result = await user_async.update(session, db_obj=user, obj_in=update_data) + + await session.refresh(result) + assert result.password_hash != old_password_hash + assert result.password_hash is not None + assert "NewDifferentPassword123!" not in result.password_hash # Should be hashed + + @pytest.mark.asyncio + async def test_update_user_with_dict(self, async_test_db, async_test_user): + """Test updating user with dictionary.""" + test_engine, AsyncTestingSessionLocal = async_test_db + + async with AsyncTestingSessionLocal() as session: + user = await user_async.get(session, id=str(async_test_user.id)) + + update_dict = {"first_name": "DictUpdate"} + result = await user_async.update(session, db_obj=user, obj_in=update_dict) + + assert result.first_name == "DictUpdate" + + +class TestGetMultiWithTotal: + """Tests for get_multi_with_total method.""" + + @pytest.mark.asyncio + async def test_get_multi_with_total_basic(self, async_test_db, async_test_user): + """Test basic pagination.""" + test_engine, AsyncTestingSessionLocal = async_test_db + + async with AsyncTestingSessionLocal() as session: + users, total = await user_async.get_multi_with_total( + session, + skip=0, + limit=10 + ) + assert total >= 1 + assert len(users) >= 1 + assert any(u.id == async_test_user.id for u in users) + + @pytest.mark.asyncio + async def test_get_multi_with_total_sorting_asc(self, async_test_db): + """Test sorting in ascending order.""" + test_engine, AsyncTestingSessionLocal = async_test_db + + # Create multiple users + async with AsyncTestingSessionLocal() as session: + for i in range(3): + user_data = UserCreate( + email=f"sort{i}@example.com", + password="SecurePass123!", + first_name=f"User{i}", + last_name="Test" + ) + await user_async.create(session, obj_in=user_data) + + async with AsyncTestingSessionLocal() as session: + users, total = await user_async.get_multi_with_total( + session, + skip=0, + limit=10, + sort_by="email", + sort_order="asc" + ) + + # Check if sorted (at least the test users) + test_users = [u for u in users if u.email.startswith("sort")] + if len(test_users) > 1: + assert test_users[0].email < test_users[1].email + + @pytest.mark.asyncio + async def test_get_multi_with_total_sorting_desc(self, async_test_db): + """Test sorting in descending order.""" + test_engine, AsyncTestingSessionLocal = async_test_db + + # Create multiple users + async with AsyncTestingSessionLocal() as session: + for i in range(3): + user_data = UserCreate( + email=f"desc{i}@example.com", + password="SecurePass123!", + first_name=f"User{i}", + last_name="Test" + ) + await user_async.create(session, obj_in=user_data) + + async with AsyncTestingSessionLocal() as session: + users, total = await user_async.get_multi_with_total( + session, + skip=0, + limit=10, + sort_by="email", + sort_order="desc" + ) + + # Check if sorted descending (at least the test users) + test_users = [u for u in users if u.email.startswith("desc")] + if len(test_users) > 1: + assert test_users[0].email > test_users[1].email + + @pytest.mark.asyncio + async def test_get_multi_with_total_filtering(self, async_test_db): + """Test filtering by field.""" + test_engine, AsyncTestingSessionLocal = async_test_db + + # Create active and inactive users + async with AsyncTestingSessionLocal() as session: + active_user = UserCreate( + email="active@example.com", + password="SecurePass123!", + first_name="Active", + last_name="User" + ) + await user_async.create(session, obj_in=active_user) + + inactive_user = UserCreate( + email="inactive@example.com", + password="SecurePass123!", + first_name="Inactive", + last_name="User" + ) + created_inactive = await user_async.create(session, obj_in=inactive_user) + + # Deactivate the user + await user_async.update( + session, + db_obj=created_inactive, + obj_in={"is_active": False} + ) + + async with AsyncTestingSessionLocal() as session: + users, total = await user_async.get_multi_with_total( + session, + skip=0, + limit=100, + filters={"is_active": True} + ) + + # All returned users should be active + assert all(u.is_active for u in users) + + @pytest.mark.asyncio + async def test_get_multi_with_total_search(self, async_test_db): + """Test search functionality.""" + test_engine, AsyncTestingSessionLocal = async_test_db + + # Create user with unique name + async with AsyncTestingSessionLocal() as session: + user_data = UserCreate( + email="searchable@example.com", + password="SecurePass123!", + first_name="Searchable", + last_name="UserName" + ) + await user_async.create(session, obj_in=user_data) + + async with AsyncTestingSessionLocal() as session: + users, total = await user_async.get_multi_with_total( + session, + skip=0, + limit=100, + search="Searchable" + ) + + assert total >= 1 + assert any(u.first_name == "Searchable" for u in users) + + @pytest.mark.asyncio + async def test_get_multi_with_total_pagination(self, async_test_db): + """Test pagination with skip and limit.""" + test_engine, AsyncTestingSessionLocal = async_test_db + + # Create multiple users + async with AsyncTestingSessionLocal() as session: + for i in range(5): + user_data = UserCreate( + email=f"page{i}@example.com", + password="SecurePass123!", + first_name=f"Page{i}", + last_name="User" + ) + await user_async.create(session, obj_in=user_data) + + async with AsyncTestingSessionLocal() as session: + # Get first page + users_page1, total = await user_async.get_multi_with_total( + session, + skip=0, + limit=2 + ) + + # Get second page + users_page2, total2 = await user_async.get_multi_with_total( + session, + skip=2, + limit=2 + ) + + # Total should be same + assert total == total2 + # Different users on different pages + assert users_page1[0].id != users_page2[0].id + + @pytest.mark.asyncio + async def test_get_multi_with_total_validation_negative_skip(self, async_test_db): + """Test validation fails for negative skip.""" + test_engine, AsyncTestingSessionLocal = async_test_db + + async with AsyncTestingSessionLocal() as session: + with pytest.raises(ValueError) as exc_info: + await user_async.get_multi_with_total(session, skip=-1, limit=10) + + assert "skip must be non-negative" in str(exc_info.value) + + @pytest.mark.asyncio + async def test_get_multi_with_total_validation_negative_limit(self, async_test_db): + """Test validation fails for negative limit.""" + test_engine, AsyncTestingSessionLocal = async_test_db + + async with AsyncTestingSessionLocal() as session: + with pytest.raises(ValueError) as exc_info: + await user_async.get_multi_with_total(session, skip=0, limit=-1) + + assert "limit must be non-negative" in str(exc_info.value) + + @pytest.mark.asyncio + async def test_get_multi_with_total_validation_max_limit(self, async_test_db): + """Test validation fails for limit > 1000.""" + test_engine, AsyncTestingSessionLocal = async_test_db + + async with AsyncTestingSessionLocal() as session: + with pytest.raises(ValueError) as exc_info: + await user_async.get_multi_with_total(session, skip=0, limit=1001) + + assert "Maximum limit is 1000" in str(exc_info.value) + + +class TestBulkUpdateStatus: + """Tests for bulk_update_status method.""" + + @pytest.mark.asyncio + async def test_bulk_update_status_success(self, async_test_db): + """Test bulk updating user status.""" + test_engine, AsyncTestingSessionLocal = async_test_db + + # Create multiple users + user_ids = [] + async with AsyncTestingSessionLocal() as session: + for i in range(3): + user_data = UserCreate( + email=f"bulk{i}@example.com", + password="SecurePass123!", + first_name=f"Bulk{i}", + last_name="User" + ) + user = await user_async.create(session, obj_in=user_data) + user_ids.append(user.id) + + # Bulk deactivate + async with AsyncTestingSessionLocal() as session: + count = await user_async.bulk_update_status( + session, + user_ids=user_ids, + is_active=False + ) + assert count == 3 + + # Verify all are inactive + async with AsyncTestingSessionLocal() as session: + for user_id in user_ids: + user = await user_async.get(session, id=str(user_id)) + assert user.is_active is False + + @pytest.mark.asyncio + async def test_bulk_update_status_empty_list(self, async_test_db): + """Test bulk update with empty list returns 0.""" + test_engine, AsyncTestingSessionLocal = async_test_db + + async with AsyncTestingSessionLocal() as session: + count = await user_async.bulk_update_status( + session, + user_ids=[], + is_active=False + ) + assert count == 0 + + @pytest.mark.asyncio + async def test_bulk_update_status_reactivate(self, async_test_db): + """Test bulk reactivating users.""" + test_engine, AsyncTestingSessionLocal = async_test_db + + # Create inactive user + async with AsyncTestingSessionLocal() as session: + user_data = UserCreate( + email="reactivate@example.com", + password="SecurePass123!", + first_name="Reactivate", + last_name="User" + ) + user = await user_async.create(session, obj_in=user_data) + # Deactivate + await user_async.update(session, db_obj=user, obj_in={"is_active": False}) + user_id = user.id + + # Reactivate + async with AsyncTestingSessionLocal() as session: + count = await user_async.bulk_update_status( + session, + user_ids=[user_id], + is_active=True + ) + assert count == 1 + + # Verify active + async with AsyncTestingSessionLocal() as session: + user = await user_async.get(session, id=str(user_id)) + assert user.is_active is True + + +class TestBulkSoftDelete: + """Tests for bulk_soft_delete method.""" + + @pytest.mark.asyncio + async def test_bulk_soft_delete_success(self, async_test_db): + """Test bulk soft deleting users.""" + test_engine, AsyncTestingSessionLocal = async_test_db + + # Create multiple users + user_ids = [] + async with AsyncTestingSessionLocal() as session: + for i in range(3): + user_data = UserCreate( + email=f"delete{i}@example.com", + password="SecurePass123!", + first_name=f"Delete{i}", + last_name="User" + ) + user = await user_async.create(session, obj_in=user_data) + user_ids.append(user.id) + + # Bulk delete + async with AsyncTestingSessionLocal() as session: + count = await user_async.bulk_soft_delete( + session, + user_ids=user_ids + ) + assert count == 3 + + # Verify all are soft deleted + async with AsyncTestingSessionLocal() as session: + for user_id in user_ids: + user = await user_async.get(session, id=str(user_id)) + assert user.deleted_at is not None + assert user.is_active is False + + @pytest.mark.asyncio + async def test_bulk_soft_delete_with_exclusion(self, async_test_db): + """Test bulk soft delete with excluded user.""" + test_engine, AsyncTestingSessionLocal = async_test_db + + # Create multiple users + user_ids = [] + async with AsyncTestingSessionLocal() as session: + for i in range(3): + user_data = UserCreate( + email=f"exclude{i}@example.com", + password="SecurePass123!", + first_name=f"Exclude{i}", + last_name="User" + ) + user = await user_async.create(session, obj_in=user_data) + user_ids.append(user.id) + + # Bulk delete, excluding first user + exclude_id = user_ids[0] + async with AsyncTestingSessionLocal() as session: + count = await user_async.bulk_soft_delete( + session, + user_ids=user_ids, + exclude_user_id=exclude_id + ) + assert count == 2 # Only 2 deleted + + # Verify excluded user is NOT deleted + async with AsyncTestingSessionLocal() as session: + excluded_user = await user_async.get(session, id=str(exclude_id)) + assert excluded_user.deleted_at is None + + @pytest.mark.asyncio + async def test_bulk_soft_delete_empty_list(self, async_test_db): + """Test bulk delete with empty list returns 0.""" + test_engine, AsyncTestingSessionLocal = async_test_db + + async with AsyncTestingSessionLocal() as session: + count = await user_async.bulk_soft_delete( + session, + user_ids=[] + ) + assert count == 0 + + @pytest.mark.asyncio + async def test_bulk_soft_delete_all_excluded(self, async_test_db): + """Test bulk delete where all users are excluded.""" + test_engine, AsyncTestingSessionLocal = async_test_db + + # Create user + async with AsyncTestingSessionLocal() as session: + user_data = UserCreate( + email="onlyuser@example.com", + password="SecurePass123!", + first_name="Only", + last_name="User" + ) + user = await user_async.create(session, obj_in=user_data) + user_id = user.id + + # Try to delete but exclude + async with AsyncTestingSessionLocal() as session: + count = await user_async.bulk_soft_delete( + session, + user_ids=[user_id], + exclude_user_id=user_id + ) + assert count == 0 + + @pytest.mark.asyncio + async def test_bulk_soft_delete_already_deleted(self, async_test_db): + """Test bulk delete doesn't re-delete already deleted users.""" + test_engine, AsyncTestingSessionLocal = async_test_db + + # Create and delete user + async with AsyncTestingSessionLocal() as session: + user_data = UserCreate( + email="predeleted@example.com", + password="SecurePass123!", + first_name="PreDeleted", + last_name="User" + ) + user = await user_async.create(session, obj_in=user_data) + user_id = user.id + + # First deletion + await user_async.bulk_soft_delete(session, user_ids=[user_id]) + + # Try to delete again + async with AsyncTestingSessionLocal() as session: + count = await user_async.bulk_soft_delete( + session, + user_ids=[user_id] + ) + assert count == 0 # Already deleted + + +class TestUtilityMethods: + """Tests for utility methods.""" + + @pytest.mark.asyncio + async def test_is_active_true(self, async_test_db, async_test_user): + """Test is_active returns True for active user.""" + test_engine, AsyncTestingSessionLocal = async_test_db + + async with AsyncTestingSessionLocal() as session: + user = await user_async.get(session, id=str(async_test_user.id)) + assert user_async.is_active(user) is True + + @pytest.mark.asyncio + async def test_is_active_false(self, async_test_db): + """Test is_active returns False for inactive user.""" + test_engine, AsyncTestingSessionLocal = async_test_db + + async with AsyncTestingSessionLocal() as session: + user_data = UserCreate( + email="inactive2@example.com", + password="SecurePass123!", + first_name="Inactive", + last_name="User" + ) + user = await user_async.create(session, obj_in=user_data) + await user_async.update(session, db_obj=user, obj_in={"is_active": False}) + + assert user_async.is_active(user) is False + + @pytest.mark.asyncio + async def test_is_superuser_true(self, async_test_db, async_test_superuser): + """Test is_superuser returns True for superuser.""" + test_engine, AsyncTestingSessionLocal = async_test_db + + async with AsyncTestingSessionLocal() as session: + user = await user_async.get(session, id=str(async_test_superuser.id)) + assert user_async.is_superuser(user) is True + + @pytest.mark.asyncio + async def test_is_superuser_false(self, async_test_db, async_test_user): + """Test is_superuser returns False for regular user.""" + test_engine, AsyncTestingSessionLocal = async_test_db + + async with AsyncTestingSessionLocal() as session: + user = await user_async.get(session, id=str(async_test_user.id)) + assert user_async.is_superuser(user) is False diff --git a/backend/tests/services/test_session_cleanup.py b/backend/tests/services/test_session_cleanup.py new file mode 100644 index 0000000..333f689 --- /dev/null +++ b/backend/tests/services/test_session_cleanup.py @@ -0,0 +1,334 @@ +# tests/services/test_session_cleanup.py +""" +Comprehensive tests for session cleanup service. +""" +import pytest +import asyncio +from datetime import datetime, timedelta, timezone +from unittest.mock import patch, MagicMock, AsyncMock +from contextlib import asynccontextmanager + +from app.models.user_session import UserSession +from sqlalchemy import select + + +class TestCleanupExpiredSessions: + """Tests for cleanup_expired_sessions function.""" + + @pytest.mark.asyncio + async def test_cleanup_expired_sessions_success(self, async_test_db, async_test_user): + """Test successful cleanup of expired sessions.""" + test_engine, AsyncTestingSessionLocal = async_test_db + + # Create mix of sessions + async with AsyncTestingSessionLocal() as session: + # 1. Active, not expired (should NOT be deleted) + active_session = UserSession( + user_id=async_test_user.id, + refresh_token_jti="active_jti_123", + device_name="Active Device", + ip_address="192.168.1.1", + user_agent="Mozilla/5.0", + is_active=True, + expires_at=datetime.now(timezone.utc) + timedelta(days=7), + created_at=datetime.now(timezone.utc) - timedelta(days=1), + last_used_at=datetime.now(timezone.utc) + ) + + # 2. Inactive, expired, old (SHOULD be deleted) + old_expired_session = UserSession( + user_id=async_test_user.id, + refresh_token_jti="old_expired_jti", + device_name="Old Device", + ip_address="192.168.1.2", + user_agent="Mozilla/5.0", + is_active=False, + expires_at=datetime.now(timezone.utc) - timedelta(days=10), + created_at=datetime.now(timezone.utc) - timedelta(days=40), + last_used_at=datetime.now(timezone.utc) + ) + + # 3. Inactive, expired, recent (should NOT be deleted - within keep_days) + recent_expired_session = UserSession( + user_id=async_test_user.id, + refresh_token_jti="recent_expired_jti", + device_name="Recent Device", + ip_address="192.168.1.3", + user_agent="Mozilla/5.0", + is_active=False, + expires_at=datetime.now(timezone.utc) - timedelta(days=1), + created_at=datetime.now(timezone.utc) - timedelta(days=5), + last_used_at=datetime.now(timezone.utc) + ) + + session.add_all([active_session, old_expired_session, recent_expired_session]) + await session.commit() + + # Mock AsyncSessionLocal to return our test session + with patch('app.services.session_cleanup.AsyncSessionLocal', return_value=AsyncTestingSessionLocal()): + from app.services.session_cleanup import cleanup_expired_sessions + deleted_count = await cleanup_expired_sessions(keep_days=30) + + # Should only delete old_expired_session + assert deleted_count == 1 + + # Verify remaining sessions + async with AsyncTestingSessionLocal() as session: + result = await session.execute(select(UserSession)) + remaining = result.scalars().all() + assert len(remaining) == 2 + jtis = [s.refresh_token_jti for s in remaining] + assert "active_jti_123" in jtis + assert "recent_expired_jti" in jtis + assert "old_expired_jti" not in jtis + + @pytest.mark.asyncio + async def test_cleanup_no_sessions_to_delete(self, async_test_db, async_test_user): + """Test cleanup when no sessions meet deletion criteria.""" + test_engine, AsyncTestingSessionLocal = async_test_db + + async with AsyncTestingSessionLocal() as session: + active = UserSession( + user_id=async_test_user.id, + refresh_token_jti="active_only_jti", + device_name="Active Device", + ip_address="192.168.1.1", + user_agent="Mozilla/5.0", + is_active=True, + expires_at=datetime.now(timezone.utc) + timedelta(days=7), + created_at=datetime.now(timezone.utc), + last_used_at=datetime.now(timezone.utc) + ) + session.add(active) + await session.commit() + + with patch('app.services.session_cleanup.AsyncSessionLocal', return_value=AsyncTestingSessionLocal()): + from app.services.session_cleanup import cleanup_expired_sessions + deleted_count = await cleanup_expired_sessions(keep_days=30) + + assert deleted_count == 0 + + @pytest.mark.asyncio + async def test_cleanup_empty_database(self, async_test_db): + """Test cleanup with no sessions in database.""" + test_engine, AsyncTestingSessionLocal = async_test_db + + with patch('app.services.session_cleanup.AsyncSessionLocal', return_value=AsyncTestingSessionLocal()): + from app.services.session_cleanup import cleanup_expired_sessions + deleted_count = await cleanup_expired_sessions(keep_days=30) + + assert deleted_count == 0 + + @pytest.mark.asyncio + async def test_cleanup_with_keep_days_0(self, async_test_db, async_test_user): + """Test cleanup with keep_days=0 deletes all inactive expired sessions.""" + test_engine, AsyncTestingSessionLocal = async_test_db + + async with AsyncTestingSessionLocal() as session: + today_expired = UserSession( + user_id=async_test_user.id, + refresh_token_jti="today_expired_jti", + device_name="Today Expired", + ip_address="192.168.1.1", + user_agent="Mozilla/5.0", + is_active=False, + expires_at=datetime.now(timezone.utc) - timedelta(hours=1), + created_at=datetime.now(timezone.utc) - timedelta(hours=2), + last_used_at=datetime.now(timezone.utc) + ) + session.add(today_expired) + await session.commit() + + with patch('app.services.session_cleanup.AsyncSessionLocal', return_value=AsyncTestingSessionLocal()): + from app.services.session_cleanup import cleanup_expired_sessions + deleted_count = await cleanup_expired_sessions(keep_days=0) + + assert deleted_count == 1 + + @pytest.mark.asyncio + async def test_cleanup_bulk_delete_efficiency(self, async_test_db, async_test_user): + """Test that cleanup uses bulk DELETE for many sessions.""" + test_engine, AsyncTestingSessionLocal = async_test_db + + # Create 50 expired sessions + async with AsyncTestingSessionLocal() as session: + sessions_to_add = [] + for i in range(50): + expired = UserSession( + user_id=async_test_user.id, + refresh_token_jti=f"bulk_jti_{i}", + device_name=f"Device {i}", + ip_address="192.168.1.1", + user_agent="Mozilla/5.0", + is_active=False, + expires_at=datetime.now(timezone.utc) - timedelta(days=10), + created_at=datetime.now(timezone.utc) - timedelta(days=40), + last_used_at=datetime.now(timezone.utc) + ) + sessions_to_add.append(expired) + session.add_all(sessions_to_add) + await session.commit() + + with patch('app.services.session_cleanup.AsyncSessionLocal', return_value=AsyncTestingSessionLocal()): + from app.services.session_cleanup import cleanup_expired_sessions + deleted_count = await cleanup_expired_sessions(keep_days=30) + + assert deleted_count == 50 + + @pytest.mark.asyncio + async def test_cleanup_database_error_returns_zero(self, async_test_db): + """Test cleanup returns 0 on database errors (doesn't crash).""" + test_engine, AsyncTestingSessionLocal = async_test_db + + # Mock session_crud.cleanup_expired to raise error + with patch('app.services.session_cleanup.AsyncSessionLocal', return_value=AsyncTestingSessionLocal()): + with patch('app.services.session_cleanup.session_crud.cleanup_expired') as mock_cleanup: + mock_cleanup.side_effect = Exception("Database connection lost") + + from app.services.session_cleanup import cleanup_expired_sessions + # Should not crash, should return 0 + deleted_count = await cleanup_expired_sessions(keep_days=30) + + assert deleted_count == 0 + + +class TestGetSessionStatistics: + """Tests for get_session_statistics function.""" + + @pytest.mark.asyncio + async def test_get_statistics_with_sessions(self, async_test_db, async_test_user): + """Test getting session statistics with various session types.""" + test_engine, AsyncTestingSessionLocal = async_test_db + + async with AsyncTestingSessionLocal() as session: + # 2 active, not expired + for i in range(2): + active = UserSession( + user_id=async_test_user.id, + refresh_token_jti=f"active_stat_{i}", + device_name=f"Active {i}", + ip_address="192.168.1.1", + user_agent="Mozilla/5.0", + is_active=True, + expires_at=datetime.now(timezone.utc) + timedelta(days=7), + created_at=datetime.now(timezone.utc), + last_used_at=datetime.now(timezone.utc) + ) + session.add(active) + + # 3 inactive, expired + for i in range(3): + inactive = UserSession( + user_id=async_test_user.id, + refresh_token_jti=f"inactive_stat_{i}", + device_name=f"Inactive {i}", + ip_address="192.168.1.2", + user_agent="Mozilla/5.0", + is_active=False, + expires_at=datetime.now(timezone.utc) - timedelta(days=1), + created_at=datetime.now(timezone.utc) - timedelta(days=2), + last_used_at=datetime.now(timezone.utc) + ) + session.add(inactive) + + # 1 active but expired + expired_active = UserSession( + user_id=async_test_user.id, + refresh_token_jti="expired_active_stat", + device_name="Expired Active", + ip_address="192.168.1.3", + user_agent="Mozilla/5.0", + is_active=True, + expires_at=datetime.now(timezone.utc) - timedelta(hours=1), + created_at=datetime.now(timezone.utc) - timedelta(days=1), + last_used_at=datetime.now(timezone.utc) + ) + session.add(expired_active) + + await session.commit() + + with patch('app.services.session_cleanup.AsyncSessionLocal', return_value=AsyncTestingSessionLocal()): + from app.services.session_cleanup import get_session_statistics + stats = await get_session_statistics() + + assert stats["total"] == 6 + assert stats["active"] == 3 # 2 active + 1 expired but active + assert stats["inactive"] == 3 + assert stats["expired"] == 4 # 3 inactive expired + 1 active expired + + @pytest.mark.asyncio + async def test_get_statistics_empty_database(self, async_test_db): + """Test getting statistics with no sessions.""" + test_engine, AsyncTestingSessionLocal = async_test_db + + with patch('app.services.session_cleanup.AsyncSessionLocal', return_value=AsyncTestingSessionLocal()): + from app.services.session_cleanup import get_session_statistics + stats = await get_session_statistics() + + assert stats["total"] == 0 + assert stats["active"] == 0 + assert stats["inactive"] == 0 + assert stats["expired"] == 0 + + @pytest.mark.asyncio + async def test_get_statistics_database_error_returns_empty_dict(self, async_test_db): + """Test statistics returns empty dict on database errors.""" + test_engine, AsyncTestingSessionLocal = async_test_db + + # Create a mock that raises on execute + mock_session = AsyncMock() + mock_session.execute.side_effect = Exception("Database error") + + @asynccontextmanager + async def mock_session_local(): + yield mock_session + + with patch('app.services.session_cleanup.AsyncSessionLocal', return_value=mock_session_local()): + from app.services.session_cleanup import get_session_statistics + stats = await get_session_statistics() + + assert stats == {} + + +class TestConcurrentCleanup: + """Tests for concurrent cleanup scenarios.""" + + @pytest.mark.asyncio + async def test_concurrent_cleanup_no_duplicate_deletes(self, async_test_db, async_test_user): + """Test concurrent cleanups don't cause race conditions.""" + test_engine, AsyncTestingSessionLocal = async_test_db + + # Create 10 expired sessions + async with AsyncTestingSessionLocal() as session: + for i in range(10): + expired = UserSession( + user_id=async_test_user.id, + refresh_token_jti=f"concurrent_jti_{i}", + device_name=f"Device {i}", + ip_address="192.168.1.1", + user_agent="Mozilla/5.0", + is_active=False, + expires_at=datetime.now(timezone.utc) - timedelta(days=10), + created_at=datetime.now(timezone.utc) - timedelta(days=40), + last_used_at=datetime.now(timezone.utc) + ) + session.add(expired) + await session.commit() + + # Run two cleanups concurrently + # Use side_effect to return fresh session instances for each call + with patch('app.services.session_cleanup.AsyncSessionLocal', side_effect=lambda: AsyncTestingSessionLocal()): + from app.services.session_cleanup import cleanup_expired_sessions + results = await asyncio.gather( + cleanup_expired_sessions(keep_days=30), + cleanup_expired_sessions(keep_days=30) + ) + + # Both should report deleting sessions (may overlap due to transaction timing) + assert sum(results) >= 10 + + # Verify all are deleted + async with AsyncTestingSessionLocal() as session: + result = await session.execute(select(UserSession)) + remaining = result.scalars().all() + assert len(remaining) == 0 diff --git a/backend/tests/utils/test_device.py b/backend/tests/utils/test_device.py new file mode 100644 index 0000000..f122441 --- /dev/null +++ b/backend/tests/utils/test_device.py @@ -0,0 +1,425 @@ +# tests/utils/test_device.py +""" +Comprehensive tests for device utility functions. +""" +import pytest +from unittest.mock import Mock + +from fastapi import Request + +from app.utils.device import ( + extract_device_info, + parse_device_name, + extract_browser, + get_client_ip, + is_mobile_device, + get_device_type +) + + +class TestParseDeviceName: + """Tests for parse_device_name function.""" + + def test_parse_device_name_empty_string(self): + """Test parsing empty user agent.""" + result = parse_device_name("") + assert result == "Unknown device" + + def test_parse_device_name_iphone(self): + """Test parsing iPhone user agent.""" + ua = "Mozilla/5.0 (iPhone; CPU iPhone OS 15_0 like Mac OS X)" + result = parse_device_name(ua) + assert result == "iPhone" + + def test_parse_device_name_ipad(self): + """Test parsing iPad user agent.""" + ua = "Mozilla/5.0 (iPad; CPU OS 14_0 like Mac OS X)" + result = parse_device_name(ua) + assert result == "iPad" + + def test_parse_device_name_android_with_model(self): + """Test parsing Android user agent with device model.""" + ua = "Mozilla/5.0 (Linux; Android 11; SM-G991B Build/RP1A)" + result = parse_device_name(ua) + assert result == "Android (Sm-G991B)" + + def test_parse_device_name_android_without_model(self): + """Test parsing Android user agent without model.""" + ua = "Mozilla/5.0 (Linux; Android)" + result = parse_device_name(ua) + assert result == "Android device" + + def test_parse_device_name_windows_phone(self): + """Test parsing Windows Phone user agent.""" + ua = "Mozilla/5.0 (Windows Phone 10.0)" + result = parse_device_name(ua) + assert result == "Windows Phone" + + def test_parse_device_name_mac(self): + """Test parsing Mac user agent.""" + ua = "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/96.0.4664.110 Safari/537.36" + result = parse_device_name(ua) + assert result == "Chrome on Mac" + + def test_parse_device_name_windows(self): + """Test parsing Windows user agent.""" + ua = "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/96.0.4664.110 Safari/537.36" + result = parse_device_name(ua) + assert result == "Chrome on Windows" + + def test_parse_device_name_linux(self): + """Test parsing Linux user agent.""" + ua = "Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/96.0.4664.110 Safari/537.36" + result = parse_device_name(ua) + assert result == "Chrome on Linux" + + def test_parse_device_name_chromebook(self): + """Test parsing Chromebook user agent.""" + ua = "Mozilla/5.0 (X11; CrOS x86_64 14092.0.0) AppleWebKit/537.36" + result = parse_device_name(ua) + assert result == "Chromebook" + + def test_parse_device_name_tablet(self): + """Test parsing generic tablet user agent.""" + ua = "Mozilla/5.0 (Linux; Android 9; Tablet) AppleWebKit/537.36" + result = parse_device_name(ua) + # Should match tablet first since it's in the string + assert "Tablet" in result or "Android" in result + + def test_parse_device_name_smart_tv(self): + """Test parsing Smart TV user agent.""" + ua = "Mozilla/5.0 (SMART-TV; Linux; Tizen 2.3)" + result = parse_device_name(ua) + assert result == "Smart TV" + + def test_parse_device_name_playstation(self): + """Test parsing PlayStation user agent.""" + ua = "Mozilla/5.0 (PlayStation 4 5.50)" + result = parse_device_name(ua) + assert result == "PlayStation" + + def test_parse_device_name_xbox(self): + """Test parsing Xbox user agent.""" + ua = "Mozilla/5.0 (Windows NT 10.0; Win64; x64; Xbox; Xbox One)" + result = parse_device_name(ua) + assert result == "Xbox" + + def test_parse_device_name_nintendo(self): + """Test parsing Nintendo user agent.""" + ua = "Mozilla/5.0 (Nintendo Switch)" + result = parse_device_name(ua) + assert result == "Nintendo" + + def test_parse_device_name_unknown(self): + """Test parsing completely unknown user agent.""" + ua = "SomeRandomBot/1.0" + result = parse_device_name(ua) + assert result == "Unknown device" + + +class TestExtractBrowser: + """Tests for extract_browser function.""" + + def test_extract_browser_empty_string(self): + """Test extracting browser from empty user agent.""" + result = extract_browser("") + assert result is None + + def test_extract_browser_none(self): + """Test extracting browser from None.""" + result = extract_browser(None) + assert result is None + + def test_extract_browser_edge(self): + """Test extracting Edge browser.""" + ua = "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/96.0.4664.110 Safari/537.36 Edg/96.0.1054.62" + result = extract_browser(ua) + assert result == "Edge" + + def test_extract_browser_edge_legacy(self): + """Test extracting legacy Edge browser.""" + ua = "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 Edge/18.19582" + result = extract_browser(ua) + assert result == "Edge" + + def test_extract_browser_opera(self): + """Test extracting Opera browser.""" + ua = "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/96.0.4664.110 Safari/537.36 OPR/82.0.4227.50" + result = extract_browser(ua) + assert result == "Opera" + + def test_extract_browser_chrome(self): + """Test extracting Chrome browser.""" + ua = "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/96.0.4664.110 Safari/537.36" + result = extract_browser(ua) + assert result == "Chrome" + + def test_extract_browser_safari(self): + """Test extracting Safari browser.""" + ua = "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/605.1.15 (KHTML, like Gecko) Version/15.0 Safari/605.1.15" + result = extract_browser(ua) + assert result == "Safari" + + def test_extract_browser_firefox(self): + """Test extracting Firefox browser.""" + ua = "Mozilla/5.0 (Windows NT 10.0; Win64; x64; rv:94.0) Gecko/20100101 Firefox/94.0" + result = extract_browser(ua) + assert result == "Firefox" + + def test_extract_browser_internet_explorer_msie(self): + """Test extracting Internet Explorer (MSIE).""" + ua = "Mozilla/4.0 (compatible; MSIE 7.0; Windows NT 10.0)" + result = extract_browser(ua) + assert result == "Internet Explorer" + + def test_extract_browser_internet_explorer_trident(self): + """Test extracting Internet Explorer (Trident).""" + ua = "Mozilla/5.0 (Windows NT 10.0; Trident/7.0; rv:11.0) like Gecko" + result = extract_browser(ua) + assert result == "Internet Explorer" + + def test_extract_browser_unknown(self): + """Test extracting from unknown browser.""" + ua = "SomeRandomBot/1.0" + result = extract_browser(ua) + assert result is None + + +class TestGetClientIp: + """Tests for get_client_ip function.""" + + def test_get_client_ip_x_forwarded_for_single(self): + """Test getting IP from X-Forwarded-For with single IP.""" + request = Mock(spec=Request) + request.headers = {"x-forwarded-for": "192.168.1.100"} + request.client = None + + result = get_client_ip(request) + assert result == "192.168.1.100" + + def test_get_client_ip_x_forwarded_for_multiple(self): + """Test getting IP from X-Forwarded-For with multiple IPs.""" + request = Mock(spec=Request) + request.headers = {"x-forwarded-for": "192.168.1.100, 10.0.0.1, 172.16.0.1"} + request.client = None + + result = get_client_ip(request) + assert result == "192.168.1.100" + + def test_get_client_ip_x_real_ip(self): + """Test getting IP from X-Real-IP.""" + request = Mock(spec=Request) + request.headers = {"x-real-ip": "192.168.1.200"} + request.client = None + + result = get_client_ip(request) + assert result == "192.168.1.200" + + def test_get_client_ip_direct_connection(self): + """Test getting IP from direct connection.""" + request = Mock(spec=Request) + request.headers = {} + request.client = Mock() + request.client.host = "192.168.1.50" + + result = get_client_ip(request) + assert result == "192.168.1.50" + + def test_get_client_ip_no_client(self): + """Test getting IP when no client info available.""" + request = Mock(spec=Request) + request.headers = {} + request.client = None + + result = get_client_ip(request) + assert result is None + + def test_get_client_ip_client_no_host(self): + """Test getting IP when client exists but no host.""" + request = Mock(spec=Request) + request.headers = {} + request.client = Mock() + request.client.host = None + + result = get_client_ip(request) + assert result is None + + def test_get_client_ip_priority_x_forwarded_for(self): + """Test that X-Forwarded-For has priority over X-Real-IP.""" + request = Mock(spec=Request) + request.headers = { + "x-forwarded-for": "192.168.1.100", + "x-real-ip": "192.168.1.200" + } + request.client = Mock() + request.client.host = "192.168.1.50" + + result = get_client_ip(request) + assert result == "192.168.1.100" + + def test_get_client_ip_priority_x_real_ip_over_client(self): + """Test that X-Real-IP has priority over client.host.""" + request = Mock(spec=Request) + request.headers = {"x-real-ip": "192.168.1.200"} + request.client = Mock() + request.client.host = "192.168.1.50" + + result = get_client_ip(request) + assert result == "192.168.1.200" + + +class TestIsMobileDevice: + """Tests for is_mobile_device function.""" + + def test_is_mobile_device_empty_string(self): + """Test with empty string.""" + result = is_mobile_device("") + assert result is False + + def test_is_mobile_device_iphone(self): + """Test iPhone user agent.""" + ua = "Mozilla/5.0 (iPhone; CPU iPhone OS 15_0 like Mac OS X)" + result = is_mobile_device(ua) + assert result is True + + def test_is_mobile_device_android(self): + """Test Android user agent.""" + ua = "Mozilla/5.0 (Linux; Android 11)" + result = is_mobile_device(ua) + assert result is True + + def test_is_mobile_device_ipad(self): + """Test iPad user agent.""" + ua = "Mozilla/5.0 (iPad; CPU OS 14_0 like Mac OS X)" + result = is_mobile_device(ua) + assert result is True + + def test_is_mobile_device_desktop(self): + """Test desktop user agent.""" + ua = "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36" + result = is_mobile_device(ua) + assert result is False + + def test_is_mobile_device_blackberry(self): + """Test BlackBerry user agent.""" + ua = "Mozilla/5.0 (BlackBerry; U; BlackBerry 9900)" + result = is_mobile_device(ua) + assert result is True + + def test_is_mobile_device_windows_phone(self): + """Test Windows Phone user agent.""" + ua = "Mozilla/5.0 (Windows Phone 10.0)" + result = is_mobile_device(ua) + assert result is True + + +class TestGetDeviceType: + """Tests for get_device_type function.""" + + def test_get_device_type_empty_string(self): + """Test with empty string.""" + result = get_device_type("") + assert result == "other" + + def test_get_device_type_ipad(self): + """Test iPad returns tablet.""" + ua = "Mozilla/5.0 (iPad; CPU OS 14_0 like Mac OS X)" + result = get_device_type(ua) + assert result == "tablet" + + def test_get_device_type_tablet(self): + """Test generic tablet.""" + ua = "Mozilla/5.0 (Linux; Android 9; Tablet)" + result = get_device_type(ua) + assert result == "tablet" + + def test_get_device_type_iphone(self): + """Test iPhone returns mobile.""" + ua = "Mozilla/5.0 (iPhone; CPU iPhone OS 15_0 like Mac OS X)" + result = get_device_type(ua) + assert result == "mobile" + + def test_get_device_type_android_mobile(self): + """Test Android mobile.""" + ua = "Mozilla/5.0 (Linux; Android 11; SM-G991B) Mobile" + result = get_device_type(ua) + assert result == "mobile" + + def test_get_device_type_windows_desktop(self): + """Test Windows desktop.""" + ua = "Mozilla/5.0 (Windows NT 10.0; Win64; x64)" + result = get_device_type(ua) + assert result == "desktop" + + def test_get_device_type_mac_desktop(self): + """Test Mac desktop.""" + ua = "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7)" + result = get_device_type(ua) + assert result == "desktop" + + def test_get_device_type_linux_desktop(self): + """Test Linux desktop.""" + ua = "Mozilla/5.0 (X11; Linux x86_64)" + result = get_device_type(ua) + assert result == "desktop" + + def test_get_device_type_chromebook(self): + """Test Chromebook.""" + ua = "Mozilla/5.0 (X11; CrOS x86_64 14092.0.0)" + result = get_device_type(ua) + assert result == "desktop" + + def test_get_device_type_unknown(self): + """Test unknown device.""" + ua = "SomeRandomBot/1.0" + result = get_device_type(ua) + assert result == "other" + + +class TestExtractDeviceInfo: + """Tests for extract_device_info function.""" + + def test_extract_device_info_complete(self): + """Test extracting device info with all headers.""" + request = Mock(spec=Request) + request.headers = { + "user-agent": "Mozilla/5.0 (iPhone; CPU iPhone OS 15_0 like Mac OS X)", + "x-device-id": "device-123-456", + "x-forwarded-for": "192.168.1.100" + } + request.client = None + + result = extract_device_info(request) + + assert result.device_name == "iPhone" + assert result.device_id == "device-123-456" + assert result.ip_address == "192.168.1.100" + assert "iPhone" in result.user_agent + assert result.location_city is None + assert result.location_country is None + + def test_extract_device_info_minimal(self): + """Test extracting device info with minimal headers.""" + request = Mock(spec=Request) + request.headers = {} + request.client = Mock() + request.client.host = "127.0.0.1" + + result = extract_device_info(request) + + assert result.device_name == "Unknown device" + assert result.device_id is None + assert result.ip_address == "127.0.0.1" + assert result.user_agent is None + + def test_extract_device_info_long_user_agent(self): + """Test that user agent is truncated to 500 chars.""" + long_ua = "A" * 600 + request = Mock(spec=Request) + request.headers = {"user-agent": long_ua} + request.client = None + + result = extract_device_info(request) + + assert len(result.user_agent) == 500 + assert result.user_agent == "A" * 500