diff --git a/backend/app/alembic/env.py b/backend/app/alembic/env.py index ccfa7cd..7084038 100644 --- a/backend/app/alembic/env.py +++ b/backend/app/alembic/env.py @@ -2,12 +2,11 @@ import sys from logging.config import fileConfig from pathlib import Path -from sqlalchemy import engine_from_config, pool, text, create_engine +from alembic import context +from sqlalchemy import create_engine, engine_from_config, pool, text from sqlalchemy.engine.url import make_url from sqlalchemy.exc import OperationalError -from alembic import context - # Get the path to the app directory (parent of 'alembic') app_dir = Path(__file__).resolve().parent.parent # Add the app directory to Python path @@ -66,7 +65,9 @@ def ensure_database_exists(db_url: str) -> None: admin_url = url.set(database="postgres") # CREATE DATABASE cannot run inside a transaction - admin_engine = create_engine(str(admin_url), isolation_level="AUTOCOMMIT", poolclass=pool.NullPool) + admin_engine = create_engine( + str(admin_url), isolation_level="AUTOCOMMIT", poolclass=pool.NullPool + ) try: with admin_engine.connect() as conn: exists = conn.execute( @@ -122,9 +123,7 @@ def run_migrations_online() -> None: ) with connectable.connect() as connection: - context.configure( - connection=connection, target_metadata=target_metadata - ) + context.configure(connection=connection, target_metadata=target_metadata) with context.begin_transaction(): context.run_migrations() @@ -133,4 +132,4 @@ def run_migrations_online() -> None: if context.is_offline_mode(): run_migrations_offline() else: - run_migrations_online() \ No newline at end of file + run_migrations_online() diff --git a/backend/app/alembic/versions/1174fffbe3e4_add_performance_indexes.py b/backend/app/alembic/versions/1174fffbe3e4_add_performance_indexes.py index 9f58956..4aa562a 100644 --- a/backend/app/alembic/versions/1174fffbe3e4_add_performance_indexes.py +++ b/backend/app/alembic/versions/1174fffbe3e4_add_performance_indexes.py @@ -5,17 +5,17 @@ Revises: fbf6318a8a36 Create Date: 2025-11-01 04:15:25.367010 """ -from typing import Sequence, Union + +from collections.abc import Sequence import sqlalchemy as sa - from alembic import op # revision identifiers, used by Alembic. -revision: str = '1174fffbe3e4' -down_revision: Union[str, None] = 'fbf6318a8a36' -branch_labels: Union[str, Sequence[str], None] = None -depends_on: Union[str, Sequence[str], None] = None +revision: str = "1174fffbe3e4" +down_revision: str | None = "fbf6318a8a36" +branch_labels: str | Sequence[str] | None = None +depends_on: str | Sequence[str] | None = None def upgrade() -> None: @@ -24,46 +24,46 @@ def upgrade() -> None: # Index for session cleanup queries # Optimizes: DELETE WHERE is_active = FALSE AND expires_at < now AND created_at < cutoff op.create_index( - 'ix_user_sessions_cleanup', - 'user_sessions', - ['is_active', 'expires_at', 'created_at'], + "ix_user_sessions_cleanup", + "user_sessions", + ["is_active", "expires_at", "created_at"], unique=False, - postgresql_where=sa.text('is_active = false') + postgresql_where=sa.text("is_active = false"), ) # Index for user search queries (basic trigram support without pg_trgm extension) # Optimizes: WHERE email ILIKE '%search%' OR first_name ILIKE '%search%' # Note: For better performance, consider enabling pg_trgm extension op.create_index( - 'ix_users_email_lower', - 'users', - [sa.text('LOWER(email)')], + "ix_users_email_lower", + "users", + [sa.text("LOWER(email)")], unique=False, - postgresql_where=sa.text('deleted_at IS NULL') + postgresql_where=sa.text("deleted_at IS NULL"), ) op.create_index( - 'ix_users_first_name_lower', - 'users', - [sa.text('LOWER(first_name)')], + "ix_users_first_name_lower", + "users", + [sa.text("LOWER(first_name)")], unique=False, - postgresql_where=sa.text('deleted_at IS NULL') + postgresql_where=sa.text("deleted_at IS NULL"), ) op.create_index( - 'ix_users_last_name_lower', - 'users', - [sa.text('LOWER(last_name)')], + "ix_users_last_name_lower", + "users", + [sa.text("LOWER(last_name)")], unique=False, - postgresql_where=sa.text('deleted_at IS NULL') + postgresql_where=sa.text("deleted_at IS NULL"), ) # Index for organization search op.create_index( - 'ix_organizations_name_lower', - 'organizations', - [sa.text('LOWER(name)')], - unique=False + "ix_organizations_name_lower", + "organizations", + [sa.text("LOWER(name)")], + unique=False, ) @@ -71,8 +71,8 @@ def downgrade() -> None: """Remove performance indexes.""" # Drop indexes in reverse order - op.drop_index('ix_organizations_name_lower', table_name='organizations') - op.drop_index('ix_users_last_name_lower', table_name='users') - op.drop_index('ix_users_first_name_lower', table_name='users') - op.drop_index('ix_users_email_lower', table_name='users') - op.drop_index('ix_user_sessions_cleanup', table_name='user_sessions') + op.drop_index("ix_organizations_name_lower", table_name="organizations") + op.drop_index("ix_users_last_name_lower", table_name="users") + op.drop_index("ix_users_first_name_lower", table_name="users") + op.drop_index("ix_users_email_lower", table_name="users") + op.drop_index("ix_user_sessions_cleanup", table_name="user_sessions") diff --git a/backend/app/alembic/versions/2d0fcec3b06d_add_soft_delete_to_users.py b/backend/app/alembic/versions/2d0fcec3b06d_add_soft_delete_to_users.py index da49433..0531bc6 100644 --- a/backend/app/alembic/versions/2d0fcec3b06d_add_soft_delete_to_users.py +++ b/backend/app/alembic/versions/2d0fcec3b06d_add_soft_delete_to_users.py @@ -5,30 +5,32 @@ Revises: 9e4f2a1b8c7d Create Date: 2025-10-30 16:40:21.000021 """ -from typing import Sequence, Union + +from collections.abc import Sequence import sqlalchemy as sa - from alembic import op # revision identifiers, used by Alembic. -revision: str = '2d0fcec3b06d' -down_revision: Union[str, None] = '9e4f2a1b8c7d' -branch_labels: Union[str, Sequence[str], None] = None -depends_on: Union[str, Sequence[str], None] = None +revision: str = "2d0fcec3b06d" +down_revision: str | None = "9e4f2a1b8c7d" +branch_labels: str | Sequence[str] | None = None +depends_on: str | Sequence[str] | None = None def upgrade() -> None: # Add deleted_at column for soft deletes - op.add_column('users', sa.Column('deleted_at', sa.DateTime(timezone=True), nullable=True)) + op.add_column( + "users", sa.Column("deleted_at", sa.DateTime(timezone=True), nullable=True) + ) # Add index on deleted_at for efficient queries - op.create_index('ix_users_deleted_at', 'users', ['deleted_at']) + op.create_index("ix_users_deleted_at", "users", ["deleted_at"]) def downgrade() -> None: # Remove index - op.drop_index('ix_users_deleted_at', table_name='users') + op.drop_index("ix_users_deleted_at", table_name="users") # Remove column - op.drop_column('users', 'deleted_at') + op.drop_column("users", "deleted_at") diff --git a/backend/app/alembic/versions/38bf9e7e74b3_add_all_initial_models.py b/backend/app/alembic/versions/38bf9e7e74b3_add_all_initial_models.py index 1a990d1..d247e8e 100644 --- a/backend/app/alembic/versions/38bf9e7e74b3_add_all_initial_models.py +++ b/backend/app/alembic/versions/38bf9e7e74b3_add_all_initial_models.py @@ -5,42 +5,42 @@ Revises: 7396957cbe80 Create Date: 2025-02-28 09:19:33.212278 """ -from typing import Sequence, Union + +from collections.abc import Sequence import sqlalchemy as sa - from alembic import op # revision identifiers, used by Alembic. -revision: str = '38bf9e7e74b3' -down_revision: Union[str, None] = '7396957cbe80' -branch_labels: Union[str, Sequence[str], None] = None -depends_on: Union[str, Sequence[str], None] = None +revision: str = "38bf9e7e74b3" +down_revision: str | None = "7396957cbe80" +branch_labels: str | Sequence[str] | None = None +depends_on: str | Sequence[str] | None = None def upgrade() -> None: - - op.create_table('users', - sa.Column('email', sa.String(), nullable=False), - sa.Column('password_hash', sa.String(), nullable=False), - sa.Column('first_name', sa.String(), nullable=False), - sa.Column('last_name', sa.String(), nullable=True), - sa.Column('phone_number', sa.String(), nullable=True), - sa.Column('is_active', sa.Boolean(), nullable=False), - sa.Column('is_superuser', sa.Boolean(), nullable=False), - sa.Column('preferences', sa.JSON(), nullable=True), - sa.Column('id', sa.UUID(), nullable=False), - sa.Column('created_at', sa.DateTime(timezone=True), nullable=False), - sa.Column('updated_at', sa.DateTime(timezone=True), nullable=False), - sa.PrimaryKeyConstraint('id') + op.create_table( + "users", + sa.Column("email", sa.String(), nullable=False), + sa.Column("password_hash", sa.String(), nullable=False), + sa.Column("first_name", sa.String(), nullable=False), + sa.Column("last_name", sa.String(), nullable=True), + sa.Column("phone_number", sa.String(), nullable=True), + sa.Column("is_active", sa.Boolean(), nullable=False), + sa.Column("is_superuser", sa.Boolean(), nullable=False), + sa.Column("preferences", sa.JSON(), nullable=True), + sa.Column("id", sa.UUID(), nullable=False), + sa.Column("created_at", sa.DateTime(timezone=True), nullable=False), + sa.Column("updated_at", sa.DateTime(timezone=True), nullable=False), + sa.PrimaryKeyConstraint("id"), ) - op.create_index(op.f('ix_users_email'), 'users', ['email'], unique=True) + op.create_index(op.f("ix_users_email"), "users", ["email"], unique=True) # ### end Alembic commands ### def downgrade() -> None: # ### commands auto generated by Alembic - please adjust! ### - op.drop_index(op.f('ix_users_email'), table_name='users') - op.drop_table('users') + op.drop_index(op.f("ix_users_email"), table_name="users") + op.drop_table("users") # ### end Alembic commands ### diff --git a/backend/app/alembic/versions/549b50ea888d_add_user_sessions_table.py b/backend/app/alembic/versions/549b50ea888d_add_user_sessions_table.py index 7c31b1e..1da902a 100644 --- a/backend/app/alembic/versions/549b50ea888d_add_user_sessions_table.py +++ b/backend/app/alembic/versions/549b50ea888d_add_user_sessions_table.py @@ -5,98 +5,85 @@ Revises: b76c725fc3cf Create Date: 2025-10-31 07:41:18.729544 """ -from typing import Sequence, Union + +from collections.abc import Sequence import sqlalchemy as sa - from alembic import op # revision identifiers, used by Alembic. -revision: str = '549b50ea888d' -down_revision: Union[str, None] = 'b76c725fc3cf' -branch_labels: Union[str, Sequence[str], None] = None -depends_on: Union[str, Sequence[str], None] = None +revision: str = "549b50ea888d" +down_revision: str | None = "b76c725fc3cf" +branch_labels: str | Sequence[str] | None = None +depends_on: str | Sequence[str] | None = None def upgrade() -> None: # Create user_sessions table for per-device session management op.create_table( - 'user_sessions', - sa.Column('id', sa.UUID(), nullable=False), - sa.Column('user_id', sa.UUID(), nullable=False), - sa.Column('refresh_token_jti', sa.String(length=255), nullable=False), - sa.Column('device_name', sa.String(length=255), nullable=True), - sa.Column('device_id', sa.String(length=255), nullable=True), - sa.Column('ip_address', sa.String(length=45), nullable=True), - sa.Column('user_agent', sa.String(length=500), nullable=True), - sa.Column('last_used_at', sa.DateTime(timezone=True), nullable=False), - sa.Column('expires_at', sa.DateTime(timezone=True), nullable=False), - sa.Column('is_active', sa.Boolean(), nullable=False, server_default='true'), - sa.Column('location_city', sa.String(length=100), nullable=True), - sa.Column('location_country', sa.String(length=100), nullable=True), - sa.Column('created_at', sa.DateTime(timezone=True), nullable=False), - sa.Column('updated_at', sa.DateTime(timezone=True), nullable=False), - sa.PrimaryKeyConstraint('id') + "user_sessions", + sa.Column("id", sa.UUID(), nullable=False), + sa.Column("user_id", sa.UUID(), nullable=False), + sa.Column("refresh_token_jti", sa.String(length=255), nullable=False), + sa.Column("device_name", sa.String(length=255), nullable=True), + sa.Column("device_id", sa.String(length=255), nullable=True), + sa.Column("ip_address", sa.String(length=45), nullable=True), + sa.Column("user_agent", sa.String(length=500), nullable=True), + sa.Column("last_used_at", sa.DateTime(timezone=True), nullable=False), + sa.Column("expires_at", sa.DateTime(timezone=True), nullable=False), + sa.Column("is_active", sa.Boolean(), nullable=False, server_default="true"), + sa.Column("location_city", sa.String(length=100), nullable=True), + sa.Column("location_country", sa.String(length=100), nullable=True), + sa.Column("created_at", sa.DateTime(timezone=True), nullable=False), + sa.Column("updated_at", sa.DateTime(timezone=True), nullable=False), + sa.PrimaryKeyConstraint("id"), ) # Create foreign key to users table op.create_foreign_key( - 'fk_user_sessions_user_id', - 'user_sessions', - 'users', - ['user_id'], - ['id'], - ondelete='CASCADE' + "fk_user_sessions_user_id", + "user_sessions", + "users", + ["user_id"], + ["id"], + ondelete="CASCADE", ) # Create indexes for performance # 1. Lookup session by refresh token JTI (most common query) op.create_index( - 'ix_user_sessions_jti', - 'user_sessions', - ['refresh_token_jti'], - unique=True + "ix_user_sessions_jti", "user_sessions", ["refresh_token_jti"], unique=True ) # 2. Lookup sessions by user ID - op.create_index( - 'ix_user_sessions_user_id', - 'user_sessions', - ['user_id'] - ) + op.create_index("ix_user_sessions_user_id", "user_sessions", ["user_id"]) # 3. Composite index for active sessions by user op.create_index( - 'ix_user_sessions_user_active', - 'user_sessions', - ['user_id', 'is_active'] + "ix_user_sessions_user_active", "user_sessions", ["user_id", "is_active"] ) # 4. Index on expires_at for cleanup job - op.create_index( - 'ix_user_sessions_expires_at', - 'user_sessions', - ['expires_at'] - ) + op.create_index("ix_user_sessions_expires_at", "user_sessions", ["expires_at"]) # 5. Composite index for active session lookup by JTI op.create_index( - 'ix_user_sessions_jti_active', - 'user_sessions', - ['refresh_token_jti', 'is_active'] + "ix_user_sessions_jti_active", + "user_sessions", + ["refresh_token_jti", "is_active"], ) def downgrade() -> None: # Drop indexes first - op.drop_index('ix_user_sessions_jti_active', table_name='user_sessions') - op.drop_index('ix_user_sessions_expires_at', table_name='user_sessions') - op.drop_index('ix_user_sessions_user_active', table_name='user_sessions') - op.drop_index('ix_user_sessions_user_id', table_name='user_sessions') - op.drop_index('ix_user_sessions_jti', table_name='user_sessions') + op.drop_index("ix_user_sessions_jti_active", table_name="user_sessions") + op.drop_index("ix_user_sessions_expires_at", table_name="user_sessions") + op.drop_index("ix_user_sessions_user_active", table_name="user_sessions") + op.drop_index("ix_user_sessions_user_id", table_name="user_sessions") + op.drop_index("ix_user_sessions_jti", table_name="user_sessions") # Drop foreign key - op.drop_constraint('fk_user_sessions_user_id', 'user_sessions', type_='foreignkey') + op.drop_constraint("fk_user_sessions_user_id", "user_sessions", type_="foreignkey") # Drop table - op.drop_table('user_sessions') + op.drop_table("user_sessions") diff --git a/backend/app/alembic/versions/7396957cbe80_initial_empty_migration.py b/backend/app/alembic/versions/7396957cbe80_initial_empty_migration.py index 29dde03..b7bd325 100644 --- a/backend/app/alembic/versions/7396957cbe80_initial_empty_migration.py +++ b/backend/app/alembic/versions/7396957cbe80_initial_empty_migration.py @@ -1,19 +1,18 @@ """Initial empty migration Revision ID: 7396957cbe80 -Revises: +Revises: Create Date: 2025-02-27 12:47:46.445313 """ -from typing import Sequence, Union -from alembic import op +from collections.abc import Sequence # revision identifiers, used by Alembic. -revision: str = '7396957cbe80' -down_revision: Union[str, None] = None -branch_labels: Union[str, Sequence[str], None] = None -depends_on: Union[str, Sequence[str], None] = None +revision: str = "7396957cbe80" +down_revision: str | None = None +branch_labels: str | Sequence[str] | None = None +depends_on: str | Sequence[str] | None = None def upgrade() -> None: diff --git a/backend/app/alembic/versions/9e4f2a1b8c7d_add_missing_indexes_and_fix_column_types.py b/backend/app/alembic/versions/9e4f2a1b8c7d_add_missing_indexes_and_fix_column_types.py index dba8ef6..7f45023 100644 --- a/backend/app/alembic/versions/9e4f2a1b8c7d_add_missing_indexes_and_fix_column_types.py +++ b/backend/app/alembic/versions/9e4f2a1b8c7d_add_missing_indexes_and_fix_column_types.py @@ -5,80 +5,112 @@ Revises: 38bf9e7e74b3 Create Date: 2025-10-30 10:00:00.000000 """ -from typing import Sequence, Union + +from collections.abc import Sequence import sqlalchemy as sa - from alembic import op # revision identifiers, used by Alembic. -revision: str = '9e4f2a1b8c7d' -down_revision: Union[str, None] = '38bf9e7e74b3' -branch_labels: Union[str, Sequence[str], None] = None -depends_on: Union[str, Sequence[str], None] = None +revision: str = "9e4f2a1b8c7d" +down_revision: str | None = "38bf9e7e74b3" +branch_labels: str | Sequence[str] | None = None +depends_on: str | Sequence[str] | None = None def upgrade() -> None: # Add missing indexes for is_active and is_superuser - op.create_index(op.f('ix_users_is_active'), 'users', ['is_active'], unique=False) - op.create_index(op.f('ix_users_is_superuser'), 'users', ['is_superuser'], unique=False) + op.create_index(op.f("ix_users_is_active"), "users", ["is_active"], unique=False) + op.create_index( + op.f("ix_users_is_superuser"), "users", ["is_superuser"], unique=False + ) # Fix column types to match model definitions with explicit lengths - op.alter_column('users', 'email', - existing_type=sa.String(), - type_=sa.String(length=255), - nullable=False) + op.alter_column( + "users", + "email", + existing_type=sa.String(), + type_=sa.String(length=255), + nullable=False, + ) - op.alter_column('users', 'password_hash', - existing_type=sa.String(), - type_=sa.String(length=255), - nullable=False) + op.alter_column( + "users", + "password_hash", + existing_type=sa.String(), + type_=sa.String(length=255), + nullable=False, + ) - op.alter_column('users', 'first_name', - existing_type=sa.String(), - type_=sa.String(length=100), - nullable=False, - server_default='user') # Add server default + op.alter_column( + "users", + "first_name", + existing_type=sa.String(), + type_=sa.String(length=100), + nullable=False, + server_default="user", + ) # Add server default - op.alter_column('users', 'last_name', - existing_type=sa.String(), - type_=sa.String(length=100), - nullable=True) + op.alter_column( + "users", + "last_name", + existing_type=sa.String(), + type_=sa.String(length=100), + nullable=True, + ) - op.alter_column('users', 'phone_number', - existing_type=sa.String(), - type_=sa.String(length=20), - nullable=True) + op.alter_column( + "users", + "phone_number", + existing_type=sa.String(), + type_=sa.String(length=20), + nullable=True, + ) def downgrade() -> None: # Revert column types - op.alter_column('users', 'phone_number', - existing_type=sa.String(length=20), - type_=sa.String(), - nullable=True) + op.alter_column( + "users", + "phone_number", + existing_type=sa.String(length=20), + type_=sa.String(), + nullable=True, + ) - op.alter_column('users', 'last_name', - existing_type=sa.String(length=100), - type_=sa.String(), - nullable=True) + op.alter_column( + "users", + "last_name", + existing_type=sa.String(length=100), + type_=sa.String(), + nullable=True, + ) - op.alter_column('users', 'first_name', - existing_type=sa.String(length=100), - type_=sa.String(), - nullable=False, - server_default=None) # Remove server default + op.alter_column( + "users", + "first_name", + existing_type=sa.String(length=100), + type_=sa.String(), + nullable=False, + server_default=None, + ) # Remove server default - op.alter_column('users', 'password_hash', - existing_type=sa.String(length=255), - type_=sa.String(), - nullable=False) + op.alter_column( + "users", + "password_hash", + existing_type=sa.String(length=255), + type_=sa.String(), + nullable=False, + ) - op.alter_column('users', 'email', - existing_type=sa.String(length=255), - type_=sa.String(), - nullable=False) + op.alter_column( + "users", + "email", + existing_type=sa.String(length=255), + type_=sa.String(), + nullable=False, + ) # Drop indexes - op.drop_index(op.f('ix_users_is_superuser'), table_name='users') - op.drop_index(op.f('ix_users_is_active'), table_name='users') + op.drop_index(op.f("ix_users_is_superuser"), table_name="users") + op.drop_index(op.f("ix_users_is_active"), table_name="users") diff --git a/backend/app/alembic/versions/b76c725fc3cf_add_composite_indexes.py b/backend/app/alembic/versions/b76c725fc3cf_add_composite_indexes.py index f7b7207..30e2426 100644 --- a/backend/app/alembic/versions/b76c725fc3cf_add_composite_indexes.py +++ b/backend/app/alembic/versions/b76c725fc3cf_add_composite_indexes.py @@ -5,17 +5,17 @@ Revises: 2d0fcec3b06d Create Date: 2025-10-30 16:41:33.273135 """ -from typing import Sequence, Union + +from collections.abc import Sequence import sqlalchemy as sa - from alembic import op # revision identifiers, used by Alembic. -revision: str = 'b76c725fc3cf' -down_revision: Union[str, None] = '2d0fcec3b06d' -branch_labels: Union[str, Sequence[str], None] = None -depends_on: Union[str, Sequence[str], None] = None +revision: str = "b76c725fc3cf" +down_revision: str | None = "2d0fcec3b06d" +branch_labels: str | Sequence[str] | None = None +depends_on: str | Sequence[str] | None = None def upgrade() -> None: @@ -23,30 +23,26 @@ def upgrade() -> None: # Composite index for filtering active users by role op.create_index( - 'ix_users_active_superuser', - 'users', - ['is_active', 'is_superuser'], - postgresql_where=sa.text('deleted_at IS NULL') + "ix_users_active_superuser", + "users", + ["is_active", "is_superuser"], + postgresql_where=sa.text("deleted_at IS NULL"), ) # Composite index for sorting active users by creation date op.create_index( - 'ix_users_active_created', - 'users', - ['is_active', 'created_at'], - postgresql_where=sa.text('deleted_at IS NULL') + "ix_users_active_created", + "users", + ["is_active", "created_at"], + postgresql_where=sa.text("deleted_at IS NULL"), ) # Composite index for email lookup of non-deleted users - op.create_index( - 'ix_users_email_not_deleted', - 'users', - ['email', 'deleted_at'] - ) + op.create_index("ix_users_email_not_deleted", "users", ["email", "deleted_at"]) def downgrade() -> None: # Remove composite indexes - op.drop_index('ix_users_email_not_deleted', table_name='users') - op.drop_index('ix_users_active_created', table_name='users') - op.drop_index('ix_users_active_superuser', table_name='users') + op.drop_index("ix_users_email_not_deleted", table_name="users") + op.drop_index("ix_users_active_created", table_name="users") + op.drop_index("ix_users_active_superuser", table_name="users") diff --git a/backend/app/alembic/versions/fbf6318a8a36_add_organizations_and_user_organizations.py b/backend/app/alembic/versions/fbf6318a8a36_add_organizations_and_user_organizations.py index ea6af13..64d4299 100644 --- a/backend/app/alembic/versions/fbf6318a8a36_add_organizations_and_user_organizations.py +++ b/backend/app/alembic/versions/fbf6318a8a36_add_organizations_and_user_organizations.py @@ -5,102 +5,123 @@ Revises: 549b50ea888d Create Date: 2025-10-31 12:08:05.141353 """ -from typing import Sequence, Union + +from collections.abc import Sequence import sqlalchemy as sa - from alembic import op # revision identifiers, used by Alembic. -revision: str = 'fbf6318a8a36' -down_revision: Union[str, None] = '549b50ea888d' -branch_labels: Union[str, Sequence[str], None] = None -depends_on: Union[str, Sequence[str], None] = None +revision: str = "fbf6318a8a36" +down_revision: str | None = "549b50ea888d" +branch_labels: str | Sequence[str] | None = None +depends_on: str | Sequence[str] | None = None def upgrade() -> None: # Create organizations table op.create_table( - 'organizations', - sa.Column('id', sa.UUID(), nullable=False), - sa.Column('name', sa.String(length=255), nullable=False), - sa.Column('slug', sa.String(length=255), nullable=False), - sa.Column('description', sa.Text(), nullable=True), - sa.Column('is_active', sa.Boolean(), nullable=False, server_default='true'), - sa.Column('settings', sa.JSON(), nullable=True), - sa.Column('created_at', sa.DateTime(timezone=True), nullable=False), - sa.Column('updated_at', sa.DateTime(timezone=True), nullable=False), - sa.PrimaryKeyConstraint('id') + "organizations", + sa.Column("id", sa.UUID(), nullable=False), + sa.Column("name", sa.String(length=255), nullable=False), + sa.Column("slug", sa.String(length=255), nullable=False), + sa.Column("description", sa.Text(), nullable=True), + sa.Column("is_active", sa.Boolean(), nullable=False, server_default="true"), + sa.Column("settings", sa.JSON(), nullable=True), + sa.Column("created_at", sa.DateTime(timezone=True), nullable=False), + sa.Column("updated_at", sa.DateTime(timezone=True), nullable=False), + sa.PrimaryKeyConstraint("id"), ) # Create indexes for organizations - op.create_index('ix_organizations_name', 'organizations', ['name']) - op.create_index('ix_organizations_slug', 'organizations', ['slug'], unique=True) - op.create_index('ix_organizations_is_active', 'organizations', ['is_active']) - op.create_index('ix_organizations_name_active', 'organizations', ['name', 'is_active']) - op.create_index('ix_organizations_slug_active', 'organizations', ['slug', 'is_active']) + op.create_index("ix_organizations_name", "organizations", ["name"]) + op.create_index("ix_organizations_slug", "organizations", ["slug"], unique=True) + op.create_index("ix_organizations_is_active", "organizations", ["is_active"]) + op.create_index( + "ix_organizations_name_active", "organizations", ["name", "is_active"] + ) + op.create_index( + "ix_organizations_slug_active", "organizations", ["slug", "is_active"] + ) # Create user_organizations junction table op.create_table( - 'user_organizations', - sa.Column('user_id', sa.UUID(), nullable=False), - sa.Column('organization_id', sa.UUID(), nullable=False), - sa.Column('role', sa.Enum('OWNER', 'ADMIN', 'MEMBER', 'GUEST', name='organizationrole'), nullable=False, server_default='MEMBER'), - sa.Column('is_active', sa.Boolean(), nullable=False, server_default='true'), - sa.Column('custom_permissions', sa.String(length=500), nullable=True), - sa.Column('created_at', sa.DateTime(timezone=True), nullable=False), - sa.Column('updated_at', sa.DateTime(timezone=True), nullable=False), - sa.PrimaryKeyConstraint('user_id', 'organization_id') + "user_organizations", + sa.Column("user_id", sa.UUID(), nullable=False), + sa.Column("organization_id", sa.UUID(), nullable=False), + sa.Column( + "role", + sa.Enum("OWNER", "ADMIN", "MEMBER", "GUEST", name="organizationrole"), + nullable=False, + server_default="MEMBER", + ), + sa.Column("is_active", sa.Boolean(), nullable=False, server_default="true"), + sa.Column("custom_permissions", sa.String(length=500), nullable=True), + sa.Column("created_at", sa.DateTime(timezone=True), nullable=False), + sa.Column("updated_at", sa.DateTime(timezone=True), nullable=False), + sa.PrimaryKeyConstraint("user_id", "organization_id"), ) # Create foreign keys op.create_foreign_key( - 'fk_user_organizations_user_id', - 'user_organizations', - 'users', - ['user_id'], - ['id'], - ondelete='CASCADE' + "fk_user_organizations_user_id", + "user_organizations", + "users", + ["user_id"], + ["id"], + ondelete="CASCADE", ) op.create_foreign_key( - 'fk_user_organizations_organization_id', - 'user_organizations', - 'organizations', - ['organization_id'], - ['id'], - ondelete='CASCADE' + "fk_user_organizations_organization_id", + "user_organizations", + "organizations", + ["organization_id"], + ["id"], + ondelete="CASCADE", ) # Create indexes for user_organizations - op.create_index('ix_user_organizations_role', 'user_organizations', ['role']) - op.create_index('ix_user_organizations_is_active', 'user_organizations', ['is_active']) - op.create_index('ix_user_org_user_active', 'user_organizations', ['user_id', 'is_active']) - op.create_index('ix_user_org_org_active', 'user_organizations', ['organization_id', 'is_active']) + op.create_index("ix_user_organizations_role", "user_organizations", ["role"]) + op.create_index( + "ix_user_organizations_is_active", "user_organizations", ["is_active"] + ) + op.create_index( + "ix_user_org_user_active", "user_organizations", ["user_id", "is_active"] + ) + op.create_index( + "ix_user_org_org_active", "user_organizations", ["organization_id", "is_active"] + ) def downgrade() -> None: # Drop indexes for user_organizations - op.drop_index('ix_user_org_org_active', table_name='user_organizations') - op.drop_index('ix_user_org_user_active', table_name='user_organizations') - op.drop_index('ix_user_organizations_is_active', table_name='user_organizations') - op.drop_index('ix_user_organizations_role', table_name='user_organizations') + op.drop_index("ix_user_org_org_active", table_name="user_organizations") + op.drop_index("ix_user_org_user_active", table_name="user_organizations") + op.drop_index("ix_user_organizations_is_active", table_name="user_organizations") + op.drop_index("ix_user_organizations_role", table_name="user_organizations") # Drop foreign keys - op.drop_constraint('fk_user_organizations_organization_id', 'user_organizations', type_='foreignkey') - op.drop_constraint('fk_user_organizations_user_id', 'user_organizations', type_='foreignkey') + op.drop_constraint( + "fk_user_organizations_organization_id", + "user_organizations", + type_="foreignkey", + ) + op.drop_constraint( + "fk_user_organizations_user_id", "user_organizations", type_="foreignkey" + ) # Drop user_organizations table - op.drop_table('user_organizations') + op.drop_table("user_organizations") # Drop indexes for organizations - op.drop_index('ix_organizations_slug_active', table_name='organizations') - op.drop_index('ix_organizations_name_active', table_name='organizations') - op.drop_index('ix_organizations_is_active', table_name='organizations') - op.drop_index('ix_organizations_slug', table_name='organizations') - op.drop_index('ix_organizations_name', table_name='organizations') + op.drop_index("ix_organizations_slug_active", table_name="organizations") + op.drop_index("ix_organizations_name_active", table_name="organizations") + op.drop_index("ix_organizations_is_active", table_name="organizations") + op.drop_index("ix_organizations_slug", table_name="organizations") + op.drop_index("ix_organizations_name", table_name="organizations") # Drop organizations table - op.drop_table('organizations') + op.drop_table("organizations") # Drop enum type - op.execute('DROP TYPE IF EXISTS organizationrole') + op.execute("DROP TYPE IF EXISTS organizationrole") diff --git a/backend/app/api/dependencies/auth.py b/backend/app/api/dependencies/auth.py index 93b7411..5d6a7aa 100755 --- a/backend/app/api/dependencies/auth.py +++ b/backend/app/api/dependencies/auth.py @@ -1,12 +1,10 @@ -from typing import Optional - -from fastapi import Depends, HTTPException, status, Header +from fastapi import Depends, Header, HTTPException, status from fastapi.security import OAuth2PasswordBearer from fastapi.security.utils import get_authorization_scheme_param from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession -from app.core.auth import get_token_data, TokenExpiredError, TokenInvalidError +from app.core.auth import TokenExpiredError, TokenInvalidError, get_token_data from app.core.database import get_db from app.models.user import User @@ -15,8 +13,7 @@ oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/api/v1/auth/login") async def get_current_user( - db: AsyncSession = Depends(get_db), - token: str = Depends(oauth2_scheme) + db: AsyncSession = Depends(get_db), token: str = Depends(oauth2_scheme) ) -> User: """ Get the current authenticated user. @@ -36,21 +33,17 @@ async def get_current_user( token_data = get_token_data(token) # Get user from database - result = await db.execute( - select(User).where(User.id == token_data.user_id) - ) + result = await db.execute(select(User).where(User.id == token_data.user_id)) user = result.scalar_one_or_none() if not user: raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail="User not found" + status_code=status.HTTP_404_NOT_FOUND, detail="User not found" ) if not user.is_active: raise HTTPException( - status_code=status.HTTP_403_FORBIDDEN, - detail="Inactive user" + status_code=status.HTTP_403_FORBIDDEN, detail="Inactive user" ) return user @@ -59,19 +52,17 @@ async def get_current_user( raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Token expired", - headers={"WWW-Authenticate": "Bearer"} + headers={"WWW-Authenticate": "Bearer"}, ) except TokenInvalidError: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Could not validate credentials", - headers={"WWW-Authenticate": "Bearer"} + headers={"WWW-Authenticate": "Bearer"}, ) -def get_current_active_user( - current_user: User = Depends(get_current_user) -) -> User: +def get_current_active_user(current_user: User = Depends(get_current_user)) -> User: """ Check if the current user is active. @@ -86,15 +77,12 @@ def get_current_active_user( """ if not current_user.is_active: raise HTTPException( - status_code=status.HTTP_403_FORBIDDEN, - detail="Inactive user" + status_code=status.HTTP_403_FORBIDDEN, detail="Inactive user" ) return current_user -def get_current_superuser( - current_user: User = Depends(get_current_user) -) -> User: +def get_current_superuser(current_user: User = Depends(get_current_user)) -> User: """ Check if the current user is a superuser. @@ -109,13 +97,12 @@ def get_current_superuser( """ if not current_user.is_superuser: raise HTTPException( - status_code=status.HTTP_403_FORBIDDEN, - detail="Not enough permissions" + status_code=status.HTTP_403_FORBIDDEN, detail="Not enough permissions" ) return current_user -async def get_optional_token(authorization: str = Header(None)) -> Optional[str]: +async def get_optional_token(authorization: str = Header(None)) -> str | None: """ Get the token from the Authorization header without requiring it. @@ -139,9 +126,8 @@ async def get_optional_token(authorization: str = Header(None)) -> Optional[str] async def get_optional_current_user( - db: AsyncSession = Depends(get_db), - token: Optional[str] = Depends(get_optional_token) -) -> Optional[User]: + db: AsyncSession = Depends(get_db), token: str | None = Depends(get_optional_token) +) -> User | None: """ Get the current user if authenticated, otherwise return None. Useful for endpoints that work with both authenticated and unauthenticated users. @@ -158,12 +144,10 @@ async def get_optional_current_user( try: token_data = get_token_data(token) - result = await db.execute( - select(User).where(User.id == token_data.user_id) - ) + result = await db.execute(select(User).where(User.id == token_data.user_id)) user = result.scalar_one_or_none() if not user or not user.is_active: return None return user except (TokenExpiredError, TokenInvalidError): - return None \ No newline at end of file + return None diff --git a/backend/app/api/dependencies/permissions.py b/backend/app/api/dependencies/permissions.py index 6a5ecab..0d5aa40 100755 --- a/backend/app/api/dependencies/permissions.py +++ b/backend/app/api/dependencies/permissions.py @@ -7,7 +7,7 @@ These dependencies are optional and flexible: - Use require_org_role for organization-specific access control - Projects can choose to use these or implement their own permission system """ -from typing import Optional + from uuid import UUID from fastapi import Depends, HTTPException, status @@ -20,9 +20,7 @@ from app.models.user import User from app.models.user_organization import OrganizationRole -def require_superuser( - current_user: User = Depends(get_current_user) -) -> User: +def require_superuser(current_user: User = Depends(get_current_user)) -> User: """ Dependency to ensure the current user is a superuser. @@ -36,7 +34,7 @@ def require_superuser( if not current_user.is_superuser: raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, - detail="Superuser privileges required" + detail="Superuser privileges required", ) return current_user @@ -62,7 +60,7 @@ class OrganizationPermission: self, organization_id: UUID, current_user: User = Depends(get_current_user), - db: AsyncSession = Depends(get_db) + db: AsyncSession = Depends(get_db), ) -> User: """ Check if user has required role in the organization. @@ -84,21 +82,19 @@ class OrganizationPermission: # Get user's role in organization user_role = await organization_crud.get_user_role_in_org( - db, - user_id=current_user.id, - organization_id=organization_id + db, user_id=current_user.id, organization_id=organization_id ) if not user_role: raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, - detail="Not a member of this organization" + detail="Not a member of this organization", ) if user_role not in self.allowed_roles: raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, - detail=f"Role {user_role} not authorized. Required: {self.allowed_roles}" + detail=f"Role {user_role} not authorized. Required: {self.allowed_roles}", ) return current_user @@ -106,18 +102,18 @@ class OrganizationPermission: # Common permission presets for convenience require_org_owner = OrganizationPermission([OrganizationRole.OWNER]) -require_org_admin = OrganizationPermission([OrganizationRole.OWNER, OrganizationRole.ADMIN]) -require_org_member = OrganizationPermission([ - OrganizationRole.OWNER, - OrganizationRole.ADMIN, - OrganizationRole.MEMBER -]) +require_org_admin = OrganizationPermission( + [OrganizationRole.OWNER, OrganizationRole.ADMIN] +) +require_org_member = OrganizationPermission( + [OrganizationRole.OWNER, OrganizationRole.ADMIN, OrganizationRole.MEMBER] +) async def require_org_membership( organization_id: UUID, current_user: User = Depends(get_current_user), - db: AsyncSession = Depends(get_db) + db: AsyncSession = Depends(get_db), ) -> User: """ Ensure user is a member of the organization (any role). @@ -128,15 +124,13 @@ async def require_org_membership( return current_user user_role = await organization_crud.get_user_role_in_org( - db, - user_id=current_user.id, - organization_id=organization_id + db, user_id=current_user.id, organization_id=organization_id ) if not user_role: raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, - detail="Not a member of this organization" + detail="Not a member of this organization", ) return current_user diff --git a/backend/app/api/main.py b/backend/app/api/main.py index 44949e3..135e8c8 100644 --- a/backend/app/api/main.py +++ b/backend/app/api/main.py @@ -1,10 +1,12 @@ from fastapi import APIRouter -from app.api.routes import auth, users, sessions, admin, organizations +from app.api.routes import admin, auth, organizations, sessions, users api_router = APIRouter() api_router.include_router(auth.router, prefix="/auth", tags=["Authentication"]) api_router.include_router(users.router, prefix="/users", tags=["Users"]) api_router.include_router(sessions.router, prefix="/sessions", tags=["Sessions"]) api_router.include_router(admin.router, prefix="/admin", tags=["Admin"]) -api_router.include_router(organizations.router, prefix="/organizations", tags=["Organizations"]) +api_router.include_router( + organizations.router, prefix="/organizations", tags=["Organizations"] +) diff --git a/backend/app/api/routes/admin.py b/backend/app/api/routes/admin.py index b4dde7e..5ff0767 100755 --- a/backend/app/api/routes/admin.py +++ b/backend/app/api/routes/admin.py @@ -5,9 +5,10 @@ Admin-specific endpoints for managing users and organizations. These endpoints require superuser privileges and provide CMS-like functionality for managing the application. """ + import logging from enum import Enum -from typing import Any, List, Optional +from typing import Any from uuid import UUID from fastapi import APIRouter, Depends, Query, status @@ -16,27 +17,32 @@ from sqlalchemy.ext.asyncio import AsyncSession from app.api.dependencies.permissions import require_superuser from app.core.database import get_db -from app.core.exceptions import NotFoundError, DuplicateError, AuthorizationError, ErrorCode +from app.core.exceptions import ( + AuthorizationError, + DuplicateError, + ErrorCode, + NotFoundError, +) from app.crud.organization import organization as organization_crud -from app.crud.user import user as user_crud from app.crud.session import session as session_crud +from app.crud.user import user as user_crud from app.models.user import User from app.models.user_organization import OrganizationRole from app.schemas.common import ( - PaginationParams, - PaginatedResponse, MessageResponse, + PaginatedResponse, + PaginationParams, SortParams, - create_pagination_meta + create_pagination_meta, ) from app.schemas.organizations import ( - OrganizationResponse, OrganizationCreate, + OrganizationMemberResponse, + OrganizationResponse, OrganizationUpdate, - OrganizationMemberResponse ) -from app.schemas.users import UserResponse, UserCreate, UserUpdate from app.schemas.sessions import AdminSessionResponse +from app.schemas.users import UserCreate, UserResponse, UserUpdate logger = logging.getLogger(__name__) @@ -46,6 +52,7 @@ router = APIRouter() # Schemas for bulk operations class BulkAction(str, Enum): """Supported bulk actions.""" + ACTIVATE = "activate" DEACTIVATE = "deactivate" DELETE = "delete" @@ -53,36 +60,41 @@ class BulkAction(str, Enum): class BulkUserAction(BaseModel): """Schema for bulk user actions.""" + action: BulkAction = Field(..., description="Action to perform on selected users") - user_ids: List[UUID] = Field(..., min_items=1, max_items=100, description="List of user IDs (max 100)") + user_ids: list[UUID] = Field( + ..., min_items=1, max_items=100, description="List of user IDs (max 100)" + ) class BulkActionResult(BaseModel): """Result of a bulk action.""" + success: bool affected_count: int failed_count: int message: str - failed_ids: Optional[List[UUID]] = [] + failed_ids: list[UUID] | None = [] # ===== User Management Endpoints ===== + @router.get( "/users", response_model=PaginatedResponse[UserResponse], summary="Admin: List All Users", description="Get paginated list of all users with filtering and search (admin only)", - operation_id="admin_list_users" + operation_id="admin_list_users", ) async def admin_list_users( pagination: PaginationParams = Depends(), sort: SortParams = Depends(), - is_active: Optional[bool] = Query(None, description="Filter by active status"), - is_superuser: Optional[bool] = Query(None, description="Filter by superuser status"), - search: Optional[str] = Query(None, description="Search by email, name"), + is_active: bool | None = Query(None, description="Filter by active status"), + is_superuser: bool | None = Query(None, description="Filter by superuser status"), + search: str | None = Query(None, description="Search by email, name"), admin: User = Depends(require_superuser), - db: AsyncSession = Depends(get_db) + db: AsyncSession = Depends(get_db), ) -> Any: """ List all users with comprehensive filtering and search. @@ -105,20 +117,20 @@ async def admin_list_users( sort_by=sort.sort_by or "created_at", sort_order=sort.sort_order.value if sort.sort_order else "desc", filters=filters if filters else None, - search=search + search=search, ) pagination_meta = create_pagination_meta( total=total, page=pagination.page, limit=pagination.limit, - items_count=len(users) + items_count=len(users), ) return PaginatedResponse(data=users, pagination=pagination_meta) except Exception as e: - logger.error(f"Error listing users (admin): {str(e)}", exc_info=True) + logger.error(f"Error listing users (admin): {e!s}", exc_info=True) raise @@ -128,12 +140,12 @@ async def admin_list_users( status_code=status.HTTP_201_CREATED, summary="Admin: Create User", description="Create a new user (admin only)", - operation_id="admin_create_user" + operation_id="admin_create_user", ) async def admin_create_user( user_in: UserCreate, admin: User = Depends(require_superuser), - db: AsyncSession = Depends(get_db) + db: AsyncSession = Depends(get_db), ) -> Any: """ Create a new user with admin privileges. @@ -145,13 +157,10 @@ async def admin_create_user( logger.info(f"Admin {admin.email} created user {user.email}") return user except ValueError as e: - logger.warning(f"Failed to create user: {str(e)}") - raise NotFoundError( - message=str(e), - error_code=ErrorCode.USER_ALREADY_EXISTS - ) + logger.warning(f"Failed to create user: {e!s}") + raise NotFoundError(message=str(e), error_code=ErrorCode.USER_ALREADY_EXISTS) except Exception as e: - logger.error(f"Error creating user (admin): {str(e)}", exc_info=True) + logger.error(f"Error creating user (admin): {e!s}", exc_info=True) raise @@ -160,19 +169,18 @@ async def admin_create_user( response_model=UserResponse, summary="Admin: Get User Details", description="Get detailed user information (admin only)", - operation_id="admin_get_user" + operation_id="admin_get_user", ) async def admin_get_user( user_id: UUID, admin: User = Depends(require_superuser), - db: AsyncSession = Depends(get_db) + db: AsyncSession = Depends(get_db), ) -> Any: """Get detailed information about a specific user.""" user = await user_crud.get(db, id=user_id) if not user: raise NotFoundError( - message=f"User {user_id} not found", - error_code=ErrorCode.USER_NOT_FOUND + message=f"User {user_id} not found", error_code=ErrorCode.USER_NOT_FOUND ) return user @@ -182,21 +190,20 @@ async def admin_get_user( response_model=UserResponse, summary="Admin: Update User", description="Update user information (admin only)", - operation_id="admin_update_user" + operation_id="admin_update_user", ) async def admin_update_user( user_id: UUID, user_in: UserUpdate, admin: User = Depends(require_superuser), - db: AsyncSession = Depends(get_db) + db: AsyncSession = Depends(get_db), ) -> Any: """Update user information with admin privileges.""" try: user = await user_crud.get(db, id=user_id) if not user: raise NotFoundError( - message=f"User {user_id} not found", - error_code=ErrorCode.USER_NOT_FOUND + message=f"User {user_id} not found", error_code=ErrorCode.USER_NOT_FOUND ) updated_user = await user_crud.update(db, db_obj=user, obj_in=user_in) @@ -206,7 +213,7 @@ async def admin_update_user( except NotFoundError: raise except Exception as e: - logger.error(f"Error updating user (admin): {str(e)}", exc_info=True) + logger.error(f"Error updating user (admin): {e!s}", exc_info=True) raise @@ -215,20 +222,19 @@ async def admin_update_user( response_model=MessageResponse, summary="Admin: Delete User", description="Soft delete a user (admin only)", - operation_id="admin_delete_user" + operation_id="admin_delete_user", ) async def admin_delete_user( user_id: UUID, admin: User = Depends(require_superuser), - db: AsyncSession = Depends(get_db) + db: AsyncSession = Depends(get_db), ) -> Any: """Soft delete a user (sets deleted_at timestamp).""" try: user = await user_crud.get(db, id=user_id) if not user: raise NotFoundError( - message=f"User {user_id} not found", - error_code=ErrorCode.USER_NOT_FOUND + message=f"User {user_id} not found", error_code=ErrorCode.USER_NOT_FOUND ) # Prevent deleting yourself @@ -236,21 +242,20 @@ async def admin_delete_user( # Use AuthorizationError for permission/operation restrictions raise AuthorizationError( message="Cannot delete your own account", - error_code=ErrorCode.OPERATION_FORBIDDEN + error_code=ErrorCode.OPERATION_FORBIDDEN, ) await user_crud.soft_delete(db, id=user_id) logger.info(f"Admin {admin.email} deleted user {user.email}") return MessageResponse( - success=True, - message=f"User {user.email} has been deleted" + success=True, message=f"User {user.email} has been deleted" ) except NotFoundError: raise except Exception as e: - logger.error(f"Error deleting user (admin): {str(e)}", exc_info=True) + logger.error(f"Error deleting user (admin): {e!s}", exc_info=True) raise @@ -259,34 +264,32 @@ async def admin_delete_user( response_model=MessageResponse, summary="Admin: Activate User", description="Activate a user account (admin only)", - operation_id="admin_activate_user" + operation_id="admin_activate_user", ) async def admin_activate_user( user_id: UUID, admin: User = Depends(require_superuser), - db: AsyncSession = Depends(get_db) + db: AsyncSession = Depends(get_db), ) -> Any: """Activate a user account.""" try: user = await user_crud.get(db, id=user_id) if not user: raise NotFoundError( - message=f"User {user_id} not found", - error_code=ErrorCode.USER_NOT_FOUND + message=f"User {user_id} not found", error_code=ErrorCode.USER_NOT_FOUND ) await user_crud.update(db, db_obj=user, obj_in={"is_active": True}) logger.info(f"Admin {admin.email} activated user {user.email}") return MessageResponse( - success=True, - message=f"User {user.email} has been activated" + success=True, message=f"User {user.email} has been activated" ) except NotFoundError: raise except Exception as e: - logger.error(f"Error activating user (admin): {str(e)}", exc_info=True) + logger.error(f"Error activating user (admin): {e!s}", exc_info=True) raise @@ -295,20 +298,19 @@ async def admin_activate_user( response_model=MessageResponse, summary="Admin: Deactivate User", description="Deactivate a user account (admin only)", - operation_id="admin_deactivate_user" + operation_id="admin_deactivate_user", ) async def admin_deactivate_user( user_id: UUID, admin: User = Depends(require_superuser), - db: AsyncSession = Depends(get_db) + db: AsyncSession = Depends(get_db), ) -> Any: """Deactivate a user account.""" try: user = await user_crud.get(db, id=user_id) if not user: raise NotFoundError( - message=f"User {user_id} not found", - error_code=ErrorCode.USER_NOT_FOUND + message=f"User {user_id} not found", error_code=ErrorCode.USER_NOT_FOUND ) # Prevent deactivating yourself @@ -316,21 +318,20 @@ async def admin_deactivate_user( # Use AuthorizationError for permission/operation restrictions raise AuthorizationError( message="Cannot deactivate your own account", - error_code=ErrorCode.OPERATION_FORBIDDEN + error_code=ErrorCode.OPERATION_FORBIDDEN, ) await user_crud.update(db, db_obj=user, obj_in={"is_active": False}) logger.info(f"Admin {admin.email} deactivated user {user.email}") return MessageResponse( - success=True, - message=f"User {user.email} has been deactivated" + success=True, message=f"User {user.email} has been deactivated" ) except NotFoundError: raise except Exception as e: - logger.error(f"Error deactivating user (admin): {str(e)}", exc_info=True) + logger.error(f"Error deactivating user (admin): {e!s}", exc_info=True) raise @@ -339,12 +340,12 @@ async def admin_deactivate_user( response_model=BulkActionResult, summary="Admin: Bulk User Action", description="Perform bulk actions on multiple users (admin only)", - operation_id="admin_bulk_user_action" + operation_id="admin_bulk_user_action", ) async def admin_bulk_user_action( bulk_action: BulkUserAction, admin: User = Depends(require_superuser), - db: AsyncSession = Depends(get_db) + db: AsyncSession = Depends(get_db), ) -> Any: """ Perform bulk actions on multiple users using optimized bulk operations. @@ -356,22 +357,16 @@ async def admin_bulk_user_action( # Use efficient bulk operations instead of loop if bulk_action.action == BulkAction.ACTIVATE: affected_count = await user_crud.bulk_update_status( - db, - user_ids=bulk_action.user_ids, - is_active=True + db, user_ids=bulk_action.user_ids, is_active=True ) elif bulk_action.action == BulkAction.DEACTIVATE: affected_count = await user_crud.bulk_update_status( - db, - user_ids=bulk_action.user_ids, - is_active=False + db, user_ids=bulk_action.user_ids, is_active=False ) elif bulk_action.action == BulkAction.DELETE: # bulk_soft_delete automatically excludes the admin user affected_count = await user_crud.bulk_soft_delete( - db, - user_ids=bulk_action.user_ids, - exclude_user_id=admin.id + db, user_ids=bulk_action.user_ids, exclude_user_id=admin.id ) else: raise ValueError(f"Unsupported bulk action: {bulk_action.action}") @@ -390,29 +385,30 @@ async def admin_bulk_user_action( affected_count=affected_count, failed_count=failed_count, message=f"Bulk {bulk_action.action.value}: {affected_count} users affected, {failed_count} skipped", - failed_ids=None # Bulk operations don't track individual failures + failed_ids=None, # Bulk operations don't track individual failures ) except Exception as e: - logger.error(f"Error in bulk user action: {str(e)}", exc_info=True) + logger.error(f"Error in bulk user action: {e!s}", exc_info=True) raise # ===== Organization Management Endpoints ===== + @router.get( "/organizations", response_model=PaginatedResponse[OrganizationResponse], summary="Admin: List Organizations", description="Get paginated list of all organizations (admin only)", - operation_id="admin_list_organizations" + operation_id="admin_list_organizations", ) async def admin_list_organizations( pagination: PaginationParams = Depends(), - is_active: Optional[bool] = Query(None, description="Filter by active status"), - search: Optional[str] = Query(None, description="Search by name, slug, description"), + is_active: bool | None = Query(None, description="Filter by active status"), + search: str | None = Query(None, description="Search by name, slug, description"), admin: User = Depends(require_superuser), - db: AsyncSession = Depends(get_db) + db: AsyncSession = Depends(get_db), ) -> Any: """List all organizations with filtering and search.""" try: @@ -422,14 +418,14 @@ async def admin_list_organizations( skip=pagination.offset, limit=pagination.limit, is_active=is_active, - search=search + search=search, ) # Build response objects from optimized query results orgs_with_count = [] for item in orgs_with_data: - org = item['organization'] - member_count = item['member_count'] + org = item["organization"] + member_count = item["member_count"] org_dict = { "id": org.id, @@ -440,7 +436,7 @@ async def admin_list_organizations( "settings": org.settings, "created_at": org.created_at, "updated_at": org.updated_at, - "member_count": member_count + "member_count": member_count, } orgs_with_count.append(OrganizationResponse(**org_dict)) @@ -448,13 +444,13 @@ async def admin_list_organizations( total=total, page=pagination.page, limit=pagination.limit, - items_count=len(orgs_with_count) + items_count=len(orgs_with_count), ) return PaginatedResponse(data=orgs_with_count, pagination=pagination_meta) except Exception as e: - logger.error(f"Error listing organizations (admin): {str(e)}", exc_info=True) + logger.error(f"Error listing organizations (admin): {e!s}", exc_info=True) raise @@ -464,12 +460,12 @@ async def admin_list_organizations( status_code=status.HTTP_201_CREATED, summary="Admin: Create Organization", description="Create a new organization (admin only)", - operation_id="admin_create_organization" + operation_id="admin_create_organization", ) async def admin_create_organization( org_in: OrganizationCreate, admin: User = Depends(require_superuser), - db: AsyncSession = Depends(get_db) + db: AsyncSession = Depends(get_db), ) -> Any: """Create a new organization.""" try: @@ -486,18 +482,15 @@ async def admin_create_organization( "settings": org.settings, "created_at": org.created_at, "updated_at": org.updated_at, - "member_count": 0 + "member_count": 0, } return OrganizationResponse(**org_dict) except ValueError as e: - logger.warning(f"Failed to create organization: {str(e)}") - raise NotFoundError( - message=str(e), - error_code=ErrorCode.ALREADY_EXISTS - ) + logger.warning(f"Failed to create organization: {e!s}") + raise NotFoundError(message=str(e), error_code=ErrorCode.ALREADY_EXISTS) except Exception as e: - logger.error(f"Error creating organization (admin): {str(e)}", exc_info=True) + logger.error(f"Error creating organization (admin): {e!s}", exc_info=True) raise @@ -506,19 +499,18 @@ async def admin_create_organization( response_model=OrganizationResponse, summary="Admin: Get Organization Details", description="Get detailed organization information (admin only)", - operation_id="admin_get_organization" + operation_id="admin_get_organization", ) async def admin_get_organization( org_id: UUID, admin: User = Depends(require_superuser), - db: AsyncSession = Depends(get_db) + db: AsyncSession = Depends(get_db), ) -> Any: """Get detailed information about a specific organization.""" org = await organization_crud.get(db, id=org_id) if not org: raise NotFoundError( - message=f"Organization {org_id} not found", - error_code=ErrorCode.NOT_FOUND + message=f"Organization {org_id} not found", error_code=ErrorCode.NOT_FOUND ) org_dict = { @@ -530,7 +522,9 @@ async def admin_get_organization( "settings": org.settings, "created_at": org.created_at, "updated_at": org.updated_at, - "member_count": await organization_crud.get_member_count(db, organization_id=org.id) + "member_count": await organization_crud.get_member_count( + db, organization_id=org.id + ), } return OrganizationResponse(**org_dict) @@ -540,13 +534,13 @@ async def admin_get_organization( response_model=OrganizationResponse, summary="Admin: Update Organization", description="Update organization information (admin only)", - operation_id="admin_update_organization" + operation_id="admin_update_organization", ) async def admin_update_organization( org_id: UUID, org_in: OrganizationUpdate, admin: User = Depends(require_superuser), - db: AsyncSession = Depends(get_db) + db: AsyncSession = Depends(get_db), ) -> Any: """Update organization information.""" try: @@ -554,7 +548,7 @@ async def admin_update_organization( if not org: raise NotFoundError( message=f"Organization {org_id} not found", - error_code=ErrorCode.NOT_FOUND + error_code=ErrorCode.NOT_FOUND, ) updated_org = await organization_crud.update(db, db_obj=org, obj_in=org_in) @@ -569,14 +563,16 @@ async def admin_update_organization( "settings": updated_org.settings, "created_at": updated_org.created_at, "updated_at": updated_org.updated_at, - "member_count": await organization_crud.get_member_count(db, organization_id=updated_org.id) + "member_count": await organization_crud.get_member_count( + db, organization_id=updated_org.id + ), } return OrganizationResponse(**org_dict) except NotFoundError: raise except Exception as e: - logger.error(f"Error updating organization (admin): {str(e)}", exc_info=True) + logger.error(f"Error updating organization (admin): {e!s}", exc_info=True) raise @@ -585,12 +581,12 @@ async def admin_update_organization( response_model=MessageResponse, summary="Admin: Delete Organization", description="Delete an organization (admin only)", - operation_id="admin_delete_organization" + operation_id="admin_delete_organization", ) async def admin_delete_organization( org_id: UUID, admin: User = Depends(require_superuser), - db: AsyncSession = Depends(get_db) + db: AsyncSession = Depends(get_db), ) -> Any: """Delete an organization and all its relationships.""" try: @@ -598,21 +594,20 @@ async def admin_delete_organization( if not org: raise NotFoundError( message=f"Organization {org_id} not found", - error_code=ErrorCode.NOT_FOUND + error_code=ErrorCode.NOT_FOUND, ) await organization_crud.remove(db, id=org_id) logger.info(f"Admin {admin.email} deleted organization {org.name}") return MessageResponse( - success=True, - message=f"Organization {org.name} has been deleted" + success=True, message=f"Organization {org.name} has been deleted" ) except NotFoundError: raise except Exception as e: - logger.error(f"Error deleting organization (admin): {str(e)}", exc_info=True) + logger.error(f"Error deleting organization (admin): {e!s}", exc_info=True) raise @@ -621,14 +616,14 @@ async def admin_delete_organization( response_model=PaginatedResponse[OrganizationMemberResponse], summary="Admin: List Organization Members", description="Get all members of an organization (admin only)", - operation_id="admin_list_organization_members" + operation_id="admin_list_organization_members", ) async def admin_list_organization_members( org_id: UUID, pagination: PaginationParams = Depends(), - is_active: Optional[bool] = Query(True, description="Filter by active status"), + is_active: bool | None = Query(True, description="Filter by active status"), admin: User = Depends(require_superuser), - db: AsyncSession = Depends(get_db) + db: AsyncSession = Depends(get_db), ) -> Any: """List all members of an organization.""" try: @@ -636,7 +631,7 @@ async def admin_list_organization_members( if not org: raise NotFoundError( message=f"Organization {org_id} not found", - error_code=ErrorCode.NOT_FOUND + error_code=ErrorCode.NOT_FOUND, ) members, total = await organization_crud.get_organization_members( @@ -644,7 +639,7 @@ async def admin_list_organization_members( organization_id=org_id, skip=pagination.offset, limit=pagination.limit, - is_active=is_active + is_active=is_active, ) # Convert to response models @@ -654,7 +649,7 @@ async def admin_list_organization_members( total=total, page=pagination.page, limit=pagination.limit, - items_count=len(member_responses) + items_count=len(member_responses), ) return PaginatedResponse(data=member_responses, pagination=pagination_meta) @@ -662,14 +657,19 @@ async def admin_list_organization_members( except NotFoundError: raise except Exception as e: - logger.error(f"Error listing organization members (admin): {str(e)}", exc_info=True) + logger.error( + f"Error listing organization members (admin): {e!s}", exc_info=True + ) raise class AddMemberRequest(BaseModel): """Request to add a member to an organization.""" + user_id: UUID = Field(..., description="User ID to add") - role: OrganizationRole = Field(OrganizationRole.MEMBER, description="Role in organization") + role: OrganizationRole = Field( + OrganizationRole.MEMBER, description="Role in organization" + ) @router.post( @@ -677,13 +677,13 @@ class AddMemberRequest(BaseModel): response_model=MessageResponse, summary="Admin: Add Member to Organization", description="Add a user to an organization (admin only)", - operation_id="admin_add_organization_member" + operation_id="admin_add_organization_member", ) async def admin_add_organization_member( org_id: UUID, request: AddMemberRequest, admin: User = Depends(require_superuser), - db: AsyncSession = Depends(get_db) + db: AsyncSession = Depends(get_db), ) -> Any: """Add a user to an organization.""" try: @@ -691,21 +691,18 @@ async def admin_add_organization_member( if not org: raise NotFoundError( message=f"Organization {org_id} not found", - error_code=ErrorCode.NOT_FOUND + error_code=ErrorCode.NOT_FOUND, ) user = await user_crud.get(db, id=request.user_id) if not user: raise NotFoundError( message=f"User {request.user_id} not found", - error_code=ErrorCode.USER_NOT_FOUND + error_code=ErrorCode.USER_NOT_FOUND, ) await organization_crud.add_user( - db, - organization_id=org_id, - user_id=request.user_id, - role=request.role + db, organization_id=org_id, user_id=request.user_id, role=request.role ) logger.info( @@ -714,22 +711,21 @@ async def admin_add_organization_member( ) return MessageResponse( - success=True, - message=f"User {user.email} added to organization {org.name}" + success=True, message=f"User {user.email} added to organization {org.name}" ) except ValueError as e: - logger.warning(f"Failed to add user to organization: {str(e)}") + logger.warning(f"Failed to add user to organization: {e!s}") # Use DuplicateError for "already exists" scenarios raise DuplicateError( - message=str(e), - error_code=ErrorCode.USER_ALREADY_EXISTS, - field="user_id" + message=str(e), error_code=ErrorCode.USER_ALREADY_EXISTS, field="user_id" ) except NotFoundError: raise except Exception as e: - logger.error(f"Error adding member to organization (admin): {str(e)}", exc_info=True) + logger.error( + f"Error adding member to organization (admin): {e!s}", exc_info=True + ) raise @@ -738,13 +734,13 @@ async def admin_add_organization_member( response_model=MessageResponse, summary="Admin: Remove Member from Organization", description="Remove a user from an organization (admin only)", - operation_id="admin_remove_organization_member" + operation_id="admin_remove_organization_member", ) async def admin_remove_organization_member( org_id: UUID, user_id: UUID, admin: User = Depends(require_superuser), - db: AsyncSession = Depends(get_db) + db: AsyncSession = Depends(get_db), ) -> Any: """Remove a user from an organization.""" try: @@ -752,39 +748,40 @@ async def admin_remove_organization_member( if not org: raise NotFoundError( message=f"Organization {org_id} not found", - error_code=ErrorCode.NOT_FOUND + error_code=ErrorCode.NOT_FOUND, ) user = await user_crud.get(db, id=user_id) if not user: raise NotFoundError( - message=f"User {user_id} not found", - error_code=ErrorCode.USER_NOT_FOUND + message=f"User {user_id} not found", error_code=ErrorCode.USER_NOT_FOUND ) success = await organization_crud.remove_user( - db, - organization_id=org_id, - user_id=user_id + db, organization_id=org_id, user_id=user_id ) if not success: raise NotFoundError( message="User is not a member of this organization", - error_code=ErrorCode.NOT_FOUND + error_code=ErrorCode.NOT_FOUND, ) - logger.info(f"Admin {admin.email} removed user {user.email} from organization {org.name}") + logger.info( + f"Admin {admin.email} removed user {user.email} from organization {org.name}" + ) return MessageResponse( success=True, - message=f"User {user.email} removed from organization {org.name}" + message=f"User {user.email} removed from organization {org.name}", ) except NotFoundError: raise except Exception as e: - logger.error(f"Error removing member from organization (admin): {str(e)}", exc_info=True) + logger.error( + f"Error removing member from organization (admin): {e!s}", exc_info=True + ) raise @@ -792,6 +789,7 @@ async def admin_remove_organization_member( # Session Management Endpoints # ============================================================================ + @router.get( "/sessions", response_model=PaginatedResponse[AdminSessionResponse], @@ -802,13 +800,13 @@ async def admin_remove_organization_member( Returns paginated list of sessions with user information. Useful for admin dashboard statistics and session monitoring. """, - operation_id="admin_list_sessions" + operation_id="admin_list_sessions", ) async def admin_list_sessions( pagination: PaginationParams = Depends(), - is_active: Optional[bool] = Query(None, description="Filter by active status"), + is_active: bool | None = Query(None, description="Filter by active status"), admin: User = Depends(require_superuser), - db: AsyncSession = Depends(get_db) + db: AsyncSession = Depends(get_db), ) -> Any: """List all sessions across all users with filtering and pagination.""" try: @@ -818,7 +816,7 @@ async def admin_list_sessions( skip=pagination.offset, limit=pagination.limit, active_only=is_active if is_active is not None else True, - with_user=True + with_user=True, ) # Build response objects with user information @@ -847,21 +845,23 @@ async def admin_list_sessions( last_used_at=session.last_used_at, created_at=session.created_at, expires_at=session.expires_at, - is_active=session.is_active + is_active=session.is_active, ) session_responses.append(session_response) - logger.info(f"Admin {admin.email} listed {len(session_responses)} sessions (total: {total})") + logger.info( + f"Admin {admin.email} listed {len(session_responses)} sessions (total: {total})" + ) pagination_meta = create_pagination_meta( total=total, page=pagination.page, limit=pagination.limit, - items_count=len(session_responses) + items_count=len(session_responses), ) return PaginatedResponse(data=session_responses, pagination=pagination_meta) except Exception as e: - logger.error(f"Error listing sessions (admin): {str(e)}", exc_info=True) + logger.error(f"Error listing sessions (admin): {e!s}", exc_info=True) raise diff --git a/backend/app/api/routes/auth.py b/backend/app/api/routes/auth.py index c576204..72ba3cb 100755 --- a/backend/app/api/routes/auth.py +++ b/backend/app/api/routes/auth.py @@ -1,39 +1,43 @@ # app/api/routes/auth.py import logging import os -from datetime import datetime, timezone +from datetime import UTC, datetime from typing import Any -from fastapi import APIRouter, Depends, HTTPException, status, Request +from fastapi import APIRouter, Depends, HTTPException, Request, status from fastapi.security import OAuth2PasswordRequestForm from slowapi import Limiter from slowapi.util import get_remote_address from sqlalchemy.ext.asyncio import AsyncSession from app.api.dependencies.auth import get_current_user -from app.core.auth import TokenExpiredError, TokenInvalidError, decode_token -from app.core.auth import get_password_hash +from app.core.auth import ( + TokenExpiredError, + TokenInvalidError, + decode_token, + get_password_hash, +) from app.core.database import get_db from app.core.exceptions import ( AuthenticationError as AuthError, DatabaseError, - ErrorCode + ErrorCode, ) from app.crud.session import session as session_crud from app.crud.user import user as user_crud from app.models.user import User from app.schemas.common import MessageResponse -from app.schemas.sessions import SessionCreate, LogoutRequest +from app.schemas.sessions import LogoutRequest, SessionCreate from app.schemas.users import ( + LoginRequest, + PasswordResetConfirm, + PasswordResetRequest, + RefreshTokenRequest, + Token, UserCreate, UserResponse, - Token, - LoginRequest, - RefreshTokenRequest, - PasswordResetRequest, - PasswordResetConfirm ) -from app.services.auth_service import AuthService, AuthenticationError +from app.services.auth_service import AuthenticationError, AuthService from app.services.email_service import email_service from app.utils.device import extract_device_info from app.utils.security import create_password_reset_token, verify_password_reset_token @@ -54,7 +58,7 @@ async def _create_login_session( request: Request, user: User, tokens: Token, - login_type: str = "login" + login_type: str = "login", ) -> None: """ Create a session record for successful login. @@ -81,8 +85,8 @@ async def _create_login_session( device_id=device_info.device_id, ip_address=device_info.ip_address, user_agent=device_info.user_agent, - last_used_at=datetime.now(timezone.utc), - expires_at=datetime.fromtimestamp(refresh_payload.exp, tz=timezone.utc), + last_used_at=datetime.now(UTC), + expires_at=datetime.fromtimestamp(refresh_payload.exp, tz=UTC), location_city=device_info.location_city, location_country=device_info.location_country, ) @@ -95,15 +99,20 @@ async def _create_login_session( ) except Exception as session_err: # Log but don't fail login if session creation fails - logger.error(f"Failed to create session for {user.email}: {str(session_err)}", exc_info=True) + logger.error( + f"Failed to create session for {user.email}: {session_err!s}", exc_info=True + ) -@router.post("/register", response_model=UserResponse, status_code=status.HTTP_201_CREATED, operation_id="register") +@router.post( + "/register", + response_model=UserResponse, + status_code=status.HTTP_201_CREATED, + operation_id="register", +) @limiter.limit(f"{5 * RATE_MULTIPLIER}/minute") async def register_user( - request: Request, - user_data: UserCreate, - db: AsyncSession = Depends(get_db) + request: Request, user_data: UserCreate, db: AsyncSession = Depends(get_db) ) -> Any: """ Register a new user. @@ -116,25 +125,23 @@ async def register_user( return user except AuthenticationError as e: # SECURITY: Don't reveal if email exists - generic error message - logger.warning(f"Registration failed: {str(e)}") + logger.warning(f"Registration failed: {e!s}") raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, - detail="Registration failed. Please check your information and try again." + detail="Registration failed. Please check your information and try again.", ) except Exception as e: - logger.error(f"Unexpected error during registration: {str(e)}", exc_info=True) + logger.error(f"Unexpected error during registration: {e!s}", exc_info=True) raise DatabaseError( message="An unexpected error occurred. Please try again later.", - error_code=ErrorCode.INTERNAL_ERROR + error_code=ErrorCode.INTERNAL_ERROR, ) @router.post("/login", response_model=Token, operation_id="login") @limiter.limit(f"{10 * RATE_MULTIPLIER}/minute") async def login( - request: Request, - login_data: LoginRequest, - db: AsyncSession = Depends(get_db) + request: Request, login_data: LoginRequest, db: AsyncSession = Depends(get_db) ) -> Any: """ Login with username and password. @@ -146,14 +153,16 @@ async def login( """ try: # Attempt to authenticate the user - user = await AuthService.authenticate_user(db, login_data.email, login_data.password) + user = await AuthService.authenticate_user( + db, login_data.email, login_data.password + ) # Explicitly check for None result and raise correct exception if user is None: logger.warning(f"Invalid login attempt for: {login_data.email}") raise AuthError( message="Invalid email or password", - error_code=ErrorCode.INVALID_CREDENTIALS + error_code=ErrorCode.INVALID_CREDENTIALS, ) # User is authenticated, generate tokens @@ -166,29 +175,26 @@ async def login( except AuthenticationError as e: # Handle specific authentication errors like inactive accounts - logger.warning(f"Authentication failed: {str(e)}") - raise AuthError( - message=str(e), - error_code=ErrorCode.INVALID_CREDENTIALS - ) + logger.warning(f"Authentication failed: {e!s}") + raise AuthError(message=str(e), error_code=ErrorCode.INVALID_CREDENTIALS) except AuthError: # Re-raise custom auth exceptions without modification raise except Exception as e: # Handle unexpected errors - logger.error(f"Unexpected error during login: {str(e)}", exc_info=True) + logger.error(f"Unexpected error during login: {e!s}", exc_info=True) raise DatabaseError( message="An unexpected error occurred. Please try again later.", - error_code=ErrorCode.INTERNAL_ERROR + error_code=ErrorCode.INTERNAL_ERROR, ) -@router.post("/login/oauth", response_model=Token, operation_id='login_oauth') +@router.post("/login/oauth", response_model=Token, operation_id="login_oauth") @limiter.limit("10/minute") async def login_oauth( - request: Request, - form_data: OAuth2PasswordRequestForm = Depends(), - db: AsyncSession = Depends(get_db) + request: Request, + form_data: OAuth2PasswordRequestForm = Depends(), + db: AsyncSession = Depends(get_db), ) -> Any: """ OAuth2-compatible login endpoint, used by the OpenAPI UI. @@ -199,12 +205,14 @@ async def login_oauth( Access and refresh tokens. """ try: - user = await AuthService.authenticate_user(db, form_data.username, form_data.password) + user = await AuthService.authenticate_user( + db, form_data.username, form_data.password + ) if user is None: raise AuthError( message="Invalid email or password", - error_code=ErrorCode.INVALID_CREDENTIALS + error_code=ErrorCode.INVALID_CREDENTIALS, ) # Generate tokens @@ -216,28 +224,25 @@ async def login_oauth( # Return full token response with user data return tokens except AuthenticationError as e: - logger.warning(f"OAuth authentication failed: {str(e)}") - raise AuthError( - message=str(e), - error_code=ErrorCode.INVALID_CREDENTIALS - ) + logger.warning(f"OAuth authentication failed: {e!s}") + raise AuthError(message=str(e), error_code=ErrorCode.INVALID_CREDENTIALS) except AuthError: # Re-raise custom auth exceptions without modification raise except Exception as e: - logger.error(f"Unexpected error during OAuth login: {str(e)}", exc_info=True) + logger.error(f"Unexpected error during OAuth login: {e!s}", exc_info=True) raise DatabaseError( message="An unexpected error occurred. Please try again later.", - error_code=ErrorCode.INTERNAL_ERROR + error_code=ErrorCode.INTERNAL_ERROR, ) @router.post("/refresh", response_model=Token, operation_id="refresh_token") @limiter.limit("30/minute") async def refresh_token( - request: Request, - refresh_data: RefreshTokenRequest, - db: AsyncSession = Depends(get_db) + request: Request, + refresh_data: RefreshTokenRequest, + db: AsyncSession = Depends(get_db), ) -> Any: """ Refresh access token using a refresh token. @@ -249,13 +254,17 @@ async def refresh_token( """ try: # Decode the refresh token to get the JTI - refresh_payload = decode_token(refresh_data.refresh_token, verify_type="refresh") + refresh_payload = decode_token( + refresh_data.refresh_token, verify_type="refresh" + ) # Check if session exists and is active session = await session_crud.get_active_by_jti(db, jti=refresh_payload.jti) if not session: - logger.warning(f"Refresh token used for inactive or non-existent session: {refresh_payload.jti}") + logger.warning( + f"Refresh token used for inactive or non-existent session: {refresh_payload.jti}" + ) raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Session has been revoked. Please log in again.", @@ -274,10 +283,12 @@ async def refresh_token( db, session=session, new_jti=new_refresh_payload.jti, - new_expires_at=datetime.fromtimestamp(new_refresh_payload.exp, tz=timezone.utc) + new_expires_at=datetime.fromtimestamp(new_refresh_payload.exp, tz=UTC), ) except Exception as session_err: - logger.error(f"Failed to update session {session.id}: {str(session_err)}", exc_info=True) + logger.error( + f"Failed to update session {session.id}: {session_err!s}", exc_info=True + ) # Continue anyway - tokens are already issued return tokens @@ -300,10 +311,10 @@ async def refresh_token( # Re-raise HTTP exceptions (like session revoked) raise except Exception as e: - logger.error(f"Unexpected error during token refresh: {str(e)}") + logger.error(f"Unexpected error during token refresh: {e!s}") raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail="An unexpected error occurred. Please try again later." + detail="An unexpected error occurred. Please try again later.", ) @@ -320,13 +331,13 @@ async def refresh_token( **Rate Limit**: 3 requests/minute """, - operation_id="request_password_reset" + operation_id="request_password_reset", ) @limiter.limit("3/minute") async def request_password_reset( request: Request, reset_request: PasswordResetRequest, - db: AsyncSession = Depends(get_db) + db: AsyncSession = Depends(get_db), ) -> Any: """ Request a password reset. @@ -345,26 +356,26 @@ async def request_password_reset( # Send password reset email await email_service.send_password_reset_email( - to_email=user.email, - reset_token=reset_token, - user_name=user.first_name + to_email=user.email, reset_token=reset_token, user_name=user.first_name ) logger.info(f"Password reset requested for {user.email}") else: # Log attempt but don't reveal if email exists - logger.warning(f"Password reset requested for non-existent or inactive email: {reset_request.email}") + logger.warning( + f"Password reset requested for non-existent or inactive email: {reset_request.email}" + ) # Always return success to prevent email enumeration return MessageResponse( success=True, - message="If your email is registered, you will receive a password reset link shortly" + message="If your email is registered, you will receive a password reset link shortly", ) except Exception as e: - logger.error(f"Error processing password reset request: {str(e)}", exc_info=True) + logger.error(f"Error processing password reset request: {e!s}", exc_info=True) # Still return success to prevent information leakage return MessageResponse( success=True, - message="If your email is registered, you will receive a password reset link shortly" + message="If your email is registered, you will receive a password reset link shortly", ) @@ -378,13 +389,13 @@ async def request_password_reset( **Rate Limit**: 5 requests/minute """, - operation_id="confirm_password_reset" + operation_id="confirm_password_reset", ) @limiter.limit("5/minute") async def confirm_password_reset( request: Request, reset_confirm: PasswordResetConfirm, - db: AsyncSession = Depends(get_db) + db: AsyncSession = Depends(get_db), ) -> Any: """ Confirm password reset with token. @@ -398,7 +409,7 @@ async def confirm_password_reset( if not email: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, - detail="Invalid or expired password reset token" + detail="Invalid or expired password reset token", ) # Look up user @@ -406,14 +417,13 @@ async def confirm_password_reset( if not user: raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail="User not found" + status_code=status.HTTP_404_NOT_FOUND, detail="User not found" ) if not user.is_active: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, - detail="User account is inactive" + detail="User account is inactive", ) # Update password @@ -424,29 +434,33 @@ async def confirm_password_reset( # SECURITY: Invalidate all existing sessions after password reset # This prevents stolen sessions from being used after password change from app.crud.session import session as session_crud + try: deactivated_count = await session_crud.deactivate_all_user_sessions( - db, - user_id=str(user.id) + db, user_id=str(user.id) + ) + logger.info( + f"Password reset successful for {user.email}, invalidated {deactivated_count} sessions" ) - logger.info(f"Password reset successful for {user.email}, invalidated {deactivated_count} sessions") except Exception as session_error: # Log but don't fail password reset if session invalidation fails - logger.error(f"Failed to invalidate sessions after password reset: {str(session_error)}") + logger.error( + f"Failed to invalidate sessions after password reset: {session_error!s}" + ) return MessageResponse( success=True, - message="Password has been reset successfully. All devices have been logged out for security. You can now log in with your new password." + message="Password has been reset successfully. All devices have been logged out for security. You can now log in with your new password.", ) except HTTPException: raise except Exception as e: - logger.error(f"Error confirming password reset: {str(e)}", exc_info=True) + logger.error(f"Error confirming password reset: {e!s}", exc_info=True) await db.rollback() raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail="An error occurred while resetting your password" + detail="An error occurred while resetting your password", ) @@ -464,14 +478,14 @@ async def confirm_password_reset( **Rate Limit**: 10 requests/minute """, - operation_id="logout" + operation_id="logout", ) @limiter.limit("10/minute") async def logout( request: Request, logout_request: LogoutRequest, current_user: User = Depends(get_current_user), - db: AsyncSession = Depends(get_db) + db: AsyncSession = Depends(get_db), ) -> Any: """ Logout from current device by deactivating the session. @@ -487,15 +501,14 @@ async def logout( try: # Decode refresh token to get JTI try: - refresh_payload = decode_token(logout_request.refresh_token, verify_type="refresh") + refresh_payload = decode_token( + logout_request.refresh_token, verify_type="refresh" + ) except (TokenExpiredError, TokenInvalidError) as e: # Even if token is expired/invalid, try to deactivate session - logger.warning(f"Logout with invalid/expired token: {str(e)}") + logger.warning(f"Logout with invalid/expired token: {e!s}") # Don't fail - return success anyway - return MessageResponse( - success=True, - message="Logged out successfully" - ) + return MessageResponse(success=True, message="Logged out successfully") # Find the session by JTI session = await session_crud.get_by_jti(db, jti=refresh_payload.jti) @@ -509,7 +522,7 @@ async def logout( ) raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, - detail="You can only logout your own sessions" + detail="You can only logout your own sessions", ) # Deactivate the session @@ -522,22 +535,20 @@ async def logout( else: # Session not found - maybe already deleted or never existed # Return success anyway (idempotent) - logger.info(f"Logout requested for non-existent session (JTI: {refresh_payload.jti})") + logger.info( + f"Logout requested for non-existent session (JTI: {refresh_payload.jti})" + ) - return MessageResponse( - success=True, - message="Logged out successfully" - ) + return MessageResponse(success=True, message="Logged out successfully") except HTTPException: raise except Exception as e: - logger.error(f"Error during logout for user {current_user.id}: {str(e)}", exc_info=True) - # Don't expose error details - return MessageResponse( - success=True, - message="Logged out successfully" + logger.error( + f"Error during logout for user {current_user.id}: {e!s}", exc_info=True ) + # Don't expose error details + return MessageResponse(success=True, message="Logged out successfully") @router.post( @@ -553,13 +564,13 @@ async def logout( **Rate Limit**: 5 requests/minute """, - operation_id="logout_all" + operation_id="logout_all", ) @limiter.limit("5/minute") async def logout_all( request: Request, current_user: User = Depends(get_current_user), - db: AsyncSession = Depends(get_db) + db: AsyncSession = Depends(get_db), ) -> Any: """ Logout from all devices by deactivating all user sessions. @@ -573,19 +584,25 @@ async def logout_all( """ try: # Deactivate all sessions for this user - count = await session_crud.deactivate_all_user_sessions(db, user_id=str(current_user.id)) + count = await session_crud.deactivate_all_user_sessions( + db, user_id=str(current_user.id) + ) - logger.info(f"User {current_user.id} logged out from all devices ({count} sessions)") + logger.info( + f"User {current_user.id} logged out from all devices ({count} sessions)" + ) return MessageResponse( success=True, - message=f"Successfully logged out from all devices ({count} sessions terminated)" + message=f"Successfully logged out from all devices ({count} sessions terminated)", ) except Exception as e: - logger.error(f"Error during logout-all for user {current_user.id}: {str(e)}", exc_info=True) + logger.error( + f"Error during logout-all for user {current_user.id}: {e!s}", exc_info=True + ) await db.rollback() raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail="An error occurred while logging out" + detail="An error occurred while logging out", ) diff --git a/backend/app/api/routes/organizations.py b/backend/app/api/routes/organizations.py index 4a559f3..6d15c0d 100755 --- a/backend/app/api/routes/organizations.py +++ b/backend/app/api/routes/organizations.py @@ -4,8 +4,9 @@ Organization endpoints for regular users. These endpoints allow users to view and manage organizations they belong to. """ + import logging -from typing import Any, List +from typing import Any from uuid import UUID from fastapi import APIRouter, Depends, Query @@ -14,18 +15,18 @@ from sqlalchemy.ext.asyncio import AsyncSession from app.api.dependencies.auth import get_current_user from app.api.dependencies.permissions import require_org_admin, require_org_membership from app.core.database import get_db -from app.core.exceptions import NotFoundError, ErrorCode +from app.core.exceptions import ErrorCode, NotFoundError from app.crud.organization import organization as organization_crud from app.models.user import User from app.schemas.common import ( - PaginationParams, PaginatedResponse, - create_pagination_meta + PaginationParams, + create_pagination_meta, ) from app.schemas.organizations import ( - OrganizationResponse, OrganizationMemberResponse, - OrganizationUpdate + OrganizationResponse, + OrganizationUpdate, ) logger = logging.getLogger(__name__) @@ -35,15 +36,15 @@ router = APIRouter() @router.get( "/me", - response_model=List[OrganizationResponse], + response_model=list[OrganizationResponse], summary="Get My Organizations", description="Get all organizations the current user belongs to", - operation_id="get_my_organizations" + operation_id="get_my_organizations", ) async def get_my_organizations( is_active: bool = Query(True, description="Filter by active membership"), current_user: User = Depends(get_current_user), - db: AsyncSession = Depends(get_db) + db: AsyncSession = Depends(get_db), ) -> Any: """ Get all organizations the current user belongs to. @@ -54,15 +55,13 @@ async def get_my_organizations( try: # Get all org data in single query with JOIN and subquery orgs_data = await organization_crud.get_user_organizations_with_details( - db, - user_id=current_user.id, - is_active=is_active + db, user_id=current_user.id, is_active=is_active ) # Transform to response objects orgs_with_data = [] for item in orgs_data: - org = item['organization'] + org = item["organization"] org_dict = { "id": org.id, "name": org.name, @@ -72,14 +71,14 @@ async def get_my_organizations( "settings": org.settings, "created_at": org.created_at, "updated_at": org.updated_at, - "member_count": item['member_count'] + "member_count": item["member_count"], } orgs_with_data.append(OrganizationResponse(**org_dict)) return orgs_with_data except Exception as e: - logger.error(f"Error getting user organizations: {str(e)}", exc_info=True) + logger.error(f"Error getting user organizations: {e!s}", exc_info=True) raise @@ -88,12 +87,12 @@ async def get_my_organizations( response_model=OrganizationResponse, summary="Get Organization Details", description="Get details of an organization the user belongs to", - operation_id="get_organization" + operation_id="get_organization", ) async def get_organization( organization_id: UUID, current_user: User = Depends(require_org_membership), - db: AsyncSession = Depends(get_db) + db: AsyncSession = Depends(get_db), ) -> Any: """ Get details of a specific organization. @@ -105,7 +104,7 @@ async def get_organization( if not org: # pragma: no cover - Permission check prevents this (see docs/UNREACHABLE_DEFENSIVE_CODE_ANALYSIS.md) raise NotFoundError( detail=f"Organization {organization_id} not found", - error_code=ErrorCode.NOT_FOUND + error_code=ErrorCode.NOT_FOUND, ) org_dict = { @@ -117,14 +116,16 @@ async def get_organization( "settings": org.settings, "created_at": org.created_at, "updated_at": org.updated_at, - "member_count": await organization_crud.get_member_count(db, organization_id=org.id) + "member_count": await organization_crud.get_member_count( + db, organization_id=org.id + ), } return OrganizationResponse(**org_dict) except NotFoundError: # pragma: no cover - See above raise except Exception as e: - logger.error(f"Error getting organization: {str(e)}", exc_info=True) + logger.error(f"Error getting organization: {e!s}", exc_info=True) raise @@ -133,14 +134,14 @@ async def get_organization( response_model=PaginatedResponse[OrganizationMemberResponse], summary="Get Organization Members", description="Get all members of an organization (members can view)", - operation_id="get_organization_members" + operation_id="get_organization_members", ) async def get_organization_members( organization_id: UUID, pagination: PaginationParams = Depends(), is_active: bool = Query(True, description="Filter by active status"), current_user: User = Depends(require_org_membership), - db: AsyncSession = Depends(get_db) + db: AsyncSession = Depends(get_db), ) -> Any: """ Get all members of an organization. @@ -153,7 +154,7 @@ async def get_organization_members( organization_id=organization_id, skip=pagination.offset, limit=pagination.limit, - is_active=is_active + is_active=is_active, ) member_responses = [OrganizationMemberResponse(**member) for member in members] @@ -162,13 +163,13 @@ async def get_organization_members( total=total, page=pagination.page, limit=pagination.limit, - items_count=len(member_responses) + items_count=len(member_responses), ) return PaginatedResponse(data=member_responses, pagination=pagination_meta) except Exception as e: - logger.error(f"Error getting organization members: {str(e)}", exc_info=True) + logger.error(f"Error getting organization members: {e!s}", exc_info=True) raise @@ -177,13 +178,13 @@ async def get_organization_members( response_model=OrganizationResponse, summary="Update Organization", description="Update organization details (admin/owner only)", - operation_id="update_organization" + operation_id="update_organization", ) async def update_organization( organization_id: UUID, org_in: OrganizationUpdate, current_user: User = Depends(require_org_admin), - db: AsyncSession = Depends(get_db) + db: AsyncSession = Depends(get_db), ) -> Any: """ Update organization details. @@ -195,11 +196,13 @@ async def update_organization( if not org: # pragma: no cover - Permission check prevents this (see docs/UNREACHABLE_DEFENSIVE_CODE_ANALYSIS.md) raise NotFoundError( detail=f"Organization {organization_id} not found", - error_code=ErrorCode.NOT_FOUND + error_code=ErrorCode.NOT_FOUND, ) updated_org = await organization_crud.update(db, db_obj=org, obj_in=org_in) - logger.info(f"User {current_user.email} updated organization {updated_org.name}") + logger.info( + f"User {current_user.email} updated organization {updated_org.name}" + ) org_dict = { "id": updated_org.id, @@ -210,12 +213,14 @@ async def update_organization( "settings": updated_org.settings, "created_at": updated_org.created_at, "updated_at": updated_org.updated_at, - "member_count": await organization_crud.get_member_count(db, organization_id=updated_org.id) + "member_count": await organization_crud.get_member_count( + db, organization_id=updated_org.id + ), } return OrganizationResponse(**org_dict) except NotFoundError: # pragma: no cover - See above raise except Exception as e: - logger.error(f"Error updating organization: {str(e)}", exc_info=True) + logger.error(f"Error updating organization: {e!s}", exc_info=True) raise diff --git a/backend/app/api/routes/sessions.py b/backend/app/api/routes/sessions.py index f39056b..54f1451 100755 --- a/backend/app/api/routes/sessions.py +++ b/backend/app/api/routes/sessions.py @@ -3,11 +3,12 @@ Session management endpoints. Allows users to view and manage their active sessions across devices. """ + import logging from typing import Any from uuid import UUID -from fastapi import APIRouter, Depends, HTTPException, status, Request +from fastapi import APIRouter, Depends, HTTPException, Request, status from slowapi import Limiter from slowapi.util import get_remote_address from sqlalchemy.ext.asyncio import AsyncSession @@ -15,11 +16,11 @@ from sqlalchemy.ext.asyncio import AsyncSession from app.api.dependencies.auth import get_current_user from app.core.auth import decode_token from app.core.database import get_db -from app.core.exceptions import NotFoundError, AuthorizationError, ErrorCode +from app.core.exceptions import AuthorizationError, ErrorCode, NotFoundError from app.crud.session import session as session_crud from app.models.user import User from app.schemas.common import MessageResponse -from app.schemas.sessions import SessionResponse, SessionListResponse +from app.schemas.sessions import SessionListResponse, SessionResponse router = APIRouter() logger = logging.getLogger(__name__) @@ -39,13 +40,13 @@ limiter = Limiter(key_func=get_remote_address) **Rate Limit**: 30 requests/minute """, - operation_id="list_my_sessions" + operation_id="list_my_sessions", ) @limiter.limit("30/minute") async def list_my_sessions( request: Request, current_user: User = Depends(get_current_user), - db: AsyncSession = Depends(get_db) + db: AsyncSession = Depends(get_db), ) -> Any: """ List all active sessions for the current user. @@ -60,18 +61,15 @@ async def list_my_sessions( try: # Get all active sessions for user sessions = await session_crud.get_user_sessions( - db, - user_id=str(current_user.id), - active_only=True + db, user_id=str(current_user.id), active_only=True ) # Try to identify current session from Authorization header - current_session_jti = None auth_header = request.headers.get("authorization") if auth_header and auth_header.startswith("Bearer "): try: access_token = auth_header.split(" ")[1] - token_payload = decode_token(access_token) + decode_token(access_token) # Note: Access tokens don't have JTI by default, but we can try # For now, we'll mark current based on most recent activity except Exception: @@ -90,22 +88,27 @@ async def list_my_sessions( last_used_at=s.last_used_at, created_at=s.created_at, expires_at=s.expires_at, - is_current=(s == sessions[0] if sessions else False) # Most recent = current + is_current=( + s == sessions[0] if sessions else False + ), # Most recent = current ) session_responses.append(session_response) - logger.info(f"User {current_user.id} listed {len(session_responses)} active sessions") + logger.info( + f"User {current_user.id} listed {len(session_responses)} active sessions" + ) return SessionListResponse( - sessions=session_responses, - total=len(session_responses) + sessions=session_responses, total=len(session_responses) ) except Exception as e: - logger.error(f"Error listing sessions for user {current_user.id}: {str(e)}", exc_info=True) + logger.error( + f"Error listing sessions for user {current_user.id}: {e!s}", exc_info=True + ) raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail="Failed to retrieve sessions" + detail="Failed to retrieve sessions", ) @@ -122,14 +125,14 @@ async def list_my_sessions( **Rate Limit**: 10 requests/minute """, - operation_id="revoke_session" + operation_id="revoke_session", ) @limiter.limit("10/minute") async def revoke_session( request: Request, session_id: UUID, current_user: User = Depends(get_current_user), - db: AsyncSession = Depends(get_db) + db: AsyncSession = Depends(get_db), ) -> Any: """ Revoke a specific session by ID. @@ -149,7 +152,7 @@ async def revoke_session( if not session: raise NotFoundError( message=f"Session {session_id} not found", - error_code=ErrorCode.NOT_FOUND + error_code=ErrorCode.NOT_FOUND, ) # Verify session belongs to current user @@ -160,7 +163,7 @@ async def revoke_session( ) raise AuthorizationError( message="You can only revoke your own sessions", - error_code=ErrorCode.INSUFFICIENT_PERMISSIONS + error_code=ErrorCode.INSUFFICIENT_PERMISSIONS, ) # Deactivate the session @@ -173,16 +176,16 @@ async def revoke_session( return MessageResponse( success=True, - message=f"Session revoked: {session.device_name or 'Unknown device'}" + message=f"Session revoked: {session.device_name or 'Unknown device'}", ) except (NotFoundError, AuthorizationError): raise except Exception as e: - logger.error(f"Error revoking session {session_id}: {str(e)}", exc_info=True) + logger.error(f"Error revoking session {session_id}: {e!s}", exc_info=True) raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail="Failed to revoke session" + detail="Failed to revoke session", ) @@ -198,13 +201,13 @@ async def revoke_session( **Rate Limit**: 5 requests/minute """, - operation_id="cleanup_expired_sessions" + operation_id="cleanup_expired_sessions", ) @limiter.limit("5/minute") async def cleanup_expired_sessions( request: Request, current_user: User = Depends(get_current_user), - db: AsyncSession = Depends(get_db) + db: AsyncSession = Depends(get_db), ) -> Any: """ Cleanup expired sessions for the current user. @@ -219,21 +222,24 @@ async def cleanup_expired_sessions( try: # Use optimized bulk DELETE instead of N individual deletes deleted_count = await session_crud.cleanup_expired_for_user( - db, - user_id=str(current_user.id) + db, user_id=str(current_user.id) ) - logger.info(f"User {current_user.id} cleaned up {deleted_count} expired sessions") + logger.info( + f"User {current_user.id} cleaned up {deleted_count} expired sessions" + ) return MessageResponse( - success=True, - message=f"Cleaned up {deleted_count} expired sessions" + success=True, message=f"Cleaned up {deleted_count} expired sessions" ) except Exception as e: - logger.error(f"Error cleaning up sessions for user {current_user.id}: {str(e)}", exc_info=True) + logger.error( + f"Error cleaning up sessions for user {current_user.id}: {e!s}", + exc_info=True, + ) await db.rollback() raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail="Failed to cleanup sessions" + detail="Failed to cleanup sessions", ) diff --git a/backend/app/api/routes/users.py b/backend/app/api/routes/users.py index ef3e097..34790f8 100755 --- a/backend/app/api/routes/users.py +++ b/backend/app/api/routes/users.py @@ -1,33 +1,30 @@ """ User management endpoints for CRUD operations. """ + import logging -from typing import Any, Optional +from typing import Any from uuid import UUID -from fastapi import APIRouter, Depends, Query, status, Request +from fastapi import APIRouter, Depends, Query, Request, status from slowapi import Limiter from slowapi.util import get_remote_address from sqlalchemy.ext.asyncio import AsyncSession -from app.api.dependencies.auth import get_current_user, get_current_superuser +from app.api.dependencies.auth import get_current_superuser, get_current_user from app.core.database import get_db -from app.core.exceptions import ( - NotFoundError, - AuthorizationError, - ErrorCode -) +from app.core.exceptions import AuthorizationError, ErrorCode, NotFoundError from app.crud.user import user as user_crud from app.models.user import User from app.schemas.common import ( - PaginationParams, - PaginatedResponse, MessageResponse, + PaginatedResponse, + PaginationParams, SortParams, - create_pagination_meta + create_pagination_meta, ) -from app.schemas.users import UserResponse, UserUpdate, PasswordChange -from app.services.auth_service import AuthService, AuthenticationError +from app.schemas.users import PasswordChange, UserResponse, UserUpdate +from app.services.auth_service import AuthenticationError, AuthService logger = logging.getLogger(__name__) @@ -50,15 +47,15 @@ limiter = Limiter(key_func=get_remote_address) **Rate Limit**: 60 requests/minute """, - operation_id="list_users" + operation_id="list_users", ) async def list_users( pagination: PaginationParams = Depends(), sort: SortParams = Depends(), - is_active: Optional[bool] = Query(None, description="Filter by active status"), - is_superuser: Optional[bool] = Query(None, description="Filter by superuser status"), + is_active: bool | None = Query(None, description="Filter by active status"), + is_superuser: bool | None = Query(None, description="Filter by superuser status"), current_user: User = Depends(get_current_superuser), - db: AsyncSession = Depends(get_db) + db: AsyncSession = Depends(get_db), ) -> Any: """ List all users with pagination, filtering, and sorting. @@ -80,7 +77,7 @@ async def list_users( limit=pagination.limit, sort_by=sort.sort_by, sort_order=sort.sort_order.value if sort.sort_order else "asc", - filters=filters if filters else None + filters=filters if filters else None, ) # Create pagination metadata @@ -88,15 +85,12 @@ async def list_users( total=total, page=pagination.page, limit=pagination.limit, - items_count=len(users) + items_count=len(users), ) - return PaginatedResponse( - data=users, - pagination=pagination_meta - ) + return PaginatedResponse(data=users, pagination=pagination_meta) except Exception as e: - logger.error(f"Error listing users: {str(e)}", exc_info=True) + logger.error(f"Error listing users: {e!s}", exc_info=True) raise @@ -111,11 +105,9 @@ async def list_users( **Rate Limit**: 60 requests/minute """, - operation_id="get_current_user_profile" + operation_id="get_current_user_profile", ) -def get_current_user_profile( - current_user: User = Depends(get_current_user) -) -> Any: +def get_current_user_profile(current_user: User = Depends(get_current_user)) -> Any: """Get current user's profile.""" return current_user @@ -133,12 +125,12 @@ def get_current_user_profile( **Rate Limit**: 30 requests/minute """, - operation_id="update_current_user" + operation_id="update_current_user", ) async def update_current_user( user_update: UserUpdate, current_user: User = Depends(get_current_user), - db: AsyncSession = Depends(get_db) + db: AsyncSession = Depends(get_db), ) -> Any: """ Update current user's profile. @@ -147,17 +139,17 @@ async def update_current_user( """ try: updated_user = await user_crud.update( - db, - db_obj=current_user, - obj_in=user_update + db, db_obj=current_user, obj_in=user_update ) logger.info(f"User {current_user.id} updated their profile") return updated_user except ValueError as e: - logger.error(f"Error updating user {current_user.id}: {str(e)}") + logger.error(f"Error updating user {current_user.id}: {e!s}") raise except Exception as e: - logger.error(f"Unexpected error updating user {current_user.id}: {str(e)}", exc_info=True) + logger.error( + f"Unexpected error updating user {current_user.id}: {e!s}", exc_info=True + ) raise @@ -175,12 +167,12 @@ async def update_current_user( **Rate Limit**: 60 requests/minute """, - operation_id="get_user_by_id" + operation_id="get_user_by_id", ) async def get_user_by_id( user_id: UUID, current_user: User = Depends(get_current_user), - db: AsyncSession = Depends(get_db) + db: AsyncSession = Depends(get_db), ) -> Any: """ Get user by ID. @@ -194,7 +186,7 @@ async def get_user_by_id( ) raise AuthorizationError( message="Not enough permissions to view this user", - error_code=ErrorCode.INSUFFICIENT_PERMISSIONS + error_code=ErrorCode.INSUFFICIENT_PERMISSIONS, ) # Get user @@ -202,7 +194,7 @@ async def get_user_by_id( if not user: raise NotFoundError( message=f"User with id {user_id} not found", - error_code=ErrorCode.USER_NOT_FOUND + error_code=ErrorCode.USER_NOT_FOUND, ) return user @@ -222,13 +214,13 @@ async def get_user_by_id( **Rate Limit**: 30 requests/minute """, - operation_id="update_user" + operation_id="update_user", ) async def update_user( user_id: UUID, user_update: UserUpdate, current_user: User = Depends(get_current_user), - db: AsyncSession = Depends(get_db) + db: AsyncSession = Depends(get_db), ) -> Any: """ Update user by ID. @@ -245,7 +237,7 @@ async def update_user( ) raise AuthorizationError( message="Not enough permissions to update this user", - error_code=ErrorCode.INSUFFICIENT_PERMISSIONS + error_code=ErrorCode.INSUFFICIENT_PERMISSIONS, ) # Get user @@ -253,7 +245,7 @@ async def update_user( if not user: raise NotFoundError( message=f"User with id {user_id} not found", - error_code=ErrorCode.USER_NOT_FOUND + error_code=ErrorCode.USER_NOT_FOUND, ) try: @@ -261,10 +253,10 @@ async def update_user( logger.info(f"User {user_id} updated by {current_user.id}") return updated_user except ValueError as e: - logger.error(f"Error updating user {user_id}: {str(e)}") + logger.error(f"Error updating user {user_id}: {e!s}") raise except Exception as e: - logger.error(f"Unexpected error updating user {user_id}: {str(e)}", exc_info=True) + logger.error(f"Unexpected error updating user {user_id}: {e!s}", exc_info=True) raise @@ -281,14 +273,14 @@ async def update_user( **Rate Limit**: 5 requests/minute """, - operation_id="change_current_user_password" + operation_id="change_current_user_password", ) @limiter.limit("5/minute") async def change_current_user_password( request: Request, password_change: PasswordChange, current_user: User = Depends(get_current_user), - db: AsyncSession = Depends(get_db) + db: AsyncSession = Depends(get_db), ) -> Any: """ Change current user's password. @@ -300,23 +292,23 @@ async def change_current_user_password( db=db, user_id=current_user.id, current_password=password_change.current_password, - new_password=password_change.new_password + new_password=password_change.new_password, ) if success: logger.info(f"User {current_user.id} changed their password") return MessageResponse( - success=True, - message="Password changed successfully" + success=True, message="Password changed successfully" ) except AuthenticationError as e: - logger.warning(f"Failed password change attempt for user {current_user.id}: {str(e)}") + logger.warning( + f"Failed password change attempt for user {current_user.id}: {e!s}" + ) raise AuthorizationError( - message=str(e), - error_code=ErrorCode.INVALID_CREDENTIALS + message=str(e), error_code=ErrorCode.INVALID_CREDENTIALS ) except Exception as e: - logger.error(f"Error changing password for user {current_user.id}: {str(e)}") + logger.error(f"Error changing password for user {current_user.id}: {e!s}") raise @@ -335,12 +327,12 @@ async def change_current_user_password( **Note**: This performs a hard delete. Consider implementing soft deletes for production. """, - operation_id="delete_user" + operation_id="delete_user", ) async def delete_user( user_id: UUID, current_user: User = Depends(get_current_superuser), - db: AsyncSession = Depends(get_db) + db: AsyncSession = Depends(get_db), ) -> Any: """ Delete user by ID (superuser only). @@ -351,7 +343,7 @@ async def delete_user( if str(user_id) == str(current_user.id): raise AuthorizationError( message="Cannot delete your own account", - error_code=ErrorCode.INSUFFICIENT_PERMISSIONS + error_code=ErrorCode.INSUFFICIENT_PERMISSIONS, ) # Get user @@ -359,7 +351,7 @@ async def delete_user( if not user: raise NotFoundError( message=f"User with id {user_id} not found", - error_code=ErrorCode.USER_NOT_FOUND + error_code=ErrorCode.USER_NOT_FOUND, ) try: @@ -367,12 +359,11 @@ async def delete_user( await user_crud.soft_delete(db, id=str(user_id)) logger.info(f"User {user_id} soft-deleted by {current_user.id}") return MessageResponse( - success=True, - message=f"User {user_id} deleted successfully" + success=True, message=f"User {user_id} deleted successfully" ) except ValueError as e: - logger.error(f"Error deleting user {user_id}: {str(e)}") + logger.error(f"Error deleting user {user_id}: {e!s}") raise except Exception as e: - logger.error(f"Unexpected error deleting user {user_id}: {str(e)}", exc_info=True) + logger.error(f"Unexpected error deleting user {user_id}: {e!s}", exc_info=True) raise diff --git a/backend/app/core/auth.py b/backend/app/core/auth.py index e73623b..b1c0b4e 100644 --- a/backend/app/core/auth.py +++ b/backend/app/core/auth.py @@ -1,39 +1,39 @@ import logging -logging.getLogger('passlib').setLevel(logging.ERROR) -from datetime import datetime, timedelta, timezone -from typing import Any, Dict, Optional, Union -import uuid +logging.getLogger("passlib").setLevel(logging.ERROR) + import asyncio +import uuid +from datetime import UTC, datetime, timedelta from functools import partial +from typing import Any -from jose import jwt, JWTError +from jose import JWTError, jwt from passlib.context import CryptContext from pydantic import ValidationError from app.core.config import settings from app.schemas.users import TokenData, TokenPayload - # Password hashing context pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto") + # Custom exceptions for auth class AuthError(Exception): """Base authentication error""" - pass + class TokenExpiredError(AuthError): """Token has expired""" - pass + class TokenInvalidError(AuthError): """Token is invalid""" - pass + class TokenMissingClaimError(AuthError): """Token is missing a required claim""" - pass def verify_password(plain_password: str, hashed_password: str) -> bool: @@ -62,8 +62,7 @@ async def verify_password_async(plain_password: str, hashed_password: str) -> bo """ loop = asyncio.get_event_loop() return await loop.run_in_executor( - None, - partial(pwd_context.verify, plain_password, hashed_password) + None, partial(pwd_context.verify, plain_password, hashed_password) ) @@ -82,17 +81,13 @@ async def get_password_hash_async(password: str) -> str: Hashed password string """ loop = asyncio.get_event_loop() - return await loop.run_in_executor( - None, - pwd_context.hash, - password - ) + return await loop.run_in_executor(None, pwd_context.hash, password) def create_access_token( - subject: Union[str, Any], - expires_delta: Optional[timedelta] = None, - claims: Optional[Dict[str, Any]] = None + subject: str | Any, + expires_delta: timedelta | None = None, + claims: dict[str, Any] | None = None, ) -> str: """ Create a JWT access token. @@ -106,17 +101,19 @@ def create_access_token( Encoded JWT token """ if expires_delta: - expire = datetime.now(timezone.utc) + expires_delta + expire = datetime.now(UTC) + expires_delta else: - expire = datetime.now(timezone.utc) + timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES) + expire = datetime.now(UTC) + timedelta( + minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES + ) # Base token data to_encode = { "sub": str(subject), "exp": expire, - "iat": datetime.now(tz=timezone.utc), + "iat": datetime.now(tz=UTC), "jti": str(uuid.uuid4()), - "type": "access" + "type": "access", } # Add custom claims @@ -125,17 +122,14 @@ def create_access_token( # Create the JWT encoded_jwt = jwt.encode( - to_encode, - settings.SECRET_KEY, - algorithm=settings.ALGORITHM + to_encode, settings.SECRET_KEY, algorithm=settings.ALGORITHM ) return encoded_jwt def create_refresh_token( - subject: Union[str, Any], - expires_delta: Optional[timedelta] = None + subject: str | Any, expires_delta: timedelta | None = None ) -> str: """ Create a JWT refresh token. @@ -148,28 +142,26 @@ def create_refresh_token( Encoded JWT refresh token """ if expires_delta: - expire = datetime.now(timezone.utc) + expires_delta + expire = datetime.now(UTC) + expires_delta else: - expire = datetime.now(timezone.utc) + timedelta(days=settings.REFRESH_TOKEN_EXPIRE_DAYS) + expire = datetime.now(UTC) + timedelta(days=settings.REFRESH_TOKEN_EXPIRE_DAYS) to_encode = { "sub": str(subject), "exp": expire, - "iat": datetime.now(timezone.utc), + "iat": datetime.now(UTC), "jti": str(uuid.uuid4()), - "type": "refresh" + "type": "refresh", } encoded_jwt = jwt.encode( - to_encode, - settings.SECRET_KEY, - algorithm=settings.ALGORITHM + to_encode, settings.SECRET_KEY, algorithm=settings.ALGORITHM ) return encoded_jwt -def decode_token(token: str, verify_type: Optional[str] = None) -> TokenPayload: +def decode_token(token: str, verify_type: str | None = None) -> TokenPayload: """ Decode and verify a JWT token. @@ -195,8 +187,8 @@ def decode_token(token: str, verify_type: Optional[str] = None) -> TokenPayload: "verify_signature": True, "verify_exp": True, "verify_iat": True, - "require": ["exp", "sub", "iat"] - } + "require": ["exp", "sub", "iat"], + }, ) # SECURITY: Explicitly verify the algorithm to prevent algorithm confusion attacks @@ -250,4 +242,4 @@ def get_token_data(token: str) -> TokenData: user_id = payload.sub is_superuser = payload.is_superuser or False - return TokenData(user_id=uuid.UUID(user_id), is_superuser=is_superuser) \ No newline at end of file + return TokenData(user_id=uuid.UUID(user_id), is_superuser=is_superuser) diff --git a/backend/app/core/config.py b/backend/app/core/config.py index 53c6cba..11def0c 100644 --- a/backend/app/core/config.py +++ b/backend/app/core/config.py @@ -1,5 +1,4 @@ import logging -from typing import Optional, List from pydantic import Field, field_validator from pydantic_settings import BaseSettings @@ -13,7 +12,7 @@ class Settings(BaseSettings): # Environment (must be before SECRET_KEY for validation) ENVIRONMENT: str = Field( default="development", - description="Environment: development, staging, or production" + description="Environment: development, staging, or production", ) # Security: Content Security Policy @@ -21,8 +20,7 @@ class Settings(BaseSettings): # Set to True for strict CSP (blocks most external resources) # Set to "relaxed" for modern frontend development CSP_MODE: str = Field( - default="relaxed", - description="CSP mode: 'strict', 'relaxed', or 'disabled'" + default="relaxed", description="CSP mode: 'strict', 'relaxed', or 'disabled'" ) # Database configuration @@ -31,7 +29,7 @@ class Settings(BaseSettings): POSTGRES_HOST: str = "localhost" POSTGRES_PORT: str = "5432" POSTGRES_DB: str = "app" - DATABASE_URL: Optional[str] = None + DATABASE_URL: str | None = None db_pool_size: int = 20 # Default connection pool size db_max_overflow: int = 50 # Maximum overflow connections db_pool_timeout: int = 30 # Seconds to wait for a connection @@ -59,38 +57,36 @@ class Settings(BaseSettings): SECRET_KEY: str = Field( default="dev_only_insecure_key_change_in_production_32chars_min", min_length=32, - description="JWT signing key. MUST be changed in production. Generate with: python -c 'import secrets; print(secrets.token_urlsafe(32))'" + description="JWT signing key. MUST be changed in production. Generate with: python -c 'import secrets; print(secrets.token_urlsafe(32))'", ) ALGORITHM: str = "HS256" ACCESS_TOKEN_EXPIRE_MINUTES: int = 15 # 15 minutes (production standard) REFRESH_TOKEN_EXPIRE_DAYS: int = 7 # 7 days # CORS configuration - BACKEND_CORS_ORIGINS: List[str] = ["http://localhost:3000"] + BACKEND_CORS_ORIGINS: list[str] = ["http://localhost:3000"] # Frontend URL for email links FRONTEND_URL: str = Field( default="http://localhost:3000", - description="Frontend application URL for email links" + description="Frontend application URL for email links", ) # Admin user - FIRST_SUPERUSER_EMAIL: Optional[str] = Field( - default=None, - description="Email for first superuser account" + FIRST_SUPERUSER_EMAIL: str | None = Field( + default=None, description="Email for first superuser account" ) - FIRST_SUPERUSER_PASSWORD: Optional[str] = Field( - default=None, - description="Password for first superuser (min 12 characters)" + FIRST_SUPERUSER_PASSWORD: str | None = Field( + default=None, description="Password for first superuser (min 12 characters)" ) - @field_validator('SECRET_KEY') + @field_validator("SECRET_KEY") @classmethod def validate_secret_key(cls, v: str, info) -> str: """Validate SECRET_KEY is secure, especially in production.""" # Get environment from values if available values_data = info.data if info.data else {} - env = values_data.get('ENVIRONMENT', 'development') + env = values_data.get("ENVIRONMENT", "development") if v.startswith("your_secret_key_here"): if env == "production": @@ -106,13 +102,15 @@ class Settings(BaseSettings): ) if len(v) < 32: - raise ValueError("SECRET_KEY must be at least 32 characters long for security") + raise ValueError( + "SECRET_KEY must be at least 32 characters long for security" + ) return v - @field_validator('FIRST_SUPERUSER_PASSWORD') + @field_validator("FIRST_SUPERUSER_PASSWORD") @classmethod - def validate_superuser_password(cls, v: Optional[str]) -> Optional[str]: + def validate_superuser_password(cls, v: str | None) -> str | None: """Validate superuser password strength.""" if v is None: return v @@ -121,7 +119,13 @@ class Settings(BaseSettings): raise ValueError("FIRST_SUPERUSER_PASSWORD must be at least 12 characters") # Check for common weak passwords - weak_passwords = {'admin123', 'Admin123', 'password123', 'Password123', '123456789012'} + weak_passwords = { + "admin123", + "Admin123", + "password123", + "Password123", + "123456789012", + } if v in weak_passwords: raise ValueError( "FIRST_SUPERUSER_PASSWORD is too weak. " @@ -144,8 +148,8 @@ class Settings(BaseSettings): "env_file": "../.env", "env_file_encoding": "utf-8", "case_sensitive": True, - "extra": "ignore" # Ignore extra fields from .env (e.g., frontend-specific vars) + "extra": "ignore", # Ignore extra fields from .env (e.g., frontend-specific vars) } -settings = Settings() \ No newline at end of file +settings = Settings() diff --git a/backend/app/core/database.py b/backend/app/core/database.py index 1265164..749aa2d 100755 --- a/backend/app/core/database.py +++ b/backend/app/core/database.py @@ -5,17 +5,18 @@ Database configuration using SQLAlchemy 2.0 and asyncpg. This module provides async database connectivity with proper connection pooling and session management for FastAPI endpoints. """ + import logging +from collections.abc import AsyncGenerator from contextlib import asynccontextmanager -from typing import AsyncGenerator from sqlalchemy import text from sqlalchemy.dialects.postgresql import JSONB, UUID from sqlalchemy.ext.asyncio import ( - AsyncSession, AsyncEngine, - create_async_engine, + AsyncSession, async_sessionmaker, + create_async_engine, ) from sqlalchemy.ext.compiler import compiles from sqlalchemy.orm import DeclarativeBase @@ -27,12 +28,12 @@ logger = logging.getLogger(__name__) # SQLite compatibility for testing -@compiles(JSONB, 'sqlite') +@compiles(JSONB, "sqlite") def compile_jsonb_sqlite(type_, compiler, **kw): return "TEXT" -@compiles(UUID, 'sqlite') +@compiles(UUID, "sqlite") def compile_uuid_sqlite(type_, compiler, **kw): return "TEXT" @@ -40,7 +41,6 @@ def compile_uuid_sqlite(type_, compiler, **kw): # Declarative base for models (SQLAlchemy 2.0 style) class Base(DeclarativeBase): """Base class for all database models.""" - pass def get_async_database_url(url: str) -> str: @@ -139,7 +139,7 @@ async def async_transaction_scope() -> AsyncGenerator[AsyncSession, None]: logger.debug("Async transaction committed successfully") except Exception as e: await session.rollback() - logger.error(f"Async transaction failed, rolling back: {str(e)}") + logger.error(f"Async transaction failed, rolling back: {e!s}") raise finally: await session.close() @@ -155,7 +155,7 @@ async def check_async_database_health() -> bool: await db.execute(text("SELECT 1")) return True except Exception as e: - logger.error(f"Async database health check failed: {str(e)}") + logger.error(f"Async database health check failed: {e!s}") return False diff --git a/backend/app/core/exceptions.py b/backend/app/core/exceptions.py index af39ace..0b45b09 100644 --- a/backend/app/core/exceptions.py +++ b/backend/app/core/exceptions.py @@ -1,8 +1,8 @@ """ Custom exceptions and global exception handlers for the API. """ + import logging -from typing import Optional, Union from fastapi import HTTPException, Request, status from fastapi.exceptions import RequestValidationError @@ -27,17 +27,13 @@ class APIException(HTTPException): status_code: int, error_code: ErrorCode, message: str, - field: Optional[str] = None, - headers: Optional[dict] = None + field: str | None = None, + headers: dict | None = None, ): self.error_code = error_code self.field = field self.message = message - super().__init__( - status_code=status_code, - detail=message, - headers=headers - ) + super().__init__(status_code=status_code, detail=message, headers=headers) class AuthenticationError(APIException): @@ -47,14 +43,14 @@ class AuthenticationError(APIException): self, message: str = "Authentication failed", error_code: ErrorCode = ErrorCode.INVALID_CREDENTIALS, - field: Optional[str] = None + field: str | None = None, ): super().__init__( status_code=status.HTTP_401_UNAUTHORIZED, error_code=error_code, message=message, field=field, - headers={"WWW-Authenticate": "Bearer"} + headers={"WWW-Authenticate": "Bearer"}, ) @@ -64,12 +60,12 @@ class AuthorizationError(APIException): def __init__( self, message: str = "Insufficient permissions", - error_code: ErrorCode = ErrorCode.INSUFFICIENT_PERMISSIONS + error_code: ErrorCode = ErrorCode.INSUFFICIENT_PERMISSIONS, ): super().__init__( status_code=status.HTTP_403_FORBIDDEN, error_code=error_code, - message=message + message=message, ) @@ -79,12 +75,12 @@ class NotFoundError(APIException): def __init__( self, message: str = "Resource not found", - error_code: ErrorCode = ErrorCode.NOT_FOUND + error_code: ErrorCode = ErrorCode.NOT_FOUND, ): super().__init__( status_code=status.HTTP_404_NOT_FOUND, error_code=error_code, - message=message + message=message, ) @@ -95,13 +91,13 @@ class DuplicateError(APIException): self, message: str = "Resource already exists", error_code: ErrorCode = ErrorCode.DUPLICATE_ENTRY, - field: Optional[str] = None + field: str | None = None, ): super().__init__( status_code=status.HTTP_409_CONFLICT, error_code=error_code, message=message, - field=field + field=field, ) @@ -112,13 +108,13 @@ class ValidationException(APIException): self, message: str = "Validation error", error_code: ErrorCode = ErrorCode.VALIDATION_ERROR, - field: Optional[str] = None + field: str | None = None, ): super().__init__( status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, error_code=error_code, message=message, - field=field + field=field, ) @@ -128,12 +124,12 @@ class DatabaseError(APIException): def __init__( self, message: str = "Database operation failed", - error_code: ErrorCode = ErrorCode.DATABASE_ERROR + error_code: ErrorCode = ErrorCode.DATABASE_ERROR, ): super().__init__( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, error_code=error_code, - message=message + message=message, ) @@ -152,23 +148,18 @@ async def api_exception_handler(request: Request, exc: APIException) -> JSONResp ) error_response = ErrorResponse( - errors=[ErrorDetail( - code=exc.error_code, - message=exc.message, - field=exc.field - )] + errors=[ErrorDetail(code=exc.error_code, message=exc.message, field=exc.field)] ) return JSONResponse( status_code=exc.status_code, content=error_response.model_dump(), - headers=exc.headers + headers=exc.headers, ) async def validation_exception_handler( - request: Request, - exc: Union[RequestValidationError, ValidationError] + request: Request, exc: RequestValidationError | ValidationError ) -> JSONResponse: """ Handler for Pydantic validation errors. @@ -189,22 +180,19 @@ async def validation_exception_handler( # Skip 'body' or 'query' prefix in location field = ".".join(str(x) for x in error["loc"][1:]) - errors.append(ErrorDetail( - code=ErrorCode.VALIDATION_ERROR, - message=error["msg"], - field=field - )) + errors.append( + ErrorDetail( + code=ErrorCode.VALIDATION_ERROR, message=error["msg"], field=field + ) + ) - logger.warning( - f"Validation error: {len(errors)} errors " - f"(path: {request.url.path})" - ) + logger.warning(f"Validation error: {len(errors)} errors (path: {request.url.path})") error_response = ErrorResponse(errors=errors) return JSONResponse( status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, - content=error_response.model_dump() + content=error_response.model_dump(), ) @@ -226,26 +214,21 @@ async def http_exception_handler(request: Request, exc: HTTPException) -> JSONRe } error_code = status_code_to_error_code.get( - exc.status_code, - ErrorCode.INTERNAL_ERROR + exc.status_code, ErrorCode.INTERNAL_ERROR ) logger.warning( - f"HTTP exception: {exc.status_code} - {exc.detail} " - f"(path: {request.url.path})" + f"HTTP exception: {exc.status_code} - {exc.detail} (path: {request.url.path})" ) error_response = ErrorResponse( - errors=[ErrorDetail( - code=error_code, - message=str(exc.detail) - )] + errors=[ErrorDetail(code=error_code, message=str(exc.detail))] ) return JSONResponse( status_code=exc.status_code, content=error_response.model_dump(), - headers=exc.headers + headers=exc.headers, ) @@ -257,26 +240,24 @@ async def unhandled_exception_handler(request: Request, exc: Exception) -> JSONR leaking sensitive information in production. """ logger.error( - f"Unhandled exception: {type(exc).__name__} - {str(exc)} " + f"Unhandled exception: {type(exc).__name__} - {exc!s} " f"(path: {request.url.path})", - exc_info=True + exc_info=True, ) # In production, don't expose internal error details from app.core.config import settings + if settings.ENVIRONMENT == "production": message = "An internal error occurred. Please try again later." else: - message = f"{type(exc).__name__}: {str(exc)}" + message = f"{type(exc).__name__}: {exc!s}" error_response = ErrorResponse( - errors=[ErrorDetail( - code=ErrorCode.INTERNAL_ERROR, - message=message - )] + errors=[ErrorDetail(code=ErrorCode.INTERNAL_ERROR, message=message)] ) return JSONResponse( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - content=error_response.model_dump() + content=error_response.model_dump(), ) diff --git a/backend/app/crud/__init__.py b/backend/app/crud/__init__.py index 734fad0..46d2542 100644 --- a/backend/app/crud/__init__.py +++ b/backend/app/crud/__init__.py @@ -3,4 +3,4 @@ from .organization import organization from .session import session as session_crud from .user import user -__all__ = ["user", "session_crud", "organization"] +__all__ = ["organization", "session_crud", "user"] diff --git a/backend/app/crud/base.py b/backend/app/crud/base.py index ac66c09..c562f27 100755 --- a/backend/app/crud/base.py +++ b/backend/app/crud/base.py @@ -4,14 +4,16 @@ Async CRUD operations base class using SQLAlchemy 2.0 async patterns. Provides reusable create, read, update, and delete operations for all models. """ + import logging import uuid -from typing import Any, Dict, Generic, List, Optional, Type, TypeVar, Union, Tuple +from datetime import UTC +from typing import Any, TypeVar from fastapi.encoders import jsonable_encoder from pydantic import BaseModel from sqlalchemy import func, select -from sqlalchemy.exc import IntegrityError, OperationalError, DataError +from sqlalchemy.exc import DataError, IntegrityError, OperationalError from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm import Load @@ -24,10 +26,14 @@ CreateSchemaType = TypeVar("CreateSchemaType", bound=BaseModel) UpdateSchemaType = TypeVar("UpdateSchemaType", bound=BaseModel) -class CRUDBase(Generic[ModelType, CreateSchemaType, UpdateSchemaType]): +class CRUDBase[ + ModelType: Base, + CreateSchemaType: BaseModel, + UpdateSchemaType: BaseModel, +]: """Async CRUD operations for a model.""" - def __init__(self, model: Type[ModelType]): + def __init__(self, model: type[ModelType]): """ CRUD object with default async methods to Create, Read, Update, Delete. @@ -37,11 +43,8 @@ class CRUDBase(Generic[ModelType, CreateSchemaType, UpdateSchemaType]): self.model = model async def get( - self, - db: AsyncSession, - id: str, - options: Optional[List[Load]] = None - ) -> Optional[ModelType]: + self, db: AsyncSession, id: str, options: list[Load] | None = None + ) -> ModelType | None: """ Get a single record by ID with UUID validation and optional eager loading. @@ -66,7 +69,7 @@ class CRUDBase(Generic[ModelType, CreateSchemaType, UpdateSchemaType]): else: uuid_obj = uuid.UUID(str(id)) except (ValueError, AttributeError, TypeError) as e: - logger.warning(f"Invalid UUID format: {id} - {str(e)}") + logger.warning(f"Invalid UUID format: {id} - {e!s}") return None try: @@ -80,7 +83,7 @@ class CRUDBase(Generic[ModelType, CreateSchemaType, UpdateSchemaType]): result = await db.execute(query) return result.scalar_one_or_none() except Exception as e: - logger.error(f"Error retrieving {self.model.__name__} with id {id}: {str(e)}") + logger.error(f"Error retrieving {self.model.__name__} with id {id}: {e!s}") raise async def get_multi( @@ -89,8 +92,8 @@ class CRUDBase(Generic[ModelType, CreateSchemaType, UpdateSchemaType]): *, skip: int = 0, limit: int = 100, - options: Optional[List[Load]] = None - ) -> List[ModelType]: + options: list[Load] | None = None, + ) -> list[ModelType]: """ Get multiple records with pagination validation and optional eager loading. @@ -122,10 +125,14 @@ class CRUDBase(Generic[ModelType, CreateSchemaType, UpdateSchemaType]): result = await db.execute(query) return list(result.scalars().all()) except Exception as e: - logger.error(f"Error retrieving multiple {self.model.__name__} records: {str(e)}") + logger.error( + f"Error retrieving multiple {self.model.__name__} records: {e!s}" + ) raise - async def create(self, db: AsyncSession, *, obj_in: CreateSchemaType) -> ModelType: # pragma: no cover + async def create( + self, db: AsyncSession, *, obj_in: CreateSchemaType + ) -> ModelType: # pragma: no cover """Create a new record with error handling. NOTE: This method is defensive code that's never called in practice. @@ -142,19 +149,25 @@ class CRUDBase(Generic[ModelType, CreateSchemaType, UpdateSchemaType]): return db_obj except IntegrityError as e: # pragma: no cover await db.rollback() - error_msg = str(e.orig) if hasattr(e, 'orig') else str(e) + error_msg = str(e.orig) if hasattr(e, "orig") else str(e) if "unique" in error_msg.lower() or "duplicate" in error_msg.lower(): - logger.warning(f"Duplicate entry attempted for {self.model.__name__}: {error_msg}") - raise ValueError(f"A {self.model.__name__} with this data already exists") + logger.warning( + f"Duplicate entry attempted for {self.model.__name__}: {error_msg}" + ) + raise ValueError( + f"A {self.model.__name__} with this data already exists" + ) logger.error(f"Integrity error creating {self.model.__name__}: {error_msg}") raise ValueError(f"Database integrity error: {error_msg}") except (OperationalError, DataError) as e: # pragma: no cover await db.rollback() - logger.error(f"Database error creating {self.model.__name__}: {str(e)}") - raise ValueError(f"Database operation failed: {str(e)}") + logger.error(f"Database error creating {self.model.__name__}: {e!s}") + raise ValueError(f"Database operation failed: {e!s}") except Exception as e: # pragma: no cover await db.rollback() - logger.error(f"Unexpected error creating {self.model.__name__}: {str(e)}", exc_info=True) + logger.error( + f"Unexpected error creating {self.model.__name__}: {e!s}", exc_info=True + ) raise async def update( @@ -162,7 +175,7 @@ class CRUDBase(Generic[ModelType, CreateSchemaType, UpdateSchemaType]): db: AsyncSession, *, db_obj: ModelType, - obj_in: Union[UpdateSchemaType, Dict[str, Any]] + obj_in: UpdateSchemaType | dict[str, Any], ) -> ModelType: """Update a record with error handling.""" try: @@ -182,22 +195,28 @@ class CRUDBase(Generic[ModelType, CreateSchemaType, UpdateSchemaType]): return db_obj except IntegrityError as e: await db.rollback() - error_msg = str(e.orig) if hasattr(e, 'orig') else str(e) + error_msg = str(e.orig) if hasattr(e, "orig") else str(e) if "unique" in error_msg.lower() or "duplicate" in error_msg.lower(): - logger.warning(f"Duplicate entry attempted for {self.model.__name__}: {error_msg}") - raise ValueError(f"A {self.model.__name__} with this data already exists") + logger.warning( + f"Duplicate entry attempted for {self.model.__name__}: {error_msg}" + ) + raise ValueError( + f"A {self.model.__name__} with this data already exists" + ) logger.error(f"Integrity error updating {self.model.__name__}: {error_msg}") raise ValueError(f"Database integrity error: {error_msg}") except (OperationalError, DataError) as e: await db.rollback() - logger.error(f"Database error updating {self.model.__name__}: {str(e)}") - raise ValueError(f"Database operation failed: {str(e)}") + logger.error(f"Database error updating {self.model.__name__}: {e!s}") + raise ValueError(f"Database operation failed: {e!s}") except Exception as e: await db.rollback() - logger.error(f"Unexpected error updating {self.model.__name__}: {str(e)}", exc_info=True) + logger.error( + f"Unexpected error updating {self.model.__name__}: {e!s}", exc_info=True + ) raise - async def remove(self, db: AsyncSession, *, id: str) -> Optional[ModelType]: + async def remove(self, db: AsyncSession, *, id: str) -> ModelType | None: """Delete a record with error handling and null check.""" # Validate UUID format and convert to UUID object if string try: @@ -206,7 +225,7 @@ class CRUDBase(Generic[ModelType, CreateSchemaType, UpdateSchemaType]): else: uuid_obj = uuid.UUID(str(id)) except (ValueError, AttributeError, TypeError) as e: - logger.warning(f"Invalid UUID format for deletion: {id} - {str(e)}") + logger.warning(f"Invalid UUID format for deletion: {id} - {e!s}") return None try: @@ -216,7 +235,9 @@ class CRUDBase(Generic[ModelType, CreateSchemaType, UpdateSchemaType]): obj = result.scalar_one_or_none() if obj is None: - logger.warning(f"{self.model.__name__} with id {id} not found for deletion") + logger.warning( + f"{self.model.__name__} with id {id} not found for deletion" + ) return None await db.delete(obj) @@ -224,12 +245,17 @@ class CRUDBase(Generic[ModelType, CreateSchemaType, UpdateSchemaType]): return obj except IntegrityError as e: await db.rollback() - error_msg = str(e.orig) if hasattr(e, 'orig') else str(e) + error_msg = str(e.orig) if hasattr(e, "orig") else str(e) logger.error(f"Integrity error deleting {self.model.__name__}: {error_msg}") - raise ValueError(f"Cannot delete {self.model.__name__}: referenced by other records") + raise ValueError( + f"Cannot delete {self.model.__name__}: referenced by other records" + ) except Exception as e: await db.rollback() - logger.error(f"Error deleting {self.model.__name__} with id {id}: {str(e)}", exc_info=True) + logger.error( + f"Error deleting {self.model.__name__} with id {id}: {e!s}", + exc_info=True, + ) raise async def get_multi_with_total( @@ -238,10 +264,10 @@ class CRUDBase(Generic[ModelType, CreateSchemaType, UpdateSchemaType]): *, skip: int = 0, limit: int = 100, - sort_by: Optional[str] = None, + sort_by: str | None = None, sort_order: str = "asc", - filters: Optional[Dict[str, Any]] = None - ) -> Tuple[List[ModelType], int]: + filters: dict[str, Any] | None = None, + ) -> tuple[list[ModelType], int]: """ Get multiple records with total count, filtering, and sorting. @@ -269,7 +295,7 @@ class CRUDBase(Generic[ModelType, CreateSchemaType, UpdateSchemaType]): query = select(self.model) # Exclude soft-deleted records by default - if hasattr(self.model, 'deleted_at'): + if hasattr(self.model, "deleted_at"): query = query.where(self.model.deleted_at.is_(None)) # Apply filters @@ -298,7 +324,9 @@ class CRUDBase(Generic[ModelType, CreateSchemaType, UpdateSchemaType]): return items, total except Exception as e: - logger.error(f"Error retrieving paginated {self.model.__name__} records: {str(e)}") + logger.error( + f"Error retrieving paginated {self.model.__name__} records: {e!s}" + ) raise async def count(self, db: AsyncSession) -> int: @@ -307,7 +335,7 @@ class CRUDBase(Generic[ModelType, CreateSchemaType, UpdateSchemaType]): result = await db.execute(select(func.count(self.model.id))) return result.scalar_one() except Exception as e: - logger.error(f"Error counting {self.model.__name__} records: {str(e)}") + logger.error(f"Error counting {self.model.__name__} records: {e!s}") raise async def exists(self, db: AsyncSession, id: str) -> bool: @@ -315,13 +343,13 @@ class CRUDBase(Generic[ModelType, CreateSchemaType, UpdateSchemaType]): obj = await self.get(db, id=id) return obj is not None - async def soft_delete(self, db: AsyncSession, *, id: str) -> Optional[ModelType]: + async def soft_delete(self, db: AsyncSession, *, id: str) -> ModelType | None: """ Soft delete a record by setting deleted_at timestamp. Only works if the model has a 'deleted_at' column. """ - from datetime import datetime, timezone + from datetime import datetime # Validate UUID format and convert to UUID object if string try: @@ -330,7 +358,7 @@ class CRUDBase(Generic[ModelType, CreateSchemaType, UpdateSchemaType]): else: uuid_obj = uuid.UUID(str(id)) except (ValueError, AttributeError, TypeError) as e: - logger.warning(f"Invalid UUID format for soft deletion: {id} - {str(e)}") + logger.warning(f"Invalid UUID format for soft deletion: {id} - {e!s}") return None try: @@ -340,26 +368,33 @@ class CRUDBase(Generic[ModelType, CreateSchemaType, UpdateSchemaType]): obj = result.scalar_one_or_none() if obj is None: - logger.warning(f"{self.model.__name__} with id {id} not found for soft deletion") + logger.warning( + f"{self.model.__name__} with id {id} not found for soft deletion" + ) return None # Check if model supports soft deletes - if not hasattr(self.model, 'deleted_at'): + if not hasattr(self.model, "deleted_at"): logger.error(f"{self.model.__name__} does not support soft deletes") - raise ValueError(f"{self.model.__name__} does not have a deleted_at column") + raise ValueError( + f"{self.model.__name__} does not have a deleted_at column" + ) # Set deleted_at timestamp - obj.deleted_at = datetime.now(timezone.utc) + obj.deleted_at = datetime.now(UTC) db.add(obj) await db.commit() await db.refresh(obj) return obj except Exception as e: await db.rollback() - logger.error(f"Error soft deleting {self.model.__name__} with id {id}: {str(e)}", exc_info=True) + logger.error( + f"Error soft deleting {self.model.__name__} with id {id}: {e!s}", + exc_info=True, + ) raise - async def restore(self, db: AsyncSession, *, id: str) -> Optional[ModelType]: + async def restore(self, db: AsyncSession, *, id: str) -> ModelType | None: """ Restore a soft-deleted record by clearing the deleted_at timestamp. @@ -372,25 +407,28 @@ class CRUDBase(Generic[ModelType, CreateSchemaType, UpdateSchemaType]): else: uuid_obj = uuid.UUID(str(id)) except (ValueError, AttributeError, TypeError) as e: - logger.warning(f"Invalid UUID format for restoration: {id} - {str(e)}") + logger.warning(f"Invalid UUID format for restoration: {id} - {e!s}") return None try: # Find the soft-deleted record - if hasattr(self.model, 'deleted_at'): + if hasattr(self.model, "deleted_at"): result = await db.execute( select(self.model).where( - self.model.id == uuid_obj, - self.model.deleted_at.isnot(None) + self.model.id == uuid_obj, self.model.deleted_at.isnot(None) ) ) obj = result.scalar_one_or_none() else: logger.error(f"{self.model.__name__} does not support soft deletes") - raise ValueError(f"{self.model.__name__} does not have a deleted_at column") + raise ValueError( + f"{self.model.__name__} does not have a deleted_at column" + ) if obj is None: - logger.warning(f"Soft-deleted {self.model.__name__} with id {id} not found for restoration") + logger.warning( + f"Soft-deleted {self.model.__name__} with id {id} not found for restoration" + ) return None # Clear deleted_at timestamp @@ -401,5 +439,8 @@ class CRUDBase(Generic[ModelType, CreateSchemaType, UpdateSchemaType]): return obj except Exception as e: await db.rollback() - logger.error(f"Error restoring {self.model.__name__} with id {id}: {str(e)}", exc_info=True) + logger.error( + f"Error restoring {self.model.__name__} with id {id}: {e!s}", + exc_info=True, + ) raise diff --git a/backend/app/crud/organization.py b/backend/app/crud/organization.py index 3e98cd9..85ef256 100755 --- a/backend/app/crud/organization.py +++ b/backend/app/crud/organization.py @@ -1,17 +1,18 @@ # app/crud/organization_async.py """Async CRUD operations for Organization model using SQLAlchemy 2.0 patterns.""" + import logging -from typing import Optional, List, Dict, Any +from typing import Any from uuid import UUID -from sqlalchemy import func, or_, and_, select, case +from sqlalchemy import and_, case, func, or_, select from sqlalchemy.exc import IntegrityError from sqlalchemy.ext.asyncio import AsyncSession from app.crud.base import CRUDBase from app.models.organization import Organization from app.models.user import User -from app.models.user_organization import UserOrganization, OrganizationRole +from app.models.user_organization import OrganizationRole, UserOrganization from app.schemas.organizations import ( OrganizationCreate, OrganizationUpdate, @@ -23,7 +24,7 @@ logger = logging.getLogger(__name__) class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUpdate]): """Async CRUD operations for Organization model.""" - async def get_by_slug(self, db: AsyncSession, *, slug: str) -> Optional[Organization]: + async def get_by_slug(self, db: AsyncSession, *, slug: str) -> Organization | None: """Get organization by slug.""" try: result = await db.execute( @@ -31,10 +32,12 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp ) return result.scalar_one_or_none() except Exception as e: - logger.error(f"Error getting organization by slug {slug}: {str(e)}") + logger.error(f"Error getting organization by slug {slug}: {e!s}") raise - async def create(self, db: AsyncSession, *, obj_in: OrganizationCreate) -> Organization: + async def create( + self, db: AsyncSession, *, obj_in: OrganizationCreate + ) -> Organization: """Create a new organization with error handling.""" try: db_obj = Organization( @@ -42,7 +45,7 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp slug=obj_in.slug, description=obj_in.description, is_active=obj_in.is_active, - settings=obj_in.settings or {} + settings=obj_in.settings or {}, ) db.add(db_obj) await db.commit() @@ -50,15 +53,19 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp return db_obj except IntegrityError as e: await db.rollback() - error_msg = str(e.orig) if hasattr(e, 'orig') else str(e) + error_msg = str(e.orig) if hasattr(e, "orig") else str(e) if "slug" in error_msg.lower(): logger.warning(f"Duplicate slug attempted: {obj_in.slug}") - raise ValueError(f"Organization with slug '{obj_in.slug}' already exists") + raise ValueError( + f"Organization with slug '{obj_in.slug}' already exists" + ) logger.error(f"Integrity error creating organization: {error_msg}") raise ValueError(f"Database integrity error: {error_msg}") except Exception as e: await db.rollback() - logger.error(f"Unexpected error creating organization: {str(e)}", exc_info=True) + logger.error( + f"Unexpected error creating organization: {e!s}", exc_info=True + ) raise async def get_multi_with_filters( @@ -67,11 +74,11 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp *, skip: int = 0, limit: int = 100, - is_active: Optional[bool] = None, - search: Optional[str] = None, + is_active: bool | None = None, + search: str | None = None, sort_by: str = "created_at", - sort_order: str = "desc" - ) -> tuple[List[Organization], int]: + sort_order: str = "desc", + ) -> tuple[list[Organization], int]: """ Get multiple organizations with filtering, searching, and sorting. @@ -89,7 +96,7 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp search_filter = or_( Organization.name.ilike(f"%{search}%"), Organization.slug.ilike(f"%{search}%"), - Organization.description.ilike(f"%{search}%") + Organization.description.ilike(f"%{search}%"), ) query = query.where(search_filter) @@ -112,7 +119,7 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp return organizations, total except Exception as e: - logger.error(f"Error getting organizations with filters: {str(e)}") + logger.error(f"Error getting organizations with filters: {e!s}") raise async def get_member_count(self, db: AsyncSession, *, organization_id: UUID) -> int: @@ -122,13 +129,15 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp select(func.count(UserOrganization.user_id)).where( and_( UserOrganization.organization_id == organization_id, - UserOrganization.is_active == True + UserOrganization.is_active, ) ) ) return result.scalar_one() or 0 except Exception as e: - logger.error(f"Error getting member count for organization {organization_id}: {str(e)}") + logger.error( + f"Error getting member count for organization {organization_id}: {e!s}" + ) raise async def get_multi_with_member_counts( @@ -137,9 +146,9 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp *, skip: int = 0, limit: int = 100, - is_active: Optional[bool] = None, - search: Optional[str] = None - ) -> tuple[List[Dict[str, Any]], int]: + is_active: bool | None = None, + search: str | None = None, + ) -> tuple[list[dict[str, Any]], int]: """ Get organizations with member counts in a SINGLE QUERY using JOIN and GROUP BY. This eliminates the N+1 query problem. @@ -156,13 +165,19 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp func.count( func.distinct( case( - (UserOrganization.is_active == True, UserOrganization.user_id), - else_=None + ( + UserOrganization.is_active, + UserOrganization.user_id, + ), + else_=None, ) ) - ).label('member_count') + ).label("member_count"), + ) + .outerjoin( + UserOrganization, + Organization.id == UserOrganization.organization_id, ) - .outerjoin(UserOrganization, Organization.id == UserOrganization.organization_id) .group_by(Organization.id) ) @@ -174,7 +189,7 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp search_filter = or_( Organization.name.ilike(f"%{search}%"), Organization.slug.ilike(f"%{search}%"), - Organization.description.ilike(f"%{search}%") + Organization.description.ilike(f"%{search}%"), ) query = query.where(search_filter) @@ -189,24 +204,25 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp total = count_result.scalar_one() # Apply pagination and ordering - query = query.order_by(Organization.created_at.desc()).offset(skip).limit(limit) + query = ( + query.order_by(Organization.created_at.desc()).offset(skip).limit(limit) + ) result = await db.execute(query) rows = result.all() # Convert to list of dicts orgs_with_counts = [ - { - 'organization': org, - 'member_count': member_count - } + {"organization": org, "member_count": member_count} for org, member_count in rows ] return orgs_with_counts, total except Exception as e: - logger.error(f"Error getting organizations with member counts: {str(e)}", exc_info=True) + logger.error( + f"Error getting organizations with member counts: {e!s}", exc_info=True + ) raise async def add_user( @@ -216,7 +232,7 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp organization_id: UUID, user_id: UUID, role: OrganizationRole = OrganizationRole.MEMBER, - custom_permissions: Optional[str] = None + custom_permissions: str | None = None, ) -> UserOrganization: """Add a user to an organization with a specific role.""" try: @@ -225,7 +241,7 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp select(UserOrganization).where( and_( UserOrganization.user_id == user_id, - UserOrganization.organization_id == organization_id + UserOrganization.organization_id == organization_id, ) ) ) @@ -249,7 +265,7 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp organization_id=organization_id, role=role, is_active=True, - custom_permissions=custom_permissions + custom_permissions=custom_permissions, ) db.add(user_org) await db.commit() @@ -257,19 +273,15 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp return user_org except IntegrityError as e: await db.rollback() - logger.error(f"Integrity error adding user to organization: {str(e)}") + logger.error(f"Integrity error adding user to organization: {e!s}") raise ValueError("Failed to add user to organization") except Exception as e: await db.rollback() - logger.error(f"Error adding user to organization: {str(e)}", exc_info=True) + logger.error(f"Error adding user to organization: {e!s}", exc_info=True) raise async def remove_user( - self, - db: AsyncSession, - *, - organization_id: UUID, - user_id: UUID + self, db: AsyncSession, *, organization_id: UUID, user_id: UUID ) -> bool: """Remove a user from an organization (soft delete).""" try: @@ -277,7 +289,7 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp select(UserOrganization).where( and_( UserOrganization.user_id == user_id, - UserOrganization.organization_id == organization_id + UserOrganization.organization_id == organization_id, ) ) ) @@ -291,7 +303,7 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp return True except Exception as e: await db.rollback() - logger.error(f"Error removing user from organization: {str(e)}", exc_info=True) + logger.error(f"Error removing user from organization: {e!s}", exc_info=True) raise async def update_user_role( @@ -301,15 +313,15 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp organization_id: UUID, user_id: UUID, role: OrganizationRole, - custom_permissions: Optional[str] = None - ) -> Optional[UserOrganization]: + custom_permissions: str | None = None, + ) -> UserOrganization | None: """Update a user's role in an organization.""" try: result = await db.execute( select(UserOrganization).where( and_( UserOrganization.user_id == user_id, - UserOrganization.organization_id == organization_id + UserOrganization.organization_id == organization_id, ) ) ) @@ -326,7 +338,7 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp return user_org except Exception as e: await db.rollback() - logger.error(f"Error updating user role: {str(e)}", exc_info=True) + logger.error(f"Error updating user role: {e!s}", exc_info=True) raise async def get_organization_members( @@ -336,8 +348,8 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp organization_id: UUID, skip: int = 0, limit: int = 100, - is_active: bool = True - ) -> tuple[List[Dict[str, Any]], int]: + is_active: bool = True, + ) -> tuple[list[dict[str, Any]], int]: """ Get members of an organization with user details. @@ -359,46 +371,55 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp count_query = select(func.count()).select_from( select(UserOrganization) .where(UserOrganization.organization_id == organization_id) - .where(UserOrganization.is_active == is_active if is_active is not None else True) + .where( + UserOrganization.is_active == is_active + if is_active is not None + else True + ) .alias() ) count_result = await db.execute(count_query) total = count_result.scalar_one() # Apply ordering and pagination - query = query.order_by(UserOrganization.created_at.desc()).offset(skip).limit(limit) + query = ( + query.order_by(UserOrganization.created_at.desc()) + .offset(skip) + .limit(limit) + ) result = await db.execute(query) results = result.all() members = [] for user_org, user in results: - members.append({ - "user_id": user.id, - "email": user.email, - "first_name": user.first_name, - "last_name": user.last_name, - "role": user_org.role, - "is_active": user_org.is_active, - "joined_at": user_org.created_at - }) + members.append( + { + "user_id": user.id, + "email": user.email, + "first_name": user.first_name, + "last_name": user.last_name, + "role": user_org.role, + "is_active": user_org.is_active, + "joined_at": user_org.created_at, + } + ) return members, total except Exception as e: - logger.error(f"Error getting organization members: {str(e)}") + logger.error(f"Error getting organization members: {e!s}") raise async def get_user_organizations( - self, - db: AsyncSession, - *, - user_id: UUID, - is_active: bool = True - ) -> List[Organization]: + self, db: AsyncSession, *, user_id: UUID, is_active: bool = True + ) -> list[Organization]: """Get all organizations a user belongs to.""" try: query = ( select(Organization) - .join(UserOrganization, Organization.id == UserOrganization.organization_id) + .join( + UserOrganization, + Organization.id == UserOrganization.organization_id, + ) .where(UserOrganization.user_id == user_id) ) @@ -408,16 +429,12 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp result = await db.execute(query) return list(result.scalars().all()) except Exception as e: - logger.error(f"Error getting user organizations: {str(e)}") + logger.error(f"Error getting user organizations: {e!s}") raise async def get_user_organizations_with_details( - self, - db: AsyncSession, - *, - user_id: UUID, - is_active: bool = True - ) -> List[Dict[str, Any]]: + self, db: AsyncSession, *, user_id: UUID, is_active: bool = True + ) -> list[dict[str, Any]]: """ Get user's organizations with role and member count in SINGLE QUERY. Eliminates N+1 problem by using subquery for member counts. @@ -430,9 +447,9 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp member_count_subq = ( select( UserOrganization.organization_id, - func.count(UserOrganization.user_id).label('member_count') + func.count(UserOrganization.user_id).label("member_count"), ) - .where(UserOrganization.is_active == True) + .where(UserOrganization.is_active) .group_by(UserOrganization.organization_id) .subquery() ) @@ -442,10 +459,18 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp select( Organization, UserOrganization.role, - func.coalesce(member_count_subq.c.member_count, 0).label('member_count') + func.coalesce(member_count_subq.c.member_count, 0).label( + "member_count" + ), + ) + .join( + UserOrganization, + Organization.id == UserOrganization.organization_id, + ) + .outerjoin( + member_count_subq, + Organization.id == member_count_subq.c.organization_id, ) - .join(UserOrganization, Organization.id == UserOrganization.organization_id) - .outerjoin(member_count_subq, Organization.id == member_count_subq.c.organization_id) .where(UserOrganization.user_id == user_id) ) @@ -456,25 +481,19 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp rows = result.all() return [ - { - 'organization': org, - 'role': role, - 'member_count': member_count - } + {"organization": org, "role": role, "member_count": member_count} for org, role, member_count in rows ] except Exception as e: - logger.error(f"Error getting user organizations with details: {str(e)}", exc_info=True) + logger.error( + f"Error getting user organizations with details: {e!s}", exc_info=True + ) raise async def get_user_role_in_org( - self, - db: AsyncSession, - *, - user_id: UUID, - organization_id: UUID - ) -> Optional[OrganizationRole]: + self, db: AsyncSession, *, user_id: UUID, organization_id: UUID + ) -> OrganizationRole | None: """Get a user's role in a specific organization.""" try: result = await db.execute( @@ -482,7 +501,7 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp and_( UserOrganization.user_id == user_id, UserOrganization.organization_id == organization_id, - UserOrganization.is_active == True + UserOrganization.is_active, ) ) ) @@ -490,29 +509,25 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp return user_org.role if user_org else None except Exception as e: - logger.error(f"Error getting user role in org: {str(e)}") + logger.error(f"Error getting user role in org: {e!s}") raise async def is_user_org_owner( - self, - db: AsyncSession, - *, - user_id: UUID, - organization_id: UUID + self, db: AsyncSession, *, user_id: UUID, organization_id: UUID ) -> bool: """Check if a user is an owner of an organization.""" - role = await self.get_user_role_in_org(db, user_id=user_id, organization_id=organization_id) + role = await self.get_user_role_in_org( + db, user_id=user_id, organization_id=organization_id + ) return role == OrganizationRole.OWNER async def is_user_org_admin( - self, - db: AsyncSession, - *, - user_id: UUID, - organization_id: UUID + self, db: AsyncSession, *, user_id: UUID, organization_id: UUID ) -> bool: """Check if a user is an owner or admin of an organization.""" - role = await self.get_user_role_in_org(db, user_id=user_id, organization_id=organization_id) + role = await self.get_user_role_in_org( + db, user_id=user_id, organization_id=organization_id + ) return role in [OrganizationRole.OWNER, OrganizationRole.ADMIN] diff --git a/backend/app/crud/session.py b/backend/app/crud/session.py index 71ddce4..528a7ce 100755 --- a/backend/app/crud/session.py +++ b/backend/app/crud/session.py @@ -1,13 +1,13 @@ """ Async CRUD operations for user sessions using SQLAlchemy 2.0 patterns. """ + import logging import uuid -from datetime import datetime, timezone, timedelta -from typing import List, Optional +from datetime import UTC, datetime, timedelta from uuid import UUID -from sqlalchemy import and_, select, update, delete, func +from sqlalchemy import and_, delete, func, select, update from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm import joinedload @@ -21,7 +21,7 @@ logger = logging.getLogger(__name__) class CRUDSession(CRUDBase[UserSession, SessionCreate, SessionUpdate]): """Async CRUD operations for user sessions.""" - async def get_by_jti(self, db: AsyncSession, *, jti: str) -> Optional[UserSession]: + async def get_by_jti(self, db: AsyncSession, *, jti: str) -> UserSession | None: """ Get session by refresh token JTI. @@ -38,10 +38,12 @@ class CRUDSession(CRUDBase[UserSession, SessionCreate, SessionUpdate]): ) return result.scalar_one_or_none() except Exception as e: - logger.error(f"Error getting session by JTI {jti}: {str(e)}") + logger.error(f"Error getting session by JTI {jti}: {e!s}") raise - async def get_active_by_jti(self, db: AsyncSession, *, jti: str) -> Optional[UserSession]: + async def get_active_by_jti( + self, db: AsyncSession, *, jti: str + ) -> UserSession | None: """ Get active session by refresh token JTI. @@ -57,13 +59,13 @@ class CRUDSession(CRUDBase[UserSession, SessionCreate, SessionUpdate]): select(UserSession).where( and_( UserSession.refresh_token_jti == jti, - UserSession.is_active == True + UserSession.is_active, ) ) ) return result.scalar_one_or_none() except Exception as e: - logger.error(f"Error getting active session by JTI {jti}: {str(e)}") + logger.error(f"Error getting active session by JTI {jti}: {e!s}") raise async def get_user_sessions( @@ -72,8 +74,8 @@ class CRUDSession(CRUDBase[UserSession, SessionCreate, SessionUpdate]): *, user_id: str, active_only: bool = True, - with_user: bool = False - ) -> List[UserSession]: + with_user: bool = False, + ) -> list[UserSession]: """ Get all sessions for a user with optional eager loading. @@ -97,20 +99,17 @@ class CRUDSession(CRUDBase[UserSession, SessionCreate, SessionUpdate]): query = query.options(joinedload(UserSession.user)) if active_only: - query = query.where(UserSession.is_active == True) + query = query.where(UserSession.is_active) query = query.order_by(UserSession.last_used_at.desc()) result = await db.execute(query) return list(result.scalars().all()) except Exception as e: - logger.error(f"Error getting sessions for user {user_id}: {str(e)}") + logger.error(f"Error getting sessions for user {user_id}: {e!s}") raise async def create_session( - self, - db: AsyncSession, - *, - obj_in: SessionCreate + self, db: AsyncSession, *, obj_in: SessionCreate ) -> UserSession: """ Create a new user session. @@ -151,10 +150,12 @@ class CRUDSession(CRUDBase[UserSession, SessionCreate, SessionUpdate]): return db_obj except Exception as e: await db.rollback() - logger.error(f"Error creating session: {str(e)}", exc_info=True) - raise ValueError(f"Failed to create session: {str(e)}") + logger.error(f"Error creating session: {e!s}", exc_info=True) + raise ValueError(f"Failed to create session: {e!s}") - async def deactivate(self, db: AsyncSession, *, session_id: str) -> Optional[UserSession]: + async def deactivate( + self, db: AsyncSession, *, session_id: str + ) -> UserSession | None: """ Deactivate a session (logout from device). @@ -184,14 +185,11 @@ class CRUDSession(CRUDBase[UserSession, SessionCreate, SessionUpdate]): return session except Exception as e: await db.rollback() - logger.error(f"Error deactivating session {session_id}: {str(e)}") + logger.error(f"Error deactivating session {session_id}: {e!s}") raise async def deactivate_all_user_sessions( - self, - db: AsyncSession, - *, - user_id: str + self, db: AsyncSession, *, user_id: str ) -> int: """ Deactivate all active sessions for a user (logout from all devices). @@ -209,12 +207,7 @@ class CRUDSession(CRUDBase[UserSession, SessionCreate, SessionUpdate]): stmt = ( update(UserSession) - .where( - and_( - UserSession.user_id == user_uuid, - UserSession.is_active == True - ) - ) + .where(and_(UserSession.user_id == user_uuid, UserSession.is_active)) .values(is_active=False) ) @@ -228,14 +221,11 @@ class CRUDSession(CRUDBase[UserSession, SessionCreate, SessionUpdate]): return count except Exception as e: await db.rollback() - logger.error(f"Error deactivating all sessions for user {user_id}: {str(e)}") + logger.error(f"Error deactivating all sessions for user {user_id}: {e!s}") raise async def update_last_used( - self, - db: AsyncSession, - *, - session: UserSession + self, db: AsyncSession, *, session: UserSession ) -> UserSession: """ Update the last_used_at timestamp for a session. @@ -248,14 +238,14 @@ class CRUDSession(CRUDBase[UserSession, SessionCreate, SessionUpdate]): Updated UserSession """ try: - session.last_used_at = datetime.now(timezone.utc) + session.last_used_at = datetime.now(UTC) db.add(session) await db.commit() await db.refresh(session) return session except Exception as e: await db.rollback() - logger.error(f"Error updating last_used for session {session.id}: {str(e)}") + logger.error(f"Error updating last_used for session {session.id}: {e!s}") raise async def update_refresh_token( @@ -264,7 +254,7 @@ class CRUDSession(CRUDBase[UserSession, SessionCreate, SessionUpdate]): *, session: UserSession, new_jti: str, - new_expires_at: datetime + new_expires_at: datetime, ) -> UserSession: """ Update session with new refresh token JTI and expiration. @@ -283,14 +273,16 @@ class CRUDSession(CRUDBase[UserSession, SessionCreate, SessionUpdate]): try: session.refresh_token_jti = new_jti session.expires_at = new_expires_at - session.last_used_at = datetime.now(timezone.utc) + session.last_used_at = datetime.now(UTC) db.add(session) await db.commit() await db.refresh(session) return session except Exception as e: await db.rollback() - logger.error(f"Error updating refresh token for session {session.id}: {str(e)}") + logger.error( + f"Error updating refresh token for session {session.id}: {e!s}" + ) raise async def cleanup_expired(self, db: AsyncSession, *, keep_days: int = 30) -> int: @@ -311,15 +303,15 @@ class CRUDSession(CRUDBase[UserSession, SessionCreate, SessionUpdate]): Number of sessions deleted """ try: - cutoff_date = datetime.now(timezone.utc) - timedelta(days=keep_days) - now = datetime.now(timezone.utc) + cutoff_date = datetime.now(UTC) - timedelta(days=keep_days) + now = datetime.now(UTC) # Use bulk DELETE with WHERE clause - single query stmt = delete(UserSession).where( and_( - UserSession.is_active == False, + not UserSession.is_active, UserSession.expires_at < now, - UserSession.created_at < cutoff_date + UserSession.created_at < cutoff_date, ) ) @@ -334,15 +326,10 @@ class CRUDSession(CRUDBase[UserSession, SessionCreate, SessionUpdate]): return count except Exception as e: await db.rollback() - logger.error(f"Error cleaning up expired sessions: {str(e)}") + logger.error(f"Error cleaning up expired sessions: {e!s}") raise - async def cleanup_expired_for_user( - self, - db: AsyncSession, - *, - user_id: str - ) -> int: + async def cleanup_expired_for_user(self, db: AsyncSession, *, user_id: str) -> int: """ Clean up expired and inactive sessions for a specific user. @@ -363,14 +350,14 @@ class CRUDSession(CRUDBase[UserSession, SessionCreate, SessionUpdate]): logger.error(f"Invalid UUID format: {user_id}") raise ValueError(f"Invalid user ID format: {user_id}") - now = datetime.now(timezone.utc) + now = datetime.now(UTC) # Use bulk DELETE with WHERE clause - single query stmt = delete(UserSession).where( and_( UserSession.user_id == uuid_obj, - UserSession.is_active == False, - UserSession.expires_at < now + not UserSession.is_active, + UserSession.expires_at < now, ) ) @@ -388,7 +375,7 @@ class CRUDSession(CRUDBase[UserSession, SessionCreate, SessionUpdate]): except Exception as e: await db.rollback() logger.error( - f"Error cleaning up expired sessions for user {user_id}: {str(e)}" + f"Error cleaning up expired sessions for user {user_id}: {e!s}" ) raise @@ -409,15 +396,12 @@ class CRUDSession(CRUDBase[UserSession, SessionCreate, SessionUpdate]): result = await db.execute( select(func.count(UserSession.id)).where( - and_( - UserSession.user_id == user_uuid, - UserSession.is_active == True - ) + and_(UserSession.user_id == user_uuid, UserSession.is_active) ) ) return result.scalar_one() except Exception as e: - logger.error(f"Error counting sessions for user {user_id}: {str(e)}") + logger.error(f"Error counting sessions for user {user_id}: {e!s}") raise async def get_all_sessions( @@ -427,8 +411,8 @@ class CRUDSession(CRUDBase[UserSession, SessionCreate, SessionUpdate]): skip: int = 0, limit: int = 100, active_only: bool = True, - with_user: bool = True - ) -> tuple[List[UserSession], int]: + with_user: bool = True, + ) -> tuple[list[UserSession], int]: """ Get all sessions across all users with pagination (admin only). @@ -451,18 +435,22 @@ class CRUDSession(CRUDBase[UserSession, SessionCreate, SessionUpdate]): query = query.options(joinedload(UserSession.user)) if active_only: - query = query.where(UserSession.is_active == True) + query = query.where(UserSession.is_active) # Get total count count_query = select(func.count(UserSession.id)) if active_only: - count_query = count_query.where(UserSession.is_active == True) + count_query = count_query.where(UserSession.is_active) count_result = await db.execute(count_query) total = count_result.scalar_one() # Apply pagination and ordering - query = query.order_by(UserSession.last_used_at.desc()).offset(skip).limit(limit) + query = ( + query.order_by(UserSession.last_used_at.desc()) + .offset(skip) + .limit(limit) + ) result = await db.execute(query) sessions = list(result.scalars().all()) @@ -470,7 +458,7 @@ class CRUDSession(CRUDBase[UserSession, SessionCreate, SessionUpdate]): return sessions, total except Exception as e: - logger.error(f"Error getting all sessions: {str(e)}", exc_info=True) + logger.error(f"Error getting all sessions: {e!s}", exc_info=True) raise diff --git a/backend/app/crud/user.py b/backend/app/crud/user.py index 3efe634..d938303 100755 --- a/backend/app/crud/user.py +++ b/backend/app/crud/user.py @@ -1,8 +1,9 @@ # app/crud/user_async.py """Async CRUD operations for User model using SQLAlchemy 2.0 patterns.""" + import logging -from datetime import datetime, timezone -from typing import Optional, Union, Dict, Any, List, Tuple +from datetime import UTC, datetime +from typing import Any from uuid import UUID from sqlalchemy import or_, select, update @@ -20,15 +21,13 @@ logger = logging.getLogger(__name__) class CRUDUser(CRUDBase[User, UserCreate, UserUpdate]): """Async CRUD operations for User model.""" - async def get_by_email(self, db: AsyncSession, *, email: str) -> Optional[User]: + async def get_by_email(self, db: AsyncSession, *, email: str) -> User | None: """Get user by email address.""" try: - result = await db.execute( - select(User).where(User.email == email) - ) + result = await db.execute(select(User).where(User.email == email)) return result.scalar_one_or_none() except Exception as e: - logger.error(f"Error getting user by email {email}: {str(e)}") + logger.error(f"Error getting user by email {email}: {e!s}") raise async def create(self, db: AsyncSession, *, obj_in: UserCreate) -> User: @@ -42,9 +41,13 @@ class CRUDUser(CRUDBase[User, UserCreate, UserUpdate]): password_hash=password_hash, first_name=obj_in.first_name, last_name=obj_in.last_name, - phone_number=obj_in.phone_number if hasattr(obj_in, 'phone_number') else None, - is_superuser=obj_in.is_superuser if hasattr(obj_in, 'is_superuser') else False, - preferences={} + phone_number=obj_in.phone_number + if hasattr(obj_in, "phone_number") + else None, + is_superuser=obj_in.is_superuser + if hasattr(obj_in, "is_superuser") + else False, + preferences={}, ) db.add(db_obj) await db.commit() @@ -52,7 +55,7 @@ class CRUDUser(CRUDBase[User, UserCreate, UserUpdate]): return db_obj except IntegrityError as e: await db.rollback() - error_msg = str(e.orig) if hasattr(e, 'orig') else str(e) + error_msg = str(e.orig) if hasattr(e, "orig") else str(e) if "email" in error_msg.lower(): logger.warning(f"Duplicate email attempted: {obj_in.email}") raise ValueError(f"User with email {obj_in.email} already exists") @@ -60,15 +63,11 @@ class CRUDUser(CRUDBase[User, UserCreate, UserUpdate]): raise ValueError(f"Database integrity error: {error_msg}") except Exception as e: await db.rollback() - logger.error(f"Unexpected error creating user: {str(e)}", exc_info=True) + logger.error(f"Unexpected error creating user: {e!s}", exc_info=True) raise async def update( - self, - db: AsyncSession, - *, - db_obj: User, - obj_in: Union[UserUpdate, Dict[str, Any]] + self, db: AsyncSession, *, db_obj: User, obj_in: UserUpdate | dict[str, Any] ) -> User: """Update user with async password hashing if password is updated.""" if isinstance(obj_in, dict): @@ -79,7 +78,9 @@ class CRUDUser(CRUDBase[User, UserCreate, UserUpdate]): # Handle password separately if it exists in update data # Hash password asynchronously to avoid blocking event loop if "password" in update_data: - update_data["password_hash"] = await get_password_hash_async(update_data["password"]) + update_data["password_hash"] = await get_password_hash_async( + update_data["password"] + ) del update_data["password"] return await super().update(db, db_obj=db_obj, obj_in=update_data) @@ -90,11 +91,11 @@ class CRUDUser(CRUDBase[User, UserCreate, UserUpdate]): *, skip: int = 0, limit: int = 100, - sort_by: Optional[str] = None, + sort_by: str | None = None, sort_order: str = "asc", - filters: Optional[Dict[str, Any]] = None, - search: Optional[str] = None - ) -> Tuple[List[User], int]: + filters: dict[str, Any] | None = None, + search: str | None = None, + ) -> tuple[list[User], int]: """ Get multiple users with total count, filtering, sorting, and search. @@ -136,12 +137,13 @@ class CRUDUser(CRUDBase[User, UserCreate, UserUpdate]): search_filter = or_( User.email.ilike(f"%{search}%"), User.first_name.ilike(f"%{search}%"), - User.last_name.ilike(f"%{search}%") + User.last_name.ilike(f"%{search}%"), ) query = query.where(search_filter) # Get total count from sqlalchemy import func + count_query = select(func.count()).select_from(query.alias()) count_result = await db.execute(count_query) total = count_result.scalar_one() @@ -162,15 +164,11 @@ class CRUDUser(CRUDBase[User, UserCreate, UserUpdate]): return users, total except Exception as e: - logger.error(f"Error retrieving paginated users: {str(e)}") + logger.error(f"Error retrieving paginated users: {e!s}") raise async def bulk_update_status( - self, - db: AsyncSession, - *, - user_ids: List[UUID], - is_active: bool + self, db: AsyncSession, *, user_ids: list[UUID], is_active: bool ) -> int: """ Bulk update is_active status for multiple users. @@ -192,7 +190,7 @@ class CRUDUser(CRUDBase[User, UserCreate, UserUpdate]): update(User) .where(User.id.in_(user_ids)) .where(User.deleted_at.is_(None)) # Don't update deleted users - .values(is_active=is_active, updated_at=datetime.now(timezone.utc)) + .values(is_active=is_active, updated_at=datetime.now(UTC)) ) result = await db.execute(stmt) @@ -204,15 +202,15 @@ class CRUDUser(CRUDBase[User, UserCreate, UserUpdate]): except Exception as e: await db.rollback() - logger.error(f"Error bulk updating user status: {str(e)}", exc_info=True) + logger.error(f"Error bulk updating user status: {e!s}", exc_info=True) raise async def bulk_soft_delete( self, db: AsyncSession, *, - user_ids: List[UUID], - exclude_user_id: Optional[UUID] = None + user_ids: list[UUID], + exclude_user_id: UUID | None = None, ) -> int: """ Bulk soft delete multiple users. @@ -239,11 +237,13 @@ class CRUDUser(CRUDBase[User, UserCreate, UserUpdate]): stmt = ( update(User) .where(User.id.in_(filtered_ids)) - .where(User.deleted_at.is_(None)) # Don't re-delete already deleted users + .where( + User.deleted_at.is_(None) + ) # Don't re-delete already deleted users .values( - deleted_at=datetime.now(timezone.utc), + deleted_at=datetime.now(UTC), is_active=False, - updated_at=datetime.now(timezone.utc) + updated_at=datetime.now(UTC), ) ) @@ -256,7 +256,7 @@ class CRUDUser(CRUDBase[User, UserCreate, UserUpdate]): except Exception as e: await db.rollback() - logger.error(f"Error bulk deleting users: {str(e)}", exc_info=True) + logger.error(f"Error bulk deleting users: {e!s}", exc_info=True) raise def is_active(self, user: User) -> bool: diff --git a/backend/app/init_db.py b/backend/app/init_db.py index eaafafe..ba0b61a 100644 --- a/backend/app/init_db.py +++ b/backend/app/init_db.py @@ -4,9 +4,9 @@ Async database initialization script. Creates the first superuser if configured and doesn't already exist. """ + import asyncio import logging -from typing import Optional from app.core.config import settings from app.core.database import SessionLocal, engine @@ -17,7 +17,7 @@ from app.schemas.users import UserCreate logger = logging.getLogger(__name__) -async def init_db() -> Optional[User]: +async def init_db() -> User | None: """ Initialize database with first superuser if settings are configured and user doesn't exist. @@ -49,7 +49,7 @@ async def init_db() -> Optional[User]: password=superuser_password, first_name="Admin", last_name="User", - is_superuser=True + is_superuser=True, ) user = await user_crud.create(session, obj_in=user_in) @@ -70,13 +70,13 @@ async def main(): # Configure logging to show info logs logging.basicConfig( level=logging.INFO, - format="%(asctime)s - %(name)s - %(levelname)s - %(message)s" + format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", ) try: user = await init_db() if user: - print(f"✓ Database initialized successfully") + print("✓ Database initialized successfully") print(f"✓ Superuser: {user.email}") else: print("✗ Failed to initialize database") diff --git a/backend/app/main.py b/backend/app/main.py index 2ebee9a..2b5d16d 100755 --- a/backend/app/main.py +++ b/backend/app/main.py @@ -2,10 +2,10 @@ import logging import os from contextlib import asynccontextmanager from datetime import datetime -from typing import Dict, Any +from typing import Any from apscheduler.schedulers.asyncio import AsyncIOScheduler -from fastapi import FastAPI, status, Request, HTTPException +from fastapi import FastAPI, HTTPException, Request, status from fastapi.exceptions import RequestValidationError from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import HTMLResponse, JSONResponse @@ -19,9 +19,9 @@ from app.core.database import check_database_health from app.core.exceptions import ( APIException, api_exception_handler, - validation_exception_handler, http_exception_handler, - unhandled_exception_handler + unhandled_exception_handler, + validation_exception_handler, ) scheduler = AsyncIOScheduler() @@ -52,11 +52,11 @@ async def lifespan(app: FastAPI): # Runs daily at 2:00 AM server time scheduler.add_job( cleanup_expired_sessions, - 'cron', + "cron", hour=2, minute=0, - id='cleanup_expired_sessions', - replace_existing=True + id="cleanup_expired_sessions", + replace_existing=True, ) scheduler.start() @@ -73,12 +73,12 @@ async def lifespan(app: FastAPI): logger.info("Scheduled jobs stopped") -logger.info(f"Starting app!!!") +logger.info("Starting app!!!") app = FastAPI( title=settings.PROJECT_NAME, version=settings.VERSION, openapi_url=f"{settings.API_V1_STR}/openapi.json", - lifespan=lifespan + lifespan=lifespan, ) # Add rate limiter state to app @@ -96,7 +96,14 @@ app.add_middleware( CORSMiddleware, allow_origins=settings.BACKEND_CORS_ORIGINS, allow_credentials=True, - allow_methods=["GET", "POST", "PUT", "PATCH", "DELETE", "OPTIONS"], # Explicit methods only + allow_methods=[ + "GET", + "POST", + "PUT", + "PATCH", + "DELETE", + "OPTIONS", + ], # Explicit methods only allow_headers=[ "Content-Type", "Authorization", @@ -129,12 +136,14 @@ async def limit_request_size(request: Request, call_next): status_code=status.HTTP_413_REQUEST_ENTITY_TOO_LARGE, content={ "success": False, - "errors": [{ - "code": "REQUEST_TOO_LARGE", - "message": f"Request body too large. Maximum size is {MAX_REQUEST_SIZE // (1024 * 1024)}MB", - "field": None - }] - } + "errors": [ + { + "code": "REQUEST_TOO_LARGE", + "message": f"Request body too large. Maximum size is {MAX_REQUEST_SIZE // (1024 * 1024)}MB", + "field": None, + } + ], + }, ) response = await call_next(request) @@ -165,15 +174,19 @@ async def add_security_headers(request: Request, call_next): # Enforce HTTPS in production if settings.ENVIRONMENT == "production": - response.headers["Strict-Transport-Security"] = "max-age=31536000; includeSubDomains" + response.headers["Strict-Transport-Security"] = ( + "max-age=31536000; includeSubDomains" + ) # Content Security Policy csp_mode = settings.CSP_MODE.lower() # Special handling for API docs - is_docs = request.url.path in ["/docs", "/redoc"] or \ - request.url.path.startswith("/docs/") or \ - request.url.path.startswith("/redoc/") + is_docs = ( + request.url.path in ["/docs", "/redoc"] + or request.url.path.startswith("/docs/") + or request.url.path.startswith("/redoc/") + ) if csp_mode == "disabled": # No CSP (only for local development/debugging) @@ -264,7 +277,7 @@ async def root(): description="Check the health status of the API and its dependencies", response_description="Health status information", tags=["Health"], - operation_id="health_check" + operation_id="health_check", ) async def health_check() -> JSONResponse: """ @@ -278,12 +291,12 @@ async def health_check() -> JSONResponse: - environment: Current environment (development, staging, production) - database: Database connectivity status """ - health_status: Dict[str, Any] = { + health_status: dict[str, Any] = { "status": "healthy", "timestamp": datetime.utcnow().isoformat() + "Z", "version": settings.VERSION, "environment": settings.ENVIRONMENT, - "checks": {} + "checks": {}, } response_status = status.HTTP_200_OK @@ -294,7 +307,7 @@ async def health_check() -> JSONResponse: if db_healthy: health_status["checks"]["database"] = { "status": "healthy", - "message": "Database connection successful" + "message": "Database connection successful", } else: raise Exception("Database health check returned unhealthy status") @@ -302,15 +315,12 @@ async def health_check() -> JSONResponse: health_status["status"] = "unhealthy" health_status["checks"]["database"] = { "status": "unhealthy", - "message": f"Database connection failed: {str(e)}" + "message": f"Database connection failed: {e!s}", } response_status = status.HTTP_503_SERVICE_UNAVAILABLE logger.error(f"Health check failed - database error: {e}") - return JSONResponse( - status_code=response_status, - content=health_status - ) + return JSONResponse(status_code=response_status, content=health_status) app.include_router(api_router, prefix=settings.API_V1_STR) diff --git a/backend/app/models/__init__.py b/backend/app/models/__init__.py index 581caf6..5f476b4 100644 --- a/backend/app/models/__init__.py +++ b/backend/app/models/__init__.py @@ -2,17 +2,25 @@ Models package initialization. Imports all models to ensure they're registered with SQLAlchemy. """ + # First import Base to avoid circular imports from app.core.database import Base + from .base import TimestampMixin, UUIDMixin from .organization import Organization + # Import models from .user import User -from .user_organization import UserOrganization, OrganizationRole +from .user_organization import OrganizationRole, UserOrganization from .user_session import UserSession __all__ = [ - 'Base', 'TimestampMixin', 'UUIDMixin', - 'User', 'UserSession', - 'Organization', 'UserOrganization', 'OrganizationRole', -] \ No newline at end of file + "Base", + "Organization", + "OrganizationRole", + "TimestampMixin", + "UUIDMixin", + "User", + "UserOrganization", + "UserSession", +] diff --git a/backend/app/models/base.py b/backend/app/models/base.py index 5a6f55e..7a8f889 100644 --- a/backend/app/models/base.py +++ b/backend/app/models/base.py @@ -1,20 +1,27 @@ import uuid -from datetime import datetime, timezone +from datetime import UTC, datetime from sqlalchemy import Column, DateTime from sqlalchemy.dialects.postgresql import UUID # noinspection PyUnresolvedReferences -from app.core.database import Base class TimestampMixin: """Mixin to add created_at and updated_at timestamps to models""" - created_at = Column(DateTime(timezone=True), default=lambda: datetime.now(timezone.utc), nullable=False) - updated_at = Column(DateTime(timezone=True), default=lambda: datetime.now(timezone.utc), - onupdate=lambda: datetime.now(timezone.utc), nullable=False) + + created_at = Column( + DateTime(timezone=True), default=lambda: datetime.now(UTC), nullable=False + ) + updated_at = Column( + DateTime(timezone=True), + default=lambda: datetime.now(UTC), + onupdate=lambda: datetime.now(UTC), + nullable=False, + ) class UUIDMixin: """Mixin to add UUID primary keys to models""" + id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) diff --git a/backend/app/models/organization.py b/backend/app/models/organization.py index b81c7cb..5a3d2b2 100644 --- a/backend/app/models/organization.py +++ b/backend/app/models/organization.py @@ -1,5 +1,5 @@ # app/models/organization.py -from sqlalchemy import Column, String, Boolean, Text, Index +from sqlalchemy import Boolean, Column, Index, String, Text from sqlalchemy.dialects.postgresql import JSONB from sqlalchemy.orm import relationship @@ -11,7 +11,8 @@ class Organization(Base, UUIDMixin, TimestampMixin): Organization model for multi-tenant support. Users can belong to multiple organizations with different roles. """ - __tablename__ = 'organizations' + + __tablename__ = "organizations" name = Column(String(255), nullable=False, index=True) slug = Column(String(255), unique=True, nullable=False, index=True) @@ -20,11 +21,13 @@ class Organization(Base, UUIDMixin, TimestampMixin): settings = Column(JSONB, default={}) # Relationships - user_organizations = relationship("UserOrganization", back_populates="organization", cascade="all, delete-orphan") + user_organizations = relationship( + "UserOrganization", back_populates="organization", cascade="all, delete-orphan" + ) __table_args__ = ( - Index('ix_organizations_name_active', 'name', 'is_active'), - Index('ix_organizations_slug_active', 'slug', 'is_active'), + Index("ix_organizations_name_active", "name", "is_active"), + Index("ix_organizations_slug_active", "slug", "is_active"), ) def __repr__(self): diff --git a/backend/app/models/user.py b/backend/app/models/user.py index d5eb715..c6604e5 100644 --- a/backend/app/models/user.py +++ b/backend/app/models/user.py @@ -1,4 +1,4 @@ -from sqlalchemy import Column, String, Boolean, DateTime +from sqlalchemy import Boolean, Column, DateTime, String from sqlalchemy.dialects.postgresql import JSONB from sqlalchemy.orm import relationship @@ -6,7 +6,7 @@ from .base import Base, TimestampMixin, UUIDMixin class User(Base, UUIDMixin, TimestampMixin): - __tablename__ = 'users' + __tablename__ = "users" email = Column(String(255), unique=True, nullable=False, index=True) password_hash = Column(String(255), nullable=False) @@ -19,7 +19,9 @@ class User(Base, UUIDMixin, TimestampMixin): deleted_at = Column(DateTime(timezone=True), nullable=True, index=True) # Relationships - user_organizations = relationship("UserOrganization", back_populates="user", cascade="all, delete-orphan") + user_organizations = relationship( + "UserOrganization", back_populates="user", cascade="all, delete-orphan" + ) def __repr__(self): - return f"" \ No newline at end of file + return f"" diff --git a/backend/app/models/user_organization.py b/backend/app/models/user_organization.py index 1102439..178ddaa 100644 --- a/backend/app/models/user_organization.py +++ b/backend/app/models/user_organization.py @@ -1,7 +1,7 @@ # app/models/user_organization.py from enum import Enum as PyEnum -from sqlalchemy import Column, ForeignKey, Boolean, String, Index, Enum +from sqlalchemy import Boolean, Column, Enum, ForeignKey, Index, String from sqlalchemy.dialects.postgresql import UUID as PGUUID from sqlalchemy.orm import relationship @@ -14,6 +14,7 @@ class OrganizationRole(str, PyEnum): These provide a baseline role system that can be optionally used. Projects can extend this or implement their own permission system. """ + OWNER = "owner" # Full control over organization ADMIN = "admin" # Can manage users and settings MEMBER = "member" # Regular member with standard access @@ -25,25 +26,41 @@ class UserOrganization(Base, TimestampMixin): Junction table for many-to-many relationship between Users and Organizations. Includes role information for flexible RBAC. """ - __tablename__ = 'user_organizations' - user_id = Column(PGUUID(as_uuid=True), ForeignKey('users.id', ondelete='CASCADE'), primary_key=True) - organization_id = Column(PGUUID(as_uuid=True), ForeignKey('organizations.id', ondelete='CASCADE'), primary_key=True) + __tablename__ = "user_organizations" - role = Column(Enum(OrganizationRole), default=OrganizationRole.MEMBER, nullable=False, index=True) + user_id = Column( + PGUUID(as_uuid=True), + ForeignKey("users.id", ondelete="CASCADE"), + primary_key=True, + ) + organization_id = Column( + PGUUID(as_uuid=True), + ForeignKey("organizations.id", ondelete="CASCADE"), + primary_key=True, + ) + + role = Column( + Enum(OrganizationRole), + default=OrganizationRole.MEMBER, + nullable=False, + index=True, + ) is_active = Column(Boolean, default=True, nullable=False, index=True) # Optional: Custom permissions override for specific users - custom_permissions = Column(String(500), nullable=True) # JSON array of permission strings + custom_permissions = Column( + String(500), nullable=True + ) # JSON array of permission strings # Relationships user = relationship("User", back_populates="user_organizations") organization = relationship("Organization", back_populates="user_organizations") __table_args__ = ( - Index('ix_user_org_user_active', 'user_id', 'is_active'), - Index('ix_user_org_org_active', 'organization_id', 'is_active'), - Index('ix_user_org_role', 'role'), + Index("ix_user_org_user_active", "user_id", "is_active"), + Index("ix_user_org_org_active", "organization_id", "is_active"), + Index("ix_user_org_role", "role"), ) def __repr__(self): diff --git a/backend/app/models/user_session.py b/backend/app/models/user_session.py index 781e07f..131abe0 100644 --- a/backend/app/models/user_session.py +++ b/backend/app/models/user_session.py @@ -6,7 +6,10 @@ This allows users to: - Logout from specific devices - Manage their active sessions """ -from sqlalchemy import Column, String, Boolean, DateTime, ForeignKey, Index + +from datetime import UTC + +from sqlalchemy import Boolean, Column, DateTime, ForeignKey, Index, String from sqlalchemy.dialects.postgresql import UUID from sqlalchemy.orm import relationship @@ -20,19 +23,27 @@ class UserSession(Base, UUIDMixin, TimestampMixin): Each time a user logs in from a device, a new session is created. Sessions are identified by the refresh token JTI (JWT ID). """ - __tablename__ = 'user_sessions' + + __tablename__ = "user_sessions" # Foreign key to user - user_id = Column(UUID(as_uuid=True), ForeignKey('users.id', ondelete='CASCADE'), nullable=False, index=True) + user_id = Column( + UUID(as_uuid=True), + ForeignKey("users.id", ondelete="CASCADE"), + nullable=False, + index=True, + ) # Refresh token identifier (JWT ID from the refresh token) refresh_token_jti = Column(String(255), unique=True, nullable=False, index=True) # Device information device_name = Column(String(255), nullable=True) # "iPhone 14", "Chrome on MacBook" - device_id = Column(String(255), nullable=True) # Persistent device identifier (from client) - ip_address = Column(String(45), nullable=True) # IPv4 (15 chars) or IPv6 (45 chars) - user_agent = Column(String(500), nullable=True) # Browser/app user agent + device_id = Column( + String(255), nullable=True + ) # Persistent device identifier (from client) + ip_address = Column(String(45), nullable=True) # IPv4 (15 chars) or IPv6 (45 chars) + user_agent = Column(String(500), nullable=True) # Browser/app user agent # Session timing last_used_at = Column(DateTime(timezone=True), nullable=False) @@ -50,8 +61,8 @@ class UserSession(Base, UUIDMixin, TimestampMixin): # Composite indexes for performance (defined in migration) __table_args__ = ( - Index('ix_user_sessions_user_active', 'user_id', 'is_active'), - Index('ix_user_sessions_jti_active', 'refresh_token_jti', 'is_active'), + Index("ix_user_sessions_user_active", "user_id", "is_active"), + Index("ix_user_sessions_jti_active", "refresh_token_jti", "is_active"), ) def __repr__(self): @@ -60,21 +71,24 @@ class UserSession(Base, UUIDMixin, TimestampMixin): @property def is_expired(self) -> bool: """Check if session has expired.""" - from datetime import datetime, timezone - return self.expires_at < datetime.now(timezone.utc) + from datetime import datetime + + return self.expires_at < datetime.now(UTC) def to_dict(self): """Convert session to dictionary for serialization.""" return { - 'id': str(self.id), - 'user_id': str(self.user_id), - 'device_name': self.device_name, - 'device_id': self.device_id, - 'ip_address': self.ip_address, - 'last_used_at': self.last_used_at.isoformat() if self.last_used_at else None, - 'expires_at': self.expires_at.isoformat() if self.expires_at else None, - 'is_active': self.is_active, - 'location_city': self.location_city, - 'location_country': self.location_country, - 'created_at': self.created_at.isoformat() if self.created_at else None, + "id": str(self.id), + "user_id": str(self.user_id), + "device_name": self.device_name, + "device_id": self.device_id, + "ip_address": self.ip_address, + "last_used_at": self.last_used_at.isoformat() + if self.last_used_at + else None, + "expires_at": self.expires_at.isoformat() if self.expires_at else None, + "is_active": self.is_active, + "location_city": self.location_city, + "location_country": self.location_country, + "created_at": self.created_at.isoformat() if self.created_at else None, } diff --git a/backend/app/schemas/common.py b/backend/app/schemas/common.py index a5f7a5d..c5f437d 100644 --- a/backend/app/schemas/common.py +++ b/backend/app/schemas/common.py @@ -1,18 +1,20 @@ """ Common schemas used across the API for pagination, responses, filtering, and sorting. """ + from enum import Enum from math import ceil -from typing import Generic, TypeVar, List, Optional +from typing import TypeVar from uuid import UUID from pydantic import BaseModel, Field -T = TypeVar('T') +T = TypeVar("T") class SortOrder(str, Enum): """Sort order options.""" + ASC = "asc" DESC = "desc" @@ -20,16 +22,9 @@ class SortOrder(str, Enum): class PaginationParams(BaseModel): """Parameters for pagination.""" - page: int = Field( - default=1, - ge=1, - description="Page number (1-indexed)" - ) + page: int = Field(default=1, ge=1, description="Page number (1-indexed)") limit: int = Field( - default=20, - ge=1, - le=100, - description="Number of items per page (max 100)" + default=20, ge=1, le=100, description="Number of items per page (max 100)" ) @property @@ -42,34 +37,20 @@ class PaginationParams(BaseModel): """Alias for offset (compatibility with existing code).""" return self.offset - model_config = { - "json_schema_extra": { - "example": { - "page": 1, - "limit": 20 - } - } - } + model_config = {"json_schema_extra": {"example": {"page": 1, "limit": 20}}} class SortParams(BaseModel): """Parameters for sorting.""" - sort_by: Optional[str] = Field( - default=None, - description="Field name to sort by" - ) + sort_by: str | None = Field(default=None, description="Field name to sort by") sort_order: SortOrder = Field( - default=SortOrder.ASC, - description="Sort order (asc or desc)" + default=SortOrder.ASC, description="Sort order (asc or desc)" ) model_config = { "json_schema_extra": { - "example": { - "sort_by": "created_at", - "sort_order": "desc" - } + "example": {"sort_by": "created_at", "sort_order": "desc"} } } @@ -92,32 +73,30 @@ class PaginationMeta(BaseModel): "page_size": 20, "total_pages": 8, "has_next": True, - "has_prev": False + "has_prev": False, } } } -class PaginatedResponse(BaseModel, Generic[T]): +class PaginatedResponse[T](BaseModel): """Generic paginated response wrapper.""" - data: List[T] = Field(..., description="List of items") + data: list[T] = Field(..., description="List of items") pagination: PaginationMeta = Field(..., description="Pagination metadata") model_config = { "json_schema_extra": { "example": { - "data": [ - {"id": "123", "name": "Example Item"} - ], + "data": [{"id": "123", "name": "Example Item"}], "pagination": { "total": 150, "page": 1, "page_size": 20, "total_pages": 8, "has_next": True, - "has_prev": False - } + "has_prev": False, + }, } } } @@ -131,10 +110,7 @@ class MessageResponse(BaseModel): model_config = { "json_schema_extra": { - "example": { - "success": True, - "message": "Operation completed successfully" - } + "example": {"success": True, "message": "Operation completed successfully"} } } @@ -142,11 +118,11 @@ class MessageResponse(BaseModel): class BulkActionRequest(BaseModel): """Request schema for bulk operations on multiple items.""" - ids: List[UUID] = Field( + ids: list[UUID] = Field( ..., min_length=1, max_length=100, - description="List of item IDs to perform action on (max 100)" + description="List of item IDs to perform action on (max 100)", ) model_config = { @@ -154,7 +130,7 @@ class BulkActionRequest(BaseModel): "example": { "ids": [ "550e8400-e29b-41d4-a716-446655440000", - "6ba7b810-9dad-11d1-80b4-00c04fd430c8" + "6ba7b810-9dad-11d1-80b4-00c04fd430c8", ] } } @@ -166,24 +142,23 @@ class BulkActionResponse(BaseModel): success: bool = Field(default=True, description="Operation success status") message: str = Field(..., description="Human-readable message") - affected_count: int = Field(..., description="Number of items affected by the operation") + affected_count: int = Field( + ..., description="Number of items affected by the operation" + ) model_config = { "json_schema_extra": { "example": { "success": True, "message": "Successfully deactivated 5 users", - "affected_count": 5 + "affected_count": 5, } } } def create_pagination_meta( - total: int, - page: int, - limit: int, - items_count: int + total: int, page: int, limit: int, items_count: int ) -> PaginationMeta: """ Helper function to create pagination metadata. @@ -205,5 +180,5 @@ def create_pagination_meta( page_size=items_count, total_pages=total_pages, has_next=page < total_pages, - has_prev=page > 1 + has_prev=page > 1, ) diff --git a/backend/app/schemas/errors.py b/backend/app/schemas/errors.py index 99c9a95..57d2973 100644 --- a/backend/app/schemas/errors.py +++ b/backend/app/schemas/errors.py @@ -1,8 +1,8 @@ """ Error schemas for standardized API error responses. """ + from enum import Enum -from typing import List, Optional from pydantic import BaseModel, Field @@ -53,14 +53,14 @@ class ErrorDetail(BaseModel): code: ErrorCode = Field(..., description="Machine-readable error code") message: str = Field(..., description="Human-readable error message") - field: Optional[str] = Field(None, description="Field name if error is field-specific") + field: str | None = Field(None, description="Field name if error is field-specific") model_config = { "json_schema_extra": { "example": { "code": "VAL_002", "message": "Password must be at least 8 characters long", - "field": "password" + "field": "password", } } } @@ -70,7 +70,7 @@ class ErrorResponse(BaseModel): """Standardized error response format.""" success: bool = Field(default=False, description="Always false for error responses") - errors: List[ErrorDetail] = Field(..., description="List of errors that occurred") + errors: list[ErrorDetail] = Field(..., description="List of errors that occurred") model_config = { "json_schema_extra": { @@ -80,9 +80,9 @@ class ErrorResponse(BaseModel): { "code": "AUTH_001", "message": "Invalid email or password", - "field": None + "field": None, } - ] + ], } } } diff --git a/backend/app/schemas/organizations.py b/backend/app/schemas/organizations.py index 2714b0c..b71e91e 100644 --- a/backend/app/schemas/organizations.py +++ b/backend/app/schemas/organizations.py @@ -1,10 +1,10 @@ # app/schemas/organizations.py import re from datetime import datetime -from typing import Optional, Dict, Any, List +from typing import Any from uuid import UUID -from pydantic import BaseModel, field_validator, ConfigDict, Field +from pydantic import BaseModel, ConfigDict, Field, field_validator from app.models.user_organization import OrganizationRole @@ -12,85 +12,94 @@ from app.models.user_organization import OrganizationRole # Organization Schemas class OrganizationBase(BaseModel): """Base organization schema with common fields.""" - name: str = Field(..., min_length=1, max_length=255) - slug: Optional[str] = Field(None, min_length=1, max_length=255) - description: Optional[str] = None - is_active: bool = True - settings: Optional[Dict[str, Any]] = {} - @field_validator('slug') + name: str = Field(..., min_length=1, max_length=255) + slug: str | None = Field(None, min_length=1, max_length=255) + description: str | None = None + is_active: bool = True + settings: dict[str, Any] | None = {} + + @field_validator("slug") @classmethod - def validate_slug(cls, v: Optional[str]) -> Optional[str]: + def validate_slug(cls, v: str | None) -> str | None: """Validate slug format: lowercase, alphanumeric, hyphens only.""" if v is None: return v - if not re.match(r'^[a-z0-9-]+$', v): - raise ValueError('Slug must contain only lowercase letters, numbers, and hyphens') - if v.startswith('-') or v.endswith('-'): - raise ValueError('Slug cannot start or end with a hyphen') - if '--' in v: - raise ValueError('Slug cannot contain consecutive hyphens') + if not re.match(r"^[a-z0-9-]+$", v): + raise ValueError( + "Slug must contain only lowercase letters, numbers, and hyphens" + ) + if v.startswith("-") or v.endswith("-"): + raise ValueError("Slug cannot start or end with a hyphen") + if "--" in v: + raise ValueError("Slug cannot contain consecutive hyphens") return v - @field_validator('name') + @field_validator("name") @classmethod def validate_name(cls, v: str) -> str: """Validate organization name.""" if not v or v.strip() == "": - raise ValueError('Organization name cannot be empty') + raise ValueError("Organization name cannot be empty") return v.strip() class OrganizationCreate(OrganizationBase): """Schema for creating a new organization.""" + name: str = Field(..., min_length=1, max_length=255) slug: str = Field(..., min_length=1, max_length=255) class OrganizationUpdate(BaseModel): """Schema for updating an organization.""" - name: Optional[str] = Field(None, min_length=1, max_length=255) - slug: Optional[str] = Field(None, min_length=1, max_length=255) - description: Optional[str] = None - is_active: Optional[bool] = None - settings: Optional[Dict[str, Any]] = None - @field_validator('slug') + name: str | None = Field(None, min_length=1, max_length=255) + slug: str | None = Field(None, min_length=1, max_length=255) + description: str | None = None + is_active: bool | None = None + settings: dict[str, Any] | None = None + + @field_validator("slug") @classmethod - def validate_slug(cls, v: Optional[str]) -> Optional[str]: + def validate_slug(cls, v: str | None) -> str | None: """Validate slug format.""" if v is None: return v - if not re.match(r'^[a-z0-9-]+$', v): - raise ValueError('Slug must contain only lowercase letters, numbers, and hyphens') - if v.startswith('-') or v.endswith('-'): - raise ValueError('Slug cannot start or end with a hyphen') - if '--' in v: - raise ValueError('Slug cannot contain consecutive hyphens') + if not re.match(r"^[a-z0-9-]+$", v): + raise ValueError( + "Slug must contain only lowercase letters, numbers, and hyphens" + ) + if v.startswith("-") or v.endswith("-"): + raise ValueError("Slug cannot start or end with a hyphen") + if "--" in v: + raise ValueError("Slug cannot contain consecutive hyphens") return v - @field_validator('name') + @field_validator("name") @classmethod - def validate_name(cls, v: Optional[str]) -> Optional[str]: + def validate_name(cls, v: str | None) -> str | None: """Validate organization name.""" if v is not None and (not v or v.strip() == ""): - raise ValueError('Organization name cannot be empty') + raise ValueError("Organization name cannot be empty") return v.strip() if v else v class OrganizationResponse(OrganizationBase): """Schema for organization API responses.""" + id: UUID created_at: datetime - updated_at: Optional[datetime] = None - member_count: Optional[int] = 0 + updated_at: datetime | None = None + member_count: int | None = 0 model_config = ConfigDict(from_attributes=True) class OrganizationListResponse(BaseModel): """Schema for paginated organization list responses.""" - organizations: List[OrganizationResponse] + + organizations: list[OrganizationResponse] total: int page: int page_size: int @@ -100,44 +109,49 @@ class OrganizationListResponse(BaseModel): # User-Organization Relationship Schemas class UserOrganizationBase(BaseModel): """Base schema for user-organization relationship.""" + role: OrganizationRole = OrganizationRole.MEMBER is_active: bool = True - custom_permissions: Optional[str] = None + custom_permissions: str | None = None class UserOrganizationCreate(BaseModel): """Schema for adding a user to an organization.""" + user_id: UUID role: OrganizationRole = OrganizationRole.MEMBER - custom_permissions: Optional[str] = None + custom_permissions: str | None = None class UserOrganizationUpdate(BaseModel): """Schema for updating user's role in an organization.""" - role: Optional[OrganizationRole] = None - is_active: Optional[bool] = None - custom_permissions: Optional[str] = None + + role: OrganizationRole | None = None + is_active: bool | None = None + custom_permissions: str | None = None class UserOrganizationResponse(BaseModel): """Schema for user-organization relationship responses.""" + user_id: UUID organization_id: UUID role: OrganizationRole is_active: bool - custom_permissions: Optional[str] = None + custom_permissions: str | None = None created_at: datetime - updated_at: Optional[datetime] = None + updated_at: datetime | None = None model_config = ConfigDict(from_attributes=True) class OrganizationMemberResponse(BaseModel): """Schema for organization member information.""" + user_id: UUID email: str first_name: str - last_name: Optional[str] = None + last_name: str | None = None role: OrganizationRole is_active: bool joined_at: datetime @@ -147,7 +161,8 @@ class OrganizationMemberResponse(BaseModel): class OrganizationMemberListResponse(BaseModel): """Schema for paginated organization member list.""" - members: List[OrganizationMemberResponse] + + members: list[OrganizationMemberResponse] total: int page: int page_size: int diff --git a/backend/app/schemas/sessions.py b/backend/app/schemas/sessions.py index 8e49ae0..9c1b7a2 100644 --- a/backend/app/schemas/sessions.py +++ b/backend/app/schemas/sessions.py @@ -1,37 +1,44 @@ """ Pydantic schemas for user session management. """ + from datetime import datetime -from typing import Optional from uuid import UUID -from pydantic import BaseModel, Field, ConfigDict +from pydantic import BaseModel, ConfigDict, Field class SessionBase(BaseModel): """Base schema for user sessions.""" - device_name: Optional[str] = Field(None, max_length=255, description="Friendly device name") - device_id: Optional[str] = Field(None, max_length=255, description="Persistent device identifier") + + device_name: str | None = Field( + None, max_length=255, description="Friendly device name" + ) + device_id: str | None = Field( + None, max_length=255, description="Persistent device identifier" + ) class SessionCreate(SessionBase): """Schema for creating a new session (internal use).""" + user_id: UUID refresh_token_jti: str = Field(..., max_length=255) - ip_address: Optional[str] = Field(None, max_length=45) - user_agent: Optional[str] = Field(None, max_length=500) + ip_address: str | None = Field(None, max_length=45) + user_agent: str | None = Field(None, max_length=500) last_used_at: datetime expires_at: datetime - location_city: Optional[str] = Field(None, max_length=100) - location_country: Optional[str] = Field(None, max_length=100) + location_city: str | None = Field(None, max_length=100) + location_country: str | None = Field(None, max_length=100) class SessionUpdate(BaseModel): """Schema for updating a session (internal use).""" - last_used_at: Optional[datetime] = None - is_active: Optional[bool] = None - refresh_token_jti: Optional[str] = None - expires_at: Optional[datetime] = None + + last_used_at: datetime | None = None + is_active: bool | None = None + refresh_token_jti: str | None = None + expires_at: datetime | None = None class SessionResponse(SessionBase): @@ -40,14 +47,17 @@ class SessionResponse(SessionBase): This is what users see when they list their active sessions. """ + id: UUID - ip_address: Optional[str] = None - location_city: Optional[str] = None - location_country: Optional[str] = None + ip_address: str | None = None + location_city: str | None = None + location_country: str | None = None last_used_at: datetime created_at: datetime expires_at: datetime - is_current: bool = Field(default=False, description="Whether this is the current session") + is_current: bool = Field( + default=False, description="Whether this is the current session" + ) model_config = ConfigDict( from_attributes=True, @@ -62,14 +72,15 @@ class SessionResponse(SessionBase): "last_used_at": "2025-10-31T12:00:00Z", "created_at": "2025-10-30T09:00:00Z", "expires_at": "2025-11-06T09:00:00Z", - "is_current": True + "is_current": True, } - } + }, ) class SessionListResponse(BaseModel): """Response containing list of sessions.""" + sessions: list[SessionResponse] total: int = Field(..., description="Total number of active sessions") @@ -84,10 +95,10 @@ class SessionListResponse(BaseModel): "last_used_at": "2025-10-31T12:00:00Z", "created_at": "2025-10-30T09:00:00Z", "expires_at": "2025-11-06T09:00:00Z", - "is_current": True + "is_current": True, } ], - "total": 1 + "total": 1, } } ) @@ -95,17 +106,14 @@ class SessionListResponse(BaseModel): class LogoutRequest(BaseModel): """Request schema for logout endpoint.""" + refresh_token: str = Field( - ..., - description="Refresh token for the session to logout from", - min_length=10 + ..., description="Refresh token for the session to logout from", min_length=10 ) model_config = ConfigDict( json_schema_extra={ - "example": { - "refresh_token": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9..." - } + "example": {"refresh_token": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9..."} } ) @@ -116,13 +124,14 @@ class AdminSessionResponse(SessionBase): Includes user information for admin to see who owns each session. """ + id: UUID user_id: UUID user_email: str = Field(..., description="Email of the user who owns this session") - user_full_name: Optional[str] = Field(None, description="Full name of the user") - ip_address: Optional[str] = None - location_city: Optional[str] = None - location_country: Optional[str] = None + user_full_name: str | None = Field(None, description="Full name of the user") + ip_address: str | None = None + location_city: str | None = None + location_country: str | None = None last_used_at: datetime created_at: datetime expires_at: datetime @@ -144,20 +153,21 @@ class AdminSessionResponse(SessionBase): "last_used_at": "2025-10-31T12:00:00Z", "created_at": "2025-10-30T09:00:00Z", "expires_at": "2025-11-06T09:00:00Z", - "is_active": True + "is_active": True, } - } + }, ) class DeviceInfo(BaseModel): """Device information extracted from request.""" - device_name: Optional[str] = None - device_id: Optional[str] = None - ip_address: Optional[str] = None - user_agent: Optional[str] = None - location_city: Optional[str] = None - location_country: Optional[str] = None + + device_name: str | None = None + device_id: str | None = None + ip_address: str | None = None + user_agent: str | None = None + location_city: str | None = None + location_country: str | None = None model_config = ConfigDict( json_schema_extra={ @@ -167,7 +177,7 @@ class DeviceInfo(BaseModel): "ip_address": "192.168.1.50", "user_agent": "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7)...", "location_city": "San Francisco", - "location_country": "United States" + "location_country": "United States", } } ) diff --git a/backend/app/schemas/users.py b/backend/app/schemas/users.py index 3f2c265..ee5994e 100755 --- a/backend/app/schemas/users.py +++ b/backend/app/schemas/users.py @@ -1,9 +1,9 @@ # app/schemas/users.py from datetime import datetime -from typing import Optional, Dict, Any +from typing import Any from uuid import UUID -from pydantic import BaseModel, EmailStr, field_validator, ConfigDict, Field +from pydantic import BaseModel, ConfigDict, EmailStr, Field, field_validator from app.schemas.validators import validate_password_strength, validate_phone_number @@ -11,12 +11,12 @@ from app.schemas.validators import validate_password_strength, validate_phone_nu class UserBase(BaseModel): email: EmailStr first_name: str - last_name: Optional[str] = None - phone_number: Optional[str] = None + last_name: str | None = None + phone_number: str | None = None - @field_validator('phone_number') + @field_validator("phone_number") @classmethod - def validate_phone(cls, v: Optional[str]) -> Optional[str]: + def validate_phone(cls, v: str | None) -> str | None: return validate_phone_number(v) @@ -24,7 +24,7 @@ class UserCreate(UserBase): password: str is_superuser: bool = False - @field_validator('password') + @field_validator("password") @classmethod def password_strength(cls, v: str) -> str: """Enterprise-grade password strength validation""" @@ -32,30 +32,32 @@ class UserCreate(UserBase): class UserUpdate(BaseModel): - first_name: Optional[str] = None - last_name: Optional[str] = None - phone_number: Optional[str] = None - password: Optional[str] = None - preferences: Optional[Dict[str, Any]] = None - is_active: Optional[bool] = None # Changed default from True to None to avoid unintended updates - is_superuser: Optional[bool] = None # Explicitly reject privilege escalation attempts + first_name: str | None = None + last_name: str | None = None + phone_number: str | None = None + password: str | None = None + preferences: dict[str, Any] | None = None + is_active: bool | None = ( + None # Changed default from True to None to avoid unintended updates + ) + is_superuser: bool | None = None # Explicitly reject privilege escalation attempts - @field_validator('phone_number') + @field_validator("phone_number") @classmethod - def validate_phone(cls, v: Optional[str]) -> Optional[str]: + def validate_phone(cls, v: str | None) -> str | None: return validate_phone_number(v) - @field_validator('password') + @field_validator("password") @classmethod - def password_strength(cls, v: Optional[str]) -> Optional[str]: + def password_strength(cls, v: str | None) -> str | None: """Enterprise-grade password strength validation""" if v is None: return v return validate_password_strength(v) - @field_validator('is_superuser') + @field_validator("is_superuser") @classmethod - def prevent_superuser_modification(cls, v: Optional[bool]) -> Optional[bool]: + def prevent_superuser_modification(cls, v: bool | None) -> bool | None: """Prevent users from modifying their superuser status via this schema.""" if v is not None: raise ValueError("Cannot modify superuser status through user update") @@ -67,7 +69,7 @@ class UserInDB(UserBase): is_active: bool is_superuser: bool created_at: datetime - updated_at: Optional[datetime] = None + updated_at: datetime | None = None model_config = ConfigDict(from_attributes=True) @@ -77,28 +79,28 @@ class UserResponse(UserBase): is_active: bool is_superuser: bool created_at: datetime - updated_at: Optional[datetime] = None + updated_at: datetime | None = None model_config = ConfigDict(from_attributes=True) class Token(BaseModel): access_token: str - refresh_token: Optional[str] = None + refresh_token: str | None = None token_type: str = "bearer" user: "UserResponse" # Forward reference since UserResponse is defined above - expires_in: Optional[int] = None # Token expiration in seconds + expires_in: int | None = None # Token expiration in seconds class TokenPayload(BaseModel): sub: str # User ID exp: int # Expiration time - iat: Optional[int] = None # Issued at - jti: Optional[str] = None # JWT ID - is_superuser: Optional[bool] = False - first_name: Optional[str] = None - email: Optional[str] = None - type: Optional[str] = None # Token type (access/refresh) + iat: int | None = None # Issued at + jti: str | None = None # JWT ID + is_superuser: bool | None = False + first_name: str | None = None + email: str | None = None + type: str | None = None # Token type (access/refresh) class TokenData(BaseModel): @@ -108,10 +110,11 @@ class TokenData(BaseModel): class PasswordChange(BaseModel): """Schema for changing password (requires current password).""" + current_password: str new_password: str - @field_validator('new_password') + @field_validator("new_password") @classmethod def password_strength(cls, v: str) -> str: """Enterprise-grade password strength validation""" @@ -120,10 +123,11 @@ class PasswordChange(BaseModel): class PasswordReset(BaseModel): """Schema for resetting password (via email token).""" + token: str new_password: str - @field_validator('new_password') + @field_validator("new_password") @classmethod def password_strength(cls, v: str) -> str: """Enterprise-grade password strength validation""" @@ -141,23 +145,19 @@ class RefreshTokenRequest(BaseModel): class PasswordResetRequest(BaseModel): """Schema for requesting a password reset.""" + email: EmailStr = Field(..., description="Email address of the account") - model_config = { - "json_schema_extra": { - "example": { - "email": "user@example.com" - } - } - } + model_config = {"json_schema_extra": {"example": {"email": "user@example.com"}}} class PasswordResetConfirm(BaseModel): """Schema for confirming a password reset with token.""" + token: str = Field(..., description="Password reset token from email") new_password: str = Field(..., min_length=8, description="New password") - @field_validator('new_password') + @field_validator("new_password") @classmethod def password_strength(cls, v: str) -> str: """Enterprise-grade password strength validation""" @@ -167,7 +167,7 @@ class PasswordResetConfirm(BaseModel): "json_schema_extra": { "example": { "token": "eyJwYXlsb2FkIjp7ImVtYWlsIjoidXNlckBleGFtcGxlLmNvbSIsImV4cCI6MTcxMjM0NTY3OH19", - "new_password": "NewSecurePassword123" + "new_password": "NewSecurePassword123", } } } diff --git a/backend/app/schemas/validators.py b/backend/app/schemas/validators.py index 3f98cbf..9b6f745 100644 --- a/backend/app/schemas/validators.py +++ b/backend/app/schemas/validators.py @@ -4,19 +4,34 @@ Shared validators for Pydantic schemas. This module provides reusable validation functions to ensure consistency across all schemas and avoid code duplication. """ + import re -from typing import Set # Common weak passwords that should be rejected -COMMON_PASSWORDS: Set[str] = { - 'password', 'password1', 'password123', 'password1234', - 'admin', 'admin123', 'admin1234', - 'welcome', 'welcome1', 'welcome123', - 'qwerty', 'qwerty123', - '12345678', '123456789', '1234567890', - 'letmein', 'letmein1', 'letmein123', - 'monkey123', 'dragon123', - 'passw0rd', 'p@ssw0rd', 'p@ssword', +COMMON_PASSWORDS: set[str] = { + "password", + "password1", + "password123", + "password1234", + "admin", + "admin123", + "admin1234", + "welcome", + "welcome1", + "welcome123", + "qwerty", + "qwerty123", + "12345678", + "123456789", + "1234567890", + "letmein", + "letmein1", + "letmein123", + "monkey123", + "dragon123", + "passw0rd", + "p@ssw0rd", + "p@ssword", } @@ -47,18 +62,21 @@ def validate_password_strength(password: str) -> str: """ # Check minimum length if len(password) < 12: - raise ValueError('Password must be at least 12 characters long') + raise ValueError("Password must be at least 12 characters long") # Check against common passwords (case-insensitive) if password.lower() in COMMON_PASSWORDS: - raise ValueError('Password is too common. Please choose a stronger password') + raise ValueError("Password is too common. Please choose a stronger password") # Check for required character types checks = [ - (any(c.islower() for c in password), 'at least one lowercase letter'), - (any(c.isupper() for c in password), 'at least one uppercase letter'), - (any(c.isdigit() for c in password), 'at least one digit'), - (any(c in '!@#$%^&*()_+-=[]{}|;:,.<>?~`' for c in password), 'at least one special character (!@#$%^&*()_+-=[]{}|;:,.<>?~`)') + (any(c.islower() for c in password), "at least one lowercase letter"), + (any(c.isupper() for c in password), "at least one uppercase letter"), + (any(c.isdigit() for c in password), "at least one digit"), + ( + any(c in "!@#$%^&*()_+-=[]{}|;:,.<>?~`" for c in password), + "at least one special character (!@#$%^&*()_+-=[]{}|;:,.<>?~`)", + ), ] failed = [msg for check, msg in checks if not check] @@ -94,10 +112,10 @@ def validate_phone_number(phone: str | None) -> str | None: # Check for empty strings if not phone or phone.strip() == "": - raise ValueError('Phone number cannot be empty') + raise ValueError("Phone number cannot be empty") # Remove all spaces and formatting characters - cleaned = re.sub(r'[\s\-\(\)]', '', phone) + cleaned = re.sub(r"[\s\-\(\)]", "", phone) # Basic pattern: # Must start with + or 0 @@ -105,19 +123,19 @@ def validate_phone_number(phone: str | None) -> str | None: # After 0 must have at least 8 digits # Maximum total length of 15 digits (international standard) # Only allowed characters are + at start and digits - pattern = r'^(?:\+[0-9]{8,14}|0[0-9]{8,14})$' + pattern = r"^(?:\+[0-9]{8,14}|0[0-9]{8,14})$" if not re.match(pattern, cleaned): - raise ValueError('Phone number must start with + or 0 followed by 8-14 digits') + raise ValueError("Phone number must start with + or 0 followed by 8-14 digits") # Additional validation to catch specific invalid cases # NOTE: These checks are defensive code - the regex pattern above already catches these cases - if cleaned.count('+') > 1: # pragma: no cover - raise ValueError('Phone number can only contain one + symbol at the start') + if cleaned.count("+") > 1: # pragma: no cover + raise ValueError("Phone number can only contain one + symbol at the start") # Check for any non-digit characters (except the leading +) if not all(c.isdigit() for c in cleaned[1:]): # pragma: no cover - raise ValueError('Phone number can only contain digits after the prefix') + raise ValueError("Phone number can only contain digits after the prefix") return cleaned @@ -169,16 +187,16 @@ def validate_slug(slug: str) -> str: ValueError: If slug format is invalid """ if not slug or len(slug) < 2: - raise ValueError('Slug must be at least 2 characters long') + raise ValueError("Slug must be at least 2 characters long") if len(slug) > 50: - raise ValueError('Slug must be at most 50 characters long') + raise ValueError("Slug must be at most 50 characters long") # Check format - if not re.match(r'^[a-z0-9]+(?:-[a-z0-9]+)*$', slug): + if not re.match(r"^[a-z0-9]+(?:-[a-z0-9]+)*$", slug): raise ValueError( - 'Slug can only contain lowercase letters, numbers, and hyphens. ' - 'It cannot start or end with a hyphen, and cannot contain consecutive hyphens' + "Slug can only contain lowercase letters, numbers, and hyphens. " + "It cannot start or end with a hyphen, and cannot contain consecutive hyphens" ) return slug diff --git a/backend/app/services/auth_service.py b/backend/app/services/auth_service.py index 7ab8f48..bbfdbc7 100755 --- a/backend/app/services/auth_service.py +++ b/backend/app/services/auth_service.py @@ -1,18 +1,17 @@ # app/services/auth_service.py import logging -from typing import Optional from uuid import UUID from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession from app.core.auth import ( - verify_password_async, - get_password_hash_async, + TokenExpiredError, + TokenInvalidError, create_access_token, create_refresh_token, - TokenExpiredError, - TokenInvalidError + get_password_hash_async, + verify_password_async, ) from app.core.config import settings from app.core.exceptions import AuthenticationError @@ -26,7 +25,9 @@ class AuthService: """Service for handling authentication operations""" @staticmethod - async def authenticate_user(db: AsyncSession, email: str, password: str) -> Optional[User]: + async def authenticate_user( + db: AsyncSession, email: str, password: str + ) -> User | None: """ Authenticate a user with email and password using async password verification. @@ -87,7 +88,7 @@ class AuthService: last_name=user_data.last_name, phone_number=user_data.phone_number, is_active=True, - is_superuser=False + is_superuser=False, ) db.add(user) @@ -103,8 +104,8 @@ class AuthService: except Exception as e: # Rollback on any database errors await db.rollback() - logger.error(f"Error creating user: {str(e)}", exc_info=True) - raise AuthenticationError(f"Failed to create user: {str(e)}") + logger.error(f"Error creating user: {e!s}", exc_info=True) + raise AuthenticationError(f"Failed to create user: {e!s}") @staticmethod def create_tokens(user: User) -> Token: @@ -121,18 +122,13 @@ class AuthService: claims = { "is_superuser": user.is_superuser, "email": user.email, - "first_name": user.first_name + "first_name": user.first_name, } # Create tokens - access_token = create_access_token( - subject=str(user.id), - claims=claims - ) + access_token = create_access_token(subject=str(user.id), claims=claims) - refresh_token = create_refresh_token( - subject=str(user.id) - ) + refresh_token = create_refresh_token(subject=str(user.id)) # Convert User model to UserResponse schema user_response = UserResponse.model_validate(user) @@ -141,7 +137,8 @@ class AuthService: access_token=access_token, refresh_token=refresh_token, user=user_response, - expires_in=settings.ACCESS_TOKEN_EXPIRE_MINUTES * 60 # Convert minutes to seconds + expires_in=settings.ACCESS_TOKEN_EXPIRE_MINUTES + * 60, # Convert minutes to seconds ) @staticmethod @@ -180,11 +177,13 @@ class AuthService: return AuthService.create_tokens(user) except (TokenExpiredError, TokenInvalidError) as e: - logger.warning(f"Token refresh failed: {str(e)}") + logger.warning(f"Token refresh failed: {e!s}") raise @staticmethod - async def change_password(db: AsyncSession, user_id: UUID, current_password: str, new_password: str) -> bool: + async def change_password( + db: AsyncSession, user_id: UUID, current_password: str, new_password: str + ) -> bool: """ Change a user's password. @@ -223,5 +222,7 @@ class AuthService: except Exception as e: # Rollback on any database errors await db.rollback() - logger.error(f"Error changing password for user {user_id}: {str(e)}", exc_info=True) - raise AuthenticationError(f"Failed to change password: {str(e)}") + logger.error( + f"Error changing password for user {user_id}: {e!s}", exc_info=True + ) + raise AuthenticationError(f"Failed to change password: {e!s}") diff --git a/backend/app/services/email_service.py b/backend/app/services/email_service.py index e693b37..ed41e8a 100644 --- a/backend/app/services/email_service.py +++ b/backend/app/services/email_service.py @@ -5,9 +5,9 @@ Email service with placeholder implementation. This service provides email sending functionality with a simple console/log-based placeholder that can be easily replaced with a real email provider (SendGrid, SES, etc.) """ + import logging from abc import ABC, abstractmethod -from typing import List, Optional from app.core.config import settings @@ -20,13 +20,12 @@ class EmailBackend(ABC): @abstractmethod async def send_email( self, - to: List[str], + to: list[str], subject: str, html_content: str, - text_content: Optional[str] = None + text_content: str | None = None, ) -> bool: """Send an email.""" - pass class ConsoleEmailBackend(EmailBackend): @@ -39,10 +38,10 @@ class ConsoleEmailBackend(EmailBackend): async def send_email( self, - to: List[str], + to: list[str], subject: str, html_content: str, - text_content: Optional[str] = None + text_content: str | None = None, ) -> bool: """ Log email content to console/logs. @@ -88,10 +87,10 @@ class SMTPEmailBackend(EmailBackend): async def send_email( self, - to: List[str], + to: list[str], subject: str, html_content: str, - text_content: Optional[str] = None + text_content: str | None = None, ) -> bool: """Send email via SMTP.""" # TODO: Implement SMTP sending @@ -108,7 +107,7 @@ class EmailService: and can be configured to use different backends (console, SMTP, SendGrid, etc.) """ - def __init__(self, backend: Optional[EmailBackend] = None): + def __init__(self, backend: EmailBackend | None = None): """ Initialize email service with a backend. @@ -118,10 +117,7 @@ class EmailService: self.backend = backend or ConsoleEmailBackend() async def send_password_reset_email( - self, - to_email: str, - reset_token: str, - user_name: Optional[str] = None + self, to_email: str, reset_token: str, user_name: str | None = None ) -> bool: """ Send password reset email. @@ -142,7 +138,7 @@ class EmailService: # Plain text version text_content = f""" -Hello{' ' + user_name if user_name else ''}, +Hello{" " + user_name if user_name else ""}, You requested a password reset for your account. Click the link below to reset your password: @@ -177,7 +173,7 @@ The {settings.PROJECT_NAME} Team

Password Reset

-

Hello{' ' + user_name if user_name else ''},

+

Hello{" " + user_name if user_name else ""},

You requested a password reset for your account. Click the button below to reset your password:

Reset Password @@ -200,17 +196,14 @@ The {settings.PROJECT_NAME} Team to=[to_email], subject=subject, html_content=html_content, - text_content=text_content + text_content=text_content, ) except Exception as e: - logger.error(f"Failed to send password reset email to {to_email}: {str(e)}") + logger.error(f"Failed to send password reset email to {to_email}: {e!s}") return False async def send_email_verification( - self, - to_email: str, - verification_token: str, - user_name: Optional[str] = None + self, to_email: str, verification_token: str, user_name: str | None = None ) -> bool: """ Send email verification email. @@ -224,14 +217,16 @@ The {settings.PROJECT_NAME} Team True if email sent successfully """ # Generate verification URL - verification_url = f"{settings.FRONTEND_URL}/verify-email?token={verification_token}" + verification_url = ( + f"{settings.FRONTEND_URL}/verify-email?token={verification_token}" + ) # Prepare email content subject = "Verify Your Email Address" # Plain text version text_content = f""" -Hello{' ' + user_name if user_name else ''}, +Hello{" " + user_name if user_name else ""}, Thank you for signing up! Please verify your email address by clicking the link below: @@ -266,7 +261,7 @@ The {settings.PROJECT_NAME} Team

Verify Your Email

-

Hello{' ' + user_name if user_name else ''},

+

Hello{" " + user_name if user_name else ""},

Thank you for signing up! Please verify your email address by clicking the button below:

Verify Email @@ -289,10 +284,10 @@ The {settings.PROJECT_NAME} Team to=[to_email], subject=subject, html_content=html_content, - text_content=text_content + text_content=text_content, ) except Exception as e: - logger.error(f"Failed to send verification email to {to_email}: {str(e)}") + logger.error(f"Failed to send verification email to {to_email}: {e!s}") return False diff --git a/backend/app/services/session_cleanup.py b/backend/app/services/session_cleanup.py index 230eeda..e993530 100755 --- a/backend/app/services/session_cleanup.py +++ b/backend/app/services/session_cleanup.py @@ -3,8 +3,9 @@ Background job for cleaning up expired sessions. This service runs periodically to remove old session records from the database. """ + import logging -from datetime import datetime, timezone +from datetime import UTC, datetime from app.core.database import SessionLocal from app.crud.session import session as session_crud @@ -39,7 +40,7 @@ async def cleanup_expired_sessions(keep_days: int = 30) -> int: return count except Exception as e: - logger.error(f"Error during session cleanup: {str(e)}", exc_info=True) + logger.error(f"Error during session cleanup: {e!s}", exc_info=True) return 0 @@ -52,20 +53,21 @@ async def get_session_statistics() -> dict: """ async with SessionLocal() as db: try: + from sqlalchemy import func, select + from app.models.user_session import UserSession - from sqlalchemy import select, func total_result = await db.execute(select(func.count(UserSession.id))) total_sessions = total_result.scalar_one() active_result = await db.execute( - select(func.count(UserSession.id)).where(UserSession.is_active == True) + select(func.count(UserSession.id)).where(UserSession.is_active) ) active_sessions = active_result.scalar_one() expired_result = await db.execute( select(func.count(UserSession.id)).where( - UserSession.expires_at < datetime.now(timezone.utc) + UserSession.expires_at < datetime.now(UTC) ) ) expired_sessions = expired_result.scalar_one() @@ -82,5 +84,5 @@ async def get_session_statistics() -> dict: return stats except Exception as e: - logger.error(f"Error getting session statistics: {str(e)}", exc_info=True) + logger.error(f"Error getting session statistics: {e!s}", exc_info=True) return {} diff --git a/backend/app/utils/auth_test_utils.py b/backend/app/utils/auth_test_utils.py index 6a5397e..b19f544 100644 --- a/backend/app/utils/auth_test_utils.py +++ b/backend/app/utils/auth_test_utils.py @@ -2,7 +2,8 @@ Authentication utilities for testing. This module provides tools to bypass FastAPI's authentication in tests. """ -from typing import Callable, Dict, Optional + +from collections.abc import Callable from fastapi import FastAPI from fastapi.security import OAuth2PasswordBearer @@ -13,9 +14,9 @@ from app.models.user import User def create_test_auth_client( - app: FastAPI, - test_user: User, - extra_overrides: Optional[Dict[Callable, Callable]] = None + app: FastAPI, + test_user: User, + extra_overrides: dict[Callable, Callable] | None = None, ) -> TestClient: """ Create a test client with authentication pre-configured. @@ -47,10 +48,7 @@ def create_test_auth_client( return TestClient(app) -def create_test_optional_auth_client( - app: FastAPI, - test_user: User -) -> TestClient: +def create_test_optional_auth_client(app: FastAPI, test_user: User) -> TestClient: """ Create a test client with optional authentication pre-configured. @@ -70,10 +68,7 @@ def create_test_optional_auth_client( return TestClient(app) -def create_test_superuser_client( - app: FastAPI, - test_user: User -) -> TestClient: +def create_test_superuser_client(app: FastAPI, test_user: User) -> TestClient: """ Create a test client with superuser authentication pre-configured. @@ -120,7 +115,7 @@ def cleanup_test_client_auth(app: FastAPI) -> None: auth_deps = [ get_current_user, get_optional_current_user, - OAuth2PasswordBearer(tokenUrl="/api/v1/auth/login") + OAuth2PasswordBearer(tokenUrl="/api/v1/auth/login"), ] # Remove overrides diff --git a/backend/app/utils/device.py b/backend/app/utils/device.py index d4842c3..de2114a 100644 --- a/backend/app/utils/device.py +++ b/backend/app/utils/device.py @@ -1,8 +1,8 @@ """ Utility functions for extracting and parsing device information from HTTP requests. """ + import re -from typing import Optional from fastapi import Request @@ -19,11 +19,11 @@ def extract_device_info(request: Request) -> DeviceInfo: Returns: DeviceInfo object with parsed device information """ - user_agent = request.headers.get('user-agent', '') + user_agent = request.headers.get("user-agent", "") device_info = DeviceInfo( device_name=parse_device_name(user_agent), - device_id=request.headers.get('x-device-id'), # Client must send this header + device_id=request.headers.get("x-device-id"), # Client must send this header ip_address=get_client_ip(request), user_agent=user_agent[:500] if user_agent else None, # Truncate to max length location_city=None, # Can be populated via IP geolocation service @@ -33,7 +33,7 @@ def extract_device_info(request: Request) -> DeviceInfo: return device_info -def parse_device_name(user_agent: str) -> Optional[str]: +def parse_device_name(user_agent: str) -> str | None: """ Parse user agent string to extract a friendly device name. @@ -54,48 +54,48 @@ def parse_device_name(user_agent: str) -> Optional[str]: user_agent_lower = user_agent.lower() # Mobile devices (check first, as they can contain desktop patterns too) - if 'iphone' in user_agent_lower: + if "iphone" in user_agent_lower: return "iPhone" - elif 'ipad' in user_agent_lower: + elif "ipad" in user_agent_lower: return "iPad" - elif 'android' in user_agent_lower: + elif "android" in user_agent_lower: # Try to extract device model - android_match = re.search(r'android.*;\s*([^)]+)\s*build', user_agent_lower) + android_match = re.search(r"android.*;\s*([^)]+)\s*build", user_agent_lower) if android_match: device_model = android_match.group(1).strip() return f"Android ({device_model.title()})" return "Android device" - elif 'windows phone' in user_agent_lower: + elif "windows phone" in user_agent_lower: return "Windows Phone" # Tablets (check before desktop, as some tablets contain "android") - elif 'tablet' in user_agent_lower: + elif "tablet" in user_agent_lower: return "Tablet" # Smart TVs (check before desktop OS patterns) - elif any(tv in user_agent_lower for tv in ['smart-tv', 'smarttv']): + elif any(tv in user_agent_lower for tv in ["smart-tv", "smarttv"]): return "Smart TV" # Game consoles (check before desktop OS patterns, as Xbox contains "Windows") - elif 'playstation' in user_agent_lower: + elif "playstation" in user_agent_lower: return "PlayStation" - elif 'xbox' in user_agent_lower: + elif "xbox" in user_agent_lower: return "Xbox" - elif 'nintendo' in user_agent_lower: + elif "nintendo" in user_agent_lower: return "Nintendo" # Desktop operating systems - elif 'macintosh' in user_agent_lower or 'mac os x' in user_agent_lower: + elif "macintosh" in user_agent_lower or "mac os x" in user_agent_lower: # Try to extract browser browser = extract_browser(user_agent) return f"{browser} on Mac" if browser else "Mac" - elif 'windows' in user_agent_lower: + elif "windows" in user_agent_lower: browser = extract_browser(user_agent) return f"{browser} on Windows" if browser else "Windows PC" - elif 'linux' in user_agent_lower and 'android' not in user_agent_lower: + elif "linux" in user_agent_lower and "android" not in user_agent_lower: browser = extract_browser(user_agent) return f"{browser} on Linux" if browser else "Linux" - elif 'cros' in user_agent_lower: + elif "cros" in user_agent_lower: return "Chromebook" # Fallback: just return browser name if detected @@ -106,7 +106,7 @@ def parse_device_name(user_agent: str) -> Optional[str]: return "Unknown device" -def extract_browser(user_agent: str) -> Optional[str]: +def extract_browser(user_agent: str) -> str | None: """ Extract browser name from user agent string. @@ -126,26 +126,26 @@ def extract_browser(user_agent: str) -> Optional[str]: user_agent_lower = user_agent.lower() # Check specific browsers (order matters - check Edge before Chrome!) - if 'edg/' in user_agent_lower or 'edge/' in user_agent_lower: + if "edg/" in user_agent_lower or "edge/" in user_agent_lower: return "Edge" - elif 'opr/' in user_agent_lower or 'opera' in user_agent_lower: + elif "opr/" in user_agent_lower or "opera" in user_agent_lower: return "Opera" - elif 'chrome/' in user_agent_lower: + elif "chrome/" in user_agent_lower: return "Chrome" - elif 'safari/' in user_agent_lower: + elif "safari/" in user_agent_lower: # Make sure it's actually Safari, not Chrome (which also contains "Safari") - if 'chrome' not in user_agent_lower: + if "chrome" not in user_agent_lower: return "Safari" return None - elif 'firefox/' in user_agent_lower: + elif "firefox/" in user_agent_lower: return "Firefox" - elif 'msie' in user_agent_lower or 'trident/' in user_agent_lower: + elif "msie" in user_agent_lower or "trident/" in user_agent_lower: return "Internet Explorer" return None -def get_client_ip(request: Request) -> Optional[str]: +def get_client_ip(request: Request) -> str | None: """ Extract client IP address from request, considering proxy headers. @@ -163,14 +163,14 @@ def get_client_ip(request: Request) -> Optional[str]: - request.client.host is fallback for direct connections """ # Check X-Forwarded-For (common in proxied environments) - x_forwarded_for = request.headers.get('x-forwarded-for') + x_forwarded_for = request.headers.get("x-forwarded-for") if x_forwarded_for: # Get the first IP (original client) - client_ip = x_forwarded_for.split(',')[0].strip() + client_ip = x_forwarded_for.split(",")[0].strip() return client_ip # Check X-Real-IP (used by some proxies like nginx) - x_real_ip = request.headers.get('x-real-ip') + x_real_ip = request.headers.get("x-real-ip") if x_real_ip: return x_real_ip.strip() @@ -195,9 +195,17 @@ def is_mobile_device(user_agent: str) -> bool: return False mobile_patterns = [ - 'mobile', 'android', 'iphone', 'ipad', 'ipod', - 'blackberry', 'windows phone', 'webos', 'opera mini', - 'iemobile', 'mobile safari' + "mobile", + "android", + "iphone", + "ipad", + "ipod", + "blackberry", + "windows phone", + "webos", + "opera mini", + "iemobile", + "mobile safari", ] user_agent_lower = user_agent.lower() @@ -220,7 +228,7 @@ def get_device_type(user_agent: str) -> str: user_agent_lower = user_agent.lower() # Check for tablets first (they can contain "mobile" too) - if 'ipad' in user_agent_lower or 'tablet' in user_agent_lower: + if "ipad" in user_agent_lower or "tablet" in user_agent_lower: return "tablet" # Check for mobile @@ -228,7 +236,7 @@ def get_device_type(user_agent: str) -> str: return "mobile" # Check for desktop OS patterns - if any(os in user_agent_lower for os in ['windows', 'macintosh', 'linux', 'cros']): + if any(os in user_agent_lower for os in ["windows", "macintosh", "linux", "cros"]): return "desktop" return "other" diff --git a/backend/app/utils/security.py b/backend/app/utils/security.py index 303d339..8533a3b 100644 --- a/backend/app/utils/security.py +++ b/backend/app/utils/security.py @@ -5,18 +5,21 @@ This module provides utilities for creating and verifying signed tokens, useful for operations like file uploads, password resets, or any other time-limited, single-use operations. """ + import base64 import hashlib import hmac import json import secrets import time -from typing import Dict, Any, Optional +from typing import Any from app.core.config import settings -def create_upload_token(file_path: str, content_type: str, expires_in: int = 300) -> str: +def create_upload_token( + file_path: str, content_type: str, expires_in: int = 300 +) -> str: """ Create a signed token for secure file uploads. @@ -40,34 +43,29 @@ def create_upload_token(file_path: str, content_type: str, expires_in: int = 300 "path": file_path, "content_type": content_type, "exp": int(time.time()) + expires_in, - "nonce": secrets.token_hex(8) # Add randomness to prevent token reuse + "nonce": secrets.token_hex(8), # Add randomness to prevent token reuse } # Convert to JSON and encode - payload_bytes = json.dumps(payload).encode('utf-8') + payload_bytes = json.dumps(payload).encode("utf-8") # Create a signature using HMAC-SHA256 for security # This prevents length extension attacks that plain SHA-256 is vulnerable to signature = hmac.new( - settings.SECRET_KEY.encode('utf-8'), - payload_bytes, - hashlib.sha256 + settings.SECRET_KEY.encode("utf-8"), payload_bytes, hashlib.sha256 ).hexdigest() # Combine payload and signature - token_data = { - "payload": payload, - "signature": signature - } + token_data = {"payload": payload, "signature": signature} # Encode the final token token_json = json.dumps(token_data) - token = base64.urlsafe_b64encode(token_json.encode('utf-8')).decode('utf-8') + token = base64.urlsafe_b64encode(token_json.encode("utf-8")).decode("utf-8") return token -def verify_upload_token(token: str) -> Optional[Dict[str, Any]]: +def verify_upload_token(token: str) -> dict[str, Any] | None: """ Verify an upload token and return the payload if valid. @@ -88,7 +86,7 @@ def verify_upload_token(token: str) -> Optional[Dict[str, Any]]: """ try: # Decode the token - token_json = base64.urlsafe_b64decode(token.encode('utf-8')).decode('utf-8') + token_json = base64.urlsafe_b64decode(token.encode("utf-8")).decode("utf-8") token_data = json.loads(token_json) # Extract payload and signature @@ -96,11 +94,9 @@ def verify_upload_token(token: str) -> Optional[Dict[str, Any]]: signature = token_data["signature"] # Verify signature using HMAC and constant-time comparison - payload_bytes = json.dumps(payload).encode('utf-8') + payload_bytes = json.dumps(payload).encode("utf-8") expected_signature = hmac.new( - settings.SECRET_KEY.encode('utf-8'), - payload_bytes, - hashlib.sha256 + settings.SECRET_KEY.encode("utf-8"), payload_bytes, hashlib.sha256 ).hexdigest() if not hmac.compare_digest(signature, expected_signature): @@ -136,34 +132,29 @@ def create_password_reset_token(email: str, expires_in: int = 3600) -> str: "email": email, "exp": int(time.time()) + expires_in, "nonce": secrets.token_hex(16), # Extra randomness - "purpose": "password_reset" + "purpose": "password_reset", } # Convert to JSON and encode - payload_bytes = json.dumps(payload).encode('utf-8') + payload_bytes = json.dumps(payload).encode("utf-8") # Create a signature using HMAC-SHA256 for security # This prevents length extension attacks that plain SHA-256 is vulnerable to signature = hmac.new( - settings.SECRET_KEY.encode('utf-8'), - payload_bytes, - hashlib.sha256 + settings.SECRET_KEY.encode("utf-8"), payload_bytes, hashlib.sha256 ).hexdigest() # Combine payload and signature - token_data = { - "payload": payload, - "signature": signature - } + token_data = {"payload": payload, "signature": signature} # Encode the final token token_json = json.dumps(token_data) - token = base64.urlsafe_b64encode(token_json.encode('utf-8')).decode('utf-8') + token = base64.urlsafe_b64encode(token_json.encode("utf-8")).decode("utf-8") return token -def verify_password_reset_token(token: str) -> Optional[str]: +def verify_password_reset_token(token: str) -> str | None: """ Verify a password reset token and return the email if valid. @@ -182,7 +173,7 @@ def verify_password_reset_token(token: str) -> Optional[str]: """ try: # Decode the token - token_json = base64.urlsafe_b64decode(token.encode('utf-8')).decode('utf-8') + token_json = base64.urlsafe_b64decode(token.encode("utf-8")).decode("utf-8") token_data = json.loads(token_json) # Extract payload and signature @@ -194,11 +185,9 @@ def verify_password_reset_token(token: str) -> Optional[str]: return None # Verify signature using HMAC and constant-time comparison - payload_bytes = json.dumps(payload).encode('utf-8') + payload_bytes = json.dumps(payload).encode("utf-8") expected_signature = hmac.new( - settings.SECRET_KEY.encode('utf-8'), - payload_bytes, - hashlib.sha256 + settings.SECRET_KEY.encode("utf-8"), payload_bytes, hashlib.sha256 ).hexdigest() if not hmac.compare_digest(signature, expected_signature): @@ -234,34 +223,29 @@ def create_email_verification_token(email: str, expires_in: int = 86400) -> str: "email": email, "exp": int(time.time()) + expires_in, "nonce": secrets.token_hex(16), - "purpose": "email_verification" + "purpose": "email_verification", } # Convert to JSON and encode - payload_bytes = json.dumps(payload).encode('utf-8') + payload_bytes = json.dumps(payload).encode("utf-8") # Create a signature using HMAC-SHA256 for security # This prevents length extension attacks that plain SHA-256 is vulnerable to signature = hmac.new( - settings.SECRET_KEY.encode('utf-8'), - payload_bytes, - hashlib.sha256 + settings.SECRET_KEY.encode("utf-8"), payload_bytes, hashlib.sha256 ).hexdigest() # Combine payload and signature - token_data = { - "payload": payload, - "signature": signature - } + token_data = {"payload": payload, "signature": signature} # Encode the final token token_json = json.dumps(token_data) - token = base64.urlsafe_b64encode(token_json.encode('utf-8')).decode('utf-8') + token = base64.urlsafe_b64encode(token_json.encode("utf-8")).decode("utf-8") return token -def verify_email_verification_token(token: str) -> Optional[str]: +def verify_email_verification_token(token: str) -> str | None: """ Verify an email verification token and return the email if valid. @@ -280,7 +264,7 @@ def verify_email_verification_token(token: str) -> Optional[str]: """ try: # Decode the token - token_json = base64.urlsafe_b64decode(token.encode('utf-8')).decode('utf-8') + token_json = base64.urlsafe_b64decode(token.encode("utf-8")).decode("utf-8") token_data = json.loads(token_json) # Extract payload and signature @@ -292,11 +276,9 @@ def verify_email_verification_token(token: str) -> Optional[str]: return None # Verify signature using HMAC and constant-time comparison - payload_bytes = json.dumps(payload).encode('utf-8') + payload_bytes = json.dumps(payload).encode("utf-8") expected_signature = hmac.new( - settings.SECRET_KEY.encode('utf-8'), - payload_bytes, - hashlib.sha256 + settings.SECRET_KEY.encode("utf-8"), payload_bytes, hashlib.sha256 ).hexdigest() if not hmac.compare_digest(signature, expected_signature): diff --git a/backend/app/utils/test_utils.py b/backend/app/utils/test_utils.py index ce33e84..fa89d59 100644 --- a/backend/app/utils/test_utils.py +++ b/backend/app/utils/test_utils.py @@ -9,17 +9,19 @@ from app.core.database import Base logger = logging.getLogger(__name__) + def get_test_engine(): """Create an SQLite in-memory engine specifically for testing""" test_engine = create_engine( "sqlite:///:memory:", connect_args={"check_same_thread": False}, poolclass=StaticPool, # Use static pool for in-memory testing - echo=False + echo=False, ) return test_engine + def setup_test_db(): """Create a test database and session factory""" # Create a new engine for this test run @@ -30,14 +32,12 @@ def setup_test_db(): # Create session factory TestingSessionLocal = sessionmaker( - autocommit=False, - autoflush=False, - bind=test_engine, - expire_on_commit=False + autocommit=False, autoflush=False, bind=test_engine, expire_on_commit=False ) return test_engine, TestingSessionLocal + def teardown_test_db(engine): """Clean up after tests""" # Drop all tables @@ -46,13 +46,14 @@ def teardown_test_db(engine): # Dispose of engine engine.dispose() + async def get_async_test_engine(): """Create an async SQLite in-memory engine specifically for testing""" test_engine = create_async_engine( "sqlite+aiosqlite:///:memory:", connect_args={"check_same_thread": False}, poolclass=StaticPool, # Use static pool for in-memory testing - echo=False + echo=False, ) return test_engine @@ -69,7 +70,7 @@ async def setup_async_test_db(): autoflush=False, bind=test_engine, expire_on_commit=False, - class_=AsyncSession + class_=AsyncSession, ) return test_engine, AsyncTestingSessionLocal diff --git a/backend/tests/api/dependencies/test_auth_dependencies.py b/backend/tests/api/dependencies/test_auth_dependencies.py index 3de6b05..5da4a55 100755 --- a/backend/tests/api/dependencies/test_auth_dependencies.py +++ b/backend/tests/api/dependencies/test_auth_dependencies.py @@ -1,15 +1,16 @@ # tests/api/dependencies/test_auth_dependencies.py -import pytest -import pytest_asyncio import uuid from unittest.mock import patch + +import pytest +import pytest_asyncio from fastapi import HTTPException from app.api.dependencies.auth import ( - get_current_user, get_current_active_user, get_current_superuser, - get_optional_current_user + get_current_user, + get_optional_current_user, ) from app.core.auth import TokenExpiredError, TokenInvalidError, get_password_hash from app.models.user import User @@ -24,7 +25,7 @@ def mock_token(): @pytest_asyncio.fixture async def async_mock_user(async_test_db): """Async fixture to create and return a mock User instance.""" - test_engine, AsyncTestingSessionLocal = async_test_db + _test_engine, AsyncTestingSessionLocal = async_test_db async with AsyncTestingSessionLocal() as session: mock_user = User( id=uuid.uuid4(), @@ -47,12 +48,14 @@ class TestGetCurrentUser: """Tests for get_current_user dependency""" @pytest.mark.asyncio - async def test_get_current_user_success(self, async_test_db, async_mock_user, mock_token): + async def test_get_current_user_success( + self, async_test_db, async_mock_user, mock_token + ): """Test successfully getting the current user""" - test_engine, AsyncTestingSessionLocal = async_test_db + _test_engine, AsyncTestingSessionLocal = async_test_db async with AsyncTestingSessionLocal() as session: # Mock get_token_data to return user_id that matches our mock_user - with patch('app.api.dependencies.auth.get_token_data') as mock_get_data: + with patch("app.api.dependencies.auth.get_token_data") as mock_get_data: mock_get_data.return_value.user_id = async_mock_user.id # Call the dependency @@ -65,12 +68,12 @@ class TestGetCurrentUser: @pytest.mark.asyncio async def test_get_current_user_nonexistent(self, async_test_db, mock_token): """Test when the token contains a user ID that doesn't exist""" - test_engine, AsyncTestingSessionLocal = async_test_db + _test_engine, AsyncTestingSessionLocal = async_test_db async with AsyncTestingSessionLocal() as session: # Mock get_token_data to return a non-existent user ID nonexistent_id = uuid.UUID("11111111-1111-1111-1111-111111111111") - with patch('app.api.dependencies.auth.get_token_data') as mock_get_data: + with patch("app.api.dependencies.auth.get_token_data") as mock_get_data: mock_get_data.return_value.user_id = nonexistent_id # Should raise HTTPException with 404 status @@ -81,19 +84,24 @@ class TestGetCurrentUser: assert "User not found" in exc_info.value.detail @pytest.mark.asyncio - async def test_get_current_user_inactive(self, async_test_db, async_mock_user, mock_token): + async def test_get_current_user_inactive( + self, async_test_db, async_mock_user, mock_token + ): """Test when the user is inactive""" - test_engine, AsyncTestingSessionLocal = async_test_db + _test_engine, AsyncTestingSessionLocal = async_test_db async with AsyncTestingSessionLocal() as session: # Get the user in this session and make it inactive from sqlalchemy import select - result = await session.execute(select(User).where(User.id == async_mock_user.id)) + + result = await session.execute( + select(User).where(User.id == async_mock_user.id) + ) user_in_session = result.scalar_one_or_none() user_in_session.is_active = False await session.commit() # Mock get_token_data - with patch('app.api.dependencies.auth.get_token_data') as mock_get_data: + with patch("app.api.dependencies.auth.get_token_data") as mock_get_data: mock_get_data.return_value.user_id = async_mock_user.id # Should raise HTTPException with 403 status @@ -106,10 +114,10 @@ class TestGetCurrentUser: @pytest.mark.asyncio async def test_get_current_user_expired_token(self, async_test_db, mock_token): """Test with an expired token""" - test_engine, AsyncTestingSessionLocal = async_test_db + _test_engine, AsyncTestingSessionLocal = async_test_db async with AsyncTestingSessionLocal() as session: # Mock get_token_data to raise TokenExpiredError - with patch('app.api.dependencies.auth.get_token_data') as mock_get_data: + with patch("app.api.dependencies.auth.get_token_data") as mock_get_data: mock_get_data.side_effect = TokenExpiredError("Token expired") # Should raise HTTPException with 401 status @@ -122,10 +130,10 @@ class TestGetCurrentUser: @pytest.mark.asyncio async def test_get_current_user_invalid_token(self, async_test_db, mock_token): """Test with an invalid token""" - test_engine, AsyncTestingSessionLocal = async_test_db + _test_engine, AsyncTestingSessionLocal = async_test_db async with AsyncTestingSessionLocal() as session: # Mock get_token_data to raise TokenInvalidError - with patch('app.api.dependencies.auth.get_token_data') as mock_get_data: + with patch("app.api.dependencies.auth.get_token_data") as mock_get_data: mock_get_data.side_effect = TokenInvalidError("Invalid token") # Should raise HTTPException with 401 status @@ -194,12 +202,14 @@ class TestGetOptionalCurrentUser: """Tests for get_optional_current_user dependency""" @pytest.mark.asyncio - async def test_get_optional_current_user_with_token(self, async_test_db, async_mock_user, mock_token): + async def test_get_optional_current_user_with_token( + self, async_test_db, async_mock_user, mock_token + ): """Test getting optional user with a valid token""" - test_engine, AsyncTestingSessionLocal = async_test_db + _test_engine, AsyncTestingSessionLocal = async_test_db async with AsyncTestingSessionLocal() as session: # Mock get_token_data - with patch('app.api.dependencies.auth.get_token_data') as mock_get_data: + with patch("app.api.dependencies.auth.get_token_data") as mock_get_data: mock_get_data.return_value.user_id = async_mock_user.id # Call the dependency @@ -212,7 +222,7 @@ class TestGetOptionalCurrentUser: @pytest.mark.asyncio async def test_get_optional_current_user_no_token(self, async_test_db): """Test getting optional user with no token""" - test_engine, AsyncTestingSessionLocal = async_test_db + _test_engine, AsyncTestingSessionLocal = async_test_db async with AsyncTestingSessionLocal() as session: # Call the dependency with no token user = await get_optional_current_user(db=session, token=None) @@ -221,12 +231,14 @@ class TestGetOptionalCurrentUser: assert user is None @pytest.mark.asyncio - async def test_get_optional_current_user_invalid_token(self, async_test_db, mock_token): + async def test_get_optional_current_user_invalid_token( + self, async_test_db, mock_token + ): """Test getting optional user with an invalid token""" - test_engine, AsyncTestingSessionLocal = async_test_db + _test_engine, AsyncTestingSessionLocal = async_test_db async with AsyncTestingSessionLocal() as session: # Mock get_token_data to raise TokenInvalidError - with patch('app.api.dependencies.auth.get_token_data') as mock_get_data: + with patch("app.api.dependencies.auth.get_token_data") as mock_get_data: mock_get_data.side_effect = TokenInvalidError("Invalid token") # Call the dependency @@ -236,12 +248,14 @@ class TestGetOptionalCurrentUser: assert user is None @pytest.mark.asyncio - async def test_get_optional_current_user_expired_token(self, async_test_db, mock_token): + async def test_get_optional_current_user_expired_token( + self, async_test_db, mock_token + ): """Test getting optional user with an expired token""" - test_engine, AsyncTestingSessionLocal = async_test_db + _test_engine, AsyncTestingSessionLocal = async_test_db async with AsyncTestingSessionLocal() as session: # Mock get_token_data to raise TokenExpiredError - with patch('app.api.dependencies.auth.get_token_data') as mock_get_data: + with patch("app.api.dependencies.auth.get_token_data") as mock_get_data: mock_get_data.side_effect = TokenExpiredError("Token expired") # Call the dependency @@ -251,19 +265,24 @@ class TestGetOptionalCurrentUser: assert user is None @pytest.mark.asyncio - async def test_get_optional_current_user_inactive(self, async_test_db, async_mock_user, mock_token): + async def test_get_optional_current_user_inactive( + self, async_test_db, async_mock_user, mock_token + ): """Test getting optional user when user is inactive""" - test_engine, AsyncTestingSessionLocal = async_test_db + _test_engine, AsyncTestingSessionLocal = async_test_db async with AsyncTestingSessionLocal() as session: # Get the user in this session and make it inactive from sqlalchemy import select - result = await session.execute(select(User).where(User.id == async_mock_user.id)) + + result = await session.execute( + select(User).where(User.id == async_mock_user.id) + ) user_in_session = result.scalar_one_or_none() user_in_session.is_active = False await session.commit() # Mock get_token_data - with patch('app.api.dependencies.auth.get_token_data') as mock_get_data: + with patch("app.api.dependencies.auth.get_token_data") as mock_get_data: mock_get_data.return_value.user_id = async_mock_user.id # Call the dependency diff --git a/backend/tests/api/routes/test_health.py b/backend/tests/api/routes/test_health.py index 47f9bce..5333c28 100755 --- a/backend/tests/api/routes/test_health.py +++ b/backend/tests/api/routes/test_health.py @@ -1,13 +1,12 @@ # tests/api/routes/test_health.py +from datetime import datetime +from unittest.mock import patch + import pytest -from unittest.mock import AsyncMock, patch, MagicMock from fastapi import status from fastapi.testclient import TestClient -from datetime import datetime -from sqlalchemy.exc import OperationalError from app.main import app -from app.core.database import get_db @pytest.fixture @@ -121,7 +120,10 @@ class TestHealthEndpoint: response = client.get("/health") # Should succeed without authentication - assert response.status_code in [status.HTTP_200_OK, status.HTTP_503_SERVICE_UNAVAILABLE] + assert response.status_code in [ + status.HTTP_200_OK, + status.HTTP_503_SERVICE_UNAVAILABLE, + ] def test_health_check_idempotent(self, client): """Test that multiple health checks return consistent results""" @@ -142,7 +144,10 @@ class TestHealthEndpoint: assert data1["environment"] == data2["environment"] # Same database check status - assert data1["checks"]["database"]["status"] == data2["checks"]["database"]["status"] + assert ( + data1["checks"]["database"]["status"] + == data2["checks"]["database"]["status"] + ) def test_health_check_content_type(self, client): """Test that health check returns JSON content type""" diff --git a/backend/tests/api/test_admin.py b/backend/tests/api/test_admin.py index 53131c7..24e73e4 100644 --- a/backend/tests/api/test_admin.py +++ b/backend/tests/api/test_admin.py @@ -2,15 +2,17 @@ """ Comprehensive tests for admin endpoints. """ + +from datetime import UTC, datetime, timedelta +from uuid import uuid4 + import pytest import pytest_asyncio -from uuid import uuid4 from fastapi import status from app.models.organization import Organization -from app.models.user_organization import UserOrganization, OrganizationRole +from app.models.user_organization import OrganizationRole, UserOrganization from app.models.user_session import UserSession -from datetime import datetime, timezone, timedelta @pytest_asyncio.fixture @@ -18,10 +20,7 @@ async def superuser_token(client, async_test_superuser): """Get access token for superuser.""" response = await client.post( "/api/v1/auth/login", - json={ - "email": "superuser@example.com", - "password": "SuperPassword123!" - } + json={"email": "superuser@example.com", "password": "SuperPassword123!"}, ) assert response.status_code == 200, f"Login failed: {response.json()}" return response.json()["access_token"] @@ -29,6 +28,7 @@ async def superuser_token(client, async_test_superuser): # ===== USER MANAGEMENT TESTS ===== + class TestAdminListUsers: """Tests for GET /admin/users endpoint.""" @@ -37,7 +37,7 @@ class TestAdminListUsers: """Test successfully listing users as admin.""" response = await client.get( "/api/v1/admin/users", - headers={"Authorization": f"Bearer {superuser_token}"} + headers={"Authorization": f"Bearer {superuser_token}"}, ) assert response.status_code == status.HTTP_200_OK @@ -47,27 +47,30 @@ class TestAdminListUsers: assert isinstance(data["data"], list) @pytest.mark.asyncio - async def test_admin_list_users_with_filters(self, client, async_test_superuser, async_test_db, superuser_token): + async def test_admin_list_users_with_filters( + self, client, async_test_superuser, async_test_db, superuser_token + ): """Test listing users with filters.""" - test_engine, AsyncTestingSessionLocal = async_test_db + _test_engine, AsyncTestingSessionLocal = async_test_db # Create inactive user async with AsyncTestingSessionLocal() as session: - from app.models.user import User from app.core.auth import get_password_hash + from app.models.user import User + inactive_user = User( email="inactive@example.com", password_hash=get_password_hash("TestPassword123!"), first_name="Inactive", last_name="User", - is_active=False + is_active=False, ) session.add(inactive_user) await session.commit() response = await client.get( "/api/v1/admin/users?is_active=false", - headers={"Authorization": f"Bearer {superuser_token}"} + headers={"Authorization": f"Bearer {superuser_token}"}, ) assert response.status_code == status.HTTP_200_OK @@ -75,11 +78,13 @@ class TestAdminListUsers: assert len(data["data"]) >= 1 @pytest.mark.asyncio - async def test_admin_list_users_with_search(self, client, async_test_superuser, superuser_token): + async def test_admin_list_users_with_search( + self, client, async_test_superuser, superuser_token + ): """Test searching users.""" response = await client.get( "/api/v1/admin/users?search=superuser", - headers={"Authorization": f"Bearer {superuser_token}"} + headers={"Authorization": f"Bearer {superuser_token}"}, ) assert response.status_code == status.HTTP_200_OK @@ -92,13 +97,12 @@ class TestAdminListUsers: # Login as regular user login_response = await client.post( "/api/v1/auth/login", - json={"email": async_test_user.email, "password": "TestPassword123!"} + json={"email": async_test_user.email, "password": "TestPassword123!"}, ) token = login_response.json()["access_token"] response = await client.get( - "/api/v1/admin/users", - headers={"Authorization": f"Bearer {token}"} + "/api/v1/admin/users", headers={"Authorization": f"Bearer {token}"} ) assert response.status_code == status.HTTP_403_FORBIDDEN @@ -108,7 +112,9 @@ class TestAdminCreateUser: """Tests for POST /admin/users endpoint.""" @pytest.mark.asyncio - async def test_admin_create_user_success(self, client, async_test_superuser, superuser_token): + async def test_admin_create_user_success( + self, client, async_test_superuser, superuser_token + ): """Test successfully creating a user as admin.""" response = await client.post( "/api/v1/admin/users", @@ -116,9 +122,9 @@ class TestAdminCreateUser: "email": "newadminuser@example.com", "password": "SecurePassword123!", "first_name": "New", - "last_name": "User" + "last_name": "User", }, - headers={"Authorization": f"Bearer {superuser_token}"} + headers={"Authorization": f"Bearer {superuser_token}"}, ) assert response.status_code == status.HTTP_201_CREATED @@ -126,7 +132,9 @@ class TestAdminCreateUser: assert data["email"] == "newadminuser@example.com" @pytest.mark.asyncio - async def test_admin_create_user_duplicate_email(self, client, async_test_superuser, async_test_user, superuser_token): + async def test_admin_create_user_duplicate_email( + self, client, async_test_superuser, async_test_user, superuser_token + ): """Test creating user with duplicate email fails.""" response = await client.post( "/api/v1/admin/users", @@ -134,9 +142,9 @@ class TestAdminCreateUser: "email": async_test_user.email, "password": "SecurePassword123!", "first_name": "Duplicate", - "last_name": "User" + "last_name": "User", }, - headers={"Authorization": f"Bearer {superuser_token}"} + headers={"Authorization": f"Bearer {superuser_token}"}, ) assert response.status_code == status.HTTP_404_NOT_FOUND @@ -146,11 +154,13 @@ class TestAdminGetUser: """Tests for GET /admin/users/{user_id} endpoint.""" @pytest.mark.asyncio - async def test_admin_get_user_success(self, client, async_test_superuser, async_test_user, superuser_token): + async def test_admin_get_user_success( + self, client, async_test_superuser, async_test_user, superuser_token + ): """Test successfully getting user details.""" response = await client.get( f"/api/v1/admin/users/{async_test_user.id}", - headers={"Authorization": f"Bearer {superuser_token}"} + headers={"Authorization": f"Bearer {superuser_token}"}, ) assert response.status_code == status.HTTP_200_OK @@ -159,11 +169,13 @@ class TestAdminGetUser: assert data["email"] == async_test_user.email @pytest.mark.asyncio - async def test_admin_get_user_not_found(self, client, async_test_superuser, superuser_token): + async def test_admin_get_user_not_found( + self, client, async_test_superuser, superuser_token + ): """Test getting non-existent user.""" response = await client.get( f"/api/v1/admin/users/{uuid4()}", - headers={"Authorization": f"Bearer {superuser_token}"} + headers={"Authorization": f"Bearer {superuser_token}"}, ) assert response.status_code == status.HTTP_404_NOT_FOUND @@ -173,12 +185,14 @@ class TestAdminUpdateUser: """Tests for PUT /admin/users/{user_id} endpoint.""" @pytest.mark.asyncio - async def test_admin_update_user_success(self, client, async_test_superuser, async_test_user, superuser_token): + async def test_admin_update_user_success( + self, client, async_test_superuser, async_test_user, superuser_token + ): """Test successfully updating a user.""" response = await client.put( f"/api/v1/admin/users/{async_test_user.id}", json={"first_name": "Updated"}, - headers={"Authorization": f"Bearer {superuser_token}"} + headers={"Authorization": f"Bearer {superuser_token}"}, ) assert response.status_code == status.HTTP_200_OK @@ -186,12 +200,14 @@ class TestAdminUpdateUser: assert data["first_name"] == "Updated" @pytest.mark.asyncio - async def test_admin_update_user_not_found(self, client, async_test_superuser, superuser_token): + async def test_admin_update_user_not_found( + self, client, async_test_superuser, superuser_token + ): """Test updating non-existent user.""" response = await client.put( f"/api/v1/admin/users/{uuid4()}", json={"first_name": "Updated"}, - headers={"Authorization": f"Bearer {superuser_token}"} + headers={"Authorization": f"Bearer {superuser_token}"}, ) assert response.status_code == status.HTTP_404_NOT_FOUND @@ -201,19 +217,22 @@ class TestAdminDeleteUser: """Tests for DELETE /admin/users/{user_id} endpoint.""" @pytest.mark.asyncio - async def test_admin_delete_user_success(self, client, async_test_superuser, async_test_db, superuser_token): + async def test_admin_delete_user_success( + self, client, async_test_superuser, async_test_db, superuser_token + ): """Test successfully deleting a user.""" - test_engine, AsyncTestingSessionLocal = async_test_db + _test_engine, AsyncTestingSessionLocal = async_test_db # Create user to delete async with AsyncTestingSessionLocal() as session: - from app.models.user import User from app.core.auth import get_password_hash + from app.models.user import User + user_to_delete = User( email="todelete@example.com", password_hash=get_password_hash("TestPassword123!"), first_name="To", - last_name="Delete" + last_name="Delete", ) session.add(user_to_delete) await session.commit() @@ -221,7 +240,7 @@ class TestAdminDeleteUser: response = await client.delete( f"/api/v1/admin/users/{user_id}", - headers={"Authorization": f"Bearer {superuser_token}"} + headers={"Authorization": f"Bearer {superuser_token}"}, ) assert response.status_code == status.HTTP_200_OK @@ -229,21 +248,25 @@ class TestAdminDeleteUser: assert data["success"] is True @pytest.mark.asyncio - async def test_admin_delete_user_not_found(self, client, async_test_superuser, superuser_token): + async def test_admin_delete_user_not_found( + self, client, async_test_superuser, superuser_token + ): """Test deleting non-existent user.""" response = await client.delete( f"/api/v1/admin/users/{uuid4()}", - headers={"Authorization": f"Bearer {superuser_token}"} + headers={"Authorization": f"Bearer {superuser_token}"}, ) assert response.status_code == status.HTTP_404_NOT_FOUND @pytest.mark.asyncio - async def test_admin_delete_self_forbidden(self, client, async_test_superuser, superuser_token): + async def test_admin_delete_self_forbidden( + self, client, async_test_superuser, superuser_token + ): """Test admin cannot delete their own account.""" response = await client.delete( f"/api/v1/admin/users/{async_test_superuser.id}", - headers={"Authorization": f"Bearer {superuser_token}"} + headers={"Authorization": f"Bearer {superuser_token}"}, ) assert response.status_code == status.HTTP_403_FORBIDDEN @@ -253,20 +276,23 @@ class TestAdminActivateUser: """Tests for POST /admin/users/{user_id}/activate endpoint.""" @pytest.mark.asyncio - async def test_admin_activate_user_success(self, client, async_test_superuser, async_test_db, superuser_token): + async def test_admin_activate_user_success( + self, client, async_test_superuser, async_test_db, superuser_token + ): """Test successfully activating a user.""" - test_engine, AsyncTestingSessionLocal = async_test_db + _test_engine, AsyncTestingSessionLocal = async_test_db # Create inactive user async with AsyncTestingSessionLocal() as session: - from app.models.user import User from app.core.auth import get_password_hash + from app.models.user import User + inactive_user = User( email="toactivate@example.com", password_hash=get_password_hash("TestPassword123!"), first_name="To", last_name="Activate", - is_active=False + is_active=False, ) session.add(inactive_user) await session.commit() @@ -274,7 +300,7 @@ class TestAdminActivateUser: response = await client.post( f"/api/v1/admin/users/{user_id}/activate", - headers={"Authorization": f"Bearer {superuser_token}"} + headers={"Authorization": f"Bearer {superuser_token}"}, ) assert response.status_code == status.HTTP_200_OK @@ -282,11 +308,13 @@ class TestAdminActivateUser: assert data["success"] is True @pytest.mark.asyncio - async def test_admin_activate_user_not_found(self, client, async_test_superuser, superuser_token): + async def test_admin_activate_user_not_found( + self, client, async_test_superuser, superuser_token + ): """Test activating non-existent user.""" response = await client.post( f"/api/v1/admin/users/{uuid4()}/activate", - headers={"Authorization": f"Bearer {superuser_token}"} + headers={"Authorization": f"Bearer {superuser_token}"}, ) assert response.status_code == status.HTTP_404_NOT_FOUND @@ -296,11 +324,13 @@ class TestAdminDeactivateUser: """Tests for POST /admin/users/{user_id}/deactivate endpoint.""" @pytest.mark.asyncio - async def test_admin_deactivate_user_success(self, client, async_test_superuser, async_test_user, superuser_token): + async def test_admin_deactivate_user_success( + self, client, async_test_superuser, async_test_user, superuser_token + ): """Test successfully deactivating a user.""" response = await client.post( f"/api/v1/admin/users/{async_test_user.id}/deactivate", - headers={"Authorization": f"Bearer {superuser_token}"} + headers={"Authorization": f"Bearer {superuser_token}"}, ) assert response.status_code == status.HTTP_200_OK @@ -308,21 +338,25 @@ class TestAdminDeactivateUser: assert data["success"] is True @pytest.mark.asyncio - async def test_admin_deactivate_user_not_found(self, client, async_test_superuser, superuser_token): + async def test_admin_deactivate_user_not_found( + self, client, async_test_superuser, superuser_token + ): """Test deactivating non-existent user.""" response = await client.post( f"/api/v1/admin/users/{uuid4()}/deactivate", - headers={"Authorization": f"Bearer {superuser_token}"} + headers={"Authorization": f"Bearer {superuser_token}"}, ) assert response.status_code == status.HTTP_404_NOT_FOUND @pytest.mark.asyncio - async def test_admin_deactivate_self_forbidden(self, client, async_test_superuser, superuser_token): + async def test_admin_deactivate_self_forbidden( + self, client, async_test_superuser, superuser_token + ): """Test admin cannot deactivate their own account.""" response = await client.post( f"/api/v1/admin/users/{async_test_superuser.id}/deactivate", - headers={"Authorization": f"Bearer {superuser_token}"} + headers={"Authorization": f"Bearer {superuser_token}"}, ) assert response.status_code == status.HTTP_403_FORBIDDEN @@ -332,22 +366,25 @@ class TestAdminBulkUserAction: """Tests for POST /admin/users/bulk-action endpoint.""" @pytest.mark.asyncio - async def test_admin_bulk_activate_users(self, client, async_test_superuser, async_test_db, superuser_token): + async def test_admin_bulk_activate_users( + self, client, async_test_superuser, async_test_db, superuser_token + ): """Test bulk activating users.""" - test_engine, AsyncTestingSessionLocal = async_test_db + _test_engine, AsyncTestingSessionLocal = async_test_db # Create inactive users user_ids = [] async with AsyncTestingSessionLocal() as session: - from app.models.user import User from app.core.auth import get_password_hash + from app.models.user import User + for i in range(3): user = User( email=f"bulk{i}@example.com", password_hash=get_password_hash("TestPassword123!"), first_name=f"Bulk{i}", last_name="User", - is_active=False + is_active=False, ) session.add(user) await session.flush() @@ -356,11 +393,8 @@ class TestAdminBulkUserAction: response = await client.post( "/api/v1/admin/users/bulk-action", - json={ - "action": "activate", - "user_ids": user_ids - }, - headers={"Authorization": f"Bearer {superuser_token}"} + json={"action": "activate", "user_ids": user_ids}, + headers={"Authorization": f"Bearer {superuser_token}"}, ) assert response.status_code == status.HTTP_200_OK @@ -368,22 +402,25 @@ class TestAdminBulkUserAction: assert data["affected_count"] == 3 @pytest.mark.asyncio - async def test_admin_bulk_deactivate_users(self, client, async_test_superuser, async_test_db, superuser_token): + async def test_admin_bulk_deactivate_users( + self, client, async_test_superuser, async_test_db, superuser_token + ): """Test bulk deactivating users.""" - test_engine, AsyncTestingSessionLocal = async_test_db + _test_engine, AsyncTestingSessionLocal = async_test_db # Create active users user_ids = [] async with AsyncTestingSessionLocal() as session: - from app.models.user import User from app.core.auth import get_password_hash + from app.models.user import User + for i in range(2): user = User( email=f"deactivate{i}@example.com", password_hash=get_password_hash("TestPassword123!"), first_name=f"Deactivate{i}", last_name="User", - is_active=True + is_active=True, ) session.add(user) await session.flush() @@ -392,11 +429,8 @@ class TestAdminBulkUserAction: response = await client.post( "/api/v1/admin/users/bulk-action", - json={ - "action": "deactivate", - "user_ids": user_ids - }, - headers={"Authorization": f"Bearer {superuser_token}"} + json={"action": "deactivate", "user_ids": user_ids}, + headers={"Authorization": f"Bearer {superuser_token}"}, ) assert response.status_code == status.HTTP_200_OK @@ -404,21 +438,24 @@ class TestAdminBulkUserAction: assert data["affected_count"] == 2 @pytest.mark.asyncio - async def test_admin_bulk_delete_users(self, client, async_test_superuser, async_test_db, superuser_token): + async def test_admin_bulk_delete_users( + self, client, async_test_superuser, async_test_db, superuser_token + ): """Test bulk deleting users.""" - test_engine, AsyncTestingSessionLocal = async_test_db + _test_engine, AsyncTestingSessionLocal = async_test_db # Create users to delete user_ids = [] async with AsyncTestingSessionLocal() as session: - from app.models.user import User from app.core.auth import get_password_hash + from app.models.user import User + for i in range(2): user = User( email=f"bulkdelete{i}@example.com", password_hash=get_password_hash("TestPassword123!"), first_name=f"BulkDelete{i}", - last_name="User" + last_name="User", ) session.add(user) await session.flush() @@ -427,11 +464,8 @@ class TestAdminBulkUserAction: response = await client.post( "/api/v1/admin/users/bulk-action", - json={ - "action": "delete", - "user_ids": user_ids - }, - headers={"Authorization": f"Bearer {superuser_token}"} + json={"action": "delete", "user_ids": user_ids}, + headers={"Authorization": f"Bearer {superuser_token}"}, ) assert response.status_code == status.HTTP_200_OK @@ -441,13 +475,16 @@ class TestAdminBulkUserAction: # ===== ORGANIZATION MANAGEMENT TESTS ===== + class TestAdminListOrganizations: """Tests for GET /admin/organizations endpoint.""" @pytest.mark.asyncio - async def test_admin_list_organizations_success(self, client, async_test_superuser, async_test_db, superuser_token): + async def test_admin_list_organizations_success( + self, client, async_test_superuser, async_test_db, superuser_token + ): """Test successfully listing organizations.""" - test_engine, AsyncTestingSessionLocal = async_test_db + _test_engine, AsyncTestingSessionLocal = async_test_db # Create test organization async with AsyncTestingSessionLocal() as session: @@ -457,7 +494,7 @@ class TestAdminListOrganizations: response = await client.get( "/api/v1/admin/organizations", - headers={"Authorization": f"Bearer {superuser_token}"} + headers={"Authorization": f"Bearer {superuser_token}"}, ) assert response.status_code == status.HTTP_200_OK @@ -466,9 +503,11 @@ class TestAdminListOrganizations: assert "pagination" in data @pytest.mark.asyncio - async def test_admin_list_organizations_with_search(self, client, async_test_superuser, async_test_db, superuser_token): + async def test_admin_list_organizations_with_search( + self, client, async_test_superuser, async_test_db, superuser_token + ): """Test searching organizations.""" - test_engine, AsyncTestingSessionLocal = async_test_db + _test_engine, AsyncTestingSessionLocal = async_test_db # Create test organization async with AsyncTestingSessionLocal() as session: @@ -478,7 +517,7 @@ class TestAdminListOrganizations: response = await client.get( "/api/v1/admin/organizations?search=Searchable", - headers={"Authorization": f"Bearer {superuser_token}"} + headers={"Authorization": f"Bearer {superuser_token}"}, ) assert response.status_code == status.HTTP_200_OK @@ -488,16 +527,18 @@ class TestAdminCreateOrganization: """Tests for POST /admin/organizations endpoint.""" @pytest.mark.asyncio - async def test_admin_create_organization_success(self, client, async_test_superuser, superuser_token): + async def test_admin_create_organization_success( + self, client, async_test_superuser, superuser_token + ): """Test successfully creating an organization.""" response = await client.post( "/api/v1/admin/organizations", json={ "name": "New Admin Org", "slug": "new-admin-org", - "description": "Created by admin" + "description": "Created by admin", }, - headers={"Authorization": f"Bearer {superuser_token}"} + headers={"Authorization": f"Bearer {superuser_token}"}, ) assert response.status_code == status.HTTP_201_CREATED @@ -506,9 +547,11 @@ class TestAdminCreateOrganization: assert data["member_count"] == 0 @pytest.mark.asyncio - async def test_admin_create_organization_duplicate_slug(self, client, async_test_superuser, async_test_db, superuser_token): + async def test_admin_create_organization_duplicate_slug( + self, client, async_test_superuser, async_test_db, superuser_token + ): """Test creating organization with duplicate slug fails.""" - test_engine, AsyncTestingSessionLocal = async_test_db + _test_engine, AsyncTestingSessionLocal = async_test_db # Create existing organization async with AsyncTestingSessionLocal() as session: @@ -518,11 +561,8 @@ class TestAdminCreateOrganization: response = await client.post( "/api/v1/admin/organizations", - json={ - "name": "Duplicate", - "slug": "duplicate-slug" - }, - headers={"Authorization": f"Bearer {superuser_token}"} + json={"name": "Duplicate", "slug": "duplicate-slug"}, + headers={"Authorization": f"Bearer {superuser_token}"}, ) assert response.status_code == status.HTTP_404_NOT_FOUND @@ -532,9 +572,11 @@ class TestAdminGetOrganization: """Tests for GET /admin/organizations/{org_id} endpoint.""" @pytest.mark.asyncio - async def test_admin_get_organization_success(self, client, async_test_superuser, async_test_db, superuser_token): + async def test_admin_get_organization_success( + self, client, async_test_superuser, async_test_db, superuser_token + ): """Test successfully getting organization details.""" - test_engine, AsyncTestingSessionLocal = async_test_db + _test_engine, AsyncTestingSessionLocal = async_test_db # Create test organization async with AsyncTestingSessionLocal() as session: @@ -545,7 +587,7 @@ class TestAdminGetOrganization: response = await client.get( f"/api/v1/admin/organizations/{org_id}", - headers={"Authorization": f"Bearer {superuser_token}"} + headers={"Authorization": f"Bearer {superuser_token}"}, ) assert response.status_code == status.HTTP_200_OK @@ -553,11 +595,13 @@ class TestAdminGetOrganization: assert data["name"] == "Get Test Org" @pytest.mark.asyncio - async def test_admin_get_organization_not_found(self, client, async_test_superuser, superuser_token): + async def test_admin_get_organization_not_found( + self, client, async_test_superuser, superuser_token + ): """Test getting non-existent organization.""" response = await client.get( f"/api/v1/admin/organizations/{uuid4()}", - headers={"Authorization": f"Bearer {superuser_token}"} + headers={"Authorization": f"Bearer {superuser_token}"}, ) assert response.status_code == status.HTTP_404_NOT_FOUND @@ -567,9 +611,11 @@ class TestAdminUpdateOrganization: """Tests for PUT /admin/organizations/{org_id} endpoint.""" @pytest.mark.asyncio - async def test_admin_update_organization_success(self, client, async_test_superuser, async_test_db, superuser_token): + async def test_admin_update_organization_success( + self, client, async_test_superuser, async_test_db, superuser_token + ): """Test successfully updating an organization.""" - test_engine, AsyncTestingSessionLocal = async_test_db + _test_engine, AsyncTestingSessionLocal = async_test_db # Create test organization async with AsyncTestingSessionLocal() as session: @@ -581,7 +627,7 @@ class TestAdminUpdateOrganization: response = await client.put( f"/api/v1/admin/organizations/{org_id}", json={"name": "Updated Name"}, - headers={"Authorization": f"Bearer {superuser_token}"} + headers={"Authorization": f"Bearer {superuser_token}"}, ) assert response.status_code == status.HTTP_200_OK @@ -589,12 +635,14 @@ class TestAdminUpdateOrganization: assert data["name"] == "Updated Name" @pytest.mark.asyncio - async def test_admin_update_organization_not_found(self, client, async_test_superuser, superuser_token): + async def test_admin_update_organization_not_found( + self, client, async_test_superuser, superuser_token + ): """Test updating non-existent organization.""" response = await client.put( f"/api/v1/admin/organizations/{uuid4()}", json={"name": "Updated"}, - headers={"Authorization": f"Bearer {superuser_token}"} + headers={"Authorization": f"Bearer {superuser_token}"}, ) assert response.status_code == status.HTTP_404_NOT_FOUND @@ -604,9 +652,11 @@ class TestAdminDeleteOrganization: """Tests for DELETE /admin/organizations/{org_id} endpoint.""" @pytest.mark.asyncio - async def test_admin_delete_organization_success(self, client, async_test_superuser, async_test_db, superuser_token): + async def test_admin_delete_organization_success( + self, client, async_test_superuser, async_test_db, superuser_token + ): """Test successfully deleting an organization.""" - test_engine, AsyncTestingSessionLocal = async_test_db + _test_engine, AsyncTestingSessionLocal = async_test_db # Create test organization async with AsyncTestingSessionLocal() as session: @@ -617,7 +667,7 @@ class TestAdminDeleteOrganization: response = await client.delete( f"/api/v1/admin/organizations/{org_id}", - headers={"Authorization": f"Bearer {superuser_token}"} + headers={"Authorization": f"Bearer {superuser_token}"}, ) assert response.status_code == status.HTTP_200_OK @@ -625,11 +675,13 @@ class TestAdminDeleteOrganization: assert data["success"] is True @pytest.mark.asyncio - async def test_admin_delete_organization_not_found(self, client, async_test_superuser, superuser_token): + async def test_admin_delete_organization_not_found( + self, client, async_test_superuser, superuser_token + ): """Test deleting non-existent organization.""" response = await client.delete( f"/api/v1/admin/organizations/{uuid4()}", - headers={"Authorization": f"Bearer {superuser_token}"} + headers={"Authorization": f"Bearer {superuser_token}"}, ) assert response.status_code == status.HTTP_404_NOT_FOUND @@ -639,9 +691,16 @@ class TestAdminListOrganizationMembers: """Tests for GET /admin/organizations/{org_id}/members endpoint.""" @pytest.mark.asyncio - async def test_admin_list_organization_members_success(self, client, async_test_superuser, async_test_db, async_test_user, superuser_token): + async def test_admin_list_organization_members_success( + self, + client, + async_test_superuser, + async_test_db, + async_test_user, + superuser_token, + ): """Test successfully listing organization members.""" - test_engine, AsyncTestingSessionLocal = async_test_db + _test_engine, AsyncTestingSessionLocal = async_test_db # Create test organization with member async with AsyncTestingSessionLocal() as session: @@ -653,7 +712,7 @@ class TestAdminListOrganizationMembers: user_id=async_test_user.id, organization_id=org.id, role=OrganizationRole.MEMBER, - is_active=True + is_active=True, ) session.add(user_org) await session.commit() @@ -661,7 +720,7 @@ class TestAdminListOrganizationMembers: response = await client.get( f"/api/v1/admin/organizations/{org_id}/members", - headers={"Authorization": f"Bearer {superuser_token}"} + headers={"Authorization": f"Bearer {superuser_token}"}, ) assert response.status_code == status.HTTP_200_OK @@ -670,11 +729,13 @@ class TestAdminListOrganizationMembers: assert len(data["data"]) >= 1 @pytest.mark.asyncio - async def test_admin_list_organization_members_not_found(self, client, async_test_superuser, superuser_token): + async def test_admin_list_organization_members_not_found( + self, client, async_test_superuser, superuser_token + ): """Test listing members of non-existent organization.""" response = await client.get( f"/api/v1/admin/organizations/{uuid4()}/members", - headers={"Authorization": f"Bearer {superuser_token}"} + headers={"Authorization": f"Bearer {superuser_token}"}, ) assert response.status_code == status.HTTP_404_NOT_FOUND @@ -684,9 +745,16 @@ class TestAdminAddOrganizationMember: """Tests for POST /admin/organizations/{org_id}/members endpoint.""" @pytest.mark.asyncio - async def test_admin_add_organization_member_success(self, client, async_test_superuser, async_test_db, async_test_user, superuser_token): + async def test_admin_add_organization_member_success( + self, + client, + async_test_superuser, + async_test_db, + async_test_user, + superuser_token, + ): """Test successfully adding a member to organization.""" - test_engine, AsyncTestingSessionLocal = async_test_db + _test_engine, AsyncTestingSessionLocal = async_test_db # Create test organization async with AsyncTestingSessionLocal() as session: @@ -697,11 +765,8 @@ class TestAdminAddOrganizationMember: response = await client.post( f"/api/v1/admin/organizations/{org_id}/members", - json={ - "user_id": str(async_test_user.id), - "role": "member" - }, - headers={"Authorization": f"Bearer {superuser_token}"} + json={"user_id": str(async_test_user.id), "role": "member"}, + headers={"Authorization": f"Bearer {superuser_token}"}, ) assert response.status_code == status.HTTP_200_OK @@ -709,9 +774,16 @@ class TestAdminAddOrganizationMember: assert data["success"] is True @pytest.mark.asyncio - async def test_admin_add_organization_member_already_exists(self, client, async_test_superuser, async_test_db, async_test_user, superuser_token): + async def test_admin_add_organization_member_already_exists( + self, + client, + async_test_superuser, + async_test_db, + async_test_user, + superuser_token, + ): """Test adding member who is already a member.""" - test_engine, AsyncTestingSessionLocal = async_test_db + _test_engine, AsyncTestingSessionLocal = async_test_db # Create organization with existing member async with AsyncTestingSessionLocal() as session: @@ -723,7 +795,7 @@ class TestAdminAddOrganizationMember: user_id=async_test_user.id, organization_id=org.id, role=OrganizationRole.MEMBER, - is_active=True + is_active=True, ) session.add(user_org) await session.commit() @@ -731,33 +803,31 @@ class TestAdminAddOrganizationMember: response = await client.post( f"/api/v1/admin/organizations/{org_id}/members", - json={ - "user_id": str(async_test_user.id), - "role": "member" - }, - headers={"Authorization": f"Bearer {superuser_token}"} + json={"user_id": str(async_test_user.id), "role": "member"}, + headers={"Authorization": f"Bearer {superuser_token}"}, ) assert response.status_code == status.HTTP_409_CONFLICT @pytest.mark.asyncio - async def test_admin_add_organization_member_org_not_found(self, client, async_test_superuser, async_test_user, superuser_token): + async def test_admin_add_organization_member_org_not_found( + self, client, async_test_superuser, async_test_user, superuser_token + ): """Test adding member to non-existent organization.""" response = await client.post( f"/api/v1/admin/organizations/{uuid4()}/members", - json={ - "user_id": str(async_test_user.id), - "role": "member" - }, - headers={"Authorization": f"Bearer {superuser_token}"} + json={"user_id": str(async_test_user.id), "role": "member"}, + headers={"Authorization": f"Bearer {superuser_token}"}, ) assert response.status_code == status.HTTP_404_NOT_FOUND @pytest.mark.asyncio - async def test_admin_add_organization_member_user_not_found(self, client, async_test_superuser, async_test_db, superuser_token): + async def test_admin_add_organization_member_user_not_found( + self, client, async_test_superuser, async_test_db, superuser_token + ): """Test adding non-existent user to organization.""" - test_engine, AsyncTestingSessionLocal = async_test_db + _test_engine, AsyncTestingSessionLocal = async_test_db # Create test organization async with AsyncTestingSessionLocal() as session: @@ -768,11 +838,8 @@ class TestAdminAddOrganizationMember: response = await client.post( f"/api/v1/admin/organizations/{org_id}/members", - json={ - "user_id": str(uuid4()), - "role": "member" - }, - headers={"Authorization": f"Bearer {superuser_token}"} + json={"user_id": str(uuid4()), "role": "member"}, + headers={"Authorization": f"Bearer {superuser_token}"}, ) assert response.status_code == status.HTTP_404_NOT_FOUND @@ -782,9 +849,16 @@ class TestAdminRemoveOrganizationMember: """Tests for DELETE /admin/organizations/{org_id}/members/{user_id} endpoint.""" @pytest.mark.asyncio - async def test_admin_remove_organization_member_success(self, client, async_test_superuser, async_test_db, async_test_user, superuser_token): + async def test_admin_remove_organization_member_success( + self, + client, + async_test_superuser, + async_test_db, + async_test_user, + superuser_token, + ): """Test successfully removing a member from organization.""" - test_engine, AsyncTestingSessionLocal = async_test_db + _test_engine, AsyncTestingSessionLocal = async_test_db # Create organization with member async with AsyncTestingSessionLocal() as session: @@ -796,7 +870,7 @@ class TestAdminRemoveOrganizationMember: user_id=async_test_user.id, organization_id=org.id, role=OrganizationRole.MEMBER, - is_active=True + is_active=True, ) session.add(user_org) await session.commit() @@ -804,7 +878,7 @@ class TestAdminRemoveOrganizationMember: response = await client.delete( f"/api/v1/admin/organizations/{org_id}/members/{async_test_user.id}", - headers={"Authorization": f"Bearer {superuser_token}"} + headers={"Authorization": f"Bearer {superuser_token}"}, ) assert response.status_code == status.HTTP_200_OK @@ -812,9 +886,16 @@ class TestAdminRemoveOrganizationMember: assert data["success"] is True @pytest.mark.asyncio - async def test_admin_remove_organization_member_not_member(self, client, async_test_superuser, async_test_db, async_test_user, superuser_token): + async def test_admin_remove_organization_member_not_member( + self, + client, + async_test_superuser, + async_test_db, + async_test_user, + superuser_token, + ): """Test removing user who is not a member.""" - test_engine, AsyncTestingSessionLocal = async_test_db + _test_engine, AsyncTestingSessionLocal = async_test_db # Create organization without member async with AsyncTestingSessionLocal() as session: @@ -825,17 +906,19 @@ class TestAdminRemoveOrganizationMember: response = await client.delete( f"/api/v1/admin/organizations/{org_id}/members/{async_test_user.id}", - headers={"Authorization": f"Bearer {superuser_token}"} + headers={"Authorization": f"Bearer {superuser_token}"}, ) assert response.status_code == status.HTTP_404_NOT_FOUND @pytest.mark.asyncio - async def test_admin_remove_organization_member_org_not_found(self, client, async_test_superuser, async_test_user, superuser_token): + async def test_admin_remove_organization_member_org_not_found( + self, client, async_test_superuser, async_test_user, superuser_token + ): """Test removing member from non-existent organization.""" response = await client.delete( f"/api/v1/admin/organizations/{uuid4()}/members/{async_test_user.id}", - headers={"Authorization": f"Bearer {superuser_token}"} + headers={"Authorization": f"Bearer {superuser_token}"}, ) assert response.status_code == status.HTTP_404_NOT_FOUND @@ -843,17 +926,25 @@ class TestAdminRemoveOrganizationMember: # ===== SESSION MANAGEMENT TESTS ===== + class TestAdminListSessions: """Tests for admin sessions list endpoint.""" @pytest.mark.asyncio - async def test_admin_list_sessions_success(self, client, async_test_superuser, async_test_user, async_test_db, superuser_token): + async def test_admin_list_sessions_success( + self, + client, + async_test_superuser, + async_test_user, + async_test_db, + superuser_token, + ): """Test listing all sessions as admin.""" - test_engine, AsyncTestingSessionLocal = async_test_db + _test_engine, AsyncTestingSessionLocal = async_test_db # Create some test sessions async with AsyncTestingSessionLocal() as session: - now = datetime.now(timezone.utc) + now = datetime.now(UTC) expires_at = now + timedelta(days=7) session1 = UserSession( @@ -867,7 +958,7 @@ class TestAdminListSessions: expires_at=expires_at, is_active=True, location_city="San Francisco", - location_country="United States" + location_country="United States", ) session2 = UserSession( user_id=async_test_superuser.id, @@ -878,14 +969,14 @@ class TestAdminListSessions: user_agent="Mozilla/5.0", last_used_at=now, expires_at=expires_at, - is_active=True + is_active=True, ) session.add_all([session1, session2]) await session.commit() response = await client.get( "/api/v1/admin/sessions?page=1&limit=10", - headers={"Authorization": f"Bearer {superuser_token}"} + headers={"Authorization": f"Bearer {superuser_token}"}, ) assert response.status_code == status.HTTP_200_OK @@ -905,13 +996,20 @@ class TestAdminListSessions: assert "is_active" in first_session @pytest.mark.asyncio - async def test_admin_list_sessions_filter_active(self, client, async_test_superuser, async_test_user, async_test_db, superuser_token): + async def test_admin_list_sessions_filter_active( + self, + client, + async_test_superuser, + async_test_user, + async_test_db, + superuser_token, + ): """Test filtering sessions by active status.""" - test_engine, AsyncTestingSessionLocal = async_test_db + _test_engine, AsyncTestingSessionLocal = async_test_db # Create active and inactive sessions async with AsyncTestingSessionLocal() as session: - now = datetime.now(timezone.utc) + now = datetime.now(UTC) expires_at = now + timedelta(days=7) active_session = UserSession( @@ -921,7 +1019,7 @@ class TestAdminListSessions: ip_address="192.168.1.100", last_used_at=now, expires_at=expires_at, - is_active=True + is_active=True, ) inactive_session = UserSession( user_id=async_test_user.id, @@ -930,7 +1028,7 @@ class TestAdminListSessions: ip_address="192.168.1.101", last_used_at=now, expires_at=expires_at, - is_active=False + is_active=False, ) session.add_all([active_session, inactive_session]) await session.commit() @@ -938,7 +1036,7 @@ class TestAdminListSessions: # Get only active sessions (default) response = await client.get( "/api/v1/admin/sessions?page=1&limit=100", - headers={"Authorization": f"Bearer {superuser_token}"} + headers={"Authorization": f"Bearer {superuser_token}"}, ) assert response.status_code == status.HTTP_200_OK @@ -949,13 +1047,15 @@ class TestAdminListSessions: assert sess["is_active"] is True @pytest.mark.asyncio - async def test_admin_list_sessions_pagination(self, client, async_test_superuser, async_test_db, superuser_token): + async def test_admin_list_sessions_pagination( + self, client, async_test_superuser, async_test_db, superuser_token + ): """Test pagination of sessions list.""" - test_engine, AsyncTestingSessionLocal = async_test_db + _test_engine, AsyncTestingSessionLocal = async_test_db # Create multiple sessions async with AsyncTestingSessionLocal() as session: - now = datetime.now(timezone.utc) + now = datetime.now(UTC) expires_at = now + timedelta(days=7) sessions = [] @@ -964,10 +1064,10 @@ class TestAdminListSessions: user_id=async_test_superuser.id, refresh_token_jti=f"jti-pagination-{i}", device_name=f"Device {i}", - ip_address=f"192.168.1.{100+i}", + ip_address=f"192.168.1.{100 + i}", last_used_at=now, expires_at=expires_at, - is_active=True + is_active=True, ) sessions.append(sess) session.add_all(sessions) @@ -976,7 +1076,7 @@ class TestAdminListSessions: # Get first page with limit 2 response = await client.get( "/api/v1/admin/sessions?page=1&limit=2", - headers={"Authorization": f"Bearer {superuser_token}"} + headers={"Authorization": f"Bearer {superuser_token}"}, ) assert response.status_code == status.HTTP_200_OK @@ -987,11 +1087,13 @@ class TestAdminListSessions: assert data["pagination"]["total"] >= 5 @pytest.mark.asyncio - async def test_admin_list_sessions_unauthorized(self, client, async_test_user, user_token): + async def test_admin_list_sessions_unauthorized( + self, client, async_test_user, user_token + ): """Test that non-admin users cannot access admin sessions endpoint.""" response = await client.get( "/api/v1/admin/sessions?page=1&limit=10", - headers={"Authorization": f"Bearer {user_token}"} + headers={"Authorization": f"Bearer {user_token}"}, ) assert response.status_code == status.HTTP_403_FORBIDDEN diff --git a/backend/tests/api/test_admin_error_handlers.py b/backend/tests/api/test_admin_error_handlers.py index 4c4bb0e..f639e12 100644 --- a/backend/tests/api/test_admin_error_handlers.py +++ b/backend/tests/api/test_admin_error_handlers.py @@ -3,12 +3,13 @@ Tests for admin route exception handlers, error paths, and success paths. Focus on code coverage of both error handling and normal operation branches. """ + +from unittest.mock import patch +from uuid import uuid4 + import pytest import pytest_asyncio -from unittest.mock import patch, AsyncMock from fastapi import status -from uuid import uuid4 -from app.models.user_organization import OrganizationRole @pytest_asyncio.fixture @@ -16,10 +17,7 @@ async def superuser_token(client, async_test_superuser): """Get access token for superuser.""" response = await client.post( "/api/v1/auth/login", - json={ - "email": "superuser@example.com", - "password": "SuperPassword123!" - } + json={"email": "superuser@example.com", "password": "SuperPassword123!"}, ) assert response.status_code == 200 return response.json()["access_token"] @@ -27,6 +25,7 @@ async def superuser_token(client, async_test_superuser): # ===== USER MANAGEMENT ERROR TESTS ===== + class TestAdminListUsersFilters: """Test admin list users with various filters.""" @@ -35,7 +34,7 @@ class TestAdminListUsersFilters: """Test listing users with is_superuser filter (covers line 96).""" response = await client.get( "/api/v1/admin/users?is_superuser=true", - headers={"Authorization": f"Bearer {superuser_token}"} + headers={"Authorization": f"Bearer {superuser_token}"}, ) assert response.status_code == status.HTTP_200_OK @@ -45,11 +44,14 @@ class TestAdminListUsersFilters: @pytest.mark.asyncio async def test_list_users_database_error_propagates(self, client, superuser_token): """Test that database errors propagate correctly (covers line 118-120).""" - with patch('app.api.routes.admin.user_crud.get_multi_with_total', side_effect=Exception("DB error")): + with patch( + "app.api.routes.admin.user_crud.get_multi_with_total", + side_effect=Exception("DB error"), + ): with pytest.raises(Exception): await client.get( "/api/v1/admin/users", - headers={"Authorization": f"Bearer {superuser_token}"} + headers={"Authorization": f"Bearer {superuser_token}"}, ) @@ -57,7 +59,9 @@ class TestAdminCreateUserErrors: """Test admin create user error handling.""" @pytest.mark.asyncio - async def test_create_user_duplicate_email(self, client, async_test_user, superuser_token): + async def test_create_user_duplicate_email( + self, client, async_test_user, superuser_token + ): """Test creating user with duplicate email (covers line 145-150).""" response = await client.post( "/api/v1/admin/users", @@ -66,17 +70,22 @@ class TestAdminCreateUserErrors: "email": async_test_user.email, "password": "NewPassword123!", "first_name": "Duplicate", - "last_name": "User" - } + "last_name": "User", + }, ) # Should get error for duplicate email assert response.status_code == status.HTTP_404_NOT_FOUND @pytest.mark.asyncio - async def test_create_user_unexpected_error_propagates(self, client, superuser_token): + async def test_create_user_unexpected_error_propagates( + self, client, superuser_token + ): """Test unexpected errors during user creation (covers line 151-153).""" - with patch('app.api.routes.admin.user_crud.create', side_effect=RuntimeError("Unexpected error")): + with patch( + "app.api.routes.admin.user_crud.create", + side_effect=RuntimeError("Unexpected error"), + ): with pytest.raises(RuntimeError): await client.post( "/api/v1/admin/users", @@ -85,8 +94,8 @@ class TestAdminCreateUserErrors: "email": "newerror@example.com", "password": "NewPassword123!", "first_name": "New", - "last_name": "User" - } + "last_name": "User", + }, ) @@ -99,7 +108,7 @@ class TestAdminGetUserErrors: fake_id = uuid4() response = await client.get( f"/api/v1/admin/users/{fake_id}", - headers={"Authorization": f"Bearer {superuser_token}"} + headers={"Authorization": f"Bearer {superuser_token}"}, ) assert response.status_code == status.HTTP_404_NOT_FOUND @@ -115,20 +124,25 @@ class TestAdminUpdateUserErrors: response = await client.put( f"/api/v1/admin/users/{fake_id}", headers={"Authorization": f"Bearer {superuser_token}"}, - json={"first_name": "Updated"} + json={"first_name": "Updated"}, ) assert response.status_code == status.HTTP_404_NOT_FOUND @pytest.mark.asyncio - async def test_update_user_unexpected_error(self, client, async_test_user, superuser_token): + async def test_update_user_unexpected_error( + self, client, async_test_user, superuser_token + ): """Test unexpected errors during user update (covers line 206-208).""" - with patch('app.api.routes.admin.user_crud.update', side_effect=RuntimeError("Update failed")): + with patch( + "app.api.routes.admin.user_crud.update", + side_effect=RuntimeError("Update failed"), + ): with pytest.raises(RuntimeError): await client.put( f"/api/v1/admin/users/{async_test_user.id}", headers={"Authorization": f"Bearer {superuser_token}"}, - json={"first_name": "Updated"} + json={"first_name": "Updated"}, ) @@ -141,19 +155,24 @@ class TestAdminDeleteUserErrors: fake_id = uuid4() response = await client.delete( f"/api/v1/admin/users/{fake_id}", - headers={"Authorization": f"Bearer {superuser_token}"} + headers={"Authorization": f"Bearer {superuser_token}"}, ) assert response.status_code == status.HTTP_404_NOT_FOUND @pytest.mark.asyncio - async def test_delete_user_unexpected_error(self, client, async_test_user, superuser_token): + async def test_delete_user_unexpected_error( + self, client, async_test_user, superuser_token + ): """Test unexpected errors during user deletion (covers line 238-240).""" - with patch('app.api.routes.admin.user_crud.soft_delete', side_effect=Exception("Delete failed")): + with patch( + "app.api.routes.admin.user_crud.soft_delete", + side_effect=Exception("Delete failed"), + ): with pytest.raises(Exception): await client.delete( f"/api/v1/admin/users/{async_test_user.id}", - headers={"Authorization": f"Bearer {superuser_token}"} + headers={"Authorization": f"Bearer {superuser_token}"}, ) @@ -166,19 +185,24 @@ class TestAdminActivateUserErrors: fake_id = uuid4() response = await client.post( f"/api/v1/admin/users/{fake_id}/activate", - headers={"Authorization": f"Bearer {superuser_token}"} + headers={"Authorization": f"Bearer {superuser_token}"}, ) assert response.status_code == status.HTTP_404_NOT_FOUND @pytest.mark.asyncio - async def test_activate_user_unexpected_error(self, client, async_test_user, superuser_token): + async def test_activate_user_unexpected_error( + self, client, async_test_user, superuser_token + ): """Test unexpected errors during user activation (covers line 282-284).""" - with patch('app.api.routes.admin.user_crud.update', side_effect=Exception("Activation failed")): + with patch( + "app.api.routes.admin.user_crud.update", + side_effect=Exception("Activation failed"), + ): with pytest.raises(Exception): await client.post( f"/api/v1/admin/users/{async_test_user.id}/activate", - headers={"Authorization": f"Bearer {superuser_token}"} + headers={"Authorization": f"Bearer {superuser_token}"}, ) @@ -191,45 +215,56 @@ class TestAdminDeactivateUserErrors: fake_id = uuid4() response = await client.post( f"/api/v1/admin/users/{fake_id}/deactivate", - headers={"Authorization": f"Bearer {superuser_token}"} + headers={"Authorization": f"Bearer {superuser_token}"}, ) assert response.status_code == status.HTTP_404_NOT_FOUND @pytest.mark.asyncio - async def test_deactivate_self_forbidden(self, client, async_test_superuser, superuser_token): + async def test_deactivate_self_forbidden( + self, client, async_test_superuser, superuser_token + ): """Test that admin cannot deactivate themselves (covers line 319-323).""" response = await client.post( f"/api/v1/admin/users/{async_test_superuser.id}/deactivate", - headers={"Authorization": f"Bearer {superuser_token}"} + headers={"Authorization": f"Bearer {superuser_token}"}, ) assert response.status_code == status.HTTP_403_FORBIDDEN @pytest.mark.asyncio - async def test_deactivate_user_unexpected_error(self, client, async_test_user, superuser_token): + async def test_deactivate_user_unexpected_error( + self, client, async_test_user, superuser_token + ): """Test unexpected errors during user deactivation (covers line 326-328).""" - with patch('app.api.routes.admin.user_crud.update', side_effect=Exception("Deactivation failed")): + with patch( + "app.api.routes.admin.user_crud.update", + side_effect=Exception("Deactivation failed"), + ): with pytest.raises(Exception): await client.post( f"/api/v1/admin/users/{async_test_user.id}/deactivate", - headers={"Authorization": f"Bearer {superuser_token}"} + headers={"Authorization": f"Bearer {superuser_token}"}, ) # ===== ORGANIZATION MANAGEMENT ERROR TESTS ===== + class TestAdminListOrganizationsErrors: """Test admin list organizations error handling.""" @pytest.mark.asyncio async def test_list_organizations_database_error(self, client, superuser_token): """Test list organizations with database error (covers line 427-456).""" - with patch('app.api.routes.admin.organization_crud.get_multi_with_member_counts', side_effect=Exception("DB error")): + with patch( + "app.api.routes.admin.organization_crud.get_multi_with_member_counts", + side_effect=Exception("DB error"), + ): with pytest.raises(Exception): await client.get( "/api/v1/admin/organizations", - headers={"Authorization": f"Bearer {superuser_token}"} + headers={"Authorization": f"Bearer {superuser_token}"}, ) @@ -237,17 +272,18 @@ class TestAdminCreateOrganizationErrors: """Test admin create organization error handling.""" @pytest.mark.asyncio - async def test_create_organization_duplicate_slug(self, client, async_test_db, superuser_token): + async def test_create_organization_duplicate_slug( + self, client, async_test_db, superuser_token + ): """Test creating organization with duplicate slug (covers line 480-483).""" - test_engine, AsyncTestingSessionLocal = async_test_db + _test_engine, AsyncTestingSessionLocal = async_test_db # Create an organization first async with AsyncTestingSessionLocal() as session: from app.models.organization import Organization + org = Organization( - name="Existing Org", - slug="existing-org", - description="Test org" + name="Existing Org", slug="existing-org", description="Test org" ) session.add(org) await session.commit() @@ -259,8 +295,8 @@ class TestAdminCreateOrganizationErrors: json={ "name": "New Org", "slug": "existing-org", - "description": "Duplicate slug" - } + "description": "Duplicate slug", + }, ) # Should get error for duplicate slug @@ -269,16 +305,15 @@ class TestAdminCreateOrganizationErrors: @pytest.mark.asyncio async def test_create_organization_unexpected_error(self, client, superuser_token): """Test unexpected errors during organization creation (covers line 484-485).""" - with patch('app.api.routes.admin.organization_crud.create', side_effect=RuntimeError("Creation failed")): + with patch( + "app.api.routes.admin.organization_crud.create", + side_effect=RuntimeError("Creation failed"), + ): with pytest.raises(RuntimeError): await client.post( "/api/v1/admin/organizations", headers={"Authorization": f"Bearer {superuser_token}"}, - json={ - "name": "New Org", - "slug": "new-org", - "description": "Test" - } + json={"name": "New Org", "slug": "new-org", "description": "Test"}, ) @@ -291,7 +326,7 @@ class TestAdminGetOrganizationErrors: fake_id = uuid4() response = await client.get( f"/api/v1/admin/organizations/{fake_id}", - headers={"Authorization": f"Bearer {superuser_token}"} + headers={"Authorization": f"Bearer {superuser_token}"}, ) assert response.status_code == status.HTTP_404_NOT_FOUND @@ -307,35 +342,39 @@ class TestAdminUpdateOrganizationErrors: response = await client.put( f"/api/v1/admin/organizations/{fake_id}", headers={"Authorization": f"Bearer {superuser_token}"}, - json={"name": "Updated Org"} + json={"name": "Updated Org"}, ) assert response.status_code == status.HTTP_404_NOT_FOUND @pytest.mark.asyncio - async def test_update_organization_unexpected_error(self, client, async_test_db, superuser_token): + async def test_update_organization_unexpected_error( + self, client, async_test_db, superuser_token + ): """Test unexpected errors during organization update (covers line 573-575).""" - test_engine, AsyncTestingSessionLocal = async_test_db + _test_engine, AsyncTestingSessionLocal = async_test_db # Create an organization async with AsyncTestingSessionLocal() as session: from app.models.organization import Organization + org = Organization( - name="Test Org", - slug="test-org-update-error", - description="Test" + name="Test Org", slug="test-org-update-error", description="Test" ) session.add(org) await session.commit() await session.refresh(org) org_id = org.id - with patch('app.api.routes.admin.organization_crud.update', side_effect=Exception("Update failed")): + with patch( + "app.api.routes.admin.organization_crud.update", + side_effect=Exception("Update failed"), + ): with pytest.raises(Exception): await client.put( f"/api/v1/admin/organizations/{org_id}", headers={"Authorization": f"Bearer {superuser_token}"}, - json={"name": "Updated"} + json={"name": "Updated"}, ) @@ -348,34 +387,38 @@ class TestAdminDeleteOrganizationErrors: fake_id = uuid4() response = await client.delete( f"/api/v1/admin/organizations/{fake_id}", - headers={"Authorization": f"Bearer {superuser_token}"} + headers={"Authorization": f"Bearer {superuser_token}"}, ) assert response.status_code == status.HTTP_404_NOT_FOUND @pytest.mark.asyncio - async def test_delete_organization_unexpected_error(self, client, async_test_db, superuser_token): + async def test_delete_organization_unexpected_error( + self, client, async_test_db, superuser_token + ): """Test unexpected errors during organization deletion (covers line 611-613).""" - test_engine, AsyncTestingSessionLocal = async_test_db + _test_engine, AsyncTestingSessionLocal = async_test_db # Create organization async with AsyncTestingSessionLocal() as session: from app.models.organization import Organization + org = Organization( - name="Error Org", - slug="error-org-delete", - description="Test" + name="Error Org", slug="error-org-delete", description="Test" ) session.add(org) await session.commit() await session.refresh(org) org_id = org.id - with patch('app.api.routes.admin.organization_crud.remove', side_effect=Exception("Delete failed")): + with patch( + "app.api.routes.admin.organization_crud.remove", + side_effect=Exception("Delete failed"), + ): with pytest.raises(Exception): await client.delete( f"/api/v1/admin/organizations/{org_id}", - headers={"Authorization": f"Bearer {superuser_token}"} + headers={"Authorization": f"Bearer {superuser_token}"}, ) @@ -388,34 +431,38 @@ class TestAdminListOrganizationMembersErrors: fake_id = uuid4() response = await client.get( f"/api/v1/admin/organizations/{fake_id}/members", - headers={"Authorization": f"Bearer {superuser_token}"} + headers={"Authorization": f"Bearer {superuser_token}"}, ) assert response.status_code == status.HTTP_404_NOT_FOUND @pytest.mark.asyncio - async def test_list_members_database_error(self, client, async_test_db, superuser_token): + async def test_list_members_database_error( + self, client, async_test_db, superuser_token + ): """Test database errors during member listing (covers line 660-662).""" - test_engine, AsyncTestingSessionLocal = async_test_db + _test_engine, AsyncTestingSessionLocal = async_test_db # Create organization async with AsyncTestingSessionLocal() as session: from app.models.organization import Organization + org = Organization( - name="Members Error Org", - slug="members-error-org", - description="Test" + name="Members Error Org", slug="members-error-org", description="Test" ) session.add(org) await session.commit() await session.refresh(org) org_id = org.id - with patch('app.api.routes.admin.organization_crud.get_organization_members', side_effect=Exception("DB error")): + with patch( + "app.api.routes.admin.organization_crud.get_organization_members", + side_effect=Exception("DB error"), + ): with pytest.raises(Exception): await client.get( f"/api/v1/admin/organizations/{org_id}/members", - headers={"Authorization": f"Bearer {superuser_token}"} + headers={"Authorization": f"Bearer {superuser_token}"}, ) @@ -423,32 +470,32 @@ class TestAdminAddOrganizationMemberErrors: """Test admin add organization member error handling.""" @pytest.mark.asyncio - async def test_add_member_nonexistent_organization(self, client, async_test_user, superuser_token): + async def test_add_member_nonexistent_organization( + self, client, async_test_user, superuser_token + ): """Test adding member to non-existent organization (covers line 689-693).""" fake_id = uuid4() response = await client.post( f"/api/v1/admin/organizations/{fake_id}/members", headers={"Authorization": f"Bearer {superuser_token}"}, - json={ - "user_id": str(async_test_user.id), - "role": "member" - } + json={"user_id": str(async_test_user.id), "role": "member"}, ) assert response.status_code == status.HTTP_404_NOT_FOUND @pytest.mark.asyncio - async def test_add_nonexistent_user_to_organization(self, client, async_test_db, superuser_token): + async def test_add_nonexistent_user_to_organization( + self, client, async_test_db, superuser_token + ): """Test adding non-existent user to organization (covers line 696-700).""" - test_engine, AsyncTestingSessionLocal = async_test_db + _test_engine, AsyncTestingSessionLocal = async_test_db # Create organization async with AsyncTestingSessionLocal() as session: from app.models.organization import Organization + org = Organization( - name="Add Member Org", - slug="add-member-org", - description="Test" + name="Add Member Org", slug="add-member-org", description="Test" ) session.add(org) await session.commit() @@ -459,41 +506,39 @@ class TestAdminAddOrganizationMemberErrors: response = await client.post( f"/api/v1/admin/organizations/{org_id}/members", headers={"Authorization": f"Bearer {superuser_token}"}, - json={ - "user_id": str(fake_user_id), - "role": "member" - } + json={"user_id": str(fake_user_id), "role": "member"}, ) assert response.status_code == status.HTTP_404_NOT_FOUND @pytest.mark.asyncio - async def test_add_member_unexpected_error(self, client, async_test_db, async_test_user, superuser_token): + async def test_add_member_unexpected_error( + self, client, async_test_db, async_test_user, superuser_token + ): """Test unexpected errors during member addition (covers line 727-729).""" - test_engine, AsyncTestingSessionLocal = async_test_db + _test_engine, AsyncTestingSessionLocal = async_test_db # Create organization async with AsyncTestingSessionLocal() as session: from app.models.organization import Organization + org = Organization( - name="Error Add Org", - slug="error-add-org", - description="Test" + name="Error Add Org", slug="error-add-org", description="Test" ) session.add(org) await session.commit() await session.refresh(org) org_id = org.id - with patch('app.api.routes.admin.organization_crud.add_user', side_effect=Exception("Add failed")): + with patch( + "app.api.routes.admin.organization_crud.add_user", + side_effect=Exception("Add failed"), + ): with pytest.raises(Exception): await client.post( f"/api/v1/admin/organizations/{org_id}/members", headers={"Authorization": f"Bearer {superuser_token}"}, - json={ - "user_id": str(async_test_user.id), - "role": "member" - } + json={"user_id": str(async_test_user.id), "role": "member"}, ) @@ -501,30 +546,32 @@ class TestAdminRemoveOrganizationMemberErrors: """Test admin remove organization member error handling.""" @pytest.mark.asyncio - async def test_remove_member_nonexistent_organization(self, client, async_test_user, superuser_token): + async def test_remove_member_nonexistent_organization( + self, client, async_test_user, superuser_token + ): """Test removing member from non-existent organization (covers line 750-754).""" fake_id = uuid4() response = await client.delete( f"/api/v1/admin/organizations/{fake_id}/members/{async_test_user.id}", - headers={"Authorization": f"Bearer {superuser_token}"} + headers={"Authorization": f"Bearer {superuser_token}"}, ) assert response.status_code == status.HTTP_404_NOT_FOUND @pytest.mark.asyncio - async def test_remove_member_unexpected_error(self, client, async_test_db, async_test_user, superuser_token): + async def test_remove_member_unexpected_error( + self, client, async_test_db, async_test_user, superuser_token + ): """Test unexpected errors during member removal (covers line 780-782).""" - test_engine, AsyncTestingSessionLocal = async_test_db + _test_engine, AsyncTestingSessionLocal = async_test_db # Create organization with member async with AsyncTestingSessionLocal() as session: from app.models.organization import Organization - from app.models.user_organization import UserOrganization, OrganizationRole + from app.models.user_organization import OrganizationRole, UserOrganization org = Organization( - name="Remove Member Org", - slug="remove-member-org", - description="Test" + name="Remove Member Org", slug="remove-member-org", description="Test" ) session.add(org) await session.commit() @@ -533,22 +580,26 @@ class TestAdminRemoveOrganizationMemberErrors: member = UserOrganization( user_id=async_test_user.id, organization_id=org.id, - role=OrganizationRole.MEMBER + role=OrganizationRole.MEMBER, ) session.add(member) await session.commit() org_id = org.id - with patch('app.api.routes.admin.organization_crud.remove_user', side_effect=Exception("Remove failed")): + with patch( + "app.api.routes.admin.organization_crud.remove_user", + side_effect=Exception("Remove failed"), + ): with pytest.raises(Exception): await client.delete( f"/api/v1/admin/organizations/{org_id}/members/{async_test_user.id}", - headers={"Authorization": f"Bearer {superuser_token}"} + headers={"Authorization": f"Bearer {superuser_token}"}, ) # ===== SUCCESS PATH TESTS ===== + class TestAdminListUsersSuccess: """Test admin list users success paths.""" @@ -557,7 +608,7 @@ class TestAdminListUsersSuccess: """Test listing users with pagination (covers lines 109-116).""" response = await client.get( "/api/v1/admin/users?page=1&limit=10", - headers={"Authorization": f"Bearer {superuser_token}"} + headers={"Authorization": f"Bearer {superuser_token}"}, ) assert response.status_code == status.HTTP_200_OK @@ -579,8 +630,8 @@ class TestAdminCreateUserSuccess: "email": f"newuser{uuid4().hex[:8]}@example.com", "password": "NewPassword123!", "first_name": "New", - "last_name": "User" - } + "last_name": "User", + }, ) assert response.status_code == status.HTTP_201_CREATED @@ -598,7 +649,7 @@ class TestAdminUpdateUserSuccess: response = await client.put( f"/api/v1/admin/users/{async_test_user.id}", headers={"Authorization": f"Bearer {superuser_token}"}, - json={"first_name": "Updated"} + json={"first_name": "Updated"}, ) assert response.status_code == status.HTTP_200_OK @@ -612,18 +663,18 @@ class TestAdminDeleteUserSuccess: @pytest.mark.asyncio async def test_delete_user_success(self, client, async_test_db, superuser_token): """Test deleting user successfully (covers lines 226-246).""" - test_engine, AsyncTestingSessionLocal = async_test_db + _test_engine, AsyncTestingSessionLocal = async_test_db # Create a user to delete async with AsyncTestingSessionLocal() as session: - from app.models.user import User from app.core.auth import get_password_hash + from app.models.user import User user_to_delete = User( email=f"delete{uuid4().hex[:8]}@example.com", password_hash=get_password_hash("Password123!"), first_name="Delete", - last_name="Me" + last_name="Me", ) session.add(user_to_delete) await session.commit() @@ -632,7 +683,7 @@ class TestAdminDeleteUserSuccess: response = await client.delete( f"/api/v1/admin/users/{user_id}", - headers={"Authorization": f"Bearer {superuser_token}"} + headers={"Authorization": f"Bearer {superuser_token}"}, ) assert response.status_code == status.HTTP_200_OK @@ -640,11 +691,13 @@ class TestAdminDeleteUserSuccess: assert data["success"] is True @pytest.mark.asyncio - async def test_delete_self_fails(self, client, async_test_superuser, superuser_token): + async def test_delete_self_fails( + self, client, async_test_superuser, superuser_token + ): """Test that admin cannot delete themselves (covers lines 233-238).""" response = await client.delete( f"/api/v1/admin/users/{async_test_superuser.id}", - headers={"Authorization": f"Bearer {superuser_token}"} + headers={"Authorization": f"Bearer {superuser_token}"}, ) assert response.status_code == status.HTTP_403_FORBIDDEN @@ -656,19 +709,19 @@ class TestAdminActivateUserSuccess: @pytest.mark.asyncio async def test_activate_user_success(self, client, async_test_db, superuser_token): """Test activating user successfully (covers lines 270-282).""" - test_engine, AsyncTestingSessionLocal = async_test_db + _test_engine, AsyncTestingSessionLocal = async_test_db # Create inactive user async with AsyncTestingSessionLocal() as session: - from app.models.user import User from app.core.auth import get_password_hash + from app.models.user import User inactive_user = User( email=f"inactive{uuid4().hex[:8]}@example.com", password_hash=get_password_hash("Password123!"), first_name="Inactive", last_name="User", - is_active=False + is_active=False, ) session.add(inactive_user) await session.commit() @@ -677,7 +730,7 @@ class TestAdminActivateUserSuccess: response = await client.post( f"/api/v1/admin/users/{user_id}/activate", - headers={"Authorization": f"Bearer {superuser_token}"} + headers={"Authorization": f"Bearer {superuser_token}"}, ) assert response.status_code == status.HTTP_200_OK @@ -689,11 +742,13 @@ class TestAdminDeactivateUserSuccess: """Test admin deactivate user success paths.""" @pytest.mark.asyncio - async def test_deactivate_user_success(self, client, async_test_user, superuser_token): + async def test_deactivate_user_success( + self, client, async_test_user, superuser_token + ): """Test deactivating user successfully (covers lines 306-326).""" response = await client.post( f"/api/v1/admin/users/{async_test_user.id}/deactivate", - headers={"Authorization": f"Bearer {superuser_token}"} + headers={"Authorization": f"Bearer {superuser_token}"}, ) assert response.status_code == status.HTTP_200_OK @@ -701,11 +756,13 @@ class TestAdminDeactivateUserSuccess: assert data["success"] is True @pytest.mark.asyncio - async def test_deactivate_self_fails(self, client, async_test_superuser, superuser_token): + async def test_deactivate_self_fails( + self, client, async_test_superuser, superuser_token + ): """Test that admin cannot deactivate themselves (covers lines 313-318).""" response = await client.post( f"/api/v1/admin/users/{async_test_superuser.id}/deactivate", - headers={"Authorization": f"Bearer {superuser_token}"} + headers={"Authorization": f"Bearer {superuser_token}"}, ) assert response.status_code == status.HTTP_403_FORBIDDEN @@ -717,13 +774,13 @@ class TestAdminBulkUserActionSuccess: @pytest.mark.asyncio async def test_bulk_activate_success(self, client, async_test_db, superuser_token): """Test bulk activate users (covers lines 355-360, 375-392).""" - test_engine, AsyncTestingSessionLocal = async_test_db + _test_engine, AsyncTestingSessionLocal = async_test_db # Create inactive users user_ids = [] async with AsyncTestingSessionLocal() as session: - from app.models.user import User from app.core.auth import get_password_hash + from app.models.user import User for i in range(2): user = User( @@ -731,7 +788,7 @@ class TestAdminBulkUserActionSuccess: password_hash=get_password_hash("Password123!"), first_name="Bulk", last_name=f"User{i}", - is_active=False + is_active=False, ) session.add(user) await session.commit() @@ -741,10 +798,7 @@ class TestAdminBulkUserActionSuccess: response = await client.post( "/api/v1/admin/users/bulk-action", headers={"Authorization": f"Bearer {superuser_token}"}, - json={ - "action": "activate", - "user_ids": user_ids - } + json={"action": "activate", "user_ids": user_ids}, ) assert response.status_code == status.HTTP_200_OK @@ -752,15 +806,17 @@ class TestAdminBulkUserActionSuccess: assert data["affected_count"] >= 0 @pytest.mark.asyncio - async def test_bulk_deactivate_success(self, client, async_test_db, superuser_token): + async def test_bulk_deactivate_success( + self, client, async_test_db, superuser_token + ): """Test bulk deactivate users (covers lines 361-366).""" - test_engine, AsyncTestingSessionLocal = async_test_db + _test_engine, AsyncTestingSessionLocal = async_test_db # Create active users user_ids = [] async with AsyncTestingSessionLocal() as session: - from app.models.user import User from app.core.auth import get_password_hash + from app.models.user import User for i in range(2): user = User( @@ -768,7 +824,7 @@ class TestAdminBulkUserActionSuccess: password_hash=get_password_hash("Password123!"), first_name="Bulk", last_name=f"Deactivate{i}", - is_active=True + is_active=True, ) session.add(user) await session.commit() @@ -778,10 +834,7 @@ class TestAdminBulkUserActionSuccess: response = await client.post( "/api/v1/admin/users/bulk-action", headers={"Authorization": f"Bearer {superuser_token}"}, - json={ - "action": "deactivate", - "user_ids": user_ids - } + json={"action": "deactivate", "user_ids": user_ids}, ) assert response.status_code == status.HTTP_200_OK @@ -791,20 +844,20 @@ class TestAdminBulkUserActionSuccess: @pytest.mark.asyncio async def test_bulk_delete_success(self, client, async_test_db, superuser_token): """Test bulk delete users (covers lines 367-373).""" - test_engine, AsyncTestingSessionLocal = async_test_db + _test_engine, AsyncTestingSessionLocal = async_test_db # Create users to delete user_ids = [] async with AsyncTestingSessionLocal() as session: - from app.models.user import User from app.core.auth import get_password_hash + from app.models.user import User for i in range(2): user = User( email=f"bulkdel{i}{uuid4().hex[:8]}@example.com", password_hash=get_password_hash("Password123!"), first_name="Bulk", - last_name=f"Delete{i}" + last_name=f"Delete{i}", ) session.add(user) await session.commit() @@ -814,10 +867,7 @@ class TestAdminBulkUserActionSuccess: response = await client.post( "/api/v1/admin/users/bulk-action", headers={"Authorization": f"Bearer {superuser_token}"}, - json={ - "action": "delete", - "user_ids": user_ids - } + json={"action": "delete", "user_ids": user_ids}, ) assert response.status_code == status.HTTP_200_OK @@ -833,7 +883,7 @@ class TestAdminListOrganizationsSuccess: """Test listing organizations with pagination (covers lines 427-452).""" response = await client.get( "/api/v1/admin/organizations?page=1&limit=10", - headers={"Authorization": f"Bearer {superuser_token}"} + headers={"Authorization": f"Bearer {superuser_token}"}, ) assert response.status_code == status.HTTP_200_OK @@ -855,8 +905,8 @@ class TestAdminCreateOrganizationSuccess: json={ "name": "New Organization", "slug": unique_slug, - "description": "Test org" - } + "description": "Test org", + }, ) assert response.status_code == status.HTTP_201_CREATED @@ -869,18 +919,18 @@ class TestAdminGetOrganizationSuccess: """Test admin get organization success paths.""" @pytest.mark.asyncio - async def test_get_organization_success(self, client, async_test_db, superuser_token): + async def test_get_organization_success( + self, client, async_test_db, superuser_token + ): """Test getting organization successfully (covers lines 516-533).""" - test_engine, AsyncTestingSessionLocal = async_test_db + _test_engine, AsyncTestingSessionLocal = async_test_db # Create organization async with AsyncTestingSessionLocal() as session: from app.models.organization import Organization org = Organization( - name="Get Org", - slug=f"getorg{uuid4().hex[:8]}", - description="Test" + name="Get Org", slug=f"getorg{uuid4().hex[:8]}", description="Test" ) session.add(org) await session.commit() @@ -889,7 +939,7 @@ class TestAdminGetOrganizationSuccess: response = await client.get( f"/api/v1/admin/organizations/{org_id}", - headers={"Authorization": f"Bearer {superuser_token}"} + headers={"Authorization": f"Bearer {superuser_token}"}, ) assert response.status_code == status.HTTP_200_OK @@ -901,9 +951,11 @@ class TestAdminUpdateOrganizationSuccess: """Test admin update organization success paths.""" @pytest.mark.asyncio - async def test_update_organization_success(self, client, async_test_db, superuser_token): + async def test_update_organization_success( + self, client, async_test_db, superuser_token + ): """Test updating organization successfully (covers lines 552-572).""" - test_engine, AsyncTestingSessionLocal = async_test_db + _test_engine, AsyncTestingSessionLocal = async_test_db # Create organization async with AsyncTestingSessionLocal() as session: @@ -912,7 +964,7 @@ class TestAdminUpdateOrganizationSuccess: org = Organization( name="Update Org", slug=f"updateorg{uuid4().hex[:8]}", - description="Test" + description="Test", ) session.add(org) await session.commit() @@ -922,7 +974,7 @@ class TestAdminUpdateOrganizationSuccess: response = await client.put( f"/api/v1/admin/organizations/{org_id}", headers={"Authorization": f"Bearer {superuser_token}"}, - json={"name": "Updated Org Name"} + json={"name": "Updated Org Name"}, ) assert response.status_code == status.HTTP_200_OK @@ -934,9 +986,11 @@ class TestAdminDeleteOrganizationSuccess: """Test admin delete organization success paths.""" @pytest.mark.asyncio - async def test_delete_organization_success(self, client, async_test_db, superuser_token): + async def test_delete_organization_success( + self, client, async_test_db, superuser_token + ): """Test deleting organization successfully (covers lines 596-608).""" - test_engine, AsyncTestingSessionLocal = async_test_db + _test_engine, AsyncTestingSessionLocal = async_test_db # Create organization async with AsyncTestingSessionLocal() as session: @@ -945,7 +999,7 @@ class TestAdminDeleteOrganizationSuccess: org = Organization( name="Delete Org", slug=f"deleteorg{uuid4().hex[:8]}", - description="Test" + description="Test", ) session.add(org) await session.commit() @@ -954,7 +1008,7 @@ class TestAdminDeleteOrganizationSuccess: response = await client.delete( f"/api/v1/admin/organizations/{org_id}", - headers={"Authorization": f"Bearer {superuser_token}"} + headers={"Authorization": f"Bearer {superuser_token}"}, ) assert response.status_code == status.HTTP_200_OK @@ -966,19 +1020,21 @@ class TestAdminListOrganizationMembersSuccess: """Test admin list organization members success paths.""" @pytest.mark.asyncio - async def test_list_organization_members_success(self, client, async_test_db, async_test_user, superuser_token): + async def test_list_organization_members_success( + self, client, async_test_db, async_test_user, superuser_token + ): """Test listing organization members successfully (covers lines 634-658).""" - test_engine, AsyncTestingSessionLocal = async_test_db + _test_engine, AsyncTestingSessionLocal = async_test_db # Create organization with member async with AsyncTestingSessionLocal() as session: from app.models.organization import Organization - from app.models.user_organization import UserOrganization, OrganizationRole + from app.models.user_organization import OrganizationRole, UserOrganization org = Organization( name="Members Org", slug=f"membersorg{uuid4().hex[:8]}", - description="Test" + description="Test", ) session.add(org) await session.commit() @@ -987,7 +1043,7 @@ class TestAdminListOrganizationMembersSuccess: member = UserOrganization( user_id=async_test_user.id, organization_id=org.id, - role=OrganizationRole.MEMBER + role=OrganizationRole.MEMBER, ) session.add(member) await session.commit() @@ -995,7 +1051,7 @@ class TestAdminListOrganizationMembersSuccess: response = await client.get( f"/api/v1/admin/organizations/{org_id}/members", - headers={"Authorization": f"Bearer {superuser_token}"} + headers={"Authorization": f"Bearer {superuser_token}"}, ) assert response.status_code == status.HTTP_200_OK @@ -1008,9 +1064,11 @@ class TestAdminAddOrganizationMemberSuccess: """Test admin add organization member success paths.""" @pytest.mark.asyncio - async def test_add_member_success(self, client, async_test_db, async_test_user, superuser_token): + async def test_add_member_success( + self, client, async_test_db, async_test_user, superuser_token + ): """Test adding member to organization successfully (covers lines 689-717).""" - test_engine, AsyncTestingSessionLocal = async_test_db + _test_engine, AsyncTestingSessionLocal = async_test_db # Create organization async with AsyncTestingSessionLocal() as session: @@ -1019,7 +1077,7 @@ class TestAdminAddOrganizationMemberSuccess: org = Organization( name="Add Member Org", slug=f"addmemberorg{uuid4().hex[:8]}", - description="Test" + description="Test", ) session.add(org) await session.commit() @@ -1029,10 +1087,7 @@ class TestAdminAddOrganizationMemberSuccess: response = await client.post( f"/api/v1/admin/organizations/{org_id}/members", headers={"Authorization": f"Bearer {superuser_token}"}, - json={ - "user_id": str(async_test_user.id), - "role": "member" - } + json={"user_id": str(async_test_user.id), "role": "member"}, ) assert response.status_code == status.HTTP_200_OK @@ -1044,19 +1099,21 @@ class TestAdminRemoveOrganizationMemberSuccess: """Test admin remove organization member success paths.""" @pytest.mark.asyncio - async def test_remove_member_success(self, client, async_test_db, async_test_user, superuser_token): + async def test_remove_member_success( + self, client, async_test_db, async_test_user, superuser_token + ): """Test removing member from organization successfully (covers lines 750-780).""" - test_engine, AsyncTestingSessionLocal = async_test_db + _test_engine, AsyncTestingSessionLocal = async_test_db # Create organization with member async with AsyncTestingSessionLocal() as session: from app.models.organization import Organization - from app.models.user_organization import UserOrganization, OrganizationRole + from app.models.user_organization import OrganizationRole, UserOrganization org = Organization( name="Remove Member Success Org", slug=f"removemembersuccess{uuid4().hex[:8]}", - description="Test" + description="Test", ) session.add(org) await session.commit() @@ -1065,7 +1122,7 @@ class TestAdminRemoveOrganizationMemberSuccess: member = UserOrganization( user_id=async_test_user.id, organization_id=org.id, - role=OrganizationRole.MEMBER + role=OrganizationRole.MEMBER, ) session.add(member) await session.commit() @@ -1073,7 +1130,7 @@ class TestAdminRemoveOrganizationMemberSuccess: response = await client.delete( f"/api/v1/admin/organizations/{org_id}/members/{async_test_user.id}", - headers={"Authorization": f"Bearer {superuser_token}"} + headers={"Authorization": f"Bearer {superuser_token}"}, ) assert response.status_code == status.HTTP_200_OK @@ -1081,9 +1138,11 @@ class TestAdminRemoveOrganizationMemberSuccess: assert data["success"] is True @pytest.mark.asyncio - async def test_remove_nonmember_fails(self, client, async_test_db, async_test_user, superuser_token): + async def test_remove_nonmember_fails( + self, client, async_test_db, async_test_user, superuser_token + ): """Test removing non-member fails (covers lines 769-773).""" - test_engine, AsyncTestingSessionLocal = async_test_db + _test_engine, AsyncTestingSessionLocal = async_test_db # Create organization without member async with AsyncTestingSessionLocal() as session: @@ -1092,7 +1151,7 @@ class TestAdminRemoveOrganizationMemberSuccess: org = Organization( name="No Member Org", slug=f"nomemberorg{uuid4().hex[:8]}", - description="Test" + description="Test", ) session.add(org) await session.commit() @@ -1101,7 +1160,7 @@ class TestAdminRemoveOrganizationMemberSuccess: response = await client.delete( f"/api/v1/admin/organizations/{org_id}/members/{async_test_user.id}", - headers={"Authorization": f"Bearer {superuser_token}"} + headers={"Authorization": f"Bearer {superuser_token}"}, ) assert response.status_code == status.HTTP_404_NOT_FOUND diff --git a/backend/tests/api/test_auth.py b/backend/tests/api/test_auth.py index 4b19564..2f28383 100644 --- a/backend/tests/api/test_auth.py +++ b/backend/tests/api/test_auth.py @@ -2,6 +2,7 @@ """ Tests for authentication endpoints. """ + import pytest import pytest_asyncio from fastapi import status @@ -19,8 +20,8 @@ class TestRegisterEndpoint: "email": "newuser@example.com", "password": "NewPassword123!", "first_name": "New", - "last_name": "User" - } + "last_name": "User", + }, ) assert response.status_code == status.HTTP_201_CREATED @@ -36,8 +37,8 @@ class TestRegisterEndpoint: "email": async_test_user.email, "password": "TestPassword123!", "first_name": "Test", - "last_name": "User" - } + "last_name": "User", + }, ) assert response.status_code == status.HTTP_400_BAD_REQUEST @@ -51,8 +52,8 @@ class TestRegisterEndpoint: "email": "test@example.com", "password": "weak", "first_name": "Test", - "last_name": "User" - } + "last_name": "User", + }, ) assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY @@ -66,10 +67,7 @@ class TestLoginEndpoint: """Test successful login.""" response = await client.post( "/api/v1/auth/login", - json={ - "email": "testuser@example.com", - "password": "TestPassword123!" - } + json={"email": "testuser@example.com", "password": "TestPassword123!"}, ) assert response.status_code == status.HTTP_200_OK @@ -82,10 +80,7 @@ class TestLoginEndpoint: """Test login with invalid password.""" response = await client.post( "/api/v1/auth/login", - json={ - "email": "testuser@example.com", - "password": "WrongPassword123!" - } + json={"email": "testuser@example.com", "password": "WrongPassword123!"}, ) assert response.status_code == status.HTTP_401_UNAUTHORIZED @@ -95,10 +90,7 @@ class TestLoginEndpoint: """Test login with non-existent user.""" response = await client.post( "/api/v1/auth/login", - json={ - "email": "nonexistent@example.com", - "password": "TestPassword123!" - } + json={"email": "nonexistent@example.com", "password": "TestPassword123!"}, ) assert response.status_code == status.HTTP_401_UNAUTHORIZED @@ -106,27 +98,25 @@ class TestLoginEndpoint: @pytest.mark.asyncio async def test_login_inactive_user(self, client, async_test_db): """Test login with inactive user.""" - test_engine, SessionLocal = async_test_db + _test_engine, SessionLocal = async_test_db async with SessionLocal() as session: - from app.models.user import User from app.core.auth import get_password_hash + from app.models.user import User + inactive_user = User( email="inactive@example.com", password_hash=get_password_hash("TestPassword123!"), first_name="Inactive", last_name="User", - is_active=False + is_active=False, ) session.add(inactive_user) await session.commit() response = await client.post( "/api/v1/auth/login", - json={ - "email": "inactive@example.com", - "password": "TestPassword123!" - } + json={"email": "inactive@example.com", "password": "TestPassword123!"}, ) assert response.status_code == status.HTTP_401_UNAUTHORIZED @@ -140,10 +130,7 @@ class TestRefreshTokenEndpoint: """Get a refresh token for testing.""" response = await client.post( "/api/v1/auth/login", - json={ - "email": "testuser@example.com", - "password": "TestPassword123!" - } + json={"email": "testuser@example.com", "password": "TestPassword123!"}, ) return response.json()["refresh_token"] @@ -151,8 +138,7 @@ class TestRefreshTokenEndpoint: async def test_refresh_token_success(self, client, refresh_token): """Test successful token refresh.""" response = await client.post( - "/api/v1/auth/refresh", - json={"refresh_token": refresh_token} + "/api/v1/auth/refresh", json={"refresh_token": refresh_token} ) assert response.status_code == status.HTTP_200_OK @@ -164,8 +150,7 @@ class TestRefreshTokenEndpoint: async def test_refresh_token_invalid(self, client): """Test refresh with invalid token.""" response = await client.post( - "/api/v1/auth/refresh", - json={"refresh_token": "invalid.token.here"} + "/api/v1/auth/refresh", json={"refresh_token": "invalid.token.here"} ) assert response.status_code == status.HTTP_401_UNAUTHORIZED @@ -179,13 +164,13 @@ class TestLogoutEndpoint: """Get tokens for testing.""" response = await client.post( "/api/v1/auth/login", - json={ - "email": "testuser@example.com", - "password": "TestPassword123!" - } + json={"email": "testuser@example.com", "password": "TestPassword123!"}, ) data = response.json() - return {"access_token": data["access_token"], "refresh_token": data["refresh_token"]} + return { + "access_token": data["access_token"], + "refresh_token": data["refresh_token"], + } @pytest.mark.asyncio async def test_logout_success(self, client, tokens): @@ -193,7 +178,7 @@ class TestLogoutEndpoint: response = await client.post( "/api/v1/auth/logout", headers={"Authorization": f"Bearer {tokens['access_token']}"}, - json={"refresh_token": tokens["refresh_token"]} + json={"refresh_token": tokens["refresh_token"]}, ) assert response.status_code == status.HTTP_200_OK @@ -202,8 +187,7 @@ class TestLogoutEndpoint: async def test_logout_without_auth(self, client): """Test logout without authentication.""" response = await client.post( - "/api/v1/auth/logout", - json={"refresh_token": "some.token"} + "/api/v1/auth/logout", json={"refresh_token": "some.token"} ) assert response.status_code == status.HTTP_401_UNAUTHORIZED @@ -215,8 +199,7 @@ class TestPasswordResetRequest: async def test_password_reset_request_success(self, client, async_test_user): """Test password reset request with existing user.""" response = await client.post( - "/api/v1/auth/password-reset/request", - json={"email": async_test_user.email} + "/api/v1/auth/password-reset/request", json={"email": async_test_user.email} ) assert response.status_code == status.HTTP_200_OK @@ -228,7 +211,7 @@ class TestPasswordResetRequest: """Test password reset request with non-existent email.""" response = await client.post( "/api/v1/auth/password-reset/request", - json={"email": "nonexistent@example.com"} + json={"email": "nonexistent@example.com"}, ) assert response.status_code == status.HTTP_200_OK @@ -244,10 +227,7 @@ class TestPasswordResetConfirm: """Test password reset with invalid token.""" response = await client.post( "/api/v1/auth/password-reset/confirm", - json={ - "token": "invalid.token.here", - "new_password": "NewPassword123!" - } + json={"token": "invalid.token.here", "new_password": "NewPassword123!"}, ) assert response.status_code == status.HTTP_400_BAD_REQUEST @@ -261,20 +241,20 @@ class TestLogoutAll: """Get tokens for testing.""" response = await client.post( "/api/v1/auth/login", - json={ - "email": "testuser@example.com", - "password": "TestPassword123!" - } + json={"email": "testuser@example.com", "password": "TestPassword123!"}, ) data = response.json() - return {"access_token": data["access_token"], "refresh_token": data["refresh_token"]} + return { + "access_token": data["access_token"], + "refresh_token": data["refresh_token"], + } @pytest.mark.asyncio async def test_logout_all_success(self, client, tokens): """Test logout from all devices.""" response = await client.post( "/api/v1/auth/logout-all", - headers={"Authorization": f"Bearer {tokens['access_token']}"} + headers={"Authorization": f"Bearer {tokens['access_token']}"}, ) assert response.status_code == status.HTTP_200_OK @@ -298,10 +278,7 @@ class TestOAuthLogin: """Test successful OAuth login.""" response = await client.post( "/api/v1/auth/login/oauth", - data={ - "username": "testuser@example.com", - "password": "TestPassword123!" - } + data={"username": "testuser@example.com", "password": "TestPassword123!"}, ) assert response.status_code == status.HTTP_200_OK @@ -315,10 +292,7 @@ class TestOAuthLogin: """Test OAuth login with invalid credentials.""" response = await client.post( "/api/v1/auth/login/oauth", - data={ - "username": "testuser@example.com", - "password": "WrongPassword" - } + data={"username": "testuser@example.com", "password": "WrongPassword"}, ) assert response.status_code == status.HTTP_401_UNAUTHORIZED diff --git a/backend/tests/api/test_auth_dependencies.py b/backend/tests/api/test_auth_dependencies.py index 3de6b05..5da4a55 100755 --- a/backend/tests/api/test_auth_dependencies.py +++ b/backend/tests/api/test_auth_dependencies.py @@ -1,15 +1,16 @@ # tests/api/dependencies/test_auth_dependencies.py -import pytest -import pytest_asyncio import uuid from unittest.mock import patch + +import pytest +import pytest_asyncio from fastapi import HTTPException from app.api.dependencies.auth import ( - get_current_user, get_current_active_user, get_current_superuser, - get_optional_current_user + get_current_user, + get_optional_current_user, ) from app.core.auth import TokenExpiredError, TokenInvalidError, get_password_hash from app.models.user import User @@ -24,7 +25,7 @@ def mock_token(): @pytest_asyncio.fixture async def async_mock_user(async_test_db): """Async fixture to create and return a mock User instance.""" - test_engine, AsyncTestingSessionLocal = async_test_db + _test_engine, AsyncTestingSessionLocal = async_test_db async with AsyncTestingSessionLocal() as session: mock_user = User( id=uuid.uuid4(), @@ -47,12 +48,14 @@ class TestGetCurrentUser: """Tests for get_current_user dependency""" @pytest.mark.asyncio - async def test_get_current_user_success(self, async_test_db, async_mock_user, mock_token): + async def test_get_current_user_success( + self, async_test_db, async_mock_user, mock_token + ): """Test successfully getting the current user""" - test_engine, AsyncTestingSessionLocal = async_test_db + _test_engine, AsyncTestingSessionLocal = async_test_db async with AsyncTestingSessionLocal() as session: # Mock get_token_data to return user_id that matches our mock_user - with patch('app.api.dependencies.auth.get_token_data') as mock_get_data: + with patch("app.api.dependencies.auth.get_token_data") as mock_get_data: mock_get_data.return_value.user_id = async_mock_user.id # Call the dependency @@ -65,12 +68,12 @@ class TestGetCurrentUser: @pytest.mark.asyncio async def test_get_current_user_nonexistent(self, async_test_db, mock_token): """Test when the token contains a user ID that doesn't exist""" - test_engine, AsyncTestingSessionLocal = async_test_db + _test_engine, AsyncTestingSessionLocal = async_test_db async with AsyncTestingSessionLocal() as session: # Mock get_token_data to return a non-existent user ID nonexistent_id = uuid.UUID("11111111-1111-1111-1111-111111111111") - with patch('app.api.dependencies.auth.get_token_data') as mock_get_data: + with patch("app.api.dependencies.auth.get_token_data") as mock_get_data: mock_get_data.return_value.user_id = nonexistent_id # Should raise HTTPException with 404 status @@ -81,19 +84,24 @@ class TestGetCurrentUser: assert "User not found" in exc_info.value.detail @pytest.mark.asyncio - async def test_get_current_user_inactive(self, async_test_db, async_mock_user, mock_token): + async def test_get_current_user_inactive( + self, async_test_db, async_mock_user, mock_token + ): """Test when the user is inactive""" - test_engine, AsyncTestingSessionLocal = async_test_db + _test_engine, AsyncTestingSessionLocal = async_test_db async with AsyncTestingSessionLocal() as session: # Get the user in this session and make it inactive from sqlalchemy import select - result = await session.execute(select(User).where(User.id == async_mock_user.id)) + + result = await session.execute( + select(User).where(User.id == async_mock_user.id) + ) user_in_session = result.scalar_one_or_none() user_in_session.is_active = False await session.commit() # Mock get_token_data - with patch('app.api.dependencies.auth.get_token_data') as mock_get_data: + with patch("app.api.dependencies.auth.get_token_data") as mock_get_data: mock_get_data.return_value.user_id = async_mock_user.id # Should raise HTTPException with 403 status @@ -106,10 +114,10 @@ class TestGetCurrentUser: @pytest.mark.asyncio async def test_get_current_user_expired_token(self, async_test_db, mock_token): """Test with an expired token""" - test_engine, AsyncTestingSessionLocal = async_test_db + _test_engine, AsyncTestingSessionLocal = async_test_db async with AsyncTestingSessionLocal() as session: # Mock get_token_data to raise TokenExpiredError - with patch('app.api.dependencies.auth.get_token_data') as mock_get_data: + with patch("app.api.dependencies.auth.get_token_data") as mock_get_data: mock_get_data.side_effect = TokenExpiredError("Token expired") # Should raise HTTPException with 401 status @@ -122,10 +130,10 @@ class TestGetCurrentUser: @pytest.mark.asyncio async def test_get_current_user_invalid_token(self, async_test_db, mock_token): """Test with an invalid token""" - test_engine, AsyncTestingSessionLocal = async_test_db + _test_engine, AsyncTestingSessionLocal = async_test_db async with AsyncTestingSessionLocal() as session: # Mock get_token_data to raise TokenInvalidError - with patch('app.api.dependencies.auth.get_token_data') as mock_get_data: + with patch("app.api.dependencies.auth.get_token_data") as mock_get_data: mock_get_data.side_effect = TokenInvalidError("Invalid token") # Should raise HTTPException with 401 status @@ -194,12 +202,14 @@ class TestGetOptionalCurrentUser: """Tests for get_optional_current_user dependency""" @pytest.mark.asyncio - async def test_get_optional_current_user_with_token(self, async_test_db, async_mock_user, mock_token): + async def test_get_optional_current_user_with_token( + self, async_test_db, async_mock_user, mock_token + ): """Test getting optional user with a valid token""" - test_engine, AsyncTestingSessionLocal = async_test_db + _test_engine, AsyncTestingSessionLocal = async_test_db async with AsyncTestingSessionLocal() as session: # Mock get_token_data - with patch('app.api.dependencies.auth.get_token_data') as mock_get_data: + with patch("app.api.dependencies.auth.get_token_data") as mock_get_data: mock_get_data.return_value.user_id = async_mock_user.id # Call the dependency @@ -212,7 +222,7 @@ class TestGetOptionalCurrentUser: @pytest.mark.asyncio async def test_get_optional_current_user_no_token(self, async_test_db): """Test getting optional user with no token""" - test_engine, AsyncTestingSessionLocal = async_test_db + _test_engine, AsyncTestingSessionLocal = async_test_db async with AsyncTestingSessionLocal() as session: # Call the dependency with no token user = await get_optional_current_user(db=session, token=None) @@ -221,12 +231,14 @@ class TestGetOptionalCurrentUser: assert user is None @pytest.mark.asyncio - async def test_get_optional_current_user_invalid_token(self, async_test_db, mock_token): + async def test_get_optional_current_user_invalid_token( + self, async_test_db, mock_token + ): """Test getting optional user with an invalid token""" - test_engine, AsyncTestingSessionLocal = async_test_db + _test_engine, AsyncTestingSessionLocal = async_test_db async with AsyncTestingSessionLocal() as session: # Mock get_token_data to raise TokenInvalidError - with patch('app.api.dependencies.auth.get_token_data') as mock_get_data: + with patch("app.api.dependencies.auth.get_token_data") as mock_get_data: mock_get_data.side_effect = TokenInvalidError("Invalid token") # Call the dependency @@ -236,12 +248,14 @@ class TestGetOptionalCurrentUser: assert user is None @pytest.mark.asyncio - async def test_get_optional_current_user_expired_token(self, async_test_db, mock_token): + async def test_get_optional_current_user_expired_token( + self, async_test_db, mock_token + ): """Test getting optional user with an expired token""" - test_engine, AsyncTestingSessionLocal = async_test_db + _test_engine, AsyncTestingSessionLocal = async_test_db async with AsyncTestingSessionLocal() as session: # Mock get_token_data to raise TokenExpiredError - with patch('app.api.dependencies.auth.get_token_data') as mock_get_data: + with patch("app.api.dependencies.auth.get_token_data") as mock_get_data: mock_get_data.side_effect = TokenExpiredError("Token expired") # Call the dependency @@ -251,19 +265,24 @@ class TestGetOptionalCurrentUser: assert user is None @pytest.mark.asyncio - async def test_get_optional_current_user_inactive(self, async_test_db, async_mock_user, mock_token): + async def test_get_optional_current_user_inactive( + self, async_test_db, async_mock_user, mock_token + ): """Test getting optional user when user is inactive""" - test_engine, AsyncTestingSessionLocal = async_test_db + _test_engine, AsyncTestingSessionLocal = async_test_db async with AsyncTestingSessionLocal() as session: # Get the user in this session and make it inactive from sqlalchemy import select - result = await session.execute(select(User).where(User.id == async_mock_user.id)) + + result = await session.execute( + select(User).where(User.id == async_mock_user.id) + ) user_in_session = result.scalar_one_or_none() user_in_session.is_active = False await session.commit() # Mock get_token_data - with patch('app.api.dependencies.auth.get_token_data') as mock_get_data: + with patch("app.api.dependencies.auth.get_token_data") as mock_get_data: mock_get_data.return_value.user_id = async_mock_user.id # Call the dependency diff --git a/backend/tests/api/test_auth_endpoints.py b/backend/tests/api/test_auth_endpoints.py index 833ff06..d70d5e7 100755 --- a/backend/tests/api/test_auth_endpoints.py +++ b/backend/tests/api/test_auth_endpoints.py @@ -2,21 +2,21 @@ """ Tests for authentication endpoints. """ + +from unittest.mock import patch + import pytest -import pytest_asyncio -from unittest.mock import patch, MagicMock from fastapi import status from sqlalchemy import select from app.models.user import User -from app.schemas.users import UserCreate # Disable rate limiting for tests @pytest.fixture(autouse=True) def disable_rate_limit(): """Disable rate limiting for all tests in this module.""" - with patch('app.api.routes.auth.limiter.enabled', False): + with patch("app.api.routes.auth.limiter.enabled", False): yield @@ -32,8 +32,8 @@ class TestRegisterEndpoint: "email": "newuser@example.com", "password": "SecurePassword123!", "first_name": "New", - "last_name": "User" - } + "last_name": "User", + }, ) assert response.status_code == status.HTTP_201_CREATED @@ -54,8 +54,8 @@ class TestRegisterEndpoint: "email": async_test_user.email, "password": "SecurePassword123!", "first_name": "Duplicate", - "last_name": "User" - } + "last_name": "User", + }, ) # Security: Returns 400 with generic message to prevent email enumeration @@ -73,8 +73,8 @@ class TestRegisterEndpoint: "email": "weakpass@example.com", "password": "weak", "first_name": "Weak", - "last_name": "Pass" - } + "last_name": "Pass", + }, ) assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY @@ -82,7 +82,7 @@ class TestRegisterEndpoint: @pytest.mark.asyncio async def test_register_unexpected_error(self, client): """Test registration with unexpected error.""" - with patch('app.services.auth_service.AuthService.create_user') as mock_create: + with patch("app.services.auth_service.AuthService.create_user") as mock_create: mock_create.side_effect = Exception("Unexpected error") response = await client.post( @@ -91,8 +91,8 @@ class TestRegisterEndpoint: "email": "error@example.com", "password": "SecurePassword123!", "first_name": "Error", - "last_name": "User" - } + "last_name": "User", + }, ) assert response.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR @@ -106,10 +106,7 @@ class TestLoginEndpoint: """Test successful login.""" response = await client.post( "/api/v1/auth/login", - json={ - "email": async_test_user.email, - "password": "TestPassword123!" - } + json={"email": async_test_user.email, "password": "TestPassword123!"}, ) assert response.status_code == status.HTTP_200_OK @@ -123,10 +120,7 @@ class TestLoginEndpoint: """Test login with wrong password.""" response = await client.post( "/api/v1/auth/login", - json={ - "email": async_test_user.email, - "password": "WrongPassword123" - } + json={"email": async_test_user.email, "password": "WrongPassword123"}, ) assert response.status_code == status.HTTP_401_UNAUTHORIZED @@ -136,10 +130,7 @@ class TestLoginEndpoint: """Test login with non-existent email.""" response = await client.post( "/api/v1/auth/login", - json={ - "email": "nonexistent@example.com", - "password": "Password123!" - } + json={"email": "nonexistent@example.com", "password": "Password123!"}, ) assert response.status_code == status.HTTP_401_UNAUTHORIZED @@ -147,20 +138,19 @@ class TestLoginEndpoint: @pytest.mark.asyncio async def test_login_inactive_user(self, client, async_test_user, async_test_db): """Test login with inactive user.""" - test_engine, AsyncTestingSessionLocal = async_test_db + _test_engine, AsyncTestingSessionLocal = async_test_db async with AsyncTestingSessionLocal() as session: # Get the user in this session and make it inactive - result = await session.execute(select(User).where(User.id == async_test_user.id)) + result = await session.execute( + select(User).where(User.id == async_test_user.id) + ) user_in_session = result.scalar_one_or_none() user_in_session.is_active = False await session.commit() response = await client.post( "/api/v1/auth/login", - json={ - "email": async_test_user.email, - "password": "TestPassword123!" - } + json={"email": async_test_user.email, "password": "TestPassword123!"}, ) assert response.status_code == status.HTTP_401_UNAUTHORIZED @@ -168,15 +158,14 @@ class TestLoginEndpoint: @pytest.mark.asyncio async def test_login_unexpected_error(self, client, async_test_user): """Test login with unexpected error.""" - with patch('app.services.auth_service.AuthService.authenticate_user') as mock_auth: + with patch( + "app.services.auth_service.AuthService.authenticate_user" + ) as mock_auth: mock_auth.side_effect = Exception("Database error") response = await client.post( "/api/v1/auth/login", - json={ - "email": async_test_user.email, - "password": "TestPassword123!" - } + json={"email": async_test_user.email, "password": "TestPassword123!"}, ) assert response.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR @@ -190,10 +179,7 @@ class TestOAuthLoginEndpoint: """Test successful OAuth login.""" response = await client.post( "/api/v1/auth/login/oauth", - data={ - "username": async_test_user.email, - "password": "TestPassword123!" - } + data={"username": async_test_user.email, "password": "TestPassword123!"}, ) assert response.status_code == status.HTTP_200_OK @@ -206,31 +192,29 @@ class TestOAuthLoginEndpoint: """Test OAuth login with wrong credentials.""" response = await client.post( "/api/v1/auth/login/oauth", - data={ - "username": async_test_user.email, - "password": "WrongPassword" - } + data={"username": async_test_user.email, "password": "WrongPassword"}, ) assert response.status_code == status.HTTP_401_UNAUTHORIZED @pytest.mark.asyncio - async def test_oauth_login_inactive_user(self, client, async_test_user, async_test_db): + async def test_oauth_login_inactive_user( + self, client, async_test_user, async_test_db + ): """Test OAuth login with inactive user.""" - test_engine, AsyncTestingSessionLocal = async_test_db + _test_engine, AsyncTestingSessionLocal = async_test_db async with AsyncTestingSessionLocal() as session: # Get the user in this session and make it inactive - result = await session.execute(select(User).where(User.id == async_test_user.id)) + result = await session.execute( + select(User).where(User.id == async_test_user.id) + ) user_in_session = result.scalar_one_or_none() user_in_session.is_active = False await session.commit() response = await client.post( "/api/v1/auth/login/oauth", - data={ - "username": async_test_user.email, - "password": "TestPassword123!" - } + data={"username": async_test_user.email, "password": "TestPassword123!"}, ) assert response.status_code == status.HTTP_401_UNAUTHORIZED @@ -238,15 +222,17 @@ class TestOAuthLoginEndpoint: @pytest.mark.asyncio async def test_oauth_login_unexpected_error(self, client, async_test_user): """Test OAuth login with unexpected error.""" - with patch('app.services.auth_service.AuthService.authenticate_user') as mock_auth: + with patch( + "app.services.auth_service.AuthService.authenticate_user" + ) as mock_auth: mock_auth.side_effect = Exception("Unexpected error") response = await client.post( "/api/v1/auth/login/oauth", data={ "username": async_test_user.email, - "password": "TestPassword123!" - } + "password": "TestPassword123!", + }, ) assert response.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR @@ -261,17 +247,13 @@ class TestRefreshTokenEndpoint: # First, login to get a refresh token login_response = await client.post( "/api/v1/auth/login", - json={ - "email": async_test_user.email, - "password": "TestPassword123!" - } + json={"email": async_test_user.email, "password": "TestPassword123!"}, ) refresh_token = login_response.json()["refresh_token"] # Now refresh the token response = await client.post( - "/api/v1/auth/refresh", - json={"refresh_token": refresh_token} + "/api/v1/auth/refresh", json={"refresh_token": refresh_token} ) assert response.status_code == status.HTTP_200_OK @@ -284,12 +266,13 @@ class TestRefreshTokenEndpoint: """Test refresh with expired token.""" from app.core.auth import TokenExpiredError - with patch('app.services.auth_service.AuthService.refresh_tokens') as mock_refresh: + with patch( + "app.services.auth_service.AuthService.refresh_tokens" + ) as mock_refresh: mock_refresh.side_effect = TokenExpiredError("Token expired") response = await client.post( - "/api/v1/auth/refresh", - json={"refresh_token": "some_token"} + "/api/v1/auth/refresh", json={"refresh_token": "some_token"} ) assert response.status_code == status.HTTP_401_UNAUTHORIZED @@ -298,8 +281,7 @@ class TestRefreshTokenEndpoint: async def test_refresh_token_invalid(self, client): """Test refresh with invalid token.""" response = await client.post( - "/api/v1/auth/refresh", - json={"refresh_token": "invalid_token"} + "/api/v1/auth/refresh", json={"refresh_token": "invalid_token"} ) assert response.status_code == status.HTTP_401_UNAUTHORIZED @@ -310,19 +292,17 @@ class TestRefreshTokenEndpoint: # Get a valid refresh token first login_response = await client.post( "/api/v1/auth/login", - json={ - "email": async_test_user.email, - "password": "TestPassword123!" - } + json={"email": async_test_user.email, "password": "TestPassword123!"}, ) refresh_token = login_response.json()["refresh_token"] - with patch('app.services.auth_service.AuthService.refresh_tokens') as mock_refresh: + with patch( + "app.services.auth_service.AuthService.refresh_tokens" + ) as mock_refresh: mock_refresh.side_effect = Exception("Unexpected error") response = await client.post( - "/api/v1/auth/refresh", - json={"refresh_token": refresh_token} + "/api/v1/auth/refresh", json={"refresh_token": refresh_token} ) assert response.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR diff --git a/backend/tests/api/test_auth_error_handlers.py b/backend/tests/api/test_auth_error_handlers.py index 80a0f5f..ac95d37 100644 --- a/backend/tests/api/test_auth_error_handlers.py +++ b/backend/tests/api/test_auth_error_handlers.py @@ -2,8 +2,10 @@ """ Tests for auth route exception handlers and error paths. """ + +from unittest.mock import patch + import pytest -from unittest.mock import patch, AsyncMock from fastapi import status @@ -11,16 +13,18 @@ class TestLoginSessionCreationFailure: """Test login when session creation fails.""" @pytest.mark.asyncio - async def test_login_succeeds_despite_session_creation_failure(self, client, async_test_user): + async def test_login_succeeds_despite_session_creation_failure( + self, client, async_test_user + ): """Test that login succeeds even if session creation fails.""" # Mock session creation to fail - with patch('app.api.routes.auth.session_crud.create_session', side_effect=Exception("Session creation failed")): + with patch( + "app.api.routes.auth.session_crud.create_session", + side_effect=Exception("Session creation failed"), + ): response = await client.post( "/api/v1/auth/login", - json={ - "email": "testuser@example.com", - "password": "TestPassword123!" - } + json={"email": "testuser@example.com", "password": "TestPassword123!"}, ) # Login should still succeed, just without session record @@ -34,15 +38,20 @@ class TestOAuthLoginSessionCreationFailure: """Test OAuth login when session creation fails.""" @pytest.mark.asyncio - async def test_oauth_login_succeeds_despite_session_failure(self, client, async_test_user): + async def test_oauth_login_succeeds_despite_session_failure( + self, client, async_test_user + ): """Test OAuth login succeeds even if session creation fails.""" - with patch('app.api.routes.auth.session_crud.create_session', side_effect=Exception("Session failed")): + with patch( + "app.api.routes.auth.session_crud.create_session", + side_effect=Exception("Session failed"), + ): response = await client.post( "/api/v1/auth/login/oauth", data={ "username": "testuser@example.com", - "password": "TestPassword123!" - } + "password": "TestPassword123!", + }, ) assert response.status_code == status.HTTP_200_OK @@ -54,23 +63,24 @@ class TestRefreshTokenSessionUpdateFailure: """Test refresh token when session update fails.""" @pytest.mark.asyncio - async def test_refresh_token_succeeds_despite_session_update_failure(self, client, async_test_user): + async def test_refresh_token_succeeds_despite_session_update_failure( + self, client, async_test_user + ): """Test that token refresh succeeds even if session update fails.""" # First login to get tokens response = await client.post( "/api/v1/auth/login", - json={ - "email": "testuser@example.com", - "password": "TestPassword123!" - } + json={"email": "testuser@example.com", "password": "TestPassword123!"}, ) tokens = response.json() # Mock session update to fail - with patch('app.api.routes.auth.session_crud.update_refresh_token', side_effect=Exception("Update failed")): + with patch( + "app.api.routes.auth.session_crud.update_refresh_token", + side_effect=Exception("Update failed"), + ): response = await client.post( - "/api/v1/auth/refresh", - json={"refresh_token": tokens["refresh_token"]} + "/api/v1/auth/refresh", json={"refresh_token": tokens["refresh_token"]} ) # Should still succeed - tokens are issued before update @@ -83,15 +93,14 @@ class TestLogoutWithExpiredToken: """Test logout with expired/invalid token.""" @pytest.mark.asyncio - async def test_logout_with_invalid_token_still_succeeds(self, client, async_test_user): + async def test_logout_with_invalid_token_still_succeeds( + self, client, async_test_user + ): """Test logout succeeds even with invalid refresh token.""" # Login first response = await client.post( "/api/v1/auth/login", - json={ - "email": "testuser@example.com", - "password": "TestPassword123!" - } + json={"email": "testuser@example.com", "password": "TestPassword123!"}, ) access_token = response.json()["access_token"] @@ -99,7 +108,7 @@ class TestLogoutWithExpiredToken: response = await client.post( "/api/v1/auth/logout", headers={"Authorization": f"Bearer {access_token}"}, - json={"refresh_token": "invalid.token.here"} + json={"refresh_token": "invalid.token.here"}, ) # Should succeed (idempotent) @@ -116,19 +125,16 @@ class TestLogoutWithNonExistentSession: """Test logout succeeds even if session not found.""" response = await client.post( "/api/v1/auth/login", - json={ - "email": "testuser@example.com", - "password": "TestPassword123!" - } + json={"email": "testuser@example.com", "password": "TestPassword123!"}, ) tokens = response.json() # Mock session lookup to return None - with patch('app.api.routes.auth.session_crud.get_by_jti', return_value=None): + with patch("app.api.routes.auth.session_crud.get_by_jti", return_value=None): response = await client.post( "/api/v1/auth/logout", headers={"Authorization": f"Bearer {tokens['access_token']}"}, - json={"refresh_token": tokens["refresh_token"]} + json={"refresh_token": tokens["refresh_token"]}, ) # Should succeed (idempotent) @@ -139,23 +145,25 @@ class TestLogoutUnexpectedError: """Test logout with unexpected errors.""" @pytest.mark.asyncio - async def test_logout_with_unexpected_error_returns_success(self, client, async_test_user): + async def test_logout_with_unexpected_error_returns_success( + self, client, async_test_user + ): """Test logout returns success even on unexpected errors.""" response = await client.post( "/api/v1/auth/login", - json={ - "email": "testuser@example.com", - "password": "TestPassword123!" - } + json={"email": "testuser@example.com", "password": "TestPassword123!"}, ) tokens = response.json() # Mock to raise unexpected error - with patch('app.api.routes.auth.session_crud.get_by_jti', side_effect=Exception("Unexpected error")): + with patch( + "app.api.routes.auth.session_crud.get_by_jti", + side_effect=Exception("Unexpected error"), + ): response = await client.post( "/api/v1/auth/logout", headers={"Authorization": f"Bearer {tokens['access_token']}"}, - json={"refresh_token": tokens["refresh_token"]} + json={"refresh_token": tokens["refresh_token"]}, ) # Should still return success (don't expose errors) @@ -172,18 +180,18 @@ class TestLogoutAllUnexpectedError: """Test logout-all handles database errors.""" response = await client.post( "/api/v1/auth/login", - json={ - "email": "testuser@example.com", - "password": "TestPassword123!" - } + json={"email": "testuser@example.com", "password": "TestPassword123!"}, ) access_token = response.json()["access_token"] # Mock to raise database error - with patch('app.api.routes.auth.session_crud.deactivate_all_user_sessions', side_effect=Exception("DB error")): + with patch( + "app.api.routes.auth.session_crud.deactivate_all_user_sessions", + side_effect=Exception("DB error"), + ): response = await client.post( "/api/v1/auth/logout-all", - headers={"Authorization": f"Bearer {access_token}"} + headers={"Authorization": f"Bearer {access_token}"}, ) assert response.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR @@ -193,7 +201,9 @@ class TestPasswordResetConfirmSessionInvalidation: """Test password reset invalidates sessions.""" @pytest.mark.asyncio - async def test_password_reset_continues_despite_session_invalidation_failure(self, client, async_test_user): + async def test_password_reset_continues_despite_session_invalidation_failure( + self, client, async_test_user + ): """Test password reset succeeds even if session invalidation fails.""" # Create a valid password reset token from app.utils.security import create_password_reset_token @@ -201,13 +211,13 @@ class TestPasswordResetConfirmSessionInvalidation: token = create_password_reset_token(async_test_user.email) # Mock session invalidation to fail - with patch('app.api.routes.auth.session_crud.deactivate_all_user_sessions', side_effect=Exception("Invalidation failed")): + with patch( + "app.api.routes.auth.session_crud.deactivate_all_user_sessions", + side_effect=Exception("Invalidation failed"), + ): response = await client.post( "/api/v1/auth/password-reset/confirm", - json={ - "token": token, - "new_password": "NewPassword123!" - } + json={"token": token, "new_password": "NewPassword123!"}, ) # Should still succeed - password was reset diff --git a/backend/tests/api/test_auth_password_reset.py b/backend/tests/api/test_auth_password_reset.py index 6463b24..108dbe8 100755 --- a/backend/tests/api/test_auth_password_reset.py +++ b/backend/tests/api/test_auth_password_reset.py @@ -2,22 +2,22 @@ """ Tests for password reset endpoints. """ + +from unittest.mock import patch + import pytest -import pytest_asyncio -from unittest.mock import patch, AsyncMock, MagicMock from fastapi import status from sqlalchemy import select -from app.schemas.users import PasswordResetRequest, PasswordResetConfirm -from app.utils.security import create_password_reset_token from app.models.user import User +from app.utils.security import create_password_reset_token # Disable rate limiting for tests @pytest.fixture(autouse=True) def disable_rate_limit(): """Disable rate limiting for all tests in this module.""" - with patch('app.api.routes.auth.limiter.enabled', False): + with patch("app.api.routes.auth.limiter.enabled", False): yield @@ -27,12 +27,14 @@ class TestPasswordResetRequest: @pytest.mark.asyncio async def test_password_reset_request_valid_email(self, client, async_test_user): """Test password reset request with valid email.""" - with patch('app.api.routes.auth.email_service.send_password_reset_email') as mock_send: + with patch( + "app.api.routes.auth.email_service.send_password_reset_email" + ) as mock_send: mock_send.return_value = True response = await client.post( "/api/v1/auth/password-reset/request", - json={"email": async_test_user.email} + json={"email": async_test_user.email}, ) assert response.status_code == status.HTTP_200_OK @@ -50,10 +52,12 @@ class TestPasswordResetRequest: @pytest.mark.asyncio async def test_password_reset_request_nonexistent_email(self, client): """Test password reset request with non-existent email.""" - with patch('app.api.routes.auth.email_service.send_password_reset_email') as mock_send: + with patch( + "app.api.routes.auth.email_service.send_password_reset_email" + ) as mock_send: response = await client.post( "/api/v1/auth/password-reset/request", - json={"email": "nonexistent@example.com"} + json={"email": "nonexistent@example.com"}, ) # Should still return success to prevent email enumeration @@ -65,20 +69,26 @@ class TestPasswordResetRequest: mock_send.assert_not_called() @pytest.mark.asyncio - async def test_password_reset_request_inactive_user(self, client, async_test_db, async_test_user): + async def test_password_reset_request_inactive_user( + self, client, async_test_db, async_test_user + ): """Test password reset request with inactive user.""" # Deactivate user - test_engine, AsyncTestingSessionLocal = async_test_db + _test_engine, AsyncTestingSessionLocal = async_test_db async with AsyncTestingSessionLocal() as session: - result = await session.execute(select(User).where(User.id == async_test_user.id)) + result = await session.execute( + select(User).where(User.id == async_test_user.id) + ) user_in_session = result.scalar_one_or_none() user_in_session.is_active = False await session.commit() - with patch('app.api.routes.auth.email_service.send_password_reset_email') as mock_send: + with patch( + "app.api.routes.auth.email_service.send_password_reset_email" + ) as mock_send: response = await client.post( "/api/v1/auth/password-reset/request", - json={"email": async_test_user.email} + json={"email": async_test_user.email}, ) # Should still return success to prevent email enumeration @@ -93,8 +103,7 @@ class TestPasswordResetRequest: async def test_password_reset_request_invalid_email_format(self, client): """Test password reset request with invalid email format.""" response = await client.post( - "/api/v1/auth/password-reset/request", - json={"email": "not-an-email"} + "/api/v1/auth/password-reset/request", json={"email": "not-an-email"} ) assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY @@ -102,22 +111,23 @@ class TestPasswordResetRequest: @pytest.mark.asyncio async def test_password_reset_request_missing_email(self, client): """Test password reset request without email.""" - response = await client.post( - "/api/v1/auth/password-reset/request", - json={} - ) + response = await client.post("/api/v1/auth/password-reset/request", json={}) assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY @pytest.mark.asyncio - async def test_password_reset_request_email_service_error(self, client, async_test_user): + async def test_password_reset_request_email_service_error( + self, client, async_test_user + ): """Test password reset when email service fails.""" - with patch('app.api.routes.auth.email_service.send_password_reset_email') as mock_send: + with patch( + "app.api.routes.auth.email_service.send_password_reset_email" + ) as mock_send: mock_send.side_effect = Exception("SMTP Error") response = await client.post( "/api/v1/auth/password-reset/request", - json={"email": async_test_user.email} + json={"email": async_test_user.email}, ) # Should still return success even if email fails @@ -128,14 +138,16 @@ class TestPasswordResetRequest: @pytest.mark.asyncio async def test_password_reset_request_rate_limiting(self, client, async_test_user): """Test that password reset requests are rate limited.""" - with patch('app.api.routes.auth.email_service.send_password_reset_email') as mock_send: + with patch( + "app.api.routes.auth.email_service.send_password_reset_email" + ) as mock_send: mock_send.return_value = True # Make multiple requests quickly (3/minute limit) for _ in range(3): response = await client.post( "/api/v1/auth/password-reset/request", - json={"email": async_test_user.email} + json={"email": async_test_user.email}, ) assert response.status_code == status.HTTP_200_OK @@ -144,7 +156,9 @@ class TestPasswordResetConfirm: """Tests for POST /auth/password-reset/confirm endpoint.""" @pytest.mark.asyncio - async def test_password_reset_confirm_valid_token(self, client, async_test_user, async_test_db): + async def test_password_reset_confirm_valid_token( + self, client, async_test_user, async_test_db + ): """Test password reset confirmation with valid token.""" # Generate valid token token = create_password_reset_token(async_test_user.email) @@ -152,10 +166,7 @@ class TestPasswordResetConfirm: response = await client.post( "/api/v1/auth/password-reset/confirm", - json={ - "token": token, - "new_password": new_password - } + json={"token": token, "new_password": new_password}, ) assert response.status_code == status.HTTP_200_OK @@ -164,11 +175,14 @@ class TestPasswordResetConfirm: assert "successfully" in data["message"].lower() # Verify user can login with new password - test_engine, AsyncTestingSessionLocal = async_test_db + _test_engine, AsyncTestingSessionLocal = async_test_db async with AsyncTestingSessionLocal() as session: - result = await session.execute(select(User).where(User.id == async_test_user.id)) + result = await session.execute( + select(User).where(User.id == async_test_user.id) + ) updated_user = result.scalar_one_or_none() from app.core.auth import verify_password + assert verify_password(new_password, updated_user.password_hash) is True @pytest.mark.asyncio @@ -184,10 +198,7 @@ class TestPasswordResetConfirm: response = await client.post( "/api/v1/auth/password-reset/confirm", - json={ - "token": token, - "new_password": "NewSecure123!" - } + json={"token": token, "new_password": "NewSecure123!"}, ) assert response.status_code == status.HTTP_400_BAD_REQUEST @@ -202,10 +213,7 @@ class TestPasswordResetConfirm: """Test password reset confirmation with invalid token.""" response = await client.post( "/api/v1/auth/password-reset/confirm", - json={ - "token": "invalid_token_xyz", - "new_password": "NewSecure123!" - } + json={"token": "invalid_token_xyz", "new_password": "NewSecure123!"}, ) assert response.status_code == status.HTTP_400_BAD_REQUEST @@ -222,19 +230,18 @@ class TestPasswordResetConfirm: # Create valid token and tamper with it token = create_password_reset_token(async_test_user.email) - decoded = base64.urlsafe_b64decode(token.encode('utf-8')).decode('utf-8') + decoded = base64.urlsafe_b64decode(token.encode("utf-8")).decode("utf-8") token_data = json.loads(decoded) token_data["payload"]["email"] = "hacker@example.com" # Re-encode tampered token - tampered = base64.urlsafe_b64encode(json.dumps(token_data).encode('utf-8')).decode('utf-8') + tampered = base64.urlsafe_b64encode( + json.dumps(token_data).encode("utf-8") + ).decode("utf-8") response = await client.post( "/api/v1/auth/password-reset/confirm", - json={ - "token": tampered, - "new_password": "NewSecure123!" - } + json={"token": tampered, "new_password": "NewSecure123!"}, ) assert response.status_code == status.HTTP_400_BAD_REQUEST @@ -247,10 +254,7 @@ class TestPasswordResetConfirm: response = await client.post( "/api/v1/auth/password-reset/confirm", - json={ - "token": token, - "new_password": "NewSecure123!" - } + json={"token": token, "new_password": "NewSecure123!"}, ) assert response.status_code == status.HTTP_404_NOT_FOUND @@ -260,12 +264,16 @@ class TestPasswordResetConfirm: assert "not found" in error_msg @pytest.mark.asyncio - async def test_password_reset_confirm_inactive_user(self, client, async_test_user, async_test_db): + async def test_password_reset_confirm_inactive_user( + self, client, async_test_user, async_test_db + ): """Test password reset confirmation for inactive user.""" # Deactivate user - test_engine, AsyncTestingSessionLocal = async_test_db + _test_engine, AsyncTestingSessionLocal = async_test_db async with AsyncTestingSessionLocal() as session: - result = await session.execute(select(User).where(User.id == async_test_user.id)) + result = await session.execute( + select(User).where(User.id == async_test_user.id) + ) user_in_session = result.scalar_one_or_none() user_in_session.is_active = False await session.commit() @@ -274,10 +282,7 @@ class TestPasswordResetConfirm: response = await client.post( "/api/v1/auth/password-reset/confirm", - json={ - "token": token, - "new_password": "NewSecure123!" - } + json={"token": token, "new_password": "NewSecure123!"}, ) assert response.status_code == status.HTTP_400_BAD_REQUEST @@ -301,10 +306,7 @@ class TestPasswordResetConfirm: for weak_password in weak_passwords: response = await client.post( "/api/v1/auth/password-reset/confirm", - json={ - "token": token, - "new_password": weak_password - } + json={"token": token, "new_password": weak_password}, ) assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY @@ -315,15 +317,14 @@ class TestPasswordResetConfirm: # Missing token response = await client.post( "/api/v1/auth/password-reset/confirm", - json={"new_password": "NewSecure123!"} + json={"new_password": "NewSecure123!"}, ) assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY # Missing password token = create_password_reset_token("test@example.com") response = await client.post( - "/api/v1/auth/password-reset/confirm", - json={"token": token} + "/api/v1/auth/password-reset/confirm", json={"token": token} ) assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY @@ -333,15 +334,12 @@ class TestPasswordResetConfirm: token = create_password_reset_token(async_test_user.email) # Mock the database commit to raise an exception - with patch('app.api.routes.auth.user_crud.get_by_email') as mock_get: + with patch("app.api.routes.auth.user_crud.get_by_email") as mock_get: mock_get.side_effect = Exception("Database error") response = await client.post( "/api/v1/auth/password-reset/confirm", - json={ - "token": token, - "new_password": "NewSecure123!" - } + json={"token": token, "new_password": "NewSecure123!"}, ) assert response.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR @@ -351,18 +349,22 @@ class TestPasswordResetConfirm: assert "error" in error_msg or "resetting" in error_msg @pytest.mark.asyncio - async def test_password_reset_full_flow(self, client, async_test_user, async_test_db): + async def test_password_reset_full_flow( + self, client, async_test_user, async_test_db + ): """Test complete password reset flow.""" original_password = async_test_user.password_hash new_password = "BrandNew123!" # Step 1: Request password reset - with patch('app.api.routes.auth.email_service.send_password_reset_email') as mock_send: + with patch( + "app.api.routes.auth.email_service.send_password_reset_email" + ) as mock_send: mock_send.return_value = True response = await client.post( "/api/v1/auth/password-reset/request", - json={"email": async_test_user.email} + json={"email": async_test_user.email}, ) assert response.status_code == status.HTTP_200_OK @@ -374,29 +376,24 @@ class TestPasswordResetConfirm: # Step 2: Confirm password reset response = await client.post( "/api/v1/auth/password-reset/confirm", - json={ - "token": reset_token, - "new_password": new_password - } + json={"token": reset_token, "new_password": new_password}, ) assert response.status_code == status.HTTP_200_OK # Step 3: Verify old password doesn't work - test_engine, AsyncTestingSessionLocal = async_test_db + _test_engine, AsyncTestingSessionLocal = async_test_db async with AsyncTestingSessionLocal() as session: - result = await session.execute(select(User).where(User.id == async_test_user.id)) + result = await session.execute( + select(User).where(User.id == async_test_user.id) + ) updated_user = result.scalar_one_or_none() - from app.core.auth import verify_password assert updated_user.password_hash != original_password # Step 4: Verify new password works response = await client.post( "/api/v1/auth/login", - json={ - "email": async_test_user.email, - "password": new_password - } + json={"email": async_test_user.email, "password": new_password}, ) assert response.status_code == status.HTTP_200_OK diff --git a/backend/tests/api/test_auth_security.py b/backend/tests/api/test_auth_security.py index 1c36b88..3ce8df6 100644 --- a/backend/tests/api/test_auth_security.py +++ b/backend/tests/api/test_auth_security.py @@ -8,11 +8,10 @@ Critical security tests covering: These tests prevent real-world attack scenarios. """ + import pytest from httpx import AsyncClient -from sqlalchemy.ext.asyncio import AsyncSession -from app.core.auth import create_refresh_token from app.crud.session import session as session_crud from app.models.user import User @@ -30,10 +29,7 @@ class TestRevokedSessionSecurity: @pytest.mark.asyncio async def test_refresh_token_rejected_after_logout( - self, - client: AsyncClient, - async_test_db, - async_test_user: User + self, client: AsyncClient, async_test_db, async_test_user: User ): """ Test that refresh tokens are rejected after session is deactivated. @@ -45,10 +41,10 @@ class TestRevokedSessionSecurity: 4. Attacker tries to use stolen refresh token 5. System MUST reject it (session revoked) """ - test_engine, SessionLocal = async_test_db + _test_engine, SessionLocal = async_test_db # Step 1: Create a session and refresh token for the user - async with SessionLocal() as session: + async with SessionLocal(): # Login to get tokens response = await client.post( "/api/v1/auth/login", @@ -64,8 +60,7 @@ class TestRevokedSessionSecurity: # Step 2: Verify refresh token works before logout response = await client.post( - "/api/v1/auth/refresh", - json={"refresh_token": refresh_token} + "/api/v1/auth/refresh", json={"refresh_token": refresh_token} ) assert response.status_code == 200, "Refresh should work before logout" @@ -73,14 +68,13 @@ class TestRevokedSessionSecurity: response = await client.post( "/api/v1/auth/logout", headers={"Authorization": f"Bearer {access_token}"}, - json={"refresh_token": refresh_token} + json={"refresh_token": refresh_token}, ) assert response.status_code == 200, "Logout should succeed" # Step 4: Attacker tries to use stolen refresh token response = await client.post( - "/api/v1/auth/refresh", - json={"refresh_token": refresh_token} + "/api/v1/auth/refresh", json={"refresh_token": refresh_token} ) # Step 5: System MUST reject (covers lines 261-262) @@ -93,10 +87,7 @@ class TestRevokedSessionSecurity: @pytest.mark.asyncio async def test_refresh_token_rejected_for_deleted_session( - self, - client: AsyncClient, - async_test_db, - async_test_user: User + self, client: AsyncClient, async_test_db, async_test_user: User ): """ Test that tokens for deleted sessions are rejected. @@ -104,7 +95,7 @@ class TestRevokedSessionSecurity: Attack Scenario: Admin deletes a session from database, but attacker has the token. """ - test_engine, SessionLocal = async_test_db + _test_engine, SessionLocal = async_test_db # Step 1: Login to create a session response = await client.post( @@ -120,6 +111,7 @@ class TestRevokedSessionSecurity: # Step 2: Manually delete the session from database (simulating admin action) from app.core.auth import decode_token + token_data = decode_token(refresh_token, verify_type="refresh") jti = token_data.jti @@ -132,15 +124,17 @@ class TestRevokedSessionSecurity: # Step 3: Try to use the refresh token response = await client.post( - "/api/v1/auth/refresh", - json={"refresh_token": refresh_token} + "/api/v1/auth/refresh", json={"refresh_token": refresh_token} ) # Should reject (session doesn't exist) assert response.status_code == 401 data = response.json() if "errors" in data: - assert "revoked" in data["errors"][0]["message"].lower() or "session" in data["errors"][0]["message"].lower() + assert ( + "revoked" in data["errors"][0]["message"].lower() + or "session" in data["errors"][0]["message"].lower() + ) else: assert "revoked" in data.get("detail", "").lower() @@ -162,7 +156,7 @@ class TestSessionHijackingSecurity: client: AsyncClient, async_test_db, async_test_user: User, - async_test_superuser: User + async_test_superuser: User, ): """ Test that users cannot logout other users' sessions. @@ -173,7 +167,7 @@ class TestSessionHijackingSecurity: 3. User A tries to logout User B's session 4. System MUST reject (cross-user attack) """ - test_engine, SessionLocal = async_test_db + _test_engine, _SessionLocal = async_test_db # Step 1: User A logs in response = await client.post( @@ -202,8 +196,10 @@ class TestSessionHijackingSecurity: # Step 3: User A tries to logout User B's session using User B's refresh token response = await client.post( "/api/v1/auth/logout", - headers={"Authorization": f"Bearer {user_a_access}"}, # User A's access token - json={"refresh_token": user_b_refresh} # But User B's refresh token + headers={ + "Authorization": f"Bearer {user_a_access}" + }, # User A's access token + json={"refresh_token": user_b_refresh}, # But User B's refresh token ) # Step 4: System MUST reject (covers lines 509-513) @@ -217,9 +213,7 @@ class TestSessionHijackingSecurity: @pytest.mark.asyncio async def test_users_can_logout_their_own_sessions( - self, - client: AsyncClient, - async_test_user: User + self, client: AsyncClient, async_test_user: User ): """ Sanity check: Users CAN logout their own sessions. @@ -241,6 +235,8 @@ class TestSessionHijackingSecurity: response = await client.post( "/api/v1/auth/logout", headers={"Authorization": f"Bearer {tokens['access_token']}"}, - json={"refresh_token": tokens["refresh_token"]} + json={"refresh_token": tokens["refresh_token"]}, + ) + assert response.status_code == 200, ( + "Users should be able to logout their own sessions" ) - assert response.status_code == 200, "Users should be able to logout their own sessions" diff --git a/backend/tests/api/test_organizations.py b/backend/tests/api/test_organizations.py index a35975d..404a43e 100644 --- a/backend/tests/api/test_organizations.py +++ b/backend/tests/api/test_organizations.py @@ -5,16 +5,18 @@ Tests for organization routes (user endpoints). These test the routes in app/api/routes/organizations.py which allow users to view and manage organizations they belong to. """ + +from unittest.mock import patch +from uuid import uuid4 + import pytest import pytest_asyncio from fastapi import status -from uuid import uuid4 -from unittest.mock import patch, AsyncMock +from app.core.auth import get_password_hash from app.models.organization import Organization from app.models.user import User -from app.models.user_organization import UserOrganization, OrganizationRole -from app.core.auth import get_password_hash +from app.models.user_organization import OrganizationRole, UserOrganization @pytest_asyncio.fixture @@ -22,10 +24,7 @@ async def user_token(client, async_test_user): """Get access token for regular user.""" response = await client.post( "/api/v1/auth/login", - json={ - "email": "testuser@example.com", - "password": "TestPassword123!" - } + json={"email": "testuser@example.com", "password": "TestPassword123!"}, ) assert response.status_code == 200 return response.json()["access_token"] @@ -34,7 +33,7 @@ async def user_token(client, async_test_user): @pytest_asyncio.fixture async def second_user(async_test_db): """Create a second test user.""" - test_engine, AsyncTestingSessionLocal = async_test_db + _test_engine, AsyncTestingSessionLocal = async_test_db async with AsyncTestingSessionLocal() as session: user = User( id=uuid4(), @@ -56,12 +55,12 @@ async def second_user(async_test_db): @pytest_asyncio.fixture async def test_org_with_user_member(async_test_db, async_test_user): """Create a test organization with async_test_user as a member.""" - test_engine, AsyncTestingSessionLocal = async_test_db + _test_engine, AsyncTestingSessionLocal = async_test_db async with AsyncTestingSessionLocal() as session: org = Organization( name="Member Org", slug="member-org", - description="Test organization where user is a member" + description="Test organization where user is a member", ) session.add(org) await session.commit() @@ -72,7 +71,7 @@ async def test_org_with_user_member(async_test_db, async_test_user): user_id=async_test_user.id, organization_id=org.id, role=OrganizationRole.MEMBER, - is_active=True + is_active=True, ) session.add(membership) await session.commit() @@ -83,12 +82,12 @@ async def test_org_with_user_member(async_test_db, async_test_user): @pytest_asyncio.fixture async def test_org_with_user_admin(async_test_db, async_test_user): """Create a test organization with async_test_user as an admin.""" - test_engine, AsyncTestingSessionLocal = async_test_db + _test_engine, AsyncTestingSessionLocal = async_test_db async with AsyncTestingSessionLocal() as session: org = Organization( name="Admin Org", slug="admin-org", - description="Test organization where user is an admin" + description="Test organization where user is an admin", ) session.add(org) await session.commit() @@ -99,7 +98,7 @@ async def test_org_with_user_admin(async_test_db, async_test_user): user_id=async_test_user.id, organization_id=org.id, role=OrganizationRole.ADMIN, - is_active=True + is_active=True, ) session.add(membership) await session.commit() @@ -110,12 +109,12 @@ async def test_org_with_user_admin(async_test_db, async_test_user): @pytest_asyncio.fixture async def test_org_with_user_owner(async_test_db, async_test_user): """Create a test organization with async_test_user as owner.""" - test_engine, AsyncTestingSessionLocal = async_test_db + _test_engine, AsyncTestingSessionLocal = async_test_db async with AsyncTestingSessionLocal() as session: org = Organization( name="Owner Org", slug="owner-org", - description="Test organization where user is owner" + description="Test organization where user is owner", ) session.add(org) await session.commit() @@ -126,7 +125,7 @@ async def test_org_with_user_owner(async_test_db, async_test_user): user_id=async_test_user.id, organization_id=org.id, role=OrganizationRole.OWNER, - is_active=True + is_active=True, ) session.add(membership) await session.commit() @@ -136,21 +135,18 @@ async def test_org_with_user_owner(async_test_db, async_test_user): # ===== GET /api/v1/organizations/me ===== + class TestGetMyOrganizations: """Tests for GET /api/v1/organizations/me endpoint.""" @pytest.mark.asyncio async def test_get_my_organizations_success( - self, - client, - user_token, - test_org_with_user_member, - test_org_with_user_admin + self, client, user_token, test_org_with_user_member, test_org_with_user_admin ): """Test successfully getting user's organizations (covers lines 54-79).""" response = await client.get( "/api/v1/organizations/me", - headers={"Authorization": f"Bearer {user_token}"} + headers={"Authorization": f"Bearer {user_token}"}, ) assert response.status_code == status.HTTP_200_OK @@ -167,21 +163,15 @@ class TestGetMyOrganizations: @pytest.mark.asyncio async def test_get_my_organizations_filter_active( - self, - client, - async_test_db, - async_test_user, - user_token + self, client, async_test_db, async_test_user, user_token ): """Test filtering organizations by active status.""" - test_engine, AsyncTestingSessionLocal = async_test_db + _test_engine, AsyncTestingSessionLocal = async_test_db # Create active org async with AsyncTestingSessionLocal() as session: active_org = Organization( - name="Active Org", - slug="active-org-filter", - is_active=True + name="Active Org", slug="active-org-filter", is_active=True ) session.add(active_org) await session.commit() @@ -192,14 +182,14 @@ class TestGetMyOrganizations: user_id=async_test_user.id, organization_id=active_org.id, role=OrganizationRole.MEMBER, - is_active=True + is_active=True, ) session.add(membership) await session.commit() response = await client.get( "/api/v1/organizations/me?is_active=true", - headers={"Authorization": f"Bearer {user_token}"} + headers={"Authorization": f"Bearer {user_token}"}, ) assert response.status_code == status.HTTP_200_OK @@ -209,7 +199,7 @@ class TestGetMyOrganizations: @pytest.mark.asyncio async def test_get_my_organizations_empty(self, client, async_test_db): """Test getting organizations when user has none.""" - test_engine, AsyncTestingSessionLocal = async_test_db + _test_engine, AsyncTestingSessionLocal = async_test_db # Create user with no org memberships async with AsyncTestingSessionLocal() as session: @@ -219,7 +209,7 @@ class TestGetMyOrganizations: password_hash=get_password_hash("TestPassword123!"), first_name="No", last_name="Org", - is_active=True + is_active=True, ) session.add(user) await session.commit() @@ -227,13 +217,12 @@ class TestGetMyOrganizations: # Login to get token login_response = await client.post( "/api/v1/auth/login", - json={"email": "noorg@example.com", "password": "TestPassword123!"} + json={"email": "noorg@example.com", "password": "TestPassword123!"}, ) token = login_response.json()["access_token"] response = await client.get( - "/api/v1/organizations/me", - headers={"Authorization": f"Bearer {token}"} + "/api/v1/organizations/me", headers={"Authorization": f"Bearer {token}"} ) assert response.status_code == status.HTTP_200_OK @@ -243,20 +232,18 @@ class TestGetMyOrganizations: # ===== GET /api/v1/organizations/{organization_id} ===== + class TestGetOrganization: """Tests for GET /api/v1/organizations/{organization_id} endpoint.""" @pytest.mark.asyncio async def test_get_organization_success( - self, - client, - user_token, - test_org_with_user_member + self, client, user_token, test_org_with_user_member ): """Test successfully getting organization details (covers lines 103-122).""" response = await client.get( f"/api/v1/organizations/{test_org_with_user_member.id}", - headers={"Authorization": f"Bearer {user_token}"} + headers={"Authorization": f"Bearer {user_token}"}, ) assert response.status_code == status.HTTP_200_OK @@ -272,7 +259,7 @@ class TestGetOrganization: fake_org_id = uuid4() response = await client.get( f"/api/v1/organizations/{fake_org_id}", - headers={"Authorization": f"Bearer {user_token}"} + headers={"Authorization": f"Bearer {user_token}"}, ) # Permission dependency checks membership before endpoint logic @@ -283,20 +270,14 @@ class TestGetOrganization: @pytest.mark.asyncio async def test_get_organization_not_member( - self, - client, - async_test_db, - async_test_user + self, client, async_test_db, async_test_user ): """Test getting organization where user is not a member fails.""" - test_engine, AsyncTestingSessionLocal = async_test_db + _test_engine, AsyncTestingSessionLocal = async_test_db # Create org without adding user async with AsyncTestingSessionLocal() as session: - org = Organization( - name="Not Member Org", - slug="not-member-org" - ) + org = Organization(name="Not Member Org", slug="not-member-org") session.add(org) await session.commit() await session.refresh(org) @@ -305,13 +286,13 @@ class TestGetOrganization: # Login as user login_response = await client.post( "/api/v1/auth/login", - json={"email": "testuser@example.com", "password": "TestPassword123!"} + json={"email": "testuser@example.com", "password": "TestPassword123!"}, ) token = login_response.json()["access_token"] response = await client.get( f"/api/v1/organizations/{org_id}", - headers={"Authorization": f"Bearer {token}"} + headers={"Authorization": f"Bearer {token}"}, ) # Should fail permission check @@ -320,6 +301,7 @@ class TestGetOrganization: # ===== GET /api/v1/organizations/{organization_id}/members ===== + class TestGetOrganizationMembers: """Tests for GET /api/v1/organizations/{organization_id}/members endpoint.""" @@ -331,10 +313,10 @@ class TestGetOrganizationMembers: async_test_user, second_user, user_token, - test_org_with_user_member + test_org_with_user_member, ): """Test successfully getting organization members (covers lines 150-168).""" - test_engine, AsyncTestingSessionLocal = async_test_db + _test_engine, AsyncTestingSessionLocal = async_test_db # Add second user to org async with AsyncTestingSessionLocal() as session: @@ -342,14 +324,14 @@ class TestGetOrganizationMembers: user_id=second_user.id, organization_id=test_org_with_user_member.id, role=OrganizationRole.MEMBER, - is_active=True + is_active=True, ) session.add(membership) await session.commit() response = await client.get( f"/api/v1/organizations/{test_org_with_user_member.id}/members", - headers={"Authorization": f"Bearer {user_token}"} + headers={"Authorization": f"Bearer {user_token}"}, ) assert response.status_code == status.HTTP_200_OK @@ -360,15 +342,12 @@ class TestGetOrganizationMembers: @pytest.mark.asyncio async def test_get_organization_members_with_pagination( - self, - client, - user_token, - test_org_with_user_member + self, client, user_token, test_org_with_user_member ): """Test pagination parameters.""" response = await client.get( f"/api/v1/organizations/{test_org_with_user_member.id}/members?page=1&limit=10", - headers={"Authorization": f"Bearer {user_token}"} + headers={"Authorization": f"Bearer {user_token}"}, ) assert response.status_code == status.HTTP_200_OK @@ -385,10 +364,10 @@ class TestGetOrganizationMembers: async_test_user, second_user, user_token, - test_org_with_user_member + test_org_with_user_member, ): """Test filtering members by active status.""" - test_engine, AsyncTestingSessionLocal = async_test_db + _test_engine, AsyncTestingSessionLocal = async_test_db # Add second user as inactive member async with AsyncTestingSessionLocal() as session: @@ -396,7 +375,7 @@ class TestGetOrganizationMembers: user_id=second_user.id, organization_id=test_org_with_user_member.id, role=OrganizationRole.MEMBER, - is_active=False + is_active=False, ) session.add(membership) await session.commit() @@ -404,7 +383,7 @@ class TestGetOrganizationMembers: # Filter for active only response = await client.get( f"/api/v1/organizations/{test_org_with_user_member.id}/members?is_active=true", - headers={"Authorization": f"Bearer {user_token}"} + headers={"Authorization": f"Bearer {user_token}"}, ) assert response.status_code == status.HTTP_200_OK @@ -416,31 +395,26 @@ class TestGetOrganizationMembers: # ===== PUT /api/v1/organizations/{organization_id} ===== + class TestUpdateOrganization: """Tests for PUT /api/v1/organizations/{organization_id} endpoint.""" @pytest.mark.asyncio async def test_update_organization_as_admin_success( - self, - client, - async_test_user, - test_org_with_user_admin + self, client, async_test_user, test_org_with_user_admin ): """Test successfully updating organization as admin (covers lines 193-215).""" # Login as admin user login_response = await client.post( "/api/v1/auth/login", - json={"email": "testuser@example.com", "password": "TestPassword123!"} + json={"email": "testuser@example.com", "password": "TestPassword123!"}, ) admin_token = login_response.json()["access_token"] response = await client.put( f"/api/v1/organizations/{test_org_with_user_admin.id}", - json={ - "name": "Updated Admin Org", - "description": "Updated description" - }, - headers={"Authorization": f"Bearer {admin_token}"} + json={"name": "Updated Admin Org", "description": "Updated description"}, + headers={"Authorization": f"Bearer {admin_token}"}, ) assert response.status_code == status.HTTP_200_OK @@ -450,23 +424,20 @@ class TestUpdateOrganization: @pytest.mark.asyncio async def test_update_organization_as_owner_success( - self, - client, - async_test_user, - test_org_with_user_owner + self, client, async_test_user, test_org_with_user_owner ): """Test successfully updating organization as owner.""" # Login as owner user login_response = await client.post( "/api/v1/auth/login", - json={"email": "testuser@example.com", "password": "TestPassword123!"} + json={"email": "testuser@example.com", "password": "TestPassword123!"}, ) owner_token = login_response.json()["access_token"] response = await client.put( f"/api/v1/organizations/{test_org_with_user_owner.id}", json={"name": "Updated Owner Org"}, - headers={"Authorization": f"Bearer {owner_token}"} + headers={"Authorization": f"Bearer {owner_token}"}, ) assert response.status_code == status.HTTP_200_OK @@ -475,16 +446,13 @@ class TestUpdateOrganization: @pytest.mark.asyncio async def test_update_organization_as_member_fails( - self, - client, - user_token, - test_org_with_user_member + self, client, user_token, test_org_with_user_member ): """Test updating organization as regular member fails.""" response = await client.put( f"/api/v1/organizations/{test_org_with_user_member.id}", json={"name": "Should Fail"}, - headers={"Authorization": f"Bearer {user_token}"} + headers={"Authorization": f"Bearer {user_token}"}, ) # Should fail permission check (need admin or owner) @@ -492,15 +460,13 @@ class TestUpdateOrganization: @pytest.mark.asyncio async def test_update_organization_not_found( - self, - client, - test_org_with_user_admin + self, client, test_org_with_user_admin ): """Test updating nonexistent organization returns 403 (permission check first).""" # Login as admin login_response = await client.post( "/api/v1/auth/login", - json={"email": "testuser@example.com", "password": "TestPassword123!"} + json={"email": "testuser@example.com", "password": "TestPassword123!"}, ) admin_token = login_response.json()["access_token"] @@ -508,7 +474,7 @@ class TestUpdateOrganization: response = await client.put( f"/api/v1/organizations/{fake_org_id}", json={"name": "Updated"}, - headers={"Authorization": f"Bearer {admin_token}"} + headers={"Authorization": f"Bearer {admin_token}"}, ) # Permission dependency checks admin role before endpoint logic @@ -520,6 +486,7 @@ class TestUpdateOrganization: # ===== Authentication Tests ===== + class TestOrganizationAuthentication: """Test authentication requirements for organization endpoints.""" @@ -548,14 +515,14 @@ class TestOrganizationAuthentication: """Test unauthenticated access to update fails.""" fake_id = uuid4() response = await client.put( - f"/api/v1/organizations/{fake_id}", - json={"name": "Test"} + f"/api/v1/organizations/{fake_id}", json={"name": "Test"} ) assert response.status_code == status.HTTP_401_UNAUTHORIZED # ===== Exception Handler Tests (Database Error Scenarios) ===== + class TestOrganizationExceptionHandlers: """ Test exception handlers in organization endpoints. @@ -566,86 +533,74 @@ class TestOrganizationExceptionHandlers: @pytest.mark.asyncio async def test_get_my_organizations_database_error( - self, - client, - user_token, - test_org_with_user_member + self, client, user_token, test_org_with_user_member ): """Test generic exception handler in get_my_organizations (covers lines 81-83).""" with patch( "app.crud.organization.organization.get_user_organizations_with_details", - side_effect=Exception("Database connection lost") + side_effect=Exception("Database connection lost"), ): # The exception handler logs and re-raises, so we expect the exception # to propagate (which proves the handler executed) with pytest.raises(Exception, match="Database connection lost"): await client.get( "/api/v1/organizations/me", - headers={"Authorization": f"Bearer {user_token}"} + headers={"Authorization": f"Bearer {user_token}"}, ) @pytest.mark.asyncio async def test_get_organization_database_error( - self, - client, - user_token, - test_org_with_user_member + self, client, user_token, test_org_with_user_member ): """Test generic exception handler in get_organization (covers lines 124-128).""" with patch( "app.crud.organization.organization.get", - side_effect=Exception("Database timeout") + side_effect=Exception("Database timeout"), ): with pytest.raises(Exception, match="Database timeout"): await client.get( f"/api/v1/organizations/{test_org_with_user_member.id}", - headers={"Authorization": f"Bearer {user_token}"} + headers={"Authorization": f"Bearer {user_token}"}, ) @pytest.mark.asyncio async def test_get_organization_members_database_error( - self, - client, - user_token, - test_org_with_user_member + self, client, user_token, test_org_with_user_member ): """Test generic exception handler in get_organization_members (covers lines 170-172).""" with patch( "app.crud.organization.organization.get_organization_members", - side_effect=Exception("Connection pool exhausted") + side_effect=Exception("Connection pool exhausted"), ): with pytest.raises(Exception, match="Connection pool exhausted"): await client.get( f"/api/v1/organizations/{test_org_with_user_member.id}/members", - headers={"Authorization": f"Bearer {user_token}"} + headers={"Authorization": f"Bearer {user_token}"}, ) @pytest.mark.asyncio async def test_update_organization_database_error( - self, - client, - async_test_user, - test_org_with_user_admin + self, client, async_test_user, test_org_with_user_admin ): """Test generic exception handler in update_organization (covers lines 217-221).""" # Login as admin user login_response = await client.post( "/api/v1/auth/login", - json={"email": "testuser@example.com", "password": "TestPassword123!"} + json={"email": "testuser@example.com", "password": "TestPassword123!"}, ) admin_token = login_response.json()["access_token"] with patch( "app.crud.organization.organization.get", - return_value=test_org_with_user_admin + return_value=test_org_with_user_admin, ): with patch( "app.crud.organization.organization.update", - side_effect=Exception("Write lock timeout") + side_effect=Exception("Write lock timeout"), ): with pytest.raises(Exception, match="Write lock timeout"): await client.put( f"/api/v1/organizations/{test_org_with_user_admin.id}", json={"name": "Should Fail"}, - headers={"Authorization": f"Bearer {admin_token}"} + headers={"Authorization": f"Bearer {admin_token}"}, ) diff --git a/backend/tests/api/test_permissions.py b/backend/tests/api/test_permissions.py index 3802c51..66dacf0 100644 --- a/backend/tests/api/test_permissions.py +++ b/backend/tests/api/test_permissions.py @@ -5,15 +5,17 @@ Tests for permission dependencies - CRITICAL SECURITY PATHS. These tests ensure superusers can bypass organization checks correctly, and that regular users are properly blocked. """ + +from uuid import uuid4 + import pytest import pytest_asyncio from fastapi import status -from uuid import uuid4 +from app.core.auth import get_password_hash from app.models.organization import Organization from app.models.user import User -from app.models.user_organization import UserOrganization, OrganizationRole -from app.core.auth import get_password_hash +from app.models.user_organization import OrganizationRole, UserOrganization @pytest_asyncio.fixture @@ -21,10 +23,7 @@ async def superuser_token(client, async_test_superuser): """Get access token for superuser.""" response = await client.post( "/api/v1/auth/login", - json={ - "email": "superuser@example.com", - "password": "SuperPassword123!" - } + json={"email": "superuser@example.com", "password": "SuperPassword123!"}, ) assert response.status_code == 200 return response.json()["access_token"] @@ -35,10 +34,7 @@ async def regular_user_token(client, async_test_user): """Get access token for regular user.""" response = await client.post( "/api/v1/auth/login", - json={ - "email": "testuser@example.com", - "password": "TestPassword123!" - } + json={"email": "testuser@example.com", "password": "TestPassword123!"}, ) assert response.status_code == 200 return response.json()["access_token"] @@ -47,12 +43,12 @@ async def regular_user_token(client, async_test_user): @pytest_asyncio.fixture async def test_org_no_members(async_test_db): """Create a test organization with NO members.""" - test_engine, AsyncTestingSessionLocal = async_test_db + _test_engine, AsyncTestingSessionLocal = async_test_db async with AsyncTestingSessionLocal() as session: org = Organization( name="No Members Org", slug="no-members-org", - description="Test org with no members" + description="Test org with no members", ) session.add(org) await session.commit() @@ -63,12 +59,12 @@ async def test_org_no_members(async_test_db): @pytest_asyncio.fixture async def test_org_with_member(async_test_db, async_test_user): """Create a test organization with async_test_user as member (not admin).""" - test_engine, AsyncTestingSessionLocal = async_test_db + _test_engine, AsyncTestingSessionLocal = async_test_db async with AsyncTestingSessionLocal() as session: org = Organization( name="Member Only Org", slug="member-only-org", - description="Test org where user is just a member" + description="Test org where user is just a member", ) session.add(org) await session.commit() @@ -79,7 +75,7 @@ async def test_org_with_member(async_test_db, async_test_user): user_id=async_test_user.id, organization_id=org.id, role=OrganizationRole.MEMBER, - is_active=True + is_active=True, ) session.add(membership) await session.commit() @@ -89,6 +85,7 @@ async def test_org_with_member(async_test_db, async_test_user): # ===== CRITICAL SECURITY TESTS: Superuser Bypass ===== + class TestSuperuserBypass: """ CRITICAL: Test that superusers can bypass organization checks. @@ -99,10 +96,7 @@ class TestSuperuserBypass: @pytest.mark.asyncio async def test_superuser_can_access_org_not_member_of( - self, - client, - superuser_token, - test_org_no_members + self, client, superuser_token, test_org_no_members ): """ CRITICAL: Superuser should bypass membership check (covers line 175). @@ -111,7 +105,7 @@ class TestSuperuserBypass: """ response = await client.get( f"/api/v1/organizations/{test_org_no_members.id}", - headers={"Authorization": f"Bearer {superuser_token}"} + headers={"Authorization": f"Bearer {superuser_token}"}, ) # Superuser should succeed even though they're not a member @@ -121,15 +115,12 @@ class TestSuperuserBypass: @pytest.mark.asyncio async def test_regular_user_cannot_access_org_not_member_of( - self, - client, - regular_user_token, - test_org_no_members + self, client, regular_user_token, test_org_no_members ): """Regular user should be blocked from org they're not a member of.""" response = await client.get( f"/api/v1/organizations/{test_org_no_members.id}", - headers={"Authorization": f"Bearer {regular_user_token}"} + headers={"Authorization": f"Bearer {regular_user_token}"}, ) # Regular user should fail permission check @@ -137,10 +128,7 @@ class TestSuperuserBypass: @pytest.mark.asyncio async def test_superuser_can_update_org_not_admin_of( - self, - client, - superuser_token, - test_org_no_members + self, client, superuser_token, test_org_no_members ): """ CRITICAL: Superuser should bypass admin check (covers line 99). @@ -150,7 +138,7 @@ class TestSuperuserBypass: response = await client.put( f"/api/v1/organizations/{test_org_no_members.id}", json={"name": "Updated by Superuser"}, - headers={"Authorization": f"Bearer {superuser_token}"} + headers={"Authorization": f"Bearer {superuser_token}"}, ) # Superuser should succeed in updating org @@ -160,16 +148,13 @@ class TestSuperuserBypass: @pytest.mark.asyncio async def test_regular_member_cannot_update_org( - self, - client, - regular_user_token, - test_org_with_member + self, client, regular_user_token, test_org_with_member ): """Regular member (not admin) should NOT be able to update org.""" response = await client.put( f"/api/v1/organizations/{test_org_with_member.id}", json={"name": "Should Fail"}, - headers={"Authorization": f"Bearer {regular_user_token}"} + headers={"Authorization": f"Bearer {regular_user_token}"}, ) # Member should fail - need admin or owner role @@ -177,15 +162,12 @@ class TestSuperuserBypass: @pytest.mark.asyncio async def test_superuser_can_list_org_members_not_member_of( - self, - client, - superuser_token, - test_org_no_members + self, client, superuser_token, test_org_no_members ): """CRITICAL: Superuser should bypass membership check to list members.""" response = await client.get( f"/api/v1/organizations/{test_org_no_members.id}/members", - headers={"Authorization": f"Bearer {superuser_token}"} + headers={"Authorization": f"Bearer {superuser_token}"}, ) # Superuser should succeed @@ -197,13 +179,14 @@ class TestSuperuserBypass: # ===== Edge Cases and Security Tests ===== + class TestPermissionEdgeCases: """Test edge cases in permission system.""" @pytest.mark.asyncio async def test_inactive_user_blocked(self, client, async_test_db): """Test that inactive users are blocked.""" - test_engine, AsyncTestingSessionLocal = async_test_db + _test_engine, AsyncTestingSessionLocal = async_test_db # Create inactive user async with AsyncTestingSessionLocal() as session: @@ -213,7 +196,7 @@ class TestPermissionEdgeCases: password_hash=get_password_hash("TestPassword123!"), first_name="Inactive", last_name="User", - is_active=False # INACTIVE + is_active=False, # INACTIVE ) session.add(user) await session.commit() @@ -222,7 +205,7 @@ class TestPermissionEdgeCases: # But accessing protected endpoints should fail login_response = await client.post( "/api/v1/auth/login", - json={"email": "inactive@example.com", "password": "TestPassword123!"} + json={"email": "inactive@example.com", "password": "TestPassword123!"}, ) # Login might fail for inactive users depending on auth implementation @@ -231,18 +214,18 @@ class TestPermissionEdgeCases: # Try to access protected endpoint response = await client.get( - "/api/v1/users/me", - headers={"Authorization": f"Bearer {token}"} + "/api/v1/users/me", headers={"Authorization": f"Bearer {token}"} ) # Should be blocked - assert response.status_code in [status.HTTP_401_UNAUTHORIZED, status.HTTP_403_FORBIDDEN] + assert response.status_code in [ + status.HTTP_401_UNAUTHORIZED, + status.HTTP_403_FORBIDDEN, + ] @pytest.mark.asyncio async def test_nonexistent_organization_returns_403_not_404( - self, - client, - regular_user_token + self, client, regular_user_token ): """ Test that accessing nonexistent org returns 403, not 404. @@ -254,7 +237,7 @@ class TestPermissionEdgeCases: fake_org_id = uuid4() response = await client.get( f"/api/v1/organizations/{fake_org_id}", - headers={"Authorization": f"Bearer {regular_user_token}"} + headers={"Authorization": f"Bearer {regular_user_token}"}, ) # Should get 403 (not a member), not 404 (doesn't exist) @@ -264,18 +247,16 @@ class TestPermissionEdgeCases: # ===== Admin Role Tests ===== + class TestAdminRolePermissions: """Test admin role can perform admin actions.""" @pytest_asyncio.fixture async def test_org_with_admin(self, async_test_db, async_test_user): """Create org where user is ADMIN.""" - test_engine, AsyncTestingSessionLocal = async_test_db + _test_engine, AsyncTestingSessionLocal = async_test_db async with AsyncTestingSessionLocal() as session: - org = Organization( - name="Admin Org", - slug="admin-org" - ) + org = Organization(name="Admin Org", slug="admin-org") session.add(org) await session.commit() await session.refresh(org) @@ -284,7 +265,7 @@ class TestAdminRolePermissions: user_id=async_test_user.id, organization_id=org.id, role=OrganizationRole.ADMIN, - is_active=True + is_active=True, ) session.add(membership) await session.commit() @@ -293,16 +274,13 @@ class TestAdminRolePermissions: @pytest.mark.asyncio async def test_admin_can_update_org( - self, - client, - regular_user_token, - test_org_with_admin + self, client, regular_user_token, test_org_with_admin ): """Admin should be able to update organization.""" response = await client.put( f"/api/v1/organizations/{test_org_with_admin.id}", json={"name": "Updated by Admin"}, - headers={"Authorization": f"Bearer {regular_user_token}"} + headers={"Authorization": f"Bearer {regular_user_token}"}, ) assert response.status_code == status.HTTP_200_OK diff --git a/backend/tests/api/test_permissions_security.py b/backend/tests/api/test_permissions_security.py index 34cddbf..46ac706 100644 --- a/backend/tests/api/test_permissions_security.py +++ b/backend/tests/api/test_permissions_security.py @@ -7,13 +7,13 @@ Critical security tests covering: These tests prevent unauthorized access and privilege escalation. """ + import pytest from httpx import AsyncClient -from sqlalchemy.ext.asyncio import AsyncSession -from app.models.user import User -from app.models.organization import Organization from app.crud.user import user as user_crud +from app.models.organization import Organization +from app.models.user import User class TestInactiveUserBlocking: @@ -29,11 +29,7 @@ class TestInactiveUserBlocking: @pytest.mark.asyncio async def test_inactive_user_cannot_access_protected_endpoints( - self, - client: AsyncClient, - async_test_db, - async_test_user: User, - user_token: str + self, client: AsyncClient, async_test_db, async_test_user: User, user_token: str ): """ Test that inactive users are blocked from protected endpoints. @@ -44,12 +40,11 @@ class TestInactiveUserBlocking: 3. User tries to access protected endpoint with valid token 4. System MUST reject (account inactive) """ - test_engine, SessionLocal = async_test_db + _test_engine, SessionLocal = async_test_db # Step 1: Verify user can access endpoint while active response = await client.get( - "/api/v1/users/me", - headers={"Authorization": f"Bearer {user_token}"} + "/api/v1/users/me", headers={"Authorization": f"Bearer {user_token}"} ) assert response.status_code == 200, "Active user should have access" @@ -61,8 +56,7 @@ class TestInactiveUserBlocking: # Step 3: User tries to access endpoint with same token response = await client.get( - "/api/v1/users/me", - headers={"Authorization": f"Bearer {user_token}"} + "/api/v1/users/me", headers={"Authorization": f"Bearer {user_token}"} ) # Step 4: System MUST reject (covers lines 52-57) @@ -75,18 +69,14 @@ class TestInactiveUserBlocking: @pytest.mark.asyncio async def test_inactive_user_blocked_from_organization_endpoints( - self, - client: AsyncClient, - async_test_db, - async_test_user: User, - user_token: str + self, client: AsyncClient, async_test_db, async_test_user: User, user_token: str ): """ Test that inactive users can't access organization endpoints. Ensures the inactive check applies to ALL protected endpoints. """ - test_engine, SessionLocal = async_test_db + _test_engine, SessionLocal = async_test_db # Deactivate user async with SessionLocal() as session: @@ -97,7 +87,7 @@ class TestInactiveUserBlocking: # Try to list organizations response = await client.get( "/api/v1/organizations/me", - headers={"Authorization": f"Bearer {user_token}"} + headers={"Authorization": f"Bearer {user_token}"}, ) # Must be blocked @@ -122,7 +112,7 @@ class TestSuperuserPrivilegeEscalation: client: AsyncClient, async_test_db, async_test_superuser: User, - superuser_token: str + superuser_token: str, ): """ Test that superusers automatically get OWNER role in organizations. @@ -131,14 +121,11 @@ class TestSuperuserPrivilegeEscalation: Superusers can manage any organization without being explicitly added. This is for platform administration. """ - test_engine, SessionLocal = async_test_db + _test_engine, SessionLocal = async_test_db # Step 1: Create an organization (owned by someone else) async with SessionLocal() as session: - org = Organization( - name="Test Organization", - slug="test-org" - ) + org = Organization(name="Test Organization", slug="test-org") session.add(org) await session.commit() await session.refresh(org) @@ -148,7 +135,7 @@ class TestSuperuserPrivilegeEscalation: # (They're not a member, but should auto-get OWNER role) response = await client.get( f"/api/v1/organizations/{org_id}", - headers={"Authorization": f"Bearer {superuser_token}"} + headers={"Authorization": f"Bearer {superuser_token}"}, ) # Step 3: Should have access (covers lines 154-157) @@ -161,21 +148,18 @@ class TestSuperuserPrivilegeEscalation: client: AsyncClient, async_test_db, async_test_superuser: User, - superuser_token: str + superuser_token: str, ): """ Test that superusers have full management access to all organizations. Ensures the OWNER role privilege escalation works end-to-end. """ - test_engine, SessionLocal = async_test_db + _test_engine, SessionLocal = async_test_db # Create an organization async with SessionLocal() as session: - org = Organization( - name="Test Organization", - slug="test-org" - ) + org = Organization(name="Test Organization", slug="test-org") session.add(org) await session.commit() await session.refresh(org) @@ -185,34 +169,29 @@ class TestSuperuserPrivilegeEscalation: response = await client.put( f"/api/v1/organizations/{org_id}", headers={"Authorization": f"Bearer {superuser_token}"}, - json={"name": "Updated Name"} + json={"name": "Updated Name"}, ) # Should succeed (superuser has OWNER privileges) - assert response.status_code in [200, 404], "Superuser should be able to manage any org" + assert response.status_code in [200, 404], ( + "Superuser should be able to manage any org" + ) # Note: Might be 404 if org endpoints require membership, but the role check passes @pytest.mark.asyncio async def test_regular_user_does_not_get_owner_role( - self, - client: AsyncClient, - async_test_db, - async_test_user: User, - user_token: str + self, client: AsyncClient, async_test_db, async_test_user: User, user_token: str ): """ Sanity check: Regular users don't get automatic OWNER role. Ensures the superuser check is working correctly (line 154). """ - test_engine, SessionLocal = async_test_db + _test_engine, SessionLocal = async_test_db # Create an organization async with SessionLocal() as session: - org = Organization( - name="Test Organization", - slug="test-org" - ) + org = Organization(name="Test Organization", slug="test-org") session.add(org) await session.commit() await session.refresh(org) @@ -221,8 +200,10 @@ class TestSuperuserPrivilegeEscalation: # Regular user tries to access it (not a member) response = await client.get( f"/api/v1/organizations/{org_id}", - headers={"Authorization": f"Bearer {user_token}"} + headers={"Authorization": f"Bearer {user_token}"}, ) # Should be denied (not a member, not a superuser) - assert response.status_code in [403, 404], "Regular user shouldn't access non-member org" + assert response.status_code in [403, 404], ( + "Regular user shouldn't access non-member org" + ) diff --git a/backend/tests/api/test_security_headers.py b/backend/tests/api/test_security_headers.py index 49898b6..b44b853 100755 --- a/backend/tests/api/test_security_headers.py +++ b/backend/tests/api/test_security_headers.py @@ -1,7 +1,8 @@ # tests/api/test_security_headers.py +from unittest.mock import patch + import pytest from fastapi.testclient import TestClient -from unittest.mock import patch from app.main import app @@ -11,8 +12,10 @@ def client(): """Create a FastAPI test client for the main app (module-scoped for speed).""" # Mock get_db to avoid database connection issues with patch("app.core.database.get_db") as mock_get_db: + async def mock_session_generator(): - from unittest.mock import MagicMock, AsyncMock + from unittest.mock import AsyncMock, MagicMock + mock_session = MagicMock() mock_session.execute = AsyncMock(return_value=None) mock_session.close = AsyncMock(return_value=None) @@ -77,8 +80,10 @@ class TestSecurityHeaders: """Test that HSTS header is set in production (covers line 95)""" with patch("app.core.config.settings.ENVIRONMENT", "production"): with patch("app.core.database.get_db") as mock_get_db: + async def mock_session_generator(): - from unittest.mock import MagicMock, AsyncMock + from unittest.mock import AsyncMock, MagicMock + mock_session = MagicMock() mock_session.execute = AsyncMock(return_value=None) mock_session.close = AsyncMock(return_value=None) @@ -88,20 +93,26 @@ class TestSecurityHeaders: # Need to reimport app to pick up the new settings from importlib import reload + import app.main + reload(app.main) test_client = TestClient(app.main.app) response = test_client.get("/health") assert "Strict-Transport-Security" in response.headers - assert "max-age=31536000" in response.headers["Strict-Transport-Security"] + assert ( + "max-age=31536000" in response.headers["Strict-Transport-Security"] + ) def test_csp_strict_mode(self): """Test CSP strict mode (covers line 121)""" with patch("app.core.config.settings.CSP_MODE", "strict"): with patch("app.core.database.get_db") as mock_get_db: + async def mock_session_generator(): - from unittest.mock import MagicMock, AsyncMock + from unittest.mock import AsyncMock, MagicMock + mock_session = MagicMock() mock_session.execute = AsyncMock(return_value=None) mock_session.close = AsyncMock(return_value=None) @@ -110,7 +121,9 @@ class TestSecurityHeaders: mock_get_db.side_effect = lambda: mock_session_generator() from importlib import reload + import app.main + reload(app.main) test_client = TestClient(app.main.app) @@ -136,8 +149,10 @@ class TestRootEndpoint: def test_root_endpoint(self): """Test root endpoint returns HTML (covers line 174)""" with patch("app.core.database.get_db") as mock_get_db: + async def mock_session_generator(): - from unittest.mock import MagicMock, AsyncMock + from unittest.mock import AsyncMock, MagicMock + mock_session = MagicMock() mock_session.execute = AsyncMock(return_value=None) mock_session.close = AsyncMock(return_value=None) diff --git a/backend/tests/api/test_sessions.py b/backend/tests/api/test_sessions.py index 826c9f5..acb0cbf 100644 --- a/backend/tests/api/test_sessions.py +++ b/backend/tests/api/test_sessions.py @@ -2,23 +2,23 @@ """ Comprehensive tests for session management API endpoints. """ + +from datetime import UTC, datetime, timedelta +from unittest.mock import patch +from uuid import uuid4 + import pytest import pytest_asyncio -from datetime import datetime, timedelta, timezone -from uuid import uuid4 -from unittest.mock import patch - from fastapi import status from app.models.user_session import UserSession -from app.schemas.users import UserCreate # Disable rate limiting for tests @pytest.fixture(autouse=True) def disable_rate_limit(): """Disable rate limiting for all tests in this module.""" - with patch('app.api.routes.sessions.limiter.enabled', False): + with patch("app.api.routes.sessions.limiter.enabled", False): yield @@ -27,10 +27,7 @@ async def user_token(client, async_test_user): """Create and return an access token for async_test_user.""" response = await client.post( "/api/v1/auth/login", - json={ - "email": "testuser@example.com", - "password": "TestPassword123!" - } + json={"email": "testuser@example.com", "password": "TestPassword123!"}, ) assert response.status_code == 200 return response.json()["access_token"] @@ -39,7 +36,7 @@ async def user_token(client, async_test_user): @pytest_asyncio.fixture async def async_test_user2(async_test_db): """Create a second test user.""" - test_engine, SessionLocal = async_test_db + _test_engine, SessionLocal = async_test_db async with SessionLocal() as session: from app.crud.user import user as user_crud @@ -49,7 +46,7 @@ async def async_test_user2(async_test_db): email="testuser2@example.com", password="TestPassword123!", first_name="Test", - last_name="User2" + last_name="User2", ) user = await user_crud.create(session, obj_in=user_data) await session.commit() @@ -61,9 +58,11 @@ class TestListMySessions: """Tests for GET /api/v1/sessions/me endpoint.""" @pytest.mark.asyncio - async def test_list_my_sessions_success(self, client, async_test_user, async_test_db, user_token): + async def test_list_my_sessions_success( + self, client, async_test_user, async_test_db, user_token + ): """Test successfully listing user's active sessions.""" - test_engine, SessionLocal = async_test_db + _test_engine, SessionLocal = async_test_db # Create some sessions for the user async with SessionLocal() as session: @@ -75,8 +74,8 @@ class TestListMySessions: ip_address="192.168.1.100", user_agent="Mozilla/5.0 (iPhone)", is_active=True, - expires_at=datetime.now(timezone.utc) + timedelta(days=7), - last_used_at=datetime.now(timezone.utc) + expires_at=datetime.now(UTC) + timedelta(days=7), + last_used_at=datetime.now(UTC), ) # Active session 2 s2 = UserSession( @@ -86,8 +85,8 @@ class TestListMySessions: ip_address="192.168.1.101", user_agent="Mozilla/5.0 (Macintosh)", is_active=True, - expires_at=datetime.now(timezone.utc) + timedelta(days=7), - last_used_at=datetime.now(timezone.utc) - timedelta(hours=1) + expires_at=datetime.now(UTC) + timedelta(days=7), + last_used_at=datetime.now(UTC) - timedelta(hours=1), ) # Inactive session (should not appear) s3 = UserSession( @@ -97,16 +96,15 @@ class TestListMySessions: ip_address="192.168.1.102", user_agent="Mozilla/5.0", is_active=False, - expires_at=datetime.now(timezone.utc) + timedelta(days=7), - last_used_at=datetime.now(timezone.utc) - timedelta(days=1) + expires_at=datetime.now(UTC) + timedelta(days=7), + last_used_at=datetime.now(UTC) - timedelta(days=1), ) session.add_all([s1, s2, s3]) await session.commit() # Make request response = await client.get( - "/api/v1/sessions/me", - headers={"Authorization": f"Bearer {user_token}"} + "/api/v1/sessions/me", headers={"Authorization": f"Bearer {user_token}"} ) assert response.status_code == status.HTTP_200_OK @@ -128,11 +126,12 @@ class TestListMySessions: assert data["sessions"][0]["is_current"] is True @pytest.mark.asyncio - async def test_list_my_sessions_with_login_session(self, client, async_test_user, user_token): + async def test_list_my_sessions_with_login_session( + self, client, async_test_user, user_token + ): """Test listing sessions shows the login session.""" response = await client.get( - "/api/v1/sessions/me", - headers={"Authorization": f"Bearer {user_token}"} + "/api/v1/sessions/me", headers={"Authorization": f"Bearer {user_token}"} ) assert response.status_code == status.HTTP_200_OK @@ -155,9 +154,11 @@ class TestRevokeSession: """Tests for DELETE /api/v1/sessions/{session_id} endpoint.""" @pytest.mark.asyncio - async def test_revoke_session_success(self, client, async_test_user, async_test_db, user_token): + async def test_revoke_session_success( + self, client, async_test_user, async_test_db, user_token + ): """Test successfully revoking a session.""" - test_engine, SessionLocal = async_test_db + _test_engine, SessionLocal = async_test_db # Create a session to revoke async with SessionLocal() as session: @@ -168,8 +169,8 @@ class TestRevokeSession: ip_address="192.168.1.103", user_agent="Mozilla/5.0 (iPad)", is_active=True, - expires_at=datetime.now(timezone.utc) + timedelta(days=7), - last_used_at=datetime.now(timezone.utc) + expires_at=datetime.now(UTC) + timedelta(days=7), + last_used_at=datetime.now(UTC), ) session.add(user_session) await session.commit() @@ -179,7 +180,7 @@ class TestRevokeSession: # Revoke the session response = await client.delete( f"/api/v1/sessions/{session_id}", - headers={"Authorization": f"Bearer {user_token}"} + headers={"Authorization": f"Bearer {user_token}"}, ) assert response.status_code == status.HTTP_200_OK @@ -191,6 +192,7 @@ class TestRevokeSession: # Verify session is deactivated async with SessionLocal() as session: from app.crud.session import session as session_crud + revoked_session = await session_crud.get(session, id=str(session_id)) assert revoked_session.is_active is False @@ -200,7 +202,7 @@ class TestRevokeSession: fake_id = uuid4() response = await client.delete( f"/api/v1/sessions/{fake_id}", - headers={"Authorization": f"Bearer {user_token}"} + headers={"Authorization": f"Bearer {user_token}"}, ) assert response.status_code == status.HTTP_404_NOT_FOUND @@ -222,7 +224,7 @@ class TestRevokeSession: self, client, async_test_user, async_test_user2, async_test_db, user_token ): """Test that users cannot revoke other users' sessions.""" - test_engine, SessionLocal = async_test_db + _test_engine, SessionLocal = async_test_db # Create a session for user2 async with SessionLocal() as session: @@ -233,8 +235,8 @@ class TestRevokeSession: ip_address="192.168.1.200", user_agent="Mozilla/5.0", is_active=True, - expires_at=datetime.now(timezone.utc) + timedelta(days=7), - last_used_at=datetime.now(timezone.utc) + expires_at=datetime.now(UTC) + timedelta(days=7), + last_used_at=datetime.now(UTC), ) session.add(other_user_session) await session.commit() @@ -244,7 +246,7 @@ class TestRevokeSession: # Try to revoke it as user1 response = await client.delete( f"/api/v1/sessions/{session_id}", - headers={"Authorization": f"Bearer {user_token}"} + headers={"Authorization": f"Bearer {user_token}"}, ) assert response.status_code == status.HTTP_403_FORBIDDEN @@ -263,7 +265,7 @@ class TestCleanupExpiredSessions: self, client, async_test_user, async_test_db, user_token ): """Test successfully cleaning up expired sessions.""" - test_engine, SessionLocal = async_test_db + _test_engine, SessionLocal = async_test_db # Create expired and active sessions using CRUD to avoid greenlet issues from app.crud.session import session as session_crud @@ -277,8 +279,8 @@ class TestCleanupExpiredSessions: device_name="Expired 1", ip_address="192.168.1.201", user_agent="Mozilla/5.0", - expires_at=datetime.now(timezone.utc) - timedelta(days=1), - last_used_at=datetime.now(timezone.utc) - timedelta(days=2) + expires_at=datetime.now(UTC) - timedelta(days=1), + last_used_at=datetime.now(UTC) - timedelta(days=2), ) e1 = await session_crud.create_session(db, obj_in=e1_data) e1.is_active = False @@ -291,8 +293,8 @@ class TestCleanupExpiredSessions: device_name="Expired 2", ip_address="192.168.1.202", user_agent="Mozilla/5.0", - expires_at=datetime.now(timezone.utc) - timedelta(hours=1), - last_used_at=datetime.now(timezone.utc) - timedelta(hours=2) + expires_at=datetime.now(UTC) - timedelta(hours=1), + last_used_at=datetime.now(UTC) - timedelta(hours=2), ) e2 = await session_crud.create_session(db, obj_in=e2_data) e2.is_active = False @@ -305,8 +307,8 @@ class TestCleanupExpiredSessions: device_name="Active", ip_address="192.168.1.203", user_agent="Mozilla/5.0", - expires_at=datetime.now(timezone.utc) + timedelta(days=7), - last_used_at=datetime.now(timezone.utc) + expires_at=datetime.now(UTC) + timedelta(days=7), + last_used_at=datetime.now(UTC), ) await session_crud.create_session(db, obj_in=a1_data) await db.commit() @@ -314,7 +316,7 @@ class TestCleanupExpiredSessions: # Cleanup expired sessions response = await client.delete( "/api/v1/sessions/me/expired", - headers={"Authorization": f"Bearer {user_token}"} + headers={"Authorization": f"Bearer {user_token}"}, ) assert response.status_code == status.HTTP_200_OK @@ -329,7 +331,7 @@ class TestCleanupExpiredSessions: self, client, async_test_user, async_test_db, user_token ): """Test cleanup when no sessions are expired.""" - test_engine, SessionLocal = async_test_db + _test_engine, SessionLocal = async_test_db # Create only active sessions using CRUD from app.crud.session import session as session_crud @@ -342,15 +344,15 @@ class TestCleanupExpiredSessions: device_name="Active Device", ip_address="192.168.1.210", user_agent="Mozilla/5.0", - expires_at=datetime.now(timezone.utc) + timedelta(days=7), - last_used_at=datetime.now(timezone.utc) + expires_at=datetime.now(UTC) + timedelta(days=7), + last_used_at=datetime.now(UTC), ) await session_crud.create_session(db, obj_in=a1_data) await db.commit() response = await client.delete( "/api/v1/sessions/me/expired", - headers={"Authorization": f"Bearer {user_token}"} + headers={"Authorization": f"Bearer {user_token}"}, ) assert response.status_code == status.HTTP_200_OK @@ -369,13 +371,16 @@ class TestCleanupExpiredSessions: # Additional tests for better coverage + class TestSessionsAdditionalCases: """Additional tests to improve sessions endpoint coverage.""" @pytest.mark.asyncio - async def test_list_sessions_pagination(self, client, async_test_user, async_test_db, user_token): + async def test_list_sessions_pagination( + self, client, async_test_user, async_test_db, user_token + ): """Test listing sessions with pagination.""" - test_engine, SessionLocal = async_test_db + _test_engine, SessionLocal = async_test_db # Create multiple sessions async with SessionLocal() as session: @@ -389,15 +394,15 @@ class TestSessionsAdditionalCases: device_name=f"Device {i}", ip_address=f"192.168.1.{i}", user_agent="Mozilla/5.0", - expires_at=datetime.now(timezone.utc) + timedelta(days=7), - last_used_at=datetime.now(timezone.utc) + expires_at=datetime.now(UTC) + timedelta(days=7), + last_used_at=datetime.now(UTC), ) await session_crud.create_session(session, obj_in=session_data) await session.commit() response = await client.get( "/api/v1/sessions/me?page=1&limit=3", - headers={"Authorization": f"Bearer {user_token}"} + headers={"Authorization": f"Bearer {user_token}"}, ) assert response.status_code == status.HTTP_200_OK @@ -410,16 +415,21 @@ class TestSessionsAdditionalCases: """Test revoking session with invalid UUID.""" response = await client.delete( "/api/v1/sessions/not-a-uuid", - headers={"Authorization": f"Bearer {user_token}"} + headers={"Authorization": f"Bearer {user_token}"}, ) # Should return 422 for invalid UUID format - assert response.status_code in [status.HTTP_422_UNPROCESSABLE_ENTITY, status.HTTP_404_NOT_FOUND] + assert response.status_code in [ + status.HTTP_422_UNPROCESSABLE_ENTITY, + status.HTTP_404_NOT_FOUND, + ] @pytest.mark.asyncio - async def test_cleanup_expired_sessions_with_mixed_states(self, client, async_test_user, async_test_db, user_token): + async def test_cleanup_expired_sessions_with_mixed_states( + self, client, async_test_user, async_test_db, user_token + ): """Test cleanup with mix of active/inactive and expired/not-expired sessions.""" - test_engine, SessionLocal = async_test_db + _test_engine, SessionLocal = async_test_db from app.crud.session import session as session_crud from app.schemas.sessions import SessionCreate @@ -432,8 +442,8 @@ class TestSessionsAdditionalCases: device_name="Expired Inactive", ip_address="192.168.1.100", user_agent="Mozilla/5.0", - expires_at=datetime.now(timezone.utc) - timedelta(days=1), - last_used_at=datetime.now(timezone.utc) - timedelta(days=2) + expires_at=datetime.now(UTC) - timedelta(days=1), + last_used_at=datetime.now(UTC) - timedelta(days=2), ) e1 = await session_crud.create_session(db, obj_in=e1_data) e1.is_active = False @@ -446,8 +456,8 @@ class TestSessionsAdditionalCases: device_name="Expired Active", ip_address="192.168.1.101", user_agent="Mozilla/5.0", - expires_at=datetime.now(timezone.utc) - timedelta(hours=1), - last_used_at=datetime.now(timezone.utc) - timedelta(hours=2) + expires_at=datetime.now(UTC) - timedelta(hours=1), + last_used_at=datetime.now(UTC) - timedelta(hours=2), ) await session_crud.create_session(db, obj_in=e2_data) @@ -455,7 +465,7 @@ class TestSessionsAdditionalCases: response = await client.delete( "/api/v1/sessions/me/expired", - headers={"Authorization": f"Bearer {user_token}"} + headers={"Authorization": f"Bearer {user_token}"}, ) assert response.status_code == status.HTTP_200_OK @@ -476,10 +486,12 @@ class TestSessionExceptionHandlers: from unittest.mock import patch # Patch decode_token to raise an exception - with patch('app.api.routes.sessions.decode_token', side_effect=Exception("Token decode error")): + with patch( + "app.api.routes.sessions.decode_token", + side_effect=Exception("Token decode error"), + ): response = await client.get( - "/api/v1/sessions/me", - headers={"Authorization": f"Bearer {user_token}"} + "/api/v1/sessions/me", headers={"Authorization": f"Bearer {user_token}"} ) # Should still succeed (exception is caught and ignored in try/except at line 77) @@ -489,12 +501,16 @@ class TestSessionExceptionHandlers: async def test_list_sessions_database_error(self, client, user_token): """Test list_sessions handles database errors (covers lines 104-106).""" from unittest.mock import patch + from app.crud import session as session_module - with patch.object(session_module.session, 'get_user_sessions', side_effect=Exception("Database error")): + with patch.object( + session_module.session, + "get_user_sessions", + side_effect=Exception("Database error"), + ): response = await client.get( - "/api/v1/sessions/me", - headers={"Authorization": f"Bearer {user_token}"} + "/api/v1/sessions/me", headers={"Authorization": f"Bearer {user_token}"} ) assert response.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR @@ -503,18 +519,21 @@ class TestSessionExceptionHandlers: assert data["errors"][0]["message"] == "Failed to retrieve sessions" @pytest.mark.asyncio - async def test_revoke_session_database_error(self, client, user_token, async_test_db, async_test_user): + async def test_revoke_session_database_error( + self, client, user_token, async_test_db, async_test_user + ): """Test revoke_session handles database errors (covers lines 181-183).""" + from datetime import datetime, timedelta from unittest.mock import patch from uuid import uuid4 + from app.crud import session as session_module # First create a session to revoke from app.crud.session import session as session_crud from app.schemas.sessions import SessionCreate - from datetime import datetime, timedelta, timezone - test_engine, AsyncTestingSessionLocal = async_test_db + _test_engine, AsyncTestingSessionLocal = async_test_db async with AsyncTestingSessionLocal() as db: session_in = SessionCreate( @@ -523,17 +542,21 @@ class TestSessionExceptionHandlers: device_name="Test Device", ip_address="192.168.1.1", user_agent="Mozilla/5.0", - last_used_at=datetime.now(timezone.utc), - expires_at=datetime.now(timezone.utc) + timedelta(days=60) + last_used_at=datetime.now(UTC), + expires_at=datetime.now(UTC) + timedelta(days=60), ) user_session = await session_crud.create_session(db, obj_in=session_in) session_id = user_session.id # Mock the deactivate method to raise an exception - with patch.object(session_module.session, 'deactivate', side_effect=Exception("Database connection lost")): + with patch.object( + session_module.session, + "deactivate", + side_effect=Exception("Database connection lost"), + ): response = await client.delete( f"/api/v1/sessions/{session_id}", - headers={"Authorization": f"Bearer {user_token}"} + headers={"Authorization": f"Bearer {user_token}"}, ) assert response.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR @@ -544,12 +567,17 @@ class TestSessionExceptionHandlers: async def test_cleanup_expired_sessions_database_error(self, client, user_token): """Test cleanup_expired_sessions handles database errors (covers lines 233-236).""" from unittest.mock import patch + from app.crud import session as session_module - with patch.object(session_module.session, 'cleanup_expired_for_user', side_effect=Exception("Cleanup failed")): + with patch.object( + session_module.session, + "cleanup_expired_for_user", + side_effect=Exception("Cleanup failed"), + ): response = await client.delete( "/api/v1/sessions/me/expired", - headers={"Authorization": f"Bearer {user_token}"} + headers={"Authorization": f"Bearer {user_token}"}, ) assert response.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR diff --git a/backend/tests/api/test_user_routes.py b/backend/tests/api/test_user_routes.py index cb22c12..cafd04a 100755 --- a/backend/tests/api/test_user_routes.py +++ b/backend/tests/api/test_user_routes.py @@ -3,32 +3,29 @@ Comprehensive tests for user management endpoints. These tests focus on finding potential bugs, not just coverage. """ -import pytest -import pytest_asyncio -from unittest.mock import patch -from fastapi import status -import uuid -from sqlalchemy import select +import uuid +from unittest.mock import patch + +import pytest +from fastapi import status + from app.models.user import User -from app.models.user import User -from app.schemas.users import UserUpdate # Disable rate limiting for tests @pytest.fixture(autouse=True) def disable_rate_limit(): """Disable rate limiting for all tests in this module.""" - with patch('app.api.routes.users.limiter.enabled', False): - with patch('app.api.routes.auth.limiter.enabled', False): + with patch("app.api.routes.users.limiter.enabled", False): + with patch("app.api.routes.auth.limiter.enabled", False): yield async def get_auth_headers(client, email, password): """Helper to get authentication headers.""" response = await client.post( - "/api/v1/auth/login", - json={"email": email, "password": password} + "/api/v1/auth/login", json={"email": email, "password": password} ) token = response.json()["access_token"] return {"Authorization": f"Bearer {token}"} @@ -40,7 +37,9 @@ class TestListUsers: @pytest.mark.asyncio async def test_list_users_as_superuser(self, client, async_test_superuser): """Test listing users as superuser.""" - headers = await get_auth_headers(client, async_test_superuser.email, "SuperPassword123!") + headers = await get_auth_headers( + client, async_test_superuser.email, "SuperPassword123!" + ) response = await client.get("/api/v1/users", headers=headers) @@ -53,16 +52,20 @@ class TestListUsers: @pytest.mark.asyncio async def test_list_users_as_regular_user(self, client, async_test_user): """Test that regular users cannot list users.""" - headers = await get_auth_headers(client, async_test_user.email, "TestPassword123!") + headers = await get_auth_headers( + client, async_test_user.email, "TestPassword123!" + ) response = await client.get("/api/v1/users", headers=headers) assert response.status_code == status.HTTP_403_FORBIDDEN @pytest.mark.asyncio - async def test_list_users_pagination(self, client, async_test_superuser, async_test_db): + async def test_list_users_pagination( + self, client, async_test_superuser, async_test_db + ): """Test pagination works correctly.""" - test_engine, AsyncTestingSessionLocal = async_test_db + _test_engine, AsyncTestingSessionLocal = async_test_db # Create multiple users async with AsyncTestingSessionLocal() as session: @@ -72,12 +75,14 @@ class TestListUsers: password_hash="hash", first_name=f"PagUser{i}", is_active=True, - is_superuser=False + is_superuser=False, ) session.add(user) await session.commit() - headers = await get_auth_headers(client, async_test_superuser.email, "SuperPassword123!") + headers = await get_auth_headers( + client, async_test_superuser.email, "SuperPassword123!" + ) # Get first page response = await client.get("/api/v1/users?page=1&limit=5", headers=headers) @@ -88,9 +93,11 @@ class TestListUsers: assert data["pagination"]["total"] >= 15 @pytest.mark.asyncio - async def test_list_users_filter_active(self, client, async_test_superuser, async_test_db): + async def test_list_users_filter_active( + self, client, async_test_superuser, async_test_db + ): """Test filtering by active status.""" - test_engine, AsyncTestingSessionLocal = async_test_db + _test_engine, AsyncTestingSessionLocal = async_test_db # Create active and inactive users async with AsyncTestingSessionLocal() as session: @@ -99,19 +106,21 @@ class TestListUsers: password_hash="hash", first_name="Active", is_active=True, - is_superuser=False + is_superuser=False, ) inactive_user = User( email="inactivefilter@example.com", password_hash="hash", first_name="Inactive", is_active=False, - is_superuser=False + is_superuser=False, ) session.add_all([active_user, inactive_user]) await session.commit() - headers = await get_auth_headers(client, async_test_superuser.email, "SuperPassword123!") + headers = await get_auth_headers( + client, async_test_superuser.email, "SuperPassword123!" + ) # Filter for active users response = await client.get("/api/v1/users?is_active=true", headers=headers) @@ -130,9 +139,13 @@ class TestListUsers: @pytest.mark.asyncio async def test_list_users_sort_by_email(self, client, async_test_superuser): """Test sorting users by email.""" - headers = await get_auth_headers(client, async_test_superuser.email, "SuperPassword123!") + headers = await get_auth_headers( + client, async_test_superuser.email, "SuperPassword123!" + ) - response = await client.get("/api/v1/users?sort_by=email&sort_order=asc", headers=headers) + response = await client.get( + "/api/v1/users?sort_by=email&sort_order=asc", headers=headers + ) assert response.status_code == status.HTTP_200_OK data = response.json() emails = [u["email"] for u in data["data"]] @@ -154,7 +167,9 @@ class TestGetCurrentUserProfile: @pytest.mark.asyncio async def test_get_own_profile(self, client, async_test_user): """Test getting own profile.""" - headers = await get_auth_headers(client, async_test_user.email, "TestPassword123!") + headers = await get_auth_headers( + client, async_test_user.email, "TestPassword123!" + ) response = await client.get("/api/v1/users/me", headers=headers) @@ -176,12 +191,14 @@ class TestUpdateCurrentUser: @pytest.mark.asyncio async def test_update_own_profile(self, client, async_test_user): """Test updating own profile.""" - headers = await get_auth_headers(client, async_test_user.email, "TestPassword123!") + headers = await get_auth_headers( + client, async_test_user.email, "TestPassword123!" + ) response = await client.patch( "/api/v1/users/me", headers=headers, - json={"first_name": "Updated", "last_name": "Name"} + json={"first_name": "Updated", "last_name": "Name"}, ) assert response.status_code == status.HTTP_200_OK @@ -192,12 +209,12 @@ class TestUpdateCurrentUser: @pytest.mark.asyncio async def test_update_profile_phone_number(self, client, async_test_user, test_db): """Test updating phone number with validation.""" - headers = await get_auth_headers(client, async_test_user.email, "TestPassword123!") + headers = await get_auth_headers( + client, async_test_user.email, "TestPassword123!" + ) response = await client.patch( - "/api/v1/users/me", - headers=headers, - json={"phone_number": "+19876543210"} + "/api/v1/users/me", headers=headers, json={"phone_number": "+19876543210"} ) assert response.status_code == status.HTTP_200_OK @@ -207,12 +224,12 @@ class TestUpdateCurrentUser: @pytest.mark.asyncio async def test_update_profile_invalid_phone(self, client, async_test_user): """Test that invalid phone numbers are rejected.""" - headers = await get_auth_headers(client, async_test_user.email, "TestPassword123!") + headers = await get_auth_headers( + client, async_test_user.email, "TestPassword123!" + ) response = await client.patch( - "/api/v1/users/me", - headers=headers, - json={"phone_number": "invalid"} + "/api/v1/users/me", headers=headers, json={"phone_number": "invalid"} ) assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY @@ -220,14 +237,16 @@ class TestUpdateCurrentUser: @pytest.mark.asyncio async def test_cannot_elevate_to_superuser(self, client, async_test_user): """Test that users cannot make themselves superuser.""" - headers = await get_auth_headers(client, async_test_user.email, "TestPassword123!") + headers = await get_auth_headers( + client, async_test_user.email, "TestPassword123!" + ) # Note: is_superuser is now in UserUpdate schema with explicit validation # This tests that Pydantic rejects the attempt at the schema level response = await client.patch( "/api/v1/users/me", headers=headers, - json={"first_name": "Test", "is_superuser": True} + json={"first_name": "Test", "is_superuser": True}, ) # Pydantic validation should reject this at the schema level @@ -242,10 +261,7 @@ class TestUpdateCurrentUser: @pytest.mark.asyncio async def test_update_profile_no_auth(self, client): """Test that unauthenticated requests are rejected.""" - response = await client.patch( - "/api/v1/users/me", - json={"first_name": "Hacker"} - ) + response = await client.patch("/api/v1/users/me", json={"first_name": "Hacker"}) assert response.status_code == status.HTTP_401_UNAUTHORIZED # Note: Removed test_update_profile_unexpected_error - see comment above @@ -257,16 +273,22 @@ class TestGetUserById: @pytest.mark.asyncio async def test_get_own_profile_by_id(self, client, async_test_user): """Test getting own profile by ID.""" - headers = await get_auth_headers(client, async_test_user.email, "TestPassword123!") + headers = await get_auth_headers( + client, async_test_user.email, "TestPassword123!" + ) - response = await client.get(f"/api/v1/users/{async_test_user.id}", headers=headers) + response = await client.get( + f"/api/v1/users/{async_test_user.id}", headers=headers + ) assert response.status_code == status.HTTP_200_OK data = response.json() assert data["email"] == async_test_user.email @pytest.mark.asyncio - async def test_get_other_user_as_regular_user(self, client, async_test_user, test_db): + async def test_get_other_user_as_regular_user( + self, client, async_test_user, test_db + ): """Test that regular users cannot view other profiles.""" # Create another user other_user = User( @@ -274,24 +296,32 @@ class TestGetUserById: password_hash="hash", first_name="Other", is_active=True, - is_superuser=False + is_superuser=False, ) test_db.add(other_user) test_db.commit() test_db.refresh(other_user) - headers = await get_auth_headers(client, async_test_user.email, "TestPassword123!") + headers = await get_auth_headers( + client, async_test_user.email, "TestPassword123!" + ) response = await client.get(f"/api/v1/users/{other_user.id}", headers=headers) assert response.status_code == status.HTTP_403_FORBIDDEN @pytest.mark.asyncio - async def test_get_other_user_as_superuser(self, client, async_test_superuser, async_test_user): + async def test_get_other_user_as_superuser( + self, client, async_test_superuser, async_test_user + ): """Test that superusers can view other profiles.""" - headers = await get_auth_headers(client, async_test_superuser.email, "SuperPassword123!") + headers = await get_auth_headers( + client, async_test_superuser.email, "SuperPassword123!" + ) - response = await client.get(f"/api/v1/users/{async_test_user.id}", headers=headers) + response = await client.get( + f"/api/v1/users/{async_test_user.id}", headers=headers + ) assert response.status_code == status.HTTP_200_OK data = response.json() @@ -300,7 +330,9 @@ class TestGetUserById: @pytest.mark.asyncio async def test_get_nonexistent_user(self, client, async_test_superuser): """Test getting non-existent user.""" - headers = await get_auth_headers(client, async_test_superuser.email, "SuperPassword123!") + headers = await get_auth_headers( + client, async_test_superuser.email, "SuperPassword123!" + ) fake_id = uuid.uuid4() response = await client.get(f"/api/v1/users/{fake_id}", headers=headers) @@ -310,7 +342,9 @@ class TestGetUserById: @pytest.mark.asyncio async def test_get_user_invalid_uuid(self, client, async_test_superuser): """Test getting user with invalid UUID format.""" - headers = await get_auth_headers(client, async_test_superuser.email, "SuperPassword123!") + headers = await get_auth_headers( + client, async_test_superuser.email, "SuperPassword123!" + ) response = await client.get("/api/v1/users/not-a-uuid", headers=headers) @@ -323,12 +357,14 @@ class TestUpdateUserById: @pytest.mark.asyncio async def test_update_own_profile_by_id(self, client, async_test_user, test_db): """Test updating own profile by ID.""" - headers = await get_auth_headers(client, async_test_user.email, "TestPassword123!") + headers = await get_auth_headers( + client, async_test_user.email, "TestPassword123!" + ) response = await client.patch( f"/api/v1/users/{async_test_user.id}", headers=headers, - json={"first_name": "SelfUpdated"} + json={"first_name": "SelfUpdated"}, ) assert response.status_code == status.HTTP_200_OK @@ -336,7 +372,9 @@ class TestUpdateUserById: assert data["first_name"] == "SelfUpdated" @pytest.mark.asyncio - async def test_update_other_user_as_regular_user(self, client, async_test_user, test_db): + async def test_update_other_user_as_regular_user( + self, client, async_test_user, test_db + ): """Test that regular users cannot update other profiles.""" # Create another user other_user = User( @@ -344,18 +382,20 @@ class TestUpdateUserById: password_hash="hash", first_name="Other", is_active=True, - is_superuser=False + is_superuser=False, ) test_db.add(other_user) test_db.commit() test_db.refresh(other_user) - headers = await get_auth_headers(client, async_test_user.email, "TestPassword123!") + headers = await get_auth_headers( + client, async_test_user.email, "TestPassword123!" + ) response = await client.patch( f"/api/v1/users/{other_user.id}", headers=headers, - json={"first_name": "Hacked"} + json={"first_name": "Hacked"}, ) assert response.status_code == status.HTTP_403_FORBIDDEN @@ -365,14 +405,18 @@ class TestUpdateUserById: assert other_user.first_name == "Other" @pytest.mark.asyncio - async def test_update_other_user_as_superuser(self, client, async_test_superuser, async_test_user, test_db): + async def test_update_other_user_as_superuser( + self, client, async_test_superuser, async_test_user, test_db + ): """Test that superusers can update other profiles.""" - headers = await get_auth_headers(client, async_test_superuser.email, "SuperPassword123!") + headers = await get_auth_headers( + client, async_test_superuser.email, "SuperPassword123!" + ) response = await client.patch( f"/api/v1/users/{async_test_user.id}", headers=headers, - json={"first_name": "AdminUpdated"} + json={"first_name": "AdminUpdated"}, ) assert response.status_code == status.HTTP_200_OK @@ -380,16 +424,20 @@ class TestUpdateUserById: assert data["first_name"] == "AdminUpdated" @pytest.mark.asyncio - async def test_regular_user_cannot_modify_superuser_status(self, client, async_test_user): + async def test_regular_user_cannot_modify_superuser_status( + self, client, async_test_user + ): """Test that regular users cannot change superuser status even if they try.""" - headers = await get_auth_headers(client, async_test_user.email, "TestPassword123!") + headers = await get_auth_headers( + client, async_test_user.email, "TestPassword123!" + ) # is_superuser not in UserUpdate schema, so it gets ignored by Pydantic # Just verify the user stays the same response = await client.patch( f"/api/v1/users/{async_test_user.id}", headers=headers, - json={"first_name": "Test"} + json={"first_name": "Test"}, ) assert response.status_code == status.HTTP_200_OK @@ -397,14 +445,18 @@ class TestUpdateUserById: assert data["is_superuser"] is False @pytest.mark.asyncio - async def test_superuser_can_update_users(self, client, async_test_superuser, async_test_user, test_db): + async def test_superuser_can_update_users( + self, client, async_test_superuser, async_test_user, test_db + ): """Test that superusers can update other users.""" - headers = await get_auth_headers(client, async_test_superuser.email, "SuperPassword123!") + headers = await get_auth_headers( + client, async_test_superuser.email, "SuperPassword123!" + ) response = await client.patch( f"/api/v1/users/{async_test_user.id}", headers=headers, - json={"first_name": "AdminChanged", "is_active": False} + json={"first_name": "AdminChanged", "is_active": False}, ) assert response.status_code == status.HTTP_200_OK @@ -415,13 +467,13 @@ class TestUpdateUserById: @pytest.mark.asyncio async def test_update_nonexistent_user(self, client, async_test_superuser): """Test updating non-existent user.""" - headers = await get_auth_headers(client, async_test_superuser.email, "SuperPassword123!") + headers = await get_auth_headers( + client, async_test_superuser.email, "SuperPassword123!" + ) fake_id = uuid.uuid4() response = await client.patch( - f"/api/v1/users/{fake_id}", - headers=headers, - json={"first_name": "Ghost"} + f"/api/v1/users/{fake_id}", headers=headers, json={"first_name": "Ghost"} ) assert response.status_code == status.HTTP_404_NOT_FOUND @@ -435,15 +487,17 @@ class TestChangePassword: @pytest.mark.asyncio async def test_change_password_success(self, client, async_test_user, test_db): """Test successful password change.""" - headers = await get_auth_headers(client, async_test_user.email, "TestPassword123!") + headers = await get_auth_headers( + client, async_test_user.email, "TestPassword123!" + ) response = await client.patch( "/api/v1/users/me/password", headers=headers, json={ "current_password": "TestPassword123!", - "new_password": "NewPassword123!" - } + "new_password": "NewPassword123!", + }, ) assert response.status_code == status.HTTP_200_OK @@ -453,25 +507,24 @@ class TestChangePassword: # Verify can login with new password login_response = await client.post( "/api/v1/auth/login", - json={ - "email": async_test_user.email, - "password": "NewPassword123!" - } + json={"email": async_test_user.email, "password": "NewPassword123!"}, ) assert login_response.status_code == status.HTTP_200_OK @pytest.mark.asyncio async def test_change_password_wrong_current(self, client, async_test_user): """Test that wrong current password is rejected.""" - headers = await get_auth_headers(client, async_test_user.email, "TestPassword123!") + headers = await get_auth_headers( + client, async_test_user.email, "TestPassword123!" + ) response = await client.patch( "/api/v1/users/me/password", headers=headers, json={ "current_password": "WrongPassword123", - "new_password": "NewPassword123!" - } + "new_password": "NewPassword123!", + }, ) assert response.status_code == status.HTTP_403_FORBIDDEN @@ -479,15 +532,14 @@ class TestChangePassword: @pytest.mark.asyncio async def test_change_password_weak_new_password(self, client, async_test_user): """Test that weak new passwords are rejected.""" - headers = await get_auth_headers(client, async_test_user.email, "TestPassword123!") + headers = await get_auth_headers( + client, async_test_user.email, "TestPassword123!" + ) response = await client.patch( "/api/v1/users/me/password", headers=headers, - json={ - "current_password": "TestPassword123!", - "new_password": "weak" - } + json={"current_password": "TestPassword123!", "new_password": "weak"}, ) assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY @@ -499,8 +551,8 @@ class TestChangePassword: "/api/v1/users/me/password", json={ "current_password": "TestPassword123!", - "new_password": "NewPassword123!" - } + "new_password": "NewPassword123!", + }, ) assert response.status_code == status.HTTP_401_UNAUTHORIZED @@ -511,9 +563,11 @@ class TestDeleteUser: """Tests for DELETE /users/{user_id} endpoint.""" @pytest.mark.asyncio - async def test_delete_user_as_superuser(self, client, async_test_superuser, async_test_db): + async def test_delete_user_as_superuser( + self, client, async_test_superuser, async_test_db + ): """Test deleting a user as superuser.""" - test_engine, AsyncTestingSessionLocal = async_test_db + _test_engine, AsyncTestingSessionLocal = async_test_db # Create a user to delete async with AsyncTestingSessionLocal() as session: @@ -522,14 +576,16 @@ class TestDeleteUser: password_hash="hash", first_name="Delete", is_active=True, - is_superuser=False + is_superuser=False, ) session.add(user_to_delete) await session.commit() await session.refresh(user_to_delete) user_id = user_to_delete.id - headers = await get_auth_headers(client, async_test_superuser.email, "SuperPassword123!") + headers = await get_auth_headers( + client, async_test_superuser.email, "SuperPassword123!" + ) response = await client.delete(f"/api/v1/users/{user_id}", headers=headers) @@ -540,6 +596,7 @@ class TestDeleteUser: # Verify user is soft-deleted (has deleted_at timestamp) async with AsyncTestingSessionLocal() as session: from sqlalchemy import select + result = await session.execute(select(User).where(User.id == user_id)) deleted_user = result.scalar_one_or_none() assert deleted_user.deleted_at is not None @@ -547,9 +604,13 @@ class TestDeleteUser: @pytest.mark.asyncio async def test_cannot_delete_self(self, client, async_test_superuser): """Test that users cannot delete their own account.""" - headers = await get_auth_headers(client, async_test_superuser.email, "SuperPassword123!") + headers = await get_auth_headers( + client, async_test_superuser.email, "SuperPassword123!" + ) - response = await client.delete(f"/api/v1/users/{async_test_superuser.id}", headers=headers) + response = await client.delete( + f"/api/v1/users/{async_test_superuser.id}", headers=headers + ) assert response.status_code == status.HTTP_403_FORBIDDEN @@ -562,22 +623,28 @@ class TestDeleteUser: password_hash="hash", first_name="Protected", is_active=True, - is_superuser=False + is_superuser=False, ) test_db.add(other_user) test_db.commit() test_db.refresh(other_user) - headers = await get_auth_headers(client, async_test_user.email, "TestPassword123!") + headers = await get_auth_headers( + client, async_test_user.email, "TestPassword123!" + ) - response = await client.delete(f"/api/v1/users/{other_user.id}", headers=headers) + response = await client.delete( + f"/api/v1/users/{other_user.id}", headers=headers + ) assert response.status_code == status.HTTP_403_FORBIDDEN @pytest.mark.asyncio async def test_delete_nonexistent_user(self, client, async_test_superuser): """Test deleting non-existent user.""" - headers = await get_auth_headers(client, async_test_superuser.email, "SuperPassword123!") + headers = await get_auth_headers( + client, async_test_superuser.email, "SuperPassword123!" + ) fake_id = uuid.uuid4() response = await client.delete(f"/api/v1/users/{fake_id}", headers=headers) diff --git a/backend/tests/api/test_users.py b/backend/tests/api/test_users.py index a33ef82..74a7f23 100644 --- a/backend/tests/api/test_users.py +++ b/backend/tests/api/test_users.py @@ -2,10 +2,12 @@ """ Tests for user routes. """ + +from uuid import uuid4 + import pytest import pytest_asyncio from fastapi import status -from uuid import uuid4 @pytest_asyncio.fixture @@ -13,10 +15,7 @@ async def superuser_token(client, async_test_superuser): """Get access token for superuser.""" response = await client.post( "/api/v1/auth/login", - json={ - "email": "superuser@example.com", - "password": "SuperPassword123!" - } + json={"email": "superuser@example.com", "password": "SuperPassword123!"}, ) assert response.status_code == 200 return response.json()["access_token"] @@ -27,10 +26,7 @@ async def user_token(client, async_test_user): """Get access token for regular user.""" response = await client.post( "/api/v1/auth/login", - json={ - "email": "testuser@example.com", - "password": "TestPassword123!" - } + json={"email": "testuser@example.com", "password": "TestPassword123!"}, ) assert response.status_code == 200 return response.json()["access_token"] @@ -43,8 +39,7 @@ class TestListUsers: async def test_list_users_success(self, client, superuser_token): """Test listing users successfully (covers lines 87-100).""" response = await client.get( - "/api/v1/users", - headers={"Authorization": f"Bearer {superuser_token}"} + "/api/v1/users", headers={"Authorization": f"Bearer {superuser_token}"} ) assert response.status_code == status.HTTP_200_OK @@ -58,7 +53,7 @@ class TestListUsers: """Test listing users with is_superuser filter (covers line 74).""" response = await client.get( "/api/v1/users?is_superuser=true", - headers={"Authorization": f"Bearer {superuser_token}"} + headers={"Authorization": f"Bearer {superuser_token}"}, ) assert response.status_code == status.HTTP_200_OK @@ -73,8 +68,7 @@ class TestGetCurrentUser: async def test_get_current_user_success(self, client, async_test_user, user_token): """Test getting current user profile.""" response = await client.get( - "/api/v1/users/me", - headers={"Authorization": f"Bearer {user_token}"} + "/api/v1/users/me", headers={"Authorization": f"Bearer {user_token}"} ) assert response.status_code == status.HTTP_200_OK @@ -92,7 +86,7 @@ class TestUpdateCurrentUser: response = await client.patch( "/api/v1/users/me", headers={"Authorization": f"Bearer {user_token}"}, - json={"first_name": "UpdatedName"} + json={"first_name": "UpdatedName"}, ) assert response.status_code == status.HTTP_200_OK @@ -104,12 +98,14 @@ class TestUpdateCurrentUser: """Test database error handling during update (covers lines 162-169).""" from unittest.mock import patch - with patch('app.api.routes.users.user_crud.update', side_effect=Exception("DB error")): + with patch( + "app.api.routes.users.user_crud.update", side_effect=Exception("DB error") + ): with pytest.raises(Exception): await client.patch( "/api/v1/users/me", headers={"Authorization": f"Bearer {user_token}"}, - json={"first_name": "Updated"} + json={"first_name": "Updated"}, ) @pytest.mark.asyncio @@ -118,7 +114,7 @@ class TestUpdateCurrentUser: response = await client.patch( "/api/v1/users/me", headers={"Authorization": f"Bearer {user_token}"}, - json={"is_superuser": True} + json={"is_superuser": True}, ) # Pydantic validation should reject this at the schema level @@ -137,12 +133,15 @@ class TestUpdateCurrentUser: """Test ValueError handling during update (covers lines 165-166).""" from unittest.mock import patch - with patch('app.api.routes.users.user_crud.update', side_effect=ValueError("Invalid value")): + with patch( + "app.api.routes.users.user_crud.update", + side_effect=ValueError("Invalid value"), + ): with pytest.raises(ValueError): await client.patch( "/api/v1/users/me", headers={"Authorization": f"Bearer {user_token}"}, - json={"first_name": "Updated"} + json={"first_name": "Updated"}, ) @@ -154,7 +153,7 @@ class TestGetUser: """Test getting user by ID.""" response = await client.get( f"/api/v1/users/{async_test_user.id}", - headers={"Authorization": f"Bearer {superuser_token}"} + headers={"Authorization": f"Bearer {superuser_token}"}, ) assert response.status_code == status.HTTP_200_OK @@ -167,7 +166,7 @@ class TestGetUser: fake_id = uuid4() response = await client.get( f"/api/v1/users/{fake_id}", - headers={"Authorization": f"Bearer {superuser_token}"} + headers={"Authorization": f"Bearer {superuser_token}"}, ) assert response.status_code == status.HTTP_404_NOT_FOUND @@ -183,30 +182,34 @@ class TestUpdateUserById: response = await client.patch( f"/api/v1/users/{fake_id}", headers={"Authorization": f"Bearer {superuser_token}"}, - json={"first_name": "Updated"} + json={"first_name": "Updated"}, ) assert response.status_code == status.HTTP_404_NOT_FOUND @pytest.mark.asyncio - async def test_update_user_by_id_non_superuser_cannot_change_superuser_status(self, client, async_test_user, user_token): + async def test_update_user_by_id_non_superuser_cannot_change_superuser_status( + self, client, async_test_user, user_token + ): """Test non-superuser cannot modify superuser status (Pydantic validation).""" response = await client.patch( f"/api/v1/users/{async_test_user.id}", headers={"Authorization": f"Bearer {user_token}"}, - json={"is_superuser": True} + json={"is_superuser": True}, ) # Pydantic validation should reject this at the schema level assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY @pytest.mark.asyncio - async def test_update_user_by_id_success(self, client, async_test_user, superuser_token): + async def test_update_user_by_id_success( + self, client, async_test_user, superuser_token + ): """Test updating user successfully (covers lines 276-278).""" response = await client.patch( f"/api/v1/users/{async_test_user.id}", headers={"Authorization": f"Bearer {superuser_token}"}, - json={"first_name": "SuperUpdated"} + json={"first_name": "SuperUpdated"}, ) assert response.status_code == status.HTTP_200_OK @@ -214,29 +217,37 @@ class TestUpdateUserById: assert data["first_name"] == "SuperUpdated" @pytest.mark.asyncio - async def test_update_user_by_id_value_error(self, client, async_test_user, superuser_token): + async def test_update_user_by_id_value_error( + self, client, async_test_user, superuser_token + ): """Test ValueError handling (covers lines 280-281).""" from unittest.mock import patch - with patch('app.api.routes.users.user_crud.update', side_effect=ValueError("Invalid")): + with patch( + "app.api.routes.users.user_crud.update", side_effect=ValueError("Invalid") + ): with pytest.raises(ValueError): await client.patch( f"/api/v1/users/{async_test_user.id}", headers={"Authorization": f"Bearer {superuser_token}"}, - json={"first_name": "Updated"} + json={"first_name": "Updated"}, ) @pytest.mark.asyncio - async def test_update_user_by_id_unexpected_error(self, client, async_test_user, superuser_token): + async def test_update_user_by_id_unexpected_error( + self, client, async_test_user, superuser_token + ): """Test unexpected error handling (covers lines 283-284).""" from unittest.mock import patch - with patch('app.api.routes.users.user_crud.update', side_effect=Exception("Unexpected")): + with patch( + "app.api.routes.users.user_crud.update", side_effect=Exception("Unexpected") + ): with pytest.raises(Exception): await client.patch( f"/api/v1/users/{async_test_user.id}", headers={"Authorization": f"Bearer {superuser_token}"}, - json={"first_name": "Updated"} + json={"first_name": "Updated"}, ) @@ -246,18 +257,18 @@ class TestChangePassword: @pytest.mark.asyncio async def test_change_password_success(self, client, async_test_db): """Test changing password successfully.""" - test_engine, AsyncTestingSessionLocal = async_test_db + _test_engine, AsyncTestingSessionLocal = async_test_db # Create a fresh user async with AsyncTestingSessionLocal() as session: - from app.models.user import User from app.core.auth import get_password_hash + from app.models.user import User new_user = User( email="changepass@example.com", password_hash=get_password_hash("OldPassword123!"), first_name="Change", - last_name="Pass" + last_name="Pass", ) session.add(new_user) await session.commit() @@ -265,10 +276,7 @@ class TestChangePassword: # Login login_response = await client.post( "/api/v1/auth/login", - json={ - "email": "changepass@example.com", - "password": "OldPassword123!" - } + json={"email": "changepass@example.com", "password": "OldPassword123!"}, ) token = login_response.json()["access_token"] @@ -278,8 +286,8 @@ class TestChangePassword: headers={"Authorization": f"Bearer {token}"}, json={ "current_password": "OldPassword123!", - "new_password": "NewPassword456!" - } + "new_password": "NewPassword456!", + }, ) assert response.status_code == status.HTTP_200_OK @@ -289,10 +297,7 @@ class TestChangePassword: # Verify new password works login_response = await client.post( "/api/v1/auth/login", - json={ - "email": "changepass@example.com", - "password": "NewPassword456!" - } + json={"email": "changepass@example.com", "password": "NewPassword456!"}, ) assert login_response.status_code == status.HTTP_200_OK @@ -306,7 +311,7 @@ class TestDeleteUserById: fake_id = uuid4() response = await client.delete( f"/api/v1/users/{fake_id}", - headers={"Authorization": f"Bearer {superuser_token}"} + headers={"Authorization": f"Bearer {superuser_token}"}, ) assert response.status_code == status.HTTP_404_NOT_FOUND @@ -314,18 +319,18 @@ class TestDeleteUserById: @pytest.mark.asyncio async def test_delete_user_success(self, client, async_test_db, superuser_token): """Test deleting user successfully (covers lines 383-388).""" - test_engine, AsyncTestingSessionLocal = async_test_db + _test_engine, AsyncTestingSessionLocal = async_test_db # Create a user to delete async with AsyncTestingSessionLocal() as session: - from app.models.user import User from app.core.auth import get_password_hash + from app.models.user import User user_to_delete = User( email=f"delete{uuid4().hex[:8]}@example.com", password_hash=get_password_hash("Password123!"), first_name="Delete", - last_name="Me" + last_name="Me", ) session.add(user_to_delete) await session.commit() @@ -334,7 +339,7 @@ class TestDeleteUserById: response = await client.delete( f"/api/v1/users/{user_id}", - headers={"Authorization": f"Bearer {superuser_token}"} + headers={"Authorization": f"Bearer {superuser_token}"}, ) assert response.status_code == status.HTTP_200_OK @@ -342,25 +347,35 @@ class TestDeleteUserById: assert data["success"] is True @pytest.mark.asyncio - async def test_delete_user_value_error(self, client, async_test_user, superuser_token): + async def test_delete_user_value_error( + self, client, async_test_user, superuser_token + ): """Test ValueError handling during delete (covers lines 390-391).""" from unittest.mock import patch - with patch('app.api.routes.users.user_crud.soft_delete', side_effect=ValueError("Cannot delete")): + with patch( + "app.api.routes.users.user_crud.soft_delete", + side_effect=ValueError("Cannot delete"), + ): with pytest.raises(ValueError): await client.delete( f"/api/v1/users/{async_test_user.id}", - headers={"Authorization": f"Bearer {superuser_token}"} + headers={"Authorization": f"Bearer {superuser_token}"}, ) @pytest.mark.asyncio - async def test_delete_user_unexpected_error(self, client, async_test_user, superuser_token): + async def test_delete_user_unexpected_error( + self, client, async_test_user, superuser_token + ): """Test unexpected error handling during delete (covers lines 393-394).""" from unittest.mock import patch - with patch('app.api.routes.users.user_crud.soft_delete', side_effect=Exception("Unexpected")): + with patch( + "app.api.routes.users.user_crud.soft_delete", + side_effect=Exception("Unexpected"), + ): with pytest.raises(Exception): await client.delete( f"/api/v1/users/{async_test_user.id}", - headers={"Authorization": f"Bearer {superuser_token}"} + headers={"Authorization": f"Bearer {superuser_token}"}, ) diff --git a/backend/tests/conftest.py b/backend/tests/conftest.py index b9395bd..7ed5973 100755 --- a/backend/tests/conftest.py +++ b/backend/tests/conftest.py @@ -1,28 +1,32 @@ # tests/conftest.py import os import uuid -from datetime import datetime, timezone import pytest import pytest_asyncio -from httpx import AsyncClient, ASGITransport +from httpx import ASGITransport, AsyncClient # Set IS_TEST environment variable BEFORE importing app # This prevents the scheduler from starting during tests os.environ["IS_TEST"] = "True" -from app.main import app -from app.core.database import get_db -from app.models.user import User from app.core.auth import get_password_hash -from app.utils.test_utils import setup_test_db, teardown_test_db, setup_async_test_db, teardown_async_test_db +from app.core.database import get_db +from app.main import app +from app.models.user import User +from app.utils.test_utils import ( + setup_async_test_db, + setup_test_db, + teardown_async_test_db, + teardown_test_db, +) @pytest.fixture(scope="function") def db_session(): """ Creates a fresh SQLite in-memory database for each test function. - + Yields a SQLAlchemy session that can be used for testing. """ # Set up the database @@ -46,6 +50,7 @@ async def async_test_db(): yield test_engine, AsyncTestingSessionLocal await teardown_async_test_db(test_engine) + @pytest.fixture def user_create_data(): return { @@ -55,7 +60,7 @@ def user_create_data(): "last_name": "User", "phone_number": "+1234567890", "is_superuser": False, - "preferences": None + "preferences": None, } @@ -102,7 +107,7 @@ async def client(async_test_db): This overrides the get_db dependency to use the test database. """ - test_engine, AsyncTestingSessionLocal = async_test_db + _test_engine, AsyncTestingSessionLocal = async_test_db async def override_get_db(): async with AsyncTestingSessionLocal() as session: @@ -176,7 +181,7 @@ async def async_test_user(async_test_db): Password: TestPassword123 """ - test_engine, AsyncTestingSessionLocal = async_test_db + _test_engine, AsyncTestingSessionLocal = async_test_db async with AsyncTestingSessionLocal() as session: user = User( id=uuid.uuid4(), @@ -202,7 +207,7 @@ async def async_test_superuser(async_test_db): Password: SuperPassword123 """ - test_engine, AsyncTestingSessionLocal = async_test_db + _test_engine, AsyncTestingSessionLocal = async_test_db async with AsyncTestingSessionLocal() as session: user = User( id=uuid.uuid4(), @@ -256,4 +261,4 @@ async def superuser_token(client, async_test_superuser): ) assert response.status_code == 200, f"Login failed: {response.text}" tokens = response.json() - return tokens["access_token"] \ No newline at end of file + return tokens["access_token"] diff --git a/backend/tests/core/test_auth.py b/backend/tests/core/test_auth.py index 500f1e3..42c12e2 100755 --- a/backend/tests/core/test_auth.py +++ b/backend/tests/core/test_auth.py @@ -1,20 +1,20 @@ # tests/core/test_auth.py import uuid +from datetime import UTC, datetime, timedelta + import pytest -from datetime import datetime, timedelta, timezone from jose import jwt -from pydantic import ValidationError from app.core.auth import ( - verify_password, - get_password_hash, + TokenExpiredError, + TokenInvalidError, + TokenMissingClaimError, create_access_token, create_refresh_token, decode_token, + get_password_hash, get_token_data, - TokenExpiredError, - TokenInvalidError, - TokenMissingClaimError + verify_password, ) from app.core.config import settings @@ -58,15 +58,13 @@ class TestTokenCreation: custom_claims = { "email": "test@example.com", "first_name": "Test", - "is_superuser": True + "is_superuser": True, } token = create_access_token(subject=user_id, claims=custom_claims) # Decode token to verify claims payload = jwt.decode( - token, - settings.SECRET_KEY, - algorithms=[settings.ALGORITHM] + token, settings.SECRET_KEY, algorithms=[settings.ALGORITHM] ) # Check standard claims @@ -87,9 +85,7 @@ class TestTokenCreation: # Decode token to verify claims payload = jwt.decode( - token, - settings.SECRET_KEY, - algorithms=[settings.ALGORITHM] + token, settings.SECRET_KEY, algorithms=[settings.ALGORITHM] ) # Check standard claims @@ -105,23 +101,18 @@ class TestTokenCreation: expires = timedelta(minutes=5) # Create token with specific expiration - token = create_access_token( - subject=user_id, - expires_delta=expires - ) + token = create_access_token(subject=user_id, expires_delta=expires) # Decode token payload = jwt.decode( - token, - settings.SECRET_KEY, - algorithms=[settings.ALGORITHM] + token, settings.SECRET_KEY, algorithms=[settings.ALGORITHM] ) # Get actual expiration time from token - expiration = datetime.fromtimestamp(payload["exp"], tz=timezone.utc) + expiration = datetime.fromtimestamp(payload["exp"], tz=UTC) # Calculate expected expiration (approximately) - now = datetime.now(timezone.utc) + now = datetime.now(UTC) expected_expiration = now + expires # Difference should be small (less than 1 second) @@ -148,7 +139,7 @@ class TestTokenDecoding: user_id = str(uuid.uuid4()) # Create a token that's already expired by directly manipulating the payload - now = datetime.now(timezone.utc) + now = datetime.now(UTC) expired_time = now - timedelta(hours=1) # 1 hour in the past # Create the expired token manually @@ -157,13 +148,11 @@ class TestTokenDecoding: "exp": int(expired_time.timestamp()), # Set expiration in the past "iat": int(now.timestamp()), "jti": str(uuid.uuid4()), - "type": "access" + "type": "access", } expired_token = jwt.encode( - payload, - settings.SECRET_KEY, - algorithm=settings.ALGORITHM + payload, settings.SECRET_KEY, algorithm=settings.ALGORITHM ) # Attempting to decode should raise TokenExpiredError @@ -180,20 +169,16 @@ class TestTokenDecoding: def test_decode_token_with_missing_sub(self): """Test that a token without 'sub' claim raises TokenMissingClaimError""" # Create a token without a subject - now = datetime.now(timezone.utc) + now = datetime.now(UTC) payload = { "exp": int((now + timedelta(minutes=30)).timestamp()), "iat": int(now.timestamp()), "jti": str(uuid.uuid4()), - "type": "access" + "type": "access", # No 'sub' claim } - token = jwt.encode( - payload, - settings.SECRET_KEY, - algorithm=settings.ALGORITHM - ) + token = jwt.encode(payload, settings.SECRET_KEY, algorithm=settings.ALGORITHM) with pytest.raises(TokenMissingClaimError): decode_token(token) @@ -211,20 +196,16 @@ class TestTokenDecoding: """Test that a token with invalid payload structure raises TokenInvalidError""" # Create a token with an invalid payload structure - missing 'sub' which is required # but including 'exp' to avoid the expiration check - now = datetime.now(timezone.utc) + now = datetime.now(UTC) payload = { # Missing "sub" field which is required "exp": int((now + timedelta(minutes=30)).timestamp()), "iat": int(now.timestamp()), "jti": str(uuid.uuid4()), - "invalid_field": "test" + "invalid_field": "test", } - token = jwt.encode( - payload, - settings.SECRET_KEY, - algorithm=settings.ALGORITHM - ) + token = jwt.encode(payload, settings.SECRET_KEY, algorithm=settings.ALGORITHM) # Should raise TokenMissingClaimError due to missing 'sub' with pytest.raises(TokenMissingClaimError): @@ -236,11 +217,7 @@ class TestTokenDecoding: "exp": int((now + timedelta(minutes=30)).timestamp()), } - token = jwt.encode( - payload, - settings.SECRET_KEY, - algorithm=settings.ALGORITHM - ) + token = jwt.encode(payload, settings.SECRET_KEY, algorithm=settings.ALGORITHM) # Should raise TokenInvalidError due to ValidationError with pytest.raises(TokenInvalidError): @@ -249,12 +226,9 @@ class TestTokenDecoding: def test_get_token_data(self): """Test extracting TokenData from a token""" user_id = uuid.uuid4() - token = create_access_token( - subject=str(user_id), - claims={"is_superuser": True} - ) + token = create_access_token(subject=str(user_id), claims={"is_superuser": True}) token_data = get_token_data(token) assert token_data.user_id == user_id - assert token_data.is_superuser is True \ No newline at end of file + assert token_data.is_superuser is True diff --git a/backend/tests/core/test_auth_security.py b/backend/tests/core/test_auth_security.py index 02d00f3..2f0d1ba 100644 --- a/backend/tests/core/test_auth_security.py +++ b/backend/tests/core/test_auth_security.py @@ -8,11 +8,11 @@ Critical security tests covering: These tests cover critical security vulnerabilities that could be exploited. """ + import pytest from jose import jwt -from datetime import datetime, timedelta, timezone -from app.core.auth import decode_token, create_access_token, TokenInvalidError +from app.core.auth import TokenInvalidError, create_access_token, decode_token from app.core.config import settings @@ -46,13 +46,14 @@ class TestJWTAlgorithmSecurityAttacks: """ # Create a payload that would normally be valid (using timestamps) import time + now = int(time.time()) payload = { "sub": "user123", "exp": now + 3600, # 1 hour from now "iat": now, - "type": "access" + "type": "access", } # Craft a malicious token with "alg: none" @@ -61,13 +62,13 @@ class TestJWTAlgorithmSecurityAttacks: import json header = {"alg": "none", "typ": "JWT"} - header_encoded = base64.urlsafe_b64encode( - json.dumps(header).encode() - ).decode().rstrip("=") + header_encoded = ( + base64.urlsafe_b64encode(json.dumps(header).encode()).decode().rstrip("=") + ) - payload_encoded = base64.urlsafe_b64encode( - json.dumps(payload).encode() - ).decode().rstrip("=") + payload_encoded = ( + base64.urlsafe_b64encode(json.dumps(payload).encode()).decode().rstrip("=") + ) # Token with no signature (algorithm "none") malicious_token = f"{header_encoded}.{payload_encoded}." @@ -85,22 +86,17 @@ class TestJWTAlgorithmSecurityAttacks: import time now = int(time.time()) - payload = { - "sub": "user123", - "exp": now + 3600, - "iat": now, - "type": "access" - } + payload = {"sub": "user123", "exp": now + 3600, "iat": now, "type": "access"} # Try uppercase "NONE" header = {"alg": "NONE", "typ": "JWT"} - header_encoded = base64.urlsafe_b64encode( - json.dumps(header).encode() - ).decode().rstrip("=") + header_encoded = ( + base64.urlsafe_b64encode(json.dumps(header).encode()).decode().rstrip("=") + ) - payload_encoded = base64.urlsafe_b64encode( - json.dumps(payload).encode() - ).decode().rstrip("=") + payload_encoded = ( + base64.urlsafe_b64encode(json.dumps(payload).encode()).decode().rstrip("=") + ) malicious_token = f"{header_encoded}.{payload_encoded}." @@ -121,15 +117,11 @@ class TestJWTAlgorithmSecurityAttacks: before our defensive checks at line 212. This is good for security! """ import time + now = int(time.time()) # Create a valid payload - payload = { - "sub": "user123", - "exp": now + 3600, - "iat": now, - "type": "access" - } + payload = {"sub": "user123", "exp": now + 3600, "iat": now, "type": "access"} # Encode with wrong algorithm (RS256 instead of HS256) # This simulates an attacker trying algorithm substitution @@ -137,9 +129,7 @@ class TestJWTAlgorithmSecurityAttacks: try: malicious_token = jwt.encode( - payload, - settings.SECRET_KEY, - algorithm=wrong_algorithm + payload, settings.SECRET_KEY, algorithm=wrong_algorithm ) # Should reject the token (library catches mismatch) @@ -156,21 +146,15 @@ class TestJWTAlgorithmSecurityAttacks: Prevents algorithm downgrade/upgrade attacks. """ import time + now = int(time.time()) - payload = { - "sub": "user123", - "exp": now + 3600, - "iat": now, - "type": "access" - } + payload = {"sub": "user123", "exp": now + 3600, "iat": now, "type": "access"} # Create token with HS384 instead of HS256 try: malicious_token = jwt.encode( - payload, - settings.SECRET_KEY, - algorithm="HS384" + payload, settings.SECRET_KEY, algorithm="HS384" ) with pytest.raises(TokenInvalidError): @@ -223,20 +207,15 @@ class TestJWTSecurityEdgeCases: # Create token without "alg" in header header = {"typ": "JWT"} # Missing "alg" - payload = { - "sub": "user123", - "exp": now + 3600, - "iat": now, - "type": "access" - } + payload = {"sub": "user123", "exp": now + 3600, "iat": now, "type": "access"} - header_encoded = base64.urlsafe_b64encode( - json.dumps(header).encode() - ).decode().rstrip("=") + header_encoded = ( + base64.urlsafe_b64encode(json.dumps(header).encode()).decode().rstrip("=") + ) - payload_encoded = base64.urlsafe_b64encode( - json.dumps(payload).encode() - ).decode().rstrip("=") + payload_encoded = ( + base64.urlsafe_b64encode(json.dumps(payload).encode()).decode().rstrip("=") + ) malicious_token = f"{header_encoded}.{payload_encoded}.fake_signature" @@ -253,15 +232,20 @@ class TestJWTSecurityEdgeCases: """Test token with malformed JSON in payload.""" import base64 - header = {"alg": "HS256", "typ": "JWT"} - header_encoded = base64.urlsafe_b64encode( - b'{"alg":"HS256","typ":"JWT"}' - ).decode().rstrip("=") + header_encoded = ( + base64.urlsafe_b64encode(b'{"alg":"HS256","typ":"JWT"}') + .decode() + .rstrip("=") + ) # Invalid JSON (missing closing brace) - invalid_payload_encoded = base64.urlsafe_b64encode( - b'{"sub":"user123"' # Invalid JSON - ).decode().rstrip("=") + invalid_payload_encoded = ( + base64.urlsafe_b64encode( + b'{"sub":"user123"' # Invalid JSON + ) + .decode() + .rstrip("=") + ) malicious_token = f"{header_encoded}.{invalid_payload_encoded}.fake_sig" diff --git a/backend/tests/core/test_config.py b/backend/tests/core/test_config.py index c5cadcd..a2ba0b1 100755 --- a/backend/tests/core/test_config.py +++ b/backend/tests/core/test_config.py @@ -1,6 +1,7 @@ # tests/core/test_config.py import pytest from pydantic import ValidationError + from app.core.config import Settings @@ -22,11 +23,15 @@ class TestSecretKeyValidation: with pytest.raises(ValidationError) as exc_info: Settings(SECRET_KEY=default_key, ENVIRONMENT="production") - assert "must be set to a secure random value in production" in str(exc_info.value) + assert "must be set to a secure random value in production" in str( + exc_info.value + ) def test_default_secret_key_in_development_allows_with_warning(self, caplog): """Test that default SECRET_KEY in development is allowed but warns""" - settings = Settings(SECRET_KEY="your_secret_key_here" + "x" * 14, ENVIRONMENT="development") + settings = Settings( + SECRET_KEY="your_secret_key_here" + "x" * 14, ENVIRONMENT="development" + ) assert settings.SECRET_KEY == "your_secret_key_here" + "x" * 14 # Note: The warning happens during validation, which we've seen works @@ -44,19 +49,13 @@ class TestSuperuserPasswordValidation: def test_none_password_accepted(self): """Test that None password is accepted (optional field)""" - settings = Settings( - SECRET_KEY="a" * 32, - FIRST_SUPERUSER_PASSWORD=None - ) + settings = Settings(SECRET_KEY="a" * 32, FIRST_SUPERUSER_PASSWORD=None) assert settings.FIRST_SUPERUSER_PASSWORD is None def test_password_too_short_raises_error(self): """Test that password shorter than 12 characters raises error""" with pytest.raises(ValidationError) as exc_info: - Settings( - SECRET_KEY="a" * 32, - FIRST_SUPERUSER_PASSWORD="Short1" - ) + Settings(SECRET_KEY="a" * 32, FIRST_SUPERUSER_PASSWORD="Short1") assert "must be at least 12 characters" in str(exc_info.value) @@ -64,14 +63,11 @@ class TestSuperuserPasswordValidation: """Test that common weak passwords are rejected""" # Test with the exact weak passwords from the validator # These are in the weak_passwords set and should be rejected - weak_passwords = ['123456789012'] # Exactly 12 chars, in the weak set + weak_passwords = ["123456789012"] # Exactly 12 chars, in the weak set for weak_pwd in weak_passwords: with pytest.raises(ValidationError) as exc_info: - Settings( - SECRET_KEY="a" * 32, - FIRST_SUPERUSER_PASSWORD=weak_pwd - ) + Settings(SECRET_KEY="a" * 32, FIRST_SUPERUSER_PASSWORD=weak_pwd) # Should get "too weak" message error_str = str(exc_info.value) assert "too weak" in error_str @@ -79,30 +75,21 @@ class TestSuperuserPasswordValidation: def test_password_without_lowercase_rejected(self): """Test that password without lowercase is rejected""" with pytest.raises(ValidationError) as exc_info: - Settings( - SECRET_KEY="a" * 32, - FIRST_SUPERUSER_PASSWORD="ALLUPPERCASE123" - ) + Settings(SECRET_KEY="a" * 32, FIRST_SUPERUSER_PASSWORD="ALLUPPERCASE123") assert "must contain lowercase, uppercase, and digits" in str(exc_info.value) def test_password_without_uppercase_rejected(self): """Test that password without uppercase is rejected""" with pytest.raises(ValidationError) as exc_info: - Settings( - SECRET_KEY="a" * 32, - FIRST_SUPERUSER_PASSWORD="alllowercase123" - ) + Settings(SECRET_KEY="a" * 32, FIRST_SUPERUSER_PASSWORD="alllowercase123") assert "must contain lowercase, uppercase, and digits" in str(exc_info.value) def test_password_without_digit_rejected(self): """Test that password without digit is rejected""" with pytest.raises(ValidationError) as exc_info: - Settings( - SECRET_KEY="a" * 32, - FIRST_SUPERUSER_PASSWORD="NoDigitsHere" - ) + Settings(SECRET_KEY="a" * 32, FIRST_SUPERUSER_PASSWORD="NoDigitsHere") assert "must contain lowercase, uppercase, and digits" in str(exc_info.value) @@ -110,8 +97,7 @@ class TestSuperuserPasswordValidation: """Test that strong password is accepted""" strong_password = "StrongPassword123!" settings = Settings( - SECRET_KEY="a" * 32, - FIRST_SUPERUSER_PASSWORD=strong_password + SECRET_KEY="a" * 32, FIRST_SUPERUSER_PASSWORD=strong_password ) assert settings.FIRST_SUPERUSER_PASSWORD == strong_password @@ -150,7 +136,7 @@ class TestDatabaseConfiguration: POSTGRES_HOST="testhost", POSTGRES_PORT="5432", POSTGRES_DB="testdb", - DATABASE_URL=None # Don't use explicit URL + DATABASE_URL=None, # Don't use explicit URL ) expected_url = "postgresql://testuser:testpass@testhost:5432/testdb" @@ -159,10 +145,7 @@ class TestDatabaseConfiguration: def test_explicit_database_url_used_when_set(self): """Test that explicit DATABASE_URL is used when provided""" explicit_url = "postgresql://explicit:pass@host:5432/db" - settings = Settings( - SECRET_KEY="a" * 32, - DATABASE_URL=explicit_url - ) + settings = Settings(SECRET_KEY="a" * 32, DATABASE_URL=explicit_url) assert settings.database_url == explicit_url diff --git a/backend/tests/core/test_config_security.py b/backend/tests/core/test_config_security.py index b9c8217..541264e 100644 --- a/backend/tests/core/test_config_security.py +++ b/backend/tests/core/test_config_security.py @@ -6,8 +6,10 @@ Critical security tests covering: These tests prevent security misconfigurations. """ -import pytest + import os + +import pytest from pydantic import ValidationError @@ -43,6 +45,7 @@ class TestSecretKeySecurityValidation: # Import Settings class fresh (to pick up new env var) # The ValidationError should be raised during reload when Settings() is instantiated import importlib + from app.core import config # Reload will raise ValidationError because Settings() is instantiated at module level @@ -58,7 +61,9 @@ class TestSecretKeySecurityValidation: # Reload config to restore original settings import importlib + from app.core import config + importlib.reload(config) def test_secret_key_exactly_32_characters_accepted(self): @@ -75,7 +80,9 @@ class TestSecretKeySecurityValidation: os.environ["SECRET_KEY"] = key_32 import importlib + from app.core import config + importlib.reload(config) # Should work @@ -89,7 +96,9 @@ class TestSecretKeySecurityValidation: os.environ.pop("SECRET_KEY", None) import importlib + from app.core import config + importlib.reload(config) def test_secret_key_long_enough_accepted(self): @@ -106,7 +115,9 @@ class TestSecretKeySecurityValidation: os.environ["SECRET_KEY"] = key_64 import importlib + from app.core import config + importlib.reload(config) # Should work @@ -120,7 +131,9 @@ class TestSecretKeySecurityValidation: os.environ.pop("SECRET_KEY", None) import importlib + from app.core import config + importlib.reload(config) def test_default_secret_key_meets_requirements(self): @@ -132,4 +145,6 @@ class TestSecretKeySecurityValidation: from app.core.config import settings # Current settings should have valid SECRET_KEY - assert len(settings.SECRET_KEY) >= 32, "Default SECRET_KEY must be at least 32 chars" + assert len(settings.SECRET_KEY) >= 32, ( + "Default SECRET_KEY must be at least 32 chars" + ) diff --git a/backend/tests/core/test_database.py b/backend/tests/core/test_database.py index 0c1ce50..fa17411 100644 --- a/backend/tests/core/test_database.py +++ b/backend/tests/core/test_database.py @@ -9,18 +9,19 @@ Covers: - init_async_db - close_async_db """ + +from unittest.mock import patch + import pytest -import pytest_asyncio -from unittest.mock import patch, MagicMock, AsyncMock from sqlalchemy.ext.asyncio import AsyncSession from app.core.database import ( - get_async_database_url, - get_db, async_transaction_scope, check_async_database_health, - init_async_db, close_async_db, + get_async_database_url, + get_db, + init_async_db, ) @@ -88,12 +89,13 @@ class TestAsyncTransactionScope: async def test_transaction_scope_commits_on_success(self, async_test_db): """Test that successful operations are committed (covers line 138).""" # Mock the transaction scope to use test database - test_engine, SessionLocal = async_test_db + _test_engine, SessionLocal = async_test_db - with patch('app.core.database.SessionLocal', SessionLocal): + with patch("app.core.database.SessionLocal", SessionLocal): async with async_transaction_scope() as db: # Execute a simple query to verify transaction works from sqlalchemy import text + result = await db.execute(text("SELECT 1")) assert result is not None # Transaction should be committed (covers line 138 debug log) @@ -101,12 +103,13 @@ class TestAsyncTransactionScope: @pytest.mark.asyncio async def test_transaction_scope_rollback_on_error(self, async_test_db): """Test that transaction rolls back on exception.""" - test_engine, SessionLocal = async_test_db + _test_engine, SessionLocal = async_test_db - with patch('app.core.database.SessionLocal', SessionLocal): + with patch("app.core.database.SessionLocal", SessionLocal): with pytest.raises(RuntimeError, match="Test error"): async with async_transaction_scope() as db: from sqlalchemy import text + await db.execute(text("SELECT 1")) raise RuntimeError("Test error") @@ -117,9 +120,9 @@ class TestCheckAsyncDatabaseHealth: @pytest.mark.asyncio async def test_database_health_check_success(self, async_test_db): """Test health check returns True on success (covers line 156).""" - test_engine, SessionLocal = async_test_db + _test_engine, SessionLocal = async_test_db - with patch('app.core.database.SessionLocal', SessionLocal): + with patch("app.core.database.SessionLocal", SessionLocal): result = await check_async_database_health() assert result is True @@ -127,7 +130,7 @@ class TestCheckAsyncDatabaseHealth: async def test_database_health_check_failure(self): """Test health check returns False on database error.""" # Mock async_transaction_scope to raise an error - with patch('app.core.database.async_transaction_scope') as mock_scope: + with patch("app.core.database.async_transaction_scope") as mock_scope: mock_scope.side_effect = Exception("Database connection failed") result = await check_async_database_health() @@ -140,10 +143,10 @@ class TestInitAsyncDb: @pytest.mark.asyncio async def test_init_async_db_creates_tables(self, async_test_db): """Test init_async_db creates tables (covers lines 174-176).""" - test_engine, SessionLocal = async_test_db + test_engine, _SessionLocal = async_test_db # Mock the engine to use test engine - with patch('app.core.database.engine', test_engine): + with patch("app.core.database.engine", test_engine): await init_async_db() # If no exception, tables were created successfully @@ -155,7 +158,6 @@ class TestCloseAsyncDb: async def test_close_async_db_disposes_engine(self): """Test close_async_db disposes engine (covers lines 185-186).""" # Create a fresh engine to test closing - from app.core.database import engine # Close connections await close_async_db() diff --git a/backend/tests/crud/test_base.py b/backend/tests/crud/test_base.py index aab66c2..e6a6b9c 100644 --- a/backend/tests/crud/test_base.py +++ b/backend/tests/crud/test_base.py @@ -2,14 +2,16 @@ """ Comprehensive tests for CRUDBase class covering all error paths and edge cases. """ + +from datetime import UTC +from unittest.mock import patch +from uuid import uuid4 + import pytest -from uuid import uuid4, UUID -from sqlalchemy.exc import IntegrityError, OperationalError, DataError +from sqlalchemy.exc import DataError, IntegrityError, OperationalError from sqlalchemy.orm import joinedload -from unittest.mock import AsyncMock, patch, MagicMock from app.crud.user import user as user_crud -from app.models.user import User from app.schemas.users import UserCreate, UserUpdate @@ -19,7 +21,7 @@ class TestCRUDBaseGet: @pytest.mark.asyncio async def test_get_with_invalid_uuid_string(self, async_test_db): """Test get with invalid UUID string returns None.""" - test_engine, SessionLocal = async_test_db + _test_engine, SessionLocal = async_test_db async with SessionLocal() as session: result = await user_crud.get(session, id="invalid-uuid") @@ -28,7 +30,7 @@ class TestCRUDBaseGet: @pytest.mark.asyncio async def test_get_with_invalid_uuid_type(self, async_test_db): """Test get with invalid UUID type returns None.""" - test_engine, SessionLocal = async_test_db + _test_engine, SessionLocal = async_test_db async with SessionLocal() as session: result = await user_crud.get(session, id=12345) # int instead of UUID @@ -37,7 +39,7 @@ class TestCRUDBaseGet: @pytest.mark.asyncio async def test_get_with_uuid_object(self, async_test_db, async_test_user): """Test get with UUID object instead of string.""" - test_engine, SessionLocal = async_test_db + _test_engine, SessionLocal = async_test_db async with SessionLocal() as session: # Pass UUID object directly @@ -48,26 +50,24 @@ class TestCRUDBaseGet: @pytest.mark.asyncio async def test_get_with_options(self, async_test_db, async_test_user): """Test get with eager loading options (tests lines 76-78).""" - test_engine, SessionLocal = async_test_db + _test_engine, SessionLocal = async_test_db async with SessionLocal() as session: # Test that options parameter is accepted and doesn't error # We pass an empty list which still tests the code path result = await user_crud.get( - session, - id=str(async_test_user.id), - options=[] + session, id=str(async_test_user.id), options=[] ) assert result is not None @pytest.mark.asyncio async def test_get_database_error(self, async_test_db): """Test get handles database errors properly.""" - test_engine, SessionLocal = async_test_db + _test_engine, SessionLocal = async_test_db async with SessionLocal() as session: # Mock execute to raise an exception - with patch.object(session, 'execute', side_effect=Exception("DB error")): + with patch.object(session, "execute", side_effect=Exception("DB error")): with pytest.raises(Exception, match="DB error"): await user_crud.get(session, id=str(uuid4())) @@ -78,7 +78,7 @@ class TestCRUDBaseGetMulti: @pytest.mark.asyncio async def test_get_multi_negative_skip(self, async_test_db): """Test get_multi with negative skip raises ValueError.""" - test_engine, SessionLocal = async_test_db + _test_engine, SessionLocal = async_test_db async with SessionLocal() as session: with pytest.raises(ValueError, match="skip must be non-negative"): @@ -87,7 +87,7 @@ class TestCRUDBaseGetMulti: @pytest.mark.asyncio async def test_get_multi_negative_limit(self, async_test_db): """Test get_multi with negative limit raises ValueError.""" - test_engine, SessionLocal = async_test_db + _test_engine, SessionLocal = async_test_db async with SessionLocal() as session: with pytest.raises(ValueError, match="limit must be non-negative"): @@ -96,7 +96,7 @@ class TestCRUDBaseGetMulti: @pytest.mark.asyncio async def test_get_multi_limit_too_large(self, async_test_db): """Test get_multi with limit > 1000 raises ValueError.""" - test_engine, SessionLocal = async_test_db + _test_engine, SessionLocal = async_test_db async with SessionLocal() as session: with pytest.raises(ValueError, match="Maximum limit is 1000"): @@ -105,25 +105,20 @@ class TestCRUDBaseGetMulti: @pytest.mark.asyncio async def test_get_multi_with_options(self, async_test_db, async_test_user): """Test get_multi with eager loading options (tests lines 118-120).""" - test_engine, SessionLocal = async_test_db + _test_engine, SessionLocal = async_test_db async with SessionLocal() as session: # Test that options parameter is accepted - results = await user_crud.get_multi( - session, - skip=0, - limit=10, - options=[] - ) + results = await user_crud.get_multi(session, skip=0, limit=10, options=[]) assert isinstance(results, list) @pytest.mark.asyncio async def test_get_multi_database_error(self, async_test_db): """Test get_multi handles database errors.""" - test_engine, SessionLocal = async_test_db + _test_engine, SessionLocal = async_test_db async with SessionLocal() as session: - with patch.object(session, 'execute', side_effect=Exception("DB error")): + with patch.object(session, "execute", side_effect=Exception("DB error")): with pytest.raises(Exception, match="DB error"): await user_crud.get_multi(session) @@ -134,7 +129,7 @@ class TestCRUDBaseCreate: @pytest.mark.asyncio async def test_create_duplicate_unique_field(self, async_test_db, async_test_user): """Test create with duplicate unique field raises ValueError.""" - test_engine, SessionLocal = async_test_db + _test_engine, SessionLocal = async_test_db async with SessionLocal() as session: # Try to create user with duplicate email @@ -142,7 +137,7 @@ class TestCRUDBaseCreate: email=async_test_user.email, # Duplicate! password="TestPassword123!", first_name="Test", - last_name="Duplicate" + last_name="Duplicate", ) with pytest.raises(ValueError, match="already exists"): @@ -151,22 +146,23 @@ class TestCRUDBaseCreate: @pytest.mark.asyncio async def test_create_integrity_error_non_duplicate(self, async_test_db): """Test create with non-duplicate IntegrityError.""" - test_engine, SessionLocal = async_test_db + _test_engine, SessionLocal = async_test_db async with SessionLocal() as session: # Mock commit to raise IntegrityError without "unique" in message - original_commit = session.commit async def mock_commit(): - error = IntegrityError("statement", {}, Exception("foreign key violation")) + error = IntegrityError( + "statement", {}, Exception("foreign key violation") + ) raise error - with patch.object(session, 'commit', side_effect=mock_commit): + with patch.object(session, "commit", side_effect=mock_commit): user_data = UserCreate( email="test@example.com", password="TestPassword123!", first_name="Test", - last_name="User" + last_name="User", ) with pytest.raises(ValueError, match="Database integrity error"): @@ -175,15 +171,21 @@ class TestCRUDBaseCreate: @pytest.mark.asyncio async def test_create_operational_error(self, async_test_db): """Test create with OperationalError (user CRUD catches as generic Exception).""" - test_engine, SessionLocal = async_test_db + _test_engine, SessionLocal = async_test_db async with SessionLocal() as session: - with patch.object(session, 'commit', side_effect=OperationalError("statement", {}, Exception("connection lost"))): + with patch.object( + session, + "commit", + side_effect=OperationalError( + "statement", {}, Exception("connection lost") + ), + ): user_data = UserCreate( email="test@example.com", password="TestPassword123!", first_name="Test", - last_name="User" + last_name="User", ) # User CRUD catches this as generic Exception and re-raises @@ -193,15 +195,19 @@ class TestCRUDBaseCreate: @pytest.mark.asyncio async def test_create_data_error(self, async_test_db): """Test create with DataError (user CRUD catches as generic Exception).""" - test_engine, SessionLocal = async_test_db + _test_engine, SessionLocal = async_test_db async with SessionLocal() as session: - with patch.object(session, 'commit', side_effect=DataError("statement", {}, Exception("invalid data"))): + with patch.object( + session, + "commit", + side_effect=DataError("statement", {}, Exception("invalid data")), + ): user_data = UserCreate( email="test@example.com", password="TestPassword123!", first_name="Test", - last_name="User" + last_name="User", ) # User CRUD catches this as generic Exception and re-raises @@ -211,15 +217,17 @@ class TestCRUDBaseCreate: @pytest.mark.asyncio async def test_create_unexpected_error(self, async_test_db): """Test create with unexpected exception.""" - test_engine, SessionLocal = async_test_db + _test_engine, SessionLocal = async_test_db async with SessionLocal() as session: - with patch.object(session, 'commit', side_effect=RuntimeError("Unexpected error")): + with patch.object( + session, "commit", side_effect=RuntimeError("Unexpected error") + ): user_data = UserCreate( email="test@example.com", password="TestPassword123!", first_name="Test", - last_name="User" + last_name="User", ) with pytest.raises(RuntimeError, match="Unexpected error"): @@ -232,16 +240,17 @@ class TestCRUDBaseUpdate: @pytest.mark.asyncio async def test_update_duplicate_unique_field(self, async_test_db, async_test_user): """Test update with duplicate unique field raises ValueError.""" - test_engine, SessionLocal = async_test_db + _test_engine, SessionLocal = async_test_db # Create another user async with SessionLocal() as session: from app.crud.user import user as user_crud + user2_data = UserCreate( email="user2@example.com", password="TestPassword123!", first_name="User", - last_name="Two" + last_name="Two", ) user2 = await user_crud.create(session, obj_in=user2_data) await session.commit() @@ -250,63 +259,89 @@ class TestCRUDBaseUpdate: async with SessionLocal() as session: user2_obj = await user_crud.get(session, id=str(user2.id)) - with patch.object(session, 'commit', side_effect=IntegrityError("statement", {}, Exception("UNIQUE constraint failed"))): + with patch.object( + session, + "commit", + side_effect=IntegrityError( + "statement", {}, Exception("UNIQUE constraint failed") + ), + ): update_data = UserUpdate(email=async_test_user.email) with pytest.raises(ValueError, match="already exists"): - await user_crud.update(session, db_obj=user2_obj, obj_in=update_data) + await user_crud.update( + session, db_obj=user2_obj, obj_in=update_data + ) @pytest.mark.asyncio async def test_update_with_dict(self, async_test_db, async_test_user): """Test update with dict instead of schema.""" - test_engine, SessionLocal = async_test_db + _test_engine, SessionLocal = async_test_db async with SessionLocal() as session: user = await user_crud.get(session, id=str(async_test_user.id)) # Update with dict (tests lines 164-165) updated = await user_crud.update( - session, - db_obj=user, - obj_in={"first_name": "UpdatedName"} + session, db_obj=user, obj_in={"first_name": "UpdatedName"} ) assert updated.first_name == "UpdatedName" @pytest.mark.asyncio async def test_update_integrity_error(self, async_test_db, async_test_user): """Test update with IntegrityError.""" - test_engine, SessionLocal = async_test_db + _test_engine, SessionLocal = async_test_db async with SessionLocal() as session: user = await user_crud.get(session, id=str(async_test_user.id)) - with patch.object(session, 'commit', side_effect=IntegrityError("statement", {}, Exception("constraint failed"))): + with patch.object( + session, + "commit", + side_effect=IntegrityError( + "statement", {}, Exception("constraint failed") + ), + ): with pytest.raises(ValueError, match="Database integrity error"): - await user_crud.update(session, db_obj=user, obj_in={"first_name": "Test"}) + await user_crud.update( + session, db_obj=user, obj_in={"first_name": "Test"} + ) @pytest.mark.asyncio async def test_update_operational_error(self, async_test_db, async_test_user): """Test update with OperationalError.""" - test_engine, SessionLocal = async_test_db + _test_engine, SessionLocal = async_test_db async with SessionLocal() as session: user = await user_crud.get(session, id=str(async_test_user.id)) - with patch.object(session, 'commit', side_effect=OperationalError("statement", {}, Exception("connection error"))): + with patch.object( + session, + "commit", + side_effect=OperationalError( + "statement", {}, Exception("connection error") + ), + ): with pytest.raises(ValueError, match="Database operation failed"): - await user_crud.update(session, db_obj=user, obj_in={"first_name": "Test"}) + await user_crud.update( + session, db_obj=user, obj_in={"first_name": "Test"} + ) @pytest.mark.asyncio async def test_update_unexpected_error(self, async_test_db, async_test_user): """Test update with unexpected error.""" - test_engine, SessionLocal = async_test_db + _test_engine, SessionLocal = async_test_db async with SessionLocal() as session: user = await user_crud.get(session, id=str(async_test_user.id)) - with patch.object(session, 'commit', side_effect=RuntimeError("Unexpected")): + with patch.object( + session, "commit", side_effect=RuntimeError("Unexpected") + ): with pytest.raises(RuntimeError): - await user_crud.update(session, db_obj=user, obj_in={"first_name": "Test"}) + await user_crud.update( + session, db_obj=user, obj_in={"first_name": "Test"} + ) class TestCRUDBaseRemove: @@ -315,7 +350,7 @@ class TestCRUDBaseRemove: @pytest.mark.asyncio async def test_remove_invalid_uuid(self, async_test_db): """Test remove with invalid UUID returns None.""" - test_engine, SessionLocal = async_test_db + _test_engine, SessionLocal = async_test_db async with SessionLocal() as session: result = await user_crud.remove(session, id="invalid-uuid") @@ -324,7 +359,7 @@ class TestCRUDBaseRemove: @pytest.mark.asyncio async def test_remove_with_uuid_object(self, async_test_db, async_test_user): """Test remove with UUID object.""" - test_engine, SessionLocal = async_test_db + _test_engine, SessionLocal = async_test_db # Create a user to delete async with SessionLocal() as session: @@ -332,7 +367,7 @@ class TestCRUDBaseRemove: email="todelete@example.com", password="TestPassword123!", first_name="To", - last_name="Delete" + last_name="Delete", ) user = await user_crud.create(session, obj_in=user_data) user_id = user.id @@ -347,7 +382,7 @@ class TestCRUDBaseRemove: @pytest.mark.asyncio async def test_remove_nonexistent(self, async_test_db): """Test remove of nonexistent record returns None.""" - test_engine, SessionLocal = async_test_db + _test_engine, SessionLocal = async_test_db async with SessionLocal() as session: result = await user_crud.remove(session, id=str(uuid4())) @@ -356,21 +391,31 @@ class TestCRUDBaseRemove: @pytest.mark.asyncio async def test_remove_integrity_error(self, async_test_db, async_test_user): """Test remove with IntegrityError (foreign key constraint).""" - test_engine, SessionLocal = async_test_db + _test_engine, SessionLocal = async_test_db async with SessionLocal() as session: # Mock delete to raise IntegrityError - with patch.object(session, 'commit', side_effect=IntegrityError("statement", {}, Exception("FOREIGN KEY constraint"))): - with pytest.raises(ValueError, match="Cannot delete.*referenced by other records"): + with patch.object( + session, + "commit", + side_effect=IntegrityError( + "statement", {}, Exception("FOREIGN KEY constraint") + ), + ): + with pytest.raises( + ValueError, match="Cannot delete.*referenced by other records" + ): await user_crud.remove(session, id=str(async_test_user.id)) @pytest.mark.asyncio async def test_remove_unexpected_error(self, async_test_db, async_test_user): """Test remove with unexpected error.""" - test_engine, SessionLocal = async_test_db + _test_engine, SessionLocal = async_test_db async with SessionLocal() as session: - with patch.object(session, 'commit', side_effect=RuntimeError("Unexpected")): + with patch.object( + session, "commit", side_effect=RuntimeError("Unexpected") + ): with pytest.raises(RuntimeError): await user_crud.remove(session, id=str(async_test_user.id)) @@ -381,10 +426,12 @@ class TestCRUDBaseGetMultiWithTotal: @pytest.mark.asyncio async def test_get_multi_with_total_basic(self, async_test_db, async_test_user): """Test get_multi_with_total basic functionality.""" - test_engine, SessionLocal = async_test_db + _test_engine, SessionLocal = async_test_db async with SessionLocal() as session: - items, total = await user_crud.get_multi_with_total(session, skip=0, limit=10) + items, total = await user_crud.get_multi_with_total( + session, skip=0, limit=10 + ) assert isinstance(items, list) assert isinstance(total, int) assert total >= 1 # At least the test user @@ -392,7 +439,7 @@ class TestCRUDBaseGetMultiWithTotal: @pytest.mark.asyncio async def test_get_multi_with_total_negative_skip(self, async_test_db): """Test get_multi_with_total with negative skip raises ValueError.""" - test_engine, SessionLocal = async_test_db + _test_engine, SessionLocal = async_test_db async with SessionLocal() as session: with pytest.raises(ValueError, match="skip must be non-negative"): @@ -401,7 +448,7 @@ class TestCRUDBaseGetMultiWithTotal: @pytest.mark.asyncio async def test_get_multi_with_total_negative_limit(self, async_test_db): """Test get_multi_with_total with negative limit raises ValueError.""" - test_engine, SessionLocal = async_test_db + _test_engine, SessionLocal = async_test_db async with SessionLocal() as session: with pytest.raises(ValueError, match="limit must be non-negative"): @@ -410,28 +457,34 @@ class TestCRUDBaseGetMultiWithTotal: @pytest.mark.asyncio async def test_get_multi_with_total_limit_too_large(self, async_test_db): """Test get_multi_with_total with limit > 1000 raises ValueError.""" - test_engine, SessionLocal = async_test_db + _test_engine, SessionLocal = async_test_db async with SessionLocal() as session: with pytest.raises(ValueError, match="Maximum limit is 1000"): await user_crud.get_multi_with_total(session, limit=1001) @pytest.mark.asyncio - async def test_get_multi_with_total_with_filters(self, async_test_db, async_test_user): + async def test_get_multi_with_total_with_filters( + self, async_test_db, async_test_user + ): """Test get_multi_with_total with filters.""" - test_engine, SessionLocal = async_test_db + _test_engine, SessionLocal = async_test_db async with SessionLocal() as session: filters = {"email": async_test_user.email} - items, total = await user_crud.get_multi_with_total(session, filters=filters) + items, total = await user_crud.get_multi_with_total( + session, filters=filters + ) assert total == 1 assert len(items) == 1 assert items[0].email == async_test_user.email @pytest.mark.asyncio - async def test_get_multi_with_total_with_sorting_asc(self, async_test_db, async_test_user): + async def test_get_multi_with_total_with_sorting_asc( + self, async_test_db, async_test_user + ): """Test get_multi_with_total with ascending sort.""" - test_engine, SessionLocal = async_test_db + _test_engine, SessionLocal = async_test_db # Create additional users async with SessionLocal() as session: @@ -439,13 +492,13 @@ class TestCRUDBaseGetMultiWithTotal: email="aaa@example.com", password="TestPassword123!", first_name="AAA", - last_name="User" + last_name="User", ) user_data2 = UserCreate( email="zzz@example.com", password="TestPassword123!", first_name="ZZZ", - last_name="User" + last_name="User", ) await user_crud.create(session, obj_in=user_data1) await user_crud.create(session, obj_in=user_data2) @@ -460,9 +513,11 @@ class TestCRUDBaseGetMultiWithTotal: assert items[0].email == "aaa@example.com" @pytest.mark.asyncio - async def test_get_multi_with_total_with_sorting_desc(self, async_test_db, async_test_user): + async def test_get_multi_with_total_with_sorting_desc( + self, async_test_db, async_test_user + ): """Test get_multi_with_total with descending sort.""" - test_engine, SessionLocal = async_test_db + _test_engine, SessionLocal = async_test_db # Create additional users async with SessionLocal() as session: @@ -470,20 +525,20 @@ class TestCRUDBaseGetMultiWithTotal: email="bbb@example.com", password="TestPassword123!", first_name="BBB", - last_name="User" + last_name="User", ) user_data2 = UserCreate( email="ccc@example.com", password="TestPassword123!", first_name="CCC", - last_name="User" + last_name="User", ) await user_crud.create(session, obj_in=user_data1) await user_crud.create(session, obj_in=user_data2) await session.commit() async with SessionLocal() as session: - items, total = await user_crud.get_multi_with_total( + items, _total = await user_crud.get_multi_with_total( session, sort_by="email", sort_order="desc", limit=1 ) assert len(items) == 1 @@ -492,7 +547,7 @@ class TestCRUDBaseGetMultiWithTotal: @pytest.mark.asyncio async def test_get_multi_with_total_with_pagination(self, async_test_db): """Test get_multi_with_total pagination works correctly.""" - test_engine, SessionLocal = async_test_db + _test_engine, SessionLocal = async_test_db # Create minimal users for pagination test (3 instead of 5) async with SessionLocal() as session: @@ -501,19 +556,23 @@ class TestCRUDBaseGetMultiWithTotal: email=f"user{i}@example.com", password="TestPassword123!", first_name=f"User{i}", - last_name="Test" + last_name="Test", ) await user_crud.create(session, obj_in=user_data) await session.commit() async with SessionLocal() as session: # Get first page - items1, total = await user_crud.get_multi_with_total(session, skip=0, limit=2) + items1, total = await user_crud.get_multi_with_total( + session, skip=0, limit=2 + ) assert len(items1) == 2 assert total >= 3 # Get second page - items2, total2 = await user_crud.get_multi_with_total(session, skip=2, limit=2) + items2, total2 = await user_crud.get_multi_with_total( + session, skip=2, limit=2 + ) assert len(items2) >= 1 assert total2 == total @@ -529,7 +588,7 @@ class TestCRUDBaseCount: @pytest.mark.asyncio async def test_count_basic(self, async_test_db, async_test_user): """Test count returns correct number.""" - test_engine, SessionLocal = async_test_db + _test_engine, SessionLocal = async_test_db async with SessionLocal() as session: count = await user_crud.count(session) @@ -539,7 +598,7 @@ class TestCRUDBaseCount: @pytest.mark.asyncio async def test_count_multiple_users(self, async_test_db, async_test_user): """Test count with multiple users.""" - test_engine, SessionLocal = async_test_db + _test_engine, SessionLocal = async_test_db # Create additional users async with SessionLocal() as session: @@ -549,13 +608,13 @@ class TestCRUDBaseCount: email="count1@example.com", password="TestPassword123!", first_name="Count", - last_name="One" + last_name="One", ) user_data2 = UserCreate( email="count2@example.com", password="TestPassword123!", first_name="Count", - last_name="Two" + last_name="Two", ) await user_crud.create(session, obj_in=user_data1) await user_crud.create(session, obj_in=user_data2) @@ -568,10 +627,10 @@ class TestCRUDBaseCount: @pytest.mark.asyncio async def test_count_database_error(self, async_test_db): """Test count handles database errors.""" - test_engine, SessionLocal = async_test_db + _test_engine, SessionLocal = async_test_db async with SessionLocal() as session: - with patch.object(session, 'execute', side_effect=Exception("DB error")): + with patch.object(session, "execute", side_effect=Exception("DB error")): with pytest.raises(Exception, match="DB error"): await user_crud.count(session) @@ -582,7 +641,7 @@ class TestCRUDBaseExists: @pytest.mark.asyncio async def test_exists_true(self, async_test_db, async_test_user): """Test exists returns True for existing record.""" - test_engine, SessionLocal = async_test_db + _test_engine, SessionLocal = async_test_db async with SessionLocal() as session: result = await user_crud.exists(session, id=str(async_test_user.id)) @@ -591,7 +650,7 @@ class TestCRUDBaseExists: @pytest.mark.asyncio async def test_exists_false(self, async_test_db): """Test exists returns False for non-existent record.""" - test_engine, SessionLocal = async_test_db + _test_engine, SessionLocal = async_test_db async with SessionLocal() as session: result = await user_crud.exists(session, id=str(uuid4())) @@ -600,7 +659,7 @@ class TestCRUDBaseExists: @pytest.mark.asyncio async def test_exists_invalid_uuid(self, async_test_db): """Test exists returns False for invalid UUID.""" - test_engine, SessionLocal = async_test_db + _test_engine, SessionLocal = async_test_db async with SessionLocal() as session: result = await user_crud.exists(session, id="invalid-uuid") @@ -613,7 +672,7 @@ class TestCRUDBaseSoftDelete: @pytest.mark.asyncio async def test_soft_delete_success(self, async_test_db): """Test soft delete sets deleted_at timestamp.""" - test_engine, SessionLocal = async_test_db + _test_engine, SessionLocal = async_test_db # Create a user to soft delete async with SessionLocal() as session: @@ -621,7 +680,7 @@ class TestCRUDBaseSoftDelete: email="softdelete@example.com", password="TestPassword123!", first_name="Soft", - last_name="Delete" + last_name="Delete", ) user = await user_crud.create(session, obj_in=user_data) user_id = user.id @@ -636,7 +695,7 @@ class TestCRUDBaseSoftDelete: @pytest.mark.asyncio async def test_soft_delete_invalid_uuid(self, async_test_db): """Test soft delete with invalid UUID returns None.""" - test_engine, SessionLocal = async_test_db + _test_engine, SessionLocal = async_test_db async with SessionLocal() as session: result = await user_crud.soft_delete(session, id="invalid-uuid") @@ -645,7 +704,7 @@ class TestCRUDBaseSoftDelete: @pytest.mark.asyncio async def test_soft_delete_nonexistent(self, async_test_db): """Test soft delete of nonexistent record returns None.""" - test_engine, SessionLocal = async_test_db + _test_engine, SessionLocal = async_test_db async with SessionLocal() as session: result = await user_crud.soft_delete(session, id=str(uuid4())) @@ -654,7 +713,7 @@ class TestCRUDBaseSoftDelete: @pytest.mark.asyncio async def test_soft_delete_with_uuid_object(self, async_test_db): """Test soft delete with UUID object.""" - test_engine, SessionLocal = async_test_db + _test_engine, SessionLocal = async_test_db # Create a user to soft delete async with SessionLocal() as session: @@ -662,7 +721,7 @@ class TestCRUDBaseSoftDelete: email="softdelete2@example.com", password="TestPassword123!", first_name="Soft", - last_name="Delete2" + last_name="Delete2", ) user = await user_crud.create(session, obj_in=user_data) user_id = user.id @@ -681,7 +740,7 @@ class TestCRUDBaseRestore: @pytest.mark.asyncio async def test_restore_success(self, async_test_db): """Test restore clears deleted_at timestamp.""" - test_engine, SessionLocal = async_test_db + _test_engine, SessionLocal = async_test_db # Create and soft delete a user async with SessionLocal() as session: @@ -689,7 +748,7 @@ class TestCRUDBaseRestore: email="restore@example.com", password="TestPassword123!", first_name="Restore", - last_name="Test" + last_name="Test", ) user = await user_crud.create(session, obj_in=user_data) user_id = user.id @@ -707,7 +766,7 @@ class TestCRUDBaseRestore: @pytest.mark.asyncio async def test_restore_invalid_uuid(self, async_test_db): """Test restore with invalid UUID returns None.""" - test_engine, SessionLocal = async_test_db + _test_engine, SessionLocal = async_test_db async with SessionLocal() as session: result = await user_crud.restore(session, id="invalid-uuid") @@ -716,7 +775,7 @@ class TestCRUDBaseRestore: @pytest.mark.asyncio async def test_restore_nonexistent(self, async_test_db): """Test restore of nonexistent record returns None.""" - test_engine, SessionLocal = async_test_db + _test_engine, SessionLocal = async_test_db async with SessionLocal() as session: result = await user_crud.restore(session, id=str(uuid4())) @@ -725,7 +784,7 @@ class TestCRUDBaseRestore: @pytest.mark.asyncio async def test_restore_not_deleted(self, async_test_db, async_test_user): """Test restore of non-deleted record returns None.""" - test_engine, SessionLocal = async_test_db + _test_engine, SessionLocal = async_test_db async with SessionLocal() as session: # Try to restore a user that's not deleted @@ -735,7 +794,7 @@ class TestCRUDBaseRestore: @pytest.mark.asyncio async def test_restore_with_uuid_object(self, async_test_db): """Test restore with UUID object.""" - test_engine, SessionLocal = async_test_db + _test_engine, SessionLocal = async_test_db # Create and soft delete a user async with SessionLocal() as session: @@ -743,7 +802,7 @@ class TestCRUDBaseRestore: email="restore2@example.com", password="TestPassword123!", first_name="Restore", - last_name="Test2" + last_name="Test2", ) user = await user_crud.create(session, obj_in=user_data) user_id = user.id @@ -765,7 +824,7 @@ class TestCRUDBasePaginationValidation: @pytest.mark.asyncio async def test_get_multi_with_total_negative_skip(self, async_test_db): """Test that negative skip raises ValueError.""" - test_engine, SessionLocal = async_test_db + _test_engine, SessionLocal = async_test_db async with SessionLocal() as session: with pytest.raises(ValueError, match="skip must be non-negative"): @@ -774,7 +833,7 @@ class TestCRUDBasePaginationValidation: @pytest.mark.asyncio async def test_get_multi_with_total_negative_limit(self, async_test_db): """Test that negative limit raises ValueError.""" - test_engine, SessionLocal = async_test_db + _test_engine, SessionLocal = async_test_db async with SessionLocal() as session: with pytest.raises(ValueError, match="limit must be non-negative"): @@ -783,23 +842,22 @@ class TestCRUDBasePaginationValidation: @pytest.mark.asyncio async def test_get_multi_with_total_limit_too_large(self, async_test_db): """Test that limit > 1000 raises ValueError.""" - test_engine, SessionLocal = async_test_db + _test_engine, SessionLocal = async_test_db async with SessionLocal() as session: with pytest.raises(ValueError, match="Maximum limit is 1000"): await user_crud.get_multi_with_total(session, skip=0, limit=1001) @pytest.mark.asyncio - async def test_get_multi_with_total_with_filters(self, async_test_db, async_test_user): + async def test_get_multi_with_total_with_filters( + self, async_test_db, async_test_user + ): """Test pagination with filters (covers lines 270-273).""" - test_engine, SessionLocal = async_test_db + _test_engine, SessionLocal = async_test_db async with SessionLocal() as session: users, total = await user_crud.get_multi_with_total( - session, - skip=0, - limit=10, - filters={"is_active": True} + session, skip=0, limit=10, filters={"is_active": True} ) assert isinstance(users, list) assert total >= 0 @@ -807,30 +865,22 @@ class TestCRUDBasePaginationValidation: @pytest.mark.asyncio async def test_get_multi_with_total_with_sorting_desc(self, async_test_db): """Test pagination with descending sort (covers lines 283-284).""" - test_engine, SessionLocal = async_test_db + _test_engine, SessionLocal = async_test_db async with SessionLocal() as session: - users, total = await user_crud.get_multi_with_total( - session, - skip=0, - limit=10, - sort_by="created_at", - sort_order="desc" + users, _total = await user_crud.get_multi_with_total( + session, skip=0, limit=10, sort_by="created_at", sort_order="desc" ) assert isinstance(users, list) @pytest.mark.asyncio async def test_get_multi_with_total_with_sorting_asc(self, async_test_db): """Test pagination with ascending sort (covers lines 285-286).""" - test_engine, SessionLocal = async_test_db + _test_engine, SessionLocal = async_test_db async with SessionLocal() as session: - users, total = await user_crud.get_multi_with_total( - session, - skip=0, - limit=10, - sort_by="created_at", - sort_order="asc" + users, _total = await user_crud.get_multi_with_total( + session, skip=0, limit=10, sort_by="created_at", sort_order="asc" ) assert isinstance(users, list) @@ -842,13 +892,15 @@ class TestCRUDBaseModelsWithoutSoftDelete: """ @pytest.mark.asyncio - async def test_soft_delete_model_without_deleted_at(self, async_test_db, async_test_user): + async def test_soft_delete_model_without_deleted_at( + self, async_test_db, async_test_user + ): """Test soft_delete on Organization model (no deleted_at) raises ValueError (covers lines 342-343).""" - test_engine, SessionLocal = async_test_db + _test_engine, SessionLocal = async_test_db # Create an organization (which doesn't have deleted_at) - from app.models.organization import Organization from app.crud.organization import organization as org_crud + from app.models.organization import Organization async with SessionLocal() as session: org = Organization(name="Test Org", slug="test-org") @@ -864,11 +916,11 @@ class TestCRUDBaseModelsWithoutSoftDelete: @pytest.mark.asyncio async def test_restore_model_without_deleted_at(self, async_test_db): """Test restore on Organization model (no deleted_at) raises ValueError (covers lines 383-384).""" - test_engine, SessionLocal = async_test_db + _test_engine, SessionLocal = async_test_db # Create an organization (which doesn't have deleted_at) - from app.models.organization import Organization from app.crud.organization import organization as org_crud + from app.models.organization import Organization async with SessionLocal() as session: org = Organization(name="Restore Test", slug="restore-test") @@ -889,14 +941,17 @@ class TestCRUDBaseEagerLoadingWithRealOptions: """ @pytest.mark.asyncio - async def test_get_with_real_eager_loading_options(self, async_test_db, async_test_user): + async def test_get_with_real_eager_loading_options( + self, async_test_db, async_test_user + ): """Test get() with actual eager loading options (covers lines 77-78).""" - from datetime import datetime, timedelta, timezone - test_engine, SessionLocal = async_test_db + from datetime import datetime, timedelta + + _test_engine, SessionLocal = async_test_db # Create a session for the user - from app.models.user_session import UserSession from app.crud.session import session as session_crud + from app.models.user_session import UserSession async with SessionLocal() as session: user_session = UserSession( @@ -905,8 +960,8 @@ class TestCRUDBaseEagerLoadingWithRealOptions: device_id="test-device", ip_address="192.168.1.1", user_agent="Test Agent", - last_used_at=datetime.now(timezone.utc), - expires_at=datetime.now(timezone.utc) + timedelta(days=60) + last_used_at=datetime.now(UTC), + expires_at=datetime.now(UTC) + timedelta(days=60), ) session.add(user_session) await session.commit() @@ -917,7 +972,7 @@ class TestCRUDBaseEagerLoadingWithRealOptions: result = await session_crud.get( session, id=str(session_id), - options=[joinedload(UserSession.user)] # Real option, not empty list + options=[joinedload(UserSession.user)], # Real option, not empty list ) assert result is not None assert result.id == session_id @@ -925,14 +980,17 @@ class TestCRUDBaseEagerLoadingWithRealOptions: assert result.user.email == async_test_user.email @pytest.mark.asyncio - async def test_get_multi_with_real_eager_loading_options(self, async_test_db, async_test_user): + async def test_get_multi_with_real_eager_loading_options( + self, async_test_db, async_test_user + ): """Test get_multi() with actual eager loading options (covers lines 119-120).""" - from datetime import datetime, timedelta, timezone - test_engine, SessionLocal = async_test_db + from datetime import datetime, timedelta + + _test_engine, SessionLocal = async_test_db # Create multiple sessions for the user - from app.models.user_session import UserSession from app.crud.session import session as session_crud + from app.models.user_session import UserSession async with SessionLocal() as session: for i in range(3): @@ -942,8 +1000,8 @@ class TestCRUDBaseEagerLoadingWithRealOptions: device_id=f"device-{i}", ip_address=f"192.168.1.{i}", user_agent=f"Agent {i}", - last_used_at=datetime.now(timezone.utc), - expires_at=datetime.now(timezone.utc) + timedelta(days=60) + last_used_at=datetime.now(UTC), + expires_at=datetime.now(UTC) + timedelta(days=60), ) session.add(user_session) await session.commit() @@ -954,7 +1012,7 @@ class TestCRUDBaseEagerLoadingWithRealOptions: session, skip=0, limit=10, - options=[joinedload(UserSession.user)] # Real option, not empty list + options=[joinedload(UserSession.user)], # Real option, not empty list ) assert len(results) >= 3 # Verify we can access user without additional queries diff --git a/backend/tests/crud/test_base_db_failures.py b/backend/tests/crud/test_base_db_failures.py index e93bc41..36e0991 100644 --- a/backend/tests/crud/test_base_db_failures.py +++ b/backend/tests/crud/test_base_db_failures.py @@ -3,13 +3,15 @@ Comprehensive tests for base CRUD database failure scenarios. Tests exception handling, rollbacks, and error messages. """ -import pytest -from unittest.mock import AsyncMock, MagicMock, patch -from sqlalchemy.exc import IntegrityError, OperationalError, DataError + +from unittest.mock import AsyncMock, patch from uuid import uuid4 +import pytest +from sqlalchemy.exc import DataError, OperationalError + from app.crud.user import user as user_crud -from app.schemas.users import UserCreate, UserUpdate +from app.schemas.users import UserCreate class TestBaseCRUDCreateFailures: @@ -18,19 +20,24 @@ class TestBaseCRUDCreateFailures: @pytest.mark.asyncio async def test_create_operational_error_triggers_rollback(self, async_test_db): """Test that OperationalError triggers rollback (User CRUD catches as Exception).""" - test_engine, SessionLocal = async_test_db + _test_engine, SessionLocal = async_test_db async with SessionLocal() as session: - async def mock_commit(): - raise OperationalError("Connection lost", {}, Exception("DB connection failed")) - with patch.object(session, 'commit', side_effect=mock_commit): - with patch.object(session, 'rollback', new_callable=AsyncMock) as mock_rollback: + async def mock_commit(): + raise OperationalError( + "Connection lost", {}, Exception("DB connection failed") + ) + + with patch.object(session, "commit", side_effect=mock_commit): + with patch.object( + session, "rollback", new_callable=AsyncMock + ) as mock_rollback: user_data = UserCreate( email="operror@example.com", password="TestPassword123!", first_name="Test", - last_name="User" + last_name="User", ) # User CRUD catches this as generic Exception and re-raises @@ -43,19 +50,22 @@ class TestBaseCRUDCreateFailures: @pytest.mark.asyncio async def test_create_data_error_triggers_rollback(self, async_test_db): """Test that DataError triggers rollback (User CRUD catches as Exception).""" - test_engine, SessionLocal = async_test_db + _test_engine, SessionLocal = async_test_db async with SessionLocal() as session: + async def mock_commit(): raise DataError("Invalid data type", {}, Exception("Data overflow")) - with patch.object(session, 'commit', side_effect=mock_commit): - with patch.object(session, 'rollback', new_callable=AsyncMock) as mock_rollback: + with patch.object(session, "commit", side_effect=mock_commit): + with patch.object( + session, "rollback", new_callable=AsyncMock + ) as mock_rollback: user_data = UserCreate( email="dataerror@example.com", password="TestPassword123!", first_name="Test", - last_name="User" + last_name="User", ) # User CRUD catches this as generic Exception and re-raises @@ -67,19 +77,22 @@ class TestBaseCRUDCreateFailures: @pytest.mark.asyncio async def test_create_unexpected_exception_triggers_rollback(self, async_test_db): """Test that unexpected exceptions trigger rollback and re-raise.""" - test_engine, SessionLocal = async_test_db + _test_engine, SessionLocal = async_test_db async with SessionLocal() as session: + async def mock_commit(): raise RuntimeError("Unexpected database error") - with patch.object(session, 'commit', side_effect=mock_commit): - with patch.object(session, 'rollback', new_callable=AsyncMock) as mock_rollback: + with patch.object(session, "commit", side_effect=mock_commit): + with patch.object( + session, "rollback", new_callable=AsyncMock + ) as mock_rollback: user_data = UserCreate( email="unexpected@example.com", password="TestPassword123!", first_name="Test", - last_name="User" + last_name="User", ) with pytest.raises(RuntimeError, match="Unexpected database error"): @@ -94,7 +107,7 @@ class TestBaseCRUDUpdateFailures: @pytest.mark.asyncio async def test_update_operational_error(self, async_test_db, async_test_user): """Test update with OperationalError.""" - test_engine, SessionLocal = async_test_db + _test_engine, SessionLocal = async_test_db async with SessionLocal() as session: user = await user_crud.get(session, id=str(async_test_user.id)) @@ -102,17 +115,21 @@ class TestBaseCRUDUpdateFailures: async def mock_commit(): raise OperationalError("Connection timeout", {}, Exception("Timeout")) - with patch.object(session, 'commit', side_effect=mock_commit): - with patch.object(session, 'rollback', new_callable=AsyncMock) as mock_rollback: + with patch.object(session, "commit", side_effect=mock_commit): + with patch.object( + session, "rollback", new_callable=AsyncMock + ) as mock_rollback: with pytest.raises(ValueError, match="Database operation failed"): - await user_crud.update(session, db_obj=user, obj_in={"first_name": "Updated"}) + await user_crud.update( + session, db_obj=user, obj_in={"first_name": "Updated"} + ) mock_rollback.assert_called_once() @pytest.mark.asyncio async def test_update_data_error(self, async_test_db, async_test_user): """Test update with DataError.""" - test_engine, SessionLocal = async_test_db + _test_engine, SessionLocal = async_test_db async with SessionLocal() as session: user = await user_crud.get(session, id=str(async_test_user.id)) @@ -120,17 +137,21 @@ class TestBaseCRUDUpdateFailures: async def mock_commit(): raise DataError("Invalid data", {}, Exception("Data type mismatch")) - with patch.object(session, 'commit', side_effect=mock_commit): - with patch.object(session, 'rollback', new_callable=AsyncMock) as mock_rollback: + with patch.object(session, "commit", side_effect=mock_commit): + with patch.object( + session, "rollback", new_callable=AsyncMock + ) as mock_rollback: with pytest.raises(ValueError, match="Database operation failed"): - await user_crud.update(session, db_obj=user, obj_in={"first_name": "Updated"}) + await user_crud.update( + session, db_obj=user, obj_in={"first_name": "Updated"} + ) mock_rollback.assert_called_once() @pytest.mark.asyncio async def test_update_unexpected_error(self, async_test_db, async_test_user): """Test update with unexpected error.""" - test_engine, SessionLocal = async_test_db + _test_engine, SessionLocal = async_test_db async with SessionLocal() as session: user = await user_crud.get(session, id=str(async_test_user.id)) @@ -138,10 +159,14 @@ class TestBaseCRUDUpdateFailures: async def mock_commit(): raise KeyError("Unexpected error") - with patch.object(session, 'commit', side_effect=mock_commit): - with patch.object(session, 'rollback', new_callable=AsyncMock) as mock_rollback: + with patch.object(session, "commit", side_effect=mock_commit): + with patch.object( + session, "rollback", new_callable=AsyncMock + ) as mock_rollback: with pytest.raises(KeyError): - await user_crud.update(session, db_obj=user, obj_in={"first_name": "Updated"}) + await user_crud.update( + session, db_obj=user, obj_in={"first_name": "Updated"} + ) mock_rollback.assert_called_once() @@ -150,16 +175,21 @@ class TestBaseCRUDRemoveFailures: """Test base CRUD remove method exception handling.""" @pytest.mark.asyncio - async def test_remove_unexpected_error_triggers_rollback(self, async_test_db, async_test_user): + async def test_remove_unexpected_error_triggers_rollback( + self, async_test_db, async_test_user + ): """Test that unexpected errors in remove trigger rollback.""" - test_engine, SessionLocal = async_test_db + _test_engine, SessionLocal = async_test_db async with SessionLocal() as session: + async def mock_commit(): raise RuntimeError("Database write failed") - with patch.object(session, 'commit', side_effect=mock_commit): - with patch.object(session, 'rollback', new_callable=AsyncMock) as mock_rollback: + with patch.object(session, "commit", side_effect=mock_commit): + with patch.object( + session, "rollback", new_callable=AsyncMock + ) as mock_rollback: with pytest.raises(RuntimeError, match="Database write failed"): await user_crud.remove(session, id=str(async_test_user.id)) @@ -172,16 +202,15 @@ class TestBaseCRUDGetMultiWithTotalFailures: @pytest.mark.asyncio async def test_get_multi_with_total_database_error(self, async_test_db): """Test get_multi_with_total handles database errors.""" - test_engine, SessionLocal = async_test_db + _test_engine, SessionLocal = async_test_db async with SessionLocal() as session: # Mock execute to raise an error - original_execute = session.execute async def mock_execute(*args, **kwargs): raise OperationalError("Query failed", {}, Exception("Database error")) - with patch.object(session, 'execute', side_effect=mock_execute): + with patch.object(session, "execute", side_effect=mock_execute): with pytest.raises(OperationalError): await user_crud.get_multi_with_total(session, skip=0, limit=10) @@ -192,13 +221,14 @@ class TestBaseCRUDCountFailures: @pytest.mark.asyncio async def test_count_database_error_propagates(self, async_test_db): """Test count propagates database errors.""" - test_engine, SessionLocal = async_test_db + _test_engine, SessionLocal = async_test_db async with SessionLocal() as session: + async def mock_execute(*args, **kwargs): raise OperationalError("Count failed", {}, Exception("DB error")) - with patch.object(session, 'execute', side_effect=mock_execute): + with patch.object(session, "execute", side_effect=mock_execute): with pytest.raises(OperationalError): await user_crud.count(session) @@ -207,16 +237,21 @@ class TestBaseCRUDSoftDeleteFailures: """Test soft_delete method exception handling.""" @pytest.mark.asyncio - async def test_soft_delete_unexpected_error_triggers_rollback(self, async_test_db, async_test_user): + async def test_soft_delete_unexpected_error_triggers_rollback( + self, async_test_db, async_test_user + ): """Test soft_delete handles unexpected errors with rollback.""" - test_engine, SessionLocal = async_test_db + _test_engine, SessionLocal = async_test_db async with SessionLocal() as session: + async def mock_commit(): raise RuntimeError("Soft delete failed") - with patch.object(session, 'commit', side_effect=mock_commit): - with patch.object(session, 'rollback', new_callable=AsyncMock) as mock_rollback: + with patch.object(session, "commit", side_effect=mock_commit): + with patch.object( + session, "rollback", new_callable=AsyncMock + ) as mock_rollback: with pytest.raises(RuntimeError, match="Soft delete failed"): await user_crud.soft_delete(session, id=str(async_test_user.id)) @@ -229,7 +264,7 @@ class TestBaseCRUDRestoreFailures: @pytest.mark.asyncio async def test_restore_unexpected_error_triggers_rollback(self, async_test_db): """Test restore handles unexpected errors with rollback.""" - test_engine, SessionLocal = async_test_db + _test_engine, SessionLocal = async_test_db # First create and soft delete a user async with SessionLocal() as session: @@ -237,7 +272,7 @@ class TestBaseCRUDRestoreFailures: email="restore_test@example.com", password="TestPassword123!", first_name="Restore", - last_name="Test" + last_name="Test", ) user = await user_crud.create(session, obj_in=user_data) user_id = user.id @@ -248,11 +283,14 @@ class TestBaseCRUDRestoreFailures: # Now test restore failure async with SessionLocal() as session: + async def mock_commit(): raise RuntimeError("Restore failed") - with patch.object(session, 'commit', side_effect=mock_commit): - with patch.object(session, 'rollback', new_callable=AsyncMock) as mock_rollback: + with patch.object(session, "commit", side_effect=mock_commit): + with patch.object( + session, "rollback", new_callable=AsyncMock + ) as mock_rollback: with pytest.raises(RuntimeError, match="Restore failed"): await user_crud.restore(session, id=str(user_id)) @@ -265,13 +303,14 @@ class TestBaseCRUDGetFailures: @pytest.mark.asyncio async def test_get_database_error_propagates(self, async_test_db): """Test get propagates database errors.""" - test_engine, SessionLocal = async_test_db + _test_engine, SessionLocal = async_test_db async with SessionLocal() as session: + async def mock_execute(*args, **kwargs): raise OperationalError("Get failed", {}, Exception("DB error")) - with patch.object(session, 'execute', side_effect=mock_execute): + with patch.object(session, "execute", side_effect=mock_execute): with pytest.raises(OperationalError): await user_crud.get(session, id=str(uuid4())) @@ -282,12 +321,13 @@ class TestBaseCRUDGetMultiFailures: @pytest.mark.asyncio async def test_get_multi_database_error_propagates(self, async_test_db): """Test get_multi propagates database errors.""" - test_engine, SessionLocal = async_test_db + _test_engine, SessionLocal = async_test_db async with SessionLocal() as session: + async def mock_execute(*args, **kwargs): raise OperationalError("Query failed", {}, Exception("DB error")) - with patch.object(session, 'execute', side_effect=mock_execute): + with patch.object(session, "execute", side_effect=mock_execute): with pytest.raises(OperationalError): await user_crud.get_multi(session, skip=0, limit=10) diff --git a/backend/tests/crud/test_organization.py b/backend/tests/crud/test_organization.py index 935f4dd..1a7b2e9 100644 --- a/backend/tests/crud/test_organization.py +++ b/backend/tests/crud/test_organization.py @@ -2,17 +2,17 @@ """ Comprehensive tests for async organization CRUD operations. """ -import pytest + +from unittest.mock import AsyncMock, MagicMock, patch from uuid import uuid4 + +import pytest from sqlalchemy import select -from sqlalchemy.ext.asyncio import AsyncSession -from unittest.mock import patch, AsyncMock, MagicMock from app.crud.organization import organization as organization_crud from app.models.organization import Organization -from app.models.user_organization import UserOrganization, OrganizationRole -from app.models.user import User -from app.schemas.organizations import OrganizationCreate, OrganizationUpdate +from app.models.user_organization import OrganizationRole, UserOrganization +from app.schemas.organizations import OrganizationCreate class TestGetBySlug: @@ -21,14 +21,12 @@ class TestGetBySlug: @pytest.mark.asyncio async def test_get_by_slug_success(self, async_test_db): """Test successfully getting an organization by slug.""" - test_engine, AsyncTestingSessionLocal = async_test_db + _test_engine, AsyncTestingSessionLocal = async_test_db # Create test organization async with AsyncTestingSessionLocal() as session: org = Organization( - name="Test Org", - slug="test-org", - description="Test description" + name="Test Org", slug="test-org", description="Test description" ) session.add(org) await session.commit() @@ -44,7 +42,7 @@ class TestGetBySlug: @pytest.mark.asyncio async def test_get_by_slug_not_found(self, async_test_db): """Test getting non-existent organization by slug.""" - test_engine, AsyncTestingSessionLocal = async_test_db + _test_engine, AsyncTestingSessionLocal = async_test_db async with AsyncTestingSessionLocal() as session: result = await organization_crud.get_by_slug(session, slug="nonexistent") @@ -57,7 +55,7 @@ class TestCreate: @pytest.mark.asyncio async def test_create_success(self, async_test_db): """Test successfully creating an organization_crud.""" - test_engine, AsyncTestingSessionLocal = async_test_db + _test_engine, AsyncTestingSessionLocal = async_test_db async with AsyncTestingSessionLocal() as session: org_in = OrganizationCreate( @@ -65,7 +63,7 @@ class TestCreate: slug="new-org", description="New organization", is_active=True, - settings={"key": "value"} + settings={"key": "value"}, ) result = await organization_crud.create(session, obj_in=org_in) @@ -78,7 +76,7 @@ class TestCreate: @pytest.mark.asyncio async def test_create_duplicate_slug(self, async_test_db): """Test creating organization with duplicate slug raises error.""" - test_engine, AsyncTestingSessionLocal = async_test_db + _test_engine, AsyncTestingSessionLocal = async_test_db # Create first org async with AsyncTestingSessionLocal() as session: @@ -88,23 +86,17 @@ class TestCreate: # Try to create second with same slug async with AsyncTestingSessionLocal() as session: - org_in = OrganizationCreate( - name="Org 2", - slug="duplicate-slug" - ) + org_in = OrganizationCreate(name="Org 2", slug="duplicate-slug") with pytest.raises(ValueError, match="already exists"): await organization_crud.create(session, obj_in=org_in) @pytest.mark.asyncio async def test_create_without_settings(self, async_test_db): """Test creating organization without settings (defaults to empty dict).""" - test_engine, AsyncTestingSessionLocal = async_test_db + _test_engine, AsyncTestingSessionLocal = async_test_db async with AsyncTestingSessionLocal() as session: - org_in = OrganizationCreate( - name="No Settings Org", - slug="no-settings" - ) + org_in = OrganizationCreate(name="No Settings Org", slug="no-settings") result = await organization_crud.create(session, obj_in=org_in) assert result.settings == {} @@ -116,7 +108,7 @@ class TestGetMultiWithFilters: @pytest.mark.asyncio async def test_get_multi_with_filters_no_filters(self, async_test_db): """Test getting organizations without any filters.""" - test_engine, AsyncTestingSessionLocal = async_test_db + _test_engine, AsyncTestingSessionLocal = async_test_db # Create test organizations async with AsyncTestingSessionLocal() as session: @@ -133,7 +125,7 @@ class TestGetMultiWithFilters: @pytest.mark.asyncio async def test_get_multi_with_filters_is_active(self, async_test_db): """Test filtering by is_active.""" - test_engine, AsyncTestingSessionLocal = async_test_db + _test_engine, AsyncTestingSessionLocal = async_test_db async with AsyncTestingSessionLocal() as session: org1 = Organization(name="Active", slug="active", is_active=True) @@ -143,8 +135,7 @@ class TestGetMultiWithFilters: async with AsyncTestingSessionLocal() as session: orgs, total = await organization_crud.get_multi_with_filters( - session, - is_active=True + session, is_active=True ) assert total == 1 assert orgs[0].name == "Active" @@ -152,18 +143,21 @@ class TestGetMultiWithFilters: @pytest.mark.asyncio async def test_get_multi_with_filters_search(self, async_test_db): """Test searching organizations.""" - test_engine, AsyncTestingSessionLocal = async_test_db + _test_engine, AsyncTestingSessionLocal = async_test_db async with AsyncTestingSessionLocal() as session: - org1 = Organization(name="Tech Corp", slug="tech-corp", description="Technology") - org2 = Organization(name="Food Inc", slug="food-inc", description="Restaurant") + org1 = Organization( + name="Tech Corp", slug="tech-corp", description="Technology" + ) + org2 = Organization( + name="Food Inc", slug="food-inc", description="Restaurant" + ) session.add_all([org1, org2]) await session.commit() async with AsyncTestingSessionLocal() as session: orgs, total = await organization_crud.get_multi_with_filters( - session, - search="tech" + session, search="tech" ) assert total == 1 assert orgs[0].name == "Tech Corp" @@ -171,7 +165,7 @@ class TestGetMultiWithFilters: @pytest.mark.asyncio async def test_get_multi_with_filters_pagination(self, async_test_db): """Test pagination.""" - test_engine, AsyncTestingSessionLocal = async_test_db + _test_engine, AsyncTestingSessionLocal = async_test_db async with AsyncTestingSessionLocal() as session: for i in range(10): @@ -181,9 +175,7 @@ class TestGetMultiWithFilters: async with AsyncTestingSessionLocal() as session: orgs, total = await organization_crud.get_multi_with_filters( - session, - skip=2, - limit=3 + session, skip=2, limit=3 ) assert total == 10 assert len(orgs) == 3 @@ -191,7 +183,7 @@ class TestGetMultiWithFilters: @pytest.mark.asyncio async def test_get_multi_with_filters_sorting(self, async_test_db): """Test sorting.""" - test_engine, AsyncTestingSessionLocal = async_test_db + _test_engine, AsyncTestingSessionLocal = async_test_db async with AsyncTestingSessionLocal() as session: org1 = Organization(name="B Org", slug="b-org") @@ -200,10 +192,8 @@ class TestGetMultiWithFilters: await session.commit() async with AsyncTestingSessionLocal() as session: - orgs, total = await organization_crud.get_multi_with_filters( - session, - sort_by="name", - sort_order="asc" + orgs, _total = await organization_crud.get_multi_with_filters( + session, sort_by="name", sort_order="asc" ) assert orgs[0].name == "A Org" assert orgs[1].name == "B Org" @@ -215,7 +205,7 @@ class TestGetMemberCount: @pytest.mark.asyncio async def test_get_member_count_success(self, async_test_db, async_test_user): """Test getting member count for organization_crud.""" - test_engine, AsyncTestingSessionLocal = async_test_db + _test_engine, AsyncTestingSessionLocal = async_test_db async with AsyncTestingSessionLocal() as session: org = Organization(name="Test Org", slug="test-org") @@ -227,20 +217,22 @@ class TestGetMemberCount: user_id=async_test_user.id, organization_id=org.id, role=OrganizationRole.MEMBER, - is_active=True + is_active=True, ) session.add(user_org) await session.commit() org_id = org.id async with AsyncTestingSessionLocal() as session: - count = await organization_crud.get_member_count(session, organization_id=org_id) + count = await organization_crud.get_member_count( + session, organization_id=org_id + ) assert count == 1 @pytest.mark.asyncio async def test_get_member_count_no_members(self, async_test_db): """Test getting member count for organization with no members.""" - test_engine, AsyncTestingSessionLocal = async_test_db + _test_engine, AsyncTestingSessionLocal = async_test_db async with AsyncTestingSessionLocal() as session: org = Organization(name="Empty Org", slug="empty-org") @@ -249,7 +241,9 @@ class TestGetMemberCount: org_id = org.id async with AsyncTestingSessionLocal() as session: - count = await organization_crud.get_member_count(session, organization_id=org_id) + count = await organization_crud.get_member_count( + session, organization_id=org_id + ) assert count == 0 @@ -259,7 +253,7 @@ class TestAddUser: @pytest.mark.asyncio async def test_add_user_success(self, async_test_db, async_test_user): """Test successfully adding a user to organization_crud.""" - test_engine, AsyncTestingSessionLocal = async_test_db + _test_engine, AsyncTestingSessionLocal = async_test_db async with AsyncTestingSessionLocal() as session: org = Organization(name="Test Org", slug="test-org") @@ -272,7 +266,7 @@ class TestAddUser: session, organization_id=org_id, user_id=async_test_user.id, - role=OrganizationRole.ADMIN + role=OrganizationRole.ADMIN, ) assert result.user_id == async_test_user.id @@ -283,7 +277,7 @@ class TestAddUser: @pytest.mark.asyncio async def test_add_user_already_active_member(self, async_test_db, async_test_user): """Test adding user who is already an active member raises error.""" - test_engine, AsyncTestingSessionLocal = async_test_db + _test_engine, AsyncTestingSessionLocal = async_test_db async with AsyncTestingSessionLocal() as session: org = Organization(name="Test Org", slug="test-org") @@ -294,7 +288,7 @@ class TestAddUser: user_id=async_test_user.id, organization_id=org.id, role=OrganizationRole.MEMBER, - is_active=True + is_active=True, ) session.add(user_org) await session.commit() @@ -303,15 +297,13 @@ class TestAddUser: async with AsyncTestingSessionLocal() as session: with pytest.raises(ValueError, match="already a member"): await organization_crud.add_user( - session, - organization_id=org_id, - user_id=async_test_user.id + session, organization_id=org_id, user_id=async_test_user.id ) @pytest.mark.asyncio async def test_add_user_reactivate_inactive(self, async_test_db, async_test_user): """Test adding user who was previously inactive reactivates them.""" - test_engine, AsyncTestingSessionLocal = async_test_db + _test_engine, AsyncTestingSessionLocal = async_test_db async with AsyncTestingSessionLocal() as session: org = Organization(name="Test Org", slug="test-org") @@ -322,7 +314,7 @@ class TestAddUser: user_id=async_test_user.id, organization_id=org.id, role=OrganizationRole.MEMBER, - is_active=False + is_active=False, ) session.add(user_org) await session.commit() @@ -333,7 +325,7 @@ class TestAddUser: session, organization_id=org_id, user_id=async_test_user.id, - role=OrganizationRole.ADMIN + role=OrganizationRole.ADMIN, ) assert result.is_active is True @@ -346,7 +338,7 @@ class TestRemoveUser: @pytest.mark.asyncio async def test_remove_user_success(self, async_test_db, async_test_user): """Test successfully removing a user from organization_crud.""" - test_engine, AsyncTestingSessionLocal = async_test_db + _test_engine, AsyncTestingSessionLocal = async_test_db async with AsyncTestingSessionLocal() as session: org = Organization(name="Test Org", slug="test-org") @@ -357,7 +349,7 @@ class TestRemoveUser: user_id=async_test_user.id, organization_id=org.id, role=OrganizationRole.MEMBER, - is_active=True + is_active=True, ) session.add(user_org) await session.commit() @@ -365,9 +357,7 @@ class TestRemoveUser: async with AsyncTestingSessionLocal() as session: result = await organization_crud.remove_user( - session, - organization_id=org_id, - user_id=async_test_user.id + session, organization_id=org_id, user_id=async_test_user.id ) assert result is True @@ -376,7 +366,7 @@ class TestRemoveUser: async with AsyncTestingSessionLocal() as session: stmt = select(UserOrganization).where( UserOrganization.user_id == async_test_user.id, - UserOrganization.organization_id == org_id + UserOrganization.organization_id == org_id, ) result = await session.execute(stmt) user_org = result.scalar_one_or_none() @@ -385,7 +375,7 @@ class TestRemoveUser: @pytest.mark.asyncio async def test_remove_user_not_found(self, async_test_db): """Test removing non-existent user returns False.""" - test_engine, AsyncTestingSessionLocal = async_test_db + _test_engine, AsyncTestingSessionLocal = async_test_db async with AsyncTestingSessionLocal() as session: org = Organization(name="Test Org", slug="test-org") @@ -395,9 +385,7 @@ class TestRemoveUser: async with AsyncTestingSessionLocal() as session: result = await organization_crud.remove_user( - session, - organization_id=org_id, - user_id=uuid4() + session, organization_id=org_id, user_id=uuid4() ) assert result is False @@ -409,7 +397,7 @@ class TestUpdateUserRole: @pytest.mark.asyncio async def test_update_user_role_success(self, async_test_db, async_test_user): """Test successfully updating user role.""" - test_engine, AsyncTestingSessionLocal = async_test_db + _test_engine, AsyncTestingSessionLocal = async_test_db async with AsyncTestingSessionLocal() as session: org = Organization(name="Test Org", slug="test-org") @@ -420,7 +408,7 @@ class TestUpdateUserRole: user_id=async_test_user.id, organization_id=org.id, role=OrganizationRole.MEMBER, - is_active=True + is_active=True, ) session.add(user_org) await session.commit() @@ -432,7 +420,7 @@ class TestUpdateUserRole: organization_id=org_id, user_id=async_test_user.id, role=OrganizationRole.ADMIN, - custom_permissions="custom" + custom_permissions="custom", ) assert result.role == OrganizationRole.ADMIN @@ -441,7 +429,7 @@ class TestUpdateUserRole: @pytest.mark.asyncio async def test_update_user_role_not_found(self, async_test_db): """Test updating role for non-existent user returns None.""" - test_engine, AsyncTestingSessionLocal = async_test_db + _test_engine, AsyncTestingSessionLocal = async_test_db async with AsyncTestingSessionLocal() as session: org = Organization(name="Test Org", slug="test-org") @@ -454,7 +442,7 @@ class TestUpdateUserRole: session, organization_id=org_id, user_id=uuid4(), - role=OrganizationRole.ADMIN + role=OrganizationRole.ADMIN, ) assert result is None @@ -464,9 +452,11 @@ class TestGetOrganizationMembers: """Tests for get_organization_members method.""" @pytest.mark.asyncio - async def test_get_organization_members_success(self, async_test_db, async_test_user): + async def test_get_organization_members_success( + self, async_test_db, async_test_user + ): """Test getting organization members.""" - test_engine, AsyncTestingSessionLocal = async_test_db + _test_engine, AsyncTestingSessionLocal = async_test_db async with AsyncTestingSessionLocal() as session: org = Organization(name="Test Org", slug="test-org") @@ -477,7 +467,7 @@ class TestGetOrganizationMembers: user_id=async_test_user.id, organization_id=org.id, role=OrganizationRole.ADMIN, - is_active=True + is_active=True, ) session.add(user_org) await session.commit() @@ -485,8 +475,7 @@ class TestGetOrganizationMembers: async with AsyncTestingSessionLocal() as session: members, total = await organization_crud.get_organization_members( - session, - organization_id=org_id + session, organization_id=org_id ) assert total == 1 @@ -496,9 +485,11 @@ class TestGetOrganizationMembers: assert members[0]["role"] == OrganizationRole.ADMIN @pytest.mark.asyncio - async def test_get_organization_members_with_pagination(self, async_test_db, async_test_user): + async def test_get_organization_members_with_pagination( + self, async_test_db, async_test_user + ): """Test getting organization members with pagination.""" - test_engine, AsyncTestingSessionLocal = async_test_db + _test_engine, AsyncTestingSessionLocal = async_test_db async with AsyncTestingSessionLocal() as session: org = Organization(name="Test Org", slug="test-org") @@ -509,7 +500,7 @@ class TestGetOrganizationMembers: user_id=async_test_user.id, organization_id=org.id, role=OrganizationRole.MEMBER, - is_active=True + is_active=True, ) session.add(user_org) await session.commit() @@ -517,10 +508,7 @@ class TestGetOrganizationMembers: async with AsyncTestingSessionLocal() as session: members, total = await organization_crud.get_organization_members( - session, - organization_id=org_id, - skip=0, - limit=10 + session, organization_id=org_id, skip=0, limit=10 ) assert total == 1 @@ -533,7 +521,7 @@ class TestGetUserOrganizations: @pytest.mark.asyncio async def test_get_user_organizations_success(self, async_test_db, async_test_user): """Test getting user's organizations.""" - test_engine, AsyncTestingSessionLocal = async_test_db + _test_engine, AsyncTestingSessionLocal = async_test_db async with AsyncTestingSessionLocal() as session: org = Organization(name="Test Org", slug="test-org") @@ -544,24 +532,25 @@ class TestGetUserOrganizations: user_id=async_test_user.id, organization_id=org.id, role=OrganizationRole.MEMBER, - is_active=True + is_active=True, ) session.add(user_org) await session.commit() async with AsyncTestingSessionLocal() as session: orgs = await organization_crud.get_user_organizations( - session, - user_id=async_test_user.id + session, user_id=async_test_user.id ) assert len(orgs) == 1 assert orgs[0].name == "Test Org" @pytest.mark.asyncio - async def test_get_user_organizations_filter_inactive(self, async_test_db, async_test_user): + async def test_get_user_organizations_filter_inactive( + self, async_test_db, async_test_user + ): """Test filtering inactive organizations.""" - test_engine, AsyncTestingSessionLocal = async_test_db + _test_engine, AsyncTestingSessionLocal = async_test_db async with AsyncTestingSessionLocal() as session: org1 = Organization(name="Active Org", slug="active-org") @@ -573,22 +562,20 @@ class TestGetUserOrganizations: user_id=async_test_user.id, organization_id=org1.id, role=OrganizationRole.MEMBER, - is_active=True + is_active=True, ) user_org2 = UserOrganization( user_id=async_test_user.id, organization_id=org2.id, role=OrganizationRole.MEMBER, - is_active=False + is_active=False, ) session.add_all([user_org1, user_org2]) await session.commit() async with AsyncTestingSessionLocal() as session: orgs = await organization_crud.get_user_organizations( - session, - user_id=async_test_user.id, - is_active=True + session, user_id=async_test_user.id, is_active=True ) assert len(orgs) == 1 @@ -601,7 +588,7 @@ class TestGetUserRole: @pytest.mark.asyncio async def test_get_user_role_in_org_success(self, async_test_db, async_test_user): """Test getting user role in organization_crud.""" - test_engine, AsyncTestingSessionLocal = async_test_db + _test_engine, AsyncTestingSessionLocal = async_test_db async with AsyncTestingSessionLocal() as session: org = Organization(name="Test Org", slug="test-org") @@ -612,7 +599,7 @@ class TestGetUserRole: user_id=async_test_user.id, organization_id=org.id, role=OrganizationRole.ADMIN, - is_active=True + is_active=True, ) session.add(user_org) await session.commit() @@ -620,9 +607,7 @@ class TestGetUserRole: async with AsyncTestingSessionLocal() as session: role = await organization_crud.get_user_role_in_org( - session, - user_id=async_test_user.id, - organization_id=org_id + session, user_id=async_test_user.id, organization_id=org_id ) assert role == OrganizationRole.ADMIN @@ -630,7 +615,7 @@ class TestGetUserRole: @pytest.mark.asyncio async def test_get_user_role_in_org_not_found(self, async_test_db): """Test getting role for non-member returns None.""" - test_engine, AsyncTestingSessionLocal = async_test_db + _test_engine, AsyncTestingSessionLocal = async_test_db async with AsyncTestingSessionLocal() as session: org = Organization(name="Test Org", slug="test-org") @@ -640,9 +625,7 @@ class TestGetUserRole: async with AsyncTestingSessionLocal() as session: role = await organization_crud.get_user_role_in_org( - session, - user_id=uuid4(), - organization_id=org_id + session, user_id=uuid4(), organization_id=org_id ) assert role is None @@ -654,7 +637,7 @@ class TestIsUserOrgOwner: @pytest.mark.asyncio async def test_is_user_org_owner_true(self, async_test_db, async_test_user): """Test checking if user is owner.""" - test_engine, AsyncTestingSessionLocal = async_test_db + _test_engine, AsyncTestingSessionLocal = async_test_db async with AsyncTestingSessionLocal() as session: org = Organization(name="Test Org", slug="test-org") @@ -665,7 +648,7 @@ class TestIsUserOrgOwner: user_id=async_test_user.id, organization_id=org.id, role=OrganizationRole.OWNER, - is_active=True + is_active=True, ) session.add(user_org) await session.commit() @@ -673,9 +656,7 @@ class TestIsUserOrgOwner: async with AsyncTestingSessionLocal() as session: is_owner = await organization_crud.is_user_org_owner( - session, - user_id=async_test_user.id, - organization_id=org_id + session, user_id=async_test_user.id, organization_id=org_id ) assert is_owner is True @@ -683,7 +664,7 @@ class TestIsUserOrgOwner: @pytest.mark.asyncio async def test_is_user_org_owner_false(self, async_test_db, async_test_user): """Test checking if non-owner user is owner.""" - test_engine, AsyncTestingSessionLocal = async_test_db + _test_engine, AsyncTestingSessionLocal = async_test_db async with AsyncTestingSessionLocal() as session: org = Organization(name="Test Org", slug="test-org") @@ -694,7 +675,7 @@ class TestIsUserOrgOwner: user_id=async_test_user.id, organization_id=org.id, role=OrganizationRole.MEMBER, - is_active=True + is_active=True, ) session.add(user_org) await session.commit() @@ -702,9 +683,7 @@ class TestIsUserOrgOwner: async with AsyncTestingSessionLocal() as session: is_owner = await organization_crud.is_user_org_owner( - session, - user_id=async_test_user.id, - organization_id=org_id + session, user_id=async_test_user.id, organization_id=org_id ) assert is_owner is False @@ -714,9 +693,11 @@ class TestGetMultiWithMemberCounts: """Tests for get_multi_with_member_counts method.""" @pytest.mark.asyncio - async def test_get_multi_with_member_counts_success(self, async_test_db, async_test_user): + async def test_get_multi_with_member_counts_success( + self, async_test_db, async_test_user + ): """Test getting organizations with member counts.""" - test_engine, AsyncTestingSessionLocal = async_test_db + _test_engine, AsyncTestingSessionLocal = async_test_db async with AsyncTestingSessionLocal() as session: org1 = Organization(name="Org 1", slug="org-1") @@ -729,44 +710,51 @@ class TestGetMultiWithMemberCounts: user_id=async_test_user.id, organization_id=org1.id, role=OrganizationRole.MEMBER, - is_active=True + is_active=True, ) session.add(user_org1) await session.commit() async with AsyncTestingSessionLocal() as session: - orgs_with_counts, total = await organization_crud.get_multi_with_member_counts(session) + ( + orgs_with_counts, + total, + ) = await organization_crud.get_multi_with_member_counts(session) assert total == 2 assert len(orgs_with_counts) == 2 # Verify structure - assert 'organization' in orgs_with_counts[0] - assert 'member_count' in orgs_with_counts[0] + assert "organization" in orgs_with_counts[0] + assert "member_count" in orgs_with_counts[0] @pytest.mark.asyncio async def test_get_multi_with_member_counts_with_filters(self, async_test_db): """Test getting organizations with member counts and filters.""" - test_engine, AsyncTestingSessionLocal = async_test_db + _test_engine, AsyncTestingSessionLocal = async_test_db async with AsyncTestingSessionLocal() as session: org1 = Organization(name="Active Org", slug="active-org", is_active=True) - org2 = Organization(name="Inactive Org", slug="inactive-org", is_active=False) + org2 = Organization( + name="Inactive Org", slug="inactive-org", is_active=False + ) session.add_all([org1, org2]) await session.commit() async with AsyncTestingSessionLocal() as session: - orgs_with_counts, total = await organization_crud.get_multi_with_member_counts( - session, - is_active=True + ( + orgs_with_counts, + total, + ) = await organization_crud.get_multi_with_member_counts( + session, is_active=True ) assert total == 1 - assert orgs_with_counts[0]['organization'].name == "Active Org" + assert orgs_with_counts[0]["organization"].name == "Active Org" @pytest.mark.asyncio async def test_get_multi_with_member_counts_with_search(self, async_test_db): """Test searching organizations with member counts.""" - test_engine, AsyncTestingSessionLocal = async_test_db + _test_engine, AsyncTestingSessionLocal = async_test_db async with AsyncTestingSessionLocal() as session: org1 = Organization(name="Tech Corp", slug="tech-corp") @@ -775,22 +763,26 @@ class TestGetMultiWithMemberCounts: await session.commit() async with AsyncTestingSessionLocal() as session: - orgs_with_counts, total = await organization_crud.get_multi_with_member_counts( - session, - search="tech" + ( + orgs_with_counts, + total, + ) = await organization_crud.get_multi_with_member_counts( + session, search="tech" ) assert total == 1 - assert orgs_with_counts[0]['organization'].name == "Tech Corp" + assert orgs_with_counts[0]["organization"].name == "Tech Corp" class TestGetUserOrganizationsWithDetails: """Tests for get_user_organizations_with_details method.""" @pytest.mark.asyncio - async def test_get_user_organizations_with_details_success(self, async_test_db, async_test_user): + async def test_get_user_organizations_with_details_success( + self, async_test_db, async_test_user + ): """Test getting user organizations with role and member count.""" - test_engine, AsyncTestingSessionLocal = async_test_db + _test_engine, AsyncTestingSessionLocal = async_test_db async with AsyncTestingSessionLocal() as session: org = Organization(name="Test Org", slug="test-org") @@ -801,26 +793,29 @@ class TestGetUserOrganizationsWithDetails: user_id=async_test_user.id, organization_id=org.id, role=OrganizationRole.ADMIN, - is_active=True + is_active=True, ) session.add(user_org) await session.commit() async with AsyncTestingSessionLocal() as session: - orgs_with_details = await organization_crud.get_user_organizations_with_details( - session, - user_id=async_test_user.id + orgs_with_details = ( + await organization_crud.get_user_organizations_with_details( + session, user_id=async_test_user.id + ) ) assert len(orgs_with_details) == 1 - assert orgs_with_details[0]['organization'].name == "Test Org" - assert orgs_with_details[0]['role'] == OrganizationRole.ADMIN - assert 'member_count' in orgs_with_details[0] + assert orgs_with_details[0]["organization"].name == "Test Org" + assert orgs_with_details[0]["role"] == OrganizationRole.ADMIN + assert "member_count" in orgs_with_details[0] @pytest.mark.asyncio - async def test_get_user_organizations_with_details_filter_inactive(self, async_test_db, async_test_user): + async def test_get_user_organizations_with_details_filter_inactive( + self, async_test_db, async_test_user + ): """Test filtering inactive organizations in user details.""" - test_engine, AsyncTestingSessionLocal = async_test_db + _test_engine, AsyncTestingSessionLocal = async_test_db async with AsyncTestingSessionLocal() as session: org1 = Organization(name="Active Org", slug="active-org") @@ -832,26 +827,26 @@ class TestGetUserOrganizationsWithDetails: user_id=async_test_user.id, organization_id=org1.id, role=OrganizationRole.MEMBER, - is_active=True + is_active=True, ) user_org2 = UserOrganization( user_id=async_test_user.id, organization_id=org2.id, role=OrganizationRole.MEMBER, - is_active=False + is_active=False, ) session.add_all([user_org1, user_org2]) await session.commit() async with AsyncTestingSessionLocal() as session: - orgs_with_details = await organization_crud.get_user_organizations_with_details( - session, - user_id=async_test_user.id, - is_active=True + orgs_with_details = ( + await organization_crud.get_user_organizations_with_details( + session, user_id=async_test_user.id, is_active=True + ) ) assert len(orgs_with_details) == 1 - assert orgs_with_details[0]['organization'].name == "Active Org" + assert orgs_with_details[0]["organization"].name == "Active Org" class TestIsUserOrgAdmin: @@ -860,7 +855,7 @@ class TestIsUserOrgAdmin: @pytest.mark.asyncio async def test_is_user_org_admin_owner(self, async_test_db, async_test_user): """Test checking if owner is admin (should be True).""" - test_engine, AsyncTestingSessionLocal = async_test_db + _test_engine, AsyncTestingSessionLocal = async_test_db async with AsyncTestingSessionLocal() as session: org = Organization(name="Test Org", slug="test-org") @@ -871,7 +866,7 @@ class TestIsUserOrgAdmin: user_id=async_test_user.id, organization_id=org.id, role=OrganizationRole.OWNER, - is_active=True + is_active=True, ) session.add(user_org) await session.commit() @@ -879,9 +874,7 @@ class TestIsUserOrgAdmin: async with AsyncTestingSessionLocal() as session: is_admin = await organization_crud.is_user_org_admin( - session, - user_id=async_test_user.id, - organization_id=org_id + session, user_id=async_test_user.id, organization_id=org_id ) assert is_admin is True @@ -889,7 +882,7 @@ class TestIsUserOrgAdmin: @pytest.mark.asyncio async def test_is_user_org_admin_admin_role(self, async_test_db, async_test_user): """Test checking if admin role is admin.""" - test_engine, AsyncTestingSessionLocal = async_test_db + _test_engine, AsyncTestingSessionLocal = async_test_db async with AsyncTestingSessionLocal() as session: org = Organization(name="Test Org", slug="test-org") @@ -900,7 +893,7 @@ class TestIsUserOrgAdmin: user_id=async_test_user.id, organization_id=org.id, role=OrganizationRole.ADMIN, - is_active=True + is_active=True, ) session.add(user_org) await session.commit() @@ -908,9 +901,7 @@ class TestIsUserOrgAdmin: async with AsyncTestingSessionLocal() as session: is_admin = await organization_crud.is_user_org_admin( - session, - user_id=async_test_user.id, - organization_id=org_id + session, user_id=async_test_user.id, organization_id=org_id ) assert is_admin is True @@ -918,7 +909,7 @@ class TestIsUserOrgAdmin: @pytest.mark.asyncio async def test_is_user_org_admin_member_false(self, async_test_db, async_test_user): """Test checking if regular member is admin.""" - test_engine, AsyncTestingSessionLocal = async_test_db + _test_engine, AsyncTestingSessionLocal = async_test_db async with AsyncTestingSessionLocal() as session: org = Organization(name="Test Org", slug="test-org") @@ -929,7 +920,7 @@ class TestIsUserOrgAdmin: user_id=async_test_user.id, organization_id=org.id, role=OrganizationRole.MEMBER, - is_active=True + is_active=True, ) session.add(user_org) await session.commit() @@ -937,9 +928,7 @@ class TestIsUserOrgAdmin: async with AsyncTestingSessionLocal() as session: is_admin = await organization_crud.is_user_org_admin( - session, - user_id=async_test_user.id, - organization_id=org_id + session, user_id=async_test_user.id, organization_id=org_id ) assert is_admin is False @@ -955,10 +944,12 @@ class TestOrganizationExceptionHandlers: @pytest.mark.asyncio async def test_get_by_slug_database_error(self, async_test_db): """Test get_by_slug handles database errors (covers lines 33-35).""" - test_engine, AsyncTestingSessionLocal = async_test_db + _test_engine, AsyncTestingSessionLocal = async_test_db async with AsyncTestingSessionLocal() as session: - with patch.object(session, 'execute', side_effect=Exception("Database connection lost")): + with patch.object( + session, "execute", side_effect=Exception("Database connection lost") + ): with pytest.raises(Exception, match="Database connection lost"): await organization_crud.get_by_slug(session, slug="test-slug") @@ -966,16 +957,20 @@ class TestOrganizationExceptionHandlers: async def test_create_integrity_error_non_slug(self, async_test_db): """Test create with non-slug IntegrityError (covers lines 56-57).""" from sqlalchemy.exc import IntegrityError - test_engine, AsyncTestingSessionLocal = async_test_db + + _test_engine, AsyncTestingSessionLocal = async_test_db async with AsyncTestingSessionLocal() as session: + async def mock_commit(): - error = IntegrityError("statement", {}, Exception("foreign key constraint failed")) + error = IntegrityError( + "statement", {}, Exception("foreign key constraint failed") + ) error.orig = Exception("foreign key constraint failed") raise error - with patch.object(session, 'commit', side_effect=mock_commit): - with patch.object(session, 'rollback', new_callable=AsyncMock): + with patch.object(session, "commit", side_effect=mock_commit): + with patch.object(session, "rollback", new_callable=AsyncMock): org_in = OrganizationCreate(name="Test", slug="test") with pytest.raises(ValueError, match="Database integrity error"): await organization_crud.create(session, obj_in=org_in) @@ -983,11 +978,13 @@ class TestOrganizationExceptionHandlers: @pytest.mark.asyncio async def test_create_unexpected_error(self, async_test_db): """Test create with unexpected exception (covers lines 58-62).""" - test_engine, AsyncTestingSessionLocal = async_test_db + _test_engine, AsyncTestingSessionLocal = async_test_db async with AsyncTestingSessionLocal() as session: - with patch.object(session, 'commit', side_effect=RuntimeError("Unexpected error")): - with patch.object(session, 'rollback', new_callable=AsyncMock): + with patch.object( + session, "commit", side_effect=RuntimeError("Unexpected error") + ): + with patch.object(session, "rollback", new_callable=AsyncMock): org_in = OrganizationCreate(name="Test", slug="test") with pytest.raises(RuntimeError, match="Unexpected error"): await organization_crud.create(session, obj_in=org_in) @@ -995,10 +992,12 @@ class TestOrganizationExceptionHandlers: @pytest.mark.asyncio async def test_get_multi_with_filters_database_error(self, async_test_db): """Test get_multi_with_filters handles database errors (covers lines 114-116).""" - test_engine, AsyncTestingSessionLocal = async_test_db + _test_engine, AsyncTestingSessionLocal = async_test_db async with AsyncTestingSessionLocal() as session: - with patch.object(session, 'execute', side_effect=Exception("Query timeout")): + with patch.object( + session, "execute", side_effect=Exception("Query timeout") + ): with pytest.raises(Exception, match="Query timeout"): await organization_crud.get_multi_with_filters(session) @@ -1006,20 +1005,27 @@ class TestOrganizationExceptionHandlers: async def test_get_member_count_database_error(self, async_test_db): """Test get_member_count handles database errors (covers lines 130-132).""" from uuid import uuid4 - test_engine, AsyncTestingSessionLocal = async_test_db + + _test_engine, AsyncTestingSessionLocal = async_test_db async with AsyncTestingSessionLocal() as session: - with patch.object(session, 'execute', side_effect=Exception("Count query failed")): + with patch.object( + session, "execute", side_effect=Exception("Count query failed") + ): with pytest.raises(Exception, match="Count query failed"): - await organization_crud.get_member_count(session, organization_id=uuid4()) + await organization_crud.get_member_count( + session, organization_id=uuid4() + ) @pytest.mark.asyncio async def test_get_multi_with_member_counts_database_error(self, async_test_db): """Test get_multi_with_member_counts handles database errors (covers lines 207-209).""" - test_engine, AsyncTestingSessionLocal = async_test_db + _test_engine, AsyncTestingSessionLocal = async_test_db async with AsyncTestingSessionLocal() as session: - with patch.object(session, 'execute', side_effect=Exception("Complex query failed")): + with patch.object( + session, "execute", side_effect=Exception("Complex query failed") + ): with pytest.raises(Exception, match="Complex query failed"): await organization_crud.get_multi_with_member_counts(session) @@ -1027,8 +1033,8 @@ class TestOrganizationExceptionHandlers: async def test_add_user_integrity_error(self, async_test_db, async_test_user): """Test add_user with IntegrityError (covers lines 258-260).""" from sqlalchemy.exc import IntegrityError - from unittest.mock import MagicMock - test_engine, AsyncTestingSessionLocal = async_test_db + + _test_engine, AsyncTestingSessionLocal = async_test_db async with AsyncTestingSessionLocal() as session: # First create org @@ -1038,6 +1044,7 @@ class TestOrganizationExceptionHandlers: org_id = org.id async with AsyncTestingSessionLocal() as session: + async def mock_commit(): raise IntegrityError("statement", {}, Exception("constraint failed")) @@ -1047,89 +1054,117 @@ class TestOrganizationExceptionHandlers: result.scalar_one_or_none = MagicMock(return_value=None) return result - with patch.object(session, 'execute', side_effect=mock_execute): - with patch.object(session, 'commit', side_effect=mock_commit): - with patch.object(session, 'rollback', new_callable=AsyncMock): - with pytest.raises(ValueError, match="Failed to add user to organization"): + with patch.object(session, "execute", side_effect=mock_execute): + with patch.object(session, "commit", side_effect=mock_commit): + with patch.object(session, "rollback", new_callable=AsyncMock): + with pytest.raises( + ValueError, match="Failed to add user to organization" + ): await organization_crud.add_user( session, organization_id=org_id, - user_id=async_test_user.id + user_id=async_test_user.id, ) @pytest.mark.asyncio async def test_remove_user_database_error(self, async_test_db, async_test_user): """Test remove_user handles database errors (covers lines 291-294).""" from uuid import uuid4 - test_engine, AsyncTestingSessionLocal = async_test_db + + _test_engine, AsyncTestingSessionLocal = async_test_db async with AsyncTestingSessionLocal() as session: - with patch.object(session, 'execute', side_effect=Exception("Delete failed")): + with patch.object( + session, "execute", side_effect=Exception("Delete failed") + ): with pytest.raises(Exception, match="Delete failed"): await organization_crud.remove_user( - session, - organization_id=uuid4(), - user_id=async_test_user.id + session, organization_id=uuid4(), user_id=async_test_user.id ) @pytest.mark.asyncio - async def test_update_user_role_database_error(self, async_test_db, async_test_user): + async def test_update_user_role_database_error( + self, async_test_db, async_test_user + ): """Test update_user_role handles database errors (covers lines 326-329).""" from uuid import uuid4 - test_engine, AsyncTestingSessionLocal = async_test_db + + _test_engine, AsyncTestingSessionLocal = async_test_db async with AsyncTestingSessionLocal() as session: - with patch.object(session, 'execute', side_effect=Exception("Update failed")): + with patch.object( + session, "execute", side_effect=Exception("Update failed") + ): with pytest.raises(Exception, match="Update failed"): await organization_crud.update_user_role( session, organization_id=uuid4(), user_id=async_test_user.id, - role=OrganizationRole.ADMIN + role=OrganizationRole.ADMIN, ) @pytest.mark.asyncio async def test_get_organization_members_database_error(self, async_test_db): """Test get_organization_members handles database errors (covers lines 385-387).""" from uuid import uuid4 - test_engine, AsyncTestingSessionLocal = async_test_db + + _test_engine, AsyncTestingSessionLocal = async_test_db async with AsyncTestingSessionLocal() as session: - with patch.object(session, 'execute', side_effect=Exception("Members query failed")): + with patch.object( + session, "execute", side_effect=Exception("Members query failed") + ): with pytest.raises(Exception, match="Members query failed"): - await organization_crud.get_organization_members(session, organization_id=uuid4()) + await organization_crud.get_organization_members( + session, organization_id=uuid4() + ) @pytest.mark.asyncio - async def test_get_user_organizations_database_error(self, async_test_db, async_test_user): + async def test_get_user_organizations_database_error( + self, async_test_db, async_test_user + ): """Test get_user_organizations handles database errors (covers lines 409-411).""" - test_engine, AsyncTestingSessionLocal = async_test_db + _test_engine, AsyncTestingSessionLocal = async_test_db async with AsyncTestingSessionLocal() as session: - with patch.object(session, 'execute', side_effect=Exception("User orgs query failed")): + with patch.object( + session, "execute", side_effect=Exception("User orgs query failed") + ): with pytest.raises(Exception, match="User orgs query failed"): - await organization_crud.get_user_organizations(session, user_id=async_test_user.id) + await organization_crud.get_user_organizations( + session, user_id=async_test_user.id + ) @pytest.mark.asyncio - async def test_get_user_organizations_with_details_database_error(self, async_test_db, async_test_user): + async def test_get_user_organizations_with_details_database_error( + self, async_test_db, async_test_user + ): """Test get_user_organizations_with_details handles database errors (covers lines 466-468).""" - test_engine, AsyncTestingSessionLocal = async_test_db + _test_engine, AsyncTestingSessionLocal = async_test_db async with AsyncTestingSessionLocal() as session: - with patch.object(session, 'execute', side_effect=Exception("Details query failed")): + with patch.object( + session, "execute", side_effect=Exception("Details query failed") + ): with pytest.raises(Exception, match="Details query failed"): - await organization_crud.get_user_organizations_with_details(session, user_id=async_test_user.id) + await organization_crud.get_user_organizations_with_details( + session, user_id=async_test_user.id + ) @pytest.mark.asyncio - async def test_get_user_role_in_org_database_error(self, async_test_db, async_test_user): + async def test_get_user_role_in_org_database_error( + self, async_test_db, async_test_user + ): """Test get_user_role_in_org handles database errors (covers lines 491-493).""" from uuid import uuid4 - test_engine, AsyncTestingSessionLocal = async_test_db + + _test_engine, AsyncTestingSessionLocal = async_test_db async with AsyncTestingSessionLocal() as session: - with patch.object(session, 'execute', side_effect=Exception("Role query failed")): + with patch.object( + session, "execute", side_effect=Exception("Role query failed") + ): with pytest.raises(Exception, match="Role query failed"): await organization_crud.get_user_role_in_org( - session, - user_id=async_test_user.id, - organization_id=uuid4() + session, user_id=async_test_user.id, organization_id=uuid4() ) diff --git a/backend/tests/crud/test_session.py b/backend/tests/crud/test_session.py index 416c33f..8b540fa 100644 --- a/backend/tests/crud/test_session.py +++ b/backend/tests/crud/test_session.py @@ -2,10 +2,12 @@ """ Comprehensive tests for async session CRUD operations. """ -import pytest -from datetime import datetime, timedelta, timezone + +from datetime import UTC, datetime, timedelta from uuid import uuid4 +import pytest + from app.crud.session import session as session_crud from app.models.user_session import UserSession from app.schemas.sessions import SessionCreate @@ -17,7 +19,7 @@ class TestGetByJti: @pytest.mark.asyncio async def test_get_by_jti_success(self, async_test_db, async_test_user): """Test getting session by JTI.""" - test_engine, AsyncTestingSessionLocal = async_test_db + _test_engine, AsyncTestingSessionLocal = async_test_db async with AsyncTestingSessionLocal() as session: user_session = UserSession( @@ -27,8 +29,8 @@ class TestGetByJti: ip_address="192.168.1.1", user_agent="Mozilla/5.0", is_active=True, - expires_at=datetime.now(timezone.utc) + timedelta(days=7), - last_used_at=datetime.now(timezone.utc) + expires_at=datetime.now(UTC) + timedelta(days=7), + last_used_at=datetime.now(UTC), ) session.add(user_session) await session.commit() @@ -41,7 +43,7 @@ class TestGetByJti: @pytest.mark.asyncio async def test_get_by_jti_not_found(self, async_test_db): """Test getting non-existent JTI returns None.""" - test_engine, AsyncTestingSessionLocal = async_test_db + _test_engine, AsyncTestingSessionLocal = async_test_db async with AsyncTestingSessionLocal() as session: result = await session_crud.get_by_jti(session, jti="nonexistent") @@ -54,7 +56,7 @@ class TestGetActiveByJti: @pytest.mark.asyncio async def test_get_active_by_jti_success(self, async_test_db, async_test_user): """Test getting active session by JTI.""" - test_engine, AsyncTestingSessionLocal = async_test_db + _test_engine, AsyncTestingSessionLocal = async_test_db async with AsyncTestingSessionLocal() as session: user_session = UserSession( @@ -64,8 +66,8 @@ class TestGetActiveByJti: ip_address="192.168.1.1", user_agent="Mozilla/5.0", is_active=True, - expires_at=datetime.now(timezone.utc) + timedelta(days=7), - last_used_at=datetime.now(timezone.utc) + expires_at=datetime.now(UTC) + timedelta(days=7), + last_used_at=datetime.now(UTC), ) session.add(user_session) await session.commit() @@ -78,7 +80,7 @@ class TestGetActiveByJti: @pytest.mark.asyncio async def test_get_active_by_jti_inactive(self, async_test_db, async_test_user): """Test getting inactive session by JTI returns None.""" - test_engine, AsyncTestingSessionLocal = async_test_db + _test_engine, AsyncTestingSessionLocal = async_test_db async with AsyncTestingSessionLocal() as session: user_session = UserSession( @@ -88,8 +90,8 @@ class TestGetActiveByJti: ip_address="192.168.1.1", user_agent="Mozilla/5.0", is_active=False, - expires_at=datetime.now(timezone.utc) + timedelta(days=7), - last_used_at=datetime.now(timezone.utc) + expires_at=datetime.now(UTC) + timedelta(days=7), + last_used_at=datetime.now(UTC), ) session.add(user_session) await session.commit() @@ -105,7 +107,7 @@ class TestGetUserSessions: @pytest.mark.asyncio async def test_get_user_sessions_active_only(self, async_test_db, async_test_user): """Test getting only active user sessions.""" - test_engine, AsyncTestingSessionLocal = async_test_db + _test_engine, AsyncTestingSessionLocal = async_test_db async with AsyncTestingSessionLocal() as session: active = UserSession( @@ -115,8 +117,8 @@ class TestGetUserSessions: ip_address="192.168.1.1", user_agent="Mozilla/5.0", is_active=True, - expires_at=datetime.now(timezone.utc) + timedelta(days=7), - last_used_at=datetime.now(timezone.utc) + expires_at=datetime.now(UTC) + timedelta(days=7), + last_used_at=datetime.now(UTC), ) inactive = UserSession( user_id=async_test_user.id, @@ -125,17 +127,15 @@ class TestGetUserSessions: ip_address="192.168.1.2", user_agent="Mozilla/5.0", is_active=False, - expires_at=datetime.now(timezone.utc) + timedelta(days=7), - last_used_at=datetime.now(timezone.utc) + expires_at=datetime.now(UTC) + timedelta(days=7), + last_used_at=datetime.now(UTC), ) session.add_all([active, inactive]) await session.commit() async with AsyncTestingSessionLocal() as session: results = await session_crud.get_user_sessions( - session, - user_id=str(async_test_user.id), - active_only=True + session, user_id=str(async_test_user.id), active_only=True ) assert len(results) == 1 assert results[0].is_active is True @@ -143,7 +143,7 @@ class TestGetUserSessions: @pytest.mark.asyncio async def test_get_user_sessions_all(self, async_test_db, async_test_user): """Test getting all user sessions.""" - test_engine, AsyncTestingSessionLocal = async_test_db + _test_engine, AsyncTestingSessionLocal = async_test_db async with AsyncTestingSessionLocal() as session: for i in range(3): @@ -154,17 +154,15 @@ class TestGetUserSessions: ip_address="192.168.1.1", user_agent="Mozilla/5.0", is_active=i % 2 == 0, - expires_at=datetime.now(timezone.utc) + timedelta(days=7), - last_used_at=datetime.now(timezone.utc) + expires_at=datetime.now(UTC) + timedelta(days=7), + last_used_at=datetime.now(UTC), ) session.add(sess) await session.commit() async with AsyncTestingSessionLocal() as session: results = await session_crud.get_user_sessions( - session, - user_id=str(async_test_user.id), - active_only=False + session, user_id=str(async_test_user.id), active_only=False ) assert len(results) == 3 @@ -175,7 +173,7 @@ class TestCreateSession: @pytest.mark.asyncio async def test_create_session_success(self, async_test_db, async_test_user): """Test successfully creating a session_crud.""" - test_engine, AsyncTestingSessionLocal = async_test_db + _test_engine, AsyncTestingSessionLocal = async_test_db async with AsyncTestingSessionLocal() as session: session_data = SessionCreate( @@ -185,10 +183,10 @@ class TestCreateSession: device_id="device_123", ip_address="192.168.1.100", user_agent="Mozilla/5.0", - last_used_at=datetime.now(timezone.utc), - expires_at=datetime.now(timezone.utc) + timedelta(days=7), + last_used_at=datetime.now(UTC), + expires_at=datetime.now(UTC) + timedelta(days=7), location_city="San Francisco", - location_country="USA" + location_country="USA", ) result = await session_crud.create_session(session, obj_in=session_data) @@ -204,7 +202,7 @@ class TestDeactivate: @pytest.mark.asyncio async def test_deactivate_success(self, async_test_db, async_test_user): """Test successfully deactivating a session_crud.""" - test_engine, AsyncTestingSessionLocal = async_test_db + _test_engine, AsyncTestingSessionLocal = async_test_db async with AsyncTestingSessionLocal() as session: user_session = UserSession( @@ -214,8 +212,8 @@ class TestDeactivate: ip_address="192.168.1.1", user_agent="Mozilla/5.0", is_active=True, - expires_at=datetime.now(timezone.utc) + timedelta(days=7), - last_used_at=datetime.now(timezone.utc) + expires_at=datetime.now(UTC) + timedelta(days=7), + last_used_at=datetime.now(UTC), ) session.add(user_session) await session.commit() @@ -229,7 +227,7 @@ class TestDeactivate: @pytest.mark.asyncio async def test_deactivate_not_found(self, async_test_db): """Test deactivating non-existent session returns None.""" - test_engine, AsyncTestingSessionLocal = async_test_db + _test_engine, AsyncTestingSessionLocal = async_test_db async with AsyncTestingSessionLocal() as session: result = await session_crud.deactivate(session, session_id=str(uuid4())) @@ -240,9 +238,11 @@ class TestDeactivateAllUserSessions: """Tests for deactivate_all_user_sessions method.""" @pytest.mark.asyncio - async def test_deactivate_all_user_sessions_success(self, async_test_db, async_test_user): + async def test_deactivate_all_user_sessions_success( + self, async_test_db, async_test_user + ): """Test deactivating all user sessions.""" - test_engine, AsyncTestingSessionLocal = async_test_db + _test_engine, AsyncTestingSessionLocal = async_test_db async with AsyncTestingSessionLocal() as session: # Create minimal sessions for test (2 instead of 5) @@ -254,16 +254,15 @@ class TestDeactivateAllUserSessions: ip_address="192.168.1.1", user_agent="Mozilla/5.0", is_active=True, - expires_at=datetime.now(timezone.utc) + timedelta(days=7), - last_used_at=datetime.now(timezone.utc) + expires_at=datetime.now(UTC) + timedelta(days=7), + last_used_at=datetime.now(UTC), ) session.add(sess) await session.commit() async with AsyncTestingSessionLocal() as session: count = await session_crud.deactivate_all_user_sessions( - session, - user_id=str(async_test_user.id) + session, user_id=str(async_test_user.id) ) assert count == 2 @@ -274,7 +273,7 @@ class TestUpdateLastUsed: @pytest.mark.asyncio async def test_update_last_used_success(self, async_test_db, async_test_user): """Test updating last_used_at timestamp.""" - test_engine, AsyncTestingSessionLocal = async_test_db + _test_engine, AsyncTestingSessionLocal = async_test_db async with AsyncTestingSessionLocal() as session: user_session = UserSession( @@ -284,8 +283,8 @@ class TestUpdateLastUsed: ip_address="192.168.1.1", user_agent="Mozilla/5.0", is_active=True, - expires_at=datetime.now(timezone.utc) + timedelta(days=7), - last_used_at=datetime.now(timezone.utc) - timedelta(hours=1) + expires_at=datetime.now(UTC) + timedelta(days=7), + last_used_at=datetime.now(UTC) - timedelta(hours=1), ) session.add(user_session) await session.commit() @@ -303,7 +302,7 @@ class TestGetUserSessionCount: @pytest.mark.asyncio async def test_get_user_session_count_success(self, async_test_db, async_test_user): """Test getting user session count.""" - test_engine, AsyncTestingSessionLocal = async_test_db + _test_engine, AsyncTestingSessionLocal = async_test_db async with AsyncTestingSessionLocal() as session: for i in range(3): @@ -314,28 +313,26 @@ class TestGetUserSessionCount: ip_address="192.168.1.1", user_agent="Mozilla/5.0", is_active=True, - expires_at=datetime.now(timezone.utc) + timedelta(days=7), - last_used_at=datetime.now(timezone.utc) + expires_at=datetime.now(UTC) + timedelta(days=7), + last_used_at=datetime.now(UTC), ) session.add(sess) await session.commit() async with AsyncTestingSessionLocal() as session: count = await session_crud.get_user_session_count( - session, - user_id=str(async_test_user.id) + session, user_id=str(async_test_user.id) ) assert count == 3 @pytest.mark.asyncio async def test_get_user_session_count_empty(self, async_test_db): """Test getting session count for user with no sessions.""" - test_engine, AsyncTestingSessionLocal = async_test_db + _test_engine, AsyncTestingSessionLocal = async_test_db async with AsyncTestingSessionLocal() as session: count = await session_crud.get_user_session_count( - session, - user_id=str(uuid4()) + session, user_id=str(uuid4()) ) assert count == 0 @@ -346,7 +343,7 @@ class TestUpdateRefreshToken: @pytest.mark.asyncio async def test_update_refresh_token_success(self, async_test_db, async_test_user): """Test updating refresh token JTI and expiration.""" - test_engine, AsyncTestingSessionLocal = async_test_db + _test_engine, AsyncTestingSessionLocal = async_test_db async with AsyncTestingSessionLocal() as session: user_session = UserSession( @@ -356,26 +353,34 @@ class TestUpdateRefreshToken: ip_address="192.168.1.1", user_agent="Mozilla/5.0", is_active=True, - expires_at=datetime.now(timezone.utc) + timedelta(days=7), - last_used_at=datetime.now(timezone.utc) - timedelta(hours=1) + expires_at=datetime.now(UTC) + timedelta(days=7), + last_used_at=datetime.now(UTC) - timedelta(hours=1), ) session.add(user_session) await session.commit() await session.refresh(user_session) new_jti = "new_jti_123" - new_expires = datetime.now(timezone.utc) + timedelta(days=14) + new_expires = datetime.now(UTC) + timedelta(days=14) result = await session_crud.update_refresh_token( session, session=user_session, new_jti=new_jti, - new_expires_at=new_expires + new_expires_at=new_expires, ) assert result.refresh_token_jti == new_jti # Compare timestamps ignoring timezone info - assert abs((result.expires_at.replace(tzinfo=None) - new_expires.replace(tzinfo=None)).total_seconds()) < 1 + assert ( + abs( + ( + result.expires_at.replace(tzinfo=None) + - new_expires.replace(tzinfo=None) + ).total_seconds() + ) + < 1 + ) class TestCleanupExpired: @@ -384,7 +389,7 @@ class TestCleanupExpired: @pytest.mark.asyncio async def test_cleanup_expired_success(self, async_test_db, async_test_user): """Test cleaning up old expired inactive sessions.""" - test_engine, AsyncTestingSessionLocal = async_test_db + _test_engine, AsyncTestingSessionLocal = async_test_db # Create old expired inactive session async with AsyncTestingSessionLocal() as session: @@ -395,9 +400,9 @@ class TestCleanupExpired: ip_address="192.168.1.1", user_agent="Mozilla/5.0", is_active=False, - expires_at=datetime.now(timezone.utc) - timedelta(days=5), - last_used_at=datetime.now(timezone.utc) - timedelta(days=35), - created_at=datetime.now(timezone.utc) - timedelta(days=35) + expires_at=datetime.now(UTC) - timedelta(days=5), + last_used_at=datetime.now(UTC) - timedelta(days=35), + created_at=datetime.now(UTC) - timedelta(days=35), ) session.add(old_session) await session.commit() @@ -410,7 +415,7 @@ class TestCleanupExpired: @pytest.mark.asyncio async def test_cleanup_expired_keeps_recent(self, async_test_db, async_test_user): """Test that cleanup keeps recent expired sessions.""" - test_engine, AsyncTestingSessionLocal = async_test_db + _test_engine, AsyncTestingSessionLocal = async_test_db # Create recent expired inactive session (less than keep_days old) async with AsyncTestingSessionLocal() as session: @@ -421,9 +426,9 @@ class TestCleanupExpired: ip_address="192.168.1.1", user_agent="Mozilla/5.0", is_active=False, - expires_at=datetime.now(timezone.utc) - timedelta(hours=1), - last_used_at=datetime.now(timezone.utc) - timedelta(hours=2), - created_at=datetime.now(timezone.utc) - timedelta(days=1) + expires_at=datetime.now(UTC) - timedelta(hours=1), + last_used_at=datetime.now(UTC) - timedelta(hours=2), + created_at=datetime.now(UTC) - timedelta(days=1), ) session.add(recent_session) await session.commit() @@ -436,7 +441,7 @@ class TestCleanupExpired: @pytest.mark.asyncio async def test_cleanup_expired_keeps_active(self, async_test_db, async_test_user): """Test that cleanup does not delete active sessions.""" - test_engine, AsyncTestingSessionLocal = async_test_db + _test_engine, AsyncTestingSessionLocal = async_test_db # Create old expired but ACTIVE session async with AsyncTestingSessionLocal() as session: @@ -447,9 +452,9 @@ class TestCleanupExpired: ip_address="192.168.1.1", user_agent="Mozilla/5.0", is_active=True, # Active - expires_at=datetime.now(timezone.utc) - timedelta(days=5), - last_used_at=datetime.now(timezone.utc) - timedelta(days=35), - created_at=datetime.now(timezone.utc) - timedelta(days=35) + expires_at=datetime.now(UTC) - timedelta(days=5), + last_used_at=datetime.now(UTC) - timedelta(days=35), + created_at=datetime.now(UTC) - timedelta(days=35), ) session.add(active_session) await session.commit() @@ -464,9 +469,11 @@ class TestCleanupExpiredForUser: """Tests for cleanup_expired_for_user method.""" @pytest.mark.asyncio - async def test_cleanup_expired_for_user_success(self, async_test_db, async_test_user): + async def test_cleanup_expired_for_user_success( + self, async_test_db, async_test_user + ): """Test cleaning up expired sessions for specific user.""" - test_engine, AsyncTestingSessionLocal = async_test_db + _test_engine, AsyncTestingSessionLocal = async_test_db # Create expired inactive session for user async with AsyncTestingSessionLocal() as session: @@ -477,8 +484,8 @@ class TestCleanupExpiredForUser: ip_address="192.168.1.1", user_agent="Mozilla/5.0", is_active=False, - expires_at=datetime.now(timezone.utc) - timedelta(days=1), - last_used_at=datetime.now(timezone.utc) - timedelta(days=2) + expires_at=datetime.now(UTC) - timedelta(days=1), + last_used_at=datetime.now(UTC) - timedelta(days=2), ) session.add(expired_session) await session.commit() @@ -486,27 +493,27 @@ class TestCleanupExpiredForUser: # Cleanup for user async with AsyncTestingSessionLocal() as session: count = await session_crud.cleanup_expired_for_user( - session, - user_id=str(async_test_user.id) + session, user_id=str(async_test_user.id) ) assert count == 1 @pytest.mark.asyncio async def test_cleanup_expired_for_user_invalid_uuid(self, async_test_db): """Test cleanup with invalid user UUID.""" - test_engine, AsyncTestingSessionLocal = async_test_db + _test_engine, AsyncTestingSessionLocal = async_test_db async with AsyncTestingSessionLocal() as session: with pytest.raises(ValueError, match="Invalid user ID format"): await session_crud.cleanup_expired_for_user( - session, - user_id="not-a-valid-uuid" + session, user_id="not-a-valid-uuid" ) @pytest.mark.asyncio - async def test_cleanup_expired_for_user_keeps_active(self, async_test_db, async_test_user): + async def test_cleanup_expired_for_user_keeps_active( + self, async_test_db, async_test_user + ): """Test that cleanup for user keeps active sessions.""" - test_engine, AsyncTestingSessionLocal = async_test_db + _test_engine, AsyncTestingSessionLocal = async_test_db # Create expired but active session async with AsyncTestingSessionLocal() as session: @@ -517,8 +524,8 @@ class TestCleanupExpiredForUser: ip_address="192.168.1.1", user_agent="Mozilla/5.0", is_active=True, # Active - expires_at=datetime.now(timezone.utc) - timedelta(days=1), - last_used_at=datetime.now(timezone.utc) - timedelta(days=2) + expires_at=datetime.now(UTC) - timedelta(days=1), + last_used_at=datetime.now(UTC) - timedelta(days=2), ) session.add(active_session) await session.commit() @@ -526,8 +533,7 @@ class TestCleanupExpiredForUser: # Cleanup async with AsyncTestingSessionLocal() as session: count = await session_crud.cleanup_expired_for_user( - session, - user_id=str(async_test_user.id) + session, user_id=str(async_test_user.id) ) assert count == 0 # Should not delete active sessions @@ -536,9 +542,11 @@ class TestGetUserSessionsWithUser: """Tests for get_user_sessions with eager loading.""" @pytest.mark.asyncio - async def test_get_user_sessions_with_user_relationship(self, async_test_db, async_test_user): + async def test_get_user_sessions_with_user_relationship( + self, async_test_db, async_test_user + ): """Test getting sessions with user relationship loaded.""" - test_engine, AsyncTestingSessionLocal = async_test_db + _test_engine, AsyncTestingSessionLocal = async_test_db async with AsyncTestingSessionLocal() as session: user_session = UserSession( @@ -548,8 +556,8 @@ class TestGetUserSessionsWithUser: ip_address="192.168.1.1", user_agent="Mozilla/5.0", is_active=True, - expires_at=datetime.now(timezone.utc) + timedelta(days=7), - last_used_at=datetime.now(timezone.utc) + expires_at=datetime.now(UTC) + timedelta(days=7), + last_used_at=datetime.now(UTC), ) session.add(user_session) await session.commit() @@ -557,8 +565,6 @@ class TestGetUserSessionsWithUser: # Get with user relationship async with AsyncTestingSessionLocal() as session: results = await session_crud.get_user_sessions( - session, - user_id=str(async_test_user.id), - with_user=True + session, user_id=str(async_test_user.id), with_user=True ) assert len(results) >= 1 diff --git a/backend/tests/crud/test_session_db_failures.py b/backend/tests/crud/test_session_db_failures.py index e7dd5d2..dabf0a1 100644 --- a/backend/tests/crud/test_session_db_failures.py +++ b/backend/tests/crud/test_session_db_failures.py @@ -2,12 +2,14 @@ """ Comprehensive tests for session CRUD database failure scenarios. """ -import pytest + +from datetime import UTC, datetime, timedelta from unittest.mock import AsyncMock, patch -from sqlalchemy.exc import OperationalError, IntegrityError -from datetime import datetime, timedelta, timezone from uuid import uuid4 +import pytest +from sqlalchemy.exc import OperationalError + from app.crud.session import session as session_crud from app.models.user_session import UserSession from app.schemas.sessions import SessionCreate @@ -19,13 +21,14 @@ class TestSessionCRUDGetByJtiFailures: @pytest.mark.asyncio async def test_get_by_jti_database_error(self, async_test_db): """Test get_by_jti handles database errors.""" - test_engine, SessionLocal = async_test_db + _test_engine, SessionLocal = async_test_db async with SessionLocal() as session: + async def mock_execute(*args, **kwargs): raise OperationalError("DB connection lost", {}, Exception()) - with patch.object(session, 'execute', side_effect=mock_execute): + with patch.object(session, "execute", side_effect=mock_execute): with pytest.raises(OperationalError): await session_crud.get_by_jti(session, jti="test_jti") @@ -36,13 +39,14 @@ class TestSessionCRUDGetActiveByJtiFailures: @pytest.mark.asyncio async def test_get_active_by_jti_database_error(self, async_test_db): """Test get_active_by_jti handles database errors.""" - test_engine, SessionLocal = async_test_db + _test_engine, SessionLocal = async_test_db async with SessionLocal() as session: + async def mock_execute(*args, **kwargs): raise OperationalError("Query timeout", {}, Exception()) - with patch.object(session, 'execute', side_effect=mock_execute): + with patch.object(session, "execute", side_effect=mock_execute): with pytest.raises(OperationalError): await session_crud.get_active_by_jti(session, jti="test_jti") @@ -51,19 +55,21 @@ class TestSessionCRUDGetUserSessionsFailures: """Test get_user_sessions exception handling.""" @pytest.mark.asyncio - async def test_get_user_sessions_database_error(self, async_test_db, async_test_user): + async def test_get_user_sessions_database_error( + self, async_test_db, async_test_user + ): """Test get_user_sessions handles database errors.""" - test_engine, SessionLocal = async_test_db + _test_engine, SessionLocal = async_test_db async with SessionLocal() as session: + async def mock_execute(*args, **kwargs): raise OperationalError("Database error", {}, Exception()) - with patch.object(session, 'execute', side_effect=mock_execute): + with patch.object(session, "execute", side_effect=mock_execute): with pytest.raises(OperationalError): await session_crud.get_user_sessions( - session, - user_id=str(async_test_user.id) + session, user_id=str(async_test_user.id) ) @@ -71,24 +77,29 @@ class TestSessionCRUDCreateSessionFailures: """Test create_session exception handling.""" @pytest.mark.asyncio - async def test_create_session_commit_failure_triggers_rollback(self, async_test_db, async_test_user): + async def test_create_session_commit_failure_triggers_rollback( + self, async_test_db, async_test_user + ): """Test create_session handles commit failures with rollback.""" - test_engine, SessionLocal = async_test_db + _test_engine, SessionLocal = async_test_db async with SessionLocal() as session: + async def mock_commit(): raise OperationalError("Commit failed", {}, Exception()) - with patch.object(session, 'commit', side_effect=mock_commit): - with patch.object(session, 'rollback', new_callable=AsyncMock) as mock_rollback: + with patch.object(session, "commit", side_effect=mock_commit): + with patch.object( + session, "rollback", new_callable=AsyncMock + ) as mock_rollback: session_data = SessionCreate( user_id=async_test_user.id, refresh_token_jti=str(uuid4()), device_name="Test Device", ip_address="127.0.0.1", user_agent="Test Agent", - expires_at=datetime.now(timezone.utc) + timedelta(days=7), - last_used_at=datetime.now(timezone.utc) + expires_at=datetime.now(UTC) + timedelta(days=7), + last_used_at=datetime.now(UTC), ) with pytest.raises(ValueError, match="Failed to create session"): @@ -97,24 +108,29 @@ class TestSessionCRUDCreateSessionFailures: mock_rollback.assert_called_once() @pytest.mark.asyncio - async def test_create_session_unexpected_error_triggers_rollback(self, async_test_db, async_test_user): + async def test_create_session_unexpected_error_triggers_rollback( + self, async_test_db, async_test_user + ): """Test create_session handles unexpected errors.""" - test_engine, SessionLocal = async_test_db + _test_engine, SessionLocal = async_test_db async with SessionLocal() as session: + async def mock_commit(): raise RuntimeError("Unexpected error") - with patch.object(session, 'commit', side_effect=mock_commit): - with patch.object(session, 'rollback', new_callable=AsyncMock) as mock_rollback: + with patch.object(session, "commit", side_effect=mock_commit): + with patch.object( + session, "rollback", new_callable=AsyncMock + ) as mock_rollback: session_data = SessionCreate( user_id=async_test_user.id, refresh_token_jti=str(uuid4()), device_name="Test Device", ip_address="127.0.0.1", user_agent="Test Agent", - expires_at=datetime.now(timezone.utc) + timedelta(days=7), - last_used_at=datetime.now(timezone.utc) + expires_at=datetime.now(UTC) + timedelta(days=7), + last_used_at=datetime.now(UTC), ) with pytest.raises(ValueError, match="Failed to create session"): @@ -127,9 +143,11 @@ class TestSessionCRUDDeactivateFailures: """Test deactivate exception handling.""" @pytest.mark.asyncio - async def test_deactivate_commit_failure_triggers_rollback(self, async_test_db, async_test_user): + async def test_deactivate_commit_failure_triggers_rollback( + self, async_test_db, async_test_user + ): """Test deactivate handles commit failures.""" - test_engine, SessionLocal = async_test_db + _test_engine, SessionLocal = async_test_db # Create a session first async with SessionLocal() as session: @@ -140,8 +158,8 @@ class TestSessionCRUDDeactivateFailures: ip_address="127.0.0.1", user_agent="Test Agent", is_active=True, - expires_at=datetime.now(timezone.utc) + timedelta(days=7), - last_used_at=datetime.now(timezone.utc) + expires_at=datetime.now(UTC) + timedelta(days=7), + last_used_at=datetime.now(UTC), ) session.add(user_session) await session.commit() @@ -150,13 +168,18 @@ class TestSessionCRUDDeactivateFailures: # Test deactivate failure async with SessionLocal() as session: + async def mock_commit(): raise OperationalError("Deactivate failed", {}, Exception()) - with patch.object(session, 'commit', side_effect=mock_commit): - with patch.object(session, 'rollback', new_callable=AsyncMock) as mock_rollback: + with patch.object(session, "commit", side_effect=mock_commit): + with patch.object( + session, "rollback", new_callable=AsyncMock + ) as mock_rollback: with pytest.raises(OperationalError): - await session_crud.deactivate(session, session_id=str(session_id)) + await session_crud.deactivate( + session, session_id=str(session_id) + ) mock_rollback.assert_called_once() @@ -165,20 +188,24 @@ class TestSessionCRUDDeactivateAllFailures: """Test deactivate_all_user_sessions exception handling.""" @pytest.mark.asyncio - async def test_deactivate_all_commit_failure_triggers_rollback(self, async_test_db, async_test_user): + async def test_deactivate_all_commit_failure_triggers_rollback( + self, async_test_db, async_test_user + ): """Test deactivate_all handles commit failures.""" - test_engine, SessionLocal = async_test_db + _test_engine, SessionLocal = async_test_db async with SessionLocal() as session: + async def mock_commit(): raise OperationalError("Bulk deactivate failed", {}, Exception()) - with patch.object(session, 'commit', side_effect=mock_commit): - with patch.object(session, 'rollback', new_callable=AsyncMock) as mock_rollback: + with patch.object(session, "commit", side_effect=mock_commit): + with patch.object( + session, "rollback", new_callable=AsyncMock + ) as mock_rollback: with pytest.raises(OperationalError): await session_crud.deactivate_all_user_sessions( - session, - user_id=str(async_test_user.id) + session, user_id=str(async_test_user.id) ) mock_rollback.assert_called_once() @@ -188,9 +215,11 @@ class TestSessionCRUDUpdateLastUsedFailures: """Test update_last_used exception handling.""" @pytest.mark.asyncio - async def test_update_last_used_commit_failure_triggers_rollback(self, async_test_db, async_test_user): + async def test_update_last_used_commit_failure_triggers_rollback( + self, async_test_db, async_test_user + ): """Test update_last_used handles commit failures.""" - test_engine, SessionLocal = async_test_db + _test_engine, SessionLocal = async_test_db # Create a session async with SessionLocal() as session: @@ -201,8 +230,8 @@ class TestSessionCRUDUpdateLastUsedFailures: ip_address="127.0.0.1", user_agent="Test Agent", is_active=True, - expires_at=datetime.now(timezone.utc) + timedelta(days=7), - last_used_at=datetime.now(timezone.utc) - timedelta(hours=1) + expires_at=datetime.now(UTC) + timedelta(days=7), + last_used_at=datetime.now(UTC) - timedelta(hours=1), ) session.add(user_session) await session.commit() @@ -211,15 +240,19 @@ class TestSessionCRUDUpdateLastUsedFailures: # Test update failure async with SessionLocal() as session: from sqlalchemy import select + from app.models.user_session import UserSession as US + result = await session.execute(select(US).where(US.id == user_session.id)) sess = result.scalar_one() async def mock_commit(): raise OperationalError("Update failed", {}, Exception()) - with patch.object(session, 'commit', side_effect=mock_commit): - with patch.object(session, 'rollback', new_callable=AsyncMock) as mock_rollback: + with patch.object(session, "commit", side_effect=mock_commit): + with patch.object( + session, "rollback", new_callable=AsyncMock + ) as mock_rollback: with pytest.raises(OperationalError): await session_crud.update_last_used(session, session=sess) @@ -230,9 +263,11 @@ class TestSessionCRUDUpdateRefreshTokenFailures: """Test update_refresh_token exception handling.""" @pytest.mark.asyncio - async def test_update_refresh_token_commit_failure_triggers_rollback(self, async_test_db, async_test_user): + async def test_update_refresh_token_commit_failure_triggers_rollback( + self, async_test_db, async_test_user + ): """Test update_refresh_token handles commit failures.""" - test_engine, SessionLocal = async_test_db + _test_engine, SessionLocal = async_test_db # Create a session async with SessionLocal() as session: @@ -243,8 +278,8 @@ class TestSessionCRUDUpdateRefreshTokenFailures: ip_address="127.0.0.1", user_agent="Test Agent", is_active=True, - expires_at=datetime.now(timezone.utc) + timedelta(days=7), - last_used_at=datetime.now(timezone.utc) + expires_at=datetime.now(UTC) + timedelta(days=7), + last_used_at=datetime.now(UTC), ) session.add(user_session) await session.commit() @@ -253,21 +288,25 @@ class TestSessionCRUDUpdateRefreshTokenFailures: # Test update failure async with SessionLocal() as session: from sqlalchemy import select + from app.models.user_session import UserSession as US + result = await session.execute(select(US).where(US.id == user_session.id)) sess = result.scalar_one() async def mock_commit(): raise OperationalError("Token update failed", {}, Exception()) - with patch.object(session, 'commit', side_effect=mock_commit): - with patch.object(session, 'rollback', new_callable=AsyncMock) as mock_rollback: + with patch.object(session, "commit", side_effect=mock_commit): + with patch.object( + session, "rollback", new_callable=AsyncMock + ) as mock_rollback: with pytest.raises(OperationalError): await session_crud.update_refresh_token( session, session=sess, new_jti=str(uuid4()), - new_expires_at=datetime.now(timezone.utc) + timedelta(days=14) + new_expires_at=datetime.now(UTC) + timedelta(days=14), ) mock_rollback.assert_called_once() @@ -277,16 +316,21 @@ class TestSessionCRUDCleanupExpiredFailures: """Test cleanup_expired exception handling.""" @pytest.mark.asyncio - async def test_cleanup_expired_commit_failure_triggers_rollback(self, async_test_db): + async def test_cleanup_expired_commit_failure_triggers_rollback( + self, async_test_db + ): """Test cleanup_expired handles commit failures.""" - test_engine, SessionLocal = async_test_db + _test_engine, SessionLocal = async_test_db async with SessionLocal() as session: + async def mock_commit(): raise OperationalError("Cleanup failed", {}, Exception()) - with patch.object(session, 'commit', side_effect=mock_commit): - with patch.object(session, 'rollback', new_callable=AsyncMock) as mock_rollback: + with patch.object(session, "commit", side_effect=mock_commit): + with patch.object( + session, "rollback", new_callable=AsyncMock + ) as mock_rollback: with pytest.raises(OperationalError): await session_crud.cleanup_expired(session, keep_days=30) @@ -297,20 +341,24 @@ class TestSessionCRUDCleanupExpiredForUserFailures: """Test cleanup_expired_for_user exception handling.""" @pytest.mark.asyncio - async def test_cleanup_expired_for_user_commit_failure_triggers_rollback(self, async_test_db, async_test_user): + async def test_cleanup_expired_for_user_commit_failure_triggers_rollback( + self, async_test_db, async_test_user + ): """Test cleanup_expired_for_user handles commit failures.""" - test_engine, SessionLocal = async_test_db + _test_engine, SessionLocal = async_test_db async with SessionLocal() as session: + async def mock_commit(): raise OperationalError("User cleanup failed", {}, Exception()) - with patch.object(session, 'commit', side_effect=mock_commit): - with patch.object(session, 'rollback', new_callable=AsyncMock) as mock_rollback: + with patch.object(session, "commit", side_effect=mock_commit): + with patch.object( + session, "rollback", new_callable=AsyncMock + ) as mock_rollback: with pytest.raises(OperationalError): await session_crud.cleanup_expired_for_user( - session, - user_id=str(async_test_user.id) + session, user_id=str(async_test_user.id) ) mock_rollback.assert_called_once() @@ -320,17 +368,19 @@ class TestSessionCRUDGetUserSessionCountFailures: """Test get_user_session_count exception handling.""" @pytest.mark.asyncio - async def test_get_user_session_count_database_error(self, async_test_db, async_test_user): + async def test_get_user_session_count_database_error( + self, async_test_db, async_test_user + ): """Test get_user_session_count handles database errors.""" - test_engine, SessionLocal = async_test_db + _test_engine, SessionLocal = async_test_db async with SessionLocal() as session: + async def mock_execute(*args, **kwargs): raise OperationalError("Count query failed", {}, Exception()) - with patch.object(session, 'execute', side_effect=mock_execute): + with patch.object(session, "execute", side_effect=mock_execute): with pytest.raises(OperationalError): await session_crud.get_user_session_count( - session, - user_id=str(async_test_user.id) + session, user_id=str(async_test_user.id) ) diff --git a/backend/tests/crud/test_user.py b/backend/tests/crud/test_user.py index b6f27de..0500a90 100644 --- a/backend/tests/crud/test_user.py +++ b/backend/tests/crud/test_user.py @@ -2,12 +2,10 @@ """ Comprehensive tests for async user CRUD operations. """ + import pytest -from datetime import datetime, timezone -from uuid import uuid4 from app.crud.user import user as user_crud -from app.models.user import User from app.schemas.users import UserCreate, UserUpdate @@ -17,7 +15,7 @@ class TestGetByEmail: @pytest.mark.asyncio async def test_get_by_email_success(self, async_test_db, async_test_user): """Test getting user by email.""" - test_engine, AsyncTestingSessionLocal = async_test_db + _test_engine, AsyncTestingSessionLocal = async_test_db async with AsyncTestingSessionLocal() as session: result = await user_crud.get_by_email(session, email=async_test_user.email) @@ -28,10 +26,12 @@ class TestGetByEmail: @pytest.mark.asyncio async def test_get_by_email_not_found(self, async_test_db): """Test getting non-existent email returns None.""" - test_engine, AsyncTestingSessionLocal = async_test_db + _test_engine, AsyncTestingSessionLocal = async_test_db async with AsyncTestingSessionLocal() as session: - result = await user_crud.get_by_email(session, email="nonexistent@example.com") + result = await user_crud.get_by_email( + session, email="nonexistent@example.com" + ) assert result is None @@ -41,7 +41,7 @@ class TestCreate: @pytest.mark.asyncio async def test_create_user_success(self, async_test_db): """Test successfully creating a user_crud.""" - test_engine, AsyncTestingSessionLocal = async_test_db + _test_engine, AsyncTestingSessionLocal = async_test_db async with AsyncTestingSessionLocal() as session: user_data = UserCreate( @@ -49,7 +49,7 @@ class TestCreate: password="SecurePass123!", first_name="New", last_name="User", - phone_number="+1234567890" + phone_number="+1234567890", ) result = await user_crud.create(session, obj_in=user_data) @@ -65,7 +65,7 @@ class TestCreate: @pytest.mark.asyncio async def test_create_superuser_success(self, async_test_db): """Test creating a superuser.""" - test_engine, AsyncTestingSessionLocal = async_test_db + _test_engine, AsyncTestingSessionLocal = async_test_db async with AsyncTestingSessionLocal() as session: user_data = UserCreate( @@ -73,7 +73,7 @@ class TestCreate: password="SuperPass123!", first_name="Super", last_name="User", - is_superuser=True + is_superuser=True, ) result = await user_crud.create(session, obj_in=user_data) @@ -83,14 +83,14 @@ class TestCreate: @pytest.mark.asyncio async def test_create_duplicate_email_fails(self, async_test_db, async_test_user): """Test creating user with duplicate email raises ValueError.""" - test_engine, AsyncTestingSessionLocal = async_test_db + _test_engine, AsyncTestingSessionLocal = async_test_db async with AsyncTestingSessionLocal() as session: user_data = UserCreate( email=async_test_user.email, # Duplicate email password="AnotherPass123!", first_name="Duplicate", - last_name="User" + last_name="User", ) with pytest.raises(ValueError) as exc_info: @@ -105,16 +105,14 @@ class TestUpdate: @pytest.mark.asyncio async def test_update_user_basic_fields(self, async_test_db, async_test_user): """Test updating basic user fields.""" - test_engine, AsyncTestingSessionLocal = async_test_db + _test_engine, AsyncTestingSessionLocal = async_test_db async with AsyncTestingSessionLocal() as session: # Get fresh copy of user user = await user_crud.get(session, id=str(async_test_user.id)) update_data = UserUpdate( - first_name="Updated", - last_name="Name", - phone_number="+9876543210" + first_name="Updated", last_name="Name", phone_number="+9876543210" ) result = await user_crud.update(session, db_obj=user, obj_in=update_data) @@ -125,7 +123,7 @@ class TestUpdate: @pytest.mark.asyncio async def test_update_user_password(self, async_test_db): """Test updating user password.""" - test_engine, AsyncTestingSessionLocal = async_test_db + _test_engine, AsyncTestingSessionLocal = async_test_db # Create a fresh user for this test async with AsyncTestingSessionLocal() as session: @@ -133,7 +131,7 @@ class TestUpdate: email="passwordtest@example.com", password="OldPassword123!", first_name="Pass", - last_name="Test" + last_name="Test", ) user = await user_crud.create(session, obj_in=user_data) user_id = user.id @@ -149,12 +147,14 @@ class TestUpdate: await session.refresh(result) assert result.password_hash != old_password_hash assert result.password_hash is not None - assert "NewDifferentPassword123!" not in result.password_hash # Should be hashed + assert ( + "NewDifferentPassword123!" not in result.password_hash + ) # Should be hashed @pytest.mark.asyncio async def test_update_user_with_dict(self, async_test_db, async_test_user): """Test updating user with dictionary.""" - test_engine, AsyncTestingSessionLocal = async_test_db + _test_engine, AsyncTestingSessionLocal = async_test_db async with AsyncTestingSessionLocal() as session: user = await user_crud.get(session, id=str(async_test_user.id)) @@ -171,13 +171,11 @@ class TestGetMultiWithTotal: @pytest.mark.asyncio async def test_get_multi_with_total_basic(self, async_test_db, async_test_user): """Test basic pagination.""" - test_engine, AsyncTestingSessionLocal = async_test_db + _test_engine, AsyncTestingSessionLocal = async_test_db async with AsyncTestingSessionLocal() as session: users, total = await user_crud.get_multi_with_total( - session, - skip=0, - limit=10 + session, skip=0, limit=10 ) assert total >= 1 assert len(users) >= 1 @@ -186,7 +184,7 @@ class TestGetMultiWithTotal: @pytest.mark.asyncio async def test_get_multi_with_total_sorting_asc(self, async_test_db): """Test sorting in ascending order.""" - test_engine, AsyncTestingSessionLocal = async_test_db + _test_engine, AsyncTestingSessionLocal = async_test_db # Create multiple users async with AsyncTestingSessionLocal() as session: @@ -195,17 +193,13 @@ class TestGetMultiWithTotal: email=f"sort{i}@example.com", password="SecurePass123!", first_name=f"User{i}", - last_name="Test" + last_name="Test", ) await user_crud.create(session, obj_in=user_data) async with AsyncTestingSessionLocal() as session: - users, total = await user_crud.get_multi_with_total( - session, - skip=0, - limit=10, - sort_by="email", - sort_order="asc" + users, _total = await user_crud.get_multi_with_total( + session, skip=0, limit=10, sort_by="email", sort_order="asc" ) # Check if sorted (at least the test users) @@ -216,7 +210,7 @@ class TestGetMultiWithTotal: @pytest.mark.asyncio async def test_get_multi_with_total_sorting_desc(self, async_test_db): """Test sorting in descending order.""" - test_engine, AsyncTestingSessionLocal = async_test_db + _test_engine, AsyncTestingSessionLocal = async_test_db # Create multiple users async with AsyncTestingSessionLocal() as session: @@ -225,17 +219,13 @@ class TestGetMultiWithTotal: email=f"desc{i}@example.com", password="SecurePass123!", first_name=f"User{i}", - last_name="Test" + last_name="Test", ) await user_crud.create(session, obj_in=user_data) async with AsyncTestingSessionLocal() as session: - users, total = await user_crud.get_multi_with_total( - session, - skip=0, - limit=10, - sort_by="email", - sort_order="desc" + users, _total = await user_crud.get_multi_with_total( + session, skip=0, limit=10, sort_by="email", sort_order="desc" ) # Check if sorted descending (at least the test users) @@ -246,7 +236,7 @@ class TestGetMultiWithTotal: @pytest.mark.asyncio async def test_get_multi_with_total_filtering(self, async_test_db): """Test filtering by field.""" - test_engine, AsyncTestingSessionLocal = async_test_db + _test_engine, AsyncTestingSessionLocal = async_test_db # Create active and inactive users async with AsyncTestingSessionLocal() as session: @@ -254,7 +244,7 @@ class TestGetMultiWithTotal: email="active@example.com", password="SecurePass123!", first_name="Active", - last_name="User" + last_name="User", ) await user_crud.create(session, obj_in=active_user) @@ -262,23 +252,18 @@ class TestGetMultiWithTotal: email="inactive@example.com", password="SecurePass123!", first_name="Inactive", - last_name="User" + last_name="User", ) created_inactive = await user_crud.create(session, obj_in=inactive_user) # Deactivate the user await user_crud.update( - session, - db_obj=created_inactive, - obj_in={"is_active": False} + session, db_obj=created_inactive, obj_in={"is_active": False} ) async with AsyncTestingSessionLocal() as session: - users, total = await user_crud.get_multi_with_total( - session, - skip=0, - limit=100, - filters={"is_active": True} + users, _total = await user_crud.get_multi_with_total( + session, skip=0, limit=100, filters={"is_active": True} ) # All returned users should be active @@ -287,7 +272,7 @@ class TestGetMultiWithTotal: @pytest.mark.asyncio async def test_get_multi_with_total_search(self, async_test_db): """Test search functionality.""" - test_engine, AsyncTestingSessionLocal = async_test_db + _test_engine, AsyncTestingSessionLocal = async_test_db # Create user with unique name async with AsyncTestingSessionLocal() as session: @@ -295,16 +280,13 @@ class TestGetMultiWithTotal: email="searchable@example.com", password="SecurePass123!", first_name="Searchable", - last_name="UserName" + last_name="UserName", ) await user_crud.create(session, obj_in=user_data) async with AsyncTestingSessionLocal() as session: users, total = await user_crud.get_multi_with_total( - session, - skip=0, - limit=100, - search="Searchable" + session, skip=0, limit=100, search="Searchable" ) assert total >= 1 @@ -313,7 +295,7 @@ class TestGetMultiWithTotal: @pytest.mark.asyncio async def test_get_multi_with_total_pagination(self, async_test_db): """Test pagination with skip and limit.""" - test_engine, AsyncTestingSessionLocal = async_test_db + _test_engine, AsyncTestingSessionLocal = async_test_db # Create multiple users async with AsyncTestingSessionLocal() as session: @@ -322,23 +304,19 @@ class TestGetMultiWithTotal: email=f"page{i}@example.com", password="SecurePass123!", first_name=f"Page{i}", - last_name="User" + last_name="User", ) await user_crud.create(session, obj_in=user_data) async with AsyncTestingSessionLocal() as session: # Get first page users_page1, total = await user_crud.get_multi_with_total( - session, - skip=0, - limit=2 + session, skip=0, limit=2 ) # Get second page users_page2, total2 = await user_crud.get_multi_with_total( - session, - skip=2, - limit=2 + session, skip=2, limit=2 ) # Total should be same @@ -349,7 +327,7 @@ class TestGetMultiWithTotal: @pytest.mark.asyncio async def test_get_multi_with_total_validation_negative_skip(self, async_test_db): """Test validation fails for negative skip.""" - test_engine, AsyncTestingSessionLocal = async_test_db + _test_engine, AsyncTestingSessionLocal = async_test_db async with AsyncTestingSessionLocal() as session: with pytest.raises(ValueError) as exc_info: @@ -360,7 +338,7 @@ class TestGetMultiWithTotal: @pytest.mark.asyncio async def test_get_multi_with_total_validation_negative_limit(self, async_test_db): """Test validation fails for negative limit.""" - test_engine, AsyncTestingSessionLocal = async_test_db + _test_engine, AsyncTestingSessionLocal = async_test_db async with AsyncTestingSessionLocal() as session: with pytest.raises(ValueError) as exc_info: @@ -371,7 +349,7 @@ class TestGetMultiWithTotal: @pytest.mark.asyncio async def test_get_multi_with_total_validation_max_limit(self, async_test_db): """Test validation fails for limit > 1000.""" - test_engine, AsyncTestingSessionLocal = async_test_db + _test_engine, AsyncTestingSessionLocal = async_test_db async with AsyncTestingSessionLocal() as session: with pytest.raises(ValueError) as exc_info: @@ -386,7 +364,7 @@ class TestBulkUpdateStatus: @pytest.mark.asyncio async def test_bulk_update_status_success(self, async_test_db): """Test bulk updating user status.""" - test_engine, AsyncTestingSessionLocal = async_test_db + _test_engine, AsyncTestingSessionLocal = async_test_db # Create multiple users user_ids = [] @@ -396,7 +374,7 @@ class TestBulkUpdateStatus: email=f"bulk{i}@example.com", password="SecurePass123!", first_name=f"Bulk{i}", - last_name="User" + last_name="User", ) user = await user_crud.create(session, obj_in=user_data) user_ids.append(user.id) @@ -404,9 +382,7 @@ class TestBulkUpdateStatus: # Bulk deactivate async with AsyncTestingSessionLocal() as session: count = await user_crud.bulk_update_status( - session, - user_ids=user_ids, - is_active=False + session, user_ids=user_ids, is_active=False ) assert count == 3 @@ -419,20 +395,18 @@ class TestBulkUpdateStatus: @pytest.mark.asyncio async def test_bulk_update_status_empty_list(self, async_test_db): """Test bulk update with empty list returns 0.""" - test_engine, AsyncTestingSessionLocal = async_test_db + _test_engine, AsyncTestingSessionLocal = async_test_db async with AsyncTestingSessionLocal() as session: count = await user_crud.bulk_update_status( - session, - user_ids=[], - is_active=False + session, user_ids=[], is_active=False ) assert count == 0 @pytest.mark.asyncio async def test_bulk_update_status_reactivate(self, async_test_db): """Test bulk reactivating users.""" - test_engine, AsyncTestingSessionLocal = async_test_db + _test_engine, AsyncTestingSessionLocal = async_test_db # Create inactive user async with AsyncTestingSessionLocal() as session: @@ -440,7 +414,7 @@ class TestBulkUpdateStatus: email="reactivate@example.com", password="SecurePass123!", first_name="Reactivate", - last_name="User" + last_name="User", ) user = await user_crud.create(session, obj_in=user_data) # Deactivate @@ -450,9 +424,7 @@ class TestBulkUpdateStatus: # Reactivate async with AsyncTestingSessionLocal() as session: count = await user_crud.bulk_update_status( - session, - user_ids=[user_id], - is_active=True + session, user_ids=[user_id], is_active=True ) assert count == 1 @@ -468,7 +440,7 @@ class TestBulkSoftDelete: @pytest.mark.asyncio async def test_bulk_soft_delete_success(self, async_test_db): """Test bulk soft deleting users.""" - test_engine, AsyncTestingSessionLocal = async_test_db + _test_engine, AsyncTestingSessionLocal = async_test_db # Create multiple users user_ids = [] @@ -478,17 +450,14 @@ class TestBulkSoftDelete: email=f"delete{i}@example.com", password="SecurePass123!", first_name=f"Delete{i}", - last_name="User" + last_name="User", ) user = await user_crud.create(session, obj_in=user_data) user_ids.append(user.id) # Bulk delete async with AsyncTestingSessionLocal() as session: - count = await user_crud.bulk_soft_delete( - session, - user_ids=user_ids - ) + count = await user_crud.bulk_soft_delete(session, user_ids=user_ids) assert count == 3 # Verify all are soft deleted @@ -501,7 +470,7 @@ class TestBulkSoftDelete: @pytest.mark.asyncio async def test_bulk_soft_delete_with_exclusion(self, async_test_db): """Test bulk soft delete with excluded user_crud.""" - test_engine, AsyncTestingSessionLocal = async_test_db + _test_engine, AsyncTestingSessionLocal = async_test_db # Create multiple users user_ids = [] @@ -511,7 +480,7 @@ class TestBulkSoftDelete: email=f"exclude{i}@example.com", password="SecurePass123!", first_name=f"Exclude{i}", - last_name="User" + last_name="User", ) user = await user_crud.create(session, obj_in=user_data) user_ids.append(user.id) @@ -520,9 +489,7 @@ class TestBulkSoftDelete: exclude_id = user_ids[0] async with AsyncTestingSessionLocal() as session: count = await user_crud.bulk_soft_delete( - session, - user_ids=user_ids, - exclude_user_id=exclude_id + session, user_ids=user_ids, exclude_user_id=exclude_id ) assert count == 2 # Only 2 deleted @@ -534,19 +501,16 @@ class TestBulkSoftDelete: @pytest.mark.asyncio async def test_bulk_soft_delete_empty_list(self, async_test_db): """Test bulk delete with empty list returns 0.""" - test_engine, AsyncTestingSessionLocal = async_test_db + _test_engine, AsyncTestingSessionLocal = async_test_db async with AsyncTestingSessionLocal() as session: - count = await user_crud.bulk_soft_delete( - session, - user_ids=[] - ) + count = await user_crud.bulk_soft_delete(session, user_ids=[]) assert count == 0 @pytest.mark.asyncio async def test_bulk_soft_delete_all_excluded(self, async_test_db): """Test bulk delete where all users are excluded.""" - test_engine, AsyncTestingSessionLocal = async_test_db + _test_engine, AsyncTestingSessionLocal = async_test_db # Create user async with AsyncTestingSessionLocal() as session: @@ -554,7 +518,7 @@ class TestBulkSoftDelete: email="onlyuser@example.com", password="SecurePass123!", first_name="Only", - last_name="User" + last_name="User", ) user = await user_crud.create(session, obj_in=user_data) user_id = user.id @@ -562,16 +526,14 @@ class TestBulkSoftDelete: # Try to delete but exclude async with AsyncTestingSessionLocal() as session: count = await user_crud.bulk_soft_delete( - session, - user_ids=[user_id], - exclude_user_id=user_id + session, user_ids=[user_id], exclude_user_id=user_id ) assert count == 0 @pytest.mark.asyncio async def test_bulk_soft_delete_already_deleted(self, async_test_db): """Test bulk delete doesn't re-delete already deleted users.""" - test_engine, AsyncTestingSessionLocal = async_test_db + _test_engine, AsyncTestingSessionLocal = async_test_db # Create and delete user async with AsyncTestingSessionLocal() as session: @@ -579,7 +541,7 @@ class TestBulkSoftDelete: email="predeleted@example.com", password="SecurePass123!", first_name="PreDeleted", - last_name="User" + last_name="User", ) user = await user_crud.create(session, obj_in=user_data) user_id = user.id @@ -589,10 +551,7 @@ class TestBulkSoftDelete: # Try to delete again async with AsyncTestingSessionLocal() as session: - count = await user_crud.bulk_soft_delete( - session, - user_ids=[user_id] - ) + count = await user_crud.bulk_soft_delete(session, user_ids=[user_id]) assert count == 0 # Already deleted @@ -602,7 +561,7 @@ class TestUtilityMethods: @pytest.mark.asyncio async def test_is_active_true(self, async_test_db, async_test_user): """Test is_active returns True for active user_crud.""" - test_engine, AsyncTestingSessionLocal = async_test_db + _test_engine, AsyncTestingSessionLocal = async_test_db async with AsyncTestingSessionLocal() as session: user = await user_crud.get(session, id=str(async_test_user.id)) @@ -611,14 +570,14 @@ class TestUtilityMethods: @pytest.mark.asyncio async def test_is_active_false(self, async_test_db): """Test is_active returns False for inactive user_crud.""" - test_engine, AsyncTestingSessionLocal = async_test_db + _test_engine, AsyncTestingSessionLocal = async_test_db async with AsyncTestingSessionLocal() as session: user_data = UserCreate( email="inactive2@example.com", password="SecurePass123!", first_name="Inactive", - last_name="User" + last_name="User", ) user = await user_crud.create(session, obj_in=user_data) await user_crud.update(session, db_obj=user, obj_in={"is_active": False}) @@ -628,7 +587,7 @@ class TestUtilityMethods: @pytest.mark.asyncio async def test_is_superuser_true(self, async_test_db, async_test_superuser): """Test is_superuser returns True for superuser.""" - test_engine, AsyncTestingSessionLocal = async_test_db + _test_engine, AsyncTestingSessionLocal = async_test_db async with AsyncTestingSessionLocal() as session: user = await user_crud.get(session, id=str(async_test_superuser.id)) @@ -637,7 +596,7 @@ class TestUtilityMethods: @pytest.mark.asyncio async def test_is_superuser_false(self, async_test_db, async_test_user): """Test is_superuser returns False for regular user_crud.""" - test_engine, AsyncTestingSessionLocal = async_test_db + _test_engine, AsyncTestingSessionLocal = async_test_db async with AsyncTestingSessionLocal() as session: user = await user_crud.get(session, id=str(async_test_user.id)) @@ -654,42 +613,52 @@ class TestUserExceptionHandlers: async def test_get_by_email_database_error(self, async_test_db): """Test get_by_email handles database errors (covers lines 30-32).""" from unittest.mock import patch - test_engine, AsyncTestingSessionLocal = async_test_db + + _test_engine, AsyncTestingSessionLocal = async_test_db async with AsyncTestingSessionLocal() as session: - with patch.object(session, 'execute', side_effect=Exception("Database query failed")): + with patch.object( + session, "execute", side_effect=Exception("Database query failed") + ): with pytest.raises(Exception, match="Database query failed"): await user_crud.get_by_email(session, email="test@example.com") @pytest.mark.asyncio - async def test_bulk_update_status_database_error(self, async_test_db, async_test_user): + async def test_bulk_update_status_database_error( + self, async_test_db, async_test_user + ): """Test bulk_update_status handles database errors (covers lines 205-208).""" - from unittest.mock import patch, AsyncMock - test_engine, AsyncTestingSessionLocal = async_test_db + from unittest.mock import AsyncMock, patch + + _test_engine, AsyncTestingSessionLocal = async_test_db async with AsyncTestingSessionLocal() as session: # Mock execute to fail - with patch.object(session, 'execute', side_effect=Exception("Bulk update failed")): - with patch.object(session, 'rollback', new_callable=AsyncMock): + with patch.object( + session, "execute", side_effect=Exception("Bulk update failed") + ): + with patch.object(session, "rollback", new_callable=AsyncMock): with pytest.raises(Exception, match="Bulk update failed"): await user_crud.bulk_update_status( - session, - user_ids=[async_test_user.id], - is_active=False + session, user_ids=[async_test_user.id], is_active=False ) @pytest.mark.asyncio - async def test_bulk_soft_delete_database_error(self, async_test_db, async_test_user): + async def test_bulk_soft_delete_database_error( + self, async_test_db, async_test_user + ): """Test bulk_soft_delete handles database errors (covers lines 257-260).""" - from unittest.mock import patch, AsyncMock - test_engine, AsyncTestingSessionLocal = async_test_db + from unittest.mock import AsyncMock, patch + + _test_engine, AsyncTestingSessionLocal = async_test_db async with AsyncTestingSessionLocal() as session: # Mock execute to fail - with patch.object(session, 'execute', side_effect=Exception("Bulk delete failed")): - with patch.object(session, 'rollback', new_callable=AsyncMock): + with patch.object( + session, "execute", side_effect=Exception("Bulk delete failed") + ): + with patch.object(session, "rollback", new_callable=AsyncMock): with pytest.raises(Exception, match="Bulk delete failed"): await user_crud.bulk_soft_delete( - session, - user_ids=[async_test_user.id] + session, user_ids=[async_test_user.id] ) diff --git a/backend/tests/models/test_user.py b/backend/tests/models/test_user.py index faaab0b..c764bfc 100755 --- a/backend/tests/models/test_user.py +++ b/backend/tests/models/test_user.py @@ -1,8 +1,10 @@ # tests/models/test_user.py import uuid -import pytest from datetime import datetime + +import pytest from sqlalchemy.exc import IntegrityError + from app.models.user import User @@ -166,7 +168,6 @@ def test_user_required_fields(db_session): db_session.rollback() - def test_user_defaults(db_session): """Test that default values are correctly set.""" # Arrange - Create a minimal user with only required fields @@ -210,22 +211,13 @@ def test_user_with_complex_json_preferences(db_session): """Test storing and retrieving complex JSON preferences.""" # Arrange - Create a user with nested JSON preferences complex_preferences = { - "theme": { - "mode": "dark", - "colors": { - "primary": "#333", - "secondary": "#666" - } - }, + "theme": {"mode": "dark", "colors": {"primary": "#333", "secondary": "#666"}}, "notifications": { "email": True, "sms": False, - "push": { - "enabled": True, - "quiet_hours": [22, 7] - } + "push": {"enabled": True, "quiet_hours": [22, 7]}, }, - "tags": ["important", "family", "events"] + "tags": ["important", "family", "events"], } user = User( @@ -234,16 +226,18 @@ def test_user_with_complex_json_preferences(db_session): password_hash="hashedpassword", first_name="Complex", last_name="JSON", - preferences=complex_preferences + preferences=complex_preferences, ) db_session.add(user) db_session.commit() # Act - Retrieve the user - retrieved_user = db_session.query(User).filter_by(email="complex@example.com").first() + retrieved_user = ( + db_session.query(User).filter_by(email="complex@example.com").first() + ) # Assert - The complex JSON should be preserved assert retrieved_user.preferences == complex_preferences assert retrieved_user.preferences["theme"]["colors"]["primary"] == "#333" assert retrieved_user.preferences["notifications"]["push"]["quiet_hours"] == [22, 7] - assert "important" in retrieved_user.preferences["tags"] \ No newline at end of file + assert "important" in retrieved_user.preferences["tags"] diff --git a/backend/tests/schemas/test_organizations.py b/backend/tests/schemas/test_organizations.py index b4e00c6..2e6ad78 100644 --- a/backend/tests/schemas/test_organizations.py +++ b/backend/tests/schemas/test_organizations.py @@ -5,6 +5,7 @@ Covers Pydantic validators for: - Slug validation (lines 26, 28, 30, 32, 62-70) - Name validation (lines 40, 77) """ + import pytest from pydantic import ValidationError @@ -20,19 +21,13 @@ class TestOrganizationBaseValidators: def test_valid_organization_base(self): """Test that valid data passes validation.""" - org = OrganizationBase( - name="Test Organization", - slug="test-org" - ) + org = OrganizationBase(name="Test Organization", slug="test-org") assert org.name == "Test Organization" assert org.slug == "test-org" def test_slug_none_returns_none(self): """Test that None slug is allowed (covers line 26).""" - org = OrganizationBase( - name="Test Organization", - slug=None - ) + org = OrganizationBase(name="Test Organization", slug=None) assert org.slug is None def test_slug_invalid_characters_rejected(self): @@ -40,57 +35,46 @@ class TestOrganizationBaseValidators: with pytest.raises(ValidationError) as exc_info: OrganizationBase( name="Test Organization", - slug="Test_Org!" # Uppercase and special chars + slug="Test_Org!", # Uppercase and special chars ) errors = exc_info.value.errors() - assert any("lowercase letters, numbers, and hyphens" in str(e['msg']) for e in errors) + assert any( + "lowercase letters, numbers, and hyphens" in str(e["msg"]) for e in errors + ) def test_slug_starts_with_hyphen_rejected(self): """Test slug starting with hyphen is rejected (covers line 30).""" with pytest.raises(ValidationError) as exc_info: - OrganizationBase( - name="Test Organization", - slug="-test-org" - ) + OrganizationBase(name="Test Organization", slug="-test-org") errors = exc_info.value.errors() - assert any("cannot start or end with a hyphen" in str(e['msg']) for e in errors) + assert any("cannot start or end with a hyphen" in str(e["msg"]) for e in errors) def test_slug_ends_with_hyphen_rejected(self): """Test slug ending with hyphen is rejected (covers line 30).""" with pytest.raises(ValidationError) as exc_info: - OrganizationBase( - name="Test Organization", - slug="test-org-" - ) + OrganizationBase(name="Test Organization", slug="test-org-") errors = exc_info.value.errors() - assert any("cannot start or end with a hyphen" in str(e['msg']) for e in errors) + assert any("cannot start or end with a hyphen" in str(e["msg"]) for e in errors) def test_slug_consecutive_hyphens_rejected(self): """Test slug with consecutive hyphens is rejected (covers line 32).""" with pytest.raises(ValidationError) as exc_info: - OrganizationBase( - name="Test Organization", - slug="test--org" - ) + OrganizationBase(name="Test Organization", slug="test--org") errors = exc_info.value.errors() - assert any("cannot contain consecutive hyphens" in str(e['msg']) for e in errors) + assert any( + "cannot contain consecutive hyphens" in str(e["msg"]) for e in errors + ) def test_name_whitespace_only_rejected(self): """Test whitespace-only name is rejected (covers line 40).""" with pytest.raises(ValidationError) as exc_info: - OrganizationBase( - name=" ", - slug="test-org" - ) + OrganizationBase(name=" ", slug="test-org") errors = exc_info.value.errors() - assert any("name cannot be empty" in str(e['msg']) for e in errors) + assert any("name cannot be empty" in str(e["msg"]) for e in errors) def test_name_trimmed(self): """Test that name is trimmed.""" - org = OrganizationBase( - name=" Test Organization ", - slug="test-org" - ) + org = OrganizationBase(name=" Test Organization ", slug="test-org") assert org.name == "Test Organization" @@ -99,22 +83,18 @@ class TestOrganizationCreateValidators: def test_valid_organization_create(self): """Test that valid data passes validation.""" - org = OrganizationCreate( - name="Test Organization", - slug="test-org" - ) + org = OrganizationCreate(name="Test Organization", slug="test-org") assert org.name == "Test Organization" assert org.slug == "test-org" def test_slug_validation_inherited(self): """Test that slug validation is inherited from base.""" with pytest.raises(ValidationError) as exc_info: - OrganizationCreate( - name="Test", - slug="Invalid_Slug!" - ) + OrganizationCreate(name="Test", slug="Invalid_Slug!") errors = exc_info.value.errors() - assert any("lowercase letters, numbers, and hyphens" in str(e['msg']) for e in errors) + assert any( + "lowercase letters, numbers, and hyphens" in str(e["msg"]) for e in errors + ) class TestOrganizationUpdateValidators: @@ -122,10 +102,7 @@ class TestOrganizationUpdateValidators: def test_valid_organization_update(self): """Test that valid update data passes validation.""" - org = OrganizationUpdate( - name="Updated Name", - slug="updated-slug" - ) + org = OrganizationUpdate(name="Updated Name", slug="updated-slug") assert org.name == "Updated Name" assert org.slug == "updated-slug" @@ -139,35 +116,39 @@ class TestOrganizationUpdateValidators: with pytest.raises(ValidationError) as exc_info: OrganizationUpdate(slug="Test_Org!") errors = exc_info.value.errors() - assert any("lowercase letters, numbers, and hyphens" in str(e['msg']) for e in errors) + assert any( + "lowercase letters, numbers, and hyphens" in str(e["msg"]) for e in errors + ) def test_update_slug_starts_with_hyphen_rejected(self): """Test update slug starting with hyphen is rejected (covers line 66).""" with pytest.raises(ValidationError) as exc_info: OrganizationUpdate(slug="-test-org") errors = exc_info.value.errors() - assert any("cannot start or end with a hyphen" in str(e['msg']) for e in errors) + assert any("cannot start or end with a hyphen" in str(e["msg"]) for e in errors) def test_update_slug_ends_with_hyphen_rejected(self): """Test update slug ending with hyphen is rejected (covers line 66).""" with pytest.raises(ValidationError) as exc_info: OrganizationUpdate(slug="test-org-") errors = exc_info.value.errors() - assert any("cannot start or end with a hyphen" in str(e['msg']) for e in errors) + assert any("cannot start or end with a hyphen" in str(e["msg"]) for e in errors) def test_update_slug_consecutive_hyphens_rejected(self): """Test update slug with consecutive hyphens is rejected (covers line 68).""" with pytest.raises(ValidationError) as exc_info: OrganizationUpdate(slug="test--org") errors = exc_info.value.errors() - assert any("cannot contain consecutive hyphens" in str(e['msg']) for e in errors) + assert any( + "cannot contain consecutive hyphens" in str(e["msg"]) for e in errors + ) def test_update_name_whitespace_only_rejected(self): """Test whitespace-only name in update is rejected (covers line 77).""" with pytest.raises(ValidationError) as exc_info: OrganizationUpdate(name=" ") errors = exc_info.value.errors() - assert any("name cannot be empty" in str(e['msg']) for e in errors) + assert any("name cannot be empty" in str(e["msg"]) for e in errors) def test_update_name_none_allowed(self): """Test that None name is allowed in update.""" diff --git a/backend/tests/schemas/test_user_schemas.py b/backend/tests/schemas/test_user_schemas.py index 826f13b..df7561e 100755 --- a/backend/tests/schemas/test_user_schemas.py +++ b/backend/tests/schemas/test_user_schemas.py @@ -1,80 +1,177 @@ # tests/schemas/test_user_schemas.py -import pytest import re + +import pytest from pydantic import ValidationError from app.schemas.users import UserBase, UserCreate + class TestPhoneNumberValidation: """Tests for phone number validation in user schemas""" def test_valid_swiss_numbers(self): """Test valid Swiss phone numbers are accepted""" # International format - user = UserBase(email="test@example.com", first_name="Test", last_name="User", phone_number="+41791234567") + user = UserBase( + email="test@example.com", + first_name="Test", + last_name="User", + phone_number="+41791234567", + ) assert user.phone_number == "+41791234567" # Local format - user = UserBase(email="test@example.com", first_name="Test", last_name="User", phone_number="0791234567") + user = UserBase( + email="test@example.com", + first_name="Test", + last_name="User", + phone_number="0791234567", + ) assert user.phone_number == "0791234567" # With formatting characters - user = UserBase(email="test@example.com", first_name="Test", last_name="User", phone_number="+41 79 123 45 67") - assert re.sub(r'[\s\-\(\)]', '', user.phone_number) == "+41791234567" + user = UserBase( + email="test@example.com", + first_name="Test", + last_name="User", + phone_number="+41 79 123 45 67", + ) + assert re.sub(r"[\s\-\(\)]", "", user.phone_number) == "+41791234567" - user = UserBase(email="test@example.com", first_name="Test", last_name="User", phone_number="079 123 45 67") - assert re.sub(r'[\s\-\(\)]', '', user.phone_number) == "0791234567" + user = UserBase( + email="test@example.com", + first_name="Test", + last_name="User", + phone_number="079 123 45 67", + ) + assert re.sub(r"[\s\-\(\)]", "", user.phone_number) == "0791234567" - user = UserBase(email="test@example.com", first_name="Test", last_name="User", phone_number="+41-79-123-45-67") - assert re.sub(r'[\s\-\(\)]', '', user.phone_number) == "+41791234567" + user = UserBase( + email="test@example.com", + first_name="Test", + last_name="User", + phone_number="+41-79-123-45-67", + ) + assert re.sub(r"[\s\-\(\)]", "", user.phone_number) == "+41791234567" - user = UserBase(email="test@example.com", first_name="Test", last_name="User", phone_number="079-123-45-67") - assert re.sub(r'[\s\-\(\)]', '', user.phone_number) == "0791234567" + user = UserBase( + email="test@example.com", + first_name="Test", + last_name="User", + phone_number="079-123-45-67", + ) + assert re.sub(r"[\s\-\(\)]", "", user.phone_number) == "0791234567" - user = UserBase(email="test@example.com", first_name="Test", last_name="User", phone_number="+41 (79) 123 45 67") - assert re.sub(r'[\s\-\(\)]', '', user.phone_number) == "+41791234567" + user = UserBase( + email="test@example.com", + first_name="Test", + last_name="User", + phone_number="+41 (79) 123 45 67", + ) + assert re.sub(r"[\s\-\(\)]", "", user.phone_number) == "+41791234567" - user = UserBase(email="test@example.com", first_name="Test", last_name="User", phone_number="079 (123) 45 67") - assert re.sub(r'[\s\-\(\)]', '', user.phone_number) == "0791234567" + user = UserBase( + email="test@example.com", + first_name="Test", + last_name="User", + phone_number="079 (123) 45 67", + ) + assert re.sub(r"[\s\-\(\)]", "", user.phone_number) == "0791234567" def test_valid_italian_numbers(self): """Test valid Italian phone numbers are accepted""" # International format - user = UserBase(email="test@example.com", first_name="Test", last_name="User", phone_number="+393451234567") + user = UserBase( + email="test@example.com", + first_name="Test", + last_name="User", + phone_number="+393451234567", + ) assert user.phone_number == "+393451234567" - user = UserBase(email="test@example.com", first_name="Test", last_name="User", phone_number="+39345123456") + user = UserBase( + email="test@example.com", + first_name="Test", + last_name="User", + phone_number="+39345123456", + ) assert user.phone_number == "+39345123456" # Local format - user = UserBase(email="test@example.com", first_name="Test", last_name="User", phone_number="03451234567") + user = UserBase( + email="test@example.com", + first_name="Test", + last_name="User", + phone_number="03451234567", + ) assert user.phone_number == "03451234567" - user = UserBase(email="test@example.com", first_name="Test", last_name="User", phone_number="0345123456789") + user = UserBase( + email="test@example.com", + first_name="Test", + last_name="User", + phone_number="0345123456789", + ) assert user.phone_number == "0345123456789" # With formatting characters - user = UserBase(email="test@example.com", first_name="Test", last_name="User", phone_number="+39 345 123 4567") - assert re.sub(r'[\s\-\(\)]', '', user.phone_number) == "+393451234567" + user = UserBase( + email="test@example.com", + first_name="Test", + last_name="User", + phone_number="+39 345 123 4567", + ) + assert re.sub(r"[\s\-\(\)]", "", user.phone_number) == "+393451234567" - user = UserBase(email="test@example.com", first_name="Test", last_name="User", phone_number="0345 123 4567") - assert re.sub(r'[\s\-\(\)]', '', user.phone_number) == "03451234567" + user = UserBase( + email="test@example.com", + first_name="Test", + last_name="User", + phone_number="0345 123 4567", + ) + assert re.sub(r"[\s\-\(\)]", "", user.phone_number) == "03451234567" - user = UserBase(email="test@example.com", first_name="Test", last_name="User", phone_number="+39-345-123-4567") - assert re.sub(r'[\s\-\(\)]', '', user.phone_number) == "+393451234567" + user = UserBase( + email="test@example.com", + first_name="Test", + last_name="User", + phone_number="+39-345-123-4567", + ) + assert re.sub(r"[\s\-\(\)]", "", user.phone_number) == "+393451234567" - user = UserBase(email="test@example.com", first_name="Test", last_name="User", phone_number="0345-123-4567") - assert re.sub(r'[\s\-\(\)]', '', user.phone_number) == "03451234567" + user = UserBase( + email="test@example.com", + first_name="Test", + last_name="User", + phone_number="0345-123-4567", + ) + assert re.sub(r"[\s\-\(\)]", "", user.phone_number) == "03451234567" - user = UserBase(email="test@example.com", first_name="Test", last_name="User", phone_number="+39 (345) 123 4567") - assert re.sub(r'[\s\-\(\)]', '', user.phone_number) == "+393451234567" + user = UserBase( + email="test@example.com", + first_name="Test", + last_name="User", + phone_number="+39 (345) 123 4567", + ) + assert re.sub(r"[\s\-\(\)]", "", user.phone_number) == "+393451234567" - user = UserBase(email="test@example.com", first_name="Test", last_name="User", phone_number="0345 (123) 4567") - assert re.sub(r'[\s\-\(\)]', '', user.phone_number) == "03451234567" + user = UserBase( + email="test@example.com", + first_name="Test", + last_name="User", + phone_number="0345 (123) 4567", + ) + assert re.sub(r"[\s\-\(\)]", "", user.phone_number) == "03451234567" def test_none_phone_number(self): """Test that None is accepted as a valid value (optional phone number)""" - user = UserBase(email="test@example.com", first_name="Test", last_name="User", phone_number=None) + user = UserBase( + email="test@example.com", + first_name="Test", + last_name="User", + phone_number=None, + ) assert user.phone_number is None def test_invalid_phone_numbers(self): @@ -83,17 +180,14 @@ class TestPhoneNumberValidation: # Too short "+12", "012", - # Invalid characters "+41xyz123456", "079abc4567", "123-abc-7890", "+1(800)CALL-NOW", - # Completely invalid formats "++4412345678", # Double plus # Note: "()+41123456" becomes "+41123456" after cleaning, which is valid - # Empty string "", # Spaces only @@ -102,7 +196,12 @@ class TestPhoneNumberValidation: for number in invalid_numbers: with pytest.raises(ValidationError): - UserBase(email="test@example.com", first_name="Test", last_name="User", phone_number=number) + UserBase( + email="test@example.com", + first_name="Test", + last_name="User", + phone_number=number, + ) def test_phone_validation_in_user_create(self): """Test that phone validation also works in UserCreate schema""" @@ -112,7 +211,7 @@ class TestPhoneNumberValidation: first_name="Test", last_name="User", password="Password123!", - phone_number="+41791234567" + phone_number="+41791234567", ) assert user.phone_number == "+41791234567" @@ -123,5 +222,5 @@ class TestPhoneNumberValidation: first_name="Test", last_name="User", password="Password123!", - phone_number="invalid-number" - ) \ No newline at end of file + phone_number="invalid-number", + ) diff --git a/backend/tests/schemas/test_validators.py b/backend/tests/schemas/test_validators.py index c5f10da..ddea360 100644 --- a/backend/tests/schemas/test_validators.py +++ b/backend/tests/schemas/test_validators.py @@ -7,12 +7,13 @@ Covers all edge cases in validation functions: - validate_email_format (line 148) - validate_slug (lines 170-183) """ + import pytest from app.schemas.validators import ( + validate_email_format, validate_password_strength, validate_phone_number, - validate_email_format, validate_slug, ) @@ -108,12 +109,14 @@ class TestPhoneNumberValidator: validate_phone_number("+123456789012345") # 15 digits after + def test_multiple_plus_symbols_rejected(self): - """Test phone number with multiple + symbols. + r"""Test phone number with multiple + symbols. Note: Line 115 is defensive code - the regex check at line 110 catches this first. The regex ^(?:\+[0-9]{8,14}|0[0-9]{8,14})$ only allows + at the start. """ - with pytest.raises(ValueError, match="must start with \\+ or 0 followed by 8-14 digits"): + with pytest.raises( + ValueError, match="must start with \\+ or 0 followed by 8-14 digits" + ): validate_phone_number("+1234+5678901") def test_non_digit_after_prefix_rejected(self): diff --git a/backend/tests/services/test_auth_service.py b/backend/tests/services/test_auth_service.py index 78e9347..cf6f84c 100755 --- a/backend/tests/services/test_auth_service.py +++ b/backend/tests/services/test_auth_service.py @@ -1,14 +1,18 @@ # tests/services/test_auth_service.py import uuid -import pytest -import pytest_asyncio from unittest.mock import patch + +import pytest from sqlalchemy import select -from app.core.auth import get_password_hash, verify_password, TokenExpiredError, TokenInvalidError +from app.core.auth import ( + TokenInvalidError, + get_password_hash, + verify_password, +) from app.models.user import User -from app.schemas.users import UserCreate, Token -from app.services.auth_service import AuthService, AuthenticationError +from app.schemas.users import Token, UserCreate +from app.services.auth_service import AuthenticationError, AuthService class TestAuthServiceAuthentication: @@ -17,12 +21,14 @@ class TestAuthServiceAuthentication: @pytest.mark.asyncio async def test_authenticate_valid_user(self, async_test_db, async_test_user): """Test authenticating a user with valid credentials""" - test_engine, AsyncTestingSessionLocal = async_test_db + _test_engine, AsyncTestingSessionLocal = async_test_db # Set a known password for the mock user password = "TestPassword123!" async with AsyncTestingSessionLocal() as session: - result = await session.execute(select(User).where(User.id == async_test_user.id)) + result = await session.execute( + select(User).where(User.id == async_test_user.id) + ) user = result.scalar_one_or_none() user.password_hash = get_password_hash(password) await session.commit() @@ -30,9 +36,7 @@ class TestAuthServiceAuthentication: # Authenticate with correct credentials async with AsyncTestingSessionLocal() as session: auth_user = await AuthService.authenticate_user( - db=session, - email=async_test_user.email, - password=password + db=session, email=async_test_user.email, password=password ) assert auth_user is not None @@ -42,26 +46,28 @@ class TestAuthServiceAuthentication: @pytest.mark.asyncio async def test_authenticate_nonexistent_user(self, async_test_db): """Test authenticating with an email that doesn't exist""" - test_engine, AsyncTestingSessionLocal = async_test_db + _test_engine, AsyncTestingSessionLocal = async_test_db async with AsyncTestingSessionLocal() as session: user = await AuthService.authenticate_user( - db=session, - email="nonexistent@example.com", - password="password" + db=session, email="nonexistent@example.com", password="password" ) assert user is None @pytest.mark.asyncio - async def test_authenticate_with_wrong_password(self, async_test_db, async_test_user): + async def test_authenticate_with_wrong_password( + self, async_test_db, async_test_user + ): """Test authenticating with the wrong password""" - test_engine, AsyncTestingSessionLocal = async_test_db + _test_engine, AsyncTestingSessionLocal = async_test_db # Set a known password for the mock user password = "TestPassword123!" async with AsyncTestingSessionLocal() as session: - result = await session.execute(select(User).where(User.id == async_test_user.id)) + result = await session.execute( + select(User).where(User.id == async_test_user.id) + ) user = result.scalar_one_or_none() user.password_hash = get_password_hash(password) await session.commit() @@ -69,9 +75,7 @@ class TestAuthServiceAuthentication: # Authenticate with wrong password async with AsyncTestingSessionLocal() as session: auth_user = await AuthService.authenticate_user( - db=session, - email=async_test_user.email, - password="WrongPassword123" + db=session, email=async_test_user.email, password="WrongPassword123" ) assert auth_user is None @@ -79,12 +83,14 @@ class TestAuthServiceAuthentication: @pytest.mark.asyncio async def test_authenticate_inactive_user(self, async_test_db, async_test_user): """Test authenticating an inactive user""" - test_engine, AsyncTestingSessionLocal = async_test_db + _test_engine, AsyncTestingSessionLocal = async_test_db # Set a known password and make user inactive password = "TestPassword123!" async with AsyncTestingSessionLocal() as session: - result = await session.execute(select(User).where(User.id == async_test_user.id)) + result = await session.execute( + select(User).where(User.id == async_test_user.id) + ) user = result.scalar_one_or_none() user.password_hash = get_password_hash(password) user.is_active = False @@ -94,9 +100,7 @@ class TestAuthServiceAuthentication: async with AsyncTestingSessionLocal() as session: with pytest.raises(AuthenticationError): await AuthService.authenticate_user( - db=session, - email=async_test_user.email, - password=password + db=session, email=async_test_user.email, password=password ) @@ -106,14 +110,14 @@ class TestAuthServiceUserCreation: @pytest.mark.asyncio async def test_create_new_user(self, async_test_db): """Test creating a new user""" - test_engine, AsyncTestingSessionLocal = async_test_db + _test_engine, AsyncTestingSessionLocal = async_test_db user_data = UserCreate( email="newuser@example.com", password="TestPassword123!", first_name="New", last_name="User", - phone_number="+1234567890" + phone_number="+1234567890", ) async with AsyncTestingSessionLocal() as session: @@ -135,15 +139,17 @@ class TestAuthServiceUserCreation: assert user.is_superuser is False @pytest.mark.asyncio - async def test_create_user_with_existing_email(self, async_test_db, async_test_user): + async def test_create_user_with_existing_email( + self, async_test_db, async_test_user + ): """Test creating a user with an email that already exists""" - test_engine, AsyncTestingSessionLocal = async_test_db + _test_engine, AsyncTestingSessionLocal = async_test_db user_data = UserCreate( email=async_test_user.email, # Use existing email password="TestPassword123!", first_name="Duplicate", - last_name="User" + last_name="User", ) # Should raise AuthenticationError @@ -169,7 +175,7 @@ class TestAuthServiceTokens: @pytest.mark.asyncio async def test_refresh_tokens(self, async_test_db, async_test_user): """Test refreshing tokens with a valid refresh token""" - test_engine, AsyncTestingSessionLocal = async_test_db + _test_engine, AsyncTestingSessionLocal = async_test_db # Create initial tokens initial_tokens = AuthService.create_tokens(async_test_user) @@ -177,8 +183,7 @@ class TestAuthServiceTokens: # Refresh tokens async with AsyncTestingSessionLocal() as session: new_tokens = await AuthService.refresh_tokens( - db=session, - refresh_token=initial_tokens.refresh_token + db=session, refresh_token=initial_tokens.refresh_token ) # Verify new tokens are different from old ones @@ -188,7 +193,7 @@ class TestAuthServiceTokens: @pytest.mark.asyncio async def test_refresh_tokens_with_invalid_token(self, async_test_db): """Test refreshing tokens with an invalid token""" - test_engine, AsyncTestingSessionLocal = async_test_db + _test_engine, AsyncTestingSessionLocal = async_test_db # Create an invalid token invalid_token = "invalid.token.string" @@ -197,14 +202,15 @@ class TestAuthServiceTokens: async with AsyncTestingSessionLocal() as session: with pytest.raises(TokenInvalidError): await AuthService.refresh_tokens( - db=session, - refresh_token=invalid_token + db=session, refresh_token=invalid_token ) @pytest.mark.asyncio - async def test_refresh_tokens_with_access_token(self, async_test_db, async_test_user): + async def test_refresh_tokens_with_access_token( + self, async_test_db, async_test_user + ): """Test refreshing tokens with an access token instead of refresh token""" - test_engine, AsyncTestingSessionLocal = async_test_db + _test_engine, AsyncTestingSessionLocal = async_test_db # Create tokens tokens = AuthService.create_tokens(async_test_user) @@ -213,18 +219,20 @@ class TestAuthServiceTokens: async with AsyncTestingSessionLocal() as session: with pytest.raises(TokenInvalidError): await AuthService.refresh_tokens( - db=session, - refresh_token=tokens.access_token + db=session, refresh_token=tokens.access_token ) @pytest.mark.asyncio async def test_refresh_tokens_with_nonexistent_user(self, async_test_db): """Test refreshing tokens for a user that doesn't exist in the database""" - test_engine, AsyncTestingSessionLocal = async_test_db + _test_engine, AsyncTestingSessionLocal = async_test_db # Create a token for a non-existent user non_existent_id = str(uuid.uuid4()) - with patch('app.core.auth.decode_token'), patch('app.core.auth.get_token_data') as mock_get_data: + with ( + patch("app.core.auth.decode_token"), + patch("app.core.auth.get_token_data") as mock_get_data, + ): # Mock the token data to return a non-existent user ID mock_get_data.return_value.user_id = uuid.UUID(non_existent_id) @@ -232,8 +240,7 @@ class TestAuthServiceTokens: async with AsyncTestingSessionLocal() as session: with pytest.raises(TokenInvalidError): await AuthService.refresh_tokens( - db=session, - refresh_token="some.refresh.token" + db=session, refresh_token="some.refresh.token" ) @@ -243,12 +250,14 @@ class TestAuthServicePasswordChange: @pytest.mark.asyncio async def test_change_password(self, async_test_db, async_test_user): """Test changing a user's password""" - test_engine, AsyncTestingSessionLocal = async_test_db + _test_engine, AsyncTestingSessionLocal = async_test_db # Set a known password for the mock user current_password = "CurrentPassword123" async with AsyncTestingSessionLocal() as session: - result = await session.execute(select(User).where(User.id == async_test_user.id)) + result = await session.execute( + select(User).where(User.id == async_test_user.id) + ) user = result.scalar_one_or_none() user.password_hash = get_password_hash(current_password) await session.commit() @@ -260,7 +269,7 @@ class TestAuthServicePasswordChange: db=session, user_id=async_test_user.id, current_password=current_password, - new_password=new_password + new_password=new_password, ) # Verify operation was successful @@ -268,7 +277,9 @@ class TestAuthServicePasswordChange: # Verify password was changed async with AsyncTestingSessionLocal() as session: - result = await session.execute(select(User).where(User.id == async_test_user.id)) + result = await session.execute( + select(User).where(User.id == async_test_user.id) + ) updated_user = result.scalar_one_or_none() # Verify old password no longer works @@ -278,14 +289,18 @@ class TestAuthServicePasswordChange: assert verify_password(new_password, updated_user.password_hash) @pytest.mark.asyncio - async def test_change_password_wrong_current_password(self, async_test_db, async_test_user): + async def test_change_password_wrong_current_password( + self, async_test_db, async_test_user + ): """Test changing password with incorrect current password""" - test_engine, AsyncTestingSessionLocal = async_test_db + _test_engine, AsyncTestingSessionLocal = async_test_db # Set a known password for the mock user current_password = "CurrentPassword123" async with AsyncTestingSessionLocal() as session: - result = await session.execute(select(User).where(User.id == async_test_user.id)) + result = await session.execute( + select(User).where(User.id == async_test_user.id) + ) user = result.scalar_one_or_none() user.password_hash = get_password_hash(current_password) await session.commit() @@ -298,19 +313,21 @@ class TestAuthServicePasswordChange: db=session, user_id=async_test_user.id, current_password=wrong_password, - new_password="NewPassword456" + new_password="NewPassword456", ) # Verify password was not changed async with AsyncTestingSessionLocal() as session: - result = await session.execute(select(User).where(User.id == async_test_user.id)) + result = await session.execute( + select(User).where(User.id == async_test_user.id) + ) user = result.scalar_one_or_none() assert verify_password(current_password, user.password_hash) @pytest.mark.asyncio async def test_change_password_nonexistent_user(self, async_test_db): """Test changing password for a user that doesn't exist""" - test_engine, AsyncTestingSessionLocal = async_test_db + _test_engine, AsyncTestingSessionLocal = async_test_db non_existent_id = uuid.uuid4() @@ -320,5 +337,5 @@ class TestAuthServicePasswordChange: db=session, user_id=non_existent_id, current_password="CurrentPassword123", - new_password="NewPassword456" + new_password="NewPassword456", ) diff --git a/backend/tests/services/test_email_service.py b/backend/tests/services/test_email_service.py index 77d8c9d..f0da405 100755 --- a/backend/tests/services/test_email_service.py +++ b/backend/tests/services/test_email_service.py @@ -2,13 +2,15 @@ """ Tests for email service functionality. """ + +from unittest.mock import AsyncMock + import pytest -from unittest.mock import patch, AsyncMock, MagicMock from app.services.email_service import ( - EmailService, ConsoleEmailBackend, - SMTPEmailBackend + EmailService, + SMTPEmailBackend, ) @@ -24,7 +26,7 @@ class TestConsoleEmailBackend: to=["user@example.com"], subject="Test Subject", html_content="

Test HTML

", - text_content="Test Text" + text_content="Test Text", ) assert result is True @@ -37,7 +39,7 @@ class TestConsoleEmailBackend: result = await backend.send_email( to=["user@example.com"], subject="Test Subject", - html_content="

Test HTML

" + html_content="

Test HTML

", ) assert result is True @@ -50,7 +52,7 @@ class TestConsoleEmailBackend: result = await backend.send_email( to=["user1@example.com", "user2@example.com"], subject="Test Subject", - html_content="

Test HTML

" + html_content="

Test HTML

", ) assert result is True @@ -66,7 +68,7 @@ class TestSMTPEmailBackend: host="smtp.example.com", port=587, username="test@example.com", - password="password" + password="password", ) assert backend.host == "smtp.example.com" @@ -81,14 +83,14 @@ class TestSMTPEmailBackend: host="smtp.example.com", port=587, username="test@example.com", - password="password" + password="password", ) # Should fall back to console backend since SMTP is not implemented result = await backend.send_email( to=["user@example.com"], subject="Test Subject", - html_content="

Test HTML

" + html_content="

Test HTML

", ) assert result is True @@ -114,9 +116,7 @@ class TestEmailService: service = EmailService() result = await service.send_password_reset_email( - to_email="user@example.com", - reset_token="test_token_123", - user_name="John" + to_email="user@example.com", reset_token="test_token_123", user_name="John" ) assert result is True @@ -127,8 +127,7 @@ class TestEmailService: service = EmailService() result = await service.send_password_reset_email( - to_email="user@example.com", - reset_token="test_token_123" + to_email="user@example.com", reset_token="test_token_123" ) assert result is True @@ -142,8 +141,7 @@ class TestEmailService: token = "test_reset_token_xyz" await service.send_password_reset_email( - to_email="user@example.com", - reset_token=token + to_email="user@example.com", reset_token=token ) # Verify send_email was called @@ -151,7 +149,7 @@ class TestEmailService: call_args = backend_mock.send_email.call_args # Check that token is in the HTML content - html_content = call_args.kwargs['html_content'] + html_content = call_args.kwargs["html_content"] assert token in html_content @pytest.mark.asyncio @@ -162,8 +160,7 @@ class TestEmailService: service = EmailService(backend=backend_mock) result = await service.send_password_reset_email( - to_email="user@example.com", - reset_token="test_token" + to_email="user@example.com", reset_token="test_token" ) assert result is False @@ -176,7 +173,7 @@ class TestEmailService: result = await service.send_email_verification( to_email="user@example.com", verification_token="verification_token_123", - user_name="Jane" + user_name="Jane", ) assert result is True @@ -187,8 +184,7 @@ class TestEmailService: service = EmailService() result = await service.send_email_verification( - to_email="user@example.com", - verification_token="verification_token_123" + to_email="user@example.com", verification_token="verification_token_123" ) assert result is True @@ -202,8 +198,7 @@ class TestEmailService: token = "test_verification_token_xyz" await service.send_email_verification( - to_email="user@example.com", - verification_token=token + to_email="user@example.com", verification_token=token ) # Verify send_email was called @@ -211,7 +206,7 @@ class TestEmailService: call_args = backend_mock.send_email.call_args # Check that token is in the HTML content - html_content = call_args.kwargs['html_content'] + html_content = call_args.kwargs["html_content"] assert token in html_content @pytest.mark.asyncio @@ -222,8 +217,7 @@ class TestEmailService: service = EmailService(backend=backend_mock) result = await service.send_email_verification( - to_email="user@example.com", - verification_token="test_token" + to_email="user@example.com", verification_token="test_token" ) assert result is False @@ -236,14 +230,12 @@ class TestEmailService: service = EmailService(backend=backend_mock) await service.send_password_reset_email( - to_email="user@example.com", - reset_token="token123", - user_name="Test User" + to_email="user@example.com", reset_token="token123", user_name="Test User" ) call_args = backend_mock.send_email.call_args - html_content = call_args.kwargs['html_content'] - text_content = call_args.kwargs['text_content'] + html_content = call_args.kwargs["html_content"] + text_content = call_args.kwargs["text_content"] # Check HTML content assert "Password Reset" in html_content @@ -251,7 +243,9 @@ class TestEmailService: assert "Test User" in html_content # Check text content - assert "Password Reset" in text_content or "password reset" in text_content.lower() + assert ( + "Password Reset" in text_content or "password reset" in text_content.lower() + ) assert "token123" in text_content @pytest.mark.asyncio @@ -264,12 +258,12 @@ class TestEmailService: await service.send_email_verification( to_email="user@example.com", verification_token="verify123", - user_name="Test User" + user_name="Test User", ) call_args = backend_mock.send_email.call_args - html_content = call_args.kwargs['html_content'] - text_content = call_args.kwargs['text_content'] + html_content = call_args.kwargs["html_content"] + text_content = call_args.kwargs["text_content"] # Check HTML content assert "Verify" in html_content diff --git a/backend/tests/services/test_session_cleanup.py b/backend/tests/services/test_session_cleanup.py index c2e096e..4fbc227 100644 --- a/backend/tests/services/test_session_cleanup.py +++ b/backend/tests/services/test_session_cleanup.py @@ -2,23 +2,27 @@ """ Comprehensive tests for session cleanup service. """ -import pytest + import asyncio -from datetime import datetime, timedelta, timezone -from unittest.mock import patch, MagicMock, AsyncMock from contextlib import asynccontextmanager +from datetime import UTC, datetime, timedelta +from unittest.mock import AsyncMock, patch + +import pytest +from sqlalchemy import select from app.models.user_session import UserSession -from sqlalchemy import select class TestCleanupExpiredSessions: """Tests for cleanup_expired_sessions function.""" @pytest.mark.asyncio - async def test_cleanup_expired_sessions_success(self, async_test_db, async_test_user): + async def test_cleanup_expired_sessions_success( + self, async_test_db, async_test_user + ): """Test successful cleanup of expired sessions.""" - test_engine, AsyncTestingSessionLocal = async_test_db + _test_engine, AsyncTestingSessionLocal = async_test_db # Create mix of sessions async with AsyncTestingSessionLocal() as session: @@ -30,9 +34,9 @@ class TestCleanupExpiredSessions: ip_address="192.168.1.1", user_agent="Mozilla/5.0", is_active=True, - expires_at=datetime.now(timezone.utc) + timedelta(days=7), - created_at=datetime.now(timezone.utc) - timedelta(days=1), - last_used_at=datetime.now(timezone.utc) + expires_at=datetime.now(UTC) + timedelta(days=7), + created_at=datetime.now(UTC) - timedelta(days=1), + last_used_at=datetime.now(UTC), ) # 2. Inactive, expired, old (SHOULD be deleted) @@ -43,9 +47,9 @@ class TestCleanupExpiredSessions: ip_address="192.168.1.2", user_agent="Mozilla/5.0", is_active=False, - expires_at=datetime.now(timezone.utc) - timedelta(days=10), - created_at=datetime.now(timezone.utc) - timedelta(days=40), - last_used_at=datetime.now(timezone.utc) + expires_at=datetime.now(UTC) - timedelta(days=10), + created_at=datetime.now(UTC) - timedelta(days=40), + last_used_at=datetime.now(UTC), ) # 3. Inactive, expired, recent (should NOT be deleted - within keep_days) @@ -56,17 +60,23 @@ class TestCleanupExpiredSessions: ip_address="192.168.1.3", user_agent="Mozilla/5.0", is_active=False, - expires_at=datetime.now(timezone.utc) - timedelta(days=1), - created_at=datetime.now(timezone.utc) - timedelta(days=5), - last_used_at=datetime.now(timezone.utc) + expires_at=datetime.now(UTC) - timedelta(days=1), + created_at=datetime.now(UTC) - timedelta(days=5), + last_used_at=datetime.now(UTC), ) - session.add_all([active_session, old_expired_session, recent_expired_session]) + session.add_all( + [active_session, old_expired_session, recent_expired_session] + ) await session.commit() # Mock SessionLocal to return our test session - with patch('app.services.session_cleanup.SessionLocal', return_value=AsyncTestingSessionLocal()): + with patch( + "app.services.session_cleanup.SessionLocal", + return_value=AsyncTestingSessionLocal(), + ): from app.services.session_cleanup import cleanup_expired_sessions + deleted_count = await cleanup_expired_sessions(keep_days=30) # Should only delete old_expired_session @@ -85,7 +95,7 @@ class TestCleanupExpiredSessions: @pytest.mark.asyncio async def test_cleanup_no_sessions_to_delete(self, async_test_db, async_test_user): """Test cleanup when no sessions meet deletion criteria.""" - test_engine, AsyncTestingSessionLocal = async_test_db + _test_engine, AsyncTestingSessionLocal = async_test_db async with AsyncTestingSessionLocal() as session: active = UserSession( @@ -95,15 +105,19 @@ class TestCleanupExpiredSessions: ip_address="192.168.1.1", user_agent="Mozilla/5.0", is_active=True, - expires_at=datetime.now(timezone.utc) + timedelta(days=7), - created_at=datetime.now(timezone.utc), - last_used_at=datetime.now(timezone.utc) + expires_at=datetime.now(UTC) + timedelta(days=7), + created_at=datetime.now(UTC), + last_used_at=datetime.now(UTC), ) session.add(active) await session.commit() - with patch('app.services.session_cleanup.SessionLocal', return_value=AsyncTestingSessionLocal()): + with patch( + "app.services.session_cleanup.SessionLocal", + return_value=AsyncTestingSessionLocal(), + ): from app.services.session_cleanup import cleanup_expired_sessions + deleted_count = await cleanup_expired_sessions(keep_days=30) assert deleted_count == 0 @@ -111,10 +125,14 @@ class TestCleanupExpiredSessions: @pytest.mark.asyncio async def test_cleanup_empty_database(self, async_test_db): """Test cleanup with no sessions in database.""" - test_engine, AsyncTestingSessionLocal = async_test_db + _test_engine, AsyncTestingSessionLocal = async_test_db - with patch('app.services.session_cleanup.SessionLocal', return_value=AsyncTestingSessionLocal()): + with patch( + "app.services.session_cleanup.SessionLocal", + return_value=AsyncTestingSessionLocal(), + ): from app.services.session_cleanup import cleanup_expired_sessions + deleted_count = await cleanup_expired_sessions(keep_days=30) assert deleted_count == 0 @@ -122,7 +140,7 @@ class TestCleanupExpiredSessions: @pytest.mark.asyncio async def test_cleanup_with_keep_days_0(self, async_test_db, async_test_user): """Test cleanup with keep_days=0 deletes all inactive expired sessions.""" - test_engine, AsyncTestingSessionLocal = async_test_db + _test_engine, AsyncTestingSessionLocal = async_test_db async with AsyncTestingSessionLocal() as session: today_expired = UserSession( @@ -132,15 +150,19 @@ class TestCleanupExpiredSessions: ip_address="192.168.1.1", user_agent="Mozilla/5.0", is_active=False, - expires_at=datetime.now(timezone.utc) - timedelta(hours=1), - created_at=datetime.now(timezone.utc) - timedelta(hours=2), - last_used_at=datetime.now(timezone.utc) + expires_at=datetime.now(UTC) - timedelta(hours=1), + created_at=datetime.now(UTC) - timedelta(hours=2), + last_used_at=datetime.now(UTC), ) session.add(today_expired) await session.commit() - with patch('app.services.session_cleanup.SessionLocal', return_value=AsyncTestingSessionLocal()): + with patch( + "app.services.session_cleanup.SessionLocal", + return_value=AsyncTestingSessionLocal(), + ): from app.services.session_cleanup import cleanup_expired_sessions + deleted_count = await cleanup_expired_sessions(keep_days=0) assert deleted_count == 1 @@ -148,7 +170,7 @@ class TestCleanupExpiredSessions: @pytest.mark.asyncio async def test_cleanup_bulk_delete_efficiency(self, async_test_db, async_test_user): """Test that cleanup uses bulk DELETE for many sessions.""" - test_engine, AsyncTestingSessionLocal = async_test_db + _test_engine, AsyncTestingSessionLocal = async_test_db # Create 50 expired sessions async with AsyncTestingSessionLocal() as session: @@ -161,16 +183,20 @@ class TestCleanupExpiredSessions: ip_address="192.168.1.1", user_agent="Mozilla/5.0", is_active=False, - expires_at=datetime.now(timezone.utc) - timedelta(days=10), - created_at=datetime.now(timezone.utc) - timedelta(days=40), - last_used_at=datetime.now(timezone.utc) + expires_at=datetime.now(UTC) - timedelta(days=10), + created_at=datetime.now(UTC) - timedelta(days=40), + last_used_at=datetime.now(UTC), ) sessions_to_add.append(expired) session.add_all(sessions_to_add) await session.commit() - with patch('app.services.session_cleanup.SessionLocal', return_value=AsyncTestingSessionLocal()): + with patch( + "app.services.session_cleanup.SessionLocal", + return_value=AsyncTestingSessionLocal(), + ): from app.services.session_cleanup import cleanup_expired_sessions + deleted_count = await cleanup_expired_sessions(keep_days=30) assert deleted_count == 50 @@ -178,14 +204,20 @@ class TestCleanupExpiredSessions: @pytest.mark.asyncio async def test_cleanup_database_error_returns_zero(self, async_test_db): """Test cleanup returns 0 on database errors (doesn't crash).""" - test_engine, AsyncTestingSessionLocal = async_test_db + _test_engine, AsyncTestingSessionLocal = async_test_db # Mock session_crud.cleanup_expired to raise error - with patch('app.services.session_cleanup.SessionLocal', return_value=AsyncTestingSessionLocal()): - with patch('app.services.session_cleanup.session_crud.cleanup_expired') as mock_cleanup: + with patch( + "app.services.session_cleanup.SessionLocal", + return_value=AsyncTestingSessionLocal(), + ): + with patch( + "app.services.session_cleanup.session_crud.cleanup_expired" + ) as mock_cleanup: mock_cleanup.side_effect = Exception("Database connection lost") from app.services.session_cleanup import cleanup_expired_sessions + # Should not crash, should return 0 deleted_count = await cleanup_expired_sessions(keep_days=30) @@ -198,7 +230,7 @@ class TestGetSessionStatistics: @pytest.mark.asyncio async def test_get_statistics_with_sessions(self, async_test_db, async_test_user): """Test getting session statistics with various session types.""" - test_engine, AsyncTestingSessionLocal = async_test_db + _test_engine, AsyncTestingSessionLocal = async_test_db async with AsyncTestingSessionLocal() as session: # 2 active, not expired @@ -210,9 +242,9 @@ class TestGetSessionStatistics: ip_address="192.168.1.1", user_agent="Mozilla/5.0", is_active=True, - expires_at=datetime.now(timezone.utc) + timedelta(days=7), - created_at=datetime.now(timezone.utc), - last_used_at=datetime.now(timezone.utc) + expires_at=datetime.now(UTC) + timedelta(days=7), + created_at=datetime.now(UTC), + last_used_at=datetime.now(UTC), ) session.add(active) @@ -225,9 +257,9 @@ class TestGetSessionStatistics: ip_address="192.168.1.2", user_agent="Mozilla/5.0", is_active=False, - expires_at=datetime.now(timezone.utc) - timedelta(days=1), - created_at=datetime.now(timezone.utc) - timedelta(days=2), - last_used_at=datetime.now(timezone.utc) + expires_at=datetime.now(UTC) - timedelta(days=1), + created_at=datetime.now(UTC) - timedelta(days=2), + last_used_at=datetime.now(UTC), ) session.add(inactive) @@ -239,16 +271,20 @@ class TestGetSessionStatistics: ip_address="192.168.1.3", user_agent="Mozilla/5.0", is_active=True, - expires_at=datetime.now(timezone.utc) - timedelta(hours=1), - created_at=datetime.now(timezone.utc) - timedelta(days=1), - last_used_at=datetime.now(timezone.utc) + expires_at=datetime.now(UTC) - timedelta(hours=1), + created_at=datetime.now(UTC) - timedelta(days=1), + last_used_at=datetime.now(UTC), ) session.add(expired_active) await session.commit() - with patch('app.services.session_cleanup.SessionLocal', return_value=AsyncTestingSessionLocal()): + with patch( + "app.services.session_cleanup.SessionLocal", + return_value=AsyncTestingSessionLocal(), + ): from app.services.session_cleanup import get_session_statistics + stats = await get_session_statistics() assert stats["total"] == 6 @@ -259,10 +295,14 @@ class TestGetSessionStatistics: @pytest.mark.asyncio async def test_get_statistics_empty_database(self, async_test_db): """Test getting statistics with no sessions.""" - test_engine, AsyncTestingSessionLocal = async_test_db + _test_engine, AsyncTestingSessionLocal = async_test_db - with patch('app.services.session_cleanup.SessionLocal', return_value=AsyncTestingSessionLocal()): + with patch( + "app.services.session_cleanup.SessionLocal", + return_value=AsyncTestingSessionLocal(), + ): from app.services.session_cleanup import get_session_statistics + stats = await get_session_statistics() assert stats["total"] == 0 @@ -271,9 +311,11 @@ class TestGetSessionStatistics: assert stats["expired"] == 0 @pytest.mark.asyncio - async def test_get_statistics_database_error_returns_empty_dict(self, async_test_db): + async def test_get_statistics_database_error_returns_empty_dict( + self, async_test_db + ): """Test statistics returns empty dict on database errors.""" - test_engine, AsyncTestingSessionLocal = async_test_db + _test_engine, _AsyncTestingSessionLocal = async_test_db # Create a mock that raises on execute mock_session = AsyncMock() @@ -283,8 +325,12 @@ class TestGetSessionStatistics: async def mock_session_local(): yield mock_session - with patch('app.services.session_cleanup.SessionLocal', return_value=mock_session_local()): + with patch( + "app.services.session_cleanup.SessionLocal", + return_value=mock_session_local(), + ): from app.services.session_cleanup import get_session_statistics + stats = await get_session_statistics() assert stats == {} @@ -294,9 +340,11 @@ class TestConcurrentCleanup: """Tests for concurrent cleanup scenarios.""" @pytest.mark.asyncio - async def test_concurrent_cleanup_no_duplicate_deletes(self, async_test_db, async_test_user): + async def test_concurrent_cleanup_no_duplicate_deletes( + self, async_test_db, async_test_user + ): """Test concurrent cleanups don't cause race conditions.""" - test_engine, AsyncTestingSessionLocal = async_test_db + _test_engine, AsyncTestingSessionLocal = async_test_db # Create 10 expired sessions async with AsyncTestingSessionLocal() as session: @@ -308,20 +356,24 @@ class TestConcurrentCleanup: ip_address="192.168.1.1", user_agent="Mozilla/5.0", is_active=False, - expires_at=datetime.now(timezone.utc) - timedelta(days=10), - created_at=datetime.now(timezone.utc) - timedelta(days=40), - last_used_at=datetime.now(timezone.utc) + expires_at=datetime.now(UTC) - timedelta(days=10), + created_at=datetime.now(UTC) - timedelta(days=40), + last_used_at=datetime.now(UTC), ) session.add(expired) await session.commit() # Run two cleanups concurrently # Use side_effect to return fresh session instances for each call - with patch('app.services.session_cleanup.SessionLocal', side_effect=lambda: AsyncTestingSessionLocal()): + with patch( + "app.services.session_cleanup.SessionLocal", + side_effect=lambda: AsyncTestingSessionLocal(), + ): from app.services.session_cleanup import cleanup_expired_sessions + results = await asyncio.gather( cleanup_expired_sessions(keep_days=30), - cleanup_expired_sessions(keep_days=30) + cleanup_expired_sessions(keep_days=30), ) # Both should report deleting sessions (may overlap due to transaction timing) diff --git a/backend/tests/test_init_db.py b/backend/tests/test_init_db.py index 713f313..1658760 100644 --- a/backend/tests/test_init_db.py +++ b/backend/tests/test_init_db.py @@ -2,12 +2,13 @@ """ Tests for database initialization script. """ -import pytest -import pytest_asyncio -from unittest.mock import AsyncMock, patch -from app.init_db import init_db +from unittest.mock import patch + +import pytest + from app.core.config import settings +from app.init_db import init_db class TestInitDb: @@ -16,69 +17,86 @@ class TestInitDb: @pytest.mark.asyncio async def test_init_db_creates_superuser_when_not_exists(self, async_test_db): """Test that init_db creates a superuser when one doesn't exist.""" - test_engine, SessionLocal = async_test_db + _test_engine, SessionLocal = async_test_db # Mock the SessionLocal to use our test database - with patch('app.init_db.SessionLocal', SessionLocal): + with patch("app.init_db.SessionLocal", SessionLocal): # Mock settings to provide test credentials - with patch.object(settings, 'FIRST_SUPERUSER_EMAIL', 'test_admin@example.com'): - with patch.object(settings, 'FIRST_SUPERUSER_PASSWORD', 'TestAdmin123!'): + with patch.object( + settings, "FIRST_SUPERUSER_EMAIL", "test_admin@example.com" + ): + with patch.object( + settings, "FIRST_SUPERUSER_PASSWORD", "TestAdmin123!" + ): # Run init_db user = await init_db() # Verify superuser was created assert user is not None - assert user.email == 'test_admin@example.com' + assert user.email == "test_admin@example.com" assert user.is_superuser is True - assert user.first_name == 'Admin' - assert user.last_name == 'User' + assert user.first_name == "Admin" + assert user.last_name == "User" @pytest.mark.asyncio - async def test_init_db_returns_existing_superuser(self, async_test_db, async_test_user): + async def test_init_db_returns_existing_superuser( + self, async_test_db, async_test_user + ): """Test that init_db returns existing superuser instead of creating duplicate.""" - test_engine, SessionLocal = async_test_db + _test_engine, SessionLocal = async_test_db # Mock the SessionLocal to use our test database - with patch('app.init_db.SessionLocal', SessionLocal): + with patch("app.init_db.SessionLocal", SessionLocal): # Mock settings to match async_test_user's email - with patch.object(settings, 'FIRST_SUPERUSER_EMAIL', 'testuser@example.com'): - with patch.object(settings, 'FIRST_SUPERUSER_PASSWORD', 'TestPassword123!'): + with patch.object( + settings, "FIRST_SUPERUSER_EMAIL", "testuser@example.com" + ): + with patch.object( + settings, "FIRST_SUPERUSER_PASSWORD", "TestPassword123!" + ): # Run init_db user = await init_db() # Verify it returns the existing user assert user is not None assert user.id == async_test_user.id - assert user.email == 'testuser@example.com' + assert user.email == "testuser@example.com" @pytest.mark.asyncio async def test_init_db_uses_default_credentials(self, async_test_db): """Test that init_db uses default credentials when env vars not set.""" - test_engine, SessionLocal = async_test_db + _test_engine, SessionLocal = async_test_db # Mock the SessionLocal to use our test database - with patch('app.init_db.SessionLocal', SessionLocal): + with patch("app.init_db.SessionLocal", SessionLocal): # Mock settings to have None values (not configured) - with patch.object(settings, 'FIRST_SUPERUSER_EMAIL', None): - with patch.object(settings, 'FIRST_SUPERUSER_PASSWORD', None): + with patch.object(settings, "FIRST_SUPERUSER_EMAIL", None): + with patch.object(settings, "FIRST_SUPERUSER_PASSWORD", None): # Run init_db user = await init_db() # Verify superuser was created with defaults assert user is not None - assert user.email == 'admin@example.com' + assert user.email == "admin@example.com" assert user.is_superuser is True @pytest.mark.asyncio async def test_init_db_handles_database_errors(self, async_test_db): """Test that init_db handles database errors gracefully.""" - test_engine, SessionLocal = async_test_db + _test_engine, SessionLocal = async_test_db # Mock user_crud.get_by_email to raise an exception - with patch('app.init_db.user_crud.get_by_email', side_effect=Exception("Database error")): - with patch('app.init_db.SessionLocal', SessionLocal): - with patch.object(settings, 'FIRST_SUPERUSER_EMAIL', 'test@example.com'): - with patch.object(settings, 'FIRST_SUPERUSER_PASSWORD', 'TestPassword123!'): + with patch( + "app.init_db.user_crud.get_by_email", + side_effect=Exception("Database error"), + ): + with patch("app.init_db.SessionLocal", SessionLocal): + with patch.object( + settings, "FIRST_SUPERUSER_EMAIL", "test@example.com" + ): + with patch.object( + settings, "FIRST_SUPERUSER_PASSWORD", "TestPassword123!" + ): # Run init_db and expect it to raise with pytest.raises(Exception, match="Database error"): await init_db() diff --git a/backend/tests/utils/test_device.py b/backend/tests/utils/test_device.py index f122441..9816b4a 100644 --- a/backend/tests/utils/test_device.py +++ b/backend/tests/utils/test_device.py @@ -2,18 +2,18 @@ """ Comprehensive tests for device utility functions. """ -import pytest + from unittest.mock import Mock from fastapi import Request from app.utils.device import ( - extract_device_info, - parse_device_name, extract_browser, + extract_device_info, get_client_ip, + get_device_type, is_mobile_device, - get_device_type + parse_device_name, ) @@ -138,7 +138,9 @@ class TestExtractBrowser: def test_extract_browser_edge_legacy(self): """Test extracting legacy Edge browser.""" - ua = "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 Edge/18.19582" + ua = ( + "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 Edge/18.19582" + ) result = extract_browser(ua) assert result == "Edge" @@ -249,7 +251,7 @@ class TestGetClientIp: request = Mock(spec=Request) request.headers = { "x-forwarded-for": "192.168.1.100", - "x-real-ip": "192.168.1.200" + "x-real-ip": "192.168.1.200", } request.client = Mock() request.client.host = "192.168.1.50" @@ -385,7 +387,7 @@ class TestExtractDeviceInfo: request.headers = { "user-agent": "Mozilla/5.0 (iPhone; CPU iPhone OS 15_0 like Mac OS X)", "x-device-id": "device-123-456", - "x-forwarded-for": "192.168.1.100" + "x-forwarded-for": "192.168.1.100", } request.client = None diff --git a/backend/tests/utils/test_security.py b/backend/tests/utils/test_security.py index 52c4a9e..2b7d44b 100755 --- a/backend/tests/utils/test_security.py +++ b/backend/tests/utils/test_security.py @@ -2,19 +2,21 @@ """ Tests for security utility functions. """ -import time + import base64 import json +import time +from unittest.mock import MagicMock, patch + import pytest -from unittest.mock import patch, MagicMock from app.utils.security import ( - create_upload_token, - verify_upload_token, - create_password_reset_token, - verify_password_reset_token, create_email_verification_token, - verify_email_verification_token + create_password_reset_token, + create_upload_token, + verify_email_verification_token, + verify_password_reset_token, + verify_upload_token, ) @@ -31,7 +33,7 @@ class TestCreateUploadToken: # Token should be base64 encoded try: - decoded = base64.urlsafe_b64decode(token.encode('utf-8')) + decoded = base64.urlsafe_b64decode(token.encode("utf-8")) token_data = json.loads(decoded) assert "payload" in token_data assert "signature" in token_data @@ -46,7 +48,7 @@ class TestCreateUploadToken: token = create_upload_token(file_path, content_type) # Decode and verify payload - decoded = base64.urlsafe_b64decode(token.encode('utf-8')) + decoded = base64.urlsafe_b64decode(token.encode("utf-8")) token_data = json.loads(decoded) payload = token_data["payload"] @@ -62,7 +64,7 @@ class TestCreateUploadToken: after = int(time.time()) # Decode token - decoded = base64.urlsafe_b64decode(token.encode('utf-8')) + decoded = base64.urlsafe_b64decode(token.encode("utf-8")) token_data = json.loads(decoded) payload = token_data["payload"] @@ -74,11 +76,13 @@ class TestCreateUploadToken: """Test token creation with custom expiration time.""" custom_exp = 600 # 10 minutes before = int(time.time()) - token = create_upload_token("/uploads/test.jpg", "image/jpeg", expires_in=custom_exp) + token = create_upload_token( + "/uploads/test.jpg", "image/jpeg", expires_in=custom_exp + ) after = int(time.time()) # Decode token - decoded = base64.urlsafe_b64decode(token.encode('utf-8')) + decoded = base64.urlsafe_b64decode(token.encode("utf-8")) token_data = json.loads(decoded) payload = token_data["payload"] @@ -92,11 +96,11 @@ class TestCreateUploadToken: token2 = create_upload_token("/uploads/test.jpg", "image/jpeg") # Decode both tokens - decoded1 = base64.urlsafe_b64decode(token1.encode('utf-8')) + decoded1 = base64.urlsafe_b64decode(token1.encode("utf-8")) token_data1 = json.loads(decoded1) nonce1 = token_data1["payload"]["nonce"] - decoded2 = base64.urlsafe_b64decode(token2.encode('utf-8')) + decoded2 = base64.urlsafe_b64decode(token2.encode("utf-8")) token_data2 = json.loads(decoded2) nonce2 = token_data2["payload"]["nonce"] @@ -133,7 +137,7 @@ class TestVerifyUploadToken: current_time = 1000000 mock_time.time = MagicMock(return_value=current_time) - with patch('app.utils.security.time', mock_time): + with patch("app.utils.security.time", mock_time): # Create token that "expires" at current_time + 1 token = create_upload_token("/uploads/test.jpg", "image/jpeg", expires_in=1) @@ -149,13 +153,15 @@ class TestVerifyUploadToken: token = create_upload_token("/uploads/test.jpg", "image/jpeg") # Decode, modify, and re-encode - decoded = base64.urlsafe_b64decode(token.encode('utf-8')) + decoded = base64.urlsafe_b64decode(token.encode("utf-8")) token_data = json.loads(decoded) token_data["signature"] = "invalid_signature" # Re-encode the tampered token tampered_json = json.dumps(token_data) - tampered_token = base64.urlsafe_b64encode(tampered_json.encode('utf-8')).decode('utf-8') + tampered_token = base64.urlsafe_b64encode(tampered_json.encode("utf-8")).decode( + "utf-8" + ) payload = verify_upload_token(tampered_token) assert payload is None @@ -165,13 +171,15 @@ class TestVerifyUploadToken: token = create_upload_token("/uploads/test.jpg", "image/jpeg") # Decode, modify payload, and re-encode - decoded = base64.urlsafe_b64decode(token.encode('utf-8')) + decoded = base64.urlsafe_b64decode(token.encode("utf-8")) token_data = json.loads(decoded) token_data["payload"]["path"] = "/uploads/hacked.exe" # Re-encode the tampered token (signature won't match) tampered_json = json.dumps(token_data) - tampered_token = base64.urlsafe_b64encode(tampered_json.encode('utf-8')).decode('utf-8') + tampered_token = base64.urlsafe_b64encode(tampered_json.encode("utf-8")).decode( + "utf-8" + ) payload = verify_upload_token(tampered_token) assert payload is None @@ -194,7 +202,9 @@ class TestVerifyUploadToken: """Test that tokens with invalid JSON are rejected.""" # Create a base64 string that decodes to invalid JSON invalid_json = "not valid json" - invalid_token = base64.urlsafe_b64encode(invalid_json.encode('utf-8')).decode('utf-8') + invalid_token = base64.urlsafe_b64encode(invalid_json.encode("utf-8")).decode( + "utf-8" + ) payload = verify_upload_token(invalid_token) assert payload is None @@ -207,11 +217,13 @@ class TestVerifyUploadToken: "path": "/uploads/test.jpg" # Missing content_type, exp, nonce }, - "signature": "some_signature" + "signature": "some_signature", } incomplete_json = json.dumps(incomplete_data) - incomplete_token = base64.urlsafe_b64encode(incomplete_json.encode('utf-8')).decode('utf-8') + incomplete_token = base64.urlsafe_b64encode( + incomplete_json.encode("utf-8") + ).decode("utf-8") payload = verify_upload_token(incomplete_token) assert payload is None @@ -266,7 +278,7 @@ class TestPasswordResetTokens: email = "user@example.com" # Create token that expires in 1 second - with patch('app.utils.security.time') as mock_time: + with patch("app.utils.security.time") as mock_time: mock_time.time = MagicMock(return_value=1000000) token = create_password_reset_token(email, expires_in=1) @@ -287,12 +299,14 @@ class TestPasswordResetTokens: token = create_password_reset_token(email) # Decode and tamper - decoded = base64.urlsafe_b64decode(token.encode('utf-8')).decode('utf-8') + decoded = base64.urlsafe_b64decode(token.encode("utf-8")).decode("utf-8") token_data = json.loads(decoded) token_data["payload"]["email"] = "hacker@example.com" # Re-encode - tampered = base64.urlsafe_b64encode(json.dumps(token_data).encode('utf-8')).decode('utf-8') + tampered = base64.urlsafe_b64encode( + json.dumps(token_data).encode("utf-8") + ).decode("utf-8") verified_email = verify_password_reset_token(tampered) assert verified_email is None @@ -312,14 +326,14 @@ class TestPasswordResetTokens: email = "user@example.com" custom_exp = 7200 # 2 hours - with patch('app.utils.security.time') as mock_time: + with patch("app.utils.security.time") as mock_time: current_time = 1000000 mock_time.time = MagicMock(return_value=current_time) token = create_password_reset_token(email, expires_in=custom_exp) # Decode to check expiration - decoded = base64.urlsafe_b64decode(token.encode('utf-8')).decode('utf-8') + decoded = base64.urlsafe_b64decode(token.encode("utf-8")).decode("utf-8") token_data = json.loads(decoded) assert token_data["payload"]["exp"] == current_time + custom_exp @@ -350,7 +364,7 @@ class TestEmailVerificationTokens: """Test that expired verification tokens are rejected.""" email = "user@example.com" - with patch('app.utils.security.time') as mock_time: + with patch("app.utils.security.time") as mock_time: mock_time.time = MagicMock(return_value=1000000) token = create_email_verification_token(email, expires_in=1) @@ -371,12 +385,14 @@ class TestEmailVerificationTokens: token = create_email_verification_token(email) # Decode and tamper - decoded = base64.urlsafe_b64decode(token.encode('utf-8')).decode('utf-8') + decoded = base64.urlsafe_b64decode(token.encode("utf-8")).decode("utf-8") token_data = json.loads(decoded) token_data["payload"]["email"] = "hacker@example.com" # Re-encode - tampered = base64.urlsafe_b64encode(json.dumps(token_data).encode('utf-8')).decode('utf-8') + tampered = base64.urlsafe_b64encode( + json.dumps(token_data).encode("utf-8") + ).decode("utf-8") verified_email = verify_email_verification_token(tampered) assert verified_email is None @@ -395,14 +411,14 @@ class TestEmailVerificationTokens: """Test email verification token with default 24-hour expiration.""" email = "user@example.com" - with patch('app.utils.security.time') as mock_time: + with patch("app.utils.security.time") as mock_time: current_time = 1000000 mock_time.time = MagicMock(return_value=current_time) token = create_email_verification_token(email) # Decode to check expiration (should be 86400 seconds = 24 hours) - decoded = base64.urlsafe_b64decode(token.encode('utf-8')).decode('utf-8') + decoded = base64.urlsafe_b64decode(token.encode("utf-8")).decode("utf-8") token_data = json.loads(decoded) assert token_data["payload"]["exp"] == current_time + 86400