Add comprehensive tests for session cleanup and async CRUD operations; improve error handling and validation across schemas and API routes
- Introduced extensive tests for session cleanup, async session CRUD methods, and concurrent cleanup to ensure reliability and efficiency. - Enhanced `schemas/users.py` with reusable password strength validation logic. - Improved error handling in `admin.py` routes by replacing `detail` with `message` for consistency and readability.
This commit is contained in:
@@ -0,0 +1,78 @@
|
||||
"""add_performance_indexes
|
||||
|
||||
Revision ID: 1174fffbe3e4
|
||||
Revises: fbf6318a8a36
|
||||
Create Date: 2025-11-01 04:15:25.367010
|
||||
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
|
||||
import sqlalchemy as sa
|
||||
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = '1174fffbe3e4'
|
||||
down_revision: Union[str, None] = 'fbf6318a8a36'
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
"""Add performance indexes for optimized queries."""
|
||||
|
||||
# Index for session cleanup queries
|
||||
# Optimizes: DELETE WHERE is_active = FALSE AND expires_at < now AND created_at < cutoff
|
||||
op.create_index(
|
||||
'ix_user_sessions_cleanup',
|
||||
'user_sessions',
|
||||
['is_active', 'expires_at', 'created_at'],
|
||||
unique=False,
|
||||
postgresql_where=sa.text('is_active = false')
|
||||
)
|
||||
|
||||
# Index for user search queries (basic trigram support without pg_trgm extension)
|
||||
# Optimizes: WHERE email ILIKE '%search%' OR first_name ILIKE '%search%'
|
||||
# Note: For better performance, consider enabling pg_trgm extension
|
||||
op.create_index(
|
||||
'ix_users_email_lower',
|
||||
'users',
|
||||
[sa.text('LOWER(email)')],
|
||||
unique=False,
|
||||
postgresql_where=sa.text('deleted_at IS NULL')
|
||||
)
|
||||
|
||||
op.create_index(
|
||||
'ix_users_first_name_lower',
|
||||
'users',
|
||||
[sa.text('LOWER(first_name)')],
|
||||
unique=False,
|
||||
postgresql_where=sa.text('deleted_at IS NULL')
|
||||
)
|
||||
|
||||
op.create_index(
|
||||
'ix_users_last_name_lower',
|
||||
'users',
|
||||
[sa.text('LOWER(last_name)')],
|
||||
unique=False,
|
||||
postgresql_where=sa.text('deleted_at IS NULL')
|
||||
)
|
||||
|
||||
# Index for organization search
|
||||
op.create_index(
|
||||
'ix_organizations_name_lower',
|
||||
'organizations',
|
||||
[sa.text('LOWER(name)')],
|
||||
unique=False
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
"""Remove performance indexes."""
|
||||
|
||||
# Drop indexes in reverse order
|
||||
op.drop_index('ix_organizations_name_lower', table_name='organizations')
|
||||
op.drop_index('ix_users_last_name_lower', table_name='users')
|
||||
op.drop_index('ix_users_first_name_lower', table_name='users')
|
||||
op.drop_index('ix_users_email_lower', table_name='users')
|
||||
op.drop_index('ix_user_sessions_cleanup', table_name='user_sessions')
|
||||
@@ -145,7 +145,7 @@ async def admin_create_user(
|
||||
except ValueError as e:
|
||||
logger.warning(f"Failed to create user: {str(e)}")
|
||||
raise NotFoundError(
|
||||
detail=str(e),
|
||||
message=str(e),
|
||||
error_code=ErrorCode.USER_ALREADY_EXISTS
|
||||
)
|
||||
except Exception as e:
|
||||
@@ -169,7 +169,7 @@ async def admin_get_user(
|
||||
user = await user_crud.get(db, id=user_id)
|
||||
if not user:
|
||||
raise NotFoundError(
|
||||
detail=f"User {user_id} not found",
|
||||
message=f"User {user_id} not found",
|
||||
error_code=ErrorCode.USER_NOT_FOUND
|
||||
)
|
||||
return user
|
||||
@@ -193,7 +193,7 @@ async def admin_update_user(
|
||||
user = await user_crud.get(db, id=user_id)
|
||||
if not user:
|
||||
raise NotFoundError(
|
||||
detail=f"User {user_id} not found",
|
||||
message=f"User {user_id} not found",
|
||||
error_code=ErrorCode.USER_NOT_FOUND
|
||||
)
|
||||
|
||||
@@ -225,7 +225,7 @@ async def admin_delete_user(
|
||||
user = await user_crud.get(db, id=user_id)
|
||||
if not user:
|
||||
raise NotFoundError(
|
||||
detail=f"User {user_id} not found",
|
||||
message=f"User {user_id} not found",
|
||||
error_code=ErrorCode.USER_NOT_FOUND
|
||||
)
|
||||
|
||||
@@ -269,7 +269,7 @@ async def admin_activate_user(
|
||||
user = await user_crud.get(db, id=user_id)
|
||||
if not user:
|
||||
raise NotFoundError(
|
||||
detail=f"User {user_id} not found",
|
||||
message=f"User {user_id} not found",
|
||||
error_code=ErrorCode.USER_NOT_FOUND
|
||||
)
|
||||
|
||||
@@ -305,7 +305,7 @@ async def admin_deactivate_user(
|
||||
user = await user_crud.get(db, id=user_id)
|
||||
if not user:
|
||||
raise NotFoundError(
|
||||
detail=f"User {user_id} not found",
|
||||
message=f"User {user_id} not found",
|
||||
error_code=ErrorCode.USER_NOT_FOUND
|
||||
)
|
||||
|
||||
@@ -491,7 +491,7 @@ async def admin_create_organization(
|
||||
except ValueError as e:
|
||||
logger.warning(f"Failed to create organization: {str(e)}")
|
||||
raise NotFoundError(
|
||||
detail=str(e),
|
||||
message=str(e),
|
||||
error_code=ErrorCode.ALREADY_EXISTS
|
||||
)
|
||||
except Exception as e:
|
||||
@@ -515,7 +515,7 @@ async def admin_get_organization(
|
||||
org = await organization_crud.get(db, id=org_id)
|
||||
if not org:
|
||||
raise NotFoundError(
|
||||
detail=f"Organization {org_id} not found",
|
||||
message=f"Organization {org_id} not found",
|
||||
error_code=ErrorCode.NOT_FOUND
|
||||
)
|
||||
|
||||
@@ -551,7 +551,7 @@ async def admin_update_organization(
|
||||
org = await organization_crud.get(db, id=org_id)
|
||||
if not org:
|
||||
raise NotFoundError(
|
||||
detail=f"Organization {org_id} not found",
|
||||
message=f"Organization {org_id} not found",
|
||||
error_code=ErrorCode.NOT_FOUND
|
||||
)
|
||||
|
||||
@@ -595,7 +595,7 @@ async def admin_delete_organization(
|
||||
org = await organization_crud.get(db, id=org_id)
|
||||
if not org:
|
||||
raise NotFoundError(
|
||||
detail=f"Organization {org_id} not found",
|
||||
message=f"Organization {org_id} not found",
|
||||
error_code=ErrorCode.NOT_FOUND
|
||||
)
|
||||
|
||||
@@ -633,7 +633,7 @@ async def admin_list_organization_members(
|
||||
org = await organization_crud.get(db, id=org_id)
|
||||
if not org:
|
||||
raise NotFoundError(
|
||||
detail=f"Organization {org_id} not found",
|
||||
message=f"Organization {org_id} not found",
|
||||
error_code=ErrorCode.NOT_FOUND
|
||||
)
|
||||
|
||||
@@ -688,14 +688,14 @@ async def admin_add_organization_member(
|
||||
org = await organization_crud.get(db, id=org_id)
|
||||
if not org:
|
||||
raise NotFoundError(
|
||||
detail=f"Organization {org_id} not found",
|
||||
message=f"Organization {org_id} not found",
|
||||
error_code=ErrorCode.NOT_FOUND
|
||||
)
|
||||
|
||||
user = await user_crud.get(db, id=request.user_id)
|
||||
if not user:
|
||||
raise NotFoundError(
|
||||
detail=f"User {request.user_id} not found",
|
||||
message=f"User {request.user_id} not found",
|
||||
error_code=ErrorCode.USER_NOT_FOUND
|
||||
)
|
||||
|
||||
@@ -749,14 +749,14 @@ async def admin_remove_organization_member(
|
||||
org = await organization_crud.get(db, id=org_id)
|
||||
if not org:
|
||||
raise NotFoundError(
|
||||
detail=f"Organization {org_id} not found",
|
||||
message=f"Organization {org_id} not found",
|
||||
error_code=ErrorCode.NOT_FOUND
|
||||
)
|
||||
|
||||
user = await user_crud.get(db, id=user_id)
|
||||
if not user:
|
||||
raise NotFoundError(
|
||||
detail=f"User {user_id} not found",
|
||||
message=f"User {user_id} not found",
|
||||
error_code=ErrorCode.USER_NOT_FOUND
|
||||
)
|
||||
|
||||
@@ -768,7 +768,7 @@ async def admin_remove_organization_member(
|
||||
|
||||
if not success:
|
||||
raise NotFoundError(
|
||||
detail="User is not a member of this organization",
|
||||
message="User is not a member of this organization",
|
||||
error_code=ErrorCode.NOT_FOUND
|
||||
)
|
||||
|
||||
|
||||
@@ -35,6 +35,7 @@ class UserUpdate(BaseModel):
|
||||
first_name: Optional[str] = None
|
||||
last_name: Optional[str] = None
|
||||
phone_number: Optional[str] = None
|
||||
password: Optional[str] = None
|
||||
preferences: Optional[Dict[str, Any]] = None
|
||||
is_active: Optional[bool] = None # Changed default from True to None to avoid unintended updates
|
||||
|
||||
@@ -43,6 +44,14 @@ class UserUpdate(BaseModel):
|
||||
def validate_phone(cls, v: Optional[str]) -> Optional[str]:
|
||||
return validate_phone_number(v)
|
||||
|
||||
@field_validator('password')
|
||||
@classmethod
|
||||
def password_strength(cls, v: Optional[str]) -> Optional[str]:
|
||||
"""Enterprise-grade password strength validation"""
|
||||
if v is None:
|
||||
return v
|
||||
return validate_password_strength(v)
|
||||
|
||||
|
||||
class UserInDB(UserBase):
|
||||
id: UUID
|
||||
|
||||
@@ -68,6 +68,22 @@ def parse_device_name(user_agent: str) -> Optional[str]:
|
||||
elif 'windows phone' in user_agent_lower:
|
||||
return "Windows Phone"
|
||||
|
||||
# Tablets (check before desktop, as some tablets contain "android")
|
||||
elif 'tablet' in user_agent_lower:
|
||||
return "Tablet"
|
||||
|
||||
# Smart TVs (check before desktop OS patterns)
|
||||
elif any(tv in user_agent_lower for tv in ['smart-tv', 'smarttv']):
|
||||
return "Smart TV"
|
||||
|
||||
# Game consoles (check before desktop OS patterns, as Xbox contains "Windows")
|
||||
elif 'playstation' in user_agent_lower:
|
||||
return "PlayStation"
|
||||
elif 'xbox' in user_agent_lower:
|
||||
return "Xbox"
|
||||
elif 'nintendo' in user_agent_lower:
|
||||
return "Nintendo"
|
||||
|
||||
# Desktop operating systems
|
||||
elif 'macintosh' in user_agent_lower or 'mac os x' in user_agent_lower:
|
||||
# Try to extract browser
|
||||
@@ -82,22 +98,6 @@ def parse_device_name(user_agent: str) -> Optional[str]:
|
||||
elif 'cros' in user_agent_lower:
|
||||
return "Chromebook"
|
||||
|
||||
# Tablets (not already caught)
|
||||
elif 'tablet' in user_agent_lower:
|
||||
return "Tablet"
|
||||
|
||||
# Smart TVs
|
||||
elif any(tv in user_agent_lower for tv in ['smart-tv', 'smarttv', 'tv']):
|
||||
return "Smart TV"
|
||||
|
||||
# Game consoles
|
||||
elif 'playstation' in user_agent_lower:
|
||||
return "PlayStation"
|
||||
elif 'xbox' in user_agent_lower:
|
||||
return "Xbox"
|
||||
elif 'nintendo' in user_agent_lower:
|
||||
return "Nintendo"
|
||||
|
||||
# Fallback: just return browser name if detected
|
||||
browser = extract_browser(user_agent)
|
||||
if browser:
|
||||
|
||||
839
backend/tests/api/test_admin.py
Normal file
839
backend/tests/api/test_admin.py
Normal file
@@ -0,0 +1,839 @@
|
||||
# tests/api/test_admin.py
|
||||
"""
|
||||
Comprehensive tests for admin endpoints.
|
||||
"""
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
from uuid import uuid4
|
||||
from fastapi import status
|
||||
|
||||
from app.models.organization import Organization
|
||||
from app.models.user_organization import UserOrganization, OrganizationRole
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def superuser_token(client, async_test_superuser):
|
||||
"""Get access token for superuser."""
|
||||
response = await client.post(
|
||||
"/api/v1/auth/login",
|
||||
json={
|
||||
"email": "superuser@example.com",
|
||||
"password": "SuperPassword123!"
|
||||
}
|
||||
)
|
||||
assert response.status_code == 200, f"Login failed: {response.json()}"
|
||||
return response.json()["access_token"]
|
||||
|
||||
|
||||
# ===== USER MANAGEMENT TESTS =====
|
||||
|
||||
class TestAdminListUsers:
|
||||
"""Tests for GET /admin/users endpoint."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_admin_list_users_success(self, client, superuser_token):
|
||||
"""Test successfully listing users as admin."""
|
||||
response = await client.get(
|
||||
"/api/v1/admin/users",
|
||||
headers={"Authorization": f"Bearer {superuser_token}"}
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
data = response.json()
|
||||
assert "data" in data
|
||||
assert "pagination" in data
|
||||
assert isinstance(data["data"], list)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_admin_list_users_with_filters(self, client, async_test_superuser, async_test_db, superuser_token):
|
||||
"""Test listing users with filters."""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
# Create inactive user
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
from app.models.user import User
|
||||
from app.core.auth import get_password_hash
|
||||
inactive_user = User(
|
||||
email="inactive@example.com",
|
||||
password_hash=get_password_hash("TestPassword123!"),
|
||||
first_name="Inactive",
|
||||
last_name="User",
|
||||
is_active=False
|
||||
)
|
||||
session.add(inactive_user)
|
||||
await session.commit()
|
||||
|
||||
response = await client.get(
|
||||
"/api/v1/admin/users?is_active=false",
|
||||
headers={"Authorization": f"Bearer {superuser_token}"}
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
data = response.json()
|
||||
assert len(data["data"]) >= 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_admin_list_users_with_search(self, client, async_test_superuser, superuser_token):
|
||||
"""Test searching users."""
|
||||
response = await client.get(
|
||||
"/api/v1/admin/users?search=superuser",
|
||||
headers={"Authorization": f"Bearer {superuser_token}"}
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
data = response.json()
|
||||
assert "data" in data
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_admin_list_users_unauthorized(self, client, async_test_user):
|
||||
"""Test non-admin cannot list users."""
|
||||
# Login as regular user
|
||||
login_response = await client.post(
|
||||
"/api/v1/auth/login",
|
||||
json={"email": async_test_user.email, "password": "TestPassword123!"}
|
||||
)
|
||||
token = login_response.json()["access_token"]
|
||||
|
||||
response = await client.get(
|
||||
"/api/v1/admin/users",
|
||||
headers={"Authorization": f"Bearer {token}"}
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_403_FORBIDDEN
|
||||
|
||||
|
||||
class TestAdminCreateUser:
|
||||
"""Tests for POST /admin/users endpoint."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_admin_create_user_success(self, client, async_test_superuser, superuser_token):
|
||||
"""Test successfully creating a user as admin."""
|
||||
response = await client.post(
|
||||
"/api/v1/admin/users",
|
||||
json={
|
||||
"email": "newadminuser@example.com",
|
||||
"password": "SecurePassword123!",
|
||||
"first_name": "New",
|
||||
"last_name": "User"
|
||||
},
|
||||
headers={"Authorization": f"Bearer {superuser_token}"}
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_201_CREATED
|
||||
data = response.json()
|
||||
assert data["email"] == "newadminuser@example.com"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_admin_create_user_duplicate_email(self, client, async_test_superuser, async_test_user, superuser_token):
|
||||
"""Test creating user with duplicate email fails."""
|
||||
response = await client.post(
|
||||
"/api/v1/admin/users",
|
||||
json={
|
||||
"email": async_test_user.email,
|
||||
"password": "SecurePassword123!",
|
||||
"first_name": "Duplicate",
|
||||
"last_name": "User"
|
||||
},
|
||||
headers={"Authorization": f"Bearer {superuser_token}"}
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_404_NOT_FOUND
|
||||
|
||||
|
||||
class TestAdminGetUser:
|
||||
"""Tests for GET /admin/users/{user_id} endpoint."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_admin_get_user_success(self, client, async_test_superuser, async_test_user, superuser_token):
|
||||
"""Test successfully getting user details."""
|
||||
response = await client.get(
|
||||
f"/api/v1/admin/users/{async_test_user.id}",
|
||||
headers={"Authorization": f"Bearer {superuser_token}"}
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
data = response.json()
|
||||
assert data["id"] == str(async_test_user.id)
|
||||
assert data["email"] == async_test_user.email
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_admin_get_user_not_found(self, client, async_test_superuser, superuser_token):
|
||||
"""Test getting non-existent user."""
|
||||
response = await client.get(
|
||||
f"/api/v1/admin/users/{uuid4()}",
|
||||
headers={"Authorization": f"Bearer {superuser_token}"}
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_404_NOT_FOUND
|
||||
|
||||
|
||||
class TestAdminUpdateUser:
|
||||
"""Tests for PUT /admin/users/{user_id} endpoint."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_admin_update_user_success(self, client, async_test_superuser, async_test_user, superuser_token):
|
||||
"""Test successfully updating a user."""
|
||||
response = await client.put(
|
||||
f"/api/v1/admin/users/{async_test_user.id}",
|
||||
json={"first_name": "Updated"},
|
||||
headers={"Authorization": f"Bearer {superuser_token}"}
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
data = response.json()
|
||||
assert data["first_name"] == "Updated"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_admin_update_user_not_found(self, client, async_test_superuser, superuser_token):
|
||||
"""Test updating non-existent user."""
|
||||
response = await client.put(
|
||||
f"/api/v1/admin/users/{uuid4()}",
|
||||
json={"first_name": "Updated"},
|
||||
headers={"Authorization": f"Bearer {superuser_token}"}
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_404_NOT_FOUND
|
||||
|
||||
|
||||
class TestAdminDeleteUser:
|
||||
"""Tests for DELETE /admin/users/{user_id} endpoint."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_admin_delete_user_success(self, client, async_test_superuser, async_test_db, superuser_token):
|
||||
"""Test successfully deleting a user."""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
# Create user to delete
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
from app.models.user import User
|
||||
from app.core.auth import get_password_hash
|
||||
user_to_delete = User(
|
||||
email="todelete@example.com",
|
||||
password_hash=get_password_hash("TestPassword123!"),
|
||||
first_name="To",
|
||||
last_name="Delete"
|
||||
)
|
||||
session.add(user_to_delete)
|
||||
await session.commit()
|
||||
user_id = user_to_delete.id
|
||||
|
||||
response = await client.delete(
|
||||
f"/api/v1/admin/users/{user_id}",
|
||||
headers={"Authorization": f"Bearer {superuser_token}"}
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
data = response.json()
|
||||
assert data["success"] is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_admin_delete_user_not_found(self, client, async_test_superuser, superuser_token):
|
||||
"""Test deleting non-existent user."""
|
||||
response = await client.delete(
|
||||
f"/api/v1/admin/users/{uuid4()}",
|
||||
headers={"Authorization": f"Bearer {superuser_token}"}
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_404_NOT_FOUND
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_admin_delete_self_forbidden(self, client, async_test_superuser, superuser_token):
|
||||
"""Test admin cannot delete their own account."""
|
||||
response = await client.delete(
|
||||
f"/api/v1/admin/users/{async_test_superuser.id}",
|
||||
headers={"Authorization": f"Bearer {superuser_token}"}
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_403_FORBIDDEN
|
||||
|
||||
|
||||
class TestAdminActivateUser:
|
||||
"""Tests for POST /admin/users/{user_id}/activate endpoint."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_admin_activate_user_success(self, client, async_test_superuser, async_test_db, superuser_token):
|
||||
"""Test successfully activating a user."""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
# Create inactive user
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
from app.models.user import User
|
||||
from app.core.auth import get_password_hash
|
||||
inactive_user = User(
|
||||
email="toactivate@example.com",
|
||||
password_hash=get_password_hash("TestPassword123!"),
|
||||
first_name="To",
|
||||
last_name="Activate",
|
||||
is_active=False
|
||||
)
|
||||
session.add(inactive_user)
|
||||
await session.commit()
|
||||
user_id = inactive_user.id
|
||||
|
||||
response = await client.post(
|
||||
f"/api/v1/admin/users/{user_id}/activate",
|
||||
headers={"Authorization": f"Bearer {superuser_token}"}
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
data = response.json()
|
||||
assert data["success"] is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_admin_activate_user_not_found(self, client, async_test_superuser, superuser_token):
|
||||
"""Test activating non-existent user."""
|
||||
response = await client.post(
|
||||
f"/api/v1/admin/users/{uuid4()}/activate",
|
||||
headers={"Authorization": f"Bearer {superuser_token}"}
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_404_NOT_FOUND
|
||||
|
||||
|
||||
class TestAdminDeactivateUser:
|
||||
"""Tests for POST /admin/users/{user_id}/deactivate endpoint."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_admin_deactivate_user_success(self, client, async_test_superuser, async_test_user, superuser_token):
|
||||
"""Test successfully deactivating a user."""
|
||||
response = await client.post(
|
||||
f"/api/v1/admin/users/{async_test_user.id}/deactivate",
|
||||
headers={"Authorization": f"Bearer {superuser_token}"}
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
data = response.json()
|
||||
assert data["success"] is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_admin_deactivate_user_not_found(self, client, async_test_superuser, superuser_token):
|
||||
"""Test deactivating non-existent user."""
|
||||
response = await client.post(
|
||||
f"/api/v1/admin/users/{uuid4()}/deactivate",
|
||||
headers={"Authorization": f"Bearer {superuser_token}"}
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_404_NOT_FOUND
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_admin_deactivate_self_forbidden(self, client, async_test_superuser, superuser_token):
|
||||
"""Test admin cannot deactivate their own account."""
|
||||
response = await client.post(
|
||||
f"/api/v1/admin/users/{async_test_superuser.id}/deactivate",
|
||||
headers={"Authorization": f"Bearer {superuser_token}"}
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_403_FORBIDDEN
|
||||
|
||||
|
||||
class TestAdminBulkUserAction:
|
||||
"""Tests for POST /admin/users/bulk-action endpoint."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_admin_bulk_activate_users(self, client, async_test_superuser, async_test_db, superuser_token):
|
||||
"""Test bulk activating users."""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
# Create inactive users
|
||||
user_ids = []
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
from app.models.user import User
|
||||
from app.core.auth import get_password_hash
|
||||
for i in range(3):
|
||||
user = User(
|
||||
email=f"bulk{i}@example.com",
|
||||
password_hash=get_password_hash("TestPassword123!"),
|
||||
first_name=f"Bulk{i}",
|
||||
last_name="User",
|
||||
is_active=False
|
||||
)
|
||||
session.add(user)
|
||||
await session.flush()
|
||||
user_ids.append(str(user.id))
|
||||
await session.commit()
|
||||
|
||||
response = await client.post(
|
||||
"/api/v1/admin/users/bulk-action",
|
||||
json={
|
||||
"action": "activate",
|
||||
"user_ids": user_ids
|
||||
},
|
||||
headers={"Authorization": f"Bearer {superuser_token}"}
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
data = response.json()
|
||||
assert data["affected_count"] == 3
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_admin_bulk_deactivate_users(self, client, async_test_superuser, async_test_db, superuser_token):
|
||||
"""Test bulk deactivating users."""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
# Create active users
|
||||
user_ids = []
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
from app.models.user import User
|
||||
from app.core.auth import get_password_hash
|
||||
for i in range(2):
|
||||
user = User(
|
||||
email=f"deactivate{i}@example.com",
|
||||
password_hash=get_password_hash("TestPassword123!"),
|
||||
first_name=f"Deactivate{i}",
|
||||
last_name="User",
|
||||
is_active=True
|
||||
)
|
||||
session.add(user)
|
||||
await session.flush()
|
||||
user_ids.append(str(user.id))
|
||||
await session.commit()
|
||||
|
||||
response = await client.post(
|
||||
"/api/v1/admin/users/bulk-action",
|
||||
json={
|
||||
"action": "deactivate",
|
||||
"user_ids": user_ids
|
||||
},
|
||||
headers={"Authorization": f"Bearer {superuser_token}"}
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
data = response.json()
|
||||
assert data["affected_count"] == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_admin_bulk_delete_users(self, client, async_test_superuser, async_test_db, superuser_token):
|
||||
"""Test bulk deleting users."""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
# Create users to delete
|
||||
user_ids = []
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
from app.models.user import User
|
||||
from app.core.auth import get_password_hash
|
||||
for i in range(2):
|
||||
user = User(
|
||||
email=f"bulkdelete{i}@example.com",
|
||||
password_hash=get_password_hash("TestPassword123!"),
|
||||
first_name=f"BulkDelete{i}",
|
||||
last_name="User"
|
||||
)
|
||||
session.add(user)
|
||||
await session.flush()
|
||||
user_ids.append(str(user.id))
|
||||
await session.commit()
|
||||
|
||||
response = await client.post(
|
||||
"/api/v1/admin/users/bulk-action",
|
||||
json={
|
||||
"action": "delete",
|
||||
"user_ids": user_ids
|
||||
},
|
||||
headers={"Authorization": f"Bearer {superuser_token}"}
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
data = response.json()
|
||||
assert data["affected_count"] >= 0
|
||||
|
||||
|
||||
# ===== ORGANIZATION MANAGEMENT TESTS =====
|
||||
|
||||
class TestAdminListOrganizations:
|
||||
"""Tests for GET /admin/organizations endpoint."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_admin_list_organizations_success(self, client, async_test_superuser, async_test_db, superuser_token):
|
||||
"""Test successfully listing organizations."""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
# Create test organization
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
org = Organization(name="Test Org", slug="test-org")
|
||||
session.add(org)
|
||||
await session.commit()
|
||||
|
||||
response = await client.get(
|
||||
"/api/v1/admin/organizations",
|
||||
headers={"Authorization": f"Bearer {superuser_token}"}
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
data = response.json()
|
||||
assert "data" in data
|
||||
assert "pagination" in data
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_admin_list_organizations_with_search(self, client, async_test_superuser, async_test_db, superuser_token):
|
||||
"""Test searching organizations."""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
# Create test organization
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
org = Organization(name="Searchable Org", slug="searchable-org")
|
||||
session.add(org)
|
||||
await session.commit()
|
||||
|
||||
response = await client.get(
|
||||
"/api/v1/admin/organizations?search=Searchable",
|
||||
headers={"Authorization": f"Bearer {superuser_token}"}
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
|
||||
|
||||
class TestAdminCreateOrganization:
|
||||
"""Tests for POST /admin/organizations endpoint."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_admin_create_organization_success(self, client, async_test_superuser, superuser_token):
|
||||
"""Test successfully creating an organization."""
|
||||
response = await client.post(
|
||||
"/api/v1/admin/organizations",
|
||||
json={
|
||||
"name": "New Admin Org",
|
||||
"slug": "new-admin-org",
|
||||
"description": "Created by admin"
|
||||
},
|
||||
headers={"Authorization": f"Bearer {superuser_token}"}
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_201_CREATED
|
||||
data = response.json()
|
||||
assert data["name"] == "New Admin Org"
|
||||
assert data["member_count"] == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_admin_create_organization_duplicate_slug(self, client, async_test_superuser, async_test_db, superuser_token):
|
||||
"""Test creating organization with duplicate slug fails."""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
# Create existing organization
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
org = Organization(name="Existing", slug="duplicate-slug")
|
||||
session.add(org)
|
||||
await session.commit()
|
||||
|
||||
response = await client.post(
|
||||
"/api/v1/admin/organizations",
|
||||
json={
|
||||
"name": "Duplicate",
|
||||
"slug": "duplicate-slug"
|
||||
},
|
||||
headers={"Authorization": f"Bearer {superuser_token}"}
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_404_NOT_FOUND
|
||||
|
||||
|
||||
class TestAdminGetOrganization:
|
||||
"""Tests for GET /admin/organizations/{org_id} endpoint."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_admin_get_organization_success(self, client, async_test_superuser, async_test_db, superuser_token):
|
||||
"""Test successfully getting organization details."""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
# Create test organization
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
org = Organization(name="Get Test Org", slug="get-test-org")
|
||||
session.add(org)
|
||||
await session.commit()
|
||||
org_id = org.id
|
||||
|
||||
response = await client.get(
|
||||
f"/api/v1/admin/organizations/{org_id}",
|
||||
headers={"Authorization": f"Bearer {superuser_token}"}
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
data = response.json()
|
||||
assert data["name"] == "Get Test Org"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_admin_get_organization_not_found(self, client, async_test_superuser, superuser_token):
|
||||
"""Test getting non-existent organization."""
|
||||
response = await client.get(
|
||||
f"/api/v1/admin/organizations/{uuid4()}",
|
||||
headers={"Authorization": f"Bearer {superuser_token}"}
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_404_NOT_FOUND
|
||||
|
||||
|
||||
class TestAdminUpdateOrganization:
|
||||
"""Tests for PUT /admin/organizations/{org_id} endpoint."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_admin_update_organization_success(self, client, async_test_superuser, async_test_db, superuser_token):
|
||||
"""Test successfully updating an organization."""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
# Create test organization
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
org = Organization(name="Update Test", slug="update-test")
|
||||
session.add(org)
|
||||
await session.commit()
|
||||
org_id = org.id
|
||||
|
||||
response = await client.put(
|
||||
f"/api/v1/admin/organizations/{org_id}",
|
||||
json={"name": "Updated Name"},
|
||||
headers={"Authorization": f"Bearer {superuser_token}"}
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
data = response.json()
|
||||
assert data["name"] == "Updated Name"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_admin_update_organization_not_found(self, client, async_test_superuser, superuser_token):
|
||||
"""Test updating non-existent organization."""
|
||||
response = await client.put(
|
||||
f"/api/v1/admin/organizations/{uuid4()}",
|
||||
json={"name": "Updated"},
|
||||
headers={"Authorization": f"Bearer {superuser_token}"}
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_404_NOT_FOUND
|
||||
|
||||
|
||||
class TestAdminDeleteOrganization:
|
||||
"""Tests for DELETE /admin/organizations/{org_id} endpoint."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_admin_delete_organization_success(self, client, async_test_superuser, async_test_db, superuser_token):
|
||||
"""Test successfully deleting an organization."""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
# Create test organization
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
org = Organization(name="Delete Test", slug="delete-test")
|
||||
session.add(org)
|
||||
await session.commit()
|
||||
org_id = org.id
|
||||
|
||||
response = await client.delete(
|
||||
f"/api/v1/admin/organizations/{org_id}",
|
||||
headers={"Authorization": f"Bearer {superuser_token}"}
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
data = response.json()
|
||||
assert data["success"] is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_admin_delete_organization_not_found(self, client, async_test_superuser, superuser_token):
|
||||
"""Test deleting non-existent organization."""
|
||||
response = await client.delete(
|
||||
f"/api/v1/admin/organizations/{uuid4()}",
|
||||
headers={"Authorization": f"Bearer {superuser_token}"}
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_404_NOT_FOUND
|
||||
|
||||
|
||||
class TestAdminListOrganizationMembers:
|
||||
"""Tests for GET /admin/organizations/{org_id}/members endpoint."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_admin_list_organization_members_success(self, client, async_test_superuser, async_test_db, async_test_user, superuser_token):
|
||||
"""Test successfully listing organization members."""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
# Create test organization with member
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
org = Organization(name="Members Test", slug="members-test")
|
||||
session.add(org)
|
||||
await session.commit()
|
||||
|
||||
user_org = UserOrganization(
|
||||
user_id=async_test_user.id,
|
||||
organization_id=org.id,
|
||||
role=OrganizationRole.MEMBER,
|
||||
is_active=True
|
||||
)
|
||||
session.add(user_org)
|
||||
await session.commit()
|
||||
org_id = org.id
|
||||
|
||||
response = await client.get(
|
||||
f"/api/v1/admin/organizations/{org_id}/members",
|
||||
headers={"Authorization": f"Bearer {superuser_token}"}
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
data = response.json()
|
||||
assert "data" in data
|
||||
assert len(data["data"]) >= 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_admin_list_organization_members_not_found(self, client, async_test_superuser, superuser_token):
|
||||
"""Test listing members of non-existent organization."""
|
||||
response = await client.get(
|
||||
f"/api/v1/admin/organizations/{uuid4()}/members",
|
||||
headers={"Authorization": f"Bearer {superuser_token}"}
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_404_NOT_FOUND
|
||||
|
||||
|
||||
class TestAdminAddOrganizationMember:
|
||||
"""Tests for POST /admin/organizations/{org_id}/members endpoint."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_admin_add_organization_member_success(self, client, async_test_superuser, async_test_db, async_test_user, superuser_token):
|
||||
"""Test successfully adding a member to organization."""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
# Create test organization
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
org = Organization(name="Add Member Test", slug="add-member-test")
|
||||
session.add(org)
|
||||
await session.commit()
|
||||
org_id = org.id
|
||||
|
||||
response = await client.post(
|
||||
f"/api/v1/admin/organizations/{org_id}/members",
|
||||
json={
|
||||
"user_id": str(async_test_user.id),
|
||||
"role": "member"
|
||||
},
|
||||
headers={"Authorization": f"Bearer {superuser_token}"}
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
data = response.json()
|
||||
assert data["success"] is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_admin_add_organization_member_already_exists(self, client, async_test_superuser, async_test_db, async_test_user, superuser_token):
|
||||
"""Test adding member who is already a member."""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
# Create organization with existing member
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
org = Organization(name="Existing Member", slug="existing-member")
|
||||
session.add(org)
|
||||
await session.commit()
|
||||
|
||||
user_org = UserOrganization(
|
||||
user_id=async_test_user.id,
|
||||
organization_id=org.id,
|
||||
role=OrganizationRole.MEMBER,
|
||||
is_active=True
|
||||
)
|
||||
session.add(user_org)
|
||||
await session.commit()
|
||||
org_id = org.id
|
||||
|
||||
response = await client.post(
|
||||
f"/api/v1/admin/organizations/{org_id}/members",
|
||||
json={
|
||||
"user_id": str(async_test_user.id),
|
||||
"role": "member"
|
||||
},
|
||||
headers={"Authorization": f"Bearer {superuser_token}"}
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_409_CONFLICT
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_admin_add_organization_member_org_not_found(self, client, async_test_superuser, async_test_user, superuser_token):
|
||||
"""Test adding member to non-existent organization."""
|
||||
response = await client.post(
|
||||
f"/api/v1/admin/organizations/{uuid4()}/members",
|
||||
json={
|
||||
"user_id": str(async_test_user.id),
|
||||
"role": "member"
|
||||
},
|
||||
headers={"Authorization": f"Bearer {superuser_token}"}
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_404_NOT_FOUND
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_admin_add_organization_member_user_not_found(self, client, async_test_superuser, async_test_db, superuser_token):
|
||||
"""Test adding non-existent user to organization."""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
# Create test organization
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
org = Organization(name="User Not Found", slug="user-not-found")
|
||||
session.add(org)
|
||||
await session.commit()
|
||||
org_id = org.id
|
||||
|
||||
response = await client.post(
|
||||
f"/api/v1/admin/organizations/{org_id}/members",
|
||||
json={
|
||||
"user_id": str(uuid4()),
|
||||
"role": "member"
|
||||
},
|
||||
headers={"Authorization": f"Bearer {superuser_token}"}
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_404_NOT_FOUND
|
||||
|
||||
|
||||
class TestAdminRemoveOrganizationMember:
|
||||
"""Tests for DELETE /admin/organizations/{org_id}/members/{user_id} endpoint."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_admin_remove_organization_member_success(self, client, async_test_superuser, async_test_db, async_test_user, superuser_token):
|
||||
"""Test successfully removing a member from organization."""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
# Create organization with member
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
org = Organization(name="Remove Member", slug="remove-member")
|
||||
session.add(org)
|
||||
await session.commit()
|
||||
|
||||
user_org = UserOrganization(
|
||||
user_id=async_test_user.id,
|
||||
organization_id=org.id,
|
||||
role=OrganizationRole.MEMBER,
|
||||
is_active=True
|
||||
)
|
||||
session.add(user_org)
|
||||
await session.commit()
|
||||
org_id = org.id
|
||||
|
||||
response = await client.delete(
|
||||
f"/api/v1/admin/organizations/{org_id}/members/{async_test_user.id}",
|
||||
headers={"Authorization": f"Bearer {superuser_token}"}
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
data = response.json()
|
||||
assert data["success"] is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_admin_remove_organization_member_not_member(self, client, async_test_superuser, async_test_db, async_test_user, superuser_token):
|
||||
"""Test removing user who is not a member."""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
# Create organization without member
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
org = Organization(name="No Member", slug="no-member")
|
||||
session.add(org)
|
||||
await session.commit()
|
||||
org_id = org.id
|
||||
|
||||
response = await client.delete(
|
||||
f"/api/v1/admin/organizations/{org_id}/members/{async_test_user.id}",
|
||||
headers={"Authorization": f"Bearer {superuser_token}"}
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_404_NOT_FOUND
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_admin_remove_organization_member_org_not_found(self, client, async_test_superuser, async_test_user, superuser_token):
|
||||
"""Test removing member from non-existent organization."""
|
||||
response = await client.delete(
|
||||
f"/api/v1/admin/organizations/{uuid4()}/members/{async_test_user.id}",
|
||||
headers={"Authorization": f"Bearer {superuser_token}"}
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_404_NOT_FOUND
|
||||
@@ -326,59 +326,3 @@ class TestRefreshTokenEndpoint:
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR
|
||||
|
||||
|
||||
class TestGetCurrentUserEndpoint:
|
||||
"""Tests for GET /auth/me endpoint."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_current_user_success(self, client, async_test_user):
|
||||
"""Test getting current user info."""
|
||||
# First, login to get an access token
|
||||
login_response = await client.post(
|
||||
"/api/v1/auth/login",
|
||||
json={
|
||||
"email": async_test_user.email,
|
||||
"password": "TestPassword123!"
|
||||
}
|
||||
)
|
||||
access_token = login_response.json()["access_token"]
|
||||
|
||||
# Get current user info
|
||||
response = await client.get(
|
||||
"/api/v1/auth/me",
|
||||
headers={"Authorization": f"Bearer {access_token}"}
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
data = response.json()
|
||||
assert data["email"] == async_test_user.email
|
||||
assert data["first_name"] == async_test_user.first_name
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_current_user_no_token(self, client):
|
||||
"""Test getting current user without token."""
|
||||
response = await client.get("/api/v1/auth/me")
|
||||
|
||||
assert response.status_code == status.HTTP_401_UNAUTHORIZED
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_current_user_invalid_token(self, client):
|
||||
"""Test getting current user with invalid token."""
|
||||
response = await client.get(
|
||||
"/api/v1/auth/me",
|
||||
headers={"Authorization": "Bearer invalid_token"}
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_401_UNAUTHORIZED
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_current_user_expired_token(self, client):
|
||||
"""Test getting current user with expired token."""
|
||||
# Use a clearly invalid/malformed token
|
||||
response = await client.get(
|
||||
"/api/v1/auth/me",
|
||||
headers={"Authorization": "Bearer eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.invalid"}
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_401_UNAUTHORIZED
|
||||
|
||||
944
backend/tests/crud/test_organization_async.py
Normal file
944
backend/tests/crud/test_organization_async.py
Normal file
@@ -0,0 +1,944 @@
|
||||
# tests/crud/test_organization_async.py
|
||||
"""
|
||||
Comprehensive tests for async organization CRUD operations.
|
||||
"""
|
||||
import pytest
|
||||
from uuid import uuid4
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.crud.organization_async import organization_async
|
||||
from app.models.organization import Organization
|
||||
from app.models.user_organization import UserOrganization, OrganizationRole
|
||||
from app.models.user import User
|
||||
from app.schemas.organizations import OrganizationCreate, OrganizationUpdate
|
||||
|
||||
|
||||
class TestGetBySlug:
|
||||
"""Tests for get_by_slug method."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_by_slug_success(self, async_test_db):
|
||||
"""Test successfully getting an organization by slug."""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
# Create test organization
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
org = Organization(
|
||||
name="Test Org",
|
||||
slug="test-org",
|
||||
description="Test description"
|
||||
)
|
||||
session.add(org)
|
||||
await session.commit()
|
||||
org_id = org.id
|
||||
|
||||
# Get by slug
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
result = await organization_async.get_by_slug(session, slug="test-org")
|
||||
assert result is not None
|
||||
assert result.id == org_id
|
||||
assert result.slug == "test-org"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_by_slug_not_found(self, async_test_db):
|
||||
"""Test getting non-existent organization by slug."""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
result = await organization_async.get_by_slug(session, slug="nonexistent")
|
||||
assert result is None
|
||||
|
||||
|
||||
class TestCreate:
|
||||
"""Tests for create method."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_success(self, async_test_db):
|
||||
"""Test successfully creating an organization."""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
org_in = OrganizationCreate(
|
||||
name="New Org",
|
||||
slug="new-org",
|
||||
description="New organization",
|
||||
is_active=True,
|
||||
settings={"key": "value"}
|
||||
)
|
||||
result = await organization_async.create(session, obj_in=org_in)
|
||||
|
||||
assert result.name == "New Org"
|
||||
assert result.slug == "new-org"
|
||||
assert result.description == "New organization"
|
||||
assert result.is_active is True
|
||||
assert result.settings == {"key": "value"}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_duplicate_slug(self, async_test_db):
|
||||
"""Test creating organization with duplicate slug raises error."""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
# Create first org
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
org1 = Organization(name="Org 1", slug="duplicate-slug")
|
||||
session.add(org1)
|
||||
await session.commit()
|
||||
|
||||
# Try to create second with same slug
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
org_in = OrganizationCreate(
|
||||
name="Org 2",
|
||||
slug="duplicate-slug"
|
||||
)
|
||||
with pytest.raises(ValueError, match="already exists"):
|
||||
await organization_async.create(session, obj_in=org_in)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_without_settings(self, async_test_db):
|
||||
"""Test creating organization without settings (defaults to empty dict)."""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
org_in = OrganizationCreate(
|
||||
name="No Settings Org",
|
||||
slug="no-settings"
|
||||
)
|
||||
result = await organization_async.create(session, obj_in=org_in)
|
||||
|
||||
assert result.settings == {}
|
||||
|
||||
|
||||
class TestGetMultiWithFilters:
|
||||
"""Tests for get_multi_with_filters method."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_multi_with_filters_no_filters(self, async_test_db):
|
||||
"""Test getting organizations without any filters."""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
# Create test organizations
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
for i in range(5):
|
||||
org = Organization(name=f"Org {i}", slug=f"org-{i}")
|
||||
session.add(org)
|
||||
await session.commit()
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
orgs, total = await organization_async.get_multi_with_filters(session)
|
||||
assert total == 5
|
||||
assert len(orgs) == 5
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_multi_with_filters_is_active(self, async_test_db):
|
||||
"""Test filtering by is_active."""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
org1 = Organization(name="Active", slug="active", is_active=True)
|
||||
org2 = Organization(name="Inactive", slug="inactive", is_active=False)
|
||||
session.add_all([org1, org2])
|
||||
await session.commit()
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
orgs, total = await organization_async.get_multi_with_filters(
|
||||
session,
|
||||
is_active=True
|
||||
)
|
||||
assert total == 1
|
||||
assert orgs[0].name == "Active"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_multi_with_filters_search(self, async_test_db):
|
||||
"""Test searching organizations."""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
org1 = Organization(name="Tech Corp", slug="tech-corp", description="Technology")
|
||||
org2 = Organization(name="Food Inc", slug="food-inc", description="Restaurant")
|
||||
session.add_all([org1, org2])
|
||||
await session.commit()
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
orgs, total = await organization_async.get_multi_with_filters(
|
||||
session,
|
||||
search="tech"
|
||||
)
|
||||
assert total == 1
|
||||
assert orgs[0].name == "Tech Corp"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_multi_with_filters_pagination(self, async_test_db):
|
||||
"""Test pagination."""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
for i in range(10):
|
||||
org = Organization(name=f"Org {i}", slug=f"org-{i}")
|
||||
session.add(org)
|
||||
await session.commit()
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
orgs, total = await organization_async.get_multi_with_filters(
|
||||
session,
|
||||
skip=2,
|
||||
limit=3
|
||||
)
|
||||
assert total == 10
|
||||
assert len(orgs) == 3
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_multi_with_filters_sorting(self, async_test_db):
|
||||
"""Test sorting."""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
org1 = Organization(name="B Org", slug="b-org")
|
||||
org2 = Organization(name="A Org", slug="a-org")
|
||||
session.add_all([org1, org2])
|
||||
await session.commit()
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
orgs, total = await organization_async.get_multi_with_filters(
|
||||
session,
|
||||
sort_by="name",
|
||||
sort_order="asc"
|
||||
)
|
||||
assert orgs[0].name == "A Org"
|
||||
assert orgs[1].name == "B Org"
|
||||
|
||||
|
||||
class TestGetMemberCount:
|
||||
"""Tests for get_member_count method."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_member_count_success(self, async_test_db, async_test_user):
|
||||
"""Test getting member count for organization."""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
org = Organization(name="Test Org", slug="test-org")
|
||||
session.add(org)
|
||||
await session.commit()
|
||||
|
||||
# Add 1 active member
|
||||
user_org = UserOrganization(
|
||||
user_id=async_test_user.id,
|
||||
organization_id=org.id,
|
||||
role=OrganizationRole.MEMBER,
|
||||
is_active=True
|
||||
)
|
||||
session.add(user_org)
|
||||
await session.commit()
|
||||
org_id = org.id
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
count = await organization_async.get_member_count(session, organization_id=org_id)
|
||||
assert count == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_member_count_no_members(self, async_test_db):
|
||||
"""Test getting member count for organization with no members."""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
org = Organization(name="Empty Org", slug="empty-org")
|
||||
session.add(org)
|
||||
await session.commit()
|
||||
org_id = org.id
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
count = await organization_async.get_member_count(session, organization_id=org_id)
|
||||
assert count == 0
|
||||
|
||||
|
||||
class TestAddUser:
|
||||
"""Tests for add_user method."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_add_user_success(self, async_test_db, async_test_user):
|
||||
"""Test successfully adding a user to organization."""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
org = Organization(name="Test Org", slug="test-org")
|
||||
session.add(org)
|
||||
await session.commit()
|
||||
org_id = org.id
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
result = await organization_async.add_user(
|
||||
session,
|
||||
organization_id=org_id,
|
||||
user_id=async_test_user.id,
|
||||
role=OrganizationRole.ADMIN
|
||||
)
|
||||
|
||||
assert result.user_id == async_test_user.id
|
||||
assert result.organization_id == org_id
|
||||
assert result.role == OrganizationRole.ADMIN
|
||||
assert result.is_active is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_add_user_already_active_member(self, async_test_db, async_test_user):
|
||||
"""Test adding user who is already an active member raises error."""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
org = Organization(name="Test Org", slug="test-org")
|
||||
session.add(org)
|
||||
await session.commit()
|
||||
|
||||
user_org = UserOrganization(
|
||||
user_id=async_test_user.id,
|
||||
organization_id=org.id,
|
||||
role=OrganizationRole.MEMBER,
|
||||
is_active=True
|
||||
)
|
||||
session.add(user_org)
|
||||
await session.commit()
|
||||
org_id = org.id
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
with pytest.raises(ValueError, match="already a member"):
|
||||
await organization_async.add_user(
|
||||
session,
|
||||
organization_id=org_id,
|
||||
user_id=async_test_user.id
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_add_user_reactivate_inactive(self, async_test_db, async_test_user):
|
||||
"""Test adding user who was previously inactive reactivates them."""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
org = Organization(name="Test Org", slug="test-org")
|
||||
session.add(org)
|
||||
await session.commit()
|
||||
|
||||
user_org = UserOrganization(
|
||||
user_id=async_test_user.id,
|
||||
organization_id=org.id,
|
||||
role=OrganizationRole.MEMBER,
|
||||
is_active=False
|
||||
)
|
||||
session.add(user_org)
|
||||
await session.commit()
|
||||
org_id = org.id
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
result = await organization_async.add_user(
|
||||
session,
|
||||
organization_id=org_id,
|
||||
user_id=async_test_user.id,
|
||||
role=OrganizationRole.ADMIN
|
||||
)
|
||||
|
||||
assert result.is_active is True
|
||||
assert result.role == OrganizationRole.ADMIN
|
||||
|
||||
|
||||
class TestRemoveUser:
|
||||
"""Tests for remove_user method."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_remove_user_success(self, async_test_db, async_test_user):
|
||||
"""Test successfully removing a user from organization."""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
org = Organization(name="Test Org", slug="test-org")
|
||||
session.add(org)
|
||||
await session.commit()
|
||||
|
||||
user_org = UserOrganization(
|
||||
user_id=async_test_user.id,
|
||||
organization_id=org.id,
|
||||
role=OrganizationRole.MEMBER,
|
||||
is_active=True
|
||||
)
|
||||
session.add(user_org)
|
||||
await session.commit()
|
||||
org_id = org.id
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
result = await organization_async.remove_user(
|
||||
session,
|
||||
organization_id=org_id,
|
||||
user_id=async_test_user.id
|
||||
)
|
||||
|
||||
assert result is True
|
||||
|
||||
# Verify soft delete
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
stmt = select(UserOrganization).where(
|
||||
UserOrganization.user_id == async_test_user.id,
|
||||
UserOrganization.organization_id == org_id
|
||||
)
|
||||
result = await session.execute(stmt)
|
||||
user_org = result.scalar_one_or_none()
|
||||
assert user_org.is_active is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_remove_user_not_found(self, async_test_db):
|
||||
"""Test removing non-existent user returns False."""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
org = Organization(name="Test Org", slug="test-org")
|
||||
session.add(org)
|
||||
await session.commit()
|
||||
org_id = org.id
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
result = await organization_async.remove_user(
|
||||
session,
|
||||
organization_id=org_id,
|
||||
user_id=uuid4()
|
||||
)
|
||||
|
||||
assert result is False
|
||||
|
||||
|
||||
class TestUpdateUserRole:
|
||||
"""Tests for update_user_role method."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_user_role_success(self, async_test_db, async_test_user):
|
||||
"""Test successfully updating user role."""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
org = Organization(name="Test Org", slug="test-org")
|
||||
session.add(org)
|
||||
await session.commit()
|
||||
|
||||
user_org = UserOrganization(
|
||||
user_id=async_test_user.id,
|
||||
organization_id=org.id,
|
||||
role=OrganizationRole.MEMBER,
|
||||
is_active=True
|
||||
)
|
||||
session.add(user_org)
|
||||
await session.commit()
|
||||
org_id = org.id
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
result = await organization_async.update_user_role(
|
||||
session,
|
||||
organization_id=org_id,
|
||||
user_id=async_test_user.id,
|
||||
role=OrganizationRole.ADMIN,
|
||||
custom_permissions="custom"
|
||||
)
|
||||
|
||||
assert result.role == OrganizationRole.ADMIN
|
||||
assert result.custom_permissions == "custom"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_user_role_not_found(self, async_test_db):
|
||||
"""Test updating role for non-existent user returns None."""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
org = Organization(name="Test Org", slug="test-org")
|
||||
session.add(org)
|
||||
await session.commit()
|
||||
org_id = org.id
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
result = await organization_async.update_user_role(
|
||||
session,
|
||||
organization_id=org_id,
|
||||
user_id=uuid4(),
|
||||
role=OrganizationRole.ADMIN
|
||||
)
|
||||
|
||||
assert result is None
|
||||
|
||||
|
||||
class TestGetOrganizationMembers:
|
||||
"""Tests for get_organization_members method."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_organization_members_success(self, async_test_db, async_test_user):
|
||||
"""Test getting organization members."""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
org = Organization(name="Test Org", slug="test-org")
|
||||
session.add(org)
|
||||
await session.commit()
|
||||
|
||||
user_org = UserOrganization(
|
||||
user_id=async_test_user.id,
|
||||
organization_id=org.id,
|
||||
role=OrganizationRole.ADMIN,
|
||||
is_active=True
|
||||
)
|
||||
session.add(user_org)
|
||||
await session.commit()
|
||||
org_id = org.id
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
members, total = await organization_async.get_organization_members(
|
||||
session,
|
||||
organization_id=org_id
|
||||
)
|
||||
|
||||
assert total == 1
|
||||
assert len(members) == 1
|
||||
assert members[0]["user_id"] == async_test_user.id
|
||||
assert members[0]["email"] == async_test_user.email
|
||||
assert members[0]["role"] == OrganizationRole.ADMIN
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_organization_members_with_pagination(self, async_test_db, async_test_user):
|
||||
"""Test getting organization members with pagination."""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
org = Organization(name="Test Org", slug="test-org")
|
||||
session.add(org)
|
||||
await session.commit()
|
||||
|
||||
user_org = UserOrganization(
|
||||
user_id=async_test_user.id,
|
||||
organization_id=org.id,
|
||||
role=OrganizationRole.MEMBER,
|
||||
is_active=True
|
||||
)
|
||||
session.add(user_org)
|
||||
await session.commit()
|
||||
org_id = org.id
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
members, total = await organization_async.get_organization_members(
|
||||
session,
|
||||
organization_id=org_id,
|
||||
skip=0,
|
||||
limit=10
|
||||
)
|
||||
|
||||
assert total == 1
|
||||
assert len(members) <= 10
|
||||
|
||||
|
||||
class TestGetUserOrganizations:
|
||||
"""Tests for get_user_organizations method."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_user_organizations_success(self, async_test_db, async_test_user):
|
||||
"""Test getting user's organizations."""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
org = Organization(name="Test Org", slug="test-org")
|
||||
session.add(org)
|
||||
await session.commit()
|
||||
|
||||
user_org = UserOrganization(
|
||||
user_id=async_test_user.id,
|
||||
organization_id=org.id,
|
||||
role=OrganizationRole.MEMBER,
|
||||
is_active=True
|
||||
)
|
||||
session.add(user_org)
|
||||
await session.commit()
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
orgs = await organization_async.get_user_organizations(
|
||||
session,
|
||||
user_id=async_test_user.id
|
||||
)
|
||||
|
||||
assert len(orgs) == 1
|
||||
assert orgs[0].name == "Test Org"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_user_organizations_filter_inactive(self, async_test_db, async_test_user):
|
||||
"""Test filtering inactive organizations."""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
org1 = Organization(name="Active Org", slug="active-org")
|
||||
org2 = Organization(name="Inactive Org", slug="inactive-org")
|
||||
session.add_all([org1, org2])
|
||||
await session.commit()
|
||||
|
||||
user_org1 = UserOrganization(
|
||||
user_id=async_test_user.id,
|
||||
organization_id=org1.id,
|
||||
role=OrganizationRole.MEMBER,
|
||||
is_active=True
|
||||
)
|
||||
user_org2 = UserOrganization(
|
||||
user_id=async_test_user.id,
|
||||
organization_id=org2.id,
|
||||
role=OrganizationRole.MEMBER,
|
||||
is_active=False
|
||||
)
|
||||
session.add_all([user_org1, user_org2])
|
||||
await session.commit()
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
orgs = await organization_async.get_user_organizations(
|
||||
session,
|
||||
user_id=async_test_user.id,
|
||||
is_active=True
|
||||
)
|
||||
|
||||
assert len(orgs) == 1
|
||||
assert orgs[0].name == "Active Org"
|
||||
|
||||
|
||||
class TestGetUserRole:
|
||||
"""Tests for get_user_role_in_org method."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_user_role_in_org_success(self, async_test_db, async_test_user):
|
||||
"""Test getting user role in organization."""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
org = Organization(name="Test Org", slug="test-org")
|
||||
session.add(org)
|
||||
await session.commit()
|
||||
|
||||
user_org = UserOrganization(
|
||||
user_id=async_test_user.id,
|
||||
organization_id=org.id,
|
||||
role=OrganizationRole.ADMIN,
|
||||
is_active=True
|
||||
)
|
||||
session.add(user_org)
|
||||
await session.commit()
|
||||
org_id = org.id
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
role = await organization_async.get_user_role_in_org(
|
||||
session,
|
||||
user_id=async_test_user.id,
|
||||
organization_id=org_id
|
||||
)
|
||||
|
||||
assert role == OrganizationRole.ADMIN
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_user_role_in_org_not_found(self, async_test_db):
|
||||
"""Test getting role for non-member returns None."""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
org = Organization(name="Test Org", slug="test-org")
|
||||
session.add(org)
|
||||
await session.commit()
|
||||
org_id = org.id
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
role = await organization_async.get_user_role_in_org(
|
||||
session,
|
||||
user_id=uuid4(),
|
||||
organization_id=org_id
|
||||
)
|
||||
|
||||
assert role is None
|
||||
|
||||
|
||||
class TestIsUserOrgOwner:
|
||||
"""Tests for is_user_org_owner method."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_is_user_org_owner_true(self, async_test_db, async_test_user):
|
||||
"""Test checking if user is owner."""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
org = Organization(name="Test Org", slug="test-org")
|
||||
session.add(org)
|
||||
await session.commit()
|
||||
|
||||
user_org = UserOrganization(
|
||||
user_id=async_test_user.id,
|
||||
organization_id=org.id,
|
||||
role=OrganizationRole.OWNER,
|
||||
is_active=True
|
||||
)
|
||||
session.add(user_org)
|
||||
await session.commit()
|
||||
org_id = org.id
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
is_owner = await organization_async.is_user_org_owner(
|
||||
session,
|
||||
user_id=async_test_user.id,
|
||||
organization_id=org_id
|
||||
)
|
||||
|
||||
assert is_owner is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_is_user_org_owner_false(self, async_test_db, async_test_user):
|
||||
"""Test checking if non-owner user is owner."""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
org = Organization(name="Test Org", slug="test-org")
|
||||
session.add(org)
|
||||
await session.commit()
|
||||
|
||||
user_org = UserOrganization(
|
||||
user_id=async_test_user.id,
|
||||
organization_id=org.id,
|
||||
role=OrganizationRole.MEMBER,
|
||||
is_active=True
|
||||
)
|
||||
session.add(user_org)
|
||||
await session.commit()
|
||||
org_id = org.id
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
is_owner = await organization_async.is_user_org_owner(
|
||||
session,
|
||||
user_id=async_test_user.id,
|
||||
organization_id=org_id
|
||||
)
|
||||
|
||||
assert is_owner is False
|
||||
|
||||
|
||||
class TestGetMultiWithMemberCounts:
|
||||
"""Tests for get_multi_with_member_counts method."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_multi_with_member_counts_success(self, async_test_db, async_test_user):
|
||||
"""Test getting organizations with member counts."""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
org1 = Organization(name="Org 1", slug="org-1")
|
||||
org2 = Organization(name="Org 2", slug="org-2")
|
||||
session.add_all([org1, org2])
|
||||
await session.commit()
|
||||
|
||||
# Add members to org1
|
||||
user_org1 = UserOrganization(
|
||||
user_id=async_test_user.id,
|
||||
organization_id=org1.id,
|
||||
role=OrganizationRole.MEMBER,
|
||||
is_active=True
|
||||
)
|
||||
session.add(user_org1)
|
||||
await session.commit()
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
orgs_with_counts, total = await organization_async.get_multi_with_member_counts(session)
|
||||
|
||||
assert total == 2
|
||||
assert len(orgs_with_counts) == 2
|
||||
# Verify structure
|
||||
assert 'organization' in orgs_with_counts[0]
|
||||
assert 'member_count' in orgs_with_counts[0]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_multi_with_member_counts_with_filters(self, async_test_db):
|
||||
"""Test getting organizations with member counts and filters."""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
org1 = Organization(name="Active Org", slug="active-org", is_active=True)
|
||||
org2 = Organization(name="Inactive Org", slug="inactive-org", is_active=False)
|
||||
session.add_all([org1, org2])
|
||||
await session.commit()
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
orgs_with_counts, total = await organization_async.get_multi_with_member_counts(
|
||||
session,
|
||||
is_active=True
|
||||
)
|
||||
|
||||
assert total == 1
|
||||
assert orgs_with_counts[0]['organization'].name == "Active Org"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_multi_with_member_counts_with_search(self, async_test_db):
|
||||
"""Test searching organizations with member counts."""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
org1 = Organization(name="Tech Corp", slug="tech-corp")
|
||||
org2 = Organization(name="Food Inc", slug="food-inc")
|
||||
session.add_all([org1, org2])
|
||||
await session.commit()
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
orgs_with_counts, total = await organization_async.get_multi_with_member_counts(
|
||||
session,
|
||||
search="tech"
|
||||
)
|
||||
|
||||
assert total == 1
|
||||
assert orgs_with_counts[0]['organization'].name == "Tech Corp"
|
||||
|
||||
|
||||
class TestGetUserOrganizationsWithDetails:
|
||||
"""Tests for get_user_organizations_with_details method."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_user_organizations_with_details_success(self, async_test_db, async_test_user):
|
||||
"""Test getting user organizations with role and member count."""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
org = Organization(name="Test Org", slug="test-org")
|
||||
session.add(org)
|
||||
await session.commit()
|
||||
|
||||
user_org = UserOrganization(
|
||||
user_id=async_test_user.id,
|
||||
organization_id=org.id,
|
||||
role=OrganizationRole.ADMIN,
|
||||
is_active=True
|
||||
)
|
||||
session.add(user_org)
|
||||
await session.commit()
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
orgs_with_details = await organization_async.get_user_organizations_with_details(
|
||||
session,
|
||||
user_id=async_test_user.id
|
||||
)
|
||||
|
||||
assert len(orgs_with_details) == 1
|
||||
assert orgs_with_details[0]['organization'].name == "Test Org"
|
||||
assert orgs_with_details[0]['role'] == OrganizationRole.ADMIN
|
||||
assert 'member_count' in orgs_with_details[0]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_user_organizations_with_details_filter_inactive(self, async_test_db, async_test_user):
|
||||
"""Test filtering inactive organizations in user details."""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
org1 = Organization(name="Active Org", slug="active-org")
|
||||
org2 = Organization(name="Inactive Org", slug="inactive-org")
|
||||
session.add_all([org1, org2])
|
||||
await session.commit()
|
||||
|
||||
user_org1 = UserOrganization(
|
||||
user_id=async_test_user.id,
|
||||
organization_id=org1.id,
|
||||
role=OrganizationRole.MEMBER,
|
||||
is_active=True
|
||||
)
|
||||
user_org2 = UserOrganization(
|
||||
user_id=async_test_user.id,
|
||||
organization_id=org2.id,
|
||||
role=OrganizationRole.MEMBER,
|
||||
is_active=False
|
||||
)
|
||||
session.add_all([user_org1, user_org2])
|
||||
await session.commit()
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
orgs_with_details = await organization_async.get_user_organizations_with_details(
|
||||
session,
|
||||
user_id=async_test_user.id,
|
||||
is_active=True
|
||||
)
|
||||
|
||||
assert len(orgs_with_details) == 1
|
||||
assert orgs_with_details[0]['organization'].name == "Active Org"
|
||||
|
||||
|
||||
class TestIsUserOrgAdmin:
|
||||
"""Tests for is_user_org_admin method."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_is_user_org_admin_owner(self, async_test_db, async_test_user):
|
||||
"""Test checking if owner is admin (should be True)."""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
org = Organization(name="Test Org", slug="test-org")
|
||||
session.add(org)
|
||||
await session.commit()
|
||||
|
||||
user_org = UserOrganization(
|
||||
user_id=async_test_user.id,
|
||||
organization_id=org.id,
|
||||
role=OrganizationRole.OWNER,
|
||||
is_active=True
|
||||
)
|
||||
session.add(user_org)
|
||||
await session.commit()
|
||||
org_id = org.id
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
is_admin = await organization_async.is_user_org_admin(
|
||||
session,
|
||||
user_id=async_test_user.id,
|
||||
organization_id=org_id
|
||||
)
|
||||
|
||||
assert is_admin is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_is_user_org_admin_admin_role(self, async_test_db, async_test_user):
|
||||
"""Test checking if admin role is admin."""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
org = Organization(name="Test Org", slug="test-org")
|
||||
session.add(org)
|
||||
await session.commit()
|
||||
|
||||
user_org = UserOrganization(
|
||||
user_id=async_test_user.id,
|
||||
organization_id=org.id,
|
||||
role=OrganizationRole.ADMIN,
|
||||
is_active=True
|
||||
)
|
||||
session.add(user_org)
|
||||
await session.commit()
|
||||
org_id = org.id
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
is_admin = await organization_async.is_user_org_admin(
|
||||
session,
|
||||
user_id=async_test_user.id,
|
||||
organization_id=org_id
|
||||
)
|
||||
|
||||
assert is_admin is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_is_user_org_admin_member_false(self, async_test_db, async_test_user):
|
||||
"""Test checking if regular member is admin."""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
org = Organization(name="Test Org", slug="test-org")
|
||||
session.add(org)
|
||||
await session.commit()
|
||||
|
||||
user_org = UserOrganization(
|
||||
user_id=async_test_user.id,
|
||||
organization_id=org.id,
|
||||
role=OrganizationRole.MEMBER,
|
||||
is_active=True
|
||||
)
|
||||
session.add(user_org)
|
||||
await session.commit()
|
||||
org_id = org.id
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
is_admin = await organization_async.is_user_org_admin(
|
||||
session,
|
||||
user_id=async_test_user.id,
|
||||
organization_id=org_id
|
||||
)
|
||||
|
||||
assert is_admin is False
|
||||
339
backend/tests/crud/test_session_async.py
Normal file
339
backend/tests/crud/test_session_async.py
Normal file
@@ -0,0 +1,339 @@
|
||||
# tests/crud/test_session_async.py
|
||||
"""
|
||||
Comprehensive tests for async session CRUD operations.
|
||||
"""
|
||||
import pytest
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from uuid import uuid4
|
||||
|
||||
from app.crud.session_async import session_async
|
||||
from app.models.user_session import UserSession
|
||||
from app.schemas.sessions import SessionCreate
|
||||
|
||||
|
||||
class TestGetByJti:
|
||||
"""Tests for get_by_jti method."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_by_jti_success(self, async_test_db, async_test_user):
|
||||
"""Test getting session by JTI."""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
user_session = UserSession(
|
||||
user_id=async_test_user.id,
|
||||
refresh_token_jti="test_jti_123",
|
||||
device_name="Test Device",
|
||||
ip_address="192.168.1.1",
|
||||
user_agent="Mozilla/5.0",
|
||||
is_active=True,
|
||||
expires_at=datetime.now(timezone.utc) + timedelta(days=7),
|
||||
last_used_at=datetime.now(timezone.utc)
|
||||
)
|
||||
session.add(user_session)
|
||||
await session.commit()
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
result = await session_async.get_by_jti(session, jti="test_jti_123")
|
||||
assert result is not None
|
||||
assert result.refresh_token_jti == "test_jti_123"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_by_jti_not_found(self, async_test_db):
|
||||
"""Test getting non-existent JTI returns None."""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
result = await session_async.get_by_jti(session, jti="nonexistent")
|
||||
assert result is None
|
||||
|
||||
|
||||
class TestGetActiveByJti:
|
||||
"""Tests for get_active_by_jti method."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_active_by_jti_success(self, async_test_db, async_test_user):
|
||||
"""Test getting active session by JTI."""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
user_session = UserSession(
|
||||
user_id=async_test_user.id,
|
||||
refresh_token_jti="active_jti",
|
||||
device_name="Test Device",
|
||||
ip_address="192.168.1.1",
|
||||
user_agent="Mozilla/5.0",
|
||||
is_active=True,
|
||||
expires_at=datetime.now(timezone.utc) + timedelta(days=7),
|
||||
last_used_at=datetime.now(timezone.utc)
|
||||
)
|
||||
session.add(user_session)
|
||||
await session.commit()
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
result = await session_async.get_active_by_jti(session, jti="active_jti")
|
||||
assert result is not None
|
||||
assert result.is_active is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_active_by_jti_inactive(self, async_test_db, async_test_user):
|
||||
"""Test getting inactive session by JTI returns None."""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
user_session = UserSession(
|
||||
user_id=async_test_user.id,
|
||||
refresh_token_jti="inactive_jti",
|
||||
device_name="Test Device",
|
||||
ip_address="192.168.1.1",
|
||||
user_agent="Mozilla/5.0",
|
||||
is_active=False,
|
||||
expires_at=datetime.now(timezone.utc) + timedelta(days=7),
|
||||
last_used_at=datetime.now(timezone.utc)
|
||||
)
|
||||
session.add(user_session)
|
||||
await session.commit()
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
result = await session_async.get_active_by_jti(session, jti="inactive_jti")
|
||||
assert result is None
|
||||
|
||||
|
||||
class TestGetUserSessions:
|
||||
"""Tests for get_user_sessions method."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_user_sessions_active_only(self, async_test_db, async_test_user):
|
||||
"""Test getting only active user sessions."""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
active = UserSession(
|
||||
user_id=async_test_user.id,
|
||||
refresh_token_jti="active",
|
||||
device_name="Active Device",
|
||||
ip_address="192.168.1.1",
|
||||
user_agent="Mozilla/5.0",
|
||||
is_active=True,
|
||||
expires_at=datetime.now(timezone.utc) + timedelta(days=7),
|
||||
last_used_at=datetime.now(timezone.utc)
|
||||
)
|
||||
inactive = UserSession(
|
||||
user_id=async_test_user.id,
|
||||
refresh_token_jti="inactive",
|
||||
device_name="Inactive Device",
|
||||
ip_address="192.168.1.2",
|
||||
user_agent="Mozilla/5.0",
|
||||
is_active=False,
|
||||
expires_at=datetime.now(timezone.utc) + timedelta(days=7),
|
||||
last_used_at=datetime.now(timezone.utc)
|
||||
)
|
||||
session.add_all([active, inactive])
|
||||
await session.commit()
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
results = await session_async.get_user_sessions(
|
||||
session,
|
||||
user_id=str(async_test_user.id),
|
||||
active_only=True
|
||||
)
|
||||
assert len(results) == 1
|
||||
assert results[0].is_active is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_user_sessions_all(self, async_test_db, async_test_user):
|
||||
"""Test getting all user sessions."""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
for i in range(3):
|
||||
sess = UserSession(
|
||||
user_id=async_test_user.id,
|
||||
refresh_token_jti=f"session_{i}",
|
||||
device_name=f"Device {i}",
|
||||
ip_address="192.168.1.1",
|
||||
user_agent="Mozilla/5.0",
|
||||
is_active=i % 2 == 0,
|
||||
expires_at=datetime.now(timezone.utc) + timedelta(days=7),
|
||||
last_used_at=datetime.now(timezone.utc)
|
||||
)
|
||||
session.add(sess)
|
||||
await session.commit()
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
results = await session_async.get_user_sessions(
|
||||
session,
|
||||
user_id=str(async_test_user.id),
|
||||
active_only=False
|
||||
)
|
||||
assert len(results) == 3
|
||||
|
||||
|
||||
class TestCreateSession:
|
||||
"""Tests for create_session method."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_session_success(self, async_test_db, async_test_user):
|
||||
"""Test successfully creating a session."""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
session_data = SessionCreate(
|
||||
user_id=async_test_user.id,
|
||||
refresh_token_jti="new_jti",
|
||||
device_name="New Device",
|
||||
device_id="device_123",
|
||||
ip_address="192.168.1.100",
|
||||
user_agent="Mozilla/5.0",
|
||||
last_used_at=datetime.now(timezone.utc),
|
||||
expires_at=datetime.now(timezone.utc) + timedelta(days=7),
|
||||
location_city="San Francisco",
|
||||
location_country="USA"
|
||||
)
|
||||
result = await session_async.create_session(session, obj_in=session_data)
|
||||
|
||||
assert result.user_id == async_test_user.id
|
||||
assert result.refresh_token_jti == "new_jti"
|
||||
assert result.is_active is True
|
||||
assert result.location_city == "San Francisco"
|
||||
|
||||
|
||||
class TestDeactivate:
|
||||
"""Tests for deactivate method."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_deactivate_success(self, async_test_db, async_test_user):
|
||||
"""Test successfully deactivating a session."""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
user_session = UserSession(
|
||||
user_id=async_test_user.id,
|
||||
refresh_token_jti="to_deactivate",
|
||||
device_name="Test Device",
|
||||
ip_address="192.168.1.1",
|
||||
user_agent="Mozilla/5.0",
|
||||
is_active=True,
|
||||
expires_at=datetime.now(timezone.utc) + timedelta(days=7),
|
||||
last_used_at=datetime.now(timezone.utc)
|
||||
)
|
||||
session.add(user_session)
|
||||
await session.commit()
|
||||
session_id = user_session.id
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
result = await session_async.deactivate(session, session_id=str(session_id))
|
||||
assert result is not None
|
||||
assert result.is_active is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_deactivate_not_found(self, async_test_db):
|
||||
"""Test deactivating non-existent session returns None."""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
result = await session_async.deactivate(session, session_id=str(uuid4()))
|
||||
assert result is None
|
||||
|
||||
|
||||
class TestDeactivateAllUserSessions:
|
||||
"""Tests for deactivate_all_user_sessions method."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_deactivate_all_user_sessions_success(self, async_test_db, async_test_user):
|
||||
"""Test deactivating all user sessions."""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
for i in range(5):
|
||||
sess = UserSession(
|
||||
user_id=async_test_user.id,
|
||||
refresh_token_jti=f"bulk_{i}",
|
||||
device_name=f"Device {i}",
|
||||
ip_address="192.168.1.1",
|
||||
user_agent="Mozilla/5.0",
|
||||
is_active=True,
|
||||
expires_at=datetime.now(timezone.utc) + timedelta(days=7),
|
||||
last_used_at=datetime.now(timezone.utc)
|
||||
)
|
||||
session.add(sess)
|
||||
await session.commit()
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
count = await session_async.deactivate_all_user_sessions(
|
||||
session,
|
||||
user_id=str(async_test_user.id)
|
||||
)
|
||||
assert count == 5
|
||||
|
||||
|
||||
class TestUpdateLastUsed:
|
||||
"""Tests for update_last_used method."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_last_used_success(self, async_test_db, async_test_user):
|
||||
"""Test updating last_used_at timestamp."""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
user_session = UserSession(
|
||||
user_id=async_test_user.id,
|
||||
refresh_token_jti="update_test",
|
||||
device_name="Test Device",
|
||||
ip_address="192.168.1.1",
|
||||
user_agent="Mozilla/5.0",
|
||||
is_active=True,
|
||||
expires_at=datetime.now(timezone.utc) + timedelta(days=7),
|
||||
last_used_at=datetime.now(timezone.utc) - timedelta(hours=1)
|
||||
)
|
||||
session.add(user_session)
|
||||
await session.commit()
|
||||
await session.refresh(user_session)
|
||||
|
||||
old_time = user_session.last_used_at
|
||||
result = await session_async.update_last_used(session, session=user_session)
|
||||
|
||||
assert result.last_used_at > old_time
|
||||
|
||||
|
||||
class TestGetUserSessionCount:
|
||||
"""Tests for get_user_session_count method."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_user_session_count_success(self, async_test_db, async_test_user):
|
||||
"""Test getting user session count."""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
for i in range(3):
|
||||
sess = UserSession(
|
||||
user_id=async_test_user.id,
|
||||
refresh_token_jti=f"count_{i}",
|
||||
device_name=f"Device {i}",
|
||||
ip_address="192.168.1.1",
|
||||
user_agent="Mozilla/5.0",
|
||||
is_active=True,
|
||||
expires_at=datetime.now(timezone.utc) + timedelta(days=7),
|
||||
last_used_at=datetime.now(timezone.utc)
|
||||
)
|
||||
session.add(sess)
|
||||
await session.commit()
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
count = await session_async.get_user_session_count(
|
||||
session,
|
||||
user_id=str(async_test_user.id)
|
||||
)
|
||||
assert count == 3
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_user_session_count_empty(self, async_test_db):
|
||||
"""Test getting session count for user with no sessions."""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
count = await session_async.get_user_session_count(
|
||||
session,
|
||||
user_id=str(uuid4())
|
||||
)
|
||||
assert count == 0
|
||||
644
backend/tests/crud/test_user_async.py
Normal file
644
backend/tests/crud/test_user_async.py
Normal file
@@ -0,0 +1,644 @@
|
||||
# tests/crud/test_user_async.py
|
||||
"""
|
||||
Comprehensive tests for async user CRUD operations.
|
||||
"""
|
||||
import pytest
|
||||
from datetime import datetime, timezone
|
||||
from uuid import uuid4
|
||||
|
||||
from app.crud.user_async import user_async
|
||||
from app.models.user import User
|
||||
from app.schemas.users import UserCreate, UserUpdate
|
||||
|
||||
|
||||
class TestGetByEmail:
|
||||
"""Tests for get_by_email method."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_by_email_success(self, async_test_db, async_test_user):
|
||||
"""Test getting user by email."""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
result = await user_async.get_by_email(session, email=async_test_user.email)
|
||||
assert result is not None
|
||||
assert result.email == async_test_user.email
|
||||
assert result.id == async_test_user.id
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_by_email_not_found(self, async_test_db):
|
||||
"""Test getting non-existent email returns None."""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
result = await user_async.get_by_email(session, email="nonexistent@example.com")
|
||||
assert result is None
|
||||
|
||||
|
||||
class TestCreate:
|
||||
"""Tests for create method."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_user_success(self, async_test_db):
|
||||
"""Test successfully creating a user."""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
user_data = UserCreate(
|
||||
email="newuser@example.com",
|
||||
password="SecurePass123!",
|
||||
first_name="New",
|
||||
last_name="User",
|
||||
phone_number="+1234567890"
|
||||
)
|
||||
result = await user_async.create(session, obj_in=user_data)
|
||||
|
||||
assert result.email == "newuser@example.com"
|
||||
assert result.first_name == "New"
|
||||
assert result.last_name == "User"
|
||||
assert result.phone_number == "+1234567890"
|
||||
assert result.is_active is True
|
||||
assert result.is_superuser is False
|
||||
assert result.password_hash is not None
|
||||
assert result.password_hash != "SecurePass123!" # Password should be hashed
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_superuser_success(self, async_test_db):
|
||||
"""Test creating a superuser."""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
user_data = UserCreate(
|
||||
email="superuser@example.com",
|
||||
password="SuperPass123!",
|
||||
first_name="Super",
|
||||
last_name="User",
|
||||
is_superuser=True
|
||||
)
|
||||
result = await user_async.create(session, obj_in=user_data)
|
||||
|
||||
assert result.is_superuser is True
|
||||
assert result.email == "superuser@example.com"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_duplicate_email_fails(self, async_test_db, async_test_user):
|
||||
"""Test creating user with duplicate email raises ValueError."""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
user_data = UserCreate(
|
||||
email=async_test_user.email, # Duplicate email
|
||||
password="AnotherPass123!",
|
||||
first_name="Duplicate",
|
||||
last_name="User"
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
await user_async.create(session, obj_in=user_data)
|
||||
|
||||
assert "already exists" in str(exc_info.value).lower()
|
||||
|
||||
|
||||
class TestUpdate:
|
||||
"""Tests for update method."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_user_basic_fields(self, async_test_db, async_test_user):
|
||||
"""Test updating basic user fields."""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
# Get fresh copy of user
|
||||
user = await user_async.get(session, id=str(async_test_user.id))
|
||||
|
||||
update_data = UserUpdate(
|
||||
first_name="Updated",
|
||||
last_name="Name",
|
||||
phone_number="+9876543210"
|
||||
)
|
||||
result = await user_async.update(session, db_obj=user, obj_in=update_data)
|
||||
|
||||
assert result.first_name == "Updated"
|
||||
assert result.last_name == "Name"
|
||||
assert result.phone_number == "+9876543210"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_user_password(self, async_test_db):
|
||||
"""Test updating user password."""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
# Create a fresh user for this test
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
user_data = UserCreate(
|
||||
email="passwordtest@example.com",
|
||||
password="OldPassword123!",
|
||||
first_name="Pass",
|
||||
last_name="Test"
|
||||
)
|
||||
user = await user_async.create(session, obj_in=user_data)
|
||||
user_id = user.id
|
||||
old_password_hash = user.password_hash
|
||||
|
||||
# Update the password
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
user = await user_async.get(session, id=str(user_id))
|
||||
|
||||
update_data = UserUpdate(password="NewDifferentPassword123!")
|
||||
result = await user_async.update(session, db_obj=user, obj_in=update_data)
|
||||
|
||||
await session.refresh(result)
|
||||
assert result.password_hash != old_password_hash
|
||||
assert result.password_hash is not None
|
||||
assert "NewDifferentPassword123!" not in result.password_hash # Should be hashed
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_user_with_dict(self, async_test_db, async_test_user):
|
||||
"""Test updating user with dictionary."""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
user = await user_async.get(session, id=str(async_test_user.id))
|
||||
|
||||
update_dict = {"first_name": "DictUpdate"}
|
||||
result = await user_async.update(session, db_obj=user, obj_in=update_dict)
|
||||
|
||||
assert result.first_name == "DictUpdate"
|
||||
|
||||
|
||||
class TestGetMultiWithTotal:
|
||||
"""Tests for get_multi_with_total method."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_multi_with_total_basic(self, async_test_db, async_test_user):
|
||||
"""Test basic pagination."""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
users, total = await user_async.get_multi_with_total(
|
||||
session,
|
||||
skip=0,
|
||||
limit=10
|
||||
)
|
||||
assert total >= 1
|
||||
assert len(users) >= 1
|
||||
assert any(u.id == async_test_user.id for u in users)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_multi_with_total_sorting_asc(self, async_test_db):
|
||||
"""Test sorting in ascending order."""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
# Create multiple users
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
for i in range(3):
|
||||
user_data = UserCreate(
|
||||
email=f"sort{i}@example.com",
|
||||
password="SecurePass123!",
|
||||
first_name=f"User{i}",
|
||||
last_name="Test"
|
||||
)
|
||||
await user_async.create(session, obj_in=user_data)
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
users, total = await user_async.get_multi_with_total(
|
||||
session,
|
||||
skip=0,
|
||||
limit=10,
|
||||
sort_by="email",
|
||||
sort_order="asc"
|
||||
)
|
||||
|
||||
# Check if sorted (at least the test users)
|
||||
test_users = [u for u in users if u.email.startswith("sort")]
|
||||
if len(test_users) > 1:
|
||||
assert test_users[0].email < test_users[1].email
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_multi_with_total_sorting_desc(self, async_test_db):
|
||||
"""Test sorting in descending order."""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
# Create multiple users
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
for i in range(3):
|
||||
user_data = UserCreate(
|
||||
email=f"desc{i}@example.com",
|
||||
password="SecurePass123!",
|
||||
first_name=f"User{i}",
|
||||
last_name="Test"
|
||||
)
|
||||
await user_async.create(session, obj_in=user_data)
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
users, total = await user_async.get_multi_with_total(
|
||||
session,
|
||||
skip=0,
|
||||
limit=10,
|
||||
sort_by="email",
|
||||
sort_order="desc"
|
||||
)
|
||||
|
||||
# Check if sorted descending (at least the test users)
|
||||
test_users = [u for u in users if u.email.startswith("desc")]
|
||||
if len(test_users) > 1:
|
||||
assert test_users[0].email > test_users[1].email
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_multi_with_total_filtering(self, async_test_db):
|
||||
"""Test filtering by field."""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
# Create active and inactive users
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
active_user = UserCreate(
|
||||
email="active@example.com",
|
||||
password="SecurePass123!",
|
||||
first_name="Active",
|
||||
last_name="User"
|
||||
)
|
||||
await user_async.create(session, obj_in=active_user)
|
||||
|
||||
inactive_user = UserCreate(
|
||||
email="inactive@example.com",
|
||||
password="SecurePass123!",
|
||||
first_name="Inactive",
|
||||
last_name="User"
|
||||
)
|
||||
created_inactive = await user_async.create(session, obj_in=inactive_user)
|
||||
|
||||
# Deactivate the user
|
||||
await user_async.update(
|
||||
session,
|
||||
db_obj=created_inactive,
|
||||
obj_in={"is_active": False}
|
||||
)
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
users, total = await user_async.get_multi_with_total(
|
||||
session,
|
||||
skip=0,
|
||||
limit=100,
|
||||
filters={"is_active": True}
|
||||
)
|
||||
|
||||
# All returned users should be active
|
||||
assert all(u.is_active for u in users)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_multi_with_total_search(self, async_test_db):
|
||||
"""Test search functionality."""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
# Create user with unique name
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
user_data = UserCreate(
|
||||
email="searchable@example.com",
|
||||
password="SecurePass123!",
|
||||
first_name="Searchable",
|
||||
last_name="UserName"
|
||||
)
|
||||
await user_async.create(session, obj_in=user_data)
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
users, total = await user_async.get_multi_with_total(
|
||||
session,
|
||||
skip=0,
|
||||
limit=100,
|
||||
search="Searchable"
|
||||
)
|
||||
|
||||
assert total >= 1
|
||||
assert any(u.first_name == "Searchable" for u in users)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_multi_with_total_pagination(self, async_test_db):
|
||||
"""Test pagination with skip and limit."""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
# Create multiple users
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
for i in range(5):
|
||||
user_data = UserCreate(
|
||||
email=f"page{i}@example.com",
|
||||
password="SecurePass123!",
|
||||
first_name=f"Page{i}",
|
||||
last_name="User"
|
||||
)
|
||||
await user_async.create(session, obj_in=user_data)
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
# Get first page
|
||||
users_page1, total = await user_async.get_multi_with_total(
|
||||
session,
|
||||
skip=0,
|
||||
limit=2
|
||||
)
|
||||
|
||||
# Get second page
|
||||
users_page2, total2 = await user_async.get_multi_with_total(
|
||||
session,
|
||||
skip=2,
|
||||
limit=2
|
||||
)
|
||||
|
||||
# Total should be same
|
||||
assert total == total2
|
||||
# Different users on different pages
|
||||
assert users_page1[0].id != users_page2[0].id
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_multi_with_total_validation_negative_skip(self, async_test_db):
|
||||
"""Test validation fails for negative skip."""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
await user_async.get_multi_with_total(session, skip=-1, limit=10)
|
||||
|
||||
assert "skip must be non-negative" in str(exc_info.value)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_multi_with_total_validation_negative_limit(self, async_test_db):
|
||||
"""Test validation fails for negative limit."""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
await user_async.get_multi_with_total(session, skip=0, limit=-1)
|
||||
|
||||
assert "limit must be non-negative" in str(exc_info.value)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_multi_with_total_validation_max_limit(self, async_test_db):
|
||||
"""Test validation fails for limit > 1000."""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
await user_async.get_multi_with_total(session, skip=0, limit=1001)
|
||||
|
||||
assert "Maximum limit is 1000" in str(exc_info.value)
|
||||
|
||||
|
||||
class TestBulkUpdateStatus:
|
||||
"""Tests for bulk_update_status method."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_bulk_update_status_success(self, async_test_db):
|
||||
"""Test bulk updating user status."""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
# Create multiple users
|
||||
user_ids = []
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
for i in range(3):
|
||||
user_data = UserCreate(
|
||||
email=f"bulk{i}@example.com",
|
||||
password="SecurePass123!",
|
||||
first_name=f"Bulk{i}",
|
||||
last_name="User"
|
||||
)
|
||||
user = await user_async.create(session, obj_in=user_data)
|
||||
user_ids.append(user.id)
|
||||
|
||||
# Bulk deactivate
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
count = await user_async.bulk_update_status(
|
||||
session,
|
||||
user_ids=user_ids,
|
||||
is_active=False
|
||||
)
|
||||
assert count == 3
|
||||
|
||||
# Verify all are inactive
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
for user_id in user_ids:
|
||||
user = await user_async.get(session, id=str(user_id))
|
||||
assert user.is_active is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_bulk_update_status_empty_list(self, async_test_db):
|
||||
"""Test bulk update with empty list returns 0."""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
count = await user_async.bulk_update_status(
|
||||
session,
|
||||
user_ids=[],
|
||||
is_active=False
|
||||
)
|
||||
assert count == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_bulk_update_status_reactivate(self, async_test_db):
|
||||
"""Test bulk reactivating users."""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
# Create inactive user
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
user_data = UserCreate(
|
||||
email="reactivate@example.com",
|
||||
password="SecurePass123!",
|
||||
first_name="Reactivate",
|
||||
last_name="User"
|
||||
)
|
||||
user = await user_async.create(session, obj_in=user_data)
|
||||
# Deactivate
|
||||
await user_async.update(session, db_obj=user, obj_in={"is_active": False})
|
||||
user_id = user.id
|
||||
|
||||
# Reactivate
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
count = await user_async.bulk_update_status(
|
||||
session,
|
||||
user_ids=[user_id],
|
||||
is_active=True
|
||||
)
|
||||
assert count == 1
|
||||
|
||||
# Verify active
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
user = await user_async.get(session, id=str(user_id))
|
||||
assert user.is_active is True
|
||||
|
||||
|
||||
class TestBulkSoftDelete:
|
||||
"""Tests for bulk_soft_delete method."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_bulk_soft_delete_success(self, async_test_db):
|
||||
"""Test bulk soft deleting users."""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
# Create multiple users
|
||||
user_ids = []
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
for i in range(3):
|
||||
user_data = UserCreate(
|
||||
email=f"delete{i}@example.com",
|
||||
password="SecurePass123!",
|
||||
first_name=f"Delete{i}",
|
||||
last_name="User"
|
||||
)
|
||||
user = await user_async.create(session, obj_in=user_data)
|
||||
user_ids.append(user.id)
|
||||
|
||||
# Bulk delete
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
count = await user_async.bulk_soft_delete(
|
||||
session,
|
||||
user_ids=user_ids
|
||||
)
|
||||
assert count == 3
|
||||
|
||||
# Verify all are soft deleted
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
for user_id in user_ids:
|
||||
user = await user_async.get(session, id=str(user_id))
|
||||
assert user.deleted_at is not None
|
||||
assert user.is_active is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_bulk_soft_delete_with_exclusion(self, async_test_db):
|
||||
"""Test bulk soft delete with excluded user."""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
# Create multiple users
|
||||
user_ids = []
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
for i in range(3):
|
||||
user_data = UserCreate(
|
||||
email=f"exclude{i}@example.com",
|
||||
password="SecurePass123!",
|
||||
first_name=f"Exclude{i}",
|
||||
last_name="User"
|
||||
)
|
||||
user = await user_async.create(session, obj_in=user_data)
|
||||
user_ids.append(user.id)
|
||||
|
||||
# Bulk delete, excluding first user
|
||||
exclude_id = user_ids[0]
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
count = await user_async.bulk_soft_delete(
|
||||
session,
|
||||
user_ids=user_ids,
|
||||
exclude_user_id=exclude_id
|
||||
)
|
||||
assert count == 2 # Only 2 deleted
|
||||
|
||||
# Verify excluded user is NOT deleted
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
excluded_user = await user_async.get(session, id=str(exclude_id))
|
||||
assert excluded_user.deleted_at is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_bulk_soft_delete_empty_list(self, async_test_db):
|
||||
"""Test bulk delete with empty list returns 0."""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
count = await user_async.bulk_soft_delete(
|
||||
session,
|
||||
user_ids=[]
|
||||
)
|
||||
assert count == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_bulk_soft_delete_all_excluded(self, async_test_db):
|
||||
"""Test bulk delete where all users are excluded."""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
# Create user
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
user_data = UserCreate(
|
||||
email="onlyuser@example.com",
|
||||
password="SecurePass123!",
|
||||
first_name="Only",
|
||||
last_name="User"
|
||||
)
|
||||
user = await user_async.create(session, obj_in=user_data)
|
||||
user_id = user.id
|
||||
|
||||
# Try to delete but exclude
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
count = await user_async.bulk_soft_delete(
|
||||
session,
|
||||
user_ids=[user_id],
|
||||
exclude_user_id=user_id
|
||||
)
|
||||
assert count == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_bulk_soft_delete_already_deleted(self, async_test_db):
|
||||
"""Test bulk delete doesn't re-delete already deleted users."""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
# Create and delete user
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
user_data = UserCreate(
|
||||
email="predeleted@example.com",
|
||||
password="SecurePass123!",
|
||||
first_name="PreDeleted",
|
||||
last_name="User"
|
||||
)
|
||||
user = await user_async.create(session, obj_in=user_data)
|
||||
user_id = user.id
|
||||
|
||||
# First deletion
|
||||
await user_async.bulk_soft_delete(session, user_ids=[user_id])
|
||||
|
||||
# Try to delete again
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
count = await user_async.bulk_soft_delete(
|
||||
session,
|
||||
user_ids=[user_id]
|
||||
)
|
||||
assert count == 0 # Already deleted
|
||||
|
||||
|
||||
class TestUtilityMethods:
|
||||
"""Tests for utility methods."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_is_active_true(self, async_test_db, async_test_user):
|
||||
"""Test is_active returns True for active user."""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
user = await user_async.get(session, id=str(async_test_user.id))
|
||||
assert user_async.is_active(user) is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_is_active_false(self, async_test_db):
|
||||
"""Test is_active returns False for inactive user."""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
user_data = UserCreate(
|
||||
email="inactive2@example.com",
|
||||
password="SecurePass123!",
|
||||
first_name="Inactive",
|
||||
last_name="User"
|
||||
)
|
||||
user = await user_async.create(session, obj_in=user_data)
|
||||
await user_async.update(session, db_obj=user, obj_in={"is_active": False})
|
||||
|
||||
assert user_async.is_active(user) is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_is_superuser_true(self, async_test_db, async_test_superuser):
|
||||
"""Test is_superuser returns True for superuser."""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
user = await user_async.get(session, id=str(async_test_superuser.id))
|
||||
assert user_async.is_superuser(user) is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_is_superuser_false(self, async_test_db, async_test_user):
|
||||
"""Test is_superuser returns False for regular user."""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
user = await user_async.get(session, id=str(async_test_user.id))
|
||||
assert user_async.is_superuser(user) is False
|
||||
334
backend/tests/services/test_session_cleanup.py
Normal file
334
backend/tests/services/test_session_cleanup.py
Normal file
@@ -0,0 +1,334 @@
|
||||
# tests/services/test_session_cleanup.py
|
||||
"""
|
||||
Comprehensive tests for session cleanup service.
|
||||
"""
|
||||
import pytest
|
||||
import asyncio
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from unittest.mock import patch, MagicMock, AsyncMock
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
from app.models.user_session import UserSession
|
||||
from sqlalchemy import select
|
||||
|
||||
|
||||
class TestCleanupExpiredSessions:
|
||||
"""Tests for cleanup_expired_sessions function."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cleanup_expired_sessions_success(self, async_test_db, async_test_user):
|
||||
"""Test successful cleanup of expired sessions."""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
# Create mix of sessions
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
# 1. Active, not expired (should NOT be deleted)
|
||||
active_session = UserSession(
|
||||
user_id=async_test_user.id,
|
||||
refresh_token_jti="active_jti_123",
|
||||
device_name="Active Device",
|
||||
ip_address="192.168.1.1",
|
||||
user_agent="Mozilla/5.0",
|
||||
is_active=True,
|
||||
expires_at=datetime.now(timezone.utc) + timedelta(days=7),
|
||||
created_at=datetime.now(timezone.utc) - timedelta(days=1),
|
||||
last_used_at=datetime.now(timezone.utc)
|
||||
)
|
||||
|
||||
# 2. Inactive, expired, old (SHOULD be deleted)
|
||||
old_expired_session = UserSession(
|
||||
user_id=async_test_user.id,
|
||||
refresh_token_jti="old_expired_jti",
|
||||
device_name="Old Device",
|
||||
ip_address="192.168.1.2",
|
||||
user_agent="Mozilla/5.0",
|
||||
is_active=False,
|
||||
expires_at=datetime.now(timezone.utc) - timedelta(days=10),
|
||||
created_at=datetime.now(timezone.utc) - timedelta(days=40),
|
||||
last_used_at=datetime.now(timezone.utc)
|
||||
)
|
||||
|
||||
# 3. Inactive, expired, recent (should NOT be deleted - within keep_days)
|
||||
recent_expired_session = UserSession(
|
||||
user_id=async_test_user.id,
|
||||
refresh_token_jti="recent_expired_jti",
|
||||
device_name="Recent Device",
|
||||
ip_address="192.168.1.3",
|
||||
user_agent="Mozilla/5.0",
|
||||
is_active=False,
|
||||
expires_at=datetime.now(timezone.utc) - timedelta(days=1),
|
||||
created_at=datetime.now(timezone.utc) - timedelta(days=5),
|
||||
last_used_at=datetime.now(timezone.utc)
|
||||
)
|
||||
|
||||
session.add_all([active_session, old_expired_session, recent_expired_session])
|
||||
await session.commit()
|
||||
|
||||
# Mock AsyncSessionLocal to return our test session
|
||||
with patch('app.services.session_cleanup.AsyncSessionLocal', return_value=AsyncTestingSessionLocal()):
|
||||
from app.services.session_cleanup import cleanup_expired_sessions
|
||||
deleted_count = await cleanup_expired_sessions(keep_days=30)
|
||||
|
||||
# Should only delete old_expired_session
|
||||
assert deleted_count == 1
|
||||
|
||||
# Verify remaining sessions
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
result = await session.execute(select(UserSession))
|
||||
remaining = result.scalars().all()
|
||||
assert len(remaining) == 2
|
||||
jtis = [s.refresh_token_jti for s in remaining]
|
||||
assert "active_jti_123" in jtis
|
||||
assert "recent_expired_jti" in jtis
|
||||
assert "old_expired_jti" not in jtis
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cleanup_no_sessions_to_delete(self, async_test_db, async_test_user):
|
||||
"""Test cleanup when no sessions meet deletion criteria."""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
active = UserSession(
|
||||
user_id=async_test_user.id,
|
||||
refresh_token_jti="active_only_jti",
|
||||
device_name="Active Device",
|
||||
ip_address="192.168.1.1",
|
||||
user_agent="Mozilla/5.0",
|
||||
is_active=True,
|
||||
expires_at=datetime.now(timezone.utc) + timedelta(days=7),
|
||||
created_at=datetime.now(timezone.utc),
|
||||
last_used_at=datetime.now(timezone.utc)
|
||||
)
|
||||
session.add(active)
|
||||
await session.commit()
|
||||
|
||||
with patch('app.services.session_cleanup.AsyncSessionLocal', return_value=AsyncTestingSessionLocal()):
|
||||
from app.services.session_cleanup import cleanup_expired_sessions
|
||||
deleted_count = await cleanup_expired_sessions(keep_days=30)
|
||||
|
||||
assert deleted_count == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cleanup_empty_database(self, async_test_db):
|
||||
"""Test cleanup with no sessions in database."""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
with patch('app.services.session_cleanup.AsyncSessionLocal', return_value=AsyncTestingSessionLocal()):
|
||||
from app.services.session_cleanup import cleanup_expired_sessions
|
||||
deleted_count = await cleanup_expired_sessions(keep_days=30)
|
||||
|
||||
assert deleted_count == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cleanup_with_keep_days_0(self, async_test_db, async_test_user):
|
||||
"""Test cleanup with keep_days=0 deletes all inactive expired sessions."""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
today_expired = UserSession(
|
||||
user_id=async_test_user.id,
|
||||
refresh_token_jti="today_expired_jti",
|
||||
device_name="Today Expired",
|
||||
ip_address="192.168.1.1",
|
||||
user_agent="Mozilla/5.0",
|
||||
is_active=False,
|
||||
expires_at=datetime.now(timezone.utc) - timedelta(hours=1),
|
||||
created_at=datetime.now(timezone.utc) - timedelta(hours=2),
|
||||
last_used_at=datetime.now(timezone.utc)
|
||||
)
|
||||
session.add(today_expired)
|
||||
await session.commit()
|
||||
|
||||
with patch('app.services.session_cleanup.AsyncSessionLocal', return_value=AsyncTestingSessionLocal()):
|
||||
from app.services.session_cleanup import cleanup_expired_sessions
|
||||
deleted_count = await cleanup_expired_sessions(keep_days=0)
|
||||
|
||||
assert deleted_count == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cleanup_bulk_delete_efficiency(self, async_test_db, async_test_user):
|
||||
"""Test that cleanup uses bulk DELETE for many sessions."""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
# Create 50 expired sessions
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
sessions_to_add = []
|
||||
for i in range(50):
|
||||
expired = UserSession(
|
||||
user_id=async_test_user.id,
|
||||
refresh_token_jti=f"bulk_jti_{i}",
|
||||
device_name=f"Device {i}",
|
||||
ip_address="192.168.1.1",
|
||||
user_agent="Mozilla/5.0",
|
||||
is_active=False,
|
||||
expires_at=datetime.now(timezone.utc) - timedelta(days=10),
|
||||
created_at=datetime.now(timezone.utc) - timedelta(days=40),
|
||||
last_used_at=datetime.now(timezone.utc)
|
||||
)
|
||||
sessions_to_add.append(expired)
|
||||
session.add_all(sessions_to_add)
|
||||
await session.commit()
|
||||
|
||||
with patch('app.services.session_cleanup.AsyncSessionLocal', return_value=AsyncTestingSessionLocal()):
|
||||
from app.services.session_cleanup import cleanup_expired_sessions
|
||||
deleted_count = await cleanup_expired_sessions(keep_days=30)
|
||||
|
||||
assert deleted_count == 50
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cleanup_database_error_returns_zero(self, async_test_db):
|
||||
"""Test cleanup returns 0 on database errors (doesn't crash)."""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
# Mock session_crud.cleanup_expired to raise error
|
||||
with patch('app.services.session_cleanup.AsyncSessionLocal', return_value=AsyncTestingSessionLocal()):
|
||||
with patch('app.services.session_cleanup.session_crud.cleanup_expired') as mock_cleanup:
|
||||
mock_cleanup.side_effect = Exception("Database connection lost")
|
||||
|
||||
from app.services.session_cleanup import cleanup_expired_sessions
|
||||
# Should not crash, should return 0
|
||||
deleted_count = await cleanup_expired_sessions(keep_days=30)
|
||||
|
||||
assert deleted_count == 0
|
||||
|
||||
|
||||
class TestGetSessionStatistics:
|
||||
"""Tests for get_session_statistics function."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_statistics_with_sessions(self, async_test_db, async_test_user):
|
||||
"""Test getting session statistics with various session types."""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
# 2 active, not expired
|
||||
for i in range(2):
|
||||
active = UserSession(
|
||||
user_id=async_test_user.id,
|
||||
refresh_token_jti=f"active_stat_{i}",
|
||||
device_name=f"Active {i}",
|
||||
ip_address="192.168.1.1",
|
||||
user_agent="Mozilla/5.0",
|
||||
is_active=True,
|
||||
expires_at=datetime.now(timezone.utc) + timedelta(days=7),
|
||||
created_at=datetime.now(timezone.utc),
|
||||
last_used_at=datetime.now(timezone.utc)
|
||||
)
|
||||
session.add(active)
|
||||
|
||||
# 3 inactive, expired
|
||||
for i in range(3):
|
||||
inactive = UserSession(
|
||||
user_id=async_test_user.id,
|
||||
refresh_token_jti=f"inactive_stat_{i}",
|
||||
device_name=f"Inactive {i}",
|
||||
ip_address="192.168.1.2",
|
||||
user_agent="Mozilla/5.0",
|
||||
is_active=False,
|
||||
expires_at=datetime.now(timezone.utc) - timedelta(days=1),
|
||||
created_at=datetime.now(timezone.utc) - timedelta(days=2),
|
||||
last_used_at=datetime.now(timezone.utc)
|
||||
)
|
||||
session.add(inactive)
|
||||
|
||||
# 1 active but expired
|
||||
expired_active = UserSession(
|
||||
user_id=async_test_user.id,
|
||||
refresh_token_jti="expired_active_stat",
|
||||
device_name="Expired Active",
|
||||
ip_address="192.168.1.3",
|
||||
user_agent="Mozilla/5.0",
|
||||
is_active=True,
|
||||
expires_at=datetime.now(timezone.utc) - timedelta(hours=1),
|
||||
created_at=datetime.now(timezone.utc) - timedelta(days=1),
|
||||
last_used_at=datetime.now(timezone.utc)
|
||||
)
|
||||
session.add(expired_active)
|
||||
|
||||
await session.commit()
|
||||
|
||||
with patch('app.services.session_cleanup.AsyncSessionLocal', return_value=AsyncTestingSessionLocal()):
|
||||
from app.services.session_cleanup import get_session_statistics
|
||||
stats = await get_session_statistics()
|
||||
|
||||
assert stats["total"] == 6
|
||||
assert stats["active"] == 3 # 2 active + 1 expired but active
|
||||
assert stats["inactive"] == 3
|
||||
assert stats["expired"] == 4 # 3 inactive expired + 1 active expired
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_statistics_empty_database(self, async_test_db):
|
||||
"""Test getting statistics with no sessions."""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
with patch('app.services.session_cleanup.AsyncSessionLocal', return_value=AsyncTestingSessionLocal()):
|
||||
from app.services.session_cleanup import get_session_statistics
|
||||
stats = await get_session_statistics()
|
||||
|
||||
assert stats["total"] == 0
|
||||
assert stats["active"] == 0
|
||||
assert stats["inactive"] == 0
|
||||
assert stats["expired"] == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_statistics_database_error_returns_empty_dict(self, async_test_db):
|
||||
"""Test statistics returns empty dict on database errors."""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
# Create a mock that raises on execute
|
||||
mock_session = AsyncMock()
|
||||
mock_session.execute.side_effect = Exception("Database error")
|
||||
|
||||
@asynccontextmanager
|
||||
async def mock_session_local():
|
||||
yield mock_session
|
||||
|
||||
with patch('app.services.session_cleanup.AsyncSessionLocal', return_value=mock_session_local()):
|
||||
from app.services.session_cleanup import get_session_statistics
|
||||
stats = await get_session_statistics()
|
||||
|
||||
assert stats == {}
|
||||
|
||||
|
||||
class TestConcurrentCleanup:
|
||||
"""Tests for concurrent cleanup scenarios."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_concurrent_cleanup_no_duplicate_deletes(self, async_test_db, async_test_user):
|
||||
"""Test concurrent cleanups don't cause race conditions."""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
# Create 10 expired sessions
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
for i in range(10):
|
||||
expired = UserSession(
|
||||
user_id=async_test_user.id,
|
||||
refresh_token_jti=f"concurrent_jti_{i}",
|
||||
device_name=f"Device {i}",
|
||||
ip_address="192.168.1.1",
|
||||
user_agent="Mozilla/5.0",
|
||||
is_active=False,
|
||||
expires_at=datetime.now(timezone.utc) - timedelta(days=10),
|
||||
created_at=datetime.now(timezone.utc) - timedelta(days=40),
|
||||
last_used_at=datetime.now(timezone.utc)
|
||||
)
|
||||
session.add(expired)
|
||||
await session.commit()
|
||||
|
||||
# Run two cleanups concurrently
|
||||
# Use side_effect to return fresh session instances for each call
|
||||
with patch('app.services.session_cleanup.AsyncSessionLocal', side_effect=lambda: AsyncTestingSessionLocal()):
|
||||
from app.services.session_cleanup import cleanup_expired_sessions
|
||||
results = await asyncio.gather(
|
||||
cleanup_expired_sessions(keep_days=30),
|
||||
cleanup_expired_sessions(keep_days=30)
|
||||
)
|
||||
|
||||
# Both should report deleting sessions (may overlap due to transaction timing)
|
||||
assert sum(results) >= 10
|
||||
|
||||
# Verify all are deleted
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
result = await session.execute(select(UserSession))
|
||||
remaining = result.scalars().all()
|
||||
assert len(remaining) == 0
|
||||
425
backend/tests/utils/test_device.py
Normal file
425
backend/tests/utils/test_device.py
Normal file
@@ -0,0 +1,425 @@
|
||||
# tests/utils/test_device.py
|
||||
"""
|
||||
Comprehensive tests for device utility functions.
|
||||
"""
|
||||
import pytest
|
||||
from unittest.mock import Mock
|
||||
|
||||
from fastapi import Request
|
||||
|
||||
from app.utils.device import (
|
||||
extract_device_info,
|
||||
parse_device_name,
|
||||
extract_browser,
|
||||
get_client_ip,
|
||||
is_mobile_device,
|
||||
get_device_type
|
||||
)
|
||||
|
||||
|
||||
class TestParseDeviceName:
|
||||
"""Tests for parse_device_name function."""
|
||||
|
||||
def test_parse_device_name_empty_string(self):
|
||||
"""Test parsing empty user agent."""
|
||||
result = parse_device_name("")
|
||||
assert result == "Unknown device"
|
||||
|
||||
def test_parse_device_name_iphone(self):
|
||||
"""Test parsing iPhone user agent."""
|
||||
ua = "Mozilla/5.0 (iPhone; CPU iPhone OS 15_0 like Mac OS X)"
|
||||
result = parse_device_name(ua)
|
||||
assert result == "iPhone"
|
||||
|
||||
def test_parse_device_name_ipad(self):
|
||||
"""Test parsing iPad user agent."""
|
||||
ua = "Mozilla/5.0 (iPad; CPU OS 14_0 like Mac OS X)"
|
||||
result = parse_device_name(ua)
|
||||
assert result == "iPad"
|
||||
|
||||
def test_parse_device_name_android_with_model(self):
|
||||
"""Test parsing Android user agent with device model."""
|
||||
ua = "Mozilla/5.0 (Linux; Android 11; SM-G991B Build/RP1A)"
|
||||
result = parse_device_name(ua)
|
||||
assert result == "Android (Sm-G991B)"
|
||||
|
||||
def test_parse_device_name_android_without_model(self):
|
||||
"""Test parsing Android user agent without model."""
|
||||
ua = "Mozilla/5.0 (Linux; Android)"
|
||||
result = parse_device_name(ua)
|
||||
assert result == "Android device"
|
||||
|
||||
def test_parse_device_name_windows_phone(self):
|
||||
"""Test parsing Windows Phone user agent."""
|
||||
ua = "Mozilla/5.0 (Windows Phone 10.0)"
|
||||
result = parse_device_name(ua)
|
||||
assert result == "Windows Phone"
|
||||
|
||||
def test_parse_device_name_mac(self):
|
||||
"""Test parsing Mac user agent."""
|
||||
ua = "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/96.0.4664.110 Safari/537.36"
|
||||
result = parse_device_name(ua)
|
||||
assert result == "Chrome on Mac"
|
||||
|
||||
def test_parse_device_name_windows(self):
|
||||
"""Test parsing Windows user agent."""
|
||||
ua = "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/96.0.4664.110 Safari/537.36"
|
||||
result = parse_device_name(ua)
|
||||
assert result == "Chrome on Windows"
|
||||
|
||||
def test_parse_device_name_linux(self):
|
||||
"""Test parsing Linux user agent."""
|
||||
ua = "Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/96.0.4664.110 Safari/537.36"
|
||||
result = parse_device_name(ua)
|
||||
assert result == "Chrome on Linux"
|
||||
|
||||
def test_parse_device_name_chromebook(self):
|
||||
"""Test parsing Chromebook user agent."""
|
||||
ua = "Mozilla/5.0 (X11; CrOS x86_64 14092.0.0) AppleWebKit/537.36"
|
||||
result = parse_device_name(ua)
|
||||
assert result == "Chromebook"
|
||||
|
||||
def test_parse_device_name_tablet(self):
|
||||
"""Test parsing generic tablet user agent."""
|
||||
ua = "Mozilla/5.0 (Linux; Android 9; Tablet) AppleWebKit/537.36"
|
||||
result = parse_device_name(ua)
|
||||
# Should match tablet first since it's in the string
|
||||
assert "Tablet" in result or "Android" in result
|
||||
|
||||
def test_parse_device_name_smart_tv(self):
|
||||
"""Test parsing Smart TV user agent."""
|
||||
ua = "Mozilla/5.0 (SMART-TV; Linux; Tizen 2.3)"
|
||||
result = parse_device_name(ua)
|
||||
assert result == "Smart TV"
|
||||
|
||||
def test_parse_device_name_playstation(self):
|
||||
"""Test parsing PlayStation user agent."""
|
||||
ua = "Mozilla/5.0 (PlayStation 4 5.50)"
|
||||
result = parse_device_name(ua)
|
||||
assert result == "PlayStation"
|
||||
|
||||
def test_parse_device_name_xbox(self):
|
||||
"""Test parsing Xbox user agent."""
|
||||
ua = "Mozilla/5.0 (Windows NT 10.0; Win64; x64; Xbox; Xbox One)"
|
||||
result = parse_device_name(ua)
|
||||
assert result == "Xbox"
|
||||
|
||||
def test_parse_device_name_nintendo(self):
|
||||
"""Test parsing Nintendo user agent."""
|
||||
ua = "Mozilla/5.0 (Nintendo Switch)"
|
||||
result = parse_device_name(ua)
|
||||
assert result == "Nintendo"
|
||||
|
||||
def test_parse_device_name_unknown(self):
|
||||
"""Test parsing completely unknown user agent."""
|
||||
ua = "SomeRandomBot/1.0"
|
||||
result = parse_device_name(ua)
|
||||
assert result == "Unknown device"
|
||||
|
||||
|
||||
class TestExtractBrowser:
|
||||
"""Tests for extract_browser function."""
|
||||
|
||||
def test_extract_browser_empty_string(self):
|
||||
"""Test extracting browser from empty user agent."""
|
||||
result = extract_browser("")
|
||||
assert result is None
|
||||
|
||||
def test_extract_browser_none(self):
|
||||
"""Test extracting browser from None."""
|
||||
result = extract_browser(None)
|
||||
assert result is None
|
||||
|
||||
def test_extract_browser_edge(self):
|
||||
"""Test extracting Edge browser."""
|
||||
ua = "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/96.0.4664.110 Safari/537.36 Edg/96.0.1054.62"
|
||||
result = extract_browser(ua)
|
||||
assert result == "Edge"
|
||||
|
||||
def test_extract_browser_edge_legacy(self):
|
||||
"""Test extracting legacy Edge browser."""
|
||||
ua = "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 Edge/18.19582"
|
||||
result = extract_browser(ua)
|
||||
assert result == "Edge"
|
||||
|
||||
def test_extract_browser_opera(self):
|
||||
"""Test extracting Opera browser."""
|
||||
ua = "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/96.0.4664.110 Safari/537.36 OPR/82.0.4227.50"
|
||||
result = extract_browser(ua)
|
||||
assert result == "Opera"
|
||||
|
||||
def test_extract_browser_chrome(self):
|
||||
"""Test extracting Chrome browser."""
|
||||
ua = "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/96.0.4664.110 Safari/537.36"
|
||||
result = extract_browser(ua)
|
||||
assert result == "Chrome"
|
||||
|
||||
def test_extract_browser_safari(self):
|
||||
"""Test extracting Safari browser."""
|
||||
ua = "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/605.1.15 (KHTML, like Gecko) Version/15.0 Safari/605.1.15"
|
||||
result = extract_browser(ua)
|
||||
assert result == "Safari"
|
||||
|
||||
def test_extract_browser_firefox(self):
|
||||
"""Test extracting Firefox browser."""
|
||||
ua = "Mozilla/5.0 (Windows NT 10.0; Win64; x64; rv:94.0) Gecko/20100101 Firefox/94.0"
|
||||
result = extract_browser(ua)
|
||||
assert result == "Firefox"
|
||||
|
||||
def test_extract_browser_internet_explorer_msie(self):
|
||||
"""Test extracting Internet Explorer (MSIE)."""
|
||||
ua = "Mozilla/4.0 (compatible; MSIE 7.0; Windows NT 10.0)"
|
||||
result = extract_browser(ua)
|
||||
assert result == "Internet Explorer"
|
||||
|
||||
def test_extract_browser_internet_explorer_trident(self):
|
||||
"""Test extracting Internet Explorer (Trident)."""
|
||||
ua = "Mozilla/5.0 (Windows NT 10.0; Trident/7.0; rv:11.0) like Gecko"
|
||||
result = extract_browser(ua)
|
||||
assert result == "Internet Explorer"
|
||||
|
||||
def test_extract_browser_unknown(self):
|
||||
"""Test extracting from unknown browser."""
|
||||
ua = "SomeRandomBot/1.0"
|
||||
result = extract_browser(ua)
|
||||
assert result is None
|
||||
|
||||
|
||||
class TestGetClientIp:
|
||||
"""Tests for get_client_ip function."""
|
||||
|
||||
def test_get_client_ip_x_forwarded_for_single(self):
|
||||
"""Test getting IP from X-Forwarded-For with single IP."""
|
||||
request = Mock(spec=Request)
|
||||
request.headers = {"x-forwarded-for": "192.168.1.100"}
|
||||
request.client = None
|
||||
|
||||
result = get_client_ip(request)
|
||||
assert result == "192.168.1.100"
|
||||
|
||||
def test_get_client_ip_x_forwarded_for_multiple(self):
|
||||
"""Test getting IP from X-Forwarded-For with multiple IPs."""
|
||||
request = Mock(spec=Request)
|
||||
request.headers = {"x-forwarded-for": "192.168.1.100, 10.0.0.1, 172.16.0.1"}
|
||||
request.client = None
|
||||
|
||||
result = get_client_ip(request)
|
||||
assert result == "192.168.1.100"
|
||||
|
||||
def test_get_client_ip_x_real_ip(self):
|
||||
"""Test getting IP from X-Real-IP."""
|
||||
request = Mock(spec=Request)
|
||||
request.headers = {"x-real-ip": "192.168.1.200"}
|
||||
request.client = None
|
||||
|
||||
result = get_client_ip(request)
|
||||
assert result == "192.168.1.200"
|
||||
|
||||
def test_get_client_ip_direct_connection(self):
|
||||
"""Test getting IP from direct connection."""
|
||||
request = Mock(spec=Request)
|
||||
request.headers = {}
|
||||
request.client = Mock()
|
||||
request.client.host = "192.168.1.50"
|
||||
|
||||
result = get_client_ip(request)
|
||||
assert result == "192.168.1.50"
|
||||
|
||||
def test_get_client_ip_no_client(self):
|
||||
"""Test getting IP when no client info available."""
|
||||
request = Mock(spec=Request)
|
||||
request.headers = {}
|
||||
request.client = None
|
||||
|
||||
result = get_client_ip(request)
|
||||
assert result is None
|
||||
|
||||
def test_get_client_ip_client_no_host(self):
|
||||
"""Test getting IP when client exists but no host."""
|
||||
request = Mock(spec=Request)
|
||||
request.headers = {}
|
||||
request.client = Mock()
|
||||
request.client.host = None
|
||||
|
||||
result = get_client_ip(request)
|
||||
assert result is None
|
||||
|
||||
def test_get_client_ip_priority_x_forwarded_for(self):
|
||||
"""Test that X-Forwarded-For has priority over X-Real-IP."""
|
||||
request = Mock(spec=Request)
|
||||
request.headers = {
|
||||
"x-forwarded-for": "192.168.1.100",
|
||||
"x-real-ip": "192.168.1.200"
|
||||
}
|
||||
request.client = Mock()
|
||||
request.client.host = "192.168.1.50"
|
||||
|
||||
result = get_client_ip(request)
|
||||
assert result == "192.168.1.100"
|
||||
|
||||
def test_get_client_ip_priority_x_real_ip_over_client(self):
|
||||
"""Test that X-Real-IP has priority over client.host."""
|
||||
request = Mock(spec=Request)
|
||||
request.headers = {"x-real-ip": "192.168.1.200"}
|
||||
request.client = Mock()
|
||||
request.client.host = "192.168.1.50"
|
||||
|
||||
result = get_client_ip(request)
|
||||
assert result == "192.168.1.200"
|
||||
|
||||
|
||||
class TestIsMobileDevice:
|
||||
"""Tests for is_mobile_device function."""
|
||||
|
||||
def test_is_mobile_device_empty_string(self):
|
||||
"""Test with empty string."""
|
||||
result = is_mobile_device("")
|
||||
assert result is False
|
||||
|
||||
def test_is_mobile_device_iphone(self):
|
||||
"""Test iPhone user agent."""
|
||||
ua = "Mozilla/5.0 (iPhone; CPU iPhone OS 15_0 like Mac OS X)"
|
||||
result = is_mobile_device(ua)
|
||||
assert result is True
|
||||
|
||||
def test_is_mobile_device_android(self):
|
||||
"""Test Android user agent."""
|
||||
ua = "Mozilla/5.0 (Linux; Android 11)"
|
||||
result = is_mobile_device(ua)
|
||||
assert result is True
|
||||
|
||||
def test_is_mobile_device_ipad(self):
|
||||
"""Test iPad user agent."""
|
||||
ua = "Mozilla/5.0 (iPad; CPU OS 14_0 like Mac OS X)"
|
||||
result = is_mobile_device(ua)
|
||||
assert result is True
|
||||
|
||||
def test_is_mobile_device_desktop(self):
|
||||
"""Test desktop user agent."""
|
||||
ua = "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36"
|
||||
result = is_mobile_device(ua)
|
||||
assert result is False
|
||||
|
||||
def test_is_mobile_device_blackberry(self):
|
||||
"""Test BlackBerry user agent."""
|
||||
ua = "Mozilla/5.0 (BlackBerry; U; BlackBerry 9900)"
|
||||
result = is_mobile_device(ua)
|
||||
assert result is True
|
||||
|
||||
def test_is_mobile_device_windows_phone(self):
|
||||
"""Test Windows Phone user agent."""
|
||||
ua = "Mozilla/5.0 (Windows Phone 10.0)"
|
||||
result = is_mobile_device(ua)
|
||||
assert result is True
|
||||
|
||||
|
||||
class TestGetDeviceType:
|
||||
"""Tests for get_device_type function."""
|
||||
|
||||
def test_get_device_type_empty_string(self):
|
||||
"""Test with empty string."""
|
||||
result = get_device_type("")
|
||||
assert result == "other"
|
||||
|
||||
def test_get_device_type_ipad(self):
|
||||
"""Test iPad returns tablet."""
|
||||
ua = "Mozilla/5.0 (iPad; CPU OS 14_0 like Mac OS X)"
|
||||
result = get_device_type(ua)
|
||||
assert result == "tablet"
|
||||
|
||||
def test_get_device_type_tablet(self):
|
||||
"""Test generic tablet."""
|
||||
ua = "Mozilla/5.0 (Linux; Android 9; Tablet)"
|
||||
result = get_device_type(ua)
|
||||
assert result == "tablet"
|
||||
|
||||
def test_get_device_type_iphone(self):
|
||||
"""Test iPhone returns mobile."""
|
||||
ua = "Mozilla/5.0 (iPhone; CPU iPhone OS 15_0 like Mac OS X)"
|
||||
result = get_device_type(ua)
|
||||
assert result == "mobile"
|
||||
|
||||
def test_get_device_type_android_mobile(self):
|
||||
"""Test Android mobile."""
|
||||
ua = "Mozilla/5.0 (Linux; Android 11; SM-G991B) Mobile"
|
||||
result = get_device_type(ua)
|
||||
assert result == "mobile"
|
||||
|
||||
def test_get_device_type_windows_desktop(self):
|
||||
"""Test Windows desktop."""
|
||||
ua = "Mozilla/5.0 (Windows NT 10.0; Win64; x64)"
|
||||
result = get_device_type(ua)
|
||||
assert result == "desktop"
|
||||
|
||||
def test_get_device_type_mac_desktop(self):
|
||||
"""Test Mac desktop."""
|
||||
ua = "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7)"
|
||||
result = get_device_type(ua)
|
||||
assert result == "desktop"
|
||||
|
||||
def test_get_device_type_linux_desktop(self):
|
||||
"""Test Linux desktop."""
|
||||
ua = "Mozilla/5.0 (X11; Linux x86_64)"
|
||||
result = get_device_type(ua)
|
||||
assert result == "desktop"
|
||||
|
||||
def test_get_device_type_chromebook(self):
|
||||
"""Test Chromebook."""
|
||||
ua = "Mozilla/5.0 (X11; CrOS x86_64 14092.0.0)"
|
||||
result = get_device_type(ua)
|
||||
assert result == "desktop"
|
||||
|
||||
def test_get_device_type_unknown(self):
|
||||
"""Test unknown device."""
|
||||
ua = "SomeRandomBot/1.0"
|
||||
result = get_device_type(ua)
|
||||
assert result == "other"
|
||||
|
||||
|
||||
class TestExtractDeviceInfo:
|
||||
"""Tests for extract_device_info function."""
|
||||
|
||||
def test_extract_device_info_complete(self):
|
||||
"""Test extracting device info with all headers."""
|
||||
request = Mock(spec=Request)
|
||||
request.headers = {
|
||||
"user-agent": "Mozilla/5.0 (iPhone; CPU iPhone OS 15_0 like Mac OS X)",
|
||||
"x-device-id": "device-123-456",
|
||||
"x-forwarded-for": "192.168.1.100"
|
||||
}
|
||||
request.client = None
|
||||
|
||||
result = extract_device_info(request)
|
||||
|
||||
assert result.device_name == "iPhone"
|
||||
assert result.device_id == "device-123-456"
|
||||
assert result.ip_address == "192.168.1.100"
|
||||
assert "iPhone" in result.user_agent
|
||||
assert result.location_city is None
|
||||
assert result.location_country is None
|
||||
|
||||
def test_extract_device_info_minimal(self):
|
||||
"""Test extracting device info with minimal headers."""
|
||||
request = Mock(spec=Request)
|
||||
request.headers = {}
|
||||
request.client = Mock()
|
||||
request.client.host = "127.0.0.1"
|
||||
|
||||
result = extract_device_info(request)
|
||||
|
||||
assert result.device_name == "Unknown device"
|
||||
assert result.device_id is None
|
||||
assert result.ip_address == "127.0.0.1"
|
||||
assert result.user_agent is None
|
||||
|
||||
def test_extract_device_info_long_user_agent(self):
|
||||
"""Test that user agent is truncated to 500 chars."""
|
||||
long_ua = "A" * 600
|
||||
request = Mock(spec=Request)
|
||||
request.headers = {"user-agent": long_ua}
|
||||
request.client = None
|
||||
|
||||
result = extract_device_info(request)
|
||||
|
||||
assert len(result.user_agent) == 500
|
||||
assert result.user_agent == "A" * 500
|
||||
Reference in New Issue
Block a user