Add pyproject.toml for consolidated project configuration and replace Black, isort, and Flake8 with Ruff

- Introduced `pyproject.toml` to centralize backend tool configurations (e.g., Ruff, mypy, coverage, pytest).
- Replaced Black, isort, and Flake8 with Ruff for linting, formatting, and import sorting.
- Updated `requirements.txt` to include Ruff and remove replaced tools.
- Added `Makefile` to streamline development workflows with commands for linting, formatting, type-checking, testing, and cleanup.
This commit is contained in:
2025-11-10 11:55:15 +01:00
parent a5c671c133
commit c589b565f0
86 changed files with 4572 additions and 3956 deletions

View File

@@ -2,12 +2,11 @@ import sys
from logging.config import fileConfig from logging.config import fileConfig
from pathlib import Path from pathlib import Path
from sqlalchemy import engine_from_config, pool, text, create_engine from alembic import context
from sqlalchemy import create_engine, engine_from_config, pool, text
from sqlalchemy.engine.url import make_url from sqlalchemy.engine.url import make_url
from sqlalchemy.exc import OperationalError from sqlalchemy.exc import OperationalError
from alembic import context
# Get the path to the app directory (parent of 'alembic') # Get the path to the app directory (parent of 'alembic')
app_dir = Path(__file__).resolve().parent.parent app_dir = Path(__file__).resolve().parent.parent
# Add the app directory to Python path # Add the app directory to Python path
@@ -66,7 +65,9 @@ def ensure_database_exists(db_url: str) -> None:
admin_url = url.set(database="postgres") admin_url = url.set(database="postgres")
# CREATE DATABASE cannot run inside a transaction # CREATE DATABASE cannot run inside a transaction
admin_engine = create_engine(str(admin_url), isolation_level="AUTOCOMMIT", poolclass=pool.NullPool) admin_engine = create_engine(
str(admin_url), isolation_level="AUTOCOMMIT", poolclass=pool.NullPool
)
try: try:
with admin_engine.connect() as conn: with admin_engine.connect() as conn:
exists = conn.execute( exists = conn.execute(
@@ -122,9 +123,7 @@ def run_migrations_online() -> None:
) )
with connectable.connect() as connection: with connectable.connect() as connection:
context.configure( context.configure(connection=connection, target_metadata=target_metadata)
connection=connection, target_metadata=target_metadata
)
with context.begin_transaction(): with context.begin_transaction():
context.run_migrations() context.run_migrations()

View File

@@ -5,17 +5,17 @@ Revises: fbf6318a8a36
Create Date: 2025-11-01 04:15:25.367010 Create Date: 2025-11-01 04:15:25.367010
""" """
from typing import Sequence, Union
from collections.abc import Sequence
import sqlalchemy as sa import sqlalchemy as sa
from alembic import op from alembic import op
# revision identifiers, used by Alembic. # revision identifiers, used by Alembic.
revision: str = '1174fffbe3e4' revision: str = "1174fffbe3e4"
down_revision: Union[str, None] = 'fbf6318a8a36' down_revision: str | None = "fbf6318a8a36"
branch_labels: Union[str, Sequence[str], None] = None branch_labels: str | Sequence[str] | None = None
depends_on: Union[str, Sequence[str], None] = None depends_on: str | Sequence[str] | None = None
def upgrade() -> None: def upgrade() -> None:
@@ -24,46 +24,46 @@ def upgrade() -> None:
# Index for session cleanup queries # Index for session cleanup queries
# Optimizes: DELETE WHERE is_active = FALSE AND expires_at < now AND created_at < cutoff # Optimizes: DELETE WHERE is_active = FALSE AND expires_at < now AND created_at < cutoff
op.create_index( op.create_index(
'ix_user_sessions_cleanup', "ix_user_sessions_cleanup",
'user_sessions', "user_sessions",
['is_active', 'expires_at', 'created_at'], ["is_active", "expires_at", "created_at"],
unique=False, unique=False,
postgresql_where=sa.text('is_active = false') postgresql_where=sa.text("is_active = false"),
) )
# Index for user search queries (basic trigram support without pg_trgm extension) # Index for user search queries (basic trigram support without pg_trgm extension)
# Optimizes: WHERE email ILIKE '%search%' OR first_name ILIKE '%search%' # Optimizes: WHERE email ILIKE '%search%' OR first_name ILIKE '%search%'
# Note: For better performance, consider enabling pg_trgm extension # Note: For better performance, consider enabling pg_trgm extension
op.create_index( op.create_index(
'ix_users_email_lower', "ix_users_email_lower",
'users', "users",
[sa.text('LOWER(email)')], [sa.text("LOWER(email)")],
unique=False, unique=False,
postgresql_where=sa.text('deleted_at IS NULL') postgresql_where=sa.text("deleted_at IS NULL"),
) )
op.create_index( op.create_index(
'ix_users_first_name_lower', "ix_users_first_name_lower",
'users', "users",
[sa.text('LOWER(first_name)')], [sa.text("LOWER(first_name)")],
unique=False, unique=False,
postgresql_where=sa.text('deleted_at IS NULL') postgresql_where=sa.text("deleted_at IS NULL"),
) )
op.create_index( op.create_index(
'ix_users_last_name_lower', "ix_users_last_name_lower",
'users', "users",
[sa.text('LOWER(last_name)')], [sa.text("LOWER(last_name)")],
unique=False, unique=False,
postgresql_where=sa.text('deleted_at IS NULL') postgresql_where=sa.text("deleted_at IS NULL"),
) )
# Index for organization search # Index for organization search
op.create_index( op.create_index(
'ix_organizations_name_lower', "ix_organizations_name_lower",
'organizations', "organizations",
[sa.text('LOWER(name)')], [sa.text("LOWER(name)")],
unique=False unique=False,
) )
@@ -71,8 +71,8 @@ def downgrade() -> None:
"""Remove performance indexes.""" """Remove performance indexes."""
# Drop indexes in reverse order # Drop indexes in reverse order
op.drop_index('ix_organizations_name_lower', table_name='organizations') op.drop_index("ix_organizations_name_lower", table_name="organizations")
op.drop_index('ix_users_last_name_lower', table_name='users') op.drop_index("ix_users_last_name_lower", table_name="users")
op.drop_index('ix_users_first_name_lower', table_name='users') op.drop_index("ix_users_first_name_lower", table_name="users")
op.drop_index('ix_users_email_lower', table_name='users') op.drop_index("ix_users_email_lower", table_name="users")
op.drop_index('ix_user_sessions_cleanup', table_name='user_sessions') op.drop_index("ix_user_sessions_cleanup", table_name="user_sessions")

View File

@@ -5,30 +5,32 @@ Revises: 9e4f2a1b8c7d
Create Date: 2025-10-30 16:40:21.000021 Create Date: 2025-10-30 16:40:21.000021
""" """
from typing import Sequence, Union
from collections.abc import Sequence
import sqlalchemy as sa import sqlalchemy as sa
from alembic import op from alembic import op
# revision identifiers, used by Alembic. # revision identifiers, used by Alembic.
revision: str = '2d0fcec3b06d' revision: str = "2d0fcec3b06d"
down_revision: Union[str, None] = '9e4f2a1b8c7d' down_revision: str | None = "9e4f2a1b8c7d"
branch_labels: Union[str, Sequence[str], None] = None branch_labels: str | Sequence[str] | None = None
depends_on: Union[str, Sequence[str], None] = None depends_on: str | Sequence[str] | None = None
def upgrade() -> None: def upgrade() -> None:
# Add deleted_at column for soft deletes # Add deleted_at column for soft deletes
op.add_column('users', sa.Column('deleted_at', sa.DateTime(timezone=True), nullable=True)) op.add_column(
"users", sa.Column("deleted_at", sa.DateTime(timezone=True), nullable=True)
)
# Add index on deleted_at for efficient queries # Add index on deleted_at for efficient queries
op.create_index('ix_users_deleted_at', 'users', ['deleted_at']) op.create_index("ix_users_deleted_at", "users", ["deleted_at"])
def downgrade() -> None: def downgrade() -> None:
# Remove index # Remove index
op.drop_index('ix_users_deleted_at', table_name='users') op.drop_index("ix_users_deleted_at", table_name="users")
# Remove column # Remove column
op.drop_column('users', 'deleted_at') op.drop_column("users", "deleted_at")

View File

@@ -5,42 +5,42 @@ Revises: 7396957cbe80
Create Date: 2025-02-28 09:19:33.212278 Create Date: 2025-02-28 09:19:33.212278
""" """
from typing import Sequence, Union
from collections.abc import Sequence
import sqlalchemy as sa import sqlalchemy as sa
from alembic import op from alembic import op
# revision identifiers, used by Alembic. # revision identifiers, used by Alembic.
revision: str = '38bf9e7e74b3' revision: str = "38bf9e7e74b3"
down_revision: Union[str, None] = '7396957cbe80' down_revision: str | None = "7396957cbe80"
branch_labels: Union[str, Sequence[str], None] = None branch_labels: str | Sequence[str] | None = None
depends_on: Union[str, Sequence[str], None] = None depends_on: str | Sequence[str] | None = None
def upgrade() -> None: def upgrade() -> None:
op.create_table(
op.create_table('users', "users",
sa.Column('email', sa.String(), nullable=False), sa.Column("email", sa.String(), nullable=False),
sa.Column('password_hash', sa.String(), nullable=False), sa.Column("password_hash", sa.String(), nullable=False),
sa.Column('first_name', sa.String(), nullable=False), sa.Column("first_name", sa.String(), nullable=False),
sa.Column('last_name', sa.String(), nullable=True), sa.Column("last_name", sa.String(), nullable=True),
sa.Column('phone_number', sa.String(), nullable=True), sa.Column("phone_number", sa.String(), nullable=True),
sa.Column('is_active', sa.Boolean(), nullable=False), sa.Column("is_active", sa.Boolean(), nullable=False),
sa.Column('is_superuser', sa.Boolean(), nullable=False), sa.Column("is_superuser", sa.Boolean(), nullable=False),
sa.Column('preferences', sa.JSON(), nullable=True), sa.Column("preferences", sa.JSON(), nullable=True),
sa.Column('id', sa.UUID(), nullable=False), sa.Column("id", sa.UUID(), nullable=False),
sa.Column('created_at', sa.DateTime(timezone=True), nullable=False), sa.Column("created_at", sa.DateTime(timezone=True), nullable=False),
sa.Column('updated_at', sa.DateTime(timezone=True), nullable=False), sa.Column("updated_at", sa.DateTime(timezone=True), nullable=False),
sa.PrimaryKeyConstraint('id') sa.PrimaryKeyConstraint("id"),
) )
op.create_index(op.f('ix_users_email'), 'users', ['email'], unique=True) op.create_index(op.f("ix_users_email"), "users", ["email"], unique=True)
# ### end Alembic commands ### # ### end Alembic commands ###
def downgrade() -> None: def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ### # ### commands auto generated by Alembic - please adjust! ###
op.drop_index(op.f('ix_users_email'), table_name='users') op.drop_index(op.f("ix_users_email"), table_name="users")
op.drop_table('users') op.drop_table("users")
# ### end Alembic commands ### # ### end Alembic commands ###

View File

@@ -5,98 +5,85 @@ Revises: b76c725fc3cf
Create Date: 2025-10-31 07:41:18.729544 Create Date: 2025-10-31 07:41:18.729544
""" """
from typing import Sequence, Union
from collections.abc import Sequence
import sqlalchemy as sa import sqlalchemy as sa
from alembic import op from alembic import op
# revision identifiers, used by Alembic. # revision identifiers, used by Alembic.
revision: str = '549b50ea888d' revision: str = "549b50ea888d"
down_revision: Union[str, None] = 'b76c725fc3cf' down_revision: str | None = "b76c725fc3cf"
branch_labels: Union[str, Sequence[str], None] = None branch_labels: str | Sequence[str] | None = None
depends_on: Union[str, Sequence[str], None] = None depends_on: str | Sequence[str] | None = None
def upgrade() -> None: def upgrade() -> None:
# Create user_sessions table for per-device session management # Create user_sessions table for per-device session management
op.create_table( op.create_table(
'user_sessions', "user_sessions",
sa.Column('id', sa.UUID(), nullable=False), sa.Column("id", sa.UUID(), nullable=False),
sa.Column('user_id', sa.UUID(), nullable=False), sa.Column("user_id", sa.UUID(), nullable=False),
sa.Column('refresh_token_jti', sa.String(length=255), nullable=False), sa.Column("refresh_token_jti", sa.String(length=255), nullable=False),
sa.Column('device_name', sa.String(length=255), nullable=True), sa.Column("device_name", sa.String(length=255), nullable=True),
sa.Column('device_id', sa.String(length=255), nullable=True), sa.Column("device_id", sa.String(length=255), nullable=True),
sa.Column('ip_address', sa.String(length=45), nullable=True), sa.Column("ip_address", sa.String(length=45), nullable=True),
sa.Column('user_agent', sa.String(length=500), nullable=True), sa.Column("user_agent", sa.String(length=500), nullable=True),
sa.Column('last_used_at', sa.DateTime(timezone=True), nullable=False), sa.Column("last_used_at", sa.DateTime(timezone=True), nullable=False),
sa.Column('expires_at', sa.DateTime(timezone=True), nullable=False), sa.Column("expires_at", sa.DateTime(timezone=True), nullable=False),
sa.Column('is_active', sa.Boolean(), nullable=False, server_default='true'), sa.Column("is_active", sa.Boolean(), nullable=False, server_default="true"),
sa.Column('location_city', sa.String(length=100), nullable=True), sa.Column("location_city", sa.String(length=100), nullable=True),
sa.Column('location_country', sa.String(length=100), nullable=True), sa.Column("location_country", sa.String(length=100), nullable=True),
sa.Column('created_at', sa.DateTime(timezone=True), nullable=False), sa.Column("created_at", sa.DateTime(timezone=True), nullable=False),
sa.Column('updated_at', sa.DateTime(timezone=True), nullable=False), sa.Column("updated_at", sa.DateTime(timezone=True), nullable=False),
sa.PrimaryKeyConstraint('id') sa.PrimaryKeyConstraint("id"),
) )
# Create foreign key to users table # Create foreign key to users table
op.create_foreign_key( op.create_foreign_key(
'fk_user_sessions_user_id', "fk_user_sessions_user_id",
'user_sessions', "user_sessions",
'users', "users",
['user_id'], ["user_id"],
['id'], ["id"],
ondelete='CASCADE' ondelete="CASCADE",
) )
# Create indexes for performance # Create indexes for performance
# 1. Lookup session by refresh token JTI (most common query) # 1. Lookup session by refresh token JTI (most common query)
op.create_index( op.create_index(
'ix_user_sessions_jti', "ix_user_sessions_jti", "user_sessions", ["refresh_token_jti"], unique=True
'user_sessions',
['refresh_token_jti'],
unique=True
) )
# 2. Lookup sessions by user ID # 2. Lookup sessions by user ID
op.create_index( op.create_index("ix_user_sessions_user_id", "user_sessions", ["user_id"])
'ix_user_sessions_user_id',
'user_sessions',
['user_id']
)
# 3. Composite index for active sessions by user # 3. Composite index for active sessions by user
op.create_index( op.create_index(
'ix_user_sessions_user_active', "ix_user_sessions_user_active", "user_sessions", ["user_id", "is_active"]
'user_sessions',
['user_id', 'is_active']
) )
# 4. Index on expires_at for cleanup job # 4. Index on expires_at for cleanup job
op.create_index( op.create_index("ix_user_sessions_expires_at", "user_sessions", ["expires_at"])
'ix_user_sessions_expires_at',
'user_sessions',
['expires_at']
)
# 5. Composite index for active session lookup by JTI # 5. Composite index for active session lookup by JTI
op.create_index( op.create_index(
'ix_user_sessions_jti_active', "ix_user_sessions_jti_active",
'user_sessions', "user_sessions",
['refresh_token_jti', 'is_active'] ["refresh_token_jti", "is_active"],
) )
def downgrade() -> None: def downgrade() -> None:
# Drop indexes first # Drop indexes first
op.drop_index('ix_user_sessions_jti_active', table_name='user_sessions') op.drop_index("ix_user_sessions_jti_active", table_name="user_sessions")
op.drop_index('ix_user_sessions_expires_at', table_name='user_sessions') op.drop_index("ix_user_sessions_expires_at", table_name="user_sessions")
op.drop_index('ix_user_sessions_user_active', table_name='user_sessions') op.drop_index("ix_user_sessions_user_active", table_name="user_sessions")
op.drop_index('ix_user_sessions_user_id', table_name='user_sessions') op.drop_index("ix_user_sessions_user_id", table_name="user_sessions")
op.drop_index('ix_user_sessions_jti', table_name='user_sessions') op.drop_index("ix_user_sessions_jti", table_name="user_sessions")
# Drop foreign key # Drop foreign key
op.drop_constraint('fk_user_sessions_user_id', 'user_sessions', type_='foreignkey') op.drop_constraint("fk_user_sessions_user_id", "user_sessions", type_="foreignkey")
# Drop table # Drop table
op.drop_table('user_sessions') op.drop_table("user_sessions")

View File

@@ -5,15 +5,14 @@ Revises:
Create Date: 2025-02-27 12:47:46.445313 Create Date: 2025-02-27 12:47:46.445313
""" """
from typing import Sequence, Union
from alembic import op from collections.abc import Sequence
# revision identifiers, used by Alembic. # revision identifiers, used by Alembic.
revision: str = '7396957cbe80' revision: str = "7396957cbe80"
down_revision: Union[str, None] = None down_revision: str | None = None
branch_labels: Union[str, Sequence[str], None] = None branch_labels: str | Sequence[str] | None = None
depends_on: Union[str, Sequence[str], None] = None depends_on: str | Sequence[str] | None = None
def upgrade() -> None: def upgrade() -> None:

View File

@@ -5,80 +5,112 @@ Revises: 38bf9e7e74b3
Create Date: 2025-10-30 10:00:00.000000 Create Date: 2025-10-30 10:00:00.000000
""" """
from typing import Sequence, Union
from collections.abc import Sequence
import sqlalchemy as sa import sqlalchemy as sa
from alembic import op from alembic import op
# revision identifiers, used by Alembic. # revision identifiers, used by Alembic.
revision: str = '9e4f2a1b8c7d' revision: str = "9e4f2a1b8c7d"
down_revision: Union[str, None] = '38bf9e7e74b3' down_revision: str | None = "38bf9e7e74b3"
branch_labels: Union[str, Sequence[str], None] = None branch_labels: str | Sequence[str] | None = None
depends_on: Union[str, Sequence[str], None] = None depends_on: str | Sequence[str] | None = None
def upgrade() -> None: def upgrade() -> None:
# Add missing indexes for is_active and is_superuser # Add missing indexes for is_active and is_superuser
op.create_index(op.f('ix_users_is_active'), 'users', ['is_active'], unique=False) op.create_index(op.f("ix_users_is_active"), "users", ["is_active"], unique=False)
op.create_index(op.f('ix_users_is_superuser'), 'users', ['is_superuser'], unique=False) op.create_index(
op.f("ix_users_is_superuser"), "users", ["is_superuser"], unique=False
)
# Fix column types to match model definitions with explicit lengths # Fix column types to match model definitions with explicit lengths
op.alter_column('users', 'email', op.alter_column(
existing_type=sa.String(), "users",
type_=sa.String(length=255), "email",
nullable=False) existing_type=sa.String(),
type_=sa.String(length=255),
nullable=False,
)
op.alter_column('users', 'password_hash', op.alter_column(
existing_type=sa.String(), "users",
type_=sa.String(length=255), "password_hash",
nullable=False) existing_type=sa.String(),
type_=sa.String(length=255),
nullable=False,
)
op.alter_column('users', 'first_name', op.alter_column(
existing_type=sa.String(), "users",
type_=sa.String(length=100), "first_name",
nullable=False, existing_type=sa.String(),
server_default='user') # Add server default type_=sa.String(length=100),
nullable=False,
server_default="user",
) # Add server default
op.alter_column('users', 'last_name', op.alter_column(
existing_type=sa.String(), "users",
type_=sa.String(length=100), "last_name",
nullable=True) existing_type=sa.String(),
type_=sa.String(length=100),
nullable=True,
)
op.alter_column('users', 'phone_number', op.alter_column(
existing_type=sa.String(), "users",
type_=sa.String(length=20), "phone_number",
nullable=True) existing_type=sa.String(),
type_=sa.String(length=20),
nullable=True,
)
def downgrade() -> None: def downgrade() -> None:
# Revert column types # Revert column types
op.alter_column('users', 'phone_number', op.alter_column(
existing_type=sa.String(length=20), "users",
type_=sa.String(), "phone_number",
nullable=True) existing_type=sa.String(length=20),
type_=sa.String(),
nullable=True,
)
op.alter_column('users', 'last_name', op.alter_column(
existing_type=sa.String(length=100), "users",
type_=sa.String(), "last_name",
nullable=True) existing_type=sa.String(length=100),
type_=sa.String(),
nullable=True,
)
op.alter_column('users', 'first_name', op.alter_column(
existing_type=sa.String(length=100), "users",
type_=sa.String(), "first_name",
nullable=False, existing_type=sa.String(length=100),
server_default=None) # Remove server default type_=sa.String(),
nullable=False,
server_default=None,
) # Remove server default
op.alter_column('users', 'password_hash', op.alter_column(
existing_type=sa.String(length=255), "users",
type_=sa.String(), "password_hash",
nullable=False) existing_type=sa.String(length=255),
type_=sa.String(),
nullable=False,
)
op.alter_column('users', 'email', op.alter_column(
existing_type=sa.String(length=255), "users",
type_=sa.String(), "email",
nullable=False) existing_type=sa.String(length=255),
type_=sa.String(),
nullable=False,
)
# Drop indexes # Drop indexes
op.drop_index(op.f('ix_users_is_superuser'), table_name='users') op.drop_index(op.f("ix_users_is_superuser"), table_name="users")
op.drop_index(op.f('ix_users_is_active'), table_name='users') op.drop_index(op.f("ix_users_is_active"), table_name="users")

View File

@@ -5,17 +5,17 @@ Revises: 2d0fcec3b06d
Create Date: 2025-10-30 16:41:33.273135 Create Date: 2025-10-30 16:41:33.273135
""" """
from typing import Sequence, Union
from collections.abc import Sequence
import sqlalchemy as sa import sqlalchemy as sa
from alembic import op from alembic import op
# revision identifiers, used by Alembic. # revision identifiers, used by Alembic.
revision: str = 'b76c725fc3cf' revision: str = "b76c725fc3cf"
down_revision: Union[str, None] = '2d0fcec3b06d' down_revision: str | None = "2d0fcec3b06d"
branch_labels: Union[str, Sequence[str], None] = None branch_labels: str | Sequence[str] | None = None
depends_on: Union[str, Sequence[str], None] = None depends_on: str | Sequence[str] | None = None
def upgrade() -> None: def upgrade() -> None:
@@ -23,30 +23,26 @@ def upgrade() -> None:
# Composite index for filtering active users by role # Composite index for filtering active users by role
op.create_index( op.create_index(
'ix_users_active_superuser', "ix_users_active_superuser",
'users', "users",
['is_active', 'is_superuser'], ["is_active", "is_superuser"],
postgresql_where=sa.text('deleted_at IS NULL') postgresql_where=sa.text("deleted_at IS NULL"),
) )
# Composite index for sorting active users by creation date # Composite index for sorting active users by creation date
op.create_index( op.create_index(
'ix_users_active_created', "ix_users_active_created",
'users', "users",
['is_active', 'created_at'], ["is_active", "created_at"],
postgresql_where=sa.text('deleted_at IS NULL') postgresql_where=sa.text("deleted_at IS NULL"),
) )
# Composite index for email lookup of non-deleted users # Composite index for email lookup of non-deleted users
op.create_index( op.create_index("ix_users_email_not_deleted", "users", ["email", "deleted_at"])
'ix_users_email_not_deleted',
'users',
['email', 'deleted_at']
)
def downgrade() -> None: def downgrade() -> None:
# Remove composite indexes # Remove composite indexes
op.drop_index('ix_users_email_not_deleted', table_name='users') op.drop_index("ix_users_email_not_deleted", table_name="users")
op.drop_index('ix_users_active_created', table_name='users') op.drop_index("ix_users_active_created", table_name="users")
op.drop_index('ix_users_active_superuser', table_name='users') op.drop_index("ix_users_active_superuser", table_name="users")

View File

@@ -5,102 +5,123 @@ Revises: 549b50ea888d
Create Date: 2025-10-31 12:08:05.141353 Create Date: 2025-10-31 12:08:05.141353
""" """
from typing import Sequence, Union
from collections.abc import Sequence
import sqlalchemy as sa import sqlalchemy as sa
from alembic import op from alembic import op
# revision identifiers, used by Alembic. # revision identifiers, used by Alembic.
revision: str = 'fbf6318a8a36' revision: str = "fbf6318a8a36"
down_revision: Union[str, None] = '549b50ea888d' down_revision: str | None = "549b50ea888d"
branch_labels: Union[str, Sequence[str], None] = None branch_labels: str | Sequence[str] | None = None
depends_on: Union[str, Sequence[str], None] = None depends_on: str | Sequence[str] | None = None
def upgrade() -> None: def upgrade() -> None:
# Create organizations table # Create organizations table
op.create_table( op.create_table(
'organizations', "organizations",
sa.Column('id', sa.UUID(), nullable=False), sa.Column("id", sa.UUID(), nullable=False),
sa.Column('name', sa.String(length=255), nullable=False), sa.Column("name", sa.String(length=255), nullable=False),
sa.Column('slug', sa.String(length=255), nullable=False), sa.Column("slug", sa.String(length=255), nullable=False),
sa.Column('description', sa.Text(), nullable=True), sa.Column("description", sa.Text(), nullable=True),
sa.Column('is_active', sa.Boolean(), nullable=False, server_default='true'), sa.Column("is_active", sa.Boolean(), nullable=False, server_default="true"),
sa.Column('settings', sa.JSON(), nullable=True), sa.Column("settings", sa.JSON(), nullable=True),
sa.Column('created_at', sa.DateTime(timezone=True), nullable=False), sa.Column("created_at", sa.DateTime(timezone=True), nullable=False),
sa.Column('updated_at', sa.DateTime(timezone=True), nullable=False), sa.Column("updated_at", sa.DateTime(timezone=True), nullable=False),
sa.PrimaryKeyConstraint('id') sa.PrimaryKeyConstraint("id"),
) )
# Create indexes for organizations # Create indexes for organizations
op.create_index('ix_organizations_name', 'organizations', ['name']) op.create_index("ix_organizations_name", "organizations", ["name"])
op.create_index('ix_organizations_slug', 'organizations', ['slug'], unique=True) op.create_index("ix_organizations_slug", "organizations", ["slug"], unique=True)
op.create_index('ix_organizations_is_active', 'organizations', ['is_active']) op.create_index("ix_organizations_is_active", "organizations", ["is_active"])
op.create_index('ix_organizations_name_active', 'organizations', ['name', 'is_active']) op.create_index(
op.create_index('ix_organizations_slug_active', 'organizations', ['slug', 'is_active']) "ix_organizations_name_active", "organizations", ["name", "is_active"]
)
op.create_index(
"ix_organizations_slug_active", "organizations", ["slug", "is_active"]
)
# Create user_organizations junction table # Create user_organizations junction table
op.create_table( op.create_table(
'user_organizations', "user_organizations",
sa.Column('user_id', sa.UUID(), nullable=False), sa.Column("user_id", sa.UUID(), nullable=False),
sa.Column('organization_id', sa.UUID(), nullable=False), sa.Column("organization_id", sa.UUID(), nullable=False),
sa.Column('role', sa.Enum('OWNER', 'ADMIN', 'MEMBER', 'GUEST', name='organizationrole'), nullable=False, server_default='MEMBER'), sa.Column(
sa.Column('is_active', sa.Boolean(), nullable=False, server_default='true'), "role",
sa.Column('custom_permissions', sa.String(length=500), nullable=True), sa.Enum("OWNER", "ADMIN", "MEMBER", "GUEST", name="organizationrole"),
sa.Column('created_at', sa.DateTime(timezone=True), nullable=False), nullable=False,
sa.Column('updated_at', sa.DateTime(timezone=True), nullable=False), server_default="MEMBER",
sa.PrimaryKeyConstraint('user_id', 'organization_id') ),
sa.Column("is_active", sa.Boolean(), nullable=False, server_default="true"),
sa.Column("custom_permissions", sa.String(length=500), nullable=True),
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False),
sa.Column("updated_at", sa.DateTime(timezone=True), nullable=False),
sa.PrimaryKeyConstraint("user_id", "organization_id"),
) )
# Create foreign keys # Create foreign keys
op.create_foreign_key( op.create_foreign_key(
'fk_user_organizations_user_id', "fk_user_organizations_user_id",
'user_organizations', "user_organizations",
'users', "users",
['user_id'], ["user_id"],
['id'], ["id"],
ondelete='CASCADE' ondelete="CASCADE",
) )
op.create_foreign_key( op.create_foreign_key(
'fk_user_organizations_organization_id', "fk_user_organizations_organization_id",
'user_organizations', "user_organizations",
'organizations', "organizations",
['organization_id'], ["organization_id"],
['id'], ["id"],
ondelete='CASCADE' ondelete="CASCADE",
) )
# Create indexes for user_organizations # Create indexes for user_organizations
op.create_index('ix_user_organizations_role', 'user_organizations', ['role']) op.create_index("ix_user_organizations_role", "user_organizations", ["role"])
op.create_index('ix_user_organizations_is_active', 'user_organizations', ['is_active']) op.create_index(
op.create_index('ix_user_org_user_active', 'user_organizations', ['user_id', 'is_active']) "ix_user_organizations_is_active", "user_organizations", ["is_active"]
op.create_index('ix_user_org_org_active', 'user_organizations', ['organization_id', 'is_active']) )
op.create_index(
"ix_user_org_user_active", "user_organizations", ["user_id", "is_active"]
)
op.create_index(
"ix_user_org_org_active", "user_organizations", ["organization_id", "is_active"]
)
def downgrade() -> None: def downgrade() -> None:
# Drop indexes for user_organizations # Drop indexes for user_organizations
op.drop_index('ix_user_org_org_active', table_name='user_organizations') op.drop_index("ix_user_org_org_active", table_name="user_organizations")
op.drop_index('ix_user_org_user_active', table_name='user_organizations') op.drop_index("ix_user_org_user_active", table_name="user_organizations")
op.drop_index('ix_user_organizations_is_active', table_name='user_organizations') op.drop_index("ix_user_organizations_is_active", table_name="user_organizations")
op.drop_index('ix_user_organizations_role', table_name='user_organizations') op.drop_index("ix_user_organizations_role", table_name="user_organizations")
# Drop foreign keys # Drop foreign keys
op.drop_constraint('fk_user_organizations_organization_id', 'user_organizations', type_='foreignkey') op.drop_constraint(
op.drop_constraint('fk_user_organizations_user_id', 'user_organizations', type_='foreignkey') "fk_user_organizations_organization_id",
"user_organizations",
type_="foreignkey",
)
op.drop_constraint(
"fk_user_organizations_user_id", "user_organizations", type_="foreignkey"
)
# Drop user_organizations table # Drop user_organizations table
op.drop_table('user_organizations') op.drop_table("user_organizations")
# Drop indexes for organizations # Drop indexes for organizations
op.drop_index('ix_organizations_slug_active', table_name='organizations') op.drop_index("ix_organizations_slug_active", table_name="organizations")
op.drop_index('ix_organizations_name_active', table_name='organizations') op.drop_index("ix_organizations_name_active", table_name="organizations")
op.drop_index('ix_organizations_is_active', table_name='organizations') op.drop_index("ix_organizations_is_active", table_name="organizations")
op.drop_index('ix_organizations_slug', table_name='organizations') op.drop_index("ix_organizations_slug", table_name="organizations")
op.drop_index('ix_organizations_name', table_name='organizations') op.drop_index("ix_organizations_name", table_name="organizations")
# Drop organizations table # Drop organizations table
op.drop_table('organizations') op.drop_table("organizations")
# Drop enum type # Drop enum type
op.execute('DROP TYPE IF EXISTS organizationrole') op.execute("DROP TYPE IF EXISTS organizationrole")

View File

@@ -1,12 +1,10 @@
from typing import Optional from fastapi import Depends, Header, HTTPException, status
from fastapi import Depends, HTTPException, status, Header
from fastapi.security import OAuth2PasswordBearer from fastapi.security import OAuth2PasswordBearer
from fastapi.security.utils import get_authorization_scheme_param from fastapi.security.utils import get_authorization_scheme_param
from sqlalchemy import select from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from app.core.auth import get_token_data, TokenExpiredError, TokenInvalidError from app.core.auth import TokenExpiredError, TokenInvalidError, get_token_data
from app.core.database import get_db from app.core.database import get_db
from app.models.user import User from app.models.user import User
@@ -15,8 +13,7 @@ oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/api/v1/auth/login")
async def get_current_user( async def get_current_user(
db: AsyncSession = Depends(get_db), db: AsyncSession = Depends(get_db), token: str = Depends(oauth2_scheme)
token: str = Depends(oauth2_scheme)
) -> User: ) -> User:
""" """
Get the current authenticated user. Get the current authenticated user.
@@ -36,21 +33,17 @@ async def get_current_user(
token_data = get_token_data(token) token_data = get_token_data(token)
# Get user from database # Get user from database
result = await db.execute( result = await db.execute(select(User).where(User.id == token_data.user_id))
select(User).where(User.id == token_data.user_id)
)
user = result.scalar_one_or_none() user = result.scalar_one_or_none()
if not user: if not user:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, status_code=status.HTTP_404_NOT_FOUND, detail="User not found"
detail="User not found"
) )
if not user.is_active: if not user.is_active:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN, status_code=status.HTTP_403_FORBIDDEN, detail="Inactive user"
detail="Inactive user"
) )
return user return user
@@ -59,19 +52,17 @@ async def get_current_user(
raise HTTPException( raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, status_code=status.HTTP_401_UNAUTHORIZED,
detail="Token expired", detail="Token expired",
headers={"WWW-Authenticate": "Bearer"} headers={"WWW-Authenticate": "Bearer"},
) )
except TokenInvalidError: except TokenInvalidError:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, status_code=status.HTTP_401_UNAUTHORIZED,
detail="Could not validate credentials", detail="Could not validate credentials",
headers={"WWW-Authenticate": "Bearer"} headers={"WWW-Authenticate": "Bearer"},
) )
def get_current_active_user( def get_current_active_user(current_user: User = Depends(get_current_user)) -> User:
current_user: User = Depends(get_current_user)
) -> User:
""" """
Check if the current user is active. Check if the current user is active.
@@ -86,15 +77,12 @@ def get_current_active_user(
""" """
if not current_user.is_active: if not current_user.is_active:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN, status_code=status.HTTP_403_FORBIDDEN, detail="Inactive user"
detail="Inactive user"
) )
return current_user return current_user
def get_current_superuser( def get_current_superuser(current_user: User = Depends(get_current_user)) -> User:
current_user: User = Depends(get_current_user)
) -> User:
""" """
Check if the current user is a superuser. Check if the current user is a superuser.
@@ -109,13 +97,12 @@ def get_current_superuser(
""" """
if not current_user.is_superuser: if not current_user.is_superuser:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN, status_code=status.HTTP_403_FORBIDDEN, detail="Not enough permissions"
detail="Not enough permissions"
) )
return current_user return current_user
async def get_optional_token(authorization: str = Header(None)) -> Optional[str]: async def get_optional_token(authorization: str = Header(None)) -> str | None:
""" """
Get the token from the Authorization header without requiring it. Get the token from the Authorization header without requiring it.
@@ -139,9 +126,8 @@ async def get_optional_token(authorization: str = Header(None)) -> Optional[str]
async def get_optional_current_user( async def get_optional_current_user(
db: AsyncSession = Depends(get_db), db: AsyncSession = Depends(get_db), token: str | None = Depends(get_optional_token)
token: Optional[str] = Depends(get_optional_token) ) -> User | None:
) -> Optional[User]:
""" """
Get the current user if authenticated, otherwise return None. Get the current user if authenticated, otherwise return None.
Useful for endpoints that work with both authenticated and unauthenticated users. Useful for endpoints that work with both authenticated and unauthenticated users.
@@ -158,9 +144,7 @@ async def get_optional_current_user(
try: try:
token_data = get_token_data(token) token_data = get_token_data(token)
result = await db.execute( result = await db.execute(select(User).where(User.id == token_data.user_id))
select(User).where(User.id == token_data.user_id)
)
user = result.scalar_one_or_none() user = result.scalar_one_or_none()
if not user or not user.is_active: if not user or not user.is_active:
return None return None

View File

@@ -7,7 +7,7 @@ These dependencies are optional and flexible:
- Use require_org_role for organization-specific access control - Use require_org_role for organization-specific access control
- Projects can choose to use these or implement their own permission system - Projects can choose to use these or implement their own permission system
""" """
from typing import Optional
from uuid import UUID from uuid import UUID
from fastapi import Depends, HTTPException, status from fastapi import Depends, HTTPException, status
@@ -20,9 +20,7 @@ from app.models.user import User
from app.models.user_organization import OrganizationRole from app.models.user_organization import OrganizationRole
def require_superuser( def require_superuser(current_user: User = Depends(get_current_user)) -> User:
current_user: User = Depends(get_current_user)
) -> User:
""" """
Dependency to ensure the current user is a superuser. Dependency to ensure the current user is a superuser.
@@ -36,7 +34,7 @@ def require_superuser(
if not current_user.is_superuser: if not current_user.is_superuser:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN, status_code=status.HTTP_403_FORBIDDEN,
detail="Superuser privileges required" detail="Superuser privileges required",
) )
return current_user return current_user
@@ -62,7 +60,7 @@ class OrganizationPermission:
self, self,
organization_id: UUID, organization_id: UUID,
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db) db: AsyncSession = Depends(get_db),
) -> User: ) -> User:
""" """
Check if user has required role in the organization. Check if user has required role in the organization.
@@ -84,21 +82,19 @@ class OrganizationPermission:
# Get user's role in organization # Get user's role in organization
user_role = await organization_crud.get_user_role_in_org( user_role = await organization_crud.get_user_role_in_org(
db, db, user_id=current_user.id, organization_id=organization_id
user_id=current_user.id,
organization_id=organization_id
) )
if not user_role: if not user_role:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN, status_code=status.HTTP_403_FORBIDDEN,
detail="Not a member of this organization" detail="Not a member of this organization",
) )
if user_role not in self.allowed_roles: if user_role not in self.allowed_roles:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN, status_code=status.HTTP_403_FORBIDDEN,
detail=f"Role {user_role} not authorized. Required: {self.allowed_roles}" detail=f"Role {user_role} not authorized. Required: {self.allowed_roles}",
) )
return current_user return current_user
@@ -106,18 +102,18 @@ class OrganizationPermission:
# Common permission presets for convenience # Common permission presets for convenience
require_org_owner = OrganizationPermission([OrganizationRole.OWNER]) require_org_owner = OrganizationPermission([OrganizationRole.OWNER])
require_org_admin = OrganizationPermission([OrganizationRole.OWNER, OrganizationRole.ADMIN]) require_org_admin = OrganizationPermission(
require_org_member = OrganizationPermission([ [OrganizationRole.OWNER, OrganizationRole.ADMIN]
OrganizationRole.OWNER, )
OrganizationRole.ADMIN, require_org_member = OrganizationPermission(
OrganizationRole.MEMBER [OrganizationRole.OWNER, OrganizationRole.ADMIN, OrganizationRole.MEMBER]
]) )
async def require_org_membership( async def require_org_membership(
organization_id: UUID, organization_id: UUID,
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db) db: AsyncSession = Depends(get_db),
) -> User: ) -> User:
""" """
Ensure user is a member of the organization (any role). Ensure user is a member of the organization (any role).
@@ -128,15 +124,13 @@ async def require_org_membership(
return current_user return current_user
user_role = await organization_crud.get_user_role_in_org( user_role = await organization_crud.get_user_role_in_org(
db, db, user_id=current_user.id, organization_id=organization_id
user_id=current_user.id,
organization_id=organization_id
) )
if not user_role: if not user_role:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN, status_code=status.HTTP_403_FORBIDDEN,
detail="Not a member of this organization" detail="Not a member of this organization",
) )
return current_user return current_user

View File

@@ -1,10 +1,12 @@
from fastapi import APIRouter from fastapi import APIRouter
from app.api.routes import auth, users, sessions, admin, organizations from app.api.routes import admin, auth, organizations, sessions, users
api_router = APIRouter() api_router = APIRouter()
api_router.include_router(auth.router, prefix="/auth", tags=["Authentication"]) api_router.include_router(auth.router, prefix="/auth", tags=["Authentication"])
api_router.include_router(users.router, prefix="/users", tags=["Users"]) api_router.include_router(users.router, prefix="/users", tags=["Users"])
api_router.include_router(sessions.router, prefix="/sessions", tags=["Sessions"]) api_router.include_router(sessions.router, prefix="/sessions", tags=["Sessions"])
api_router.include_router(admin.router, prefix="/admin", tags=["Admin"]) api_router.include_router(admin.router, prefix="/admin", tags=["Admin"])
api_router.include_router(organizations.router, prefix="/organizations", tags=["Organizations"]) api_router.include_router(
organizations.router, prefix="/organizations", tags=["Organizations"]
)

View File

@@ -5,9 +5,10 @@ Admin-specific endpoints for managing users and organizations.
These endpoints require superuser privileges and provide CMS-like functionality These endpoints require superuser privileges and provide CMS-like functionality
for managing the application. for managing the application.
""" """
import logging import logging
from enum import Enum from enum import Enum
from typing import Any, List, Optional from typing import Any
from uuid import UUID from uuid import UUID
from fastapi import APIRouter, Depends, Query, status from fastapi import APIRouter, Depends, Query, status
@@ -16,27 +17,32 @@ from sqlalchemy.ext.asyncio import AsyncSession
from app.api.dependencies.permissions import require_superuser from app.api.dependencies.permissions import require_superuser
from app.core.database import get_db from app.core.database import get_db
from app.core.exceptions import NotFoundError, DuplicateError, AuthorizationError, ErrorCode from app.core.exceptions import (
AuthorizationError,
DuplicateError,
ErrorCode,
NotFoundError,
)
from app.crud.organization import organization as organization_crud from app.crud.organization import organization as organization_crud
from app.crud.user import user as user_crud
from app.crud.session import session as session_crud from app.crud.session import session as session_crud
from app.crud.user import user as user_crud
from app.models.user import User from app.models.user import User
from app.models.user_organization import OrganizationRole from app.models.user_organization import OrganizationRole
from app.schemas.common import ( from app.schemas.common import (
PaginationParams,
PaginatedResponse,
MessageResponse, MessageResponse,
PaginatedResponse,
PaginationParams,
SortParams, SortParams,
create_pagination_meta create_pagination_meta,
) )
from app.schemas.organizations import ( from app.schemas.organizations import (
OrganizationResponse,
OrganizationCreate, OrganizationCreate,
OrganizationMemberResponse,
OrganizationResponse,
OrganizationUpdate, OrganizationUpdate,
OrganizationMemberResponse
) )
from app.schemas.users import UserResponse, UserCreate, UserUpdate
from app.schemas.sessions import AdminSessionResponse from app.schemas.sessions import AdminSessionResponse
from app.schemas.users import UserCreate, UserResponse, UserUpdate
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -46,6 +52,7 @@ router = APIRouter()
# Schemas for bulk operations # Schemas for bulk operations
class BulkAction(str, Enum): class BulkAction(str, Enum):
"""Supported bulk actions.""" """Supported bulk actions."""
ACTIVATE = "activate" ACTIVATE = "activate"
DEACTIVATE = "deactivate" DEACTIVATE = "deactivate"
DELETE = "delete" DELETE = "delete"
@@ -53,36 +60,41 @@ class BulkAction(str, Enum):
class BulkUserAction(BaseModel): class BulkUserAction(BaseModel):
"""Schema for bulk user actions.""" """Schema for bulk user actions."""
action: BulkAction = Field(..., description="Action to perform on selected users") action: BulkAction = Field(..., description="Action to perform on selected users")
user_ids: List[UUID] = Field(..., min_items=1, max_items=100, description="List of user IDs (max 100)") user_ids: list[UUID] = Field(
..., min_items=1, max_items=100, description="List of user IDs (max 100)"
)
class BulkActionResult(BaseModel): class BulkActionResult(BaseModel):
"""Result of a bulk action.""" """Result of a bulk action."""
success: bool success: bool
affected_count: int affected_count: int
failed_count: int failed_count: int
message: str message: str
failed_ids: Optional[List[UUID]] = [] failed_ids: list[UUID] | None = []
# ===== User Management Endpoints ===== # ===== User Management Endpoints =====
@router.get( @router.get(
"/users", "/users",
response_model=PaginatedResponse[UserResponse], response_model=PaginatedResponse[UserResponse],
summary="Admin: List All Users", summary="Admin: List All Users",
description="Get paginated list of all users with filtering and search (admin only)", description="Get paginated list of all users with filtering and search (admin only)",
operation_id="admin_list_users" operation_id="admin_list_users",
) )
async def admin_list_users( async def admin_list_users(
pagination: PaginationParams = Depends(), pagination: PaginationParams = Depends(),
sort: SortParams = Depends(), sort: SortParams = Depends(),
is_active: Optional[bool] = Query(None, description="Filter by active status"), is_active: bool | None = Query(None, description="Filter by active status"),
is_superuser: Optional[bool] = Query(None, description="Filter by superuser status"), is_superuser: bool | None = Query(None, description="Filter by superuser status"),
search: Optional[str] = Query(None, description="Search by email, name"), search: str | None = Query(None, description="Search by email, name"),
admin: User = Depends(require_superuser), admin: User = Depends(require_superuser),
db: AsyncSession = Depends(get_db) db: AsyncSession = Depends(get_db),
) -> Any: ) -> Any:
""" """
List all users with comprehensive filtering and search. List all users with comprehensive filtering and search.
@@ -105,20 +117,20 @@ async def admin_list_users(
sort_by=sort.sort_by or "created_at", sort_by=sort.sort_by or "created_at",
sort_order=sort.sort_order.value if sort.sort_order else "desc", sort_order=sort.sort_order.value if sort.sort_order else "desc",
filters=filters if filters else None, filters=filters if filters else None,
search=search search=search,
) )
pagination_meta = create_pagination_meta( pagination_meta = create_pagination_meta(
total=total, total=total,
page=pagination.page, page=pagination.page,
limit=pagination.limit, limit=pagination.limit,
items_count=len(users) items_count=len(users),
) )
return PaginatedResponse(data=users, pagination=pagination_meta) return PaginatedResponse(data=users, pagination=pagination_meta)
except Exception as e: except Exception as e:
logger.error(f"Error listing users (admin): {str(e)}", exc_info=True) logger.error(f"Error listing users (admin): {e!s}", exc_info=True)
raise raise
@@ -128,12 +140,12 @@ async def admin_list_users(
status_code=status.HTTP_201_CREATED, status_code=status.HTTP_201_CREATED,
summary="Admin: Create User", summary="Admin: Create User",
description="Create a new user (admin only)", description="Create a new user (admin only)",
operation_id="admin_create_user" operation_id="admin_create_user",
) )
async def admin_create_user( async def admin_create_user(
user_in: UserCreate, user_in: UserCreate,
admin: User = Depends(require_superuser), admin: User = Depends(require_superuser),
db: AsyncSession = Depends(get_db) db: AsyncSession = Depends(get_db),
) -> Any: ) -> Any:
""" """
Create a new user with admin privileges. Create a new user with admin privileges.
@@ -145,13 +157,10 @@ async def admin_create_user(
logger.info(f"Admin {admin.email} created user {user.email}") logger.info(f"Admin {admin.email} created user {user.email}")
return user return user
except ValueError as e: except ValueError as e:
logger.warning(f"Failed to create user: {str(e)}") logger.warning(f"Failed to create user: {e!s}")
raise NotFoundError( raise NotFoundError(message=str(e), error_code=ErrorCode.USER_ALREADY_EXISTS)
message=str(e),
error_code=ErrorCode.USER_ALREADY_EXISTS
)
except Exception as e: except Exception as e:
logger.error(f"Error creating user (admin): {str(e)}", exc_info=True) logger.error(f"Error creating user (admin): {e!s}", exc_info=True)
raise raise
@@ -160,19 +169,18 @@ async def admin_create_user(
response_model=UserResponse, response_model=UserResponse,
summary="Admin: Get User Details", summary="Admin: Get User Details",
description="Get detailed user information (admin only)", description="Get detailed user information (admin only)",
operation_id="admin_get_user" operation_id="admin_get_user",
) )
async def admin_get_user( async def admin_get_user(
user_id: UUID, user_id: UUID,
admin: User = Depends(require_superuser), admin: User = Depends(require_superuser),
db: AsyncSession = Depends(get_db) db: AsyncSession = Depends(get_db),
) -> Any: ) -> Any:
"""Get detailed information about a specific user.""" """Get detailed information about a specific user."""
user = await user_crud.get(db, id=user_id) user = await user_crud.get(db, id=user_id)
if not user: if not user:
raise NotFoundError( raise NotFoundError(
message=f"User {user_id} not found", message=f"User {user_id} not found", error_code=ErrorCode.USER_NOT_FOUND
error_code=ErrorCode.USER_NOT_FOUND
) )
return user return user
@@ -182,21 +190,20 @@ async def admin_get_user(
response_model=UserResponse, response_model=UserResponse,
summary="Admin: Update User", summary="Admin: Update User",
description="Update user information (admin only)", description="Update user information (admin only)",
operation_id="admin_update_user" operation_id="admin_update_user",
) )
async def admin_update_user( async def admin_update_user(
user_id: UUID, user_id: UUID,
user_in: UserUpdate, user_in: UserUpdate,
admin: User = Depends(require_superuser), admin: User = Depends(require_superuser),
db: AsyncSession = Depends(get_db) db: AsyncSession = Depends(get_db),
) -> Any: ) -> Any:
"""Update user information with admin privileges.""" """Update user information with admin privileges."""
try: try:
user = await user_crud.get(db, id=user_id) user = await user_crud.get(db, id=user_id)
if not user: if not user:
raise NotFoundError( raise NotFoundError(
message=f"User {user_id} not found", message=f"User {user_id} not found", error_code=ErrorCode.USER_NOT_FOUND
error_code=ErrorCode.USER_NOT_FOUND
) )
updated_user = await user_crud.update(db, db_obj=user, obj_in=user_in) updated_user = await user_crud.update(db, db_obj=user, obj_in=user_in)
@@ -206,7 +213,7 @@ async def admin_update_user(
except NotFoundError: except NotFoundError:
raise raise
except Exception as e: except Exception as e:
logger.error(f"Error updating user (admin): {str(e)}", exc_info=True) logger.error(f"Error updating user (admin): {e!s}", exc_info=True)
raise raise
@@ -215,20 +222,19 @@ async def admin_update_user(
response_model=MessageResponse, response_model=MessageResponse,
summary="Admin: Delete User", summary="Admin: Delete User",
description="Soft delete a user (admin only)", description="Soft delete a user (admin only)",
operation_id="admin_delete_user" operation_id="admin_delete_user",
) )
async def admin_delete_user( async def admin_delete_user(
user_id: UUID, user_id: UUID,
admin: User = Depends(require_superuser), admin: User = Depends(require_superuser),
db: AsyncSession = Depends(get_db) db: AsyncSession = Depends(get_db),
) -> Any: ) -> Any:
"""Soft delete a user (sets deleted_at timestamp).""" """Soft delete a user (sets deleted_at timestamp)."""
try: try:
user = await user_crud.get(db, id=user_id) user = await user_crud.get(db, id=user_id)
if not user: if not user:
raise NotFoundError( raise NotFoundError(
message=f"User {user_id} not found", message=f"User {user_id} not found", error_code=ErrorCode.USER_NOT_FOUND
error_code=ErrorCode.USER_NOT_FOUND
) )
# Prevent deleting yourself # Prevent deleting yourself
@@ -236,21 +242,20 @@ async def admin_delete_user(
# Use AuthorizationError for permission/operation restrictions # Use AuthorizationError for permission/operation restrictions
raise AuthorizationError( raise AuthorizationError(
message="Cannot delete your own account", message="Cannot delete your own account",
error_code=ErrorCode.OPERATION_FORBIDDEN error_code=ErrorCode.OPERATION_FORBIDDEN,
) )
await user_crud.soft_delete(db, id=user_id) await user_crud.soft_delete(db, id=user_id)
logger.info(f"Admin {admin.email} deleted user {user.email}") logger.info(f"Admin {admin.email} deleted user {user.email}")
return MessageResponse( return MessageResponse(
success=True, success=True, message=f"User {user.email} has been deleted"
message=f"User {user.email} has been deleted"
) )
except NotFoundError: except NotFoundError:
raise raise
except Exception as e: except Exception as e:
logger.error(f"Error deleting user (admin): {str(e)}", exc_info=True) logger.error(f"Error deleting user (admin): {e!s}", exc_info=True)
raise raise
@@ -259,34 +264,32 @@ async def admin_delete_user(
response_model=MessageResponse, response_model=MessageResponse,
summary="Admin: Activate User", summary="Admin: Activate User",
description="Activate a user account (admin only)", description="Activate a user account (admin only)",
operation_id="admin_activate_user" operation_id="admin_activate_user",
) )
async def admin_activate_user( async def admin_activate_user(
user_id: UUID, user_id: UUID,
admin: User = Depends(require_superuser), admin: User = Depends(require_superuser),
db: AsyncSession = Depends(get_db) db: AsyncSession = Depends(get_db),
) -> Any: ) -> Any:
"""Activate a user account.""" """Activate a user account."""
try: try:
user = await user_crud.get(db, id=user_id) user = await user_crud.get(db, id=user_id)
if not user: if not user:
raise NotFoundError( raise NotFoundError(
message=f"User {user_id} not found", message=f"User {user_id} not found", error_code=ErrorCode.USER_NOT_FOUND
error_code=ErrorCode.USER_NOT_FOUND
) )
await user_crud.update(db, db_obj=user, obj_in={"is_active": True}) await user_crud.update(db, db_obj=user, obj_in={"is_active": True})
logger.info(f"Admin {admin.email} activated user {user.email}") logger.info(f"Admin {admin.email} activated user {user.email}")
return MessageResponse( return MessageResponse(
success=True, success=True, message=f"User {user.email} has been activated"
message=f"User {user.email} has been activated"
) )
except NotFoundError: except NotFoundError:
raise raise
except Exception as e: except Exception as e:
logger.error(f"Error activating user (admin): {str(e)}", exc_info=True) logger.error(f"Error activating user (admin): {e!s}", exc_info=True)
raise raise
@@ -295,20 +298,19 @@ async def admin_activate_user(
response_model=MessageResponse, response_model=MessageResponse,
summary="Admin: Deactivate User", summary="Admin: Deactivate User",
description="Deactivate a user account (admin only)", description="Deactivate a user account (admin only)",
operation_id="admin_deactivate_user" operation_id="admin_deactivate_user",
) )
async def admin_deactivate_user( async def admin_deactivate_user(
user_id: UUID, user_id: UUID,
admin: User = Depends(require_superuser), admin: User = Depends(require_superuser),
db: AsyncSession = Depends(get_db) db: AsyncSession = Depends(get_db),
) -> Any: ) -> Any:
"""Deactivate a user account.""" """Deactivate a user account."""
try: try:
user = await user_crud.get(db, id=user_id) user = await user_crud.get(db, id=user_id)
if not user: if not user:
raise NotFoundError( raise NotFoundError(
message=f"User {user_id} not found", message=f"User {user_id} not found", error_code=ErrorCode.USER_NOT_FOUND
error_code=ErrorCode.USER_NOT_FOUND
) )
# Prevent deactivating yourself # Prevent deactivating yourself
@@ -316,21 +318,20 @@ async def admin_deactivate_user(
# Use AuthorizationError for permission/operation restrictions # Use AuthorizationError for permission/operation restrictions
raise AuthorizationError( raise AuthorizationError(
message="Cannot deactivate your own account", message="Cannot deactivate your own account",
error_code=ErrorCode.OPERATION_FORBIDDEN error_code=ErrorCode.OPERATION_FORBIDDEN,
) )
await user_crud.update(db, db_obj=user, obj_in={"is_active": False}) await user_crud.update(db, db_obj=user, obj_in={"is_active": False})
logger.info(f"Admin {admin.email} deactivated user {user.email}") logger.info(f"Admin {admin.email} deactivated user {user.email}")
return MessageResponse( return MessageResponse(
success=True, success=True, message=f"User {user.email} has been deactivated"
message=f"User {user.email} has been deactivated"
) )
except NotFoundError: except NotFoundError:
raise raise
except Exception as e: except Exception as e:
logger.error(f"Error deactivating user (admin): {str(e)}", exc_info=True) logger.error(f"Error deactivating user (admin): {e!s}", exc_info=True)
raise raise
@@ -339,12 +340,12 @@ async def admin_deactivate_user(
response_model=BulkActionResult, response_model=BulkActionResult,
summary="Admin: Bulk User Action", summary="Admin: Bulk User Action",
description="Perform bulk actions on multiple users (admin only)", description="Perform bulk actions on multiple users (admin only)",
operation_id="admin_bulk_user_action" operation_id="admin_bulk_user_action",
) )
async def admin_bulk_user_action( async def admin_bulk_user_action(
bulk_action: BulkUserAction, bulk_action: BulkUserAction,
admin: User = Depends(require_superuser), admin: User = Depends(require_superuser),
db: AsyncSession = Depends(get_db) db: AsyncSession = Depends(get_db),
) -> Any: ) -> Any:
""" """
Perform bulk actions on multiple users using optimized bulk operations. Perform bulk actions on multiple users using optimized bulk operations.
@@ -356,22 +357,16 @@ async def admin_bulk_user_action(
# Use efficient bulk operations instead of loop # Use efficient bulk operations instead of loop
if bulk_action.action == BulkAction.ACTIVATE: if bulk_action.action == BulkAction.ACTIVATE:
affected_count = await user_crud.bulk_update_status( affected_count = await user_crud.bulk_update_status(
db, db, user_ids=bulk_action.user_ids, is_active=True
user_ids=bulk_action.user_ids,
is_active=True
) )
elif bulk_action.action == BulkAction.DEACTIVATE: elif bulk_action.action == BulkAction.DEACTIVATE:
affected_count = await user_crud.bulk_update_status( affected_count = await user_crud.bulk_update_status(
db, db, user_ids=bulk_action.user_ids, is_active=False
user_ids=bulk_action.user_ids,
is_active=False
) )
elif bulk_action.action == BulkAction.DELETE: elif bulk_action.action == BulkAction.DELETE:
# bulk_soft_delete automatically excludes the admin user # bulk_soft_delete automatically excludes the admin user
affected_count = await user_crud.bulk_soft_delete( affected_count = await user_crud.bulk_soft_delete(
db, db, user_ids=bulk_action.user_ids, exclude_user_id=admin.id
user_ids=bulk_action.user_ids,
exclude_user_id=admin.id
) )
else: else:
raise ValueError(f"Unsupported bulk action: {bulk_action.action}") raise ValueError(f"Unsupported bulk action: {bulk_action.action}")
@@ -390,29 +385,30 @@ async def admin_bulk_user_action(
affected_count=affected_count, affected_count=affected_count,
failed_count=failed_count, failed_count=failed_count,
message=f"Bulk {bulk_action.action.value}: {affected_count} users affected, {failed_count} skipped", message=f"Bulk {bulk_action.action.value}: {affected_count} users affected, {failed_count} skipped",
failed_ids=None # Bulk operations don't track individual failures failed_ids=None, # Bulk operations don't track individual failures
) )
except Exception as e: except Exception as e:
logger.error(f"Error in bulk user action: {str(e)}", exc_info=True) logger.error(f"Error in bulk user action: {e!s}", exc_info=True)
raise raise
# ===== Organization Management Endpoints ===== # ===== Organization Management Endpoints =====
@router.get( @router.get(
"/organizations", "/organizations",
response_model=PaginatedResponse[OrganizationResponse], response_model=PaginatedResponse[OrganizationResponse],
summary="Admin: List Organizations", summary="Admin: List Organizations",
description="Get paginated list of all organizations (admin only)", description="Get paginated list of all organizations (admin only)",
operation_id="admin_list_organizations" operation_id="admin_list_organizations",
) )
async def admin_list_organizations( async def admin_list_organizations(
pagination: PaginationParams = Depends(), pagination: PaginationParams = Depends(),
is_active: Optional[bool] = Query(None, description="Filter by active status"), is_active: bool | None = Query(None, description="Filter by active status"),
search: Optional[str] = Query(None, description="Search by name, slug, description"), search: str | None = Query(None, description="Search by name, slug, description"),
admin: User = Depends(require_superuser), admin: User = Depends(require_superuser),
db: AsyncSession = Depends(get_db) db: AsyncSession = Depends(get_db),
) -> Any: ) -> Any:
"""List all organizations with filtering and search.""" """List all organizations with filtering and search."""
try: try:
@@ -422,14 +418,14 @@ async def admin_list_organizations(
skip=pagination.offset, skip=pagination.offset,
limit=pagination.limit, limit=pagination.limit,
is_active=is_active, is_active=is_active,
search=search search=search,
) )
# Build response objects from optimized query results # Build response objects from optimized query results
orgs_with_count = [] orgs_with_count = []
for item in orgs_with_data: for item in orgs_with_data:
org = item['organization'] org = item["organization"]
member_count = item['member_count'] member_count = item["member_count"]
org_dict = { org_dict = {
"id": org.id, "id": org.id,
@@ -440,7 +436,7 @@ async def admin_list_organizations(
"settings": org.settings, "settings": org.settings,
"created_at": org.created_at, "created_at": org.created_at,
"updated_at": org.updated_at, "updated_at": org.updated_at,
"member_count": member_count "member_count": member_count,
} }
orgs_with_count.append(OrganizationResponse(**org_dict)) orgs_with_count.append(OrganizationResponse(**org_dict))
@@ -448,13 +444,13 @@ async def admin_list_organizations(
total=total, total=total,
page=pagination.page, page=pagination.page,
limit=pagination.limit, limit=pagination.limit,
items_count=len(orgs_with_count) items_count=len(orgs_with_count),
) )
return PaginatedResponse(data=orgs_with_count, pagination=pagination_meta) return PaginatedResponse(data=orgs_with_count, pagination=pagination_meta)
except Exception as e: except Exception as e:
logger.error(f"Error listing organizations (admin): {str(e)}", exc_info=True) logger.error(f"Error listing organizations (admin): {e!s}", exc_info=True)
raise raise
@@ -464,12 +460,12 @@ async def admin_list_organizations(
status_code=status.HTTP_201_CREATED, status_code=status.HTTP_201_CREATED,
summary="Admin: Create Organization", summary="Admin: Create Organization",
description="Create a new organization (admin only)", description="Create a new organization (admin only)",
operation_id="admin_create_organization" operation_id="admin_create_organization",
) )
async def admin_create_organization( async def admin_create_organization(
org_in: OrganizationCreate, org_in: OrganizationCreate,
admin: User = Depends(require_superuser), admin: User = Depends(require_superuser),
db: AsyncSession = Depends(get_db) db: AsyncSession = Depends(get_db),
) -> Any: ) -> Any:
"""Create a new organization.""" """Create a new organization."""
try: try:
@@ -486,18 +482,15 @@ async def admin_create_organization(
"settings": org.settings, "settings": org.settings,
"created_at": org.created_at, "created_at": org.created_at,
"updated_at": org.updated_at, "updated_at": org.updated_at,
"member_count": 0 "member_count": 0,
} }
return OrganizationResponse(**org_dict) return OrganizationResponse(**org_dict)
except ValueError as e: except ValueError as e:
logger.warning(f"Failed to create organization: {str(e)}") logger.warning(f"Failed to create organization: {e!s}")
raise NotFoundError( raise NotFoundError(message=str(e), error_code=ErrorCode.ALREADY_EXISTS)
message=str(e),
error_code=ErrorCode.ALREADY_EXISTS
)
except Exception as e: except Exception as e:
logger.error(f"Error creating organization (admin): {str(e)}", exc_info=True) logger.error(f"Error creating organization (admin): {e!s}", exc_info=True)
raise raise
@@ -506,19 +499,18 @@ async def admin_create_organization(
response_model=OrganizationResponse, response_model=OrganizationResponse,
summary="Admin: Get Organization Details", summary="Admin: Get Organization Details",
description="Get detailed organization information (admin only)", description="Get detailed organization information (admin only)",
operation_id="admin_get_organization" operation_id="admin_get_organization",
) )
async def admin_get_organization( async def admin_get_organization(
org_id: UUID, org_id: UUID,
admin: User = Depends(require_superuser), admin: User = Depends(require_superuser),
db: AsyncSession = Depends(get_db) db: AsyncSession = Depends(get_db),
) -> Any: ) -> Any:
"""Get detailed information about a specific organization.""" """Get detailed information about a specific organization."""
org = await organization_crud.get(db, id=org_id) org = await organization_crud.get(db, id=org_id)
if not org: if not org:
raise NotFoundError( raise NotFoundError(
message=f"Organization {org_id} not found", message=f"Organization {org_id} not found", error_code=ErrorCode.NOT_FOUND
error_code=ErrorCode.NOT_FOUND
) )
org_dict = { org_dict = {
@@ -530,7 +522,9 @@ async def admin_get_organization(
"settings": org.settings, "settings": org.settings,
"created_at": org.created_at, "created_at": org.created_at,
"updated_at": org.updated_at, "updated_at": org.updated_at,
"member_count": await organization_crud.get_member_count(db, organization_id=org.id) "member_count": await organization_crud.get_member_count(
db, organization_id=org.id
),
} }
return OrganizationResponse(**org_dict) return OrganizationResponse(**org_dict)
@@ -540,13 +534,13 @@ async def admin_get_organization(
response_model=OrganizationResponse, response_model=OrganizationResponse,
summary="Admin: Update Organization", summary="Admin: Update Organization",
description="Update organization information (admin only)", description="Update organization information (admin only)",
operation_id="admin_update_organization" operation_id="admin_update_organization",
) )
async def admin_update_organization( async def admin_update_organization(
org_id: UUID, org_id: UUID,
org_in: OrganizationUpdate, org_in: OrganizationUpdate,
admin: User = Depends(require_superuser), admin: User = Depends(require_superuser),
db: AsyncSession = Depends(get_db) db: AsyncSession = Depends(get_db),
) -> Any: ) -> Any:
"""Update organization information.""" """Update organization information."""
try: try:
@@ -554,7 +548,7 @@ async def admin_update_organization(
if not org: if not org:
raise NotFoundError( raise NotFoundError(
message=f"Organization {org_id} not found", message=f"Organization {org_id} not found",
error_code=ErrorCode.NOT_FOUND error_code=ErrorCode.NOT_FOUND,
) )
updated_org = await organization_crud.update(db, db_obj=org, obj_in=org_in) updated_org = await organization_crud.update(db, db_obj=org, obj_in=org_in)
@@ -569,14 +563,16 @@ async def admin_update_organization(
"settings": updated_org.settings, "settings": updated_org.settings,
"created_at": updated_org.created_at, "created_at": updated_org.created_at,
"updated_at": updated_org.updated_at, "updated_at": updated_org.updated_at,
"member_count": await organization_crud.get_member_count(db, organization_id=updated_org.id) "member_count": await organization_crud.get_member_count(
db, organization_id=updated_org.id
),
} }
return OrganizationResponse(**org_dict) return OrganizationResponse(**org_dict)
except NotFoundError: except NotFoundError:
raise raise
except Exception as e: except Exception as e:
logger.error(f"Error updating organization (admin): {str(e)}", exc_info=True) logger.error(f"Error updating organization (admin): {e!s}", exc_info=True)
raise raise
@@ -585,12 +581,12 @@ async def admin_update_organization(
response_model=MessageResponse, response_model=MessageResponse,
summary="Admin: Delete Organization", summary="Admin: Delete Organization",
description="Delete an organization (admin only)", description="Delete an organization (admin only)",
operation_id="admin_delete_organization" operation_id="admin_delete_organization",
) )
async def admin_delete_organization( async def admin_delete_organization(
org_id: UUID, org_id: UUID,
admin: User = Depends(require_superuser), admin: User = Depends(require_superuser),
db: AsyncSession = Depends(get_db) db: AsyncSession = Depends(get_db),
) -> Any: ) -> Any:
"""Delete an organization and all its relationships.""" """Delete an organization and all its relationships."""
try: try:
@@ -598,21 +594,20 @@ async def admin_delete_organization(
if not org: if not org:
raise NotFoundError( raise NotFoundError(
message=f"Organization {org_id} not found", message=f"Organization {org_id} not found",
error_code=ErrorCode.NOT_FOUND error_code=ErrorCode.NOT_FOUND,
) )
await organization_crud.remove(db, id=org_id) await organization_crud.remove(db, id=org_id)
logger.info(f"Admin {admin.email} deleted organization {org.name}") logger.info(f"Admin {admin.email} deleted organization {org.name}")
return MessageResponse( return MessageResponse(
success=True, success=True, message=f"Organization {org.name} has been deleted"
message=f"Organization {org.name} has been deleted"
) )
except NotFoundError: except NotFoundError:
raise raise
except Exception as e: except Exception as e:
logger.error(f"Error deleting organization (admin): {str(e)}", exc_info=True) logger.error(f"Error deleting organization (admin): {e!s}", exc_info=True)
raise raise
@@ -621,14 +616,14 @@ async def admin_delete_organization(
response_model=PaginatedResponse[OrganizationMemberResponse], response_model=PaginatedResponse[OrganizationMemberResponse],
summary="Admin: List Organization Members", summary="Admin: List Organization Members",
description="Get all members of an organization (admin only)", description="Get all members of an organization (admin only)",
operation_id="admin_list_organization_members" operation_id="admin_list_organization_members",
) )
async def admin_list_organization_members( async def admin_list_organization_members(
org_id: UUID, org_id: UUID,
pagination: PaginationParams = Depends(), pagination: PaginationParams = Depends(),
is_active: Optional[bool] = Query(True, description="Filter by active status"), is_active: bool | None = Query(True, description="Filter by active status"),
admin: User = Depends(require_superuser), admin: User = Depends(require_superuser),
db: AsyncSession = Depends(get_db) db: AsyncSession = Depends(get_db),
) -> Any: ) -> Any:
"""List all members of an organization.""" """List all members of an organization."""
try: try:
@@ -636,7 +631,7 @@ async def admin_list_organization_members(
if not org: if not org:
raise NotFoundError( raise NotFoundError(
message=f"Organization {org_id} not found", message=f"Organization {org_id} not found",
error_code=ErrorCode.NOT_FOUND error_code=ErrorCode.NOT_FOUND,
) )
members, total = await organization_crud.get_organization_members( members, total = await organization_crud.get_organization_members(
@@ -644,7 +639,7 @@ async def admin_list_organization_members(
organization_id=org_id, organization_id=org_id,
skip=pagination.offset, skip=pagination.offset,
limit=pagination.limit, limit=pagination.limit,
is_active=is_active is_active=is_active,
) )
# Convert to response models # Convert to response models
@@ -654,7 +649,7 @@ async def admin_list_organization_members(
total=total, total=total,
page=pagination.page, page=pagination.page,
limit=pagination.limit, limit=pagination.limit,
items_count=len(member_responses) items_count=len(member_responses),
) )
return PaginatedResponse(data=member_responses, pagination=pagination_meta) return PaginatedResponse(data=member_responses, pagination=pagination_meta)
@@ -662,14 +657,19 @@ async def admin_list_organization_members(
except NotFoundError: except NotFoundError:
raise raise
except Exception as e: except Exception as e:
logger.error(f"Error listing organization members (admin): {str(e)}", exc_info=True) logger.error(
f"Error listing organization members (admin): {e!s}", exc_info=True
)
raise raise
class AddMemberRequest(BaseModel): class AddMemberRequest(BaseModel):
"""Request to add a member to an organization.""" """Request to add a member to an organization."""
user_id: UUID = Field(..., description="User ID to add") user_id: UUID = Field(..., description="User ID to add")
role: OrganizationRole = Field(OrganizationRole.MEMBER, description="Role in organization") role: OrganizationRole = Field(
OrganizationRole.MEMBER, description="Role in organization"
)
@router.post( @router.post(
@@ -677,13 +677,13 @@ class AddMemberRequest(BaseModel):
response_model=MessageResponse, response_model=MessageResponse,
summary="Admin: Add Member to Organization", summary="Admin: Add Member to Organization",
description="Add a user to an organization (admin only)", description="Add a user to an organization (admin only)",
operation_id="admin_add_organization_member" operation_id="admin_add_organization_member",
) )
async def admin_add_organization_member( async def admin_add_organization_member(
org_id: UUID, org_id: UUID,
request: AddMemberRequest, request: AddMemberRequest,
admin: User = Depends(require_superuser), admin: User = Depends(require_superuser),
db: AsyncSession = Depends(get_db) db: AsyncSession = Depends(get_db),
) -> Any: ) -> Any:
"""Add a user to an organization.""" """Add a user to an organization."""
try: try:
@@ -691,21 +691,18 @@ async def admin_add_organization_member(
if not org: if not org:
raise NotFoundError( raise NotFoundError(
message=f"Organization {org_id} not found", message=f"Organization {org_id} not found",
error_code=ErrorCode.NOT_FOUND error_code=ErrorCode.NOT_FOUND,
) )
user = await user_crud.get(db, id=request.user_id) user = await user_crud.get(db, id=request.user_id)
if not user: if not user:
raise NotFoundError( raise NotFoundError(
message=f"User {request.user_id} not found", message=f"User {request.user_id} not found",
error_code=ErrorCode.USER_NOT_FOUND error_code=ErrorCode.USER_NOT_FOUND,
) )
await organization_crud.add_user( await organization_crud.add_user(
db, db, organization_id=org_id, user_id=request.user_id, role=request.role
organization_id=org_id,
user_id=request.user_id,
role=request.role
) )
logger.info( logger.info(
@@ -714,22 +711,21 @@ async def admin_add_organization_member(
) )
return MessageResponse( return MessageResponse(
success=True, success=True, message=f"User {user.email} added to organization {org.name}"
message=f"User {user.email} added to organization {org.name}"
) )
except ValueError as e: except ValueError as e:
logger.warning(f"Failed to add user to organization: {str(e)}") logger.warning(f"Failed to add user to organization: {e!s}")
# Use DuplicateError for "already exists" scenarios # Use DuplicateError for "already exists" scenarios
raise DuplicateError( raise DuplicateError(
message=str(e), message=str(e), error_code=ErrorCode.USER_ALREADY_EXISTS, field="user_id"
error_code=ErrorCode.USER_ALREADY_EXISTS,
field="user_id"
) )
except NotFoundError: except NotFoundError:
raise raise
except Exception as e: except Exception as e:
logger.error(f"Error adding member to organization (admin): {str(e)}", exc_info=True) logger.error(
f"Error adding member to organization (admin): {e!s}", exc_info=True
)
raise raise
@@ -738,13 +734,13 @@ async def admin_add_organization_member(
response_model=MessageResponse, response_model=MessageResponse,
summary="Admin: Remove Member from Organization", summary="Admin: Remove Member from Organization",
description="Remove a user from an organization (admin only)", description="Remove a user from an organization (admin only)",
operation_id="admin_remove_organization_member" operation_id="admin_remove_organization_member",
) )
async def admin_remove_organization_member( async def admin_remove_organization_member(
org_id: UUID, org_id: UUID,
user_id: UUID, user_id: UUID,
admin: User = Depends(require_superuser), admin: User = Depends(require_superuser),
db: AsyncSession = Depends(get_db) db: AsyncSession = Depends(get_db),
) -> Any: ) -> Any:
"""Remove a user from an organization.""" """Remove a user from an organization."""
try: try:
@@ -752,39 +748,40 @@ async def admin_remove_organization_member(
if not org: if not org:
raise NotFoundError( raise NotFoundError(
message=f"Organization {org_id} not found", message=f"Organization {org_id} not found",
error_code=ErrorCode.NOT_FOUND error_code=ErrorCode.NOT_FOUND,
) )
user = await user_crud.get(db, id=user_id) user = await user_crud.get(db, id=user_id)
if not user: if not user:
raise NotFoundError( raise NotFoundError(
message=f"User {user_id} not found", message=f"User {user_id} not found", error_code=ErrorCode.USER_NOT_FOUND
error_code=ErrorCode.USER_NOT_FOUND
) )
success = await organization_crud.remove_user( success = await organization_crud.remove_user(
db, db, organization_id=org_id, user_id=user_id
organization_id=org_id,
user_id=user_id
) )
if not success: if not success:
raise NotFoundError( raise NotFoundError(
message="User is not a member of this organization", message="User is not a member of this organization",
error_code=ErrorCode.NOT_FOUND error_code=ErrorCode.NOT_FOUND,
) )
logger.info(f"Admin {admin.email} removed user {user.email} from organization {org.name}") logger.info(
f"Admin {admin.email} removed user {user.email} from organization {org.name}"
)
return MessageResponse( return MessageResponse(
success=True, success=True,
message=f"User {user.email} removed from organization {org.name}" message=f"User {user.email} removed from organization {org.name}",
) )
except NotFoundError: except NotFoundError:
raise raise
except Exception as e: except Exception as e:
logger.error(f"Error removing member from organization (admin): {str(e)}", exc_info=True) logger.error(
f"Error removing member from organization (admin): {e!s}", exc_info=True
)
raise raise
@@ -792,6 +789,7 @@ async def admin_remove_organization_member(
# Session Management Endpoints # Session Management Endpoints
# ============================================================================ # ============================================================================
@router.get( @router.get(
"/sessions", "/sessions",
response_model=PaginatedResponse[AdminSessionResponse], response_model=PaginatedResponse[AdminSessionResponse],
@@ -802,13 +800,13 @@ async def admin_remove_organization_member(
Returns paginated list of sessions with user information. Returns paginated list of sessions with user information.
Useful for admin dashboard statistics and session monitoring. Useful for admin dashboard statistics and session monitoring.
""", """,
operation_id="admin_list_sessions" operation_id="admin_list_sessions",
) )
async def admin_list_sessions( async def admin_list_sessions(
pagination: PaginationParams = Depends(), pagination: PaginationParams = Depends(),
is_active: Optional[bool] = Query(None, description="Filter by active status"), is_active: bool | None = Query(None, description="Filter by active status"),
admin: User = Depends(require_superuser), admin: User = Depends(require_superuser),
db: AsyncSession = Depends(get_db) db: AsyncSession = Depends(get_db),
) -> Any: ) -> Any:
"""List all sessions across all users with filtering and pagination.""" """List all sessions across all users with filtering and pagination."""
try: try:
@@ -818,7 +816,7 @@ async def admin_list_sessions(
skip=pagination.offset, skip=pagination.offset,
limit=pagination.limit, limit=pagination.limit,
active_only=is_active if is_active is not None else True, active_only=is_active if is_active is not None else True,
with_user=True with_user=True,
) )
# Build response objects with user information # Build response objects with user information
@@ -847,21 +845,23 @@ async def admin_list_sessions(
last_used_at=session.last_used_at, last_used_at=session.last_used_at,
created_at=session.created_at, created_at=session.created_at,
expires_at=session.expires_at, expires_at=session.expires_at,
is_active=session.is_active is_active=session.is_active,
) )
session_responses.append(session_response) session_responses.append(session_response)
logger.info(f"Admin {admin.email} listed {len(session_responses)} sessions (total: {total})") logger.info(
f"Admin {admin.email} listed {len(session_responses)} sessions (total: {total})"
)
pagination_meta = create_pagination_meta( pagination_meta = create_pagination_meta(
total=total, total=total,
page=pagination.page, page=pagination.page,
limit=pagination.limit, limit=pagination.limit,
items_count=len(session_responses) items_count=len(session_responses),
) )
return PaginatedResponse(data=session_responses, pagination=pagination_meta) return PaginatedResponse(data=session_responses, pagination=pagination_meta)
except Exception as e: except Exception as e:
logger.error(f"Error listing sessions (admin): {str(e)}", exc_info=True) logger.error(f"Error listing sessions (admin): {e!s}", exc_info=True)
raise raise

View File

@@ -1,39 +1,43 @@
# app/api/routes/auth.py # app/api/routes/auth.py
import logging import logging
import os import os
from datetime import datetime, timezone from datetime import UTC, datetime
from typing import Any from typing import Any
from fastapi import APIRouter, Depends, HTTPException, status, Request from fastapi import APIRouter, Depends, HTTPException, Request, status
from fastapi.security import OAuth2PasswordRequestForm from fastapi.security import OAuth2PasswordRequestForm
from slowapi import Limiter from slowapi import Limiter
from slowapi.util import get_remote_address from slowapi.util import get_remote_address
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from app.api.dependencies.auth import get_current_user from app.api.dependencies.auth import get_current_user
from app.core.auth import TokenExpiredError, TokenInvalidError, decode_token from app.core.auth import (
from app.core.auth import get_password_hash TokenExpiredError,
TokenInvalidError,
decode_token,
get_password_hash,
)
from app.core.database import get_db from app.core.database import get_db
from app.core.exceptions import ( from app.core.exceptions import (
AuthenticationError as AuthError, AuthenticationError as AuthError,
DatabaseError, DatabaseError,
ErrorCode ErrorCode,
) )
from app.crud.session import session as session_crud from app.crud.session import session as session_crud
from app.crud.user import user as user_crud from app.crud.user import user as user_crud
from app.models.user import User from app.models.user import User
from app.schemas.common import MessageResponse from app.schemas.common import MessageResponse
from app.schemas.sessions import SessionCreate, LogoutRequest from app.schemas.sessions import LogoutRequest, SessionCreate
from app.schemas.users import ( from app.schemas.users import (
LoginRequest,
PasswordResetConfirm,
PasswordResetRequest,
RefreshTokenRequest,
Token,
UserCreate, UserCreate,
UserResponse, UserResponse,
Token,
LoginRequest,
RefreshTokenRequest,
PasswordResetRequest,
PasswordResetConfirm
) )
from app.services.auth_service import AuthService, AuthenticationError from app.services.auth_service import AuthenticationError, AuthService
from app.services.email_service import email_service from app.services.email_service import email_service
from app.utils.device import extract_device_info from app.utils.device import extract_device_info
from app.utils.security import create_password_reset_token, verify_password_reset_token from app.utils.security import create_password_reset_token, verify_password_reset_token
@@ -54,7 +58,7 @@ async def _create_login_session(
request: Request, request: Request,
user: User, user: User,
tokens: Token, tokens: Token,
login_type: str = "login" login_type: str = "login",
) -> None: ) -> None:
""" """
Create a session record for successful login. Create a session record for successful login.
@@ -81,8 +85,8 @@ async def _create_login_session(
device_id=device_info.device_id, device_id=device_info.device_id,
ip_address=device_info.ip_address, ip_address=device_info.ip_address,
user_agent=device_info.user_agent, user_agent=device_info.user_agent,
last_used_at=datetime.now(timezone.utc), last_used_at=datetime.now(UTC),
expires_at=datetime.fromtimestamp(refresh_payload.exp, tz=timezone.utc), expires_at=datetime.fromtimestamp(refresh_payload.exp, tz=UTC),
location_city=device_info.location_city, location_city=device_info.location_city,
location_country=device_info.location_country, location_country=device_info.location_country,
) )
@@ -95,15 +99,20 @@ async def _create_login_session(
) )
except Exception as session_err: except Exception as session_err:
# Log but don't fail login if session creation fails # Log but don't fail login if session creation fails
logger.error(f"Failed to create session for {user.email}: {str(session_err)}", exc_info=True) logger.error(
f"Failed to create session for {user.email}: {session_err!s}", exc_info=True
)
@router.post("/register", response_model=UserResponse, status_code=status.HTTP_201_CREATED, operation_id="register") @router.post(
"/register",
response_model=UserResponse,
status_code=status.HTTP_201_CREATED,
operation_id="register",
)
@limiter.limit(f"{5 * RATE_MULTIPLIER}/minute") @limiter.limit(f"{5 * RATE_MULTIPLIER}/minute")
async def register_user( async def register_user(
request: Request, request: Request, user_data: UserCreate, db: AsyncSession = Depends(get_db)
user_data: UserCreate,
db: AsyncSession = Depends(get_db)
) -> Any: ) -> Any:
""" """
Register a new user. Register a new user.
@@ -116,25 +125,23 @@ async def register_user(
return user return user
except AuthenticationError as e: except AuthenticationError as e:
# SECURITY: Don't reveal if email exists - generic error message # SECURITY: Don't reveal if email exists - generic error message
logger.warning(f"Registration failed: {str(e)}") logger.warning(f"Registration failed: {e!s}")
raise HTTPException( raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST, status_code=status.HTTP_400_BAD_REQUEST,
detail="Registration failed. Please check your information and try again." detail="Registration failed. Please check your information and try again.",
) )
except Exception as e: except Exception as e:
logger.error(f"Unexpected error during registration: {str(e)}", exc_info=True) logger.error(f"Unexpected error during registration: {e!s}", exc_info=True)
raise DatabaseError( raise DatabaseError(
message="An unexpected error occurred. Please try again later.", message="An unexpected error occurred. Please try again later.",
error_code=ErrorCode.INTERNAL_ERROR error_code=ErrorCode.INTERNAL_ERROR,
) )
@router.post("/login", response_model=Token, operation_id="login") @router.post("/login", response_model=Token, operation_id="login")
@limiter.limit(f"{10 * RATE_MULTIPLIER}/minute") @limiter.limit(f"{10 * RATE_MULTIPLIER}/minute")
async def login( async def login(
request: Request, request: Request, login_data: LoginRequest, db: AsyncSession = Depends(get_db)
login_data: LoginRequest,
db: AsyncSession = Depends(get_db)
) -> Any: ) -> Any:
""" """
Login with username and password. Login with username and password.
@@ -146,14 +153,16 @@ async def login(
""" """
try: try:
# Attempt to authenticate the user # Attempt to authenticate the user
user = await AuthService.authenticate_user(db, login_data.email, login_data.password) user = await AuthService.authenticate_user(
db, login_data.email, login_data.password
)
# Explicitly check for None result and raise correct exception # Explicitly check for None result and raise correct exception
if user is None: if user is None:
logger.warning(f"Invalid login attempt for: {login_data.email}") logger.warning(f"Invalid login attempt for: {login_data.email}")
raise AuthError( raise AuthError(
message="Invalid email or password", message="Invalid email or password",
error_code=ErrorCode.INVALID_CREDENTIALS error_code=ErrorCode.INVALID_CREDENTIALS,
) )
# User is authenticated, generate tokens # User is authenticated, generate tokens
@@ -166,29 +175,26 @@ async def login(
except AuthenticationError as e: except AuthenticationError as e:
# Handle specific authentication errors like inactive accounts # Handle specific authentication errors like inactive accounts
logger.warning(f"Authentication failed: {str(e)}") logger.warning(f"Authentication failed: {e!s}")
raise AuthError( raise AuthError(message=str(e), error_code=ErrorCode.INVALID_CREDENTIALS)
message=str(e),
error_code=ErrorCode.INVALID_CREDENTIALS
)
except AuthError: except AuthError:
# Re-raise custom auth exceptions without modification # Re-raise custom auth exceptions without modification
raise raise
except Exception as e: except Exception as e:
# Handle unexpected errors # Handle unexpected errors
logger.error(f"Unexpected error during login: {str(e)}", exc_info=True) logger.error(f"Unexpected error during login: {e!s}", exc_info=True)
raise DatabaseError( raise DatabaseError(
message="An unexpected error occurred. Please try again later.", message="An unexpected error occurred. Please try again later.",
error_code=ErrorCode.INTERNAL_ERROR error_code=ErrorCode.INTERNAL_ERROR,
) )
@router.post("/login/oauth", response_model=Token, operation_id='login_oauth') @router.post("/login/oauth", response_model=Token, operation_id="login_oauth")
@limiter.limit("10/minute") @limiter.limit("10/minute")
async def login_oauth( async def login_oauth(
request: Request, request: Request,
form_data: OAuth2PasswordRequestForm = Depends(), form_data: OAuth2PasswordRequestForm = Depends(),
db: AsyncSession = Depends(get_db) db: AsyncSession = Depends(get_db),
) -> Any: ) -> Any:
""" """
OAuth2-compatible login endpoint, used by the OpenAPI UI. OAuth2-compatible login endpoint, used by the OpenAPI UI.
@@ -199,12 +205,14 @@ async def login_oauth(
Access and refresh tokens. Access and refresh tokens.
""" """
try: try:
user = await AuthService.authenticate_user(db, form_data.username, form_data.password) user = await AuthService.authenticate_user(
db, form_data.username, form_data.password
)
if user is None: if user is None:
raise AuthError( raise AuthError(
message="Invalid email or password", message="Invalid email or password",
error_code=ErrorCode.INVALID_CREDENTIALS error_code=ErrorCode.INVALID_CREDENTIALS,
) )
# Generate tokens # Generate tokens
@@ -216,28 +224,25 @@ async def login_oauth(
# Return full token response with user data # Return full token response with user data
return tokens return tokens
except AuthenticationError as e: except AuthenticationError as e:
logger.warning(f"OAuth authentication failed: {str(e)}") logger.warning(f"OAuth authentication failed: {e!s}")
raise AuthError( raise AuthError(message=str(e), error_code=ErrorCode.INVALID_CREDENTIALS)
message=str(e),
error_code=ErrorCode.INVALID_CREDENTIALS
)
except AuthError: except AuthError:
# Re-raise custom auth exceptions without modification # Re-raise custom auth exceptions without modification
raise raise
except Exception as e: except Exception as e:
logger.error(f"Unexpected error during OAuth login: {str(e)}", exc_info=True) logger.error(f"Unexpected error during OAuth login: {e!s}", exc_info=True)
raise DatabaseError( raise DatabaseError(
message="An unexpected error occurred. Please try again later.", message="An unexpected error occurred. Please try again later.",
error_code=ErrorCode.INTERNAL_ERROR error_code=ErrorCode.INTERNAL_ERROR,
) )
@router.post("/refresh", response_model=Token, operation_id="refresh_token") @router.post("/refresh", response_model=Token, operation_id="refresh_token")
@limiter.limit("30/minute") @limiter.limit("30/minute")
async def refresh_token( async def refresh_token(
request: Request, request: Request,
refresh_data: RefreshTokenRequest, refresh_data: RefreshTokenRequest,
db: AsyncSession = Depends(get_db) db: AsyncSession = Depends(get_db),
) -> Any: ) -> Any:
""" """
Refresh access token using a refresh token. Refresh access token using a refresh token.
@@ -249,13 +254,17 @@ async def refresh_token(
""" """
try: try:
# Decode the refresh token to get the JTI # Decode the refresh token to get the JTI
refresh_payload = decode_token(refresh_data.refresh_token, verify_type="refresh") refresh_payload = decode_token(
refresh_data.refresh_token, verify_type="refresh"
)
# Check if session exists and is active # Check if session exists and is active
session = await session_crud.get_active_by_jti(db, jti=refresh_payload.jti) session = await session_crud.get_active_by_jti(db, jti=refresh_payload.jti)
if not session: if not session:
logger.warning(f"Refresh token used for inactive or non-existent session: {refresh_payload.jti}") logger.warning(
f"Refresh token used for inactive or non-existent session: {refresh_payload.jti}"
)
raise HTTPException( raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, status_code=status.HTTP_401_UNAUTHORIZED,
detail="Session has been revoked. Please log in again.", detail="Session has been revoked. Please log in again.",
@@ -274,10 +283,12 @@ async def refresh_token(
db, db,
session=session, session=session,
new_jti=new_refresh_payload.jti, new_jti=new_refresh_payload.jti,
new_expires_at=datetime.fromtimestamp(new_refresh_payload.exp, tz=timezone.utc) new_expires_at=datetime.fromtimestamp(new_refresh_payload.exp, tz=UTC),
) )
except Exception as session_err: except Exception as session_err:
logger.error(f"Failed to update session {session.id}: {str(session_err)}", exc_info=True) logger.error(
f"Failed to update session {session.id}: {session_err!s}", exc_info=True
)
# Continue anyway - tokens are already issued # Continue anyway - tokens are already issued
return tokens return tokens
@@ -300,10 +311,10 @@ async def refresh_token(
# Re-raise HTTP exceptions (like session revoked) # Re-raise HTTP exceptions (like session revoked)
raise raise
except Exception as e: except Exception as e:
logger.error(f"Unexpected error during token refresh: {str(e)}") logger.error(f"Unexpected error during token refresh: {e!s}")
raise HTTPException( raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="An unexpected error occurred. Please try again later." detail="An unexpected error occurred. Please try again later.",
) )
@@ -320,13 +331,13 @@ async def refresh_token(
**Rate Limit**: 3 requests/minute **Rate Limit**: 3 requests/minute
""", """,
operation_id="request_password_reset" operation_id="request_password_reset",
) )
@limiter.limit("3/minute") @limiter.limit("3/minute")
async def request_password_reset( async def request_password_reset(
request: Request, request: Request,
reset_request: PasswordResetRequest, reset_request: PasswordResetRequest,
db: AsyncSession = Depends(get_db) db: AsyncSession = Depends(get_db),
) -> Any: ) -> Any:
""" """
Request a password reset. Request a password reset.
@@ -345,26 +356,26 @@ async def request_password_reset(
# Send password reset email # Send password reset email
await email_service.send_password_reset_email( await email_service.send_password_reset_email(
to_email=user.email, to_email=user.email, reset_token=reset_token, user_name=user.first_name
reset_token=reset_token,
user_name=user.first_name
) )
logger.info(f"Password reset requested for {user.email}") logger.info(f"Password reset requested for {user.email}")
else: else:
# Log attempt but don't reveal if email exists # Log attempt but don't reveal if email exists
logger.warning(f"Password reset requested for non-existent or inactive email: {reset_request.email}") logger.warning(
f"Password reset requested for non-existent or inactive email: {reset_request.email}"
)
# Always return success to prevent email enumeration # Always return success to prevent email enumeration
return MessageResponse( return MessageResponse(
success=True, success=True,
message="If your email is registered, you will receive a password reset link shortly" message="If your email is registered, you will receive a password reset link shortly",
) )
except Exception as e: except Exception as e:
logger.error(f"Error processing password reset request: {str(e)}", exc_info=True) logger.error(f"Error processing password reset request: {e!s}", exc_info=True)
# Still return success to prevent information leakage # Still return success to prevent information leakage
return MessageResponse( return MessageResponse(
success=True, success=True,
message="If your email is registered, you will receive a password reset link shortly" message="If your email is registered, you will receive a password reset link shortly",
) )
@@ -378,13 +389,13 @@ async def request_password_reset(
**Rate Limit**: 5 requests/minute **Rate Limit**: 5 requests/minute
""", """,
operation_id="confirm_password_reset" operation_id="confirm_password_reset",
) )
@limiter.limit("5/minute") @limiter.limit("5/minute")
async def confirm_password_reset( async def confirm_password_reset(
request: Request, request: Request,
reset_confirm: PasswordResetConfirm, reset_confirm: PasswordResetConfirm,
db: AsyncSession = Depends(get_db) db: AsyncSession = Depends(get_db),
) -> Any: ) -> Any:
""" """
Confirm password reset with token. Confirm password reset with token.
@@ -398,7 +409,7 @@ async def confirm_password_reset(
if not email: if not email:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST, status_code=status.HTTP_400_BAD_REQUEST,
detail="Invalid or expired password reset token" detail="Invalid or expired password reset token",
) )
# Look up user # Look up user
@@ -406,14 +417,13 @@ async def confirm_password_reset(
if not user: if not user:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, status_code=status.HTTP_404_NOT_FOUND, detail="User not found"
detail="User not found"
) )
if not user.is_active: if not user.is_active:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST, status_code=status.HTTP_400_BAD_REQUEST,
detail="User account is inactive" detail="User account is inactive",
) )
# Update password # Update password
@@ -424,29 +434,33 @@ async def confirm_password_reset(
# SECURITY: Invalidate all existing sessions after password reset # SECURITY: Invalidate all existing sessions after password reset
# This prevents stolen sessions from being used after password change # This prevents stolen sessions from being used after password change
from app.crud.session import session as session_crud from app.crud.session import session as session_crud
try: try:
deactivated_count = await session_crud.deactivate_all_user_sessions( deactivated_count = await session_crud.deactivate_all_user_sessions(
db, db, user_id=str(user.id)
user_id=str(user.id) )
logger.info(
f"Password reset successful for {user.email}, invalidated {deactivated_count} sessions"
) )
logger.info(f"Password reset successful for {user.email}, invalidated {deactivated_count} sessions")
except Exception as session_error: except Exception as session_error:
# Log but don't fail password reset if session invalidation fails # Log but don't fail password reset if session invalidation fails
logger.error(f"Failed to invalidate sessions after password reset: {str(session_error)}") logger.error(
f"Failed to invalidate sessions after password reset: {session_error!s}"
)
return MessageResponse( return MessageResponse(
success=True, success=True,
message="Password has been reset successfully. All devices have been logged out for security. You can now log in with your new password." message="Password has been reset successfully. All devices have been logged out for security. You can now log in with your new password.",
) )
except HTTPException: except HTTPException:
raise raise
except Exception as e: except Exception as e:
logger.error(f"Error confirming password reset: {str(e)}", exc_info=True) logger.error(f"Error confirming password reset: {e!s}", exc_info=True)
await db.rollback() await db.rollback()
raise HTTPException( raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="An error occurred while resetting your password" detail="An error occurred while resetting your password",
) )
@@ -464,14 +478,14 @@ async def confirm_password_reset(
**Rate Limit**: 10 requests/minute **Rate Limit**: 10 requests/minute
""", """,
operation_id="logout" operation_id="logout",
) )
@limiter.limit("10/minute") @limiter.limit("10/minute")
async def logout( async def logout(
request: Request, request: Request,
logout_request: LogoutRequest, logout_request: LogoutRequest,
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db) db: AsyncSession = Depends(get_db),
) -> Any: ) -> Any:
""" """
Logout from current device by deactivating the session. Logout from current device by deactivating the session.
@@ -487,15 +501,14 @@ async def logout(
try: try:
# Decode refresh token to get JTI # Decode refresh token to get JTI
try: try:
refresh_payload = decode_token(logout_request.refresh_token, verify_type="refresh") refresh_payload = decode_token(
logout_request.refresh_token, verify_type="refresh"
)
except (TokenExpiredError, TokenInvalidError) as e: except (TokenExpiredError, TokenInvalidError) as e:
# Even if token is expired/invalid, try to deactivate session # Even if token is expired/invalid, try to deactivate session
logger.warning(f"Logout with invalid/expired token: {str(e)}") logger.warning(f"Logout with invalid/expired token: {e!s}")
# Don't fail - return success anyway # Don't fail - return success anyway
return MessageResponse( return MessageResponse(success=True, message="Logged out successfully")
success=True,
message="Logged out successfully"
)
# Find the session by JTI # Find the session by JTI
session = await session_crud.get_by_jti(db, jti=refresh_payload.jti) session = await session_crud.get_by_jti(db, jti=refresh_payload.jti)
@@ -509,7 +522,7 @@ async def logout(
) )
raise HTTPException( raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN, status_code=status.HTTP_403_FORBIDDEN,
detail="You can only logout your own sessions" detail="You can only logout your own sessions",
) )
# Deactivate the session # Deactivate the session
@@ -522,22 +535,20 @@ async def logout(
else: else:
# Session not found - maybe already deleted or never existed # Session not found - maybe already deleted or never existed
# Return success anyway (idempotent) # Return success anyway (idempotent)
logger.info(f"Logout requested for non-existent session (JTI: {refresh_payload.jti})") logger.info(
f"Logout requested for non-existent session (JTI: {refresh_payload.jti})"
)
return MessageResponse( return MessageResponse(success=True, message="Logged out successfully")
success=True,
message="Logged out successfully"
)
except HTTPException: except HTTPException:
raise raise
except Exception as e: except Exception as e:
logger.error(f"Error during logout for user {current_user.id}: {str(e)}", exc_info=True) logger.error(
# Don't expose error details f"Error during logout for user {current_user.id}: {e!s}", exc_info=True
return MessageResponse(
success=True,
message="Logged out successfully"
) )
# Don't expose error details
return MessageResponse(success=True, message="Logged out successfully")
@router.post( @router.post(
@@ -553,13 +564,13 @@ async def logout(
**Rate Limit**: 5 requests/minute **Rate Limit**: 5 requests/minute
""", """,
operation_id="logout_all" operation_id="logout_all",
) )
@limiter.limit("5/minute") @limiter.limit("5/minute")
async def logout_all( async def logout_all(
request: Request, request: Request,
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db) db: AsyncSession = Depends(get_db),
) -> Any: ) -> Any:
""" """
Logout from all devices by deactivating all user sessions. Logout from all devices by deactivating all user sessions.
@@ -573,19 +584,25 @@ async def logout_all(
""" """
try: try:
# Deactivate all sessions for this user # Deactivate all sessions for this user
count = await session_crud.deactivate_all_user_sessions(db, user_id=str(current_user.id)) count = await session_crud.deactivate_all_user_sessions(
db, user_id=str(current_user.id)
)
logger.info(f"User {current_user.id} logged out from all devices ({count} sessions)") logger.info(
f"User {current_user.id} logged out from all devices ({count} sessions)"
)
return MessageResponse( return MessageResponse(
success=True, success=True,
message=f"Successfully logged out from all devices ({count} sessions terminated)" message=f"Successfully logged out from all devices ({count} sessions terminated)",
) )
except Exception as e: except Exception as e:
logger.error(f"Error during logout-all for user {current_user.id}: {str(e)}", exc_info=True) logger.error(
f"Error during logout-all for user {current_user.id}: {e!s}", exc_info=True
)
await db.rollback() await db.rollback()
raise HTTPException( raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="An error occurred while logging out" detail="An error occurred while logging out",
) )

View File

@@ -4,8 +4,9 @@ Organization endpoints for regular users.
These endpoints allow users to view and manage organizations they belong to. These endpoints allow users to view and manage organizations they belong to.
""" """
import logging import logging
from typing import Any, List from typing import Any
from uuid import UUID from uuid import UUID
from fastapi import APIRouter, Depends, Query from fastapi import APIRouter, Depends, Query
@@ -14,18 +15,18 @@ from sqlalchemy.ext.asyncio import AsyncSession
from app.api.dependencies.auth import get_current_user from app.api.dependencies.auth import get_current_user
from app.api.dependencies.permissions import require_org_admin, require_org_membership from app.api.dependencies.permissions import require_org_admin, require_org_membership
from app.core.database import get_db from app.core.database import get_db
from app.core.exceptions import NotFoundError, ErrorCode from app.core.exceptions import ErrorCode, NotFoundError
from app.crud.organization import organization as organization_crud from app.crud.organization import organization as organization_crud
from app.models.user import User from app.models.user import User
from app.schemas.common import ( from app.schemas.common import (
PaginationParams,
PaginatedResponse, PaginatedResponse,
create_pagination_meta PaginationParams,
create_pagination_meta,
) )
from app.schemas.organizations import ( from app.schemas.organizations import (
OrganizationResponse,
OrganizationMemberResponse, OrganizationMemberResponse,
OrganizationUpdate OrganizationResponse,
OrganizationUpdate,
) )
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -35,15 +36,15 @@ router = APIRouter()
@router.get( @router.get(
"/me", "/me",
response_model=List[OrganizationResponse], response_model=list[OrganizationResponse],
summary="Get My Organizations", summary="Get My Organizations",
description="Get all organizations the current user belongs to", description="Get all organizations the current user belongs to",
operation_id="get_my_organizations" operation_id="get_my_organizations",
) )
async def get_my_organizations( async def get_my_organizations(
is_active: bool = Query(True, description="Filter by active membership"), is_active: bool = Query(True, description="Filter by active membership"),
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db) db: AsyncSession = Depends(get_db),
) -> Any: ) -> Any:
""" """
Get all organizations the current user belongs to. Get all organizations the current user belongs to.
@@ -54,15 +55,13 @@ async def get_my_organizations(
try: try:
# Get all org data in single query with JOIN and subquery # Get all org data in single query with JOIN and subquery
orgs_data = await organization_crud.get_user_organizations_with_details( orgs_data = await organization_crud.get_user_organizations_with_details(
db, db, user_id=current_user.id, is_active=is_active
user_id=current_user.id,
is_active=is_active
) )
# Transform to response objects # Transform to response objects
orgs_with_data = [] orgs_with_data = []
for item in orgs_data: for item in orgs_data:
org = item['organization'] org = item["organization"]
org_dict = { org_dict = {
"id": org.id, "id": org.id,
"name": org.name, "name": org.name,
@@ -72,14 +71,14 @@ async def get_my_organizations(
"settings": org.settings, "settings": org.settings,
"created_at": org.created_at, "created_at": org.created_at,
"updated_at": org.updated_at, "updated_at": org.updated_at,
"member_count": item['member_count'] "member_count": item["member_count"],
} }
orgs_with_data.append(OrganizationResponse(**org_dict)) orgs_with_data.append(OrganizationResponse(**org_dict))
return orgs_with_data return orgs_with_data
except Exception as e: except Exception as e:
logger.error(f"Error getting user organizations: {str(e)}", exc_info=True) logger.error(f"Error getting user organizations: {e!s}", exc_info=True)
raise raise
@@ -88,12 +87,12 @@ async def get_my_organizations(
response_model=OrganizationResponse, response_model=OrganizationResponse,
summary="Get Organization Details", summary="Get Organization Details",
description="Get details of an organization the user belongs to", description="Get details of an organization the user belongs to",
operation_id="get_organization" operation_id="get_organization",
) )
async def get_organization( async def get_organization(
organization_id: UUID, organization_id: UUID,
current_user: User = Depends(require_org_membership), current_user: User = Depends(require_org_membership),
db: AsyncSession = Depends(get_db) db: AsyncSession = Depends(get_db),
) -> Any: ) -> Any:
""" """
Get details of a specific organization. Get details of a specific organization.
@@ -105,7 +104,7 @@ async def get_organization(
if not org: # pragma: no cover - Permission check prevents this (see docs/UNREACHABLE_DEFENSIVE_CODE_ANALYSIS.md) if not org: # pragma: no cover - Permission check prevents this (see docs/UNREACHABLE_DEFENSIVE_CODE_ANALYSIS.md)
raise NotFoundError( raise NotFoundError(
detail=f"Organization {organization_id} not found", detail=f"Organization {organization_id} not found",
error_code=ErrorCode.NOT_FOUND error_code=ErrorCode.NOT_FOUND,
) )
org_dict = { org_dict = {
@@ -117,14 +116,16 @@ async def get_organization(
"settings": org.settings, "settings": org.settings,
"created_at": org.created_at, "created_at": org.created_at,
"updated_at": org.updated_at, "updated_at": org.updated_at,
"member_count": await organization_crud.get_member_count(db, organization_id=org.id) "member_count": await organization_crud.get_member_count(
db, organization_id=org.id
),
} }
return OrganizationResponse(**org_dict) return OrganizationResponse(**org_dict)
except NotFoundError: # pragma: no cover - See above except NotFoundError: # pragma: no cover - See above
raise raise
except Exception as e: except Exception as e:
logger.error(f"Error getting organization: {str(e)}", exc_info=True) logger.error(f"Error getting organization: {e!s}", exc_info=True)
raise raise
@@ -133,14 +134,14 @@ async def get_organization(
response_model=PaginatedResponse[OrganizationMemberResponse], response_model=PaginatedResponse[OrganizationMemberResponse],
summary="Get Organization Members", summary="Get Organization Members",
description="Get all members of an organization (members can view)", description="Get all members of an organization (members can view)",
operation_id="get_organization_members" operation_id="get_organization_members",
) )
async def get_organization_members( async def get_organization_members(
organization_id: UUID, organization_id: UUID,
pagination: PaginationParams = Depends(), pagination: PaginationParams = Depends(),
is_active: bool = Query(True, description="Filter by active status"), is_active: bool = Query(True, description="Filter by active status"),
current_user: User = Depends(require_org_membership), current_user: User = Depends(require_org_membership),
db: AsyncSession = Depends(get_db) db: AsyncSession = Depends(get_db),
) -> Any: ) -> Any:
""" """
Get all members of an organization. Get all members of an organization.
@@ -153,7 +154,7 @@ async def get_organization_members(
organization_id=organization_id, organization_id=organization_id,
skip=pagination.offset, skip=pagination.offset,
limit=pagination.limit, limit=pagination.limit,
is_active=is_active is_active=is_active,
) )
member_responses = [OrganizationMemberResponse(**member) for member in members] member_responses = [OrganizationMemberResponse(**member) for member in members]
@@ -162,13 +163,13 @@ async def get_organization_members(
total=total, total=total,
page=pagination.page, page=pagination.page,
limit=pagination.limit, limit=pagination.limit,
items_count=len(member_responses) items_count=len(member_responses),
) )
return PaginatedResponse(data=member_responses, pagination=pagination_meta) return PaginatedResponse(data=member_responses, pagination=pagination_meta)
except Exception as e: except Exception as e:
logger.error(f"Error getting organization members: {str(e)}", exc_info=True) logger.error(f"Error getting organization members: {e!s}", exc_info=True)
raise raise
@@ -177,13 +178,13 @@ async def get_organization_members(
response_model=OrganizationResponse, response_model=OrganizationResponse,
summary="Update Organization", summary="Update Organization",
description="Update organization details (admin/owner only)", description="Update organization details (admin/owner only)",
operation_id="update_organization" operation_id="update_organization",
) )
async def update_organization( async def update_organization(
organization_id: UUID, organization_id: UUID,
org_in: OrganizationUpdate, org_in: OrganizationUpdate,
current_user: User = Depends(require_org_admin), current_user: User = Depends(require_org_admin),
db: AsyncSession = Depends(get_db) db: AsyncSession = Depends(get_db),
) -> Any: ) -> Any:
""" """
Update organization details. Update organization details.
@@ -195,11 +196,13 @@ async def update_organization(
if not org: # pragma: no cover - Permission check prevents this (see docs/UNREACHABLE_DEFENSIVE_CODE_ANALYSIS.md) if not org: # pragma: no cover - Permission check prevents this (see docs/UNREACHABLE_DEFENSIVE_CODE_ANALYSIS.md)
raise NotFoundError( raise NotFoundError(
detail=f"Organization {organization_id} not found", detail=f"Organization {organization_id} not found",
error_code=ErrorCode.NOT_FOUND error_code=ErrorCode.NOT_FOUND,
) )
updated_org = await organization_crud.update(db, db_obj=org, obj_in=org_in) updated_org = await organization_crud.update(db, db_obj=org, obj_in=org_in)
logger.info(f"User {current_user.email} updated organization {updated_org.name}") logger.info(
f"User {current_user.email} updated organization {updated_org.name}"
)
org_dict = { org_dict = {
"id": updated_org.id, "id": updated_org.id,
@@ -210,12 +213,14 @@ async def update_organization(
"settings": updated_org.settings, "settings": updated_org.settings,
"created_at": updated_org.created_at, "created_at": updated_org.created_at,
"updated_at": updated_org.updated_at, "updated_at": updated_org.updated_at,
"member_count": await organization_crud.get_member_count(db, organization_id=updated_org.id) "member_count": await organization_crud.get_member_count(
db, organization_id=updated_org.id
),
} }
return OrganizationResponse(**org_dict) return OrganizationResponse(**org_dict)
except NotFoundError: # pragma: no cover - See above except NotFoundError: # pragma: no cover - See above
raise raise
except Exception as e: except Exception as e:
logger.error(f"Error updating organization: {str(e)}", exc_info=True) logger.error(f"Error updating organization: {e!s}", exc_info=True)
raise raise

View File

@@ -3,11 +3,12 @@ Session management endpoints.
Allows users to view and manage their active sessions across devices. Allows users to view and manage their active sessions across devices.
""" """
import logging import logging
from typing import Any from typing import Any
from uuid import UUID from uuid import UUID
from fastapi import APIRouter, Depends, HTTPException, status, Request from fastapi import APIRouter, Depends, HTTPException, Request, status
from slowapi import Limiter from slowapi import Limiter
from slowapi.util import get_remote_address from slowapi.util import get_remote_address
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
@@ -15,11 +16,11 @@ from sqlalchemy.ext.asyncio import AsyncSession
from app.api.dependencies.auth import get_current_user from app.api.dependencies.auth import get_current_user
from app.core.auth import decode_token from app.core.auth import decode_token
from app.core.database import get_db from app.core.database import get_db
from app.core.exceptions import NotFoundError, AuthorizationError, ErrorCode from app.core.exceptions import AuthorizationError, ErrorCode, NotFoundError
from app.crud.session import session as session_crud from app.crud.session import session as session_crud
from app.models.user import User from app.models.user import User
from app.schemas.common import MessageResponse from app.schemas.common import MessageResponse
from app.schemas.sessions import SessionResponse, SessionListResponse from app.schemas.sessions import SessionListResponse, SessionResponse
router = APIRouter() router = APIRouter()
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -39,13 +40,13 @@ limiter = Limiter(key_func=get_remote_address)
**Rate Limit**: 30 requests/minute **Rate Limit**: 30 requests/minute
""", """,
operation_id="list_my_sessions" operation_id="list_my_sessions",
) )
@limiter.limit("30/minute") @limiter.limit("30/minute")
async def list_my_sessions( async def list_my_sessions(
request: Request, request: Request,
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db) db: AsyncSession = Depends(get_db),
) -> Any: ) -> Any:
""" """
List all active sessions for the current user. List all active sessions for the current user.
@@ -60,18 +61,15 @@ async def list_my_sessions(
try: try:
# Get all active sessions for user # Get all active sessions for user
sessions = await session_crud.get_user_sessions( sessions = await session_crud.get_user_sessions(
db, db, user_id=str(current_user.id), active_only=True
user_id=str(current_user.id),
active_only=True
) )
# Try to identify current session from Authorization header # Try to identify current session from Authorization header
current_session_jti = None
auth_header = request.headers.get("authorization") auth_header = request.headers.get("authorization")
if auth_header and auth_header.startswith("Bearer "): if auth_header and auth_header.startswith("Bearer "):
try: try:
access_token = auth_header.split(" ")[1] access_token = auth_header.split(" ")[1]
token_payload = decode_token(access_token) decode_token(access_token)
# Note: Access tokens don't have JTI by default, but we can try # Note: Access tokens don't have JTI by default, but we can try
# For now, we'll mark current based on most recent activity # For now, we'll mark current based on most recent activity
except Exception: except Exception:
@@ -90,22 +88,27 @@ async def list_my_sessions(
last_used_at=s.last_used_at, last_used_at=s.last_used_at,
created_at=s.created_at, created_at=s.created_at,
expires_at=s.expires_at, expires_at=s.expires_at,
is_current=(s == sessions[0] if sessions else False) # Most recent = current is_current=(
s == sessions[0] if sessions else False
), # Most recent = current
) )
session_responses.append(session_response) session_responses.append(session_response)
logger.info(f"User {current_user.id} listed {len(session_responses)} active sessions") logger.info(
f"User {current_user.id} listed {len(session_responses)} active sessions"
)
return SessionListResponse( return SessionListResponse(
sessions=session_responses, sessions=session_responses, total=len(session_responses)
total=len(session_responses)
) )
except Exception as e: except Exception as e:
logger.error(f"Error listing sessions for user {current_user.id}: {str(e)}", exc_info=True) logger.error(
f"Error listing sessions for user {current_user.id}: {e!s}", exc_info=True
)
raise HTTPException( raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to retrieve sessions" detail="Failed to retrieve sessions",
) )
@@ -122,14 +125,14 @@ async def list_my_sessions(
**Rate Limit**: 10 requests/minute **Rate Limit**: 10 requests/minute
""", """,
operation_id="revoke_session" operation_id="revoke_session",
) )
@limiter.limit("10/minute") @limiter.limit("10/minute")
async def revoke_session( async def revoke_session(
request: Request, request: Request,
session_id: UUID, session_id: UUID,
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db) db: AsyncSession = Depends(get_db),
) -> Any: ) -> Any:
""" """
Revoke a specific session by ID. Revoke a specific session by ID.
@@ -149,7 +152,7 @@ async def revoke_session(
if not session: if not session:
raise NotFoundError( raise NotFoundError(
message=f"Session {session_id} not found", message=f"Session {session_id} not found",
error_code=ErrorCode.NOT_FOUND error_code=ErrorCode.NOT_FOUND,
) )
# Verify session belongs to current user # Verify session belongs to current user
@@ -160,7 +163,7 @@ async def revoke_session(
) )
raise AuthorizationError( raise AuthorizationError(
message="You can only revoke your own sessions", message="You can only revoke your own sessions",
error_code=ErrorCode.INSUFFICIENT_PERMISSIONS error_code=ErrorCode.INSUFFICIENT_PERMISSIONS,
) )
# Deactivate the session # Deactivate the session
@@ -173,16 +176,16 @@ async def revoke_session(
return MessageResponse( return MessageResponse(
success=True, success=True,
message=f"Session revoked: {session.device_name or 'Unknown device'}" message=f"Session revoked: {session.device_name or 'Unknown device'}",
) )
except (NotFoundError, AuthorizationError): except (NotFoundError, AuthorizationError):
raise raise
except Exception as e: except Exception as e:
logger.error(f"Error revoking session {session_id}: {str(e)}", exc_info=True) logger.error(f"Error revoking session {session_id}: {e!s}", exc_info=True)
raise HTTPException( raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to revoke session" detail="Failed to revoke session",
) )
@@ -198,13 +201,13 @@ async def revoke_session(
**Rate Limit**: 5 requests/minute **Rate Limit**: 5 requests/minute
""", """,
operation_id="cleanup_expired_sessions" operation_id="cleanup_expired_sessions",
) )
@limiter.limit("5/minute") @limiter.limit("5/minute")
async def cleanup_expired_sessions( async def cleanup_expired_sessions(
request: Request, request: Request,
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db) db: AsyncSession = Depends(get_db),
) -> Any: ) -> Any:
""" """
Cleanup expired sessions for the current user. Cleanup expired sessions for the current user.
@@ -219,21 +222,24 @@ async def cleanup_expired_sessions(
try: try:
# Use optimized bulk DELETE instead of N individual deletes # Use optimized bulk DELETE instead of N individual deletes
deleted_count = await session_crud.cleanup_expired_for_user( deleted_count = await session_crud.cleanup_expired_for_user(
db, db, user_id=str(current_user.id)
user_id=str(current_user.id)
) )
logger.info(f"User {current_user.id} cleaned up {deleted_count} expired sessions") logger.info(
f"User {current_user.id} cleaned up {deleted_count} expired sessions"
)
return MessageResponse( return MessageResponse(
success=True, success=True, message=f"Cleaned up {deleted_count} expired sessions"
message=f"Cleaned up {deleted_count} expired sessions"
) )
except Exception as e: except Exception as e:
logger.error(f"Error cleaning up sessions for user {current_user.id}: {str(e)}", exc_info=True) logger.error(
f"Error cleaning up sessions for user {current_user.id}: {e!s}",
exc_info=True,
)
await db.rollback() await db.rollback()
raise HTTPException( raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to cleanup sessions" detail="Failed to cleanup sessions",
) )

View File

@@ -1,33 +1,30 @@
""" """
User management endpoints for CRUD operations. User management endpoints for CRUD operations.
""" """
import logging import logging
from typing import Any, Optional from typing import Any
from uuid import UUID from uuid import UUID
from fastapi import APIRouter, Depends, Query, status, Request from fastapi import APIRouter, Depends, Query, Request, status
from slowapi import Limiter from slowapi import Limiter
from slowapi.util import get_remote_address from slowapi.util import get_remote_address
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from app.api.dependencies.auth import get_current_user, get_current_superuser from app.api.dependencies.auth import get_current_superuser, get_current_user
from app.core.database import get_db from app.core.database import get_db
from app.core.exceptions import ( from app.core.exceptions import AuthorizationError, ErrorCode, NotFoundError
NotFoundError,
AuthorizationError,
ErrorCode
)
from app.crud.user import user as user_crud from app.crud.user import user as user_crud
from app.models.user import User from app.models.user import User
from app.schemas.common import ( from app.schemas.common import (
PaginationParams,
PaginatedResponse,
MessageResponse, MessageResponse,
PaginatedResponse,
PaginationParams,
SortParams, SortParams,
create_pagination_meta create_pagination_meta,
) )
from app.schemas.users import UserResponse, UserUpdate, PasswordChange from app.schemas.users import PasswordChange, UserResponse, UserUpdate
from app.services.auth_service import AuthService, AuthenticationError from app.services.auth_service import AuthenticationError, AuthService
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -50,15 +47,15 @@ limiter = Limiter(key_func=get_remote_address)
**Rate Limit**: 60 requests/minute **Rate Limit**: 60 requests/minute
""", """,
operation_id="list_users" operation_id="list_users",
) )
async def list_users( async def list_users(
pagination: PaginationParams = Depends(), pagination: PaginationParams = Depends(),
sort: SortParams = Depends(), sort: SortParams = Depends(),
is_active: Optional[bool] = Query(None, description="Filter by active status"), is_active: bool | None = Query(None, description="Filter by active status"),
is_superuser: Optional[bool] = Query(None, description="Filter by superuser status"), is_superuser: bool | None = Query(None, description="Filter by superuser status"),
current_user: User = Depends(get_current_superuser), current_user: User = Depends(get_current_superuser),
db: AsyncSession = Depends(get_db) db: AsyncSession = Depends(get_db),
) -> Any: ) -> Any:
""" """
List all users with pagination, filtering, and sorting. List all users with pagination, filtering, and sorting.
@@ -80,7 +77,7 @@ async def list_users(
limit=pagination.limit, limit=pagination.limit,
sort_by=sort.sort_by, sort_by=sort.sort_by,
sort_order=sort.sort_order.value if sort.sort_order else "asc", sort_order=sort.sort_order.value if sort.sort_order else "asc",
filters=filters if filters else None filters=filters if filters else None,
) )
# Create pagination metadata # Create pagination metadata
@@ -88,15 +85,12 @@ async def list_users(
total=total, total=total,
page=pagination.page, page=pagination.page,
limit=pagination.limit, limit=pagination.limit,
items_count=len(users) items_count=len(users),
) )
return PaginatedResponse( return PaginatedResponse(data=users, pagination=pagination_meta)
data=users,
pagination=pagination_meta
)
except Exception as e: except Exception as e:
logger.error(f"Error listing users: {str(e)}", exc_info=True) logger.error(f"Error listing users: {e!s}", exc_info=True)
raise raise
@@ -111,11 +105,9 @@ async def list_users(
**Rate Limit**: 60 requests/minute **Rate Limit**: 60 requests/minute
""", """,
operation_id="get_current_user_profile" operation_id="get_current_user_profile",
) )
def get_current_user_profile( def get_current_user_profile(current_user: User = Depends(get_current_user)) -> Any:
current_user: User = Depends(get_current_user)
) -> Any:
"""Get current user's profile.""" """Get current user's profile."""
return current_user return current_user
@@ -133,12 +125,12 @@ def get_current_user_profile(
**Rate Limit**: 30 requests/minute **Rate Limit**: 30 requests/minute
""", """,
operation_id="update_current_user" operation_id="update_current_user",
) )
async def update_current_user( async def update_current_user(
user_update: UserUpdate, user_update: UserUpdate,
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db) db: AsyncSession = Depends(get_db),
) -> Any: ) -> Any:
""" """
Update current user's profile. Update current user's profile.
@@ -147,17 +139,17 @@ async def update_current_user(
""" """
try: try:
updated_user = await user_crud.update( updated_user = await user_crud.update(
db, db, db_obj=current_user, obj_in=user_update
db_obj=current_user,
obj_in=user_update
) )
logger.info(f"User {current_user.id} updated their profile") logger.info(f"User {current_user.id} updated their profile")
return updated_user return updated_user
except ValueError as e: except ValueError as e:
logger.error(f"Error updating user {current_user.id}: {str(e)}") logger.error(f"Error updating user {current_user.id}: {e!s}")
raise raise
except Exception as e: except Exception as e:
logger.error(f"Unexpected error updating user {current_user.id}: {str(e)}", exc_info=True) logger.error(
f"Unexpected error updating user {current_user.id}: {e!s}", exc_info=True
)
raise raise
@@ -175,12 +167,12 @@ async def update_current_user(
**Rate Limit**: 60 requests/minute **Rate Limit**: 60 requests/minute
""", """,
operation_id="get_user_by_id" operation_id="get_user_by_id",
) )
async def get_user_by_id( async def get_user_by_id(
user_id: UUID, user_id: UUID,
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db) db: AsyncSession = Depends(get_db),
) -> Any: ) -> Any:
""" """
Get user by ID. Get user by ID.
@@ -194,7 +186,7 @@ async def get_user_by_id(
) )
raise AuthorizationError( raise AuthorizationError(
message="Not enough permissions to view this user", message="Not enough permissions to view this user",
error_code=ErrorCode.INSUFFICIENT_PERMISSIONS error_code=ErrorCode.INSUFFICIENT_PERMISSIONS,
) )
# Get user # Get user
@@ -202,7 +194,7 @@ async def get_user_by_id(
if not user: if not user:
raise NotFoundError( raise NotFoundError(
message=f"User with id {user_id} not found", message=f"User with id {user_id} not found",
error_code=ErrorCode.USER_NOT_FOUND error_code=ErrorCode.USER_NOT_FOUND,
) )
return user return user
@@ -222,13 +214,13 @@ async def get_user_by_id(
**Rate Limit**: 30 requests/minute **Rate Limit**: 30 requests/minute
""", """,
operation_id="update_user" operation_id="update_user",
) )
async def update_user( async def update_user(
user_id: UUID, user_id: UUID,
user_update: UserUpdate, user_update: UserUpdate,
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db) db: AsyncSession = Depends(get_db),
) -> Any: ) -> Any:
""" """
Update user by ID. Update user by ID.
@@ -245,7 +237,7 @@ async def update_user(
) )
raise AuthorizationError( raise AuthorizationError(
message="Not enough permissions to update this user", message="Not enough permissions to update this user",
error_code=ErrorCode.INSUFFICIENT_PERMISSIONS error_code=ErrorCode.INSUFFICIENT_PERMISSIONS,
) )
# Get user # Get user
@@ -253,7 +245,7 @@ async def update_user(
if not user: if not user:
raise NotFoundError( raise NotFoundError(
message=f"User with id {user_id} not found", message=f"User with id {user_id} not found",
error_code=ErrorCode.USER_NOT_FOUND error_code=ErrorCode.USER_NOT_FOUND,
) )
try: try:
@@ -261,10 +253,10 @@ async def update_user(
logger.info(f"User {user_id} updated by {current_user.id}") logger.info(f"User {user_id} updated by {current_user.id}")
return updated_user return updated_user
except ValueError as e: except ValueError as e:
logger.error(f"Error updating user {user_id}: {str(e)}") logger.error(f"Error updating user {user_id}: {e!s}")
raise raise
except Exception as e: except Exception as e:
logger.error(f"Unexpected error updating user {user_id}: {str(e)}", exc_info=True) logger.error(f"Unexpected error updating user {user_id}: {e!s}", exc_info=True)
raise raise
@@ -281,14 +273,14 @@ async def update_user(
**Rate Limit**: 5 requests/minute **Rate Limit**: 5 requests/minute
""", """,
operation_id="change_current_user_password" operation_id="change_current_user_password",
) )
@limiter.limit("5/minute") @limiter.limit("5/minute")
async def change_current_user_password( async def change_current_user_password(
request: Request, request: Request,
password_change: PasswordChange, password_change: PasswordChange,
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db) db: AsyncSession = Depends(get_db),
) -> Any: ) -> Any:
""" """
Change current user's password. Change current user's password.
@@ -300,23 +292,23 @@ async def change_current_user_password(
db=db, db=db,
user_id=current_user.id, user_id=current_user.id,
current_password=password_change.current_password, current_password=password_change.current_password,
new_password=password_change.new_password new_password=password_change.new_password,
) )
if success: if success:
logger.info(f"User {current_user.id} changed their password") logger.info(f"User {current_user.id} changed their password")
return MessageResponse( return MessageResponse(
success=True, success=True, message="Password changed successfully"
message="Password changed successfully"
) )
except AuthenticationError as e: except AuthenticationError as e:
logger.warning(f"Failed password change attempt for user {current_user.id}: {str(e)}") logger.warning(
f"Failed password change attempt for user {current_user.id}: {e!s}"
)
raise AuthorizationError( raise AuthorizationError(
message=str(e), message=str(e), error_code=ErrorCode.INVALID_CREDENTIALS
error_code=ErrorCode.INVALID_CREDENTIALS
) )
except Exception as e: except Exception as e:
logger.error(f"Error changing password for user {current_user.id}: {str(e)}") logger.error(f"Error changing password for user {current_user.id}: {e!s}")
raise raise
@@ -335,12 +327,12 @@ async def change_current_user_password(
**Note**: This performs a hard delete. Consider implementing soft deletes for production. **Note**: This performs a hard delete. Consider implementing soft deletes for production.
""", """,
operation_id="delete_user" operation_id="delete_user",
) )
async def delete_user( async def delete_user(
user_id: UUID, user_id: UUID,
current_user: User = Depends(get_current_superuser), current_user: User = Depends(get_current_superuser),
db: AsyncSession = Depends(get_db) db: AsyncSession = Depends(get_db),
) -> Any: ) -> Any:
""" """
Delete user by ID (superuser only). Delete user by ID (superuser only).
@@ -351,7 +343,7 @@ async def delete_user(
if str(user_id) == str(current_user.id): if str(user_id) == str(current_user.id):
raise AuthorizationError( raise AuthorizationError(
message="Cannot delete your own account", message="Cannot delete your own account",
error_code=ErrorCode.INSUFFICIENT_PERMISSIONS error_code=ErrorCode.INSUFFICIENT_PERMISSIONS,
) )
# Get user # Get user
@@ -359,7 +351,7 @@ async def delete_user(
if not user: if not user:
raise NotFoundError( raise NotFoundError(
message=f"User with id {user_id} not found", message=f"User with id {user_id} not found",
error_code=ErrorCode.USER_NOT_FOUND error_code=ErrorCode.USER_NOT_FOUND,
) )
try: try:
@@ -367,12 +359,11 @@ async def delete_user(
await user_crud.soft_delete(db, id=str(user_id)) await user_crud.soft_delete(db, id=str(user_id))
logger.info(f"User {user_id} soft-deleted by {current_user.id}") logger.info(f"User {user_id} soft-deleted by {current_user.id}")
return MessageResponse( return MessageResponse(
success=True, success=True, message=f"User {user_id} deleted successfully"
message=f"User {user_id} deleted successfully"
) )
except ValueError as e: except ValueError as e:
logger.error(f"Error deleting user {user_id}: {str(e)}") logger.error(f"Error deleting user {user_id}: {e!s}")
raise raise
except Exception as e: except Exception as e:
logger.error(f"Unexpected error deleting user {user_id}: {str(e)}", exc_info=True) logger.error(f"Unexpected error deleting user {user_id}: {e!s}", exc_info=True)
raise raise

View File

@@ -1,39 +1,39 @@
import logging import logging
logging.getLogger('passlib').setLevel(logging.ERROR)
from datetime import datetime, timedelta, timezone logging.getLogger("passlib").setLevel(logging.ERROR)
from typing import Any, Dict, Optional, Union
import uuid
import asyncio import asyncio
import uuid
from datetime import UTC, datetime, timedelta
from functools import partial from functools import partial
from typing import Any
from jose import jwt, JWTError from jose import JWTError, jwt
from passlib.context import CryptContext from passlib.context import CryptContext
from pydantic import ValidationError from pydantic import ValidationError
from app.core.config import settings from app.core.config import settings
from app.schemas.users import TokenData, TokenPayload from app.schemas.users import TokenData, TokenPayload
# Password hashing context # Password hashing context
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto") pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
# Custom exceptions for auth # Custom exceptions for auth
class AuthError(Exception): class AuthError(Exception):
"""Base authentication error""" """Base authentication error"""
pass
class TokenExpiredError(AuthError): class TokenExpiredError(AuthError):
"""Token has expired""" """Token has expired"""
pass
class TokenInvalidError(AuthError): class TokenInvalidError(AuthError):
"""Token is invalid""" """Token is invalid"""
pass
class TokenMissingClaimError(AuthError): class TokenMissingClaimError(AuthError):
"""Token is missing a required claim""" """Token is missing a required claim"""
pass
def verify_password(plain_password: str, hashed_password: str) -> bool: def verify_password(plain_password: str, hashed_password: str) -> bool:
@@ -62,8 +62,7 @@ async def verify_password_async(plain_password: str, hashed_password: str) -> bo
""" """
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()
return await loop.run_in_executor( return await loop.run_in_executor(
None, None, partial(pwd_context.verify, plain_password, hashed_password)
partial(pwd_context.verify, plain_password, hashed_password)
) )
@@ -82,17 +81,13 @@ async def get_password_hash_async(password: str) -> str:
Hashed password string Hashed password string
""" """
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()
return await loop.run_in_executor( return await loop.run_in_executor(None, pwd_context.hash, password)
None,
pwd_context.hash,
password
)
def create_access_token( def create_access_token(
subject: Union[str, Any], subject: str | Any,
expires_delta: Optional[timedelta] = None, expires_delta: timedelta | None = None,
claims: Optional[Dict[str, Any]] = None claims: dict[str, Any] | None = None,
) -> str: ) -> str:
""" """
Create a JWT access token. Create a JWT access token.
@@ -106,17 +101,19 @@ def create_access_token(
Encoded JWT token Encoded JWT token
""" """
if expires_delta: if expires_delta:
expire = datetime.now(timezone.utc) + expires_delta expire = datetime.now(UTC) + expires_delta
else: else:
expire = datetime.now(timezone.utc) + timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES) expire = datetime.now(UTC) + timedelta(
minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES
)
# Base token data # Base token data
to_encode = { to_encode = {
"sub": str(subject), "sub": str(subject),
"exp": expire, "exp": expire,
"iat": datetime.now(tz=timezone.utc), "iat": datetime.now(tz=UTC),
"jti": str(uuid.uuid4()), "jti": str(uuid.uuid4()),
"type": "access" "type": "access",
} }
# Add custom claims # Add custom claims
@@ -125,17 +122,14 @@ def create_access_token(
# Create the JWT # Create the JWT
encoded_jwt = jwt.encode( encoded_jwt = jwt.encode(
to_encode, to_encode, settings.SECRET_KEY, algorithm=settings.ALGORITHM
settings.SECRET_KEY,
algorithm=settings.ALGORITHM
) )
return encoded_jwt return encoded_jwt
def create_refresh_token( def create_refresh_token(
subject: Union[str, Any], subject: str | Any, expires_delta: timedelta | None = None
expires_delta: Optional[timedelta] = None
) -> str: ) -> str:
""" """
Create a JWT refresh token. Create a JWT refresh token.
@@ -148,28 +142,26 @@ def create_refresh_token(
Encoded JWT refresh token Encoded JWT refresh token
""" """
if expires_delta: if expires_delta:
expire = datetime.now(timezone.utc) + expires_delta expire = datetime.now(UTC) + expires_delta
else: else:
expire = datetime.now(timezone.utc) + timedelta(days=settings.REFRESH_TOKEN_EXPIRE_DAYS) expire = datetime.now(UTC) + timedelta(days=settings.REFRESH_TOKEN_EXPIRE_DAYS)
to_encode = { to_encode = {
"sub": str(subject), "sub": str(subject),
"exp": expire, "exp": expire,
"iat": datetime.now(timezone.utc), "iat": datetime.now(UTC),
"jti": str(uuid.uuid4()), "jti": str(uuid.uuid4()),
"type": "refresh" "type": "refresh",
} }
encoded_jwt = jwt.encode( encoded_jwt = jwt.encode(
to_encode, to_encode, settings.SECRET_KEY, algorithm=settings.ALGORITHM
settings.SECRET_KEY,
algorithm=settings.ALGORITHM
) )
return encoded_jwt return encoded_jwt
def decode_token(token: str, verify_type: Optional[str] = None) -> TokenPayload: def decode_token(token: str, verify_type: str | None = None) -> TokenPayload:
""" """
Decode and verify a JWT token. Decode and verify a JWT token.
@@ -195,8 +187,8 @@ def decode_token(token: str, verify_type: Optional[str] = None) -> TokenPayload:
"verify_signature": True, "verify_signature": True,
"verify_exp": True, "verify_exp": True,
"verify_iat": True, "verify_iat": True,
"require": ["exp", "sub", "iat"] "require": ["exp", "sub", "iat"],
} },
) )
# SECURITY: Explicitly verify the algorithm to prevent algorithm confusion attacks # SECURITY: Explicitly verify the algorithm to prevent algorithm confusion attacks

View File

@@ -1,5 +1,4 @@
import logging import logging
from typing import Optional, List
from pydantic import Field, field_validator from pydantic import Field, field_validator
from pydantic_settings import BaseSettings from pydantic_settings import BaseSettings
@@ -13,7 +12,7 @@ class Settings(BaseSettings):
# Environment (must be before SECRET_KEY for validation) # Environment (must be before SECRET_KEY for validation)
ENVIRONMENT: str = Field( ENVIRONMENT: str = Field(
default="development", default="development",
description="Environment: development, staging, or production" description="Environment: development, staging, or production",
) )
# Security: Content Security Policy # Security: Content Security Policy
@@ -21,8 +20,7 @@ class Settings(BaseSettings):
# Set to True for strict CSP (blocks most external resources) # Set to True for strict CSP (blocks most external resources)
# Set to "relaxed" for modern frontend development # Set to "relaxed" for modern frontend development
CSP_MODE: str = Field( CSP_MODE: str = Field(
default="relaxed", default="relaxed", description="CSP mode: 'strict', 'relaxed', or 'disabled'"
description="CSP mode: 'strict', 'relaxed', or 'disabled'"
) )
# Database configuration # Database configuration
@@ -31,7 +29,7 @@ class Settings(BaseSettings):
POSTGRES_HOST: str = "localhost" POSTGRES_HOST: str = "localhost"
POSTGRES_PORT: str = "5432" POSTGRES_PORT: str = "5432"
POSTGRES_DB: str = "app" POSTGRES_DB: str = "app"
DATABASE_URL: Optional[str] = None DATABASE_URL: str | None = None
db_pool_size: int = 20 # Default connection pool size db_pool_size: int = 20 # Default connection pool size
db_max_overflow: int = 50 # Maximum overflow connections db_max_overflow: int = 50 # Maximum overflow connections
db_pool_timeout: int = 30 # Seconds to wait for a connection db_pool_timeout: int = 30 # Seconds to wait for a connection
@@ -59,38 +57,36 @@ class Settings(BaseSettings):
SECRET_KEY: str = Field( SECRET_KEY: str = Field(
default="dev_only_insecure_key_change_in_production_32chars_min", default="dev_only_insecure_key_change_in_production_32chars_min",
min_length=32, min_length=32,
description="JWT signing key. MUST be changed in production. Generate with: python -c 'import secrets; print(secrets.token_urlsafe(32))'" description="JWT signing key. MUST be changed in production. Generate with: python -c 'import secrets; print(secrets.token_urlsafe(32))'",
) )
ALGORITHM: str = "HS256" ALGORITHM: str = "HS256"
ACCESS_TOKEN_EXPIRE_MINUTES: int = 15 # 15 minutes (production standard) ACCESS_TOKEN_EXPIRE_MINUTES: int = 15 # 15 minutes (production standard)
REFRESH_TOKEN_EXPIRE_DAYS: int = 7 # 7 days REFRESH_TOKEN_EXPIRE_DAYS: int = 7 # 7 days
# CORS configuration # CORS configuration
BACKEND_CORS_ORIGINS: List[str] = ["http://localhost:3000"] BACKEND_CORS_ORIGINS: list[str] = ["http://localhost:3000"]
# Frontend URL for email links # Frontend URL for email links
FRONTEND_URL: str = Field( FRONTEND_URL: str = Field(
default="http://localhost:3000", default="http://localhost:3000",
description="Frontend application URL for email links" description="Frontend application URL for email links",
) )
# Admin user # Admin user
FIRST_SUPERUSER_EMAIL: Optional[str] = Field( FIRST_SUPERUSER_EMAIL: str | None = Field(
default=None, default=None, description="Email for first superuser account"
description="Email for first superuser account"
) )
FIRST_SUPERUSER_PASSWORD: Optional[str] = Field( FIRST_SUPERUSER_PASSWORD: str | None = Field(
default=None, default=None, description="Password for first superuser (min 12 characters)"
description="Password for first superuser (min 12 characters)"
) )
@field_validator('SECRET_KEY') @field_validator("SECRET_KEY")
@classmethod @classmethod
def validate_secret_key(cls, v: str, info) -> str: def validate_secret_key(cls, v: str, info) -> str:
"""Validate SECRET_KEY is secure, especially in production.""" """Validate SECRET_KEY is secure, especially in production."""
# Get environment from values if available # Get environment from values if available
values_data = info.data if info.data else {} values_data = info.data if info.data else {}
env = values_data.get('ENVIRONMENT', 'development') env = values_data.get("ENVIRONMENT", "development")
if v.startswith("your_secret_key_here"): if v.startswith("your_secret_key_here"):
if env == "production": if env == "production":
@@ -106,13 +102,15 @@ class Settings(BaseSettings):
) )
if len(v) < 32: if len(v) < 32:
raise ValueError("SECRET_KEY must be at least 32 characters long for security") raise ValueError(
"SECRET_KEY must be at least 32 characters long for security"
)
return v return v
@field_validator('FIRST_SUPERUSER_PASSWORD') @field_validator("FIRST_SUPERUSER_PASSWORD")
@classmethod @classmethod
def validate_superuser_password(cls, v: Optional[str]) -> Optional[str]: def validate_superuser_password(cls, v: str | None) -> str | None:
"""Validate superuser password strength.""" """Validate superuser password strength."""
if v is None: if v is None:
return v return v
@@ -121,7 +119,13 @@ class Settings(BaseSettings):
raise ValueError("FIRST_SUPERUSER_PASSWORD must be at least 12 characters") raise ValueError("FIRST_SUPERUSER_PASSWORD must be at least 12 characters")
# Check for common weak passwords # Check for common weak passwords
weak_passwords = {'admin123', 'Admin123', 'password123', 'Password123', '123456789012'} weak_passwords = {
"admin123",
"Admin123",
"password123",
"Password123",
"123456789012",
}
if v in weak_passwords: if v in weak_passwords:
raise ValueError( raise ValueError(
"FIRST_SUPERUSER_PASSWORD is too weak. " "FIRST_SUPERUSER_PASSWORD is too weak. "
@@ -144,7 +148,7 @@ class Settings(BaseSettings):
"env_file": "../.env", "env_file": "../.env",
"env_file_encoding": "utf-8", "env_file_encoding": "utf-8",
"case_sensitive": True, "case_sensitive": True,
"extra": "ignore" # Ignore extra fields from .env (e.g., frontend-specific vars) "extra": "ignore", # Ignore extra fields from .env (e.g., frontend-specific vars)
} }

View File

@@ -5,17 +5,18 @@ Database configuration using SQLAlchemy 2.0 and asyncpg.
This module provides async database connectivity with proper connection pooling This module provides async database connectivity with proper connection pooling
and session management for FastAPI endpoints. and session management for FastAPI endpoints.
""" """
import logging import logging
from collections.abc import AsyncGenerator
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
from typing import AsyncGenerator
from sqlalchemy import text from sqlalchemy import text
from sqlalchemy.dialects.postgresql import JSONB, UUID from sqlalchemy.dialects.postgresql import JSONB, UUID
from sqlalchemy.ext.asyncio import ( from sqlalchemy.ext.asyncio import (
AsyncSession,
AsyncEngine, AsyncEngine,
create_async_engine, AsyncSession,
async_sessionmaker, async_sessionmaker,
create_async_engine,
) )
from sqlalchemy.ext.compiler import compiles from sqlalchemy.ext.compiler import compiles
from sqlalchemy.orm import DeclarativeBase from sqlalchemy.orm import DeclarativeBase
@@ -27,12 +28,12 @@ logger = logging.getLogger(__name__)
# SQLite compatibility for testing # SQLite compatibility for testing
@compiles(JSONB, 'sqlite') @compiles(JSONB, "sqlite")
def compile_jsonb_sqlite(type_, compiler, **kw): def compile_jsonb_sqlite(type_, compiler, **kw):
return "TEXT" return "TEXT"
@compiles(UUID, 'sqlite') @compiles(UUID, "sqlite")
def compile_uuid_sqlite(type_, compiler, **kw): def compile_uuid_sqlite(type_, compiler, **kw):
return "TEXT" return "TEXT"
@@ -40,7 +41,6 @@ def compile_uuid_sqlite(type_, compiler, **kw):
# Declarative base for models (SQLAlchemy 2.0 style) # Declarative base for models (SQLAlchemy 2.0 style)
class Base(DeclarativeBase): class Base(DeclarativeBase):
"""Base class for all database models.""" """Base class for all database models."""
pass
def get_async_database_url(url: str) -> str: def get_async_database_url(url: str) -> str:
@@ -139,7 +139,7 @@ async def async_transaction_scope() -> AsyncGenerator[AsyncSession, None]:
logger.debug("Async transaction committed successfully") logger.debug("Async transaction committed successfully")
except Exception as e: except Exception as e:
await session.rollback() await session.rollback()
logger.error(f"Async transaction failed, rolling back: {str(e)}") logger.error(f"Async transaction failed, rolling back: {e!s}")
raise raise
finally: finally:
await session.close() await session.close()
@@ -155,7 +155,7 @@ async def check_async_database_health() -> bool:
await db.execute(text("SELECT 1")) await db.execute(text("SELECT 1"))
return True return True
except Exception as e: except Exception as e:
logger.error(f"Async database health check failed: {str(e)}") logger.error(f"Async database health check failed: {e!s}")
return False return False

View File

@@ -1,8 +1,8 @@
""" """
Custom exceptions and global exception handlers for the API. Custom exceptions and global exception handlers for the API.
""" """
import logging import logging
from typing import Optional, Union
from fastapi import HTTPException, Request, status from fastapi import HTTPException, Request, status
from fastapi.exceptions import RequestValidationError from fastapi.exceptions import RequestValidationError
@@ -27,17 +27,13 @@ class APIException(HTTPException):
status_code: int, status_code: int,
error_code: ErrorCode, error_code: ErrorCode,
message: str, message: str,
field: Optional[str] = None, field: str | None = None,
headers: Optional[dict] = None headers: dict | None = None,
): ):
self.error_code = error_code self.error_code = error_code
self.field = field self.field = field
self.message = message self.message = message
super().__init__( super().__init__(status_code=status_code, detail=message, headers=headers)
status_code=status_code,
detail=message,
headers=headers
)
class AuthenticationError(APIException): class AuthenticationError(APIException):
@@ -47,14 +43,14 @@ class AuthenticationError(APIException):
self, self,
message: str = "Authentication failed", message: str = "Authentication failed",
error_code: ErrorCode = ErrorCode.INVALID_CREDENTIALS, error_code: ErrorCode = ErrorCode.INVALID_CREDENTIALS,
field: Optional[str] = None field: str | None = None,
): ):
super().__init__( super().__init__(
status_code=status.HTTP_401_UNAUTHORIZED, status_code=status.HTTP_401_UNAUTHORIZED,
error_code=error_code, error_code=error_code,
message=message, message=message,
field=field, field=field,
headers={"WWW-Authenticate": "Bearer"} headers={"WWW-Authenticate": "Bearer"},
) )
@@ -64,12 +60,12 @@ class AuthorizationError(APIException):
def __init__( def __init__(
self, self,
message: str = "Insufficient permissions", message: str = "Insufficient permissions",
error_code: ErrorCode = ErrorCode.INSUFFICIENT_PERMISSIONS error_code: ErrorCode = ErrorCode.INSUFFICIENT_PERMISSIONS,
): ):
super().__init__( super().__init__(
status_code=status.HTTP_403_FORBIDDEN, status_code=status.HTTP_403_FORBIDDEN,
error_code=error_code, error_code=error_code,
message=message message=message,
) )
@@ -79,12 +75,12 @@ class NotFoundError(APIException):
def __init__( def __init__(
self, self,
message: str = "Resource not found", message: str = "Resource not found",
error_code: ErrorCode = ErrorCode.NOT_FOUND error_code: ErrorCode = ErrorCode.NOT_FOUND,
): ):
super().__init__( super().__init__(
status_code=status.HTTP_404_NOT_FOUND, status_code=status.HTTP_404_NOT_FOUND,
error_code=error_code, error_code=error_code,
message=message message=message,
) )
@@ -95,13 +91,13 @@ class DuplicateError(APIException):
self, self,
message: str = "Resource already exists", message: str = "Resource already exists",
error_code: ErrorCode = ErrorCode.DUPLICATE_ENTRY, error_code: ErrorCode = ErrorCode.DUPLICATE_ENTRY,
field: Optional[str] = None field: str | None = None,
): ):
super().__init__( super().__init__(
status_code=status.HTTP_409_CONFLICT, status_code=status.HTTP_409_CONFLICT,
error_code=error_code, error_code=error_code,
message=message, message=message,
field=field field=field,
) )
@@ -112,13 +108,13 @@ class ValidationException(APIException):
self, self,
message: str = "Validation error", message: str = "Validation error",
error_code: ErrorCode = ErrorCode.VALIDATION_ERROR, error_code: ErrorCode = ErrorCode.VALIDATION_ERROR,
field: Optional[str] = None field: str | None = None,
): ):
super().__init__( super().__init__(
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
error_code=error_code, error_code=error_code,
message=message, message=message,
field=field field=field,
) )
@@ -128,12 +124,12 @@ class DatabaseError(APIException):
def __init__( def __init__(
self, self,
message: str = "Database operation failed", message: str = "Database operation failed",
error_code: ErrorCode = ErrorCode.DATABASE_ERROR error_code: ErrorCode = ErrorCode.DATABASE_ERROR,
): ):
super().__init__( super().__init__(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
error_code=error_code, error_code=error_code,
message=message message=message,
) )
@@ -152,23 +148,18 @@ async def api_exception_handler(request: Request, exc: APIException) -> JSONResp
) )
error_response = ErrorResponse( error_response = ErrorResponse(
errors=[ErrorDetail( errors=[ErrorDetail(code=exc.error_code, message=exc.message, field=exc.field)]
code=exc.error_code,
message=exc.message,
field=exc.field
)]
) )
return JSONResponse( return JSONResponse(
status_code=exc.status_code, status_code=exc.status_code,
content=error_response.model_dump(), content=error_response.model_dump(),
headers=exc.headers headers=exc.headers,
) )
async def validation_exception_handler( async def validation_exception_handler(
request: Request, request: Request, exc: RequestValidationError | ValidationError
exc: Union[RequestValidationError, ValidationError]
) -> JSONResponse: ) -> JSONResponse:
""" """
Handler for Pydantic validation errors. Handler for Pydantic validation errors.
@@ -189,22 +180,19 @@ async def validation_exception_handler(
# Skip 'body' or 'query' prefix in location # Skip 'body' or 'query' prefix in location
field = ".".join(str(x) for x in error["loc"][1:]) field = ".".join(str(x) for x in error["loc"][1:])
errors.append(ErrorDetail( errors.append(
code=ErrorCode.VALIDATION_ERROR, ErrorDetail(
message=error["msg"], code=ErrorCode.VALIDATION_ERROR, message=error["msg"], field=field
field=field )
)) )
logger.warning( logger.warning(f"Validation error: {len(errors)} errors (path: {request.url.path})")
f"Validation error: {len(errors)} errors "
f"(path: {request.url.path})"
)
error_response = ErrorResponse(errors=errors) error_response = ErrorResponse(errors=errors)
return JSONResponse( return JSONResponse(
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
content=error_response.model_dump() content=error_response.model_dump(),
) )
@@ -226,26 +214,21 @@ async def http_exception_handler(request: Request, exc: HTTPException) -> JSONRe
} }
error_code = status_code_to_error_code.get( error_code = status_code_to_error_code.get(
exc.status_code, exc.status_code, ErrorCode.INTERNAL_ERROR
ErrorCode.INTERNAL_ERROR
) )
logger.warning( logger.warning(
f"HTTP exception: {exc.status_code} - {exc.detail} " f"HTTP exception: {exc.status_code} - {exc.detail} (path: {request.url.path})"
f"(path: {request.url.path})"
) )
error_response = ErrorResponse( error_response = ErrorResponse(
errors=[ErrorDetail( errors=[ErrorDetail(code=error_code, message=str(exc.detail))]
code=error_code,
message=str(exc.detail)
)]
) )
return JSONResponse( return JSONResponse(
status_code=exc.status_code, status_code=exc.status_code,
content=error_response.model_dump(), content=error_response.model_dump(),
headers=exc.headers headers=exc.headers,
) )
@@ -257,26 +240,24 @@ async def unhandled_exception_handler(request: Request, exc: Exception) -> JSONR
leaking sensitive information in production. leaking sensitive information in production.
""" """
logger.error( logger.error(
f"Unhandled exception: {type(exc).__name__} - {str(exc)} " f"Unhandled exception: {type(exc).__name__} - {exc!s} "
f"(path: {request.url.path})", f"(path: {request.url.path})",
exc_info=True exc_info=True,
) )
# In production, don't expose internal error details # In production, don't expose internal error details
from app.core.config import settings from app.core.config import settings
if settings.ENVIRONMENT == "production": if settings.ENVIRONMENT == "production":
message = "An internal error occurred. Please try again later." message = "An internal error occurred. Please try again later."
else: else:
message = f"{type(exc).__name__}: {str(exc)}" message = f"{type(exc).__name__}: {exc!s}"
error_response = ErrorResponse( error_response = ErrorResponse(
errors=[ErrorDetail( errors=[ErrorDetail(code=ErrorCode.INTERNAL_ERROR, message=message)]
code=ErrorCode.INTERNAL_ERROR,
message=message
)]
) )
return JSONResponse( return JSONResponse(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
content=error_response.model_dump() content=error_response.model_dump(),
) )

View File

@@ -3,4 +3,4 @@ from .organization import organization
from .session import session as session_crud from .session import session as session_crud
from .user import user from .user import user
__all__ = ["user", "session_crud", "organization"] __all__ = ["organization", "session_crud", "user"]

View File

@@ -4,14 +4,16 @@ Async CRUD operations base class using SQLAlchemy 2.0 async patterns.
Provides reusable create, read, update, and delete operations for all models. Provides reusable create, read, update, and delete operations for all models.
""" """
import logging import logging
import uuid import uuid
from typing import Any, Dict, Generic, List, Optional, Type, TypeVar, Union, Tuple from datetime import UTC
from typing import Any, TypeVar
from fastapi.encoders import jsonable_encoder from fastapi.encoders import jsonable_encoder
from pydantic import BaseModel from pydantic import BaseModel
from sqlalchemy import func, select from sqlalchemy import func, select
from sqlalchemy.exc import IntegrityError, OperationalError, DataError from sqlalchemy.exc import DataError, IntegrityError, OperationalError
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import Load from sqlalchemy.orm import Load
@@ -24,10 +26,14 @@ CreateSchemaType = TypeVar("CreateSchemaType", bound=BaseModel)
UpdateSchemaType = TypeVar("UpdateSchemaType", bound=BaseModel) UpdateSchemaType = TypeVar("UpdateSchemaType", bound=BaseModel)
class CRUDBase(Generic[ModelType, CreateSchemaType, UpdateSchemaType]): class CRUDBase[
ModelType: Base,
CreateSchemaType: BaseModel,
UpdateSchemaType: BaseModel,
]:
"""Async CRUD operations for a model.""" """Async CRUD operations for a model."""
def __init__(self, model: Type[ModelType]): def __init__(self, model: type[ModelType]):
""" """
CRUD object with default async methods to Create, Read, Update, Delete. CRUD object with default async methods to Create, Read, Update, Delete.
@@ -37,11 +43,8 @@ class CRUDBase(Generic[ModelType, CreateSchemaType, UpdateSchemaType]):
self.model = model self.model = model
async def get( async def get(
self, self, db: AsyncSession, id: str, options: list[Load] | None = None
db: AsyncSession, ) -> ModelType | None:
id: str,
options: Optional[List[Load]] = None
) -> Optional[ModelType]:
""" """
Get a single record by ID with UUID validation and optional eager loading. Get a single record by ID with UUID validation and optional eager loading.
@@ -66,7 +69,7 @@ class CRUDBase(Generic[ModelType, CreateSchemaType, UpdateSchemaType]):
else: else:
uuid_obj = uuid.UUID(str(id)) uuid_obj = uuid.UUID(str(id))
except (ValueError, AttributeError, TypeError) as e: except (ValueError, AttributeError, TypeError) as e:
logger.warning(f"Invalid UUID format: {id} - {str(e)}") logger.warning(f"Invalid UUID format: {id} - {e!s}")
return None return None
try: try:
@@ -80,7 +83,7 @@ class CRUDBase(Generic[ModelType, CreateSchemaType, UpdateSchemaType]):
result = await db.execute(query) result = await db.execute(query)
return result.scalar_one_or_none() return result.scalar_one_or_none()
except Exception as e: except Exception as e:
logger.error(f"Error retrieving {self.model.__name__} with id {id}: {str(e)}") logger.error(f"Error retrieving {self.model.__name__} with id {id}: {e!s}")
raise raise
async def get_multi( async def get_multi(
@@ -89,8 +92,8 @@ class CRUDBase(Generic[ModelType, CreateSchemaType, UpdateSchemaType]):
*, *,
skip: int = 0, skip: int = 0,
limit: int = 100, limit: int = 100,
options: Optional[List[Load]] = None options: list[Load] | None = None,
) -> List[ModelType]: ) -> list[ModelType]:
""" """
Get multiple records with pagination validation and optional eager loading. Get multiple records with pagination validation and optional eager loading.
@@ -122,10 +125,14 @@ class CRUDBase(Generic[ModelType, CreateSchemaType, UpdateSchemaType]):
result = await db.execute(query) result = await db.execute(query)
return list(result.scalars().all()) return list(result.scalars().all())
except Exception as e: except Exception as e:
logger.error(f"Error retrieving multiple {self.model.__name__} records: {str(e)}") logger.error(
f"Error retrieving multiple {self.model.__name__} records: {e!s}"
)
raise raise
async def create(self, db: AsyncSession, *, obj_in: CreateSchemaType) -> ModelType: # pragma: no cover async def create(
self, db: AsyncSession, *, obj_in: CreateSchemaType
) -> ModelType: # pragma: no cover
"""Create a new record with error handling. """Create a new record with error handling.
NOTE: This method is defensive code that's never called in practice. NOTE: This method is defensive code that's never called in practice.
@@ -142,19 +149,25 @@ class CRUDBase(Generic[ModelType, CreateSchemaType, UpdateSchemaType]):
return db_obj return db_obj
except IntegrityError as e: # pragma: no cover except IntegrityError as e: # pragma: no cover
await db.rollback() await db.rollback()
error_msg = str(e.orig) if hasattr(e, 'orig') else str(e) error_msg = str(e.orig) if hasattr(e, "orig") else str(e)
if "unique" in error_msg.lower() or "duplicate" in error_msg.lower(): if "unique" in error_msg.lower() or "duplicate" in error_msg.lower():
logger.warning(f"Duplicate entry attempted for {self.model.__name__}: {error_msg}") logger.warning(
raise ValueError(f"A {self.model.__name__} with this data already exists") f"Duplicate entry attempted for {self.model.__name__}: {error_msg}"
)
raise ValueError(
f"A {self.model.__name__} with this data already exists"
)
logger.error(f"Integrity error creating {self.model.__name__}: {error_msg}") logger.error(f"Integrity error creating {self.model.__name__}: {error_msg}")
raise ValueError(f"Database integrity error: {error_msg}") raise ValueError(f"Database integrity error: {error_msg}")
except (OperationalError, DataError) as e: # pragma: no cover except (OperationalError, DataError) as e: # pragma: no cover
await db.rollback() await db.rollback()
logger.error(f"Database error creating {self.model.__name__}: {str(e)}") logger.error(f"Database error creating {self.model.__name__}: {e!s}")
raise ValueError(f"Database operation failed: {str(e)}") raise ValueError(f"Database operation failed: {e!s}")
except Exception as e: # pragma: no cover except Exception as e: # pragma: no cover
await db.rollback() await db.rollback()
logger.error(f"Unexpected error creating {self.model.__name__}: {str(e)}", exc_info=True) logger.error(
f"Unexpected error creating {self.model.__name__}: {e!s}", exc_info=True
)
raise raise
async def update( async def update(
@@ -162,7 +175,7 @@ class CRUDBase(Generic[ModelType, CreateSchemaType, UpdateSchemaType]):
db: AsyncSession, db: AsyncSession,
*, *,
db_obj: ModelType, db_obj: ModelType,
obj_in: Union[UpdateSchemaType, Dict[str, Any]] obj_in: UpdateSchemaType | dict[str, Any],
) -> ModelType: ) -> ModelType:
"""Update a record with error handling.""" """Update a record with error handling."""
try: try:
@@ -182,22 +195,28 @@ class CRUDBase(Generic[ModelType, CreateSchemaType, UpdateSchemaType]):
return db_obj return db_obj
except IntegrityError as e: except IntegrityError as e:
await db.rollback() await db.rollback()
error_msg = str(e.orig) if hasattr(e, 'orig') else str(e) error_msg = str(e.orig) if hasattr(e, "orig") else str(e)
if "unique" in error_msg.lower() or "duplicate" in error_msg.lower(): if "unique" in error_msg.lower() or "duplicate" in error_msg.lower():
logger.warning(f"Duplicate entry attempted for {self.model.__name__}: {error_msg}") logger.warning(
raise ValueError(f"A {self.model.__name__} with this data already exists") f"Duplicate entry attempted for {self.model.__name__}: {error_msg}"
)
raise ValueError(
f"A {self.model.__name__} with this data already exists"
)
logger.error(f"Integrity error updating {self.model.__name__}: {error_msg}") logger.error(f"Integrity error updating {self.model.__name__}: {error_msg}")
raise ValueError(f"Database integrity error: {error_msg}") raise ValueError(f"Database integrity error: {error_msg}")
except (OperationalError, DataError) as e: except (OperationalError, DataError) as e:
await db.rollback() await db.rollback()
logger.error(f"Database error updating {self.model.__name__}: {str(e)}") logger.error(f"Database error updating {self.model.__name__}: {e!s}")
raise ValueError(f"Database operation failed: {str(e)}") raise ValueError(f"Database operation failed: {e!s}")
except Exception as e: except Exception as e:
await db.rollback() await db.rollback()
logger.error(f"Unexpected error updating {self.model.__name__}: {str(e)}", exc_info=True) logger.error(
f"Unexpected error updating {self.model.__name__}: {e!s}", exc_info=True
)
raise raise
async def remove(self, db: AsyncSession, *, id: str) -> Optional[ModelType]: async def remove(self, db: AsyncSession, *, id: str) -> ModelType | None:
"""Delete a record with error handling and null check.""" """Delete a record with error handling and null check."""
# Validate UUID format and convert to UUID object if string # Validate UUID format and convert to UUID object if string
try: try:
@@ -206,7 +225,7 @@ class CRUDBase(Generic[ModelType, CreateSchemaType, UpdateSchemaType]):
else: else:
uuid_obj = uuid.UUID(str(id)) uuid_obj = uuid.UUID(str(id))
except (ValueError, AttributeError, TypeError) as e: except (ValueError, AttributeError, TypeError) as e:
logger.warning(f"Invalid UUID format for deletion: {id} - {str(e)}") logger.warning(f"Invalid UUID format for deletion: {id} - {e!s}")
return None return None
try: try:
@@ -216,7 +235,9 @@ class CRUDBase(Generic[ModelType, CreateSchemaType, UpdateSchemaType]):
obj = result.scalar_one_or_none() obj = result.scalar_one_or_none()
if obj is None: if obj is None:
logger.warning(f"{self.model.__name__} with id {id} not found for deletion") logger.warning(
f"{self.model.__name__} with id {id} not found for deletion"
)
return None return None
await db.delete(obj) await db.delete(obj)
@@ -224,12 +245,17 @@ class CRUDBase(Generic[ModelType, CreateSchemaType, UpdateSchemaType]):
return obj return obj
except IntegrityError as e: except IntegrityError as e:
await db.rollback() await db.rollback()
error_msg = str(e.orig) if hasattr(e, 'orig') else str(e) error_msg = str(e.orig) if hasattr(e, "orig") else str(e)
logger.error(f"Integrity error deleting {self.model.__name__}: {error_msg}") logger.error(f"Integrity error deleting {self.model.__name__}: {error_msg}")
raise ValueError(f"Cannot delete {self.model.__name__}: referenced by other records") raise ValueError(
f"Cannot delete {self.model.__name__}: referenced by other records"
)
except Exception as e: except Exception as e:
await db.rollback() await db.rollback()
logger.error(f"Error deleting {self.model.__name__} with id {id}: {str(e)}", exc_info=True) logger.error(
f"Error deleting {self.model.__name__} with id {id}: {e!s}",
exc_info=True,
)
raise raise
async def get_multi_with_total( async def get_multi_with_total(
@@ -238,10 +264,10 @@ class CRUDBase(Generic[ModelType, CreateSchemaType, UpdateSchemaType]):
*, *,
skip: int = 0, skip: int = 0,
limit: int = 100, limit: int = 100,
sort_by: Optional[str] = None, sort_by: str | None = None,
sort_order: str = "asc", sort_order: str = "asc",
filters: Optional[Dict[str, Any]] = None filters: dict[str, Any] | None = None,
) -> Tuple[List[ModelType], int]: ) -> tuple[list[ModelType], int]:
""" """
Get multiple records with total count, filtering, and sorting. Get multiple records with total count, filtering, and sorting.
@@ -269,7 +295,7 @@ class CRUDBase(Generic[ModelType, CreateSchemaType, UpdateSchemaType]):
query = select(self.model) query = select(self.model)
# Exclude soft-deleted records by default # Exclude soft-deleted records by default
if hasattr(self.model, 'deleted_at'): if hasattr(self.model, "deleted_at"):
query = query.where(self.model.deleted_at.is_(None)) query = query.where(self.model.deleted_at.is_(None))
# Apply filters # Apply filters
@@ -298,7 +324,9 @@ class CRUDBase(Generic[ModelType, CreateSchemaType, UpdateSchemaType]):
return items, total return items, total
except Exception as e: except Exception as e:
logger.error(f"Error retrieving paginated {self.model.__name__} records: {str(e)}") logger.error(
f"Error retrieving paginated {self.model.__name__} records: {e!s}"
)
raise raise
async def count(self, db: AsyncSession) -> int: async def count(self, db: AsyncSession) -> int:
@@ -307,7 +335,7 @@ class CRUDBase(Generic[ModelType, CreateSchemaType, UpdateSchemaType]):
result = await db.execute(select(func.count(self.model.id))) result = await db.execute(select(func.count(self.model.id)))
return result.scalar_one() return result.scalar_one()
except Exception as e: except Exception as e:
logger.error(f"Error counting {self.model.__name__} records: {str(e)}") logger.error(f"Error counting {self.model.__name__} records: {e!s}")
raise raise
async def exists(self, db: AsyncSession, id: str) -> bool: async def exists(self, db: AsyncSession, id: str) -> bool:
@@ -315,13 +343,13 @@ class CRUDBase(Generic[ModelType, CreateSchemaType, UpdateSchemaType]):
obj = await self.get(db, id=id) obj = await self.get(db, id=id)
return obj is not None return obj is not None
async def soft_delete(self, db: AsyncSession, *, id: str) -> Optional[ModelType]: async def soft_delete(self, db: AsyncSession, *, id: str) -> ModelType | None:
""" """
Soft delete a record by setting deleted_at timestamp. Soft delete a record by setting deleted_at timestamp.
Only works if the model has a 'deleted_at' column. Only works if the model has a 'deleted_at' column.
""" """
from datetime import datetime, timezone from datetime import datetime
# Validate UUID format and convert to UUID object if string # Validate UUID format and convert to UUID object if string
try: try:
@@ -330,7 +358,7 @@ class CRUDBase(Generic[ModelType, CreateSchemaType, UpdateSchemaType]):
else: else:
uuid_obj = uuid.UUID(str(id)) uuid_obj = uuid.UUID(str(id))
except (ValueError, AttributeError, TypeError) as e: except (ValueError, AttributeError, TypeError) as e:
logger.warning(f"Invalid UUID format for soft deletion: {id} - {str(e)}") logger.warning(f"Invalid UUID format for soft deletion: {id} - {e!s}")
return None return None
try: try:
@@ -340,26 +368,33 @@ class CRUDBase(Generic[ModelType, CreateSchemaType, UpdateSchemaType]):
obj = result.scalar_one_or_none() obj = result.scalar_one_or_none()
if obj is None: if obj is None:
logger.warning(f"{self.model.__name__} with id {id} not found for soft deletion") logger.warning(
f"{self.model.__name__} with id {id} not found for soft deletion"
)
return None return None
# Check if model supports soft deletes # Check if model supports soft deletes
if not hasattr(self.model, 'deleted_at'): if not hasattr(self.model, "deleted_at"):
logger.error(f"{self.model.__name__} does not support soft deletes") logger.error(f"{self.model.__name__} does not support soft deletes")
raise ValueError(f"{self.model.__name__} does not have a deleted_at column") raise ValueError(
f"{self.model.__name__} does not have a deleted_at column"
)
# Set deleted_at timestamp # Set deleted_at timestamp
obj.deleted_at = datetime.now(timezone.utc) obj.deleted_at = datetime.now(UTC)
db.add(obj) db.add(obj)
await db.commit() await db.commit()
await db.refresh(obj) await db.refresh(obj)
return obj return obj
except Exception as e: except Exception as e:
await db.rollback() await db.rollback()
logger.error(f"Error soft deleting {self.model.__name__} with id {id}: {str(e)}", exc_info=True) logger.error(
f"Error soft deleting {self.model.__name__} with id {id}: {e!s}",
exc_info=True,
)
raise raise
async def restore(self, db: AsyncSession, *, id: str) -> Optional[ModelType]: async def restore(self, db: AsyncSession, *, id: str) -> ModelType | None:
""" """
Restore a soft-deleted record by clearing the deleted_at timestamp. Restore a soft-deleted record by clearing the deleted_at timestamp.
@@ -372,25 +407,28 @@ class CRUDBase(Generic[ModelType, CreateSchemaType, UpdateSchemaType]):
else: else:
uuid_obj = uuid.UUID(str(id)) uuid_obj = uuid.UUID(str(id))
except (ValueError, AttributeError, TypeError) as e: except (ValueError, AttributeError, TypeError) as e:
logger.warning(f"Invalid UUID format for restoration: {id} - {str(e)}") logger.warning(f"Invalid UUID format for restoration: {id} - {e!s}")
return None return None
try: try:
# Find the soft-deleted record # Find the soft-deleted record
if hasattr(self.model, 'deleted_at'): if hasattr(self.model, "deleted_at"):
result = await db.execute( result = await db.execute(
select(self.model).where( select(self.model).where(
self.model.id == uuid_obj, self.model.id == uuid_obj, self.model.deleted_at.isnot(None)
self.model.deleted_at.isnot(None)
) )
) )
obj = result.scalar_one_or_none() obj = result.scalar_one_or_none()
else: else:
logger.error(f"{self.model.__name__} does not support soft deletes") logger.error(f"{self.model.__name__} does not support soft deletes")
raise ValueError(f"{self.model.__name__} does not have a deleted_at column") raise ValueError(
f"{self.model.__name__} does not have a deleted_at column"
)
if obj is None: if obj is None:
logger.warning(f"Soft-deleted {self.model.__name__} with id {id} not found for restoration") logger.warning(
f"Soft-deleted {self.model.__name__} with id {id} not found for restoration"
)
return None return None
# Clear deleted_at timestamp # Clear deleted_at timestamp
@@ -401,5 +439,8 @@ class CRUDBase(Generic[ModelType, CreateSchemaType, UpdateSchemaType]):
return obj return obj
except Exception as e: except Exception as e:
await db.rollback() await db.rollback()
logger.error(f"Error restoring {self.model.__name__} with id {id}: {str(e)}", exc_info=True) logger.error(
f"Error restoring {self.model.__name__} with id {id}: {e!s}",
exc_info=True,
)
raise raise

View File

@@ -1,17 +1,18 @@
# app/crud/organization_async.py # app/crud/organization_async.py
"""Async CRUD operations for Organization model using SQLAlchemy 2.0 patterns.""" """Async CRUD operations for Organization model using SQLAlchemy 2.0 patterns."""
import logging import logging
from typing import Optional, List, Dict, Any from typing import Any
from uuid import UUID from uuid import UUID
from sqlalchemy import func, or_, and_, select, case from sqlalchemy import and_, case, func, or_, select
from sqlalchemy.exc import IntegrityError from sqlalchemy.exc import IntegrityError
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from app.crud.base import CRUDBase from app.crud.base import CRUDBase
from app.models.organization import Organization from app.models.organization import Organization
from app.models.user import User from app.models.user import User
from app.models.user_organization import UserOrganization, OrganizationRole from app.models.user_organization import OrganizationRole, UserOrganization
from app.schemas.organizations import ( from app.schemas.organizations import (
OrganizationCreate, OrganizationCreate,
OrganizationUpdate, OrganizationUpdate,
@@ -23,7 +24,7 @@ logger = logging.getLogger(__name__)
class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUpdate]): class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUpdate]):
"""Async CRUD operations for Organization model.""" """Async CRUD operations for Organization model."""
async def get_by_slug(self, db: AsyncSession, *, slug: str) -> Optional[Organization]: async def get_by_slug(self, db: AsyncSession, *, slug: str) -> Organization | None:
"""Get organization by slug.""" """Get organization by slug."""
try: try:
result = await db.execute( result = await db.execute(
@@ -31,10 +32,12 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp
) )
return result.scalar_one_or_none() return result.scalar_one_or_none()
except Exception as e: except Exception as e:
logger.error(f"Error getting organization by slug {slug}: {str(e)}") logger.error(f"Error getting organization by slug {slug}: {e!s}")
raise raise
async def create(self, db: AsyncSession, *, obj_in: OrganizationCreate) -> Organization: async def create(
self, db: AsyncSession, *, obj_in: OrganizationCreate
) -> Organization:
"""Create a new organization with error handling.""" """Create a new organization with error handling."""
try: try:
db_obj = Organization( db_obj = Organization(
@@ -42,7 +45,7 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp
slug=obj_in.slug, slug=obj_in.slug,
description=obj_in.description, description=obj_in.description,
is_active=obj_in.is_active, is_active=obj_in.is_active,
settings=obj_in.settings or {} settings=obj_in.settings or {},
) )
db.add(db_obj) db.add(db_obj)
await db.commit() await db.commit()
@@ -50,15 +53,19 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp
return db_obj return db_obj
except IntegrityError as e: except IntegrityError as e:
await db.rollback() await db.rollback()
error_msg = str(e.orig) if hasattr(e, 'orig') else str(e) error_msg = str(e.orig) if hasattr(e, "orig") else str(e)
if "slug" in error_msg.lower(): if "slug" in error_msg.lower():
logger.warning(f"Duplicate slug attempted: {obj_in.slug}") logger.warning(f"Duplicate slug attempted: {obj_in.slug}")
raise ValueError(f"Organization with slug '{obj_in.slug}' already exists") raise ValueError(
f"Organization with slug '{obj_in.slug}' already exists"
)
logger.error(f"Integrity error creating organization: {error_msg}") logger.error(f"Integrity error creating organization: {error_msg}")
raise ValueError(f"Database integrity error: {error_msg}") raise ValueError(f"Database integrity error: {error_msg}")
except Exception as e: except Exception as e:
await db.rollback() await db.rollback()
logger.error(f"Unexpected error creating organization: {str(e)}", exc_info=True) logger.error(
f"Unexpected error creating organization: {e!s}", exc_info=True
)
raise raise
async def get_multi_with_filters( async def get_multi_with_filters(
@@ -67,11 +74,11 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp
*, *,
skip: int = 0, skip: int = 0,
limit: int = 100, limit: int = 100,
is_active: Optional[bool] = None, is_active: bool | None = None,
search: Optional[str] = None, search: str | None = None,
sort_by: str = "created_at", sort_by: str = "created_at",
sort_order: str = "desc" sort_order: str = "desc",
) -> tuple[List[Organization], int]: ) -> tuple[list[Organization], int]:
""" """
Get multiple organizations with filtering, searching, and sorting. Get multiple organizations with filtering, searching, and sorting.
@@ -89,7 +96,7 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp
search_filter = or_( search_filter = or_(
Organization.name.ilike(f"%{search}%"), Organization.name.ilike(f"%{search}%"),
Organization.slug.ilike(f"%{search}%"), Organization.slug.ilike(f"%{search}%"),
Organization.description.ilike(f"%{search}%") Organization.description.ilike(f"%{search}%"),
) )
query = query.where(search_filter) query = query.where(search_filter)
@@ -112,7 +119,7 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp
return organizations, total return organizations, total
except Exception as e: except Exception as e:
logger.error(f"Error getting organizations with filters: {str(e)}") logger.error(f"Error getting organizations with filters: {e!s}")
raise raise
async def get_member_count(self, db: AsyncSession, *, organization_id: UUID) -> int: async def get_member_count(self, db: AsyncSession, *, organization_id: UUID) -> int:
@@ -122,13 +129,15 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp
select(func.count(UserOrganization.user_id)).where( select(func.count(UserOrganization.user_id)).where(
and_( and_(
UserOrganization.organization_id == organization_id, UserOrganization.organization_id == organization_id,
UserOrganization.is_active == True UserOrganization.is_active,
) )
) )
) )
return result.scalar_one() or 0 return result.scalar_one() or 0
except Exception as e: except Exception as e:
logger.error(f"Error getting member count for organization {organization_id}: {str(e)}") logger.error(
f"Error getting member count for organization {organization_id}: {e!s}"
)
raise raise
async def get_multi_with_member_counts( async def get_multi_with_member_counts(
@@ -137,9 +146,9 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp
*, *,
skip: int = 0, skip: int = 0,
limit: int = 100, limit: int = 100,
is_active: Optional[bool] = None, is_active: bool | None = None,
search: Optional[str] = None search: str | None = None,
) -> tuple[List[Dict[str, Any]], int]: ) -> tuple[list[dict[str, Any]], int]:
""" """
Get organizations with member counts in a SINGLE QUERY using JOIN and GROUP BY. Get organizations with member counts in a SINGLE QUERY using JOIN and GROUP BY.
This eliminates the N+1 query problem. This eliminates the N+1 query problem.
@@ -156,13 +165,19 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp
func.count( func.count(
func.distinct( func.distinct(
case( case(
(UserOrganization.is_active == True, UserOrganization.user_id), (
else_=None UserOrganization.is_active,
UserOrganization.user_id,
),
else_=None,
) )
) )
).label('member_count') ).label("member_count"),
)
.outerjoin(
UserOrganization,
Organization.id == UserOrganization.organization_id,
) )
.outerjoin(UserOrganization, Organization.id == UserOrganization.organization_id)
.group_by(Organization.id) .group_by(Organization.id)
) )
@@ -174,7 +189,7 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp
search_filter = or_( search_filter = or_(
Organization.name.ilike(f"%{search}%"), Organization.name.ilike(f"%{search}%"),
Organization.slug.ilike(f"%{search}%"), Organization.slug.ilike(f"%{search}%"),
Organization.description.ilike(f"%{search}%") Organization.description.ilike(f"%{search}%"),
) )
query = query.where(search_filter) query = query.where(search_filter)
@@ -189,24 +204,25 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp
total = count_result.scalar_one() total = count_result.scalar_one()
# Apply pagination and ordering # Apply pagination and ordering
query = query.order_by(Organization.created_at.desc()).offset(skip).limit(limit) query = (
query.order_by(Organization.created_at.desc()).offset(skip).limit(limit)
)
result = await db.execute(query) result = await db.execute(query)
rows = result.all() rows = result.all()
# Convert to list of dicts # Convert to list of dicts
orgs_with_counts = [ orgs_with_counts = [
{ {"organization": org, "member_count": member_count}
'organization': org,
'member_count': member_count
}
for org, member_count in rows for org, member_count in rows
] ]
return orgs_with_counts, total return orgs_with_counts, total
except Exception as e: except Exception as e:
logger.error(f"Error getting organizations with member counts: {str(e)}", exc_info=True) logger.error(
f"Error getting organizations with member counts: {e!s}", exc_info=True
)
raise raise
async def add_user( async def add_user(
@@ -216,7 +232,7 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp
organization_id: UUID, organization_id: UUID,
user_id: UUID, user_id: UUID,
role: OrganizationRole = OrganizationRole.MEMBER, role: OrganizationRole = OrganizationRole.MEMBER,
custom_permissions: Optional[str] = None custom_permissions: str | None = None,
) -> UserOrganization: ) -> UserOrganization:
"""Add a user to an organization with a specific role.""" """Add a user to an organization with a specific role."""
try: try:
@@ -225,7 +241,7 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp
select(UserOrganization).where( select(UserOrganization).where(
and_( and_(
UserOrganization.user_id == user_id, UserOrganization.user_id == user_id,
UserOrganization.organization_id == organization_id UserOrganization.organization_id == organization_id,
) )
) )
) )
@@ -249,7 +265,7 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp
organization_id=organization_id, organization_id=organization_id,
role=role, role=role,
is_active=True, is_active=True,
custom_permissions=custom_permissions custom_permissions=custom_permissions,
) )
db.add(user_org) db.add(user_org)
await db.commit() await db.commit()
@@ -257,19 +273,15 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp
return user_org return user_org
except IntegrityError as e: except IntegrityError as e:
await db.rollback() await db.rollback()
logger.error(f"Integrity error adding user to organization: {str(e)}") logger.error(f"Integrity error adding user to organization: {e!s}")
raise ValueError("Failed to add user to organization") raise ValueError("Failed to add user to organization")
except Exception as e: except Exception as e:
await db.rollback() await db.rollback()
logger.error(f"Error adding user to organization: {str(e)}", exc_info=True) logger.error(f"Error adding user to organization: {e!s}", exc_info=True)
raise raise
async def remove_user( async def remove_user(
self, self, db: AsyncSession, *, organization_id: UUID, user_id: UUID
db: AsyncSession,
*,
organization_id: UUID,
user_id: UUID
) -> bool: ) -> bool:
"""Remove a user from an organization (soft delete).""" """Remove a user from an organization (soft delete)."""
try: try:
@@ -277,7 +289,7 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp
select(UserOrganization).where( select(UserOrganization).where(
and_( and_(
UserOrganization.user_id == user_id, UserOrganization.user_id == user_id,
UserOrganization.organization_id == organization_id UserOrganization.organization_id == organization_id,
) )
) )
) )
@@ -291,7 +303,7 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp
return True return True
except Exception as e: except Exception as e:
await db.rollback() await db.rollback()
logger.error(f"Error removing user from organization: {str(e)}", exc_info=True) logger.error(f"Error removing user from organization: {e!s}", exc_info=True)
raise raise
async def update_user_role( async def update_user_role(
@@ -301,15 +313,15 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp
organization_id: UUID, organization_id: UUID,
user_id: UUID, user_id: UUID,
role: OrganizationRole, role: OrganizationRole,
custom_permissions: Optional[str] = None custom_permissions: str | None = None,
) -> Optional[UserOrganization]: ) -> UserOrganization | None:
"""Update a user's role in an organization.""" """Update a user's role in an organization."""
try: try:
result = await db.execute( result = await db.execute(
select(UserOrganization).where( select(UserOrganization).where(
and_( and_(
UserOrganization.user_id == user_id, UserOrganization.user_id == user_id,
UserOrganization.organization_id == organization_id UserOrganization.organization_id == organization_id,
) )
) )
) )
@@ -326,7 +338,7 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp
return user_org return user_org
except Exception as e: except Exception as e:
await db.rollback() await db.rollback()
logger.error(f"Error updating user role: {str(e)}", exc_info=True) logger.error(f"Error updating user role: {e!s}", exc_info=True)
raise raise
async def get_organization_members( async def get_organization_members(
@@ -336,8 +348,8 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp
organization_id: UUID, organization_id: UUID,
skip: int = 0, skip: int = 0,
limit: int = 100, limit: int = 100,
is_active: bool = True is_active: bool = True,
) -> tuple[List[Dict[str, Any]], int]: ) -> tuple[list[dict[str, Any]], int]:
""" """
Get members of an organization with user details. Get members of an organization with user details.
@@ -359,46 +371,55 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp
count_query = select(func.count()).select_from( count_query = select(func.count()).select_from(
select(UserOrganization) select(UserOrganization)
.where(UserOrganization.organization_id == organization_id) .where(UserOrganization.organization_id == organization_id)
.where(UserOrganization.is_active == is_active if is_active is not None else True) .where(
UserOrganization.is_active == is_active
if is_active is not None
else True
)
.alias() .alias()
) )
count_result = await db.execute(count_query) count_result = await db.execute(count_query)
total = count_result.scalar_one() total = count_result.scalar_one()
# Apply ordering and pagination # Apply ordering and pagination
query = query.order_by(UserOrganization.created_at.desc()).offset(skip).limit(limit) query = (
query.order_by(UserOrganization.created_at.desc())
.offset(skip)
.limit(limit)
)
result = await db.execute(query) result = await db.execute(query)
results = result.all() results = result.all()
members = [] members = []
for user_org, user in results: for user_org, user in results:
members.append({ members.append(
"user_id": user.id, {
"email": user.email, "user_id": user.id,
"first_name": user.first_name, "email": user.email,
"last_name": user.last_name, "first_name": user.first_name,
"role": user_org.role, "last_name": user.last_name,
"is_active": user_org.is_active, "role": user_org.role,
"joined_at": user_org.created_at "is_active": user_org.is_active,
}) "joined_at": user_org.created_at,
}
)
return members, total return members, total
except Exception as e: except Exception as e:
logger.error(f"Error getting organization members: {str(e)}") logger.error(f"Error getting organization members: {e!s}")
raise raise
async def get_user_organizations( async def get_user_organizations(
self, self, db: AsyncSession, *, user_id: UUID, is_active: bool = True
db: AsyncSession, ) -> list[Organization]:
*,
user_id: UUID,
is_active: bool = True
) -> List[Organization]:
"""Get all organizations a user belongs to.""" """Get all organizations a user belongs to."""
try: try:
query = ( query = (
select(Organization) select(Organization)
.join(UserOrganization, Organization.id == UserOrganization.organization_id) .join(
UserOrganization,
Organization.id == UserOrganization.organization_id,
)
.where(UserOrganization.user_id == user_id) .where(UserOrganization.user_id == user_id)
) )
@@ -408,16 +429,12 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp
result = await db.execute(query) result = await db.execute(query)
return list(result.scalars().all()) return list(result.scalars().all())
except Exception as e: except Exception as e:
logger.error(f"Error getting user organizations: {str(e)}") logger.error(f"Error getting user organizations: {e!s}")
raise raise
async def get_user_organizations_with_details( async def get_user_organizations_with_details(
self, self, db: AsyncSession, *, user_id: UUID, is_active: bool = True
db: AsyncSession, ) -> list[dict[str, Any]]:
*,
user_id: UUID,
is_active: bool = True
) -> List[Dict[str, Any]]:
""" """
Get user's organizations with role and member count in SINGLE QUERY. Get user's organizations with role and member count in SINGLE QUERY.
Eliminates N+1 problem by using subquery for member counts. Eliminates N+1 problem by using subquery for member counts.
@@ -430,9 +447,9 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp
member_count_subq = ( member_count_subq = (
select( select(
UserOrganization.organization_id, UserOrganization.organization_id,
func.count(UserOrganization.user_id).label('member_count') func.count(UserOrganization.user_id).label("member_count"),
) )
.where(UserOrganization.is_active == True) .where(UserOrganization.is_active)
.group_by(UserOrganization.organization_id) .group_by(UserOrganization.organization_id)
.subquery() .subquery()
) )
@@ -442,10 +459,18 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp
select( select(
Organization, Organization,
UserOrganization.role, UserOrganization.role,
func.coalesce(member_count_subq.c.member_count, 0).label('member_count') func.coalesce(member_count_subq.c.member_count, 0).label(
"member_count"
),
)
.join(
UserOrganization,
Organization.id == UserOrganization.organization_id,
)
.outerjoin(
member_count_subq,
Organization.id == member_count_subq.c.organization_id,
) )
.join(UserOrganization, Organization.id == UserOrganization.organization_id)
.outerjoin(member_count_subq, Organization.id == member_count_subq.c.organization_id)
.where(UserOrganization.user_id == user_id) .where(UserOrganization.user_id == user_id)
) )
@@ -456,25 +481,19 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp
rows = result.all() rows = result.all()
return [ return [
{ {"organization": org, "role": role, "member_count": member_count}
'organization': org,
'role': role,
'member_count': member_count
}
for org, role, member_count in rows for org, role, member_count in rows
] ]
except Exception as e: except Exception as e:
logger.error(f"Error getting user organizations with details: {str(e)}", exc_info=True) logger.error(
f"Error getting user organizations with details: {e!s}", exc_info=True
)
raise raise
async def get_user_role_in_org( async def get_user_role_in_org(
self, self, db: AsyncSession, *, user_id: UUID, organization_id: UUID
db: AsyncSession, ) -> OrganizationRole | None:
*,
user_id: UUID,
organization_id: UUID
) -> Optional[OrganizationRole]:
"""Get a user's role in a specific organization.""" """Get a user's role in a specific organization."""
try: try:
result = await db.execute( result = await db.execute(
@@ -482,7 +501,7 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp
and_( and_(
UserOrganization.user_id == user_id, UserOrganization.user_id == user_id,
UserOrganization.organization_id == organization_id, UserOrganization.organization_id == organization_id,
UserOrganization.is_active == True UserOrganization.is_active,
) )
) )
) )
@@ -490,29 +509,25 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp
return user_org.role if user_org else None return user_org.role if user_org else None
except Exception as e: except Exception as e:
logger.error(f"Error getting user role in org: {str(e)}") logger.error(f"Error getting user role in org: {e!s}")
raise raise
async def is_user_org_owner( async def is_user_org_owner(
self, self, db: AsyncSession, *, user_id: UUID, organization_id: UUID
db: AsyncSession,
*,
user_id: UUID,
organization_id: UUID
) -> bool: ) -> bool:
"""Check if a user is an owner of an organization.""" """Check if a user is an owner of an organization."""
role = await self.get_user_role_in_org(db, user_id=user_id, organization_id=organization_id) role = await self.get_user_role_in_org(
db, user_id=user_id, organization_id=organization_id
)
return role == OrganizationRole.OWNER return role == OrganizationRole.OWNER
async def is_user_org_admin( async def is_user_org_admin(
self, self, db: AsyncSession, *, user_id: UUID, organization_id: UUID
db: AsyncSession,
*,
user_id: UUID,
organization_id: UUID
) -> bool: ) -> bool:
"""Check if a user is an owner or admin of an organization.""" """Check if a user is an owner or admin of an organization."""
role = await self.get_user_role_in_org(db, user_id=user_id, organization_id=organization_id) role = await self.get_user_role_in_org(
db, user_id=user_id, organization_id=organization_id
)
return role in [OrganizationRole.OWNER, OrganizationRole.ADMIN] return role in [OrganizationRole.OWNER, OrganizationRole.ADMIN]

View File

@@ -1,13 +1,13 @@
""" """
Async CRUD operations for user sessions using SQLAlchemy 2.0 patterns. Async CRUD operations for user sessions using SQLAlchemy 2.0 patterns.
""" """
import logging import logging
import uuid import uuid
from datetime import datetime, timezone, timedelta from datetime import UTC, datetime, timedelta
from typing import List, Optional
from uuid import UUID from uuid import UUID
from sqlalchemy import and_, select, update, delete, func from sqlalchemy import and_, delete, func, select, update
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import joinedload from sqlalchemy.orm import joinedload
@@ -21,7 +21,7 @@ logger = logging.getLogger(__name__)
class CRUDSession(CRUDBase[UserSession, SessionCreate, SessionUpdate]): class CRUDSession(CRUDBase[UserSession, SessionCreate, SessionUpdate]):
"""Async CRUD operations for user sessions.""" """Async CRUD operations for user sessions."""
async def get_by_jti(self, db: AsyncSession, *, jti: str) -> Optional[UserSession]: async def get_by_jti(self, db: AsyncSession, *, jti: str) -> UserSession | None:
""" """
Get session by refresh token JTI. Get session by refresh token JTI.
@@ -38,10 +38,12 @@ class CRUDSession(CRUDBase[UserSession, SessionCreate, SessionUpdate]):
) )
return result.scalar_one_or_none() return result.scalar_one_or_none()
except Exception as e: except Exception as e:
logger.error(f"Error getting session by JTI {jti}: {str(e)}") logger.error(f"Error getting session by JTI {jti}: {e!s}")
raise raise
async def get_active_by_jti(self, db: AsyncSession, *, jti: str) -> Optional[UserSession]: async def get_active_by_jti(
self, db: AsyncSession, *, jti: str
) -> UserSession | None:
""" """
Get active session by refresh token JTI. Get active session by refresh token JTI.
@@ -57,13 +59,13 @@ class CRUDSession(CRUDBase[UserSession, SessionCreate, SessionUpdate]):
select(UserSession).where( select(UserSession).where(
and_( and_(
UserSession.refresh_token_jti == jti, UserSession.refresh_token_jti == jti,
UserSession.is_active == True UserSession.is_active,
) )
) )
) )
return result.scalar_one_or_none() return result.scalar_one_or_none()
except Exception as e: except Exception as e:
logger.error(f"Error getting active session by JTI {jti}: {str(e)}") logger.error(f"Error getting active session by JTI {jti}: {e!s}")
raise raise
async def get_user_sessions( async def get_user_sessions(
@@ -72,8 +74,8 @@ class CRUDSession(CRUDBase[UserSession, SessionCreate, SessionUpdate]):
*, *,
user_id: str, user_id: str,
active_only: bool = True, active_only: bool = True,
with_user: bool = False with_user: bool = False,
) -> List[UserSession]: ) -> list[UserSession]:
""" """
Get all sessions for a user with optional eager loading. Get all sessions for a user with optional eager loading.
@@ -97,20 +99,17 @@ class CRUDSession(CRUDBase[UserSession, SessionCreate, SessionUpdate]):
query = query.options(joinedload(UserSession.user)) query = query.options(joinedload(UserSession.user))
if active_only: if active_only:
query = query.where(UserSession.is_active == True) query = query.where(UserSession.is_active)
query = query.order_by(UserSession.last_used_at.desc()) query = query.order_by(UserSession.last_used_at.desc())
result = await db.execute(query) result = await db.execute(query)
return list(result.scalars().all()) return list(result.scalars().all())
except Exception as e: except Exception as e:
logger.error(f"Error getting sessions for user {user_id}: {str(e)}") logger.error(f"Error getting sessions for user {user_id}: {e!s}")
raise raise
async def create_session( async def create_session(
self, self, db: AsyncSession, *, obj_in: SessionCreate
db: AsyncSession,
*,
obj_in: SessionCreate
) -> UserSession: ) -> UserSession:
""" """
Create a new user session. Create a new user session.
@@ -151,10 +150,12 @@ class CRUDSession(CRUDBase[UserSession, SessionCreate, SessionUpdate]):
return db_obj return db_obj
except Exception as e: except Exception as e:
await db.rollback() await db.rollback()
logger.error(f"Error creating session: {str(e)}", exc_info=True) logger.error(f"Error creating session: {e!s}", exc_info=True)
raise ValueError(f"Failed to create session: {str(e)}") raise ValueError(f"Failed to create session: {e!s}")
async def deactivate(self, db: AsyncSession, *, session_id: str) -> Optional[UserSession]: async def deactivate(
self, db: AsyncSession, *, session_id: str
) -> UserSession | None:
""" """
Deactivate a session (logout from device). Deactivate a session (logout from device).
@@ -184,14 +185,11 @@ class CRUDSession(CRUDBase[UserSession, SessionCreate, SessionUpdate]):
return session return session
except Exception as e: except Exception as e:
await db.rollback() await db.rollback()
logger.error(f"Error deactivating session {session_id}: {str(e)}") logger.error(f"Error deactivating session {session_id}: {e!s}")
raise raise
async def deactivate_all_user_sessions( async def deactivate_all_user_sessions(
self, self, db: AsyncSession, *, user_id: str
db: AsyncSession,
*,
user_id: str
) -> int: ) -> int:
""" """
Deactivate all active sessions for a user (logout from all devices). Deactivate all active sessions for a user (logout from all devices).
@@ -209,12 +207,7 @@ class CRUDSession(CRUDBase[UserSession, SessionCreate, SessionUpdate]):
stmt = ( stmt = (
update(UserSession) update(UserSession)
.where( .where(and_(UserSession.user_id == user_uuid, UserSession.is_active))
and_(
UserSession.user_id == user_uuid,
UserSession.is_active == True
)
)
.values(is_active=False) .values(is_active=False)
) )
@@ -228,14 +221,11 @@ class CRUDSession(CRUDBase[UserSession, SessionCreate, SessionUpdate]):
return count return count
except Exception as e: except Exception as e:
await db.rollback() await db.rollback()
logger.error(f"Error deactivating all sessions for user {user_id}: {str(e)}") logger.error(f"Error deactivating all sessions for user {user_id}: {e!s}")
raise raise
async def update_last_used( async def update_last_used(
self, self, db: AsyncSession, *, session: UserSession
db: AsyncSession,
*,
session: UserSession
) -> UserSession: ) -> UserSession:
""" """
Update the last_used_at timestamp for a session. Update the last_used_at timestamp for a session.
@@ -248,14 +238,14 @@ class CRUDSession(CRUDBase[UserSession, SessionCreate, SessionUpdate]):
Updated UserSession Updated UserSession
""" """
try: try:
session.last_used_at = datetime.now(timezone.utc) session.last_used_at = datetime.now(UTC)
db.add(session) db.add(session)
await db.commit() await db.commit()
await db.refresh(session) await db.refresh(session)
return session return session
except Exception as e: except Exception as e:
await db.rollback() await db.rollback()
logger.error(f"Error updating last_used for session {session.id}: {str(e)}") logger.error(f"Error updating last_used for session {session.id}: {e!s}")
raise raise
async def update_refresh_token( async def update_refresh_token(
@@ -264,7 +254,7 @@ class CRUDSession(CRUDBase[UserSession, SessionCreate, SessionUpdate]):
*, *,
session: UserSession, session: UserSession,
new_jti: str, new_jti: str,
new_expires_at: datetime new_expires_at: datetime,
) -> UserSession: ) -> UserSession:
""" """
Update session with new refresh token JTI and expiration. Update session with new refresh token JTI and expiration.
@@ -283,14 +273,16 @@ class CRUDSession(CRUDBase[UserSession, SessionCreate, SessionUpdate]):
try: try:
session.refresh_token_jti = new_jti session.refresh_token_jti = new_jti
session.expires_at = new_expires_at session.expires_at = new_expires_at
session.last_used_at = datetime.now(timezone.utc) session.last_used_at = datetime.now(UTC)
db.add(session) db.add(session)
await db.commit() await db.commit()
await db.refresh(session) await db.refresh(session)
return session return session
except Exception as e: except Exception as e:
await db.rollback() await db.rollback()
logger.error(f"Error updating refresh token for session {session.id}: {str(e)}") logger.error(
f"Error updating refresh token for session {session.id}: {e!s}"
)
raise raise
async def cleanup_expired(self, db: AsyncSession, *, keep_days: int = 30) -> int: async def cleanup_expired(self, db: AsyncSession, *, keep_days: int = 30) -> int:
@@ -311,15 +303,15 @@ class CRUDSession(CRUDBase[UserSession, SessionCreate, SessionUpdate]):
Number of sessions deleted Number of sessions deleted
""" """
try: try:
cutoff_date = datetime.now(timezone.utc) - timedelta(days=keep_days) cutoff_date = datetime.now(UTC) - timedelta(days=keep_days)
now = datetime.now(timezone.utc) now = datetime.now(UTC)
# Use bulk DELETE with WHERE clause - single query # Use bulk DELETE with WHERE clause - single query
stmt = delete(UserSession).where( stmt = delete(UserSession).where(
and_( and_(
UserSession.is_active == False, not UserSession.is_active,
UserSession.expires_at < now, UserSession.expires_at < now,
UserSession.created_at < cutoff_date UserSession.created_at < cutoff_date,
) )
) )
@@ -334,15 +326,10 @@ class CRUDSession(CRUDBase[UserSession, SessionCreate, SessionUpdate]):
return count return count
except Exception as e: except Exception as e:
await db.rollback() await db.rollback()
logger.error(f"Error cleaning up expired sessions: {str(e)}") logger.error(f"Error cleaning up expired sessions: {e!s}")
raise raise
async def cleanup_expired_for_user( async def cleanup_expired_for_user(self, db: AsyncSession, *, user_id: str) -> int:
self,
db: AsyncSession,
*,
user_id: str
) -> int:
""" """
Clean up expired and inactive sessions for a specific user. Clean up expired and inactive sessions for a specific user.
@@ -363,14 +350,14 @@ class CRUDSession(CRUDBase[UserSession, SessionCreate, SessionUpdate]):
logger.error(f"Invalid UUID format: {user_id}") logger.error(f"Invalid UUID format: {user_id}")
raise ValueError(f"Invalid user ID format: {user_id}") raise ValueError(f"Invalid user ID format: {user_id}")
now = datetime.now(timezone.utc) now = datetime.now(UTC)
# Use bulk DELETE with WHERE clause - single query # Use bulk DELETE with WHERE clause - single query
stmt = delete(UserSession).where( stmt = delete(UserSession).where(
and_( and_(
UserSession.user_id == uuid_obj, UserSession.user_id == uuid_obj,
UserSession.is_active == False, not UserSession.is_active,
UserSession.expires_at < now UserSession.expires_at < now,
) )
) )
@@ -388,7 +375,7 @@ class CRUDSession(CRUDBase[UserSession, SessionCreate, SessionUpdate]):
except Exception as e: except Exception as e:
await db.rollback() await db.rollback()
logger.error( logger.error(
f"Error cleaning up expired sessions for user {user_id}: {str(e)}" f"Error cleaning up expired sessions for user {user_id}: {e!s}"
) )
raise raise
@@ -409,15 +396,12 @@ class CRUDSession(CRUDBase[UserSession, SessionCreate, SessionUpdate]):
result = await db.execute( result = await db.execute(
select(func.count(UserSession.id)).where( select(func.count(UserSession.id)).where(
and_( and_(UserSession.user_id == user_uuid, UserSession.is_active)
UserSession.user_id == user_uuid,
UserSession.is_active == True
)
) )
) )
return result.scalar_one() return result.scalar_one()
except Exception as e: except Exception as e:
logger.error(f"Error counting sessions for user {user_id}: {str(e)}") logger.error(f"Error counting sessions for user {user_id}: {e!s}")
raise raise
async def get_all_sessions( async def get_all_sessions(
@@ -427,8 +411,8 @@ class CRUDSession(CRUDBase[UserSession, SessionCreate, SessionUpdate]):
skip: int = 0, skip: int = 0,
limit: int = 100, limit: int = 100,
active_only: bool = True, active_only: bool = True,
with_user: bool = True with_user: bool = True,
) -> tuple[List[UserSession], int]: ) -> tuple[list[UserSession], int]:
""" """
Get all sessions across all users with pagination (admin only). Get all sessions across all users with pagination (admin only).
@@ -451,18 +435,22 @@ class CRUDSession(CRUDBase[UserSession, SessionCreate, SessionUpdate]):
query = query.options(joinedload(UserSession.user)) query = query.options(joinedload(UserSession.user))
if active_only: if active_only:
query = query.where(UserSession.is_active == True) query = query.where(UserSession.is_active)
# Get total count # Get total count
count_query = select(func.count(UserSession.id)) count_query = select(func.count(UserSession.id))
if active_only: if active_only:
count_query = count_query.where(UserSession.is_active == True) count_query = count_query.where(UserSession.is_active)
count_result = await db.execute(count_query) count_result = await db.execute(count_query)
total = count_result.scalar_one() total = count_result.scalar_one()
# Apply pagination and ordering # Apply pagination and ordering
query = query.order_by(UserSession.last_used_at.desc()).offset(skip).limit(limit) query = (
query.order_by(UserSession.last_used_at.desc())
.offset(skip)
.limit(limit)
)
result = await db.execute(query) result = await db.execute(query)
sessions = list(result.scalars().all()) sessions = list(result.scalars().all())
@@ -470,7 +458,7 @@ class CRUDSession(CRUDBase[UserSession, SessionCreate, SessionUpdate]):
return sessions, total return sessions, total
except Exception as e: except Exception as e:
logger.error(f"Error getting all sessions: {str(e)}", exc_info=True) logger.error(f"Error getting all sessions: {e!s}", exc_info=True)
raise raise

View File

@@ -1,8 +1,9 @@
# app/crud/user_async.py # app/crud/user_async.py
"""Async CRUD operations for User model using SQLAlchemy 2.0 patterns.""" """Async CRUD operations for User model using SQLAlchemy 2.0 patterns."""
import logging import logging
from datetime import datetime, timezone from datetime import UTC, datetime
from typing import Optional, Union, Dict, Any, List, Tuple from typing import Any
from uuid import UUID from uuid import UUID
from sqlalchemy import or_, select, update from sqlalchemy import or_, select, update
@@ -20,15 +21,13 @@ logger = logging.getLogger(__name__)
class CRUDUser(CRUDBase[User, UserCreate, UserUpdate]): class CRUDUser(CRUDBase[User, UserCreate, UserUpdate]):
"""Async CRUD operations for User model.""" """Async CRUD operations for User model."""
async def get_by_email(self, db: AsyncSession, *, email: str) -> Optional[User]: async def get_by_email(self, db: AsyncSession, *, email: str) -> User | None:
"""Get user by email address.""" """Get user by email address."""
try: try:
result = await db.execute( result = await db.execute(select(User).where(User.email == email))
select(User).where(User.email == email)
)
return result.scalar_one_or_none() return result.scalar_one_or_none()
except Exception as e: except Exception as e:
logger.error(f"Error getting user by email {email}: {str(e)}") logger.error(f"Error getting user by email {email}: {e!s}")
raise raise
async def create(self, db: AsyncSession, *, obj_in: UserCreate) -> User: async def create(self, db: AsyncSession, *, obj_in: UserCreate) -> User:
@@ -42,9 +41,13 @@ class CRUDUser(CRUDBase[User, UserCreate, UserUpdate]):
password_hash=password_hash, password_hash=password_hash,
first_name=obj_in.first_name, first_name=obj_in.first_name,
last_name=obj_in.last_name, last_name=obj_in.last_name,
phone_number=obj_in.phone_number if hasattr(obj_in, 'phone_number') else None, phone_number=obj_in.phone_number
is_superuser=obj_in.is_superuser if hasattr(obj_in, 'is_superuser') else False, if hasattr(obj_in, "phone_number")
preferences={} else None,
is_superuser=obj_in.is_superuser
if hasattr(obj_in, "is_superuser")
else False,
preferences={},
) )
db.add(db_obj) db.add(db_obj)
await db.commit() await db.commit()
@@ -52,7 +55,7 @@ class CRUDUser(CRUDBase[User, UserCreate, UserUpdate]):
return db_obj return db_obj
except IntegrityError as e: except IntegrityError as e:
await db.rollback() await db.rollback()
error_msg = str(e.orig) if hasattr(e, 'orig') else str(e) error_msg = str(e.orig) if hasattr(e, "orig") else str(e)
if "email" in error_msg.lower(): if "email" in error_msg.lower():
logger.warning(f"Duplicate email attempted: {obj_in.email}") logger.warning(f"Duplicate email attempted: {obj_in.email}")
raise ValueError(f"User with email {obj_in.email} already exists") raise ValueError(f"User with email {obj_in.email} already exists")
@@ -60,15 +63,11 @@ class CRUDUser(CRUDBase[User, UserCreate, UserUpdate]):
raise ValueError(f"Database integrity error: {error_msg}") raise ValueError(f"Database integrity error: {error_msg}")
except Exception as e: except Exception as e:
await db.rollback() await db.rollback()
logger.error(f"Unexpected error creating user: {str(e)}", exc_info=True) logger.error(f"Unexpected error creating user: {e!s}", exc_info=True)
raise raise
async def update( async def update(
self, self, db: AsyncSession, *, db_obj: User, obj_in: UserUpdate | dict[str, Any]
db: AsyncSession,
*,
db_obj: User,
obj_in: Union[UserUpdate, Dict[str, Any]]
) -> User: ) -> User:
"""Update user with async password hashing if password is updated.""" """Update user with async password hashing if password is updated."""
if isinstance(obj_in, dict): if isinstance(obj_in, dict):
@@ -79,7 +78,9 @@ class CRUDUser(CRUDBase[User, UserCreate, UserUpdate]):
# Handle password separately if it exists in update data # Handle password separately if it exists in update data
# Hash password asynchronously to avoid blocking event loop # Hash password asynchronously to avoid blocking event loop
if "password" in update_data: if "password" in update_data:
update_data["password_hash"] = await get_password_hash_async(update_data["password"]) update_data["password_hash"] = await get_password_hash_async(
update_data["password"]
)
del update_data["password"] del update_data["password"]
return await super().update(db, db_obj=db_obj, obj_in=update_data) return await super().update(db, db_obj=db_obj, obj_in=update_data)
@@ -90,11 +91,11 @@ class CRUDUser(CRUDBase[User, UserCreate, UserUpdate]):
*, *,
skip: int = 0, skip: int = 0,
limit: int = 100, limit: int = 100,
sort_by: Optional[str] = None, sort_by: str | None = None,
sort_order: str = "asc", sort_order: str = "asc",
filters: Optional[Dict[str, Any]] = None, filters: dict[str, Any] | None = None,
search: Optional[str] = None search: str | None = None,
) -> Tuple[List[User], int]: ) -> tuple[list[User], int]:
""" """
Get multiple users with total count, filtering, sorting, and search. Get multiple users with total count, filtering, sorting, and search.
@@ -136,12 +137,13 @@ class CRUDUser(CRUDBase[User, UserCreate, UserUpdate]):
search_filter = or_( search_filter = or_(
User.email.ilike(f"%{search}%"), User.email.ilike(f"%{search}%"),
User.first_name.ilike(f"%{search}%"), User.first_name.ilike(f"%{search}%"),
User.last_name.ilike(f"%{search}%") User.last_name.ilike(f"%{search}%"),
) )
query = query.where(search_filter) query = query.where(search_filter)
# Get total count # Get total count
from sqlalchemy import func from sqlalchemy import func
count_query = select(func.count()).select_from(query.alias()) count_query = select(func.count()).select_from(query.alias())
count_result = await db.execute(count_query) count_result = await db.execute(count_query)
total = count_result.scalar_one() total = count_result.scalar_one()
@@ -162,15 +164,11 @@ class CRUDUser(CRUDBase[User, UserCreate, UserUpdate]):
return users, total return users, total
except Exception as e: except Exception as e:
logger.error(f"Error retrieving paginated users: {str(e)}") logger.error(f"Error retrieving paginated users: {e!s}")
raise raise
async def bulk_update_status( async def bulk_update_status(
self, self, db: AsyncSession, *, user_ids: list[UUID], is_active: bool
db: AsyncSession,
*,
user_ids: List[UUID],
is_active: bool
) -> int: ) -> int:
""" """
Bulk update is_active status for multiple users. Bulk update is_active status for multiple users.
@@ -192,7 +190,7 @@ class CRUDUser(CRUDBase[User, UserCreate, UserUpdate]):
update(User) update(User)
.where(User.id.in_(user_ids)) .where(User.id.in_(user_ids))
.where(User.deleted_at.is_(None)) # Don't update deleted users .where(User.deleted_at.is_(None)) # Don't update deleted users
.values(is_active=is_active, updated_at=datetime.now(timezone.utc)) .values(is_active=is_active, updated_at=datetime.now(UTC))
) )
result = await db.execute(stmt) result = await db.execute(stmt)
@@ -204,15 +202,15 @@ class CRUDUser(CRUDBase[User, UserCreate, UserUpdate]):
except Exception as e: except Exception as e:
await db.rollback() await db.rollback()
logger.error(f"Error bulk updating user status: {str(e)}", exc_info=True) logger.error(f"Error bulk updating user status: {e!s}", exc_info=True)
raise raise
async def bulk_soft_delete( async def bulk_soft_delete(
self, self,
db: AsyncSession, db: AsyncSession,
*, *,
user_ids: List[UUID], user_ids: list[UUID],
exclude_user_id: Optional[UUID] = None exclude_user_id: UUID | None = None,
) -> int: ) -> int:
""" """
Bulk soft delete multiple users. Bulk soft delete multiple users.
@@ -239,11 +237,13 @@ class CRUDUser(CRUDBase[User, UserCreate, UserUpdate]):
stmt = ( stmt = (
update(User) update(User)
.where(User.id.in_(filtered_ids)) .where(User.id.in_(filtered_ids))
.where(User.deleted_at.is_(None)) # Don't re-delete already deleted users .where(
User.deleted_at.is_(None)
) # Don't re-delete already deleted users
.values( .values(
deleted_at=datetime.now(timezone.utc), deleted_at=datetime.now(UTC),
is_active=False, is_active=False,
updated_at=datetime.now(timezone.utc) updated_at=datetime.now(UTC),
) )
) )
@@ -256,7 +256,7 @@ class CRUDUser(CRUDBase[User, UserCreate, UserUpdate]):
except Exception as e: except Exception as e:
await db.rollback() await db.rollback()
logger.error(f"Error bulk deleting users: {str(e)}", exc_info=True) logger.error(f"Error bulk deleting users: {e!s}", exc_info=True)
raise raise
def is_active(self, user: User) -> bool: def is_active(self, user: User) -> bool:

View File

@@ -4,9 +4,9 @@ Async database initialization script.
Creates the first superuser if configured and doesn't already exist. Creates the first superuser if configured and doesn't already exist.
""" """
import asyncio import asyncio
import logging import logging
from typing import Optional
from app.core.config import settings from app.core.config import settings
from app.core.database import SessionLocal, engine from app.core.database import SessionLocal, engine
@@ -17,7 +17,7 @@ from app.schemas.users import UserCreate
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
async def init_db() -> Optional[User]: async def init_db() -> User | None:
""" """
Initialize database with first superuser if settings are configured and user doesn't exist. Initialize database with first superuser if settings are configured and user doesn't exist.
@@ -49,7 +49,7 @@ async def init_db() -> Optional[User]:
password=superuser_password, password=superuser_password,
first_name="Admin", first_name="Admin",
last_name="User", last_name="User",
is_superuser=True is_superuser=True,
) )
user = await user_crud.create(session, obj_in=user_in) user = await user_crud.create(session, obj_in=user_in)
@@ -70,13 +70,13 @@ async def main():
# Configure logging to show info logs # Configure logging to show info logs
logging.basicConfig( logging.basicConfig(
level=logging.INFO, level=logging.INFO,
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s" format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
) )
try: try:
user = await init_db() user = await init_db()
if user: if user:
print(f"✓ Database initialized successfully") print("✓ Database initialized successfully")
print(f"✓ Superuser: {user.email}") print(f"✓ Superuser: {user.email}")
else: else:
print("✗ Failed to initialize database") print("✗ Failed to initialize database")

View File

@@ -2,10 +2,10 @@ import logging
import os import os
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
from datetime import datetime from datetime import datetime
from typing import Dict, Any from typing import Any
from apscheduler.schedulers.asyncio import AsyncIOScheduler from apscheduler.schedulers.asyncio import AsyncIOScheduler
from fastapi import FastAPI, status, Request, HTTPException from fastapi import FastAPI, HTTPException, Request, status
from fastapi.exceptions import RequestValidationError from fastapi.exceptions import RequestValidationError
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import HTMLResponse, JSONResponse from fastapi.responses import HTMLResponse, JSONResponse
@@ -19,9 +19,9 @@ from app.core.database import check_database_health
from app.core.exceptions import ( from app.core.exceptions import (
APIException, APIException,
api_exception_handler, api_exception_handler,
validation_exception_handler,
http_exception_handler, http_exception_handler,
unhandled_exception_handler unhandled_exception_handler,
validation_exception_handler,
) )
scheduler = AsyncIOScheduler() scheduler = AsyncIOScheduler()
@@ -52,11 +52,11 @@ async def lifespan(app: FastAPI):
# Runs daily at 2:00 AM server time # Runs daily at 2:00 AM server time
scheduler.add_job( scheduler.add_job(
cleanup_expired_sessions, cleanup_expired_sessions,
'cron', "cron",
hour=2, hour=2,
minute=0, minute=0,
id='cleanup_expired_sessions', id="cleanup_expired_sessions",
replace_existing=True replace_existing=True,
) )
scheduler.start() scheduler.start()
@@ -73,12 +73,12 @@ async def lifespan(app: FastAPI):
logger.info("Scheduled jobs stopped") logger.info("Scheduled jobs stopped")
logger.info(f"Starting app!!!") logger.info("Starting app!!!")
app = FastAPI( app = FastAPI(
title=settings.PROJECT_NAME, title=settings.PROJECT_NAME,
version=settings.VERSION, version=settings.VERSION,
openapi_url=f"{settings.API_V1_STR}/openapi.json", openapi_url=f"{settings.API_V1_STR}/openapi.json",
lifespan=lifespan lifespan=lifespan,
) )
# Add rate limiter state to app # Add rate limiter state to app
@@ -96,7 +96,14 @@ app.add_middleware(
CORSMiddleware, CORSMiddleware,
allow_origins=settings.BACKEND_CORS_ORIGINS, allow_origins=settings.BACKEND_CORS_ORIGINS,
allow_credentials=True, allow_credentials=True,
allow_methods=["GET", "POST", "PUT", "PATCH", "DELETE", "OPTIONS"], # Explicit methods only allow_methods=[
"GET",
"POST",
"PUT",
"PATCH",
"DELETE",
"OPTIONS",
], # Explicit methods only
allow_headers=[ allow_headers=[
"Content-Type", "Content-Type",
"Authorization", "Authorization",
@@ -129,12 +136,14 @@ async def limit_request_size(request: Request, call_next):
status_code=status.HTTP_413_REQUEST_ENTITY_TOO_LARGE, status_code=status.HTTP_413_REQUEST_ENTITY_TOO_LARGE,
content={ content={
"success": False, "success": False,
"errors": [{ "errors": [
"code": "REQUEST_TOO_LARGE", {
"message": f"Request body too large. Maximum size is {MAX_REQUEST_SIZE // (1024 * 1024)}MB", "code": "REQUEST_TOO_LARGE",
"field": None "message": f"Request body too large. Maximum size is {MAX_REQUEST_SIZE // (1024 * 1024)}MB",
}] "field": None,
} }
],
},
) )
response = await call_next(request) response = await call_next(request)
@@ -165,15 +174,19 @@ async def add_security_headers(request: Request, call_next):
# Enforce HTTPS in production # Enforce HTTPS in production
if settings.ENVIRONMENT == "production": if settings.ENVIRONMENT == "production":
response.headers["Strict-Transport-Security"] = "max-age=31536000; includeSubDomains" response.headers["Strict-Transport-Security"] = (
"max-age=31536000; includeSubDomains"
)
# Content Security Policy # Content Security Policy
csp_mode = settings.CSP_MODE.lower() csp_mode = settings.CSP_MODE.lower()
# Special handling for API docs # Special handling for API docs
is_docs = request.url.path in ["/docs", "/redoc"] or \ is_docs = (
request.url.path.startswith("/docs/") or \ request.url.path in ["/docs", "/redoc"]
request.url.path.startswith("/redoc/") or request.url.path.startswith("/docs/")
or request.url.path.startswith("/redoc/")
)
if csp_mode == "disabled": if csp_mode == "disabled":
# No CSP (only for local development/debugging) # No CSP (only for local development/debugging)
@@ -264,7 +277,7 @@ async def root():
description="Check the health status of the API and its dependencies", description="Check the health status of the API and its dependencies",
response_description="Health status information", response_description="Health status information",
tags=["Health"], tags=["Health"],
operation_id="health_check" operation_id="health_check",
) )
async def health_check() -> JSONResponse: async def health_check() -> JSONResponse:
""" """
@@ -278,12 +291,12 @@ async def health_check() -> JSONResponse:
- environment: Current environment (development, staging, production) - environment: Current environment (development, staging, production)
- database: Database connectivity status - database: Database connectivity status
""" """
health_status: Dict[str, Any] = { health_status: dict[str, Any] = {
"status": "healthy", "status": "healthy",
"timestamp": datetime.utcnow().isoformat() + "Z", "timestamp": datetime.utcnow().isoformat() + "Z",
"version": settings.VERSION, "version": settings.VERSION,
"environment": settings.ENVIRONMENT, "environment": settings.ENVIRONMENT,
"checks": {} "checks": {},
} }
response_status = status.HTTP_200_OK response_status = status.HTTP_200_OK
@@ -294,7 +307,7 @@ async def health_check() -> JSONResponse:
if db_healthy: if db_healthy:
health_status["checks"]["database"] = { health_status["checks"]["database"] = {
"status": "healthy", "status": "healthy",
"message": "Database connection successful" "message": "Database connection successful",
} }
else: else:
raise Exception("Database health check returned unhealthy status") raise Exception("Database health check returned unhealthy status")
@@ -302,15 +315,12 @@ async def health_check() -> JSONResponse:
health_status["status"] = "unhealthy" health_status["status"] = "unhealthy"
health_status["checks"]["database"] = { health_status["checks"]["database"] = {
"status": "unhealthy", "status": "unhealthy",
"message": f"Database connection failed: {str(e)}" "message": f"Database connection failed: {e!s}",
} }
response_status = status.HTTP_503_SERVICE_UNAVAILABLE response_status = status.HTTP_503_SERVICE_UNAVAILABLE
logger.error(f"Health check failed - database error: {e}") logger.error(f"Health check failed - database error: {e}")
return JSONResponse( return JSONResponse(status_code=response_status, content=health_status)
status_code=response_status,
content=health_status
)
app.include_router(api_router, prefix=settings.API_V1_STR) app.include_router(api_router, prefix=settings.API_V1_STR)

View File

@@ -2,17 +2,25 @@
Models package initialization. Models package initialization.
Imports all models to ensure they're registered with SQLAlchemy. Imports all models to ensure they're registered with SQLAlchemy.
""" """
# First import Base to avoid circular imports # First import Base to avoid circular imports
from app.core.database import Base from app.core.database import Base
from .base import TimestampMixin, UUIDMixin from .base import TimestampMixin, UUIDMixin
from .organization import Organization from .organization import Organization
# Import models # Import models
from .user import User from .user import User
from .user_organization import UserOrganization, OrganizationRole from .user_organization import OrganizationRole, UserOrganization
from .user_session import UserSession from .user_session import UserSession
__all__ = [ __all__ = [
'Base', 'TimestampMixin', 'UUIDMixin', "Base",
'User', 'UserSession', "Organization",
'Organization', 'UserOrganization', 'OrganizationRole', "OrganizationRole",
"TimestampMixin",
"UUIDMixin",
"User",
"UserOrganization",
"UserSession",
] ]

View File

@@ -1,20 +1,27 @@
import uuid import uuid
from datetime import datetime, timezone from datetime import UTC, datetime
from sqlalchemy import Column, DateTime from sqlalchemy import Column, DateTime
from sqlalchemy.dialects.postgresql import UUID from sqlalchemy.dialects.postgresql import UUID
# noinspection PyUnresolvedReferences # noinspection PyUnresolvedReferences
from app.core.database import Base
class TimestampMixin: class TimestampMixin:
"""Mixin to add created_at and updated_at timestamps to models""" """Mixin to add created_at and updated_at timestamps to models"""
created_at = Column(DateTime(timezone=True), default=lambda: datetime.now(timezone.utc), nullable=False)
updated_at = Column(DateTime(timezone=True), default=lambda: datetime.now(timezone.utc), created_at = Column(
onupdate=lambda: datetime.now(timezone.utc), nullable=False) DateTime(timezone=True), default=lambda: datetime.now(UTC), nullable=False
)
updated_at = Column(
DateTime(timezone=True),
default=lambda: datetime.now(UTC),
onupdate=lambda: datetime.now(UTC),
nullable=False,
)
class UUIDMixin: class UUIDMixin:
"""Mixin to add UUID primary keys to models""" """Mixin to add UUID primary keys to models"""
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)

View File

@@ -1,5 +1,5 @@
# app/models/organization.py # app/models/organization.py
from sqlalchemy import Column, String, Boolean, Text, Index from sqlalchemy import Boolean, Column, Index, String, Text
from sqlalchemy.dialects.postgresql import JSONB from sqlalchemy.dialects.postgresql import JSONB
from sqlalchemy.orm import relationship from sqlalchemy.orm import relationship
@@ -11,7 +11,8 @@ class Organization(Base, UUIDMixin, TimestampMixin):
Organization model for multi-tenant support. Organization model for multi-tenant support.
Users can belong to multiple organizations with different roles. Users can belong to multiple organizations with different roles.
""" """
__tablename__ = 'organizations'
__tablename__ = "organizations"
name = Column(String(255), nullable=False, index=True) name = Column(String(255), nullable=False, index=True)
slug = Column(String(255), unique=True, nullable=False, index=True) slug = Column(String(255), unique=True, nullable=False, index=True)
@@ -20,11 +21,13 @@ class Organization(Base, UUIDMixin, TimestampMixin):
settings = Column(JSONB, default={}) settings = Column(JSONB, default={})
# Relationships # Relationships
user_organizations = relationship("UserOrganization", back_populates="organization", cascade="all, delete-orphan") user_organizations = relationship(
"UserOrganization", back_populates="organization", cascade="all, delete-orphan"
)
__table_args__ = ( __table_args__ = (
Index('ix_organizations_name_active', 'name', 'is_active'), Index("ix_organizations_name_active", "name", "is_active"),
Index('ix_organizations_slug_active', 'slug', 'is_active'), Index("ix_organizations_slug_active", "slug", "is_active"),
) )
def __repr__(self): def __repr__(self):

View File

@@ -1,4 +1,4 @@
from sqlalchemy import Column, String, Boolean, DateTime from sqlalchemy import Boolean, Column, DateTime, String
from sqlalchemy.dialects.postgresql import JSONB from sqlalchemy.dialects.postgresql import JSONB
from sqlalchemy.orm import relationship from sqlalchemy.orm import relationship
@@ -6,7 +6,7 @@ from .base import Base, TimestampMixin, UUIDMixin
class User(Base, UUIDMixin, TimestampMixin): class User(Base, UUIDMixin, TimestampMixin):
__tablename__ = 'users' __tablename__ = "users"
email = Column(String(255), unique=True, nullable=False, index=True) email = Column(String(255), unique=True, nullable=False, index=True)
password_hash = Column(String(255), nullable=False) password_hash = Column(String(255), nullable=False)
@@ -19,7 +19,9 @@ class User(Base, UUIDMixin, TimestampMixin):
deleted_at = Column(DateTime(timezone=True), nullable=True, index=True) deleted_at = Column(DateTime(timezone=True), nullable=True, index=True)
# Relationships # Relationships
user_organizations = relationship("UserOrganization", back_populates="user", cascade="all, delete-orphan") user_organizations = relationship(
"UserOrganization", back_populates="user", cascade="all, delete-orphan"
)
def __repr__(self): def __repr__(self):
return f"<User {self.email}>" return f"<User {self.email}>"

View File

@@ -1,7 +1,7 @@
# app/models/user_organization.py # app/models/user_organization.py
from enum import Enum as PyEnum from enum import Enum as PyEnum
from sqlalchemy import Column, ForeignKey, Boolean, String, Index, Enum from sqlalchemy import Boolean, Column, Enum, ForeignKey, Index, String
from sqlalchemy.dialects.postgresql import UUID as PGUUID from sqlalchemy.dialects.postgresql import UUID as PGUUID
from sqlalchemy.orm import relationship from sqlalchemy.orm import relationship
@@ -14,6 +14,7 @@ class OrganizationRole(str, PyEnum):
These provide a baseline role system that can be optionally used. These provide a baseline role system that can be optionally used.
Projects can extend this or implement their own permission system. Projects can extend this or implement their own permission system.
""" """
OWNER = "owner" # Full control over organization OWNER = "owner" # Full control over organization
ADMIN = "admin" # Can manage users and settings ADMIN = "admin" # Can manage users and settings
MEMBER = "member" # Regular member with standard access MEMBER = "member" # Regular member with standard access
@@ -25,25 +26,41 @@ class UserOrganization(Base, TimestampMixin):
Junction table for many-to-many relationship between Users and Organizations. Junction table for many-to-many relationship between Users and Organizations.
Includes role information for flexible RBAC. Includes role information for flexible RBAC.
""" """
__tablename__ = 'user_organizations'
user_id = Column(PGUUID(as_uuid=True), ForeignKey('users.id', ondelete='CASCADE'), primary_key=True) __tablename__ = "user_organizations"
organization_id = Column(PGUUID(as_uuid=True), ForeignKey('organizations.id', ondelete='CASCADE'), primary_key=True)
role = Column(Enum(OrganizationRole), default=OrganizationRole.MEMBER, nullable=False, index=True) user_id = Column(
PGUUID(as_uuid=True),
ForeignKey("users.id", ondelete="CASCADE"),
primary_key=True,
)
organization_id = Column(
PGUUID(as_uuid=True),
ForeignKey("organizations.id", ondelete="CASCADE"),
primary_key=True,
)
role = Column(
Enum(OrganizationRole),
default=OrganizationRole.MEMBER,
nullable=False,
index=True,
)
is_active = Column(Boolean, default=True, nullable=False, index=True) is_active = Column(Boolean, default=True, nullable=False, index=True)
# Optional: Custom permissions override for specific users # Optional: Custom permissions override for specific users
custom_permissions = Column(String(500), nullable=True) # JSON array of permission strings custom_permissions = Column(
String(500), nullable=True
) # JSON array of permission strings
# Relationships # Relationships
user = relationship("User", back_populates="user_organizations") user = relationship("User", back_populates="user_organizations")
organization = relationship("Organization", back_populates="user_organizations") organization = relationship("Organization", back_populates="user_organizations")
__table_args__ = ( __table_args__ = (
Index('ix_user_org_user_active', 'user_id', 'is_active'), Index("ix_user_org_user_active", "user_id", "is_active"),
Index('ix_user_org_org_active', 'organization_id', 'is_active'), Index("ix_user_org_org_active", "organization_id", "is_active"),
Index('ix_user_org_role', 'role'), Index("ix_user_org_role", "role"),
) )
def __repr__(self): def __repr__(self):

View File

@@ -6,7 +6,10 @@ This allows users to:
- Logout from specific devices - Logout from specific devices
- Manage their active sessions - Manage their active sessions
""" """
from sqlalchemy import Column, String, Boolean, DateTime, ForeignKey, Index
from datetime import UTC
from sqlalchemy import Boolean, Column, DateTime, ForeignKey, Index, String
from sqlalchemy.dialects.postgresql import UUID from sqlalchemy.dialects.postgresql import UUID
from sqlalchemy.orm import relationship from sqlalchemy.orm import relationship
@@ -20,19 +23,27 @@ class UserSession(Base, UUIDMixin, TimestampMixin):
Each time a user logs in from a device, a new session is created. Each time a user logs in from a device, a new session is created.
Sessions are identified by the refresh token JTI (JWT ID). Sessions are identified by the refresh token JTI (JWT ID).
""" """
__tablename__ = 'user_sessions'
__tablename__ = "user_sessions"
# Foreign key to user # Foreign key to user
user_id = Column(UUID(as_uuid=True), ForeignKey('users.id', ondelete='CASCADE'), nullable=False, index=True) user_id = Column(
UUID(as_uuid=True),
ForeignKey("users.id", ondelete="CASCADE"),
nullable=False,
index=True,
)
# Refresh token identifier (JWT ID from the refresh token) # Refresh token identifier (JWT ID from the refresh token)
refresh_token_jti = Column(String(255), unique=True, nullable=False, index=True) refresh_token_jti = Column(String(255), unique=True, nullable=False, index=True)
# Device information # Device information
device_name = Column(String(255), nullable=True) # "iPhone 14", "Chrome on MacBook" device_name = Column(String(255), nullable=True) # "iPhone 14", "Chrome on MacBook"
device_id = Column(String(255), nullable=True) # Persistent device identifier (from client) device_id = Column(
ip_address = Column(String(45), nullable=True) # IPv4 (15 chars) or IPv6 (45 chars) String(255), nullable=True
user_agent = Column(String(500), nullable=True) # Browser/app user agent ) # Persistent device identifier (from client)
ip_address = Column(String(45), nullable=True) # IPv4 (15 chars) or IPv6 (45 chars)
user_agent = Column(String(500), nullable=True) # Browser/app user agent
# Session timing # Session timing
last_used_at = Column(DateTime(timezone=True), nullable=False) last_used_at = Column(DateTime(timezone=True), nullable=False)
@@ -50,8 +61,8 @@ class UserSession(Base, UUIDMixin, TimestampMixin):
# Composite indexes for performance (defined in migration) # Composite indexes for performance (defined in migration)
__table_args__ = ( __table_args__ = (
Index('ix_user_sessions_user_active', 'user_id', 'is_active'), Index("ix_user_sessions_user_active", "user_id", "is_active"),
Index('ix_user_sessions_jti_active', 'refresh_token_jti', 'is_active'), Index("ix_user_sessions_jti_active", "refresh_token_jti", "is_active"),
) )
def __repr__(self): def __repr__(self):
@@ -60,21 +71,24 @@ class UserSession(Base, UUIDMixin, TimestampMixin):
@property @property
def is_expired(self) -> bool: def is_expired(self) -> bool:
"""Check if session has expired.""" """Check if session has expired."""
from datetime import datetime, timezone from datetime import datetime
return self.expires_at < datetime.now(timezone.utc)
return self.expires_at < datetime.now(UTC)
def to_dict(self): def to_dict(self):
"""Convert session to dictionary for serialization.""" """Convert session to dictionary for serialization."""
return { return {
'id': str(self.id), "id": str(self.id),
'user_id': str(self.user_id), "user_id": str(self.user_id),
'device_name': self.device_name, "device_name": self.device_name,
'device_id': self.device_id, "device_id": self.device_id,
'ip_address': self.ip_address, "ip_address": self.ip_address,
'last_used_at': self.last_used_at.isoformat() if self.last_used_at else None, "last_used_at": self.last_used_at.isoformat()
'expires_at': self.expires_at.isoformat() if self.expires_at else None, if self.last_used_at
'is_active': self.is_active, else None,
'location_city': self.location_city, "expires_at": self.expires_at.isoformat() if self.expires_at else None,
'location_country': self.location_country, "is_active": self.is_active,
'created_at': self.created_at.isoformat() if self.created_at else None, "location_city": self.location_city,
"location_country": self.location_country,
"created_at": self.created_at.isoformat() if self.created_at else None,
} }

View File

@@ -1,18 +1,20 @@
""" """
Common schemas used across the API for pagination, responses, filtering, and sorting. Common schemas used across the API for pagination, responses, filtering, and sorting.
""" """
from enum import Enum from enum import Enum
from math import ceil from math import ceil
from typing import Generic, TypeVar, List, Optional from typing import TypeVar
from uuid import UUID from uuid import UUID
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
T = TypeVar('T') T = TypeVar("T")
class SortOrder(str, Enum): class SortOrder(str, Enum):
"""Sort order options.""" """Sort order options."""
ASC = "asc" ASC = "asc"
DESC = "desc" DESC = "desc"
@@ -20,16 +22,9 @@ class SortOrder(str, Enum):
class PaginationParams(BaseModel): class PaginationParams(BaseModel):
"""Parameters for pagination.""" """Parameters for pagination."""
page: int = Field( page: int = Field(default=1, ge=1, description="Page number (1-indexed)")
default=1,
ge=1,
description="Page number (1-indexed)"
)
limit: int = Field( limit: int = Field(
default=20, default=20, ge=1, le=100, description="Number of items per page (max 100)"
ge=1,
le=100,
description="Number of items per page (max 100)"
) )
@property @property
@@ -42,34 +37,20 @@ class PaginationParams(BaseModel):
"""Alias for offset (compatibility with existing code).""" """Alias for offset (compatibility with existing code)."""
return self.offset return self.offset
model_config = { model_config = {"json_schema_extra": {"example": {"page": 1, "limit": 20}}}
"json_schema_extra": {
"example": {
"page": 1,
"limit": 20
}
}
}
class SortParams(BaseModel): class SortParams(BaseModel):
"""Parameters for sorting.""" """Parameters for sorting."""
sort_by: Optional[str] = Field( sort_by: str | None = Field(default=None, description="Field name to sort by")
default=None,
description="Field name to sort by"
)
sort_order: SortOrder = Field( sort_order: SortOrder = Field(
default=SortOrder.ASC, default=SortOrder.ASC, description="Sort order (asc or desc)"
description="Sort order (asc or desc)"
) )
model_config = { model_config = {
"json_schema_extra": { "json_schema_extra": {
"example": { "example": {"sort_by": "created_at", "sort_order": "desc"}
"sort_by": "created_at",
"sort_order": "desc"
}
} }
} }
@@ -92,32 +73,30 @@ class PaginationMeta(BaseModel):
"page_size": 20, "page_size": 20,
"total_pages": 8, "total_pages": 8,
"has_next": True, "has_next": True,
"has_prev": False "has_prev": False,
} }
} }
} }
class PaginatedResponse(BaseModel, Generic[T]): class PaginatedResponse[T](BaseModel):
"""Generic paginated response wrapper.""" """Generic paginated response wrapper."""
data: List[T] = Field(..., description="List of items") data: list[T] = Field(..., description="List of items")
pagination: PaginationMeta = Field(..., description="Pagination metadata") pagination: PaginationMeta = Field(..., description="Pagination metadata")
model_config = { model_config = {
"json_schema_extra": { "json_schema_extra": {
"example": { "example": {
"data": [ "data": [{"id": "123", "name": "Example Item"}],
{"id": "123", "name": "Example Item"}
],
"pagination": { "pagination": {
"total": 150, "total": 150,
"page": 1, "page": 1,
"page_size": 20, "page_size": 20,
"total_pages": 8, "total_pages": 8,
"has_next": True, "has_next": True,
"has_prev": False "has_prev": False,
} },
} }
} }
} }
@@ -131,10 +110,7 @@ class MessageResponse(BaseModel):
model_config = { model_config = {
"json_schema_extra": { "json_schema_extra": {
"example": { "example": {"success": True, "message": "Operation completed successfully"}
"success": True,
"message": "Operation completed successfully"
}
} }
} }
@@ -142,11 +118,11 @@ class MessageResponse(BaseModel):
class BulkActionRequest(BaseModel): class BulkActionRequest(BaseModel):
"""Request schema for bulk operations on multiple items.""" """Request schema for bulk operations on multiple items."""
ids: List[UUID] = Field( ids: list[UUID] = Field(
..., ...,
min_length=1, min_length=1,
max_length=100, max_length=100,
description="List of item IDs to perform action on (max 100)" description="List of item IDs to perform action on (max 100)",
) )
model_config = { model_config = {
@@ -154,7 +130,7 @@ class BulkActionRequest(BaseModel):
"example": { "example": {
"ids": [ "ids": [
"550e8400-e29b-41d4-a716-446655440000", "550e8400-e29b-41d4-a716-446655440000",
"6ba7b810-9dad-11d1-80b4-00c04fd430c8" "6ba7b810-9dad-11d1-80b4-00c04fd430c8",
] ]
} }
} }
@@ -166,24 +142,23 @@ class BulkActionResponse(BaseModel):
success: bool = Field(default=True, description="Operation success status") success: bool = Field(default=True, description="Operation success status")
message: str = Field(..., description="Human-readable message") message: str = Field(..., description="Human-readable message")
affected_count: int = Field(..., description="Number of items affected by the operation") affected_count: int = Field(
..., description="Number of items affected by the operation"
)
model_config = { model_config = {
"json_schema_extra": { "json_schema_extra": {
"example": { "example": {
"success": True, "success": True,
"message": "Successfully deactivated 5 users", "message": "Successfully deactivated 5 users",
"affected_count": 5 "affected_count": 5,
} }
} }
} }
def create_pagination_meta( def create_pagination_meta(
total: int, total: int, page: int, limit: int, items_count: int
page: int,
limit: int,
items_count: int
) -> PaginationMeta: ) -> PaginationMeta:
""" """
Helper function to create pagination metadata. Helper function to create pagination metadata.
@@ -205,5 +180,5 @@ def create_pagination_meta(
page_size=items_count, page_size=items_count,
total_pages=total_pages, total_pages=total_pages,
has_next=page < total_pages, has_next=page < total_pages,
has_prev=page > 1 has_prev=page > 1,
) )

View File

@@ -1,8 +1,8 @@
""" """
Error schemas for standardized API error responses. Error schemas for standardized API error responses.
""" """
from enum import Enum from enum import Enum
from typing import List, Optional
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
@@ -53,14 +53,14 @@ class ErrorDetail(BaseModel):
code: ErrorCode = Field(..., description="Machine-readable error code") code: ErrorCode = Field(..., description="Machine-readable error code")
message: str = Field(..., description="Human-readable error message") message: str = Field(..., description="Human-readable error message")
field: Optional[str] = Field(None, description="Field name if error is field-specific") field: str | None = Field(None, description="Field name if error is field-specific")
model_config = { model_config = {
"json_schema_extra": { "json_schema_extra": {
"example": { "example": {
"code": "VAL_002", "code": "VAL_002",
"message": "Password must be at least 8 characters long", "message": "Password must be at least 8 characters long",
"field": "password" "field": "password",
} }
} }
} }
@@ -70,7 +70,7 @@ class ErrorResponse(BaseModel):
"""Standardized error response format.""" """Standardized error response format."""
success: bool = Field(default=False, description="Always false for error responses") success: bool = Field(default=False, description="Always false for error responses")
errors: List[ErrorDetail] = Field(..., description="List of errors that occurred") errors: list[ErrorDetail] = Field(..., description="List of errors that occurred")
model_config = { model_config = {
"json_schema_extra": { "json_schema_extra": {
@@ -80,9 +80,9 @@ class ErrorResponse(BaseModel):
{ {
"code": "AUTH_001", "code": "AUTH_001",
"message": "Invalid email or password", "message": "Invalid email or password",
"field": None "field": None,
} }
] ],
} }
} }
} }

View File

@@ -1,10 +1,10 @@
# app/schemas/organizations.py # app/schemas/organizations.py
import re import re
from datetime import datetime from datetime import datetime
from typing import Optional, Dict, Any, List from typing import Any
from uuid import UUID from uuid import UUID
from pydantic import BaseModel, field_validator, ConfigDict, Field from pydantic import BaseModel, ConfigDict, Field, field_validator
from app.models.user_organization import OrganizationRole from app.models.user_organization import OrganizationRole
@@ -12,85 +12,94 @@ from app.models.user_organization import OrganizationRole
# Organization Schemas # Organization Schemas
class OrganizationBase(BaseModel): class OrganizationBase(BaseModel):
"""Base organization schema with common fields.""" """Base organization schema with common fields."""
name: str = Field(..., min_length=1, max_length=255)
slug: Optional[str] = Field(None, min_length=1, max_length=255)
description: Optional[str] = None
is_active: bool = True
settings: Optional[Dict[str, Any]] = {}
@field_validator('slug') name: str = Field(..., min_length=1, max_length=255)
slug: str | None = Field(None, min_length=1, max_length=255)
description: str | None = None
is_active: bool = True
settings: dict[str, Any] | None = {}
@field_validator("slug")
@classmethod @classmethod
def validate_slug(cls, v: Optional[str]) -> Optional[str]: def validate_slug(cls, v: str | None) -> str | None:
"""Validate slug format: lowercase, alphanumeric, hyphens only.""" """Validate slug format: lowercase, alphanumeric, hyphens only."""
if v is None: if v is None:
return v return v
if not re.match(r'^[a-z0-9-]+$', v): if not re.match(r"^[a-z0-9-]+$", v):
raise ValueError('Slug must contain only lowercase letters, numbers, and hyphens') raise ValueError(
if v.startswith('-') or v.endswith('-'): "Slug must contain only lowercase letters, numbers, and hyphens"
raise ValueError('Slug cannot start or end with a hyphen') )
if '--' in v: if v.startswith("-") or v.endswith("-"):
raise ValueError('Slug cannot contain consecutive hyphens') raise ValueError("Slug cannot start or end with a hyphen")
if "--" in v:
raise ValueError("Slug cannot contain consecutive hyphens")
return v return v
@field_validator('name') @field_validator("name")
@classmethod @classmethod
def validate_name(cls, v: str) -> str: def validate_name(cls, v: str) -> str:
"""Validate organization name.""" """Validate organization name."""
if not v or v.strip() == "": if not v or v.strip() == "":
raise ValueError('Organization name cannot be empty') raise ValueError("Organization name cannot be empty")
return v.strip() return v.strip()
class OrganizationCreate(OrganizationBase): class OrganizationCreate(OrganizationBase):
"""Schema for creating a new organization.""" """Schema for creating a new organization."""
name: str = Field(..., min_length=1, max_length=255) name: str = Field(..., min_length=1, max_length=255)
slug: str = Field(..., min_length=1, max_length=255) slug: str = Field(..., min_length=1, max_length=255)
class OrganizationUpdate(BaseModel): class OrganizationUpdate(BaseModel):
"""Schema for updating an organization.""" """Schema for updating an organization."""
name: Optional[str] = Field(None, min_length=1, max_length=255)
slug: Optional[str] = Field(None, min_length=1, max_length=255)
description: Optional[str] = None
is_active: Optional[bool] = None
settings: Optional[Dict[str, Any]] = None
@field_validator('slug') name: str | None = Field(None, min_length=1, max_length=255)
slug: str | None = Field(None, min_length=1, max_length=255)
description: str | None = None
is_active: bool | None = None
settings: dict[str, Any] | None = None
@field_validator("slug")
@classmethod @classmethod
def validate_slug(cls, v: Optional[str]) -> Optional[str]: def validate_slug(cls, v: str | None) -> str | None:
"""Validate slug format.""" """Validate slug format."""
if v is None: if v is None:
return v return v
if not re.match(r'^[a-z0-9-]+$', v): if not re.match(r"^[a-z0-9-]+$", v):
raise ValueError('Slug must contain only lowercase letters, numbers, and hyphens') raise ValueError(
if v.startswith('-') or v.endswith('-'): "Slug must contain only lowercase letters, numbers, and hyphens"
raise ValueError('Slug cannot start or end with a hyphen') )
if '--' in v: if v.startswith("-") or v.endswith("-"):
raise ValueError('Slug cannot contain consecutive hyphens') raise ValueError("Slug cannot start or end with a hyphen")
if "--" in v:
raise ValueError("Slug cannot contain consecutive hyphens")
return v return v
@field_validator('name') @field_validator("name")
@classmethod @classmethod
def validate_name(cls, v: Optional[str]) -> Optional[str]: def validate_name(cls, v: str | None) -> str | None:
"""Validate organization name.""" """Validate organization name."""
if v is not None and (not v or v.strip() == ""): if v is not None and (not v or v.strip() == ""):
raise ValueError('Organization name cannot be empty') raise ValueError("Organization name cannot be empty")
return v.strip() if v else v return v.strip() if v else v
class OrganizationResponse(OrganizationBase): class OrganizationResponse(OrganizationBase):
"""Schema for organization API responses.""" """Schema for organization API responses."""
id: UUID id: UUID
created_at: datetime created_at: datetime
updated_at: Optional[datetime] = None updated_at: datetime | None = None
member_count: Optional[int] = 0 member_count: int | None = 0
model_config = ConfigDict(from_attributes=True) model_config = ConfigDict(from_attributes=True)
class OrganizationListResponse(BaseModel): class OrganizationListResponse(BaseModel):
"""Schema for paginated organization list responses.""" """Schema for paginated organization list responses."""
organizations: List[OrganizationResponse]
organizations: list[OrganizationResponse]
total: int total: int
page: int page: int
page_size: int page_size: int
@@ -100,44 +109,49 @@ class OrganizationListResponse(BaseModel):
# User-Organization Relationship Schemas # User-Organization Relationship Schemas
class UserOrganizationBase(BaseModel): class UserOrganizationBase(BaseModel):
"""Base schema for user-organization relationship.""" """Base schema for user-organization relationship."""
role: OrganizationRole = OrganizationRole.MEMBER role: OrganizationRole = OrganizationRole.MEMBER
is_active: bool = True is_active: bool = True
custom_permissions: Optional[str] = None custom_permissions: str | None = None
class UserOrganizationCreate(BaseModel): class UserOrganizationCreate(BaseModel):
"""Schema for adding a user to an organization.""" """Schema for adding a user to an organization."""
user_id: UUID user_id: UUID
role: OrganizationRole = OrganizationRole.MEMBER role: OrganizationRole = OrganizationRole.MEMBER
custom_permissions: Optional[str] = None custom_permissions: str | None = None
class UserOrganizationUpdate(BaseModel): class UserOrganizationUpdate(BaseModel):
"""Schema for updating user's role in an organization.""" """Schema for updating user's role in an organization."""
role: Optional[OrganizationRole] = None
is_active: Optional[bool] = None role: OrganizationRole | None = None
custom_permissions: Optional[str] = None is_active: bool | None = None
custom_permissions: str | None = None
class UserOrganizationResponse(BaseModel): class UserOrganizationResponse(BaseModel):
"""Schema for user-organization relationship responses.""" """Schema for user-organization relationship responses."""
user_id: UUID user_id: UUID
organization_id: UUID organization_id: UUID
role: OrganizationRole role: OrganizationRole
is_active: bool is_active: bool
custom_permissions: Optional[str] = None custom_permissions: str | None = None
created_at: datetime created_at: datetime
updated_at: Optional[datetime] = None updated_at: datetime | None = None
model_config = ConfigDict(from_attributes=True) model_config = ConfigDict(from_attributes=True)
class OrganizationMemberResponse(BaseModel): class OrganizationMemberResponse(BaseModel):
"""Schema for organization member information.""" """Schema for organization member information."""
user_id: UUID user_id: UUID
email: str email: str
first_name: str first_name: str
last_name: Optional[str] = None last_name: str | None = None
role: OrganizationRole role: OrganizationRole
is_active: bool is_active: bool
joined_at: datetime joined_at: datetime
@@ -147,7 +161,8 @@ class OrganizationMemberResponse(BaseModel):
class OrganizationMemberListResponse(BaseModel): class OrganizationMemberListResponse(BaseModel):
"""Schema for paginated organization member list.""" """Schema for paginated organization member list."""
members: List[OrganizationMemberResponse]
members: list[OrganizationMemberResponse]
total: int total: int
page: int page: int
page_size: int page_size: int

View File

@@ -1,37 +1,44 @@
""" """
Pydantic schemas for user session management. Pydantic schemas for user session management.
""" """
from datetime import datetime from datetime import datetime
from typing import Optional
from uuid import UUID from uuid import UUID
from pydantic import BaseModel, Field, ConfigDict from pydantic import BaseModel, ConfigDict, Field
class SessionBase(BaseModel): class SessionBase(BaseModel):
"""Base schema for user sessions.""" """Base schema for user sessions."""
device_name: Optional[str] = Field(None, max_length=255, description="Friendly device name")
device_id: Optional[str] = Field(None, max_length=255, description="Persistent device identifier") device_name: str | None = Field(
None, max_length=255, description="Friendly device name"
)
device_id: str | None = Field(
None, max_length=255, description="Persistent device identifier"
)
class SessionCreate(SessionBase): class SessionCreate(SessionBase):
"""Schema for creating a new session (internal use).""" """Schema for creating a new session (internal use)."""
user_id: UUID user_id: UUID
refresh_token_jti: str = Field(..., max_length=255) refresh_token_jti: str = Field(..., max_length=255)
ip_address: Optional[str] = Field(None, max_length=45) ip_address: str | None = Field(None, max_length=45)
user_agent: Optional[str] = Field(None, max_length=500) user_agent: str | None = Field(None, max_length=500)
last_used_at: datetime last_used_at: datetime
expires_at: datetime expires_at: datetime
location_city: Optional[str] = Field(None, max_length=100) location_city: str | None = Field(None, max_length=100)
location_country: Optional[str] = Field(None, max_length=100) location_country: str | None = Field(None, max_length=100)
class SessionUpdate(BaseModel): class SessionUpdate(BaseModel):
"""Schema for updating a session (internal use).""" """Schema for updating a session (internal use)."""
last_used_at: Optional[datetime] = None
is_active: Optional[bool] = None last_used_at: datetime | None = None
refresh_token_jti: Optional[str] = None is_active: bool | None = None
expires_at: Optional[datetime] = None refresh_token_jti: str | None = None
expires_at: datetime | None = None
class SessionResponse(SessionBase): class SessionResponse(SessionBase):
@@ -40,14 +47,17 @@ class SessionResponse(SessionBase):
This is what users see when they list their active sessions. This is what users see when they list their active sessions.
""" """
id: UUID id: UUID
ip_address: Optional[str] = None ip_address: str | None = None
location_city: Optional[str] = None location_city: str | None = None
location_country: Optional[str] = None location_country: str | None = None
last_used_at: datetime last_used_at: datetime
created_at: datetime created_at: datetime
expires_at: datetime expires_at: datetime
is_current: bool = Field(default=False, description="Whether this is the current session") is_current: bool = Field(
default=False, description="Whether this is the current session"
)
model_config = ConfigDict( model_config = ConfigDict(
from_attributes=True, from_attributes=True,
@@ -62,14 +72,15 @@ class SessionResponse(SessionBase):
"last_used_at": "2025-10-31T12:00:00Z", "last_used_at": "2025-10-31T12:00:00Z",
"created_at": "2025-10-30T09:00:00Z", "created_at": "2025-10-30T09:00:00Z",
"expires_at": "2025-11-06T09:00:00Z", "expires_at": "2025-11-06T09:00:00Z",
"is_current": True "is_current": True,
} }
} },
) )
class SessionListResponse(BaseModel): class SessionListResponse(BaseModel):
"""Response containing list of sessions.""" """Response containing list of sessions."""
sessions: list[SessionResponse] sessions: list[SessionResponse]
total: int = Field(..., description="Total number of active sessions") total: int = Field(..., description="Total number of active sessions")
@@ -84,10 +95,10 @@ class SessionListResponse(BaseModel):
"last_used_at": "2025-10-31T12:00:00Z", "last_used_at": "2025-10-31T12:00:00Z",
"created_at": "2025-10-30T09:00:00Z", "created_at": "2025-10-30T09:00:00Z",
"expires_at": "2025-11-06T09:00:00Z", "expires_at": "2025-11-06T09:00:00Z",
"is_current": True "is_current": True,
} }
], ],
"total": 1 "total": 1,
} }
} }
) )
@@ -95,17 +106,14 @@ class SessionListResponse(BaseModel):
class LogoutRequest(BaseModel): class LogoutRequest(BaseModel):
"""Request schema for logout endpoint.""" """Request schema for logout endpoint."""
refresh_token: str = Field( refresh_token: str = Field(
..., ..., description="Refresh token for the session to logout from", min_length=10
description="Refresh token for the session to logout from",
min_length=10
) )
model_config = ConfigDict( model_config = ConfigDict(
json_schema_extra={ json_schema_extra={
"example": { "example": {"refresh_token": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9..."}
"refresh_token": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9..."
}
} }
) )
@@ -116,13 +124,14 @@ class AdminSessionResponse(SessionBase):
Includes user information for admin to see who owns each session. Includes user information for admin to see who owns each session.
""" """
id: UUID id: UUID
user_id: UUID user_id: UUID
user_email: str = Field(..., description="Email of the user who owns this session") user_email: str = Field(..., description="Email of the user who owns this session")
user_full_name: Optional[str] = Field(None, description="Full name of the user") user_full_name: str | None = Field(None, description="Full name of the user")
ip_address: Optional[str] = None ip_address: str | None = None
location_city: Optional[str] = None location_city: str | None = None
location_country: Optional[str] = None location_country: str | None = None
last_used_at: datetime last_used_at: datetime
created_at: datetime created_at: datetime
expires_at: datetime expires_at: datetime
@@ -144,20 +153,21 @@ class AdminSessionResponse(SessionBase):
"last_used_at": "2025-10-31T12:00:00Z", "last_used_at": "2025-10-31T12:00:00Z",
"created_at": "2025-10-30T09:00:00Z", "created_at": "2025-10-30T09:00:00Z",
"expires_at": "2025-11-06T09:00:00Z", "expires_at": "2025-11-06T09:00:00Z",
"is_active": True "is_active": True,
} }
} },
) )
class DeviceInfo(BaseModel): class DeviceInfo(BaseModel):
"""Device information extracted from request.""" """Device information extracted from request."""
device_name: Optional[str] = None
device_id: Optional[str] = None device_name: str | None = None
ip_address: Optional[str] = None device_id: str | None = None
user_agent: Optional[str] = None ip_address: str | None = None
location_city: Optional[str] = None user_agent: str | None = None
location_country: Optional[str] = None location_city: str | None = None
location_country: str | None = None
model_config = ConfigDict( model_config = ConfigDict(
json_schema_extra={ json_schema_extra={
@@ -167,7 +177,7 @@ class DeviceInfo(BaseModel):
"ip_address": "192.168.1.50", "ip_address": "192.168.1.50",
"user_agent": "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7)...", "user_agent": "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7)...",
"location_city": "San Francisco", "location_city": "San Francisco",
"location_country": "United States" "location_country": "United States",
} }
} }
) )

View File

@@ -1,9 +1,9 @@
# app/schemas/users.py # app/schemas/users.py
from datetime import datetime from datetime import datetime
from typing import Optional, Dict, Any from typing import Any
from uuid import UUID from uuid import UUID
from pydantic import BaseModel, EmailStr, field_validator, ConfigDict, Field from pydantic import BaseModel, ConfigDict, EmailStr, Field, field_validator
from app.schemas.validators import validate_password_strength, validate_phone_number from app.schemas.validators import validate_password_strength, validate_phone_number
@@ -11,12 +11,12 @@ from app.schemas.validators import validate_password_strength, validate_phone_nu
class UserBase(BaseModel): class UserBase(BaseModel):
email: EmailStr email: EmailStr
first_name: str first_name: str
last_name: Optional[str] = None last_name: str | None = None
phone_number: Optional[str] = None phone_number: str | None = None
@field_validator('phone_number') @field_validator("phone_number")
@classmethod @classmethod
def validate_phone(cls, v: Optional[str]) -> Optional[str]: def validate_phone(cls, v: str | None) -> str | None:
return validate_phone_number(v) return validate_phone_number(v)
@@ -24,7 +24,7 @@ class UserCreate(UserBase):
password: str password: str
is_superuser: bool = False is_superuser: bool = False
@field_validator('password') @field_validator("password")
@classmethod @classmethod
def password_strength(cls, v: str) -> str: def password_strength(cls, v: str) -> str:
"""Enterprise-grade password strength validation""" """Enterprise-grade password strength validation"""
@@ -32,30 +32,32 @@ class UserCreate(UserBase):
class UserUpdate(BaseModel): class UserUpdate(BaseModel):
first_name: Optional[str] = None first_name: str | None = None
last_name: Optional[str] = None last_name: str | None = None
phone_number: Optional[str] = None phone_number: str | None = None
password: Optional[str] = None password: str | None = None
preferences: Optional[Dict[str, Any]] = None preferences: dict[str, Any] | None = None
is_active: Optional[bool] = None # Changed default from True to None to avoid unintended updates is_active: bool | None = (
is_superuser: Optional[bool] = None # Explicitly reject privilege escalation attempts None # Changed default from True to None to avoid unintended updates
)
is_superuser: bool | None = None # Explicitly reject privilege escalation attempts
@field_validator('phone_number') @field_validator("phone_number")
@classmethod @classmethod
def validate_phone(cls, v: Optional[str]) -> Optional[str]: def validate_phone(cls, v: str | None) -> str | None:
return validate_phone_number(v) return validate_phone_number(v)
@field_validator('password') @field_validator("password")
@classmethod @classmethod
def password_strength(cls, v: Optional[str]) -> Optional[str]: def password_strength(cls, v: str | None) -> str | None:
"""Enterprise-grade password strength validation""" """Enterprise-grade password strength validation"""
if v is None: if v is None:
return v return v
return validate_password_strength(v) return validate_password_strength(v)
@field_validator('is_superuser') @field_validator("is_superuser")
@classmethod @classmethod
def prevent_superuser_modification(cls, v: Optional[bool]) -> Optional[bool]: def prevent_superuser_modification(cls, v: bool | None) -> bool | None:
"""Prevent users from modifying their superuser status via this schema.""" """Prevent users from modifying their superuser status via this schema."""
if v is not None: if v is not None:
raise ValueError("Cannot modify superuser status through user update") raise ValueError("Cannot modify superuser status through user update")
@@ -67,7 +69,7 @@ class UserInDB(UserBase):
is_active: bool is_active: bool
is_superuser: bool is_superuser: bool
created_at: datetime created_at: datetime
updated_at: Optional[datetime] = None updated_at: datetime | None = None
model_config = ConfigDict(from_attributes=True) model_config = ConfigDict(from_attributes=True)
@@ -77,28 +79,28 @@ class UserResponse(UserBase):
is_active: bool is_active: bool
is_superuser: bool is_superuser: bool
created_at: datetime created_at: datetime
updated_at: Optional[datetime] = None updated_at: datetime | None = None
model_config = ConfigDict(from_attributes=True) model_config = ConfigDict(from_attributes=True)
class Token(BaseModel): class Token(BaseModel):
access_token: str access_token: str
refresh_token: Optional[str] = None refresh_token: str | None = None
token_type: str = "bearer" token_type: str = "bearer"
user: "UserResponse" # Forward reference since UserResponse is defined above user: "UserResponse" # Forward reference since UserResponse is defined above
expires_in: Optional[int] = None # Token expiration in seconds expires_in: int | None = None # Token expiration in seconds
class TokenPayload(BaseModel): class TokenPayload(BaseModel):
sub: str # User ID sub: str # User ID
exp: int # Expiration time exp: int # Expiration time
iat: Optional[int] = None # Issued at iat: int | None = None # Issued at
jti: Optional[str] = None # JWT ID jti: str | None = None # JWT ID
is_superuser: Optional[bool] = False is_superuser: bool | None = False
first_name: Optional[str] = None first_name: str | None = None
email: Optional[str] = None email: str | None = None
type: Optional[str] = None # Token type (access/refresh) type: str | None = None # Token type (access/refresh)
class TokenData(BaseModel): class TokenData(BaseModel):
@@ -108,10 +110,11 @@ class TokenData(BaseModel):
class PasswordChange(BaseModel): class PasswordChange(BaseModel):
"""Schema for changing password (requires current password).""" """Schema for changing password (requires current password)."""
current_password: str current_password: str
new_password: str new_password: str
@field_validator('new_password') @field_validator("new_password")
@classmethod @classmethod
def password_strength(cls, v: str) -> str: def password_strength(cls, v: str) -> str:
"""Enterprise-grade password strength validation""" """Enterprise-grade password strength validation"""
@@ -120,10 +123,11 @@ class PasswordChange(BaseModel):
class PasswordReset(BaseModel): class PasswordReset(BaseModel):
"""Schema for resetting password (via email token).""" """Schema for resetting password (via email token)."""
token: str token: str
new_password: str new_password: str
@field_validator('new_password') @field_validator("new_password")
@classmethod @classmethod
def password_strength(cls, v: str) -> str: def password_strength(cls, v: str) -> str:
"""Enterprise-grade password strength validation""" """Enterprise-grade password strength validation"""
@@ -141,23 +145,19 @@ class RefreshTokenRequest(BaseModel):
class PasswordResetRequest(BaseModel): class PasswordResetRequest(BaseModel):
"""Schema for requesting a password reset.""" """Schema for requesting a password reset."""
email: EmailStr = Field(..., description="Email address of the account") email: EmailStr = Field(..., description="Email address of the account")
model_config = { model_config = {"json_schema_extra": {"example": {"email": "user@example.com"}}}
"json_schema_extra": {
"example": {
"email": "user@example.com"
}
}
}
class PasswordResetConfirm(BaseModel): class PasswordResetConfirm(BaseModel):
"""Schema for confirming a password reset with token.""" """Schema for confirming a password reset with token."""
token: str = Field(..., description="Password reset token from email") token: str = Field(..., description="Password reset token from email")
new_password: str = Field(..., min_length=8, description="New password") new_password: str = Field(..., min_length=8, description="New password")
@field_validator('new_password') @field_validator("new_password")
@classmethod @classmethod
def password_strength(cls, v: str) -> str: def password_strength(cls, v: str) -> str:
"""Enterprise-grade password strength validation""" """Enterprise-grade password strength validation"""
@@ -167,7 +167,7 @@ class PasswordResetConfirm(BaseModel):
"json_schema_extra": { "json_schema_extra": {
"example": { "example": {
"token": "eyJwYXlsb2FkIjp7ImVtYWlsIjoidXNlckBleGFtcGxlLmNvbSIsImV4cCI6MTcxMjM0NTY3OH19", "token": "eyJwYXlsb2FkIjp7ImVtYWlsIjoidXNlckBleGFtcGxlLmNvbSIsImV4cCI6MTcxMjM0NTY3OH19",
"new_password": "NewSecurePassword123" "new_password": "NewSecurePassword123",
} }
} }
} }

View File

@@ -4,19 +4,34 @@ Shared validators for Pydantic schemas.
This module provides reusable validation functions to ensure consistency This module provides reusable validation functions to ensure consistency
across all schemas and avoid code duplication. across all schemas and avoid code duplication.
""" """
import re import re
from typing import Set
# Common weak passwords that should be rejected # Common weak passwords that should be rejected
COMMON_PASSWORDS: Set[str] = { COMMON_PASSWORDS: set[str] = {
'password', 'password1', 'password123', 'password1234', "password",
'admin', 'admin123', 'admin1234', "password1",
'welcome', 'welcome1', 'welcome123', "password123",
'qwerty', 'qwerty123', "password1234",
'12345678', '123456789', '1234567890', "admin",
'letmein', 'letmein1', 'letmein123', "admin123",
'monkey123', 'dragon123', "admin1234",
'passw0rd', 'p@ssw0rd', 'p@ssword', "welcome",
"welcome1",
"welcome123",
"qwerty",
"qwerty123",
"12345678",
"123456789",
"1234567890",
"letmein",
"letmein1",
"letmein123",
"monkey123",
"dragon123",
"passw0rd",
"p@ssw0rd",
"p@ssword",
} }
@@ -47,18 +62,21 @@ def validate_password_strength(password: str) -> str:
""" """
# Check minimum length # Check minimum length
if len(password) < 12: if len(password) < 12:
raise ValueError('Password must be at least 12 characters long') raise ValueError("Password must be at least 12 characters long")
# Check against common passwords (case-insensitive) # Check against common passwords (case-insensitive)
if password.lower() in COMMON_PASSWORDS: if password.lower() in COMMON_PASSWORDS:
raise ValueError('Password is too common. Please choose a stronger password') raise ValueError("Password is too common. Please choose a stronger password")
# Check for required character types # Check for required character types
checks = [ checks = [
(any(c.islower() for c in password), 'at least one lowercase letter'), (any(c.islower() for c in password), "at least one lowercase letter"),
(any(c.isupper() for c in password), 'at least one uppercase letter'), (any(c.isupper() for c in password), "at least one uppercase letter"),
(any(c.isdigit() for c in password), 'at least one digit'), (any(c.isdigit() for c in password), "at least one digit"),
(any(c in '!@#$%^&*()_+-=[]{}|;:,.<>?~`' for c in password), 'at least one special character (!@#$%^&*()_+-=[]{}|;:,.<>?~`)') (
any(c in "!@#$%^&*()_+-=[]{}|;:,.<>?~`" for c in password),
"at least one special character (!@#$%^&*()_+-=[]{}|;:,.<>?~`)",
),
] ]
failed = [msg for check, msg in checks if not check] failed = [msg for check, msg in checks if not check]
@@ -94,10 +112,10 @@ def validate_phone_number(phone: str | None) -> str | None:
# Check for empty strings # Check for empty strings
if not phone or phone.strip() == "": if not phone or phone.strip() == "":
raise ValueError('Phone number cannot be empty') raise ValueError("Phone number cannot be empty")
# Remove all spaces and formatting characters # Remove all spaces and formatting characters
cleaned = re.sub(r'[\s\-\(\)]', '', phone) cleaned = re.sub(r"[\s\-\(\)]", "", phone)
# Basic pattern: # Basic pattern:
# Must start with + or 0 # Must start with + or 0
@@ -105,19 +123,19 @@ def validate_phone_number(phone: str | None) -> str | None:
# After 0 must have at least 8 digits # After 0 must have at least 8 digits
# Maximum total length of 15 digits (international standard) # Maximum total length of 15 digits (international standard)
# Only allowed characters are + at start and digits # Only allowed characters are + at start and digits
pattern = r'^(?:\+[0-9]{8,14}|0[0-9]{8,14})$' pattern = r"^(?:\+[0-9]{8,14}|0[0-9]{8,14})$"
if not re.match(pattern, cleaned): if not re.match(pattern, cleaned):
raise ValueError('Phone number must start with + or 0 followed by 8-14 digits') raise ValueError("Phone number must start with + or 0 followed by 8-14 digits")
# Additional validation to catch specific invalid cases # Additional validation to catch specific invalid cases
# NOTE: These checks are defensive code - the regex pattern above already catches these cases # NOTE: These checks are defensive code - the regex pattern above already catches these cases
if cleaned.count('+') > 1: # pragma: no cover if cleaned.count("+") > 1: # pragma: no cover
raise ValueError('Phone number can only contain one + symbol at the start') raise ValueError("Phone number can only contain one + symbol at the start")
# Check for any non-digit characters (except the leading +) # Check for any non-digit characters (except the leading +)
if not all(c.isdigit() for c in cleaned[1:]): # pragma: no cover if not all(c.isdigit() for c in cleaned[1:]): # pragma: no cover
raise ValueError('Phone number can only contain digits after the prefix') raise ValueError("Phone number can only contain digits after the prefix")
return cleaned return cleaned
@@ -169,16 +187,16 @@ def validate_slug(slug: str) -> str:
ValueError: If slug format is invalid ValueError: If slug format is invalid
""" """
if not slug or len(slug) < 2: if not slug or len(slug) < 2:
raise ValueError('Slug must be at least 2 characters long') raise ValueError("Slug must be at least 2 characters long")
if len(slug) > 50: if len(slug) > 50:
raise ValueError('Slug must be at most 50 characters long') raise ValueError("Slug must be at most 50 characters long")
# Check format # Check format
if not re.match(r'^[a-z0-9]+(?:-[a-z0-9]+)*$', slug): if not re.match(r"^[a-z0-9]+(?:-[a-z0-9]+)*$", slug):
raise ValueError( raise ValueError(
'Slug can only contain lowercase letters, numbers, and hyphens. ' "Slug can only contain lowercase letters, numbers, and hyphens. "
'It cannot start or end with a hyphen, and cannot contain consecutive hyphens' "It cannot start or end with a hyphen, and cannot contain consecutive hyphens"
) )
return slug return slug

View File

@@ -1,18 +1,17 @@
# app/services/auth_service.py # app/services/auth_service.py
import logging import logging
from typing import Optional
from uuid import UUID from uuid import UUID
from sqlalchemy import select from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from app.core.auth import ( from app.core.auth import (
verify_password_async, TokenExpiredError,
get_password_hash_async, TokenInvalidError,
create_access_token, create_access_token,
create_refresh_token, create_refresh_token,
TokenExpiredError, get_password_hash_async,
TokenInvalidError verify_password_async,
) )
from app.core.config import settings from app.core.config import settings
from app.core.exceptions import AuthenticationError from app.core.exceptions import AuthenticationError
@@ -26,7 +25,9 @@ class AuthService:
"""Service for handling authentication operations""" """Service for handling authentication operations"""
@staticmethod @staticmethod
async def authenticate_user(db: AsyncSession, email: str, password: str) -> Optional[User]: async def authenticate_user(
db: AsyncSession, email: str, password: str
) -> User | None:
""" """
Authenticate a user with email and password using async password verification. Authenticate a user with email and password using async password verification.
@@ -87,7 +88,7 @@ class AuthService:
last_name=user_data.last_name, last_name=user_data.last_name,
phone_number=user_data.phone_number, phone_number=user_data.phone_number,
is_active=True, is_active=True,
is_superuser=False is_superuser=False,
) )
db.add(user) db.add(user)
@@ -103,8 +104,8 @@ class AuthService:
except Exception as e: except Exception as e:
# Rollback on any database errors # Rollback on any database errors
await db.rollback() await db.rollback()
logger.error(f"Error creating user: {str(e)}", exc_info=True) logger.error(f"Error creating user: {e!s}", exc_info=True)
raise AuthenticationError(f"Failed to create user: {str(e)}") raise AuthenticationError(f"Failed to create user: {e!s}")
@staticmethod @staticmethod
def create_tokens(user: User) -> Token: def create_tokens(user: User) -> Token:
@@ -121,18 +122,13 @@ class AuthService:
claims = { claims = {
"is_superuser": user.is_superuser, "is_superuser": user.is_superuser,
"email": user.email, "email": user.email,
"first_name": user.first_name "first_name": user.first_name,
} }
# Create tokens # Create tokens
access_token = create_access_token( access_token = create_access_token(subject=str(user.id), claims=claims)
subject=str(user.id),
claims=claims
)
refresh_token = create_refresh_token( refresh_token = create_refresh_token(subject=str(user.id))
subject=str(user.id)
)
# Convert User model to UserResponse schema # Convert User model to UserResponse schema
user_response = UserResponse.model_validate(user) user_response = UserResponse.model_validate(user)
@@ -141,7 +137,8 @@ class AuthService:
access_token=access_token, access_token=access_token,
refresh_token=refresh_token, refresh_token=refresh_token,
user=user_response, user=user_response,
expires_in=settings.ACCESS_TOKEN_EXPIRE_MINUTES * 60 # Convert minutes to seconds expires_in=settings.ACCESS_TOKEN_EXPIRE_MINUTES
* 60, # Convert minutes to seconds
) )
@staticmethod @staticmethod
@@ -180,11 +177,13 @@ class AuthService:
return AuthService.create_tokens(user) return AuthService.create_tokens(user)
except (TokenExpiredError, TokenInvalidError) as e: except (TokenExpiredError, TokenInvalidError) as e:
logger.warning(f"Token refresh failed: {str(e)}") logger.warning(f"Token refresh failed: {e!s}")
raise raise
@staticmethod @staticmethod
async def change_password(db: AsyncSession, user_id: UUID, current_password: str, new_password: str) -> bool: async def change_password(
db: AsyncSession, user_id: UUID, current_password: str, new_password: str
) -> bool:
""" """
Change a user's password. Change a user's password.
@@ -223,5 +222,7 @@ class AuthService:
except Exception as e: except Exception as e:
# Rollback on any database errors # Rollback on any database errors
await db.rollback() await db.rollback()
logger.error(f"Error changing password for user {user_id}: {str(e)}", exc_info=True) logger.error(
raise AuthenticationError(f"Failed to change password: {str(e)}") f"Error changing password for user {user_id}: {e!s}", exc_info=True
)
raise AuthenticationError(f"Failed to change password: {e!s}")

View File

@@ -5,9 +5,9 @@ Email service with placeholder implementation.
This service provides email sending functionality with a simple console/log-based This service provides email sending functionality with a simple console/log-based
placeholder that can be easily replaced with a real email provider (SendGrid, SES, etc.) placeholder that can be easily replaced with a real email provider (SendGrid, SES, etc.)
""" """
import logging import logging
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import List, Optional
from app.core.config import settings from app.core.config import settings
@@ -20,13 +20,12 @@ class EmailBackend(ABC):
@abstractmethod @abstractmethod
async def send_email( async def send_email(
self, self,
to: List[str], to: list[str],
subject: str, subject: str,
html_content: str, html_content: str,
text_content: Optional[str] = None text_content: str | None = None,
) -> bool: ) -> bool:
"""Send an email.""" """Send an email."""
pass
class ConsoleEmailBackend(EmailBackend): class ConsoleEmailBackend(EmailBackend):
@@ -39,10 +38,10 @@ class ConsoleEmailBackend(EmailBackend):
async def send_email( async def send_email(
self, self,
to: List[str], to: list[str],
subject: str, subject: str,
html_content: str, html_content: str,
text_content: Optional[str] = None text_content: str | None = None,
) -> bool: ) -> bool:
""" """
Log email content to console/logs. Log email content to console/logs.
@@ -88,10 +87,10 @@ class SMTPEmailBackend(EmailBackend):
async def send_email( async def send_email(
self, self,
to: List[str], to: list[str],
subject: str, subject: str,
html_content: str, html_content: str,
text_content: Optional[str] = None text_content: str | None = None,
) -> bool: ) -> bool:
"""Send email via SMTP.""" """Send email via SMTP."""
# TODO: Implement SMTP sending # TODO: Implement SMTP sending
@@ -108,7 +107,7 @@ class EmailService:
and can be configured to use different backends (console, SMTP, SendGrid, etc.) and can be configured to use different backends (console, SMTP, SendGrid, etc.)
""" """
def __init__(self, backend: Optional[EmailBackend] = None): def __init__(self, backend: EmailBackend | None = None):
""" """
Initialize email service with a backend. Initialize email service with a backend.
@@ -118,10 +117,7 @@ class EmailService:
self.backend = backend or ConsoleEmailBackend() self.backend = backend or ConsoleEmailBackend()
async def send_password_reset_email( async def send_password_reset_email(
self, self, to_email: str, reset_token: str, user_name: str | None = None
to_email: str,
reset_token: str,
user_name: Optional[str] = None
) -> bool: ) -> bool:
""" """
Send password reset email. Send password reset email.
@@ -142,7 +138,7 @@ class EmailService:
# Plain text version # Plain text version
text_content = f""" text_content = f"""
Hello{' ' + user_name if user_name else ''}, Hello{" " + user_name if user_name else ""},
You requested a password reset for your account. Click the link below to reset your password: You requested a password reset for your account. Click the link below to reset your password:
@@ -177,7 +173,7 @@ The {settings.PROJECT_NAME} Team
<h1>Password Reset</h1> <h1>Password Reset</h1>
</div> </div>
<div class="content"> <div class="content">
<p>Hello{' ' + user_name if user_name else ''},</p> <p>Hello{" " + user_name if user_name else ""},</p>
<p>You requested a password reset for your account. Click the button below to reset your password:</p> <p>You requested a password reset for your account. Click the button below to reset your password:</p>
<p style="text-align: center;"> <p style="text-align: center;">
<a href="{reset_url}" class="button">Reset Password</a> <a href="{reset_url}" class="button">Reset Password</a>
@@ -200,17 +196,14 @@ The {settings.PROJECT_NAME} Team
to=[to_email], to=[to_email],
subject=subject, subject=subject,
html_content=html_content, html_content=html_content,
text_content=text_content text_content=text_content,
) )
except Exception as e: except Exception as e:
logger.error(f"Failed to send password reset email to {to_email}: {str(e)}") logger.error(f"Failed to send password reset email to {to_email}: {e!s}")
return False return False
async def send_email_verification( async def send_email_verification(
self, self, to_email: str, verification_token: str, user_name: str | None = None
to_email: str,
verification_token: str,
user_name: Optional[str] = None
) -> bool: ) -> bool:
""" """
Send email verification email. Send email verification email.
@@ -224,14 +217,16 @@ The {settings.PROJECT_NAME} Team
True if email sent successfully True if email sent successfully
""" """
# Generate verification URL # Generate verification URL
verification_url = f"{settings.FRONTEND_URL}/verify-email?token={verification_token}" verification_url = (
f"{settings.FRONTEND_URL}/verify-email?token={verification_token}"
)
# Prepare email content # Prepare email content
subject = "Verify Your Email Address" subject = "Verify Your Email Address"
# Plain text version # Plain text version
text_content = f""" text_content = f"""
Hello{' ' + user_name if user_name else ''}, Hello{" " + user_name if user_name else ""},
Thank you for signing up! Please verify your email address by clicking the link below: Thank you for signing up! Please verify your email address by clicking the link below:
@@ -266,7 +261,7 @@ The {settings.PROJECT_NAME} Team
<h1>Verify Your Email</h1> <h1>Verify Your Email</h1>
</div> </div>
<div class="content"> <div class="content">
<p>Hello{' ' + user_name if user_name else ''},</p> <p>Hello{" " + user_name if user_name else ""},</p>
<p>Thank you for signing up! Please verify your email address by clicking the button below:</p> <p>Thank you for signing up! Please verify your email address by clicking the button below:</p>
<p style="text-align: center;"> <p style="text-align: center;">
<a href="{verification_url}" class="button">Verify Email</a> <a href="{verification_url}" class="button">Verify Email</a>
@@ -289,10 +284,10 @@ The {settings.PROJECT_NAME} Team
to=[to_email], to=[to_email],
subject=subject, subject=subject,
html_content=html_content, html_content=html_content,
text_content=text_content text_content=text_content,
) )
except Exception as e: except Exception as e:
logger.error(f"Failed to send verification email to {to_email}: {str(e)}") logger.error(f"Failed to send verification email to {to_email}: {e!s}")
return False return False

View File

@@ -3,8 +3,9 @@ Background job for cleaning up expired sessions.
This service runs periodically to remove old session records from the database. This service runs periodically to remove old session records from the database.
""" """
import logging import logging
from datetime import datetime, timezone from datetime import UTC, datetime
from app.core.database import SessionLocal from app.core.database import SessionLocal
from app.crud.session import session as session_crud from app.crud.session import session as session_crud
@@ -39,7 +40,7 @@ async def cleanup_expired_sessions(keep_days: int = 30) -> int:
return count return count
except Exception as e: except Exception as e:
logger.error(f"Error during session cleanup: {str(e)}", exc_info=True) logger.error(f"Error during session cleanup: {e!s}", exc_info=True)
return 0 return 0
@@ -52,20 +53,21 @@ async def get_session_statistics() -> dict:
""" """
async with SessionLocal() as db: async with SessionLocal() as db:
try: try:
from sqlalchemy import func, select
from app.models.user_session import UserSession from app.models.user_session import UserSession
from sqlalchemy import select, func
total_result = await db.execute(select(func.count(UserSession.id))) total_result = await db.execute(select(func.count(UserSession.id)))
total_sessions = total_result.scalar_one() total_sessions = total_result.scalar_one()
active_result = await db.execute( active_result = await db.execute(
select(func.count(UserSession.id)).where(UserSession.is_active == True) select(func.count(UserSession.id)).where(UserSession.is_active)
) )
active_sessions = active_result.scalar_one() active_sessions = active_result.scalar_one()
expired_result = await db.execute( expired_result = await db.execute(
select(func.count(UserSession.id)).where( select(func.count(UserSession.id)).where(
UserSession.expires_at < datetime.now(timezone.utc) UserSession.expires_at < datetime.now(UTC)
) )
) )
expired_sessions = expired_result.scalar_one() expired_sessions = expired_result.scalar_one()
@@ -82,5 +84,5 @@ async def get_session_statistics() -> dict:
return stats return stats
except Exception as e: except Exception as e:
logger.error(f"Error getting session statistics: {str(e)}", exc_info=True) logger.error(f"Error getting session statistics: {e!s}", exc_info=True)
return {} return {}

View File

@@ -2,7 +2,8 @@
Authentication utilities for testing. Authentication utilities for testing.
This module provides tools to bypass FastAPI's authentication in tests. This module provides tools to bypass FastAPI's authentication in tests.
""" """
from typing import Callable, Dict, Optional
from collections.abc import Callable
from fastapi import FastAPI from fastapi import FastAPI
from fastapi.security import OAuth2PasswordBearer from fastapi.security import OAuth2PasswordBearer
@@ -13,9 +14,9 @@ from app.models.user import User
def create_test_auth_client( def create_test_auth_client(
app: FastAPI, app: FastAPI,
test_user: User, test_user: User,
extra_overrides: Optional[Dict[Callable, Callable]] = None extra_overrides: dict[Callable, Callable] | None = None,
) -> TestClient: ) -> TestClient:
""" """
Create a test client with authentication pre-configured. Create a test client with authentication pre-configured.
@@ -47,10 +48,7 @@ def create_test_auth_client(
return TestClient(app) return TestClient(app)
def create_test_optional_auth_client( def create_test_optional_auth_client(app: FastAPI, test_user: User) -> TestClient:
app: FastAPI,
test_user: User
) -> TestClient:
""" """
Create a test client with optional authentication pre-configured. Create a test client with optional authentication pre-configured.
@@ -70,10 +68,7 @@ def create_test_optional_auth_client(
return TestClient(app) return TestClient(app)
def create_test_superuser_client( def create_test_superuser_client(app: FastAPI, test_user: User) -> TestClient:
app: FastAPI,
test_user: User
) -> TestClient:
""" """
Create a test client with superuser authentication pre-configured. Create a test client with superuser authentication pre-configured.
@@ -120,7 +115,7 @@ def cleanup_test_client_auth(app: FastAPI) -> None:
auth_deps = [ auth_deps = [
get_current_user, get_current_user,
get_optional_current_user, get_optional_current_user,
OAuth2PasswordBearer(tokenUrl="/api/v1/auth/login") OAuth2PasswordBearer(tokenUrl="/api/v1/auth/login"),
] ]
# Remove overrides # Remove overrides

View File

@@ -1,8 +1,8 @@
""" """
Utility functions for extracting and parsing device information from HTTP requests. Utility functions for extracting and parsing device information from HTTP requests.
""" """
import re import re
from typing import Optional
from fastapi import Request from fastapi import Request
@@ -19,11 +19,11 @@ def extract_device_info(request: Request) -> DeviceInfo:
Returns: Returns:
DeviceInfo object with parsed device information DeviceInfo object with parsed device information
""" """
user_agent = request.headers.get('user-agent', '') user_agent = request.headers.get("user-agent", "")
device_info = DeviceInfo( device_info = DeviceInfo(
device_name=parse_device_name(user_agent), device_name=parse_device_name(user_agent),
device_id=request.headers.get('x-device-id'), # Client must send this header device_id=request.headers.get("x-device-id"), # Client must send this header
ip_address=get_client_ip(request), ip_address=get_client_ip(request),
user_agent=user_agent[:500] if user_agent else None, # Truncate to max length user_agent=user_agent[:500] if user_agent else None, # Truncate to max length
location_city=None, # Can be populated via IP geolocation service location_city=None, # Can be populated via IP geolocation service
@@ -33,7 +33,7 @@ def extract_device_info(request: Request) -> DeviceInfo:
return device_info return device_info
def parse_device_name(user_agent: str) -> Optional[str]: def parse_device_name(user_agent: str) -> str | None:
""" """
Parse user agent string to extract a friendly device name. Parse user agent string to extract a friendly device name.
@@ -54,48 +54,48 @@ def parse_device_name(user_agent: str) -> Optional[str]:
user_agent_lower = user_agent.lower() user_agent_lower = user_agent.lower()
# Mobile devices (check first, as they can contain desktop patterns too) # Mobile devices (check first, as they can contain desktop patterns too)
if 'iphone' in user_agent_lower: if "iphone" in user_agent_lower:
return "iPhone" return "iPhone"
elif 'ipad' in user_agent_lower: elif "ipad" in user_agent_lower:
return "iPad" return "iPad"
elif 'android' in user_agent_lower: elif "android" in user_agent_lower:
# Try to extract device model # Try to extract device model
android_match = re.search(r'android.*;\s*([^)]+)\s*build', user_agent_lower) android_match = re.search(r"android.*;\s*([^)]+)\s*build", user_agent_lower)
if android_match: if android_match:
device_model = android_match.group(1).strip() device_model = android_match.group(1).strip()
return f"Android ({device_model.title()})" return f"Android ({device_model.title()})"
return "Android device" return "Android device"
elif 'windows phone' in user_agent_lower: elif "windows phone" in user_agent_lower:
return "Windows Phone" return "Windows Phone"
# Tablets (check before desktop, as some tablets contain "android") # Tablets (check before desktop, as some tablets contain "android")
elif 'tablet' in user_agent_lower: elif "tablet" in user_agent_lower:
return "Tablet" return "Tablet"
# Smart TVs (check before desktop OS patterns) # Smart TVs (check before desktop OS patterns)
elif any(tv in user_agent_lower for tv in ['smart-tv', 'smarttv']): elif any(tv in user_agent_lower for tv in ["smart-tv", "smarttv"]):
return "Smart TV" return "Smart TV"
# Game consoles (check before desktop OS patterns, as Xbox contains "Windows") # Game consoles (check before desktop OS patterns, as Xbox contains "Windows")
elif 'playstation' in user_agent_lower: elif "playstation" in user_agent_lower:
return "PlayStation" return "PlayStation"
elif 'xbox' in user_agent_lower: elif "xbox" in user_agent_lower:
return "Xbox" return "Xbox"
elif 'nintendo' in user_agent_lower: elif "nintendo" in user_agent_lower:
return "Nintendo" return "Nintendo"
# Desktop operating systems # Desktop operating systems
elif 'macintosh' in user_agent_lower or 'mac os x' in user_agent_lower: elif "macintosh" in user_agent_lower or "mac os x" in user_agent_lower:
# Try to extract browser # Try to extract browser
browser = extract_browser(user_agent) browser = extract_browser(user_agent)
return f"{browser} on Mac" if browser else "Mac" return f"{browser} on Mac" if browser else "Mac"
elif 'windows' in user_agent_lower: elif "windows" in user_agent_lower:
browser = extract_browser(user_agent) browser = extract_browser(user_agent)
return f"{browser} on Windows" if browser else "Windows PC" return f"{browser} on Windows" if browser else "Windows PC"
elif 'linux' in user_agent_lower and 'android' not in user_agent_lower: elif "linux" in user_agent_lower and "android" not in user_agent_lower:
browser = extract_browser(user_agent) browser = extract_browser(user_agent)
return f"{browser} on Linux" if browser else "Linux" return f"{browser} on Linux" if browser else "Linux"
elif 'cros' in user_agent_lower: elif "cros" in user_agent_lower:
return "Chromebook" return "Chromebook"
# Fallback: just return browser name if detected # Fallback: just return browser name if detected
@@ -106,7 +106,7 @@ def parse_device_name(user_agent: str) -> Optional[str]:
return "Unknown device" return "Unknown device"
def extract_browser(user_agent: str) -> Optional[str]: def extract_browser(user_agent: str) -> str | None:
""" """
Extract browser name from user agent string. Extract browser name from user agent string.
@@ -126,26 +126,26 @@ def extract_browser(user_agent: str) -> Optional[str]:
user_agent_lower = user_agent.lower() user_agent_lower = user_agent.lower()
# Check specific browsers (order matters - check Edge before Chrome!) # Check specific browsers (order matters - check Edge before Chrome!)
if 'edg/' in user_agent_lower or 'edge/' in user_agent_lower: if "edg/" in user_agent_lower or "edge/" in user_agent_lower:
return "Edge" return "Edge"
elif 'opr/' in user_agent_lower or 'opera' in user_agent_lower: elif "opr/" in user_agent_lower or "opera" in user_agent_lower:
return "Opera" return "Opera"
elif 'chrome/' in user_agent_lower: elif "chrome/" in user_agent_lower:
return "Chrome" return "Chrome"
elif 'safari/' in user_agent_lower: elif "safari/" in user_agent_lower:
# Make sure it's actually Safari, not Chrome (which also contains "Safari") # Make sure it's actually Safari, not Chrome (which also contains "Safari")
if 'chrome' not in user_agent_lower: if "chrome" not in user_agent_lower:
return "Safari" return "Safari"
return None return None
elif 'firefox/' in user_agent_lower: elif "firefox/" in user_agent_lower:
return "Firefox" return "Firefox"
elif 'msie' in user_agent_lower or 'trident/' in user_agent_lower: elif "msie" in user_agent_lower or "trident/" in user_agent_lower:
return "Internet Explorer" return "Internet Explorer"
return None return None
def get_client_ip(request: Request) -> Optional[str]: def get_client_ip(request: Request) -> str | None:
""" """
Extract client IP address from request, considering proxy headers. Extract client IP address from request, considering proxy headers.
@@ -163,14 +163,14 @@ def get_client_ip(request: Request) -> Optional[str]:
- request.client.host is fallback for direct connections - request.client.host is fallback for direct connections
""" """
# Check X-Forwarded-For (common in proxied environments) # Check X-Forwarded-For (common in proxied environments)
x_forwarded_for = request.headers.get('x-forwarded-for') x_forwarded_for = request.headers.get("x-forwarded-for")
if x_forwarded_for: if x_forwarded_for:
# Get the first IP (original client) # Get the first IP (original client)
client_ip = x_forwarded_for.split(',')[0].strip() client_ip = x_forwarded_for.split(",")[0].strip()
return client_ip return client_ip
# Check X-Real-IP (used by some proxies like nginx) # Check X-Real-IP (used by some proxies like nginx)
x_real_ip = request.headers.get('x-real-ip') x_real_ip = request.headers.get("x-real-ip")
if x_real_ip: if x_real_ip:
return x_real_ip.strip() return x_real_ip.strip()
@@ -195,9 +195,17 @@ def is_mobile_device(user_agent: str) -> bool:
return False return False
mobile_patterns = [ mobile_patterns = [
'mobile', 'android', 'iphone', 'ipad', 'ipod', "mobile",
'blackberry', 'windows phone', 'webos', 'opera mini', "android",
'iemobile', 'mobile safari' "iphone",
"ipad",
"ipod",
"blackberry",
"windows phone",
"webos",
"opera mini",
"iemobile",
"mobile safari",
] ]
user_agent_lower = user_agent.lower() user_agent_lower = user_agent.lower()
@@ -220,7 +228,7 @@ def get_device_type(user_agent: str) -> str:
user_agent_lower = user_agent.lower() user_agent_lower = user_agent.lower()
# Check for tablets first (they can contain "mobile" too) # Check for tablets first (they can contain "mobile" too)
if 'ipad' in user_agent_lower or 'tablet' in user_agent_lower: if "ipad" in user_agent_lower or "tablet" in user_agent_lower:
return "tablet" return "tablet"
# Check for mobile # Check for mobile
@@ -228,7 +236,7 @@ def get_device_type(user_agent: str) -> str:
return "mobile" return "mobile"
# Check for desktop OS patterns # Check for desktop OS patterns
if any(os in user_agent_lower for os in ['windows', 'macintosh', 'linux', 'cros']): if any(os in user_agent_lower for os in ["windows", "macintosh", "linux", "cros"]):
return "desktop" return "desktop"
return "other" return "other"

View File

@@ -5,18 +5,21 @@ This module provides utilities for creating and verifying signed tokens,
useful for operations like file uploads, password resets, or any other useful for operations like file uploads, password resets, or any other
time-limited, single-use operations. time-limited, single-use operations.
""" """
import base64 import base64
import hashlib import hashlib
import hmac import hmac
import json import json
import secrets import secrets
import time import time
from typing import Dict, Any, Optional from typing import Any
from app.core.config import settings from app.core.config import settings
def create_upload_token(file_path: str, content_type: str, expires_in: int = 300) -> str: def create_upload_token(
file_path: str, content_type: str, expires_in: int = 300
) -> str:
""" """
Create a signed token for secure file uploads. Create a signed token for secure file uploads.
@@ -40,34 +43,29 @@ def create_upload_token(file_path: str, content_type: str, expires_in: int = 300
"path": file_path, "path": file_path,
"content_type": content_type, "content_type": content_type,
"exp": int(time.time()) + expires_in, "exp": int(time.time()) + expires_in,
"nonce": secrets.token_hex(8) # Add randomness to prevent token reuse "nonce": secrets.token_hex(8), # Add randomness to prevent token reuse
} }
# Convert to JSON and encode # Convert to JSON and encode
payload_bytes = json.dumps(payload).encode('utf-8') payload_bytes = json.dumps(payload).encode("utf-8")
# Create a signature using HMAC-SHA256 for security # Create a signature using HMAC-SHA256 for security
# This prevents length extension attacks that plain SHA-256 is vulnerable to # This prevents length extension attacks that plain SHA-256 is vulnerable to
signature = hmac.new( signature = hmac.new(
settings.SECRET_KEY.encode('utf-8'), settings.SECRET_KEY.encode("utf-8"), payload_bytes, hashlib.sha256
payload_bytes,
hashlib.sha256
).hexdigest() ).hexdigest()
# Combine payload and signature # Combine payload and signature
token_data = { token_data = {"payload": payload, "signature": signature}
"payload": payload,
"signature": signature
}
# Encode the final token # Encode the final token
token_json = json.dumps(token_data) token_json = json.dumps(token_data)
token = base64.urlsafe_b64encode(token_json.encode('utf-8')).decode('utf-8') token = base64.urlsafe_b64encode(token_json.encode("utf-8")).decode("utf-8")
return token return token
def verify_upload_token(token: str) -> Optional[Dict[str, Any]]: def verify_upload_token(token: str) -> dict[str, Any] | None:
""" """
Verify an upload token and return the payload if valid. Verify an upload token and return the payload if valid.
@@ -88,7 +86,7 @@ def verify_upload_token(token: str) -> Optional[Dict[str, Any]]:
""" """
try: try:
# Decode the token # Decode the token
token_json = base64.urlsafe_b64decode(token.encode('utf-8')).decode('utf-8') token_json = base64.urlsafe_b64decode(token.encode("utf-8")).decode("utf-8")
token_data = json.loads(token_json) token_data = json.loads(token_json)
# Extract payload and signature # Extract payload and signature
@@ -96,11 +94,9 @@ def verify_upload_token(token: str) -> Optional[Dict[str, Any]]:
signature = token_data["signature"] signature = token_data["signature"]
# Verify signature using HMAC and constant-time comparison # Verify signature using HMAC and constant-time comparison
payload_bytes = json.dumps(payload).encode('utf-8') payload_bytes = json.dumps(payload).encode("utf-8")
expected_signature = hmac.new( expected_signature = hmac.new(
settings.SECRET_KEY.encode('utf-8'), settings.SECRET_KEY.encode("utf-8"), payload_bytes, hashlib.sha256
payload_bytes,
hashlib.sha256
).hexdigest() ).hexdigest()
if not hmac.compare_digest(signature, expected_signature): if not hmac.compare_digest(signature, expected_signature):
@@ -136,34 +132,29 @@ def create_password_reset_token(email: str, expires_in: int = 3600) -> str:
"email": email, "email": email,
"exp": int(time.time()) + expires_in, "exp": int(time.time()) + expires_in,
"nonce": secrets.token_hex(16), # Extra randomness "nonce": secrets.token_hex(16), # Extra randomness
"purpose": "password_reset" "purpose": "password_reset",
} }
# Convert to JSON and encode # Convert to JSON and encode
payload_bytes = json.dumps(payload).encode('utf-8') payload_bytes = json.dumps(payload).encode("utf-8")
# Create a signature using HMAC-SHA256 for security # Create a signature using HMAC-SHA256 for security
# This prevents length extension attacks that plain SHA-256 is vulnerable to # This prevents length extension attacks that plain SHA-256 is vulnerable to
signature = hmac.new( signature = hmac.new(
settings.SECRET_KEY.encode('utf-8'), settings.SECRET_KEY.encode("utf-8"), payload_bytes, hashlib.sha256
payload_bytes,
hashlib.sha256
).hexdigest() ).hexdigest()
# Combine payload and signature # Combine payload and signature
token_data = { token_data = {"payload": payload, "signature": signature}
"payload": payload,
"signature": signature
}
# Encode the final token # Encode the final token
token_json = json.dumps(token_data) token_json = json.dumps(token_data)
token = base64.urlsafe_b64encode(token_json.encode('utf-8')).decode('utf-8') token = base64.urlsafe_b64encode(token_json.encode("utf-8")).decode("utf-8")
return token return token
def verify_password_reset_token(token: str) -> Optional[str]: def verify_password_reset_token(token: str) -> str | None:
""" """
Verify a password reset token and return the email if valid. Verify a password reset token and return the email if valid.
@@ -182,7 +173,7 @@ def verify_password_reset_token(token: str) -> Optional[str]:
""" """
try: try:
# Decode the token # Decode the token
token_json = base64.urlsafe_b64decode(token.encode('utf-8')).decode('utf-8') token_json = base64.urlsafe_b64decode(token.encode("utf-8")).decode("utf-8")
token_data = json.loads(token_json) token_data = json.loads(token_json)
# Extract payload and signature # Extract payload and signature
@@ -194,11 +185,9 @@ def verify_password_reset_token(token: str) -> Optional[str]:
return None return None
# Verify signature using HMAC and constant-time comparison # Verify signature using HMAC and constant-time comparison
payload_bytes = json.dumps(payload).encode('utf-8') payload_bytes = json.dumps(payload).encode("utf-8")
expected_signature = hmac.new( expected_signature = hmac.new(
settings.SECRET_KEY.encode('utf-8'), settings.SECRET_KEY.encode("utf-8"), payload_bytes, hashlib.sha256
payload_bytes,
hashlib.sha256
).hexdigest() ).hexdigest()
if not hmac.compare_digest(signature, expected_signature): if not hmac.compare_digest(signature, expected_signature):
@@ -234,34 +223,29 @@ def create_email_verification_token(email: str, expires_in: int = 86400) -> str:
"email": email, "email": email,
"exp": int(time.time()) + expires_in, "exp": int(time.time()) + expires_in,
"nonce": secrets.token_hex(16), "nonce": secrets.token_hex(16),
"purpose": "email_verification" "purpose": "email_verification",
} }
# Convert to JSON and encode # Convert to JSON and encode
payload_bytes = json.dumps(payload).encode('utf-8') payload_bytes = json.dumps(payload).encode("utf-8")
# Create a signature using HMAC-SHA256 for security # Create a signature using HMAC-SHA256 for security
# This prevents length extension attacks that plain SHA-256 is vulnerable to # This prevents length extension attacks that plain SHA-256 is vulnerable to
signature = hmac.new( signature = hmac.new(
settings.SECRET_KEY.encode('utf-8'), settings.SECRET_KEY.encode("utf-8"), payload_bytes, hashlib.sha256
payload_bytes,
hashlib.sha256
).hexdigest() ).hexdigest()
# Combine payload and signature # Combine payload and signature
token_data = { token_data = {"payload": payload, "signature": signature}
"payload": payload,
"signature": signature
}
# Encode the final token # Encode the final token
token_json = json.dumps(token_data) token_json = json.dumps(token_data)
token = base64.urlsafe_b64encode(token_json.encode('utf-8')).decode('utf-8') token = base64.urlsafe_b64encode(token_json.encode("utf-8")).decode("utf-8")
return token return token
def verify_email_verification_token(token: str) -> Optional[str]: def verify_email_verification_token(token: str) -> str | None:
""" """
Verify an email verification token and return the email if valid. Verify an email verification token and return the email if valid.
@@ -280,7 +264,7 @@ def verify_email_verification_token(token: str) -> Optional[str]:
""" """
try: try:
# Decode the token # Decode the token
token_json = base64.urlsafe_b64decode(token.encode('utf-8')).decode('utf-8') token_json = base64.urlsafe_b64decode(token.encode("utf-8")).decode("utf-8")
token_data = json.loads(token_json) token_data = json.loads(token_json)
# Extract payload and signature # Extract payload and signature
@@ -292,11 +276,9 @@ def verify_email_verification_token(token: str) -> Optional[str]:
return None return None
# Verify signature using HMAC and constant-time comparison # Verify signature using HMAC and constant-time comparison
payload_bytes = json.dumps(payload).encode('utf-8') payload_bytes = json.dumps(payload).encode("utf-8")
expected_signature = hmac.new( expected_signature = hmac.new(
settings.SECRET_KEY.encode('utf-8'), settings.SECRET_KEY.encode("utf-8"), payload_bytes, hashlib.sha256
payload_bytes,
hashlib.sha256
).hexdigest() ).hexdigest()
if not hmac.compare_digest(signature, expected_signature): if not hmac.compare_digest(signature, expected_signature):

View File

@@ -9,17 +9,19 @@ from app.core.database import Base
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def get_test_engine(): def get_test_engine():
"""Create an SQLite in-memory engine specifically for testing""" """Create an SQLite in-memory engine specifically for testing"""
test_engine = create_engine( test_engine = create_engine(
"sqlite:///:memory:", "sqlite:///:memory:",
connect_args={"check_same_thread": False}, connect_args={"check_same_thread": False},
poolclass=StaticPool, # Use static pool for in-memory testing poolclass=StaticPool, # Use static pool for in-memory testing
echo=False echo=False,
) )
return test_engine return test_engine
def setup_test_db(): def setup_test_db():
"""Create a test database and session factory""" """Create a test database and session factory"""
# Create a new engine for this test run # Create a new engine for this test run
@@ -30,14 +32,12 @@ def setup_test_db():
# Create session factory # Create session factory
TestingSessionLocal = sessionmaker( TestingSessionLocal = sessionmaker(
autocommit=False, autocommit=False, autoflush=False, bind=test_engine, expire_on_commit=False
autoflush=False,
bind=test_engine,
expire_on_commit=False
) )
return test_engine, TestingSessionLocal return test_engine, TestingSessionLocal
def teardown_test_db(engine): def teardown_test_db(engine):
"""Clean up after tests""" """Clean up after tests"""
# Drop all tables # Drop all tables
@@ -46,13 +46,14 @@ def teardown_test_db(engine):
# Dispose of engine # Dispose of engine
engine.dispose() engine.dispose()
async def get_async_test_engine(): async def get_async_test_engine():
"""Create an async SQLite in-memory engine specifically for testing""" """Create an async SQLite in-memory engine specifically for testing"""
test_engine = create_async_engine( test_engine = create_async_engine(
"sqlite+aiosqlite:///:memory:", "sqlite+aiosqlite:///:memory:",
connect_args={"check_same_thread": False}, connect_args={"check_same_thread": False},
poolclass=StaticPool, # Use static pool for in-memory testing poolclass=StaticPool, # Use static pool for in-memory testing
echo=False echo=False,
) )
return test_engine return test_engine
@@ -69,7 +70,7 @@ async def setup_async_test_db():
autoflush=False, autoflush=False,
bind=test_engine, bind=test_engine,
expire_on_commit=False, expire_on_commit=False,
class_=AsyncSession class_=AsyncSession,
) )
return test_engine, AsyncTestingSessionLocal return test_engine, AsyncTestingSessionLocal

View File

@@ -1,15 +1,16 @@
# tests/api/dependencies/test_auth_dependencies.py # tests/api/dependencies/test_auth_dependencies.py
import pytest
import pytest_asyncio
import uuid import uuid
from unittest.mock import patch from unittest.mock import patch
import pytest
import pytest_asyncio
from fastapi import HTTPException from fastapi import HTTPException
from app.api.dependencies.auth import ( from app.api.dependencies.auth import (
get_current_user,
get_current_active_user, get_current_active_user,
get_current_superuser, get_current_superuser,
get_optional_current_user get_current_user,
get_optional_current_user,
) )
from app.core.auth import TokenExpiredError, TokenInvalidError, get_password_hash from app.core.auth import TokenExpiredError, TokenInvalidError, get_password_hash
from app.models.user import User from app.models.user import User
@@ -24,7 +25,7 @@ def mock_token():
@pytest_asyncio.fixture @pytest_asyncio.fixture
async def async_mock_user(async_test_db): async def async_mock_user(async_test_db):
"""Async fixture to create and return a mock User instance.""" """Async fixture to create and return a mock User instance."""
test_engine, AsyncTestingSessionLocal = async_test_db _test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session: async with AsyncTestingSessionLocal() as session:
mock_user = User( mock_user = User(
id=uuid.uuid4(), id=uuid.uuid4(),
@@ -47,12 +48,14 @@ class TestGetCurrentUser:
"""Tests for get_current_user dependency""" """Tests for get_current_user dependency"""
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_get_current_user_success(self, async_test_db, async_mock_user, mock_token): async def test_get_current_user_success(
self, async_test_db, async_mock_user, mock_token
):
"""Test successfully getting the current user""" """Test successfully getting the current user"""
test_engine, AsyncTestingSessionLocal = async_test_db _test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session: async with AsyncTestingSessionLocal() as session:
# Mock get_token_data to return user_id that matches our mock_user # Mock get_token_data to return user_id that matches our mock_user
with patch('app.api.dependencies.auth.get_token_data') as mock_get_data: with patch("app.api.dependencies.auth.get_token_data") as mock_get_data:
mock_get_data.return_value.user_id = async_mock_user.id mock_get_data.return_value.user_id = async_mock_user.id
# Call the dependency # Call the dependency
@@ -65,12 +68,12 @@ class TestGetCurrentUser:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_get_current_user_nonexistent(self, async_test_db, mock_token): async def test_get_current_user_nonexistent(self, async_test_db, mock_token):
"""Test when the token contains a user ID that doesn't exist""" """Test when the token contains a user ID that doesn't exist"""
test_engine, AsyncTestingSessionLocal = async_test_db _test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session: async with AsyncTestingSessionLocal() as session:
# Mock get_token_data to return a non-existent user ID # Mock get_token_data to return a non-existent user ID
nonexistent_id = uuid.UUID("11111111-1111-1111-1111-111111111111") nonexistent_id = uuid.UUID("11111111-1111-1111-1111-111111111111")
with patch('app.api.dependencies.auth.get_token_data') as mock_get_data: with patch("app.api.dependencies.auth.get_token_data") as mock_get_data:
mock_get_data.return_value.user_id = nonexistent_id mock_get_data.return_value.user_id = nonexistent_id
# Should raise HTTPException with 404 status # Should raise HTTPException with 404 status
@@ -81,19 +84,24 @@ class TestGetCurrentUser:
assert "User not found" in exc_info.value.detail assert "User not found" in exc_info.value.detail
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_get_current_user_inactive(self, async_test_db, async_mock_user, mock_token): async def test_get_current_user_inactive(
self, async_test_db, async_mock_user, mock_token
):
"""Test when the user is inactive""" """Test when the user is inactive"""
test_engine, AsyncTestingSessionLocal = async_test_db _test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session: async with AsyncTestingSessionLocal() as session:
# Get the user in this session and make it inactive # Get the user in this session and make it inactive
from sqlalchemy import select from sqlalchemy import select
result = await session.execute(select(User).where(User.id == async_mock_user.id))
result = await session.execute(
select(User).where(User.id == async_mock_user.id)
)
user_in_session = result.scalar_one_or_none() user_in_session = result.scalar_one_or_none()
user_in_session.is_active = False user_in_session.is_active = False
await session.commit() await session.commit()
# Mock get_token_data # Mock get_token_data
with patch('app.api.dependencies.auth.get_token_data') as mock_get_data: with patch("app.api.dependencies.auth.get_token_data") as mock_get_data:
mock_get_data.return_value.user_id = async_mock_user.id mock_get_data.return_value.user_id = async_mock_user.id
# Should raise HTTPException with 403 status # Should raise HTTPException with 403 status
@@ -106,10 +114,10 @@ class TestGetCurrentUser:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_get_current_user_expired_token(self, async_test_db, mock_token): async def test_get_current_user_expired_token(self, async_test_db, mock_token):
"""Test with an expired token""" """Test with an expired token"""
test_engine, AsyncTestingSessionLocal = async_test_db _test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session: async with AsyncTestingSessionLocal() as session:
# Mock get_token_data to raise TokenExpiredError # Mock get_token_data to raise TokenExpiredError
with patch('app.api.dependencies.auth.get_token_data') as mock_get_data: with patch("app.api.dependencies.auth.get_token_data") as mock_get_data:
mock_get_data.side_effect = TokenExpiredError("Token expired") mock_get_data.side_effect = TokenExpiredError("Token expired")
# Should raise HTTPException with 401 status # Should raise HTTPException with 401 status
@@ -122,10 +130,10 @@ class TestGetCurrentUser:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_get_current_user_invalid_token(self, async_test_db, mock_token): async def test_get_current_user_invalid_token(self, async_test_db, mock_token):
"""Test with an invalid token""" """Test with an invalid token"""
test_engine, AsyncTestingSessionLocal = async_test_db _test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session: async with AsyncTestingSessionLocal() as session:
# Mock get_token_data to raise TokenInvalidError # Mock get_token_data to raise TokenInvalidError
with patch('app.api.dependencies.auth.get_token_data') as mock_get_data: with patch("app.api.dependencies.auth.get_token_data") as mock_get_data:
mock_get_data.side_effect = TokenInvalidError("Invalid token") mock_get_data.side_effect = TokenInvalidError("Invalid token")
# Should raise HTTPException with 401 status # Should raise HTTPException with 401 status
@@ -194,12 +202,14 @@ class TestGetOptionalCurrentUser:
"""Tests for get_optional_current_user dependency""" """Tests for get_optional_current_user dependency"""
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_get_optional_current_user_with_token(self, async_test_db, async_mock_user, mock_token): async def test_get_optional_current_user_with_token(
self, async_test_db, async_mock_user, mock_token
):
"""Test getting optional user with a valid token""" """Test getting optional user with a valid token"""
test_engine, AsyncTestingSessionLocal = async_test_db _test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session: async with AsyncTestingSessionLocal() as session:
# Mock get_token_data # Mock get_token_data
with patch('app.api.dependencies.auth.get_token_data') as mock_get_data: with patch("app.api.dependencies.auth.get_token_data") as mock_get_data:
mock_get_data.return_value.user_id = async_mock_user.id mock_get_data.return_value.user_id = async_mock_user.id
# Call the dependency # Call the dependency
@@ -212,7 +222,7 @@ class TestGetOptionalCurrentUser:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_get_optional_current_user_no_token(self, async_test_db): async def test_get_optional_current_user_no_token(self, async_test_db):
"""Test getting optional user with no token""" """Test getting optional user with no token"""
test_engine, AsyncTestingSessionLocal = async_test_db _test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session: async with AsyncTestingSessionLocal() as session:
# Call the dependency with no token # Call the dependency with no token
user = await get_optional_current_user(db=session, token=None) user = await get_optional_current_user(db=session, token=None)
@@ -221,12 +231,14 @@ class TestGetOptionalCurrentUser:
assert user is None assert user is None
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_get_optional_current_user_invalid_token(self, async_test_db, mock_token): async def test_get_optional_current_user_invalid_token(
self, async_test_db, mock_token
):
"""Test getting optional user with an invalid token""" """Test getting optional user with an invalid token"""
test_engine, AsyncTestingSessionLocal = async_test_db _test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session: async with AsyncTestingSessionLocal() as session:
# Mock get_token_data to raise TokenInvalidError # Mock get_token_data to raise TokenInvalidError
with patch('app.api.dependencies.auth.get_token_data') as mock_get_data: with patch("app.api.dependencies.auth.get_token_data") as mock_get_data:
mock_get_data.side_effect = TokenInvalidError("Invalid token") mock_get_data.side_effect = TokenInvalidError("Invalid token")
# Call the dependency # Call the dependency
@@ -236,12 +248,14 @@ class TestGetOptionalCurrentUser:
assert user is None assert user is None
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_get_optional_current_user_expired_token(self, async_test_db, mock_token): async def test_get_optional_current_user_expired_token(
self, async_test_db, mock_token
):
"""Test getting optional user with an expired token""" """Test getting optional user with an expired token"""
test_engine, AsyncTestingSessionLocal = async_test_db _test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session: async with AsyncTestingSessionLocal() as session:
# Mock get_token_data to raise TokenExpiredError # Mock get_token_data to raise TokenExpiredError
with patch('app.api.dependencies.auth.get_token_data') as mock_get_data: with patch("app.api.dependencies.auth.get_token_data") as mock_get_data:
mock_get_data.side_effect = TokenExpiredError("Token expired") mock_get_data.side_effect = TokenExpiredError("Token expired")
# Call the dependency # Call the dependency
@@ -251,19 +265,24 @@ class TestGetOptionalCurrentUser:
assert user is None assert user is None
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_get_optional_current_user_inactive(self, async_test_db, async_mock_user, mock_token): async def test_get_optional_current_user_inactive(
self, async_test_db, async_mock_user, mock_token
):
"""Test getting optional user when user is inactive""" """Test getting optional user when user is inactive"""
test_engine, AsyncTestingSessionLocal = async_test_db _test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session: async with AsyncTestingSessionLocal() as session:
# Get the user in this session and make it inactive # Get the user in this session and make it inactive
from sqlalchemy import select from sqlalchemy import select
result = await session.execute(select(User).where(User.id == async_mock_user.id))
result = await session.execute(
select(User).where(User.id == async_mock_user.id)
)
user_in_session = result.scalar_one_or_none() user_in_session = result.scalar_one_or_none()
user_in_session.is_active = False user_in_session.is_active = False
await session.commit() await session.commit()
# Mock get_token_data # Mock get_token_data
with patch('app.api.dependencies.auth.get_token_data') as mock_get_data: with patch("app.api.dependencies.auth.get_token_data") as mock_get_data:
mock_get_data.return_value.user_id = async_mock_user.id mock_get_data.return_value.user_id = async_mock_user.id
# Call the dependency # Call the dependency

View File

@@ -1,13 +1,12 @@
# tests/api/routes/test_health.py # tests/api/routes/test_health.py
from datetime import datetime
from unittest.mock import patch
import pytest import pytest
from unittest.mock import AsyncMock, patch, MagicMock
from fastapi import status from fastapi import status
from fastapi.testclient import TestClient from fastapi.testclient import TestClient
from datetime import datetime
from sqlalchemy.exc import OperationalError
from app.main import app from app.main import app
from app.core.database import get_db
@pytest.fixture @pytest.fixture
@@ -121,7 +120,10 @@ class TestHealthEndpoint:
response = client.get("/health") response = client.get("/health")
# Should succeed without authentication # Should succeed without authentication
assert response.status_code in [status.HTTP_200_OK, status.HTTP_503_SERVICE_UNAVAILABLE] assert response.status_code in [
status.HTTP_200_OK,
status.HTTP_503_SERVICE_UNAVAILABLE,
]
def test_health_check_idempotent(self, client): def test_health_check_idempotent(self, client):
"""Test that multiple health checks return consistent results""" """Test that multiple health checks return consistent results"""
@@ -142,7 +144,10 @@ class TestHealthEndpoint:
assert data1["environment"] == data2["environment"] assert data1["environment"] == data2["environment"]
# Same database check status # Same database check status
assert data1["checks"]["database"]["status"] == data2["checks"]["database"]["status"] assert (
data1["checks"]["database"]["status"]
== data2["checks"]["database"]["status"]
)
def test_health_check_content_type(self, client): def test_health_check_content_type(self, client):
"""Test that health check returns JSON content type""" """Test that health check returns JSON content type"""

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -2,6 +2,7 @@
""" """
Tests for authentication endpoints. Tests for authentication endpoints.
""" """
import pytest import pytest
import pytest_asyncio import pytest_asyncio
from fastapi import status from fastapi import status
@@ -19,8 +20,8 @@ class TestRegisterEndpoint:
"email": "newuser@example.com", "email": "newuser@example.com",
"password": "NewPassword123!", "password": "NewPassword123!",
"first_name": "New", "first_name": "New",
"last_name": "User" "last_name": "User",
} },
) )
assert response.status_code == status.HTTP_201_CREATED assert response.status_code == status.HTTP_201_CREATED
@@ -36,8 +37,8 @@ class TestRegisterEndpoint:
"email": async_test_user.email, "email": async_test_user.email,
"password": "TestPassword123!", "password": "TestPassword123!",
"first_name": "Test", "first_name": "Test",
"last_name": "User" "last_name": "User",
} },
) )
assert response.status_code == status.HTTP_400_BAD_REQUEST assert response.status_code == status.HTTP_400_BAD_REQUEST
@@ -51,8 +52,8 @@ class TestRegisterEndpoint:
"email": "test@example.com", "email": "test@example.com",
"password": "weak", "password": "weak",
"first_name": "Test", "first_name": "Test",
"last_name": "User" "last_name": "User",
} },
) )
assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY
@@ -66,10 +67,7 @@ class TestLoginEndpoint:
"""Test successful login.""" """Test successful login."""
response = await client.post( response = await client.post(
"/api/v1/auth/login", "/api/v1/auth/login",
json={ json={"email": "testuser@example.com", "password": "TestPassword123!"},
"email": "testuser@example.com",
"password": "TestPassword123!"
}
) )
assert response.status_code == status.HTTP_200_OK assert response.status_code == status.HTTP_200_OK
@@ -82,10 +80,7 @@ class TestLoginEndpoint:
"""Test login with invalid password.""" """Test login with invalid password."""
response = await client.post( response = await client.post(
"/api/v1/auth/login", "/api/v1/auth/login",
json={ json={"email": "testuser@example.com", "password": "WrongPassword123!"},
"email": "testuser@example.com",
"password": "WrongPassword123!"
}
) )
assert response.status_code == status.HTTP_401_UNAUTHORIZED assert response.status_code == status.HTTP_401_UNAUTHORIZED
@@ -95,10 +90,7 @@ class TestLoginEndpoint:
"""Test login with non-existent user.""" """Test login with non-existent user."""
response = await client.post( response = await client.post(
"/api/v1/auth/login", "/api/v1/auth/login",
json={ json={"email": "nonexistent@example.com", "password": "TestPassword123!"},
"email": "nonexistent@example.com",
"password": "TestPassword123!"
}
) )
assert response.status_code == status.HTTP_401_UNAUTHORIZED assert response.status_code == status.HTTP_401_UNAUTHORIZED
@@ -106,27 +98,25 @@ class TestLoginEndpoint:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_login_inactive_user(self, client, async_test_db): async def test_login_inactive_user(self, client, async_test_db):
"""Test login with inactive user.""" """Test login with inactive user."""
test_engine, SessionLocal = async_test_db _test_engine, SessionLocal = async_test_db
async with SessionLocal() as session: async with SessionLocal() as session:
from app.models.user import User
from app.core.auth import get_password_hash from app.core.auth import get_password_hash
from app.models.user import User
inactive_user = User( inactive_user = User(
email="inactive@example.com", email="inactive@example.com",
password_hash=get_password_hash("TestPassword123!"), password_hash=get_password_hash("TestPassword123!"),
first_name="Inactive", first_name="Inactive",
last_name="User", last_name="User",
is_active=False is_active=False,
) )
session.add(inactive_user) session.add(inactive_user)
await session.commit() await session.commit()
response = await client.post( response = await client.post(
"/api/v1/auth/login", "/api/v1/auth/login",
json={ json={"email": "inactive@example.com", "password": "TestPassword123!"},
"email": "inactive@example.com",
"password": "TestPassword123!"
}
) )
assert response.status_code == status.HTTP_401_UNAUTHORIZED assert response.status_code == status.HTTP_401_UNAUTHORIZED
@@ -140,10 +130,7 @@ class TestRefreshTokenEndpoint:
"""Get a refresh token for testing.""" """Get a refresh token for testing."""
response = await client.post( response = await client.post(
"/api/v1/auth/login", "/api/v1/auth/login",
json={ json={"email": "testuser@example.com", "password": "TestPassword123!"},
"email": "testuser@example.com",
"password": "TestPassword123!"
}
) )
return response.json()["refresh_token"] return response.json()["refresh_token"]
@@ -151,8 +138,7 @@ class TestRefreshTokenEndpoint:
async def test_refresh_token_success(self, client, refresh_token): async def test_refresh_token_success(self, client, refresh_token):
"""Test successful token refresh.""" """Test successful token refresh."""
response = await client.post( response = await client.post(
"/api/v1/auth/refresh", "/api/v1/auth/refresh", json={"refresh_token": refresh_token}
json={"refresh_token": refresh_token}
) )
assert response.status_code == status.HTTP_200_OK assert response.status_code == status.HTTP_200_OK
@@ -164,8 +150,7 @@ class TestRefreshTokenEndpoint:
async def test_refresh_token_invalid(self, client): async def test_refresh_token_invalid(self, client):
"""Test refresh with invalid token.""" """Test refresh with invalid token."""
response = await client.post( response = await client.post(
"/api/v1/auth/refresh", "/api/v1/auth/refresh", json={"refresh_token": "invalid.token.here"}
json={"refresh_token": "invalid.token.here"}
) )
assert response.status_code == status.HTTP_401_UNAUTHORIZED assert response.status_code == status.HTTP_401_UNAUTHORIZED
@@ -179,13 +164,13 @@ class TestLogoutEndpoint:
"""Get tokens for testing.""" """Get tokens for testing."""
response = await client.post( response = await client.post(
"/api/v1/auth/login", "/api/v1/auth/login",
json={ json={"email": "testuser@example.com", "password": "TestPassword123!"},
"email": "testuser@example.com",
"password": "TestPassword123!"
}
) )
data = response.json() data = response.json()
return {"access_token": data["access_token"], "refresh_token": data["refresh_token"]} return {
"access_token": data["access_token"],
"refresh_token": data["refresh_token"],
}
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_logout_success(self, client, tokens): async def test_logout_success(self, client, tokens):
@@ -193,7 +178,7 @@ class TestLogoutEndpoint:
response = await client.post( response = await client.post(
"/api/v1/auth/logout", "/api/v1/auth/logout",
headers={"Authorization": f"Bearer {tokens['access_token']}"}, headers={"Authorization": f"Bearer {tokens['access_token']}"},
json={"refresh_token": tokens["refresh_token"]} json={"refresh_token": tokens["refresh_token"]},
) )
assert response.status_code == status.HTTP_200_OK assert response.status_code == status.HTTP_200_OK
@@ -202,8 +187,7 @@ class TestLogoutEndpoint:
async def test_logout_without_auth(self, client): async def test_logout_without_auth(self, client):
"""Test logout without authentication.""" """Test logout without authentication."""
response = await client.post( response = await client.post(
"/api/v1/auth/logout", "/api/v1/auth/logout", json={"refresh_token": "some.token"}
json={"refresh_token": "some.token"}
) )
assert response.status_code == status.HTTP_401_UNAUTHORIZED assert response.status_code == status.HTTP_401_UNAUTHORIZED
@@ -215,8 +199,7 @@ class TestPasswordResetRequest:
async def test_password_reset_request_success(self, client, async_test_user): async def test_password_reset_request_success(self, client, async_test_user):
"""Test password reset request with existing user.""" """Test password reset request with existing user."""
response = await client.post( response = await client.post(
"/api/v1/auth/password-reset/request", "/api/v1/auth/password-reset/request", json={"email": async_test_user.email}
json={"email": async_test_user.email}
) )
assert response.status_code == status.HTTP_200_OK assert response.status_code == status.HTTP_200_OK
@@ -228,7 +211,7 @@ class TestPasswordResetRequest:
"""Test password reset request with non-existent email.""" """Test password reset request with non-existent email."""
response = await client.post( response = await client.post(
"/api/v1/auth/password-reset/request", "/api/v1/auth/password-reset/request",
json={"email": "nonexistent@example.com"} json={"email": "nonexistent@example.com"},
) )
assert response.status_code == status.HTTP_200_OK assert response.status_code == status.HTTP_200_OK
@@ -244,10 +227,7 @@ class TestPasswordResetConfirm:
"""Test password reset with invalid token.""" """Test password reset with invalid token."""
response = await client.post( response = await client.post(
"/api/v1/auth/password-reset/confirm", "/api/v1/auth/password-reset/confirm",
json={ json={"token": "invalid.token.here", "new_password": "NewPassword123!"},
"token": "invalid.token.here",
"new_password": "NewPassword123!"
}
) )
assert response.status_code == status.HTTP_400_BAD_REQUEST assert response.status_code == status.HTTP_400_BAD_REQUEST
@@ -261,20 +241,20 @@ class TestLogoutAll:
"""Get tokens for testing.""" """Get tokens for testing."""
response = await client.post( response = await client.post(
"/api/v1/auth/login", "/api/v1/auth/login",
json={ json={"email": "testuser@example.com", "password": "TestPassword123!"},
"email": "testuser@example.com",
"password": "TestPassword123!"
}
) )
data = response.json() data = response.json()
return {"access_token": data["access_token"], "refresh_token": data["refresh_token"]} return {
"access_token": data["access_token"],
"refresh_token": data["refresh_token"],
}
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_logout_all_success(self, client, tokens): async def test_logout_all_success(self, client, tokens):
"""Test logout from all devices.""" """Test logout from all devices."""
response = await client.post( response = await client.post(
"/api/v1/auth/logout-all", "/api/v1/auth/logout-all",
headers={"Authorization": f"Bearer {tokens['access_token']}"} headers={"Authorization": f"Bearer {tokens['access_token']}"},
) )
assert response.status_code == status.HTTP_200_OK assert response.status_code == status.HTTP_200_OK
@@ -298,10 +278,7 @@ class TestOAuthLogin:
"""Test successful OAuth login.""" """Test successful OAuth login."""
response = await client.post( response = await client.post(
"/api/v1/auth/login/oauth", "/api/v1/auth/login/oauth",
data={ data={"username": "testuser@example.com", "password": "TestPassword123!"},
"username": "testuser@example.com",
"password": "TestPassword123!"
}
) )
assert response.status_code == status.HTTP_200_OK assert response.status_code == status.HTTP_200_OK
@@ -315,10 +292,7 @@ class TestOAuthLogin:
"""Test OAuth login with invalid credentials.""" """Test OAuth login with invalid credentials."""
response = await client.post( response = await client.post(
"/api/v1/auth/login/oauth", "/api/v1/auth/login/oauth",
data={ data={"username": "testuser@example.com", "password": "WrongPassword"},
"username": "testuser@example.com",
"password": "WrongPassword"
}
) )
assert response.status_code == status.HTTP_401_UNAUTHORIZED assert response.status_code == status.HTTP_401_UNAUTHORIZED

View File

@@ -1,15 +1,16 @@
# tests/api/dependencies/test_auth_dependencies.py # tests/api/dependencies/test_auth_dependencies.py
import pytest
import pytest_asyncio
import uuid import uuid
from unittest.mock import patch from unittest.mock import patch
import pytest
import pytest_asyncio
from fastapi import HTTPException from fastapi import HTTPException
from app.api.dependencies.auth import ( from app.api.dependencies.auth import (
get_current_user,
get_current_active_user, get_current_active_user,
get_current_superuser, get_current_superuser,
get_optional_current_user get_current_user,
get_optional_current_user,
) )
from app.core.auth import TokenExpiredError, TokenInvalidError, get_password_hash from app.core.auth import TokenExpiredError, TokenInvalidError, get_password_hash
from app.models.user import User from app.models.user import User
@@ -24,7 +25,7 @@ def mock_token():
@pytest_asyncio.fixture @pytest_asyncio.fixture
async def async_mock_user(async_test_db): async def async_mock_user(async_test_db):
"""Async fixture to create and return a mock User instance.""" """Async fixture to create and return a mock User instance."""
test_engine, AsyncTestingSessionLocal = async_test_db _test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session: async with AsyncTestingSessionLocal() as session:
mock_user = User( mock_user = User(
id=uuid.uuid4(), id=uuid.uuid4(),
@@ -47,12 +48,14 @@ class TestGetCurrentUser:
"""Tests for get_current_user dependency""" """Tests for get_current_user dependency"""
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_get_current_user_success(self, async_test_db, async_mock_user, mock_token): async def test_get_current_user_success(
self, async_test_db, async_mock_user, mock_token
):
"""Test successfully getting the current user""" """Test successfully getting the current user"""
test_engine, AsyncTestingSessionLocal = async_test_db _test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session: async with AsyncTestingSessionLocal() as session:
# Mock get_token_data to return user_id that matches our mock_user # Mock get_token_data to return user_id that matches our mock_user
with patch('app.api.dependencies.auth.get_token_data') as mock_get_data: with patch("app.api.dependencies.auth.get_token_data") as mock_get_data:
mock_get_data.return_value.user_id = async_mock_user.id mock_get_data.return_value.user_id = async_mock_user.id
# Call the dependency # Call the dependency
@@ -65,12 +68,12 @@ class TestGetCurrentUser:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_get_current_user_nonexistent(self, async_test_db, mock_token): async def test_get_current_user_nonexistent(self, async_test_db, mock_token):
"""Test when the token contains a user ID that doesn't exist""" """Test when the token contains a user ID that doesn't exist"""
test_engine, AsyncTestingSessionLocal = async_test_db _test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session: async with AsyncTestingSessionLocal() as session:
# Mock get_token_data to return a non-existent user ID # Mock get_token_data to return a non-existent user ID
nonexistent_id = uuid.UUID("11111111-1111-1111-1111-111111111111") nonexistent_id = uuid.UUID("11111111-1111-1111-1111-111111111111")
with patch('app.api.dependencies.auth.get_token_data') as mock_get_data: with patch("app.api.dependencies.auth.get_token_data") as mock_get_data:
mock_get_data.return_value.user_id = nonexistent_id mock_get_data.return_value.user_id = nonexistent_id
# Should raise HTTPException with 404 status # Should raise HTTPException with 404 status
@@ -81,19 +84,24 @@ class TestGetCurrentUser:
assert "User not found" in exc_info.value.detail assert "User not found" in exc_info.value.detail
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_get_current_user_inactive(self, async_test_db, async_mock_user, mock_token): async def test_get_current_user_inactive(
self, async_test_db, async_mock_user, mock_token
):
"""Test when the user is inactive""" """Test when the user is inactive"""
test_engine, AsyncTestingSessionLocal = async_test_db _test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session: async with AsyncTestingSessionLocal() as session:
# Get the user in this session and make it inactive # Get the user in this session and make it inactive
from sqlalchemy import select from sqlalchemy import select
result = await session.execute(select(User).where(User.id == async_mock_user.id))
result = await session.execute(
select(User).where(User.id == async_mock_user.id)
)
user_in_session = result.scalar_one_or_none() user_in_session = result.scalar_one_or_none()
user_in_session.is_active = False user_in_session.is_active = False
await session.commit() await session.commit()
# Mock get_token_data # Mock get_token_data
with patch('app.api.dependencies.auth.get_token_data') as mock_get_data: with patch("app.api.dependencies.auth.get_token_data") as mock_get_data:
mock_get_data.return_value.user_id = async_mock_user.id mock_get_data.return_value.user_id = async_mock_user.id
# Should raise HTTPException with 403 status # Should raise HTTPException with 403 status
@@ -106,10 +114,10 @@ class TestGetCurrentUser:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_get_current_user_expired_token(self, async_test_db, mock_token): async def test_get_current_user_expired_token(self, async_test_db, mock_token):
"""Test with an expired token""" """Test with an expired token"""
test_engine, AsyncTestingSessionLocal = async_test_db _test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session: async with AsyncTestingSessionLocal() as session:
# Mock get_token_data to raise TokenExpiredError # Mock get_token_data to raise TokenExpiredError
with patch('app.api.dependencies.auth.get_token_data') as mock_get_data: with patch("app.api.dependencies.auth.get_token_data") as mock_get_data:
mock_get_data.side_effect = TokenExpiredError("Token expired") mock_get_data.side_effect = TokenExpiredError("Token expired")
# Should raise HTTPException with 401 status # Should raise HTTPException with 401 status
@@ -122,10 +130,10 @@ class TestGetCurrentUser:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_get_current_user_invalid_token(self, async_test_db, mock_token): async def test_get_current_user_invalid_token(self, async_test_db, mock_token):
"""Test with an invalid token""" """Test with an invalid token"""
test_engine, AsyncTestingSessionLocal = async_test_db _test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session: async with AsyncTestingSessionLocal() as session:
# Mock get_token_data to raise TokenInvalidError # Mock get_token_data to raise TokenInvalidError
with patch('app.api.dependencies.auth.get_token_data') as mock_get_data: with patch("app.api.dependencies.auth.get_token_data") as mock_get_data:
mock_get_data.side_effect = TokenInvalidError("Invalid token") mock_get_data.side_effect = TokenInvalidError("Invalid token")
# Should raise HTTPException with 401 status # Should raise HTTPException with 401 status
@@ -194,12 +202,14 @@ class TestGetOptionalCurrentUser:
"""Tests for get_optional_current_user dependency""" """Tests for get_optional_current_user dependency"""
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_get_optional_current_user_with_token(self, async_test_db, async_mock_user, mock_token): async def test_get_optional_current_user_with_token(
self, async_test_db, async_mock_user, mock_token
):
"""Test getting optional user with a valid token""" """Test getting optional user with a valid token"""
test_engine, AsyncTestingSessionLocal = async_test_db _test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session: async with AsyncTestingSessionLocal() as session:
# Mock get_token_data # Mock get_token_data
with patch('app.api.dependencies.auth.get_token_data') as mock_get_data: with patch("app.api.dependencies.auth.get_token_data") as mock_get_data:
mock_get_data.return_value.user_id = async_mock_user.id mock_get_data.return_value.user_id = async_mock_user.id
# Call the dependency # Call the dependency
@@ -212,7 +222,7 @@ class TestGetOptionalCurrentUser:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_get_optional_current_user_no_token(self, async_test_db): async def test_get_optional_current_user_no_token(self, async_test_db):
"""Test getting optional user with no token""" """Test getting optional user with no token"""
test_engine, AsyncTestingSessionLocal = async_test_db _test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session: async with AsyncTestingSessionLocal() as session:
# Call the dependency with no token # Call the dependency with no token
user = await get_optional_current_user(db=session, token=None) user = await get_optional_current_user(db=session, token=None)
@@ -221,12 +231,14 @@ class TestGetOptionalCurrentUser:
assert user is None assert user is None
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_get_optional_current_user_invalid_token(self, async_test_db, mock_token): async def test_get_optional_current_user_invalid_token(
self, async_test_db, mock_token
):
"""Test getting optional user with an invalid token""" """Test getting optional user with an invalid token"""
test_engine, AsyncTestingSessionLocal = async_test_db _test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session: async with AsyncTestingSessionLocal() as session:
# Mock get_token_data to raise TokenInvalidError # Mock get_token_data to raise TokenInvalidError
with patch('app.api.dependencies.auth.get_token_data') as mock_get_data: with patch("app.api.dependencies.auth.get_token_data") as mock_get_data:
mock_get_data.side_effect = TokenInvalidError("Invalid token") mock_get_data.side_effect = TokenInvalidError("Invalid token")
# Call the dependency # Call the dependency
@@ -236,12 +248,14 @@ class TestGetOptionalCurrentUser:
assert user is None assert user is None
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_get_optional_current_user_expired_token(self, async_test_db, mock_token): async def test_get_optional_current_user_expired_token(
self, async_test_db, mock_token
):
"""Test getting optional user with an expired token""" """Test getting optional user with an expired token"""
test_engine, AsyncTestingSessionLocal = async_test_db _test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session: async with AsyncTestingSessionLocal() as session:
# Mock get_token_data to raise TokenExpiredError # Mock get_token_data to raise TokenExpiredError
with patch('app.api.dependencies.auth.get_token_data') as mock_get_data: with patch("app.api.dependencies.auth.get_token_data") as mock_get_data:
mock_get_data.side_effect = TokenExpiredError("Token expired") mock_get_data.side_effect = TokenExpiredError("Token expired")
# Call the dependency # Call the dependency
@@ -251,19 +265,24 @@ class TestGetOptionalCurrentUser:
assert user is None assert user is None
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_get_optional_current_user_inactive(self, async_test_db, async_mock_user, mock_token): async def test_get_optional_current_user_inactive(
self, async_test_db, async_mock_user, mock_token
):
"""Test getting optional user when user is inactive""" """Test getting optional user when user is inactive"""
test_engine, AsyncTestingSessionLocal = async_test_db _test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session: async with AsyncTestingSessionLocal() as session:
# Get the user in this session and make it inactive # Get the user in this session and make it inactive
from sqlalchemy import select from sqlalchemy import select
result = await session.execute(select(User).where(User.id == async_mock_user.id))
result = await session.execute(
select(User).where(User.id == async_mock_user.id)
)
user_in_session = result.scalar_one_or_none() user_in_session = result.scalar_one_or_none()
user_in_session.is_active = False user_in_session.is_active = False
await session.commit() await session.commit()
# Mock get_token_data # Mock get_token_data
with patch('app.api.dependencies.auth.get_token_data') as mock_get_data: with patch("app.api.dependencies.auth.get_token_data") as mock_get_data:
mock_get_data.return_value.user_id = async_mock_user.id mock_get_data.return_value.user_id = async_mock_user.id
# Call the dependency # Call the dependency

View File

@@ -2,21 +2,21 @@
""" """
Tests for authentication endpoints. Tests for authentication endpoints.
""" """
from unittest.mock import patch
import pytest import pytest
import pytest_asyncio
from unittest.mock import patch, MagicMock
from fastapi import status from fastapi import status
from sqlalchemy import select from sqlalchemy import select
from app.models.user import User from app.models.user import User
from app.schemas.users import UserCreate
# Disable rate limiting for tests # Disable rate limiting for tests
@pytest.fixture(autouse=True) @pytest.fixture(autouse=True)
def disable_rate_limit(): def disable_rate_limit():
"""Disable rate limiting for all tests in this module.""" """Disable rate limiting for all tests in this module."""
with patch('app.api.routes.auth.limiter.enabled', False): with patch("app.api.routes.auth.limiter.enabled", False):
yield yield
@@ -32,8 +32,8 @@ class TestRegisterEndpoint:
"email": "newuser@example.com", "email": "newuser@example.com",
"password": "SecurePassword123!", "password": "SecurePassword123!",
"first_name": "New", "first_name": "New",
"last_name": "User" "last_name": "User",
} },
) )
assert response.status_code == status.HTTP_201_CREATED assert response.status_code == status.HTTP_201_CREATED
@@ -54,8 +54,8 @@ class TestRegisterEndpoint:
"email": async_test_user.email, "email": async_test_user.email,
"password": "SecurePassword123!", "password": "SecurePassword123!",
"first_name": "Duplicate", "first_name": "Duplicate",
"last_name": "User" "last_name": "User",
} },
) )
# Security: Returns 400 with generic message to prevent email enumeration # Security: Returns 400 with generic message to prevent email enumeration
@@ -73,8 +73,8 @@ class TestRegisterEndpoint:
"email": "weakpass@example.com", "email": "weakpass@example.com",
"password": "weak", "password": "weak",
"first_name": "Weak", "first_name": "Weak",
"last_name": "Pass" "last_name": "Pass",
} },
) )
assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY
@@ -82,7 +82,7 @@ class TestRegisterEndpoint:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_register_unexpected_error(self, client): async def test_register_unexpected_error(self, client):
"""Test registration with unexpected error.""" """Test registration with unexpected error."""
with patch('app.services.auth_service.AuthService.create_user') as mock_create: with patch("app.services.auth_service.AuthService.create_user") as mock_create:
mock_create.side_effect = Exception("Unexpected error") mock_create.side_effect = Exception("Unexpected error")
response = await client.post( response = await client.post(
@@ -91,8 +91,8 @@ class TestRegisterEndpoint:
"email": "error@example.com", "email": "error@example.com",
"password": "SecurePassword123!", "password": "SecurePassword123!",
"first_name": "Error", "first_name": "Error",
"last_name": "User" "last_name": "User",
} },
) )
assert response.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR assert response.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR
@@ -106,10 +106,7 @@ class TestLoginEndpoint:
"""Test successful login.""" """Test successful login."""
response = await client.post( response = await client.post(
"/api/v1/auth/login", "/api/v1/auth/login",
json={ json={"email": async_test_user.email, "password": "TestPassword123!"},
"email": async_test_user.email,
"password": "TestPassword123!"
}
) )
assert response.status_code == status.HTTP_200_OK assert response.status_code == status.HTTP_200_OK
@@ -123,10 +120,7 @@ class TestLoginEndpoint:
"""Test login with wrong password.""" """Test login with wrong password."""
response = await client.post( response = await client.post(
"/api/v1/auth/login", "/api/v1/auth/login",
json={ json={"email": async_test_user.email, "password": "WrongPassword123"},
"email": async_test_user.email,
"password": "WrongPassword123"
}
) )
assert response.status_code == status.HTTP_401_UNAUTHORIZED assert response.status_code == status.HTTP_401_UNAUTHORIZED
@@ -136,10 +130,7 @@ class TestLoginEndpoint:
"""Test login with non-existent email.""" """Test login with non-existent email."""
response = await client.post( response = await client.post(
"/api/v1/auth/login", "/api/v1/auth/login",
json={ json={"email": "nonexistent@example.com", "password": "Password123!"},
"email": "nonexistent@example.com",
"password": "Password123!"
}
) )
assert response.status_code == status.HTTP_401_UNAUTHORIZED assert response.status_code == status.HTTP_401_UNAUTHORIZED
@@ -147,20 +138,19 @@ class TestLoginEndpoint:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_login_inactive_user(self, client, async_test_user, async_test_db): async def test_login_inactive_user(self, client, async_test_user, async_test_db):
"""Test login with inactive user.""" """Test login with inactive user."""
test_engine, AsyncTestingSessionLocal = async_test_db _test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session: async with AsyncTestingSessionLocal() as session:
# Get the user in this session and make it inactive # Get the user in this session and make it inactive
result = await session.execute(select(User).where(User.id == async_test_user.id)) result = await session.execute(
select(User).where(User.id == async_test_user.id)
)
user_in_session = result.scalar_one_or_none() user_in_session = result.scalar_one_or_none()
user_in_session.is_active = False user_in_session.is_active = False
await session.commit() await session.commit()
response = await client.post( response = await client.post(
"/api/v1/auth/login", "/api/v1/auth/login",
json={ json={"email": async_test_user.email, "password": "TestPassword123!"},
"email": async_test_user.email,
"password": "TestPassword123!"
}
) )
assert response.status_code == status.HTTP_401_UNAUTHORIZED assert response.status_code == status.HTTP_401_UNAUTHORIZED
@@ -168,15 +158,14 @@ class TestLoginEndpoint:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_login_unexpected_error(self, client, async_test_user): async def test_login_unexpected_error(self, client, async_test_user):
"""Test login with unexpected error.""" """Test login with unexpected error."""
with patch('app.services.auth_service.AuthService.authenticate_user') as mock_auth: with patch(
"app.services.auth_service.AuthService.authenticate_user"
) as mock_auth:
mock_auth.side_effect = Exception("Database error") mock_auth.side_effect = Exception("Database error")
response = await client.post( response = await client.post(
"/api/v1/auth/login", "/api/v1/auth/login",
json={ json={"email": async_test_user.email, "password": "TestPassword123!"},
"email": async_test_user.email,
"password": "TestPassword123!"
}
) )
assert response.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR assert response.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR
@@ -190,10 +179,7 @@ class TestOAuthLoginEndpoint:
"""Test successful OAuth login.""" """Test successful OAuth login."""
response = await client.post( response = await client.post(
"/api/v1/auth/login/oauth", "/api/v1/auth/login/oauth",
data={ data={"username": async_test_user.email, "password": "TestPassword123!"},
"username": async_test_user.email,
"password": "TestPassword123!"
}
) )
assert response.status_code == status.HTTP_200_OK assert response.status_code == status.HTTP_200_OK
@@ -206,31 +192,29 @@ class TestOAuthLoginEndpoint:
"""Test OAuth login with wrong credentials.""" """Test OAuth login with wrong credentials."""
response = await client.post( response = await client.post(
"/api/v1/auth/login/oauth", "/api/v1/auth/login/oauth",
data={ data={"username": async_test_user.email, "password": "WrongPassword"},
"username": async_test_user.email,
"password": "WrongPassword"
}
) )
assert response.status_code == status.HTTP_401_UNAUTHORIZED assert response.status_code == status.HTTP_401_UNAUTHORIZED
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_oauth_login_inactive_user(self, client, async_test_user, async_test_db): async def test_oauth_login_inactive_user(
self, client, async_test_user, async_test_db
):
"""Test OAuth login with inactive user.""" """Test OAuth login with inactive user."""
test_engine, AsyncTestingSessionLocal = async_test_db _test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session: async with AsyncTestingSessionLocal() as session:
# Get the user in this session and make it inactive # Get the user in this session and make it inactive
result = await session.execute(select(User).where(User.id == async_test_user.id)) result = await session.execute(
select(User).where(User.id == async_test_user.id)
)
user_in_session = result.scalar_one_or_none() user_in_session = result.scalar_one_or_none()
user_in_session.is_active = False user_in_session.is_active = False
await session.commit() await session.commit()
response = await client.post( response = await client.post(
"/api/v1/auth/login/oauth", "/api/v1/auth/login/oauth",
data={ data={"username": async_test_user.email, "password": "TestPassword123!"},
"username": async_test_user.email,
"password": "TestPassword123!"
}
) )
assert response.status_code == status.HTTP_401_UNAUTHORIZED assert response.status_code == status.HTTP_401_UNAUTHORIZED
@@ -238,15 +222,17 @@ class TestOAuthLoginEndpoint:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_oauth_login_unexpected_error(self, client, async_test_user): async def test_oauth_login_unexpected_error(self, client, async_test_user):
"""Test OAuth login with unexpected error.""" """Test OAuth login with unexpected error."""
with patch('app.services.auth_service.AuthService.authenticate_user') as mock_auth: with patch(
"app.services.auth_service.AuthService.authenticate_user"
) as mock_auth:
mock_auth.side_effect = Exception("Unexpected error") mock_auth.side_effect = Exception("Unexpected error")
response = await client.post( response = await client.post(
"/api/v1/auth/login/oauth", "/api/v1/auth/login/oauth",
data={ data={
"username": async_test_user.email, "username": async_test_user.email,
"password": "TestPassword123!" "password": "TestPassword123!",
} },
) )
assert response.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR assert response.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR
@@ -261,17 +247,13 @@ class TestRefreshTokenEndpoint:
# First, login to get a refresh token # First, login to get a refresh token
login_response = await client.post( login_response = await client.post(
"/api/v1/auth/login", "/api/v1/auth/login",
json={ json={"email": async_test_user.email, "password": "TestPassword123!"},
"email": async_test_user.email,
"password": "TestPassword123!"
}
) )
refresh_token = login_response.json()["refresh_token"] refresh_token = login_response.json()["refresh_token"]
# Now refresh the token # Now refresh the token
response = await client.post( response = await client.post(
"/api/v1/auth/refresh", "/api/v1/auth/refresh", json={"refresh_token": refresh_token}
json={"refresh_token": refresh_token}
) )
assert response.status_code == status.HTTP_200_OK assert response.status_code == status.HTTP_200_OK
@@ -284,12 +266,13 @@ class TestRefreshTokenEndpoint:
"""Test refresh with expired token.""" """Test refresh with expired token."""
from app.core.auth import TokenExpiredError from app.core.auth import TokenExpiredError
with patch('app.services.auth_service.AuthService.refresh_tokens') as mock_refresh: with patch(
"app.services.auth_service.AuthService.refresh_tokens"
) as mock_refresh:
mock_refresh.side_effect = TokenExpiredError("Token expired") mock_refresh.side_effect = TokenExpiredError("Token expired")
response = await client.post( response = await client.post(
"/api/v1/auth/refresh", "/api/v1/auth/refresh", json={"refresh_token": "some_token"}
json={"refresh_token": "some_token"}
) )
assert response.status_code == status.HTTP_401_UNAUTHORIZED assert response.status_code == status.HTTP_401_UNAUTHORIZED
@@ -298,8 +281,7 @@ class TestRefreshTokenEndpoint:
async def test_refresh_token_invalid(self, client): async def test_refresh_token_invalid(self, client):
"""Test refresh with invalid token.""" """Test refresh with invalid token."""
response = await client.post( response = await client.post(
"/api/v1/auth/refresh", "/api/v1/auth/refresh", json={"refresh_token": "invalid_token"}
json={"refresh_token": "invalid_token"}
) )
assert response.status_code == status.HTTP_401_UNAUTHORIZED assert response.status_code == status.HTTP_401_UNAUTHORIZED
@@ -310,19 +292,17 @@ class TestRefreshTokenEndpoint:
# Get a valid refresh token first # Get a valid refresh token first
login_response = await client.post( login_response = await client.post(
"/api/v1/auth/login", "/api/v1/auth/login",
json={ json={"email": async_test_user.email, "password": "TestPassword123!"},
"email": async_test_user.email,
"password": "TestPassword123!"
}
) )
refresh_token = login_response.json()["refresh_token"] refresh_token = login_response.json()["refresh_token"]
with patch('app.services.auth_service.AuthService.refresh_tokens') as mock_refresh: with patch(
"app.services.auth_service.AuthService.refresh_tokens"
) as mock_refresh:
mock_refresh.side_effect = Exception("Unexpected error") mock_refresh.side_effect = Exception("Unexpected error")
response = await client.post( response = await client.post(
"/api/v1/auth/refresh", "/api/v1/auth/refresh", json={"refresh_token": refresh_token}
json={"refresh_token": refresh_token}
) )
assert response.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR assert response.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR

View File

@@ -2,8 +2,10 @@
""" """
Tests for auth route exception handlers and error paths. Tests for auth route exception handlers and error paths.
""" """
from unittest.mock import patch
import pytest import pytest
from unittest.mock import patch, AsyncMock
from fastapi import status from fastapi import status
@@ -11,16 +13,18 @@ class TestLoginSessionCreationFailure:
"""Test login when session creation fails.""" """Test login when session creation fails."""
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_login_succeeds_despite_session_creation_failure(self, client, async_test_user): async def test_login_succeeds_despite_session_creation_failure(
self, client, async_test_user
):
"""Test that login succeeds even if session creation fails.""" """Test that login succeeds even if session creation fails."""
# Mock session creation to fail # Mock session creation to fail
with patch('app.api.routes.auth.session_crud.create_session', side_effect=Exception("Session creation failed")): with patch(
"app.api.routes.auth.session_crud.create_session",
side_effect=Exception("Session creation failed"),
):
response = await client.post( response = await client.post(
"/api/v1/auth/login", "/api/v1/auth/login",
json={ json={"email": "testuser@example.com", "password": "TestPassword123!"},
"email": "testuser@example.com",
"password": "TestPassword123!"
}
) )
# Login should still succeed, just without session record # Login should still succeed, just without session record
@@ -34,15 +38,20 @@ class TestOAuthLoginSessionCreationFailure:
"""Test OAuth login when session creation fails.""" """Test OAuth login when session creation fails."""
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_oauth_login_succeeds_despite_session_failure(self, client, async_test_user): async def test_oauth_login_succeeds_despite_session_failure(
self, client, async_test_user
):
"""Test OAuth login succeeds even if session creation fails.""" """Test OAuth login succeeds even if session creation fails."""
with patch('app.api.routes.auth.session_crud.create_session', side_effect=Exception("Session failed")): with patch(
"app.api.routes.auth.session_crud.create_session",
side_effect=Exception("Session failed"),
):
response = await client.post( response = await client.post(
"/api/v1/auth/login/oauth", "/api/v1/auth/login/oauth",
data={ data={
"username": "testuser@example.com", "username": "testuser@example.com",
"password": "TestPassword123!" "password": "TestPassword123!",
} },
) )
assert response.status_code == status.HTTP_200_OK assert response.status_code == status.HTTP_200_OK
@@ -54,23 +63,24 @@ class TestRefreshTokenSessionUpdateFailure:
"""Test refresh token when session update fails.""" """Test refresh token when session update fails."""
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_refresh_token_succeeds_despite_session_update_failure(self, client, async_test_user): async def test_refresh_token_succeeds_despite_session_update_failure(
self, client, async_test_user
):
"""Test that token refresh succeeds even if session update fails.""" """Test that token refresh succeeds even if session update fails."""
# First login to get tokens # First login to get tokens
response = await client.post( response = await client.post(
"/api/v1/auth/login", "/api/v1/auth/login",
json={ json={"email": "testuser@example.com", "password": "TestPassword123!"},
"email": "testuser@example.com",
"password": "TestPassword123!"
}
) )
tokens = response.json() tokens = response.json()
# Mock session update to fail # Mock session update to fail
with patch('app.api.routes.auth.session_crud.update_refresh_token', side_effect=Exception("Update failed")): with patch(
"app.api.routes.auth.session_crud.update_refresh_token",
side_effect=Exception("Update failed"),
):
response = await client.post( response = await client.post(
"/api/v1/auth/refresh", "/api/v1/auth/refresh", json={"refresh_token": tokens["refresh_token"]}
json={"refresh_token": tokens["refresh_token"]}
) )
# Should still succeed - tokens are issued before update # Should still succeed - tokens are issued before update
@@ -83,15 +93,14 @@ class TestLogoutWithExpiredToken:
"""Test logout with expired/invalid token.""" """Test logout with expired/invalid token."""
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_logout_with_invalid_token_still_succeeds(self, client, async_test_user): async def test_logout_with_invalid_token_still_succeeds(
self, client, async_test_user
):
"""Test logout succeeds even with invalid refresh token.""" """Test logout succeeds even with invalid refresh token."""
# Login first # Login first
response = await client.post( response = await client.post(
"/api/v1/auth/login", "/api/v1/auth/login",
json={ json={"email": "testuser@example.com", "password": "TestPassword123!"},
"email": "testuser@example.com",
"password": "TestPassword123!"
}
) )
access_token = response.json()["access_token"] access_token = response.json()["access_token"]
@@ -99,7 +108,7 @@ class TestLogoutWithExpiredToken:
response = await client.post( response = await client.post(
"/api/v1/auth/logout", "/api/v1/auth/logout",
headers={"Authorization": f"Bearer {access_token}"}, headers={"Authorization": f"Bearer {access_token}"},
json={"refresh_token": "invalid.token.here"} json={"refresh_token": "invalid.token.here"},
) )
# Should succeed (idempotent) # Should succeed (idempotent)
@@ -116,19 +125,16 @@ class TestLogoutWithNonExistentSession:
"""Test logout succeeds even if session not found.""" """Test logout succeeds even if session not found."""
response = await client.post( response = await client.post(
"/api/v1/auth/login", "/api/v1/auth/login",
json={ json={"email": "testuser@example.com", "password": "TestPassword123!"},
"email": "testuser@example.com",
"password": "TestPassword123!"
}
) )
tokens = response.json() tokens = response.json()
# Mock session lookup to return None # Mock session lookup to return None
with patch('app.api.routes.auth.session_crud.get_by_jti', return_value=None): with patch("app.api.routes.auth.session_crud.get_by_jti", return_value=None):
response = await client.post( response = await client.post(
"/api/v1/auth/logout", "/api/v1/auth/logout",
headers={"Authorization": f"Bearer {tokens['access_token']}"}, headers={"Authorization": f"Bearer {tokens['access_token']}"},
json={"refresh_token": tokens["refresh_token"]} json={"refresh_token": tokens["refresh_token"]},
) )
# Should succeed (idempotent) # Should succeed (idempotent)
@@ -139,23 +145,25 @@ class TestLogoutUnexpectedError:
"""Test logout with unexpected errors.""" """Test logout with unexpected errors."""
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_logout_with_unexpected_error_returns_success(self, client, async_test_user): async def test_logout_with_unexpected_error_returns_success(
self, client, async_test_user
):
"""Test logout returns success even on unexpected errors.""" """Test logout returns success even on unexpected errors."""
response = await client.post( response = await client.post(
"/api/v1/auth/login", "/api/v1/auth/login",
json={ json={"email": "testuser@example.com", "password": "TestPassword123!"},
"email": "testuser@example.com",
"password": "TestPassword123!"
}
) )
tokens = response.json() tokens = response.json()
# Mock to raise unexpected error # Mock to raise unexpected error
with patch('app.api.routes.auth.session_crud.get_by_jti', side_effect=Exception("Unexpected error")): with patch(
"app.api.routes.auth.session_crud.get_by_jti",
side_effect=Exception("Unexpected error"),
):
response = await client.post( response = await client.post(
"/api/v1/auth/logout", "/api/v1/auth/logout",
headers={"Authorization": f"Bearer {tokens['access_token']}"}, headers={"Authorization": f"Bearer {tokens['access_token']}"},
json={"refresh_token": tokens["refresh_token"]} json={"refresh_token": tokens["refresh_token"]},
) )
# Should still return success (don't expose errors) # Should still return success (don't expose errors)
@@ -172,18 +180,18 @@ class TestLogoutAllUnexpectedError:
"""Test logout-all handles database errors.""" """Test logout-all handles database errors."""
response = await client.post( response = await client.post(
"/api/v1/auth/login", "/api/v1/auth/login",
json={ json={"email": "testuser@example.com", "password": "TestPassword123!"},
"email": "testuser@example.com",
"password": "TestPassword123!"
}
) )
access_token = response.json()["access_token"] access_token = response.json()["access_token"]
# Mock to raise database error # Mock to raise database error
with patch('app.api.routes.auth.session_crud.deactivate_all_user_sessions', side_effect=Exception("DB error")): with patch(
"app.api.routes.auth.session_crud.deactivate_all_user_sessions",
side_effect=Exception("DB error"),
):
response = await client.post( response = await client.post(
"/api/v1/auth/logout-all", "/api/v1/auth/logout-all",
headers={"Authorization": f"Bearer {access_token}"} headers={"Authorization": f"Bearer {access_token}"},
) )
assert response.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR assert response.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR
@@ -193,7 +201,9 @@ class TestPasswordResetConfirmSessionInvalidation:
"""Test password reset invalidates sessions.""" """Test password reset invalidates sessions."""
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_password_reset_continues_despite_session_invalidation_failure(self, client, async_test_user): async def test_password_reset_continues_despite_session_invalidation_failure(
self, client, async_test_user
):
"""Test password reset succeeds even if session invalidation fails.""" """Test password reset succeeds even if session invalidation fails."""
# Create a valid password reset token # Create a valid password reset token
from app.utils.security import create_password_reset_token from app.utils.security import create_password_reset_token
@@ -201,13 +211,13 @@ class TestPasswordResetConfirmSessionInvalidation:
token = create_password_reset_token(async_test_user.email) token = create_password_reset_token(async_test_user.email)
# Mock session invalidation to fail # Mock session invalidation to fail
with patch('app.api.routes.auth.session_crud.deactivate_all_user_sessions', side_effect=Exception("Invalidation failed")): with patch(
"app.api.routes.auth.session_crud.deactivate_all_user_sessions",
side_effect=Exception("Invalidation failed"),
):
response = await client.post( response = await client.post(
"/api/v1/auth/password-reset/confirm", "/api/v1/auth/password-reset/confirm",
json={ json={"token": token, "new_password": "NewPassword123!"},
"token": token,
"new_password": "NewPassword123!"
}
) )
# Should still succeed - password was reset # Should still succeed - password was reset

View File

@@ -2,22 +2,22 @@
""" """
Tests for password reset endpoints. Tests for password reset endpoints.
""" """
from unittest.mock import patch
import pytest import pytest
import pytest_asyncio
from unittest.mock import patch, AsyncMock, MagicMock
from fastapi import status from fastapi import status
from sqlalchemy import select from sqlalchemy import select
from app.schemas.users import PasswordResetRequest, PasswordResetConfirm
from app.utils.security import create_password_reset_token
from app.models.user import User from app.models.user import User
from app.utils.security import create_password_reset_token
# Disable rate limiting for tests # Disable rate limiting for tests
@pytest.fixture(autouse=True) @pytest.fixture(autouse=True)
def disable_rate_limit(): def disable_rate_limit():
"""Disable rate limiting for all tests in this module.""" """Disable rate limiting for all tests in this module."""
with patch('app.api.routes.auth.limiter.enabled', False): with patch("app.api.routes.auth.limiter.enabled", False):
yield yield
@@ -27,12 +27,14 @@ class TestPasswordResetRequest:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_password_reset_request_valid_email(self, client, async_test_user): async def test_password_reset_request_valid_email(self, client, async_test_user):
"""Test password reset request with valid email.""" """Test password reset request with valid email."""
with patch('app.api.routes.auth.email_service.send_password_reset_email') as mock_send: with patch(
"app.api.routes.auth.email_service.send_password_reset_email"
) as mock_send:
mock_send.return_value = True mock_send.return_value = True
response = await client.post( response = await client.post(
"/api/v1/auth/password-reset/request", "/api/v1/auth/password-reset/request",
json={"email": async_test_user.email} json={"email": async_test_user.email},
) )
assert response.status_code == status.HTTP_200_OK assert response.status_code == status.HTTP_200_OK
@@ -50,10 +52,12 @@ class TestPasswordResetRequest:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_password_reset_request_nonexistent_email(self, client): async def test_password_reset_request_nonexistent_email(self, client):
"""Test password reset request with non-existent email.""" """Test password reset request with non-existent email."""
with patch('app.api.routes.auth.email_service.send_password_reset_email') as mock_send: with patch(
"app.api.routes.auth.email_service.send_password_reset_email"
) as mock_send:
response = await client.post( response = await client.post(
"/api/v1/auth/password-reset/request", "/api/v1/auth/password-reset/request",
json={"email": "nonexistent@example.com"} json={"email": "nonexistent@example.com"},
) )
# Should still return success to prevent email enumeration # Should still return success to prevent email enumeration
@@ -65,20 +69,26 @@ class TestPasswordResetRequest:
mock_send.assert_not_called() mock_send.assert_not_called()
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_password_reset_request_inactive_user(self, client, async_test_db, async_test_user): async def test_password_reset_request_inactive_user(
self, client, async_test_db, async_test_user
):
"""Test password reset request with inactive user.""" """Test password reset request with inactive user."""
# Deactivate user # Deactivate user
test_engine, AsyncTestingSessionLocal = async_test_db _test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session: async with AsyncTestingSessionLocal() as session:
result = await session.execute(select(User).where(User.id == async_test_user.id)) result = await session.execute(
select(User).where(User.id == async_test_user.id)
)
user_in_session = result.scalar_one_or_none() user_in_session = result.scalar_one_or_none()
user_in_session.is_active = False user_in_session.is_active = False
await session.commit() await session.commit()
with patch('app.api.routes.auth.email_service.send_password_reset_email') as mock_send: with patch(
"app.api.routes.auth.email_service.send_password_reset_email"
) as mock_send:
response = await client.post( response = await client.post(
"/api/v1/auth/password-reset/request", "/api/v1/auth/password-reset/request",
json={"email": async_test_user.email} json={"email": async_test_user.email},
) )
# Should still return success to prevent email enumeration # Should still return success to prevent email enumeration
@@ -93,8 +103,7 @@ class TestPasswordResetRequest:
async def test_password_reset_request_invalid_email_format(self, client): async def test_password_reset_request_invalid_email_format(self, client):
"""Test password reset request with invalid email format.""" """Test password reset request with invalid email format."""
response = await client.post( response = await client.post(
"/api/v1/auth/password-reset/request", "/api/v1/auth/password-reset/request", json={"email": "not-an-email"}
json={"email": "not-an-email"}
) )
assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY
@@ -102,22 +111,23 @@ class TestPasswordResetRequest:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_password_reset_request_missing_email(self, client): async def test_password_reset_request_missing_email(self, client):
"""Test password reset request without email.""" """Test password reset request without email."""
response = await client.post( response = await client.post("/api/v1/auth/password-reset/request", json={})
"/api/v1/auth/password-reset/request",
json={}
)
assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_password_reset_request_email_service_error(self, client, async_test_user): async def test_password_reset_request_email_service_error(
self, client, async_test_user
):
"""Test password reset when email service fails.""" """Test password reset when email service fails."""
with patch('app.api.routes.auth.email_service.send_password_reset_email') as mock_send: with patch(
"app.api.routes.auth.email_service.send_password_reset_email"
) as mock_send:
mock_send.side_effect = Exception("SMTP Error") mock_send.side_effect = Exception("SMTP Error")
response = await client.post( response = await client.post(
"/api/v1/auth/password-reset/request", "/api/v1/auth/password-reset/request",
json={"email": async_test_user.email} json={"email": async_test_user.email},
) )
# Should still return success even if email fails # Should still return success even if email fails
@@ -128,14 +138,16 @@ class TestPasswordResetRequest:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_password_reset_request_rate_limiting(self, client, async_test_user): async def test_password_reset_request_rate_limiting(self, client, async_test_user):
"""Test that password reset requests are rate limited.""" """Test that password reset requests are rate limited."""
with patch('app.api.routes.auth.email_service.send_password_reset_email') as mock_send: with patch(
"app.api.routes.auth.email_service.send_password_reset_email"
) as mock_send:
mock_send.return_value = True mock_send.return_value = True
# Make multiple requests quickly (3/minute limit) # Make multiple requests quickly (3/minute limit)
for _ in range(3): for _ in range(3):
response = await client.post( response = await client.post(
"/api/v1/auth/password-reset/request", "/api/v1/auth/password-reset/request",
json={"email": async_test_user.email} json={"email": async_test_user.email},
) )
assert response.status_code == status.HTTP_200_OK assert response.status_code == status.HTTP_200_OK
@@ -144,7 +156,9 @@ class TestPasswordResetConfirm:
"""Tests for POST /auth/password-reset/confirm endpoint.""" """Tests for POST /auth/password-reset/confirm endpoint."""
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_password_reset_confirm_valid_token(self, client, async_test_user, async_test_db): async def test_password_reset_confirm_valid_token(
self, client, async_test_user, async_test_db
):
"""Test password reset confirmation with valid token.""" """Test password reset confirmation with valid token."""
# Generate valid token # Generate valid token
token = create_password_reset_token(async_test_user.email) token = create_password_reset_token(async_test_user.email)
@@ -152,10 +166,7 @@ class TestPasswordResetConfirm:
response = await client.post( response = await client.post(
"/api/v1/auth/password-reset/confirm", "/api/v1/auth/password-reset/confirm",
json={ json={"token": token, "new_password": new_password},
"token": token,
"new_password": new_password
}
) )
assert response.status_code == status.HTTP_200_OK assert response.status_code == status.HTTP_200_OK
@@ -164,11 +175,14 @@ class TestPasswordResetConfirm:
assert "successfully" in data["message"].lower() assert "successfully" in data["message"].lower()
# Verify user can login with new password # Verify user can login with new password
test_engine, AsyncTestingSessionLocal = async_test_db _test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session: async with AsyncTestingSessionLocal() as session:
result = await session.execute(select(User).where(User.id == async_test_user.id)) result = await session.execute(
select(User).where(User.id == async_test_user.id)
)
updated_user = result.scalar_one_or_none() updated_user = result.scalar_one_or_none()
from app.core.auth import verify_password from app.core.auth import verify_password
assert verify_password(new_password, updated_user.password_hash) is True assert verify_password(new_password, updated_user.password_hash) is True
@pytest.mark.asyncio @pytest.mark.asyncio
@@ -184,10 +198,7 @@ class TestPasswordResetConfirm:
response = await client.post( response = await client.post(
"/api/v1/auth/password-reset/confirm", "/api/v1/auth/password-reset/confirm",
json={ json={"token": token, "new_password": "NewSecure123!"},
"token": token,
"new_password": "NewSecure123!"
}
) )
assert response.status_code == status.HTTP_400_BAD_REQUEST assert response.status_code == status.HTTP_400_BAD_REQUEST
@@ -202,10 +213,7 @@ class TestPasswordResetConfirm:
"""Test password reset confirmation with invalid token.""" """Test password reset confirmation with invalid token."""
response = await client.post( response = await client.post(
"/api/v1/auth/password-reset/confirm", "/api/v1/auth/password-reset/confirm",
json={ json={"token": "invalid_token_xyz", "new_password": "NewSecure123!"},
"token": "invalid_token_xyz",
"new_password": "NewSecure123!"
}
) )
assert response.status_code == status.HTTP_400_BAD_REQUEST assert response.status_code == status.HTTP_400_BAD_REQUEST
@@ -222,19 +230,18 @@ class TestPasswordResetConfirm:
# Create valid token and tamper with it # Create valid token and tamper with it
token = create_password_reset_token(async_test_user.email) token = create_password_reset_token(async_test_user.email)
decoded = base64.urlsafe_b64decode(token.encode('utf-8')).decode('utf-8') decoded = base64.urlsafe_b64decode(token.encode("utf-8")).decode("utf-8")
token_data = json.loads(decoded) token_data = json.loads(decoded)
token_data["payload"]["email"] = "hacker@example.com" token_data["payload"]["email"] = "hacker@example.com"
# Re-encode tampered token # Re-encode tampered token
tampered = base64.urlsafe_b64encode(json.dumps(token_data).encode('utf-8')).decode('utf-8') tampered = base64.urlsafe_b64encode(
json.dumps(token_data).encode("utf-8")
).decode("utf-8")
response = await client.post( response = await client.post(
"/api/v1/auth/password-reset/confirm", "/api/v1/auth/password-reset/confirm",
json={ json={"token": tampered, "new_password": "NewSecure123!"},
"token": tampered,
"new_password": "NewSecure123!"
}
) )
assert response.status_code == status.HTTP_400_BAD_REQUEST assert response.status_code == status.HTTP_400_BAD_REQUEST
@@ -247,10 +254,7 @@ class TestPasswordResetConfirm:
response = await client.post( response = await client.post(
"/api/v1/auth/password-reset/confirm", "/api/v1/auth/password-reset/confirm",
json={ json={"token": token, "new_password": "NewSecure123!"},
"token": token,
"new_password": "NewSecure123!"
}
) )
assert response.status_code == status.HTTP_404_NOT_FOUND assert response.status_code == status.HTTP_404_NOT_FOUND
@@ -260,12 +264,16 @@ class TestPasswordResetConfirm:
assert "not found" in error_msg assert "not found" in error_msg
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_password_reset_confirm_inactive_user(self, client, async_test_user, async_test_db): async def test_password_reset_confirm_inactive_user(
self, client, async_test_user, async_test_db
):
"""Test password reset confirmation for inactive user.""" """Test password reset confirmation for inactive user."""
# Deactivate user # Deactivate user
test_engine, AsyncTestingSessionLocal = async_test_db _test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session: async with AsyncTestingSessionLocal() as session:
result = await session.execute(select(User).where(User.id == async_test_user.id)) result = await session.execute(
select(User).where(User.id == async_test_user.id)
)
user_in_session = result.scalar_one_or_none() user_in_session = result.scalar_one_or_none()
user_in_session.is_active = False user_in_session.is_active = False
await session.commit() await session.commit()
@@ -274,10 +282,7 @@ class TestPasswordResetConfirm:
response = await client.post( response = await client.post(
"/api/v1/auth/password-reset/confirm", "/api/v1/auth/password-reset/confirm",
json={ json={"token": token, "new_password": "NewSecure123!"},
"token": token,
"new_password": "NewSecure123!"
}
) )
assert response.status_code == status.HTTP_400_BAD_REQUEST assert response.status_code == status.HTTP_400_BAD_REQUEST
@@ -301,10 +306,7 @@ class TestPasswordResetConfirm:
for weak_password in weak_passwords: for weak_password in weak_passwords:
response = await client.post( response = await client.post(
"/api/v1/auth/password-reset/confirm", "/api/v1/auth/password-reset/confirm",
json={ json={"token": token, "new_password": weak_password},
"token": token,
"new_password": weak_password
}
) )
assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY
@@ -315,15 +317,14 @@ class TestPasswordResetConfirm:
# Missing token # Missing token
response = await client.post( response = await client.post(
"/api/v1/auth/password-reset/confirm", "/api/v1/auth/password-reset/confirm",
json={"new_password": "NewSecure123!"} json={"new_password": "NewSecure123!"},
) )
assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY
# Missing password # Missing password
token = create_password_reset_token("test@example.com") token = create_password_reset_token("test@example.com")
response = await client.post( response = await client.post(
"/api/v1/auth/password-reset/confirm", "/api/v1/auth/password-reset/confirm", json={"token": token}
json={"token": token}
) )
assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY
@@ -333,15 +334,12 @@ class TestPasswordResetConfirm:
token = create_password_reset_token(async_test_user.email) token = create_password_reset_token(async_test_user.email)
# Mock the database commit to raise an exception # Mock the database commit to raise an exception
with patch('app.api.routes.auth.user_crud.get_by_email') as mock_get: with patch("app.api.routes.auth.user_crud.get_by_email") as mock_get:
mock_get.side_effect = Exception("Database error") mock_get.side_effect = Exception("Database error")
response = await client.post( response = await client.post(
"/api/v1/auth/password-reset/confirm", "/api/v1/auth/password-reset/confirm",
json={ json={"token": token, "new_password": "NewSecure123!"},
"token": token,
"new_password": "NewSecure123!"
}
) )
assert response.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR assert response.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR
@@ -351,18 +349,22 @@ class TestPasswordResetConfirm:
assert "error" in error_msg or "resetting" in error_msg assert "error" in error_msg or "resetting" in error_msg
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_password_reset_full_flow(self, client, async_test_user, async_test_db): async def test_password_reset_full_flow(
self, client, async_test_user, async_test_db
):
"""Test complete password reset flow.""" """Test complete password reset flow."""
original_password = async_test_user.password_hash original_password = async_test_user.password_hash
new_password = "BrandNew123!" new_password = "BrandNew123!"
# Step 1: Request password reset # Step 1: Request password reset
with patch('app.api.routes.auth.email_service.send_password_reset_email') as mock_send: with patch(
"app.api.routes.auth.email_service.send_password_reset_email"
) as mock_send:
mock_send.return_value = True mock_send.return_value = True
response = await client.post( response = await client.post(
"/api/v1/auth/password-reset/request", "/api/v1/auth/password-reset/request",
json={"email": async_test_user.email} json={"email": async_test_user.email},
) )
assert response.status_code == status.HTTP_200_OK assert response.status_code == status.HTTP_200_OK
@@ -374,29 +376,24 @@ class TestPasswordResetConfirm:
# Step 2: Confirm password reset # Step 2: Confirm password reset
response = await client.post( response = await client.post(
"/api/v1/auth/password-reset/confirm", "/api/v1/auth/password-reset/confirm",
json={ json={"token": reset_token, "new_password": new_password},
"token": reset_token,
"new_password": new_password
}
) )
assert response.status_code == status.HTTP_200_OK assert response.status_code == status.HTTP_200_OK
# Step 3: Verify old password doesn't work # Step 3: Verify old password doesn't work
test_engine, AsyncTestingSessionLocal = async_test_db _test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session: async with AsyncTestingSessionLocal() as session:
result = await session.execute(select(User).where(User.id == async_test_user.id)) result = await session.execute(
select(User).where(User.id == async_test_user.id)
)
updated_user = result.scalar_one_or_none() updated_user = result.scalar_one_or_none()
from app.core.auth import verify_password
assert updated_user.password_hash != original_password assert updated_user.password_hash != original_password
# Step 4: Verify new password works # Step 4: Verify new password works
response = await client.post( response = await client.post(
"/api/v1/auth/login", "/api/v1/auth/login",
json={ json={"email": async_test_user.email, "password": new_password},
"email": async_test_user.email,
"password": new_password
}
) )
assert response.status_code == status.HTTP_200_OK assert response.status_code == status.HTTP_200_OK

View File

@@ -8,11 +8,10 @@ Critical security tests covering:
These tests prevent real-world attack scenarios. These tests prevent real-world attack scenarios.
""" """
import pytest import pytest
from httpx import AsyncClient from httpx import AsyncClient
from sqlalchemy.ext.asyncio import AsyncSession
from app.core.auth import create_refresh_token
from app.crud.session import session as session_crud from app.crud.session import session as session_crud
from app.models.user import User from app.models.user import User
@@ -30,10 +29,7 @@ class TestRevokedSessionSecurity:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_refresh_token_rejected_after_logout( async def test_refresh_token_rejected_after_logout(
self, self, client: AsyncClient, async_test_db, async_test_user: User
client: AsyncClient,
async_test_db,
async_test_user: User
): ):
""" """
Test that refresh tokens are rejected after session is deactivated. Test that refresh tokens are rejected after session is deactivated.
@@ -45,10 +41,10 @@ class TestRevokedSessionSecurity:
4. Attacker tries to use stolen refresh token 4. Attacker tries to use stolen refresh token
5. System MUST reject it (session revoked) 5. System MUST reject it (session revoked)
""" """
test_engine, SessionLocal = async_test_db _test_engine, SessionLocal = async_test_db
# Step 1: Create a session and refresh token for the user # Step 1: Create a session and refresh token for the user
async with SessionLocal() as session: async with SessionLocal():
# Login to get tokens # Login to get tokens
response = await client.post( response = await client.post(
"/api/v1/auth/login", "/api/v1/auth/login",
@@ -64,8 +60,7 @@ class TestRevokedSessionSecurity:
# Step 2: Verify refresh token works before logout # Step 2: Verify refresh token works before logout
response = await client.post( response = await client.post(
"/api/v1/auth/refresh", "/api/v1/auth/refresh", json={"refresh_token": refresh_token}
json={"refresh_token": refresh_token}
) )
assert response.status_code == 200, "Refresh should work before logout" assert response.status_code == 200, "Refresh should work before logout"
@@ -73,14 +68,13 @@ class TestRevokedSessionSecurity:
response = await client.post( response = await client.post(
"/api/v1/auth/logout", "/api/v1/auth/logout",
headers={"Authorization": f"Bearer {access_token}"}, headers={"Authorization": f"Bearer {access_token}"},
json={"refresh_token": refresh_token} json={"refresh_token": refresh_token},
) )
assert response.status_code == 200, "Logout should succeed" assert response.status_code == 200, "Logout should succeed"
# Step 4: Attacker tries to use stolen refresh token # Step 4: Attacker tries to use stolen refresh token
response = await client.post( response = await client.post(
"/api/v1/auth/refresh", "/api/v1/auth/refresh", json={"refresh_token": refresh_token}
json={"refresh_token": refresh_token}
) )
# Step 5: System MUST reject (covers lines 261-262) # Step 5: System MUST reject (covers lines 261-262)
@@ -93,10 +87,7 @@ class TestRevokedSessionSecurity:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_refresh_token_rejected_for_deleted_session( async def test_refresh_token_rejected_for_deleted_session(
self, self, client: AsyncClient, async_test_db, async_test_user: User
client: AsyncClient,
async_test_db,
async_test_user: User
): ):
""" """
Test that tokens for deleted sessions are rejected. Test that tokens for deleted sessions are rejected.
@@ -104,7 +95,7 @@ class TestRevokedSessionSecurity:
Attack Scenario: Attack Scenario:
Admin deletes a session from database, but attacker has the token. Admin deletes a session from database, but attacker has the token.
""" """
test_engine, SessionLocal = async_test_db _test_engine, SessionLocal = async_test_db
# Step 1: Login to create a session # Step 1: Login to create a session
response = await client.post( response = await client.post(
@@ -120,6 +111,7 @@ class TestRevokedSessionSecurity:
# Step 2: Manually delete the session from database (simulating admin action) # Step 2: Manually delete the session from database (simulating admin action)
from app.core.auth import decode_token from app.core.auth import decode_token
token_data = decode_token(refresh_token, verify_type="refresh") token_data = decode_token(refresh_token, verify_type="refresh")
jti = token_data.jti jti = token_data.jti
@@ -132,15 +124,17 @@ class TestRevokedSessionSecurity:
# Step 3: Try to use the refresh token # Step 3: Try to use the refresh token
response = await client.post( response = await client.post(
"/api/v1/auth/refresh", "/api/v1/auth/refresh", json={"refresh_token": refresh_token}
json={"refresh_token": refresh_token}
) )
# Should reject (session doesn't exist) # Should reject (session doesn't exist)
assert response.status_code == 401 assert response.status_code == 401
data = response.json() data = response.json()
if "errors" in data: if "errors" in data:
assert "revoked" in data["errors"][0]["message"].lower() or "session" in data["errors"][0]["message"].lower() assert (
"revoked" in data["errors"][0]["message"].lower()
or "session" in data["errors"][0]["message"].lower()
)
else: else:
assert "revoked" in data.get("detail", "").lower() assert "revoked" in data.get("detail", "").lower()
@@ -162,7 +156,7 @@ class TestSessionHijackingSecurity:
client: AsyncClient, client: AsyncClient,
async_test_db, async_test_db,
async_test_user: User, async_test_user: User,
async_test_superuser: User async_test_superuser: User,
): ):
""" """
Test that users cannot logout other users' sessions. Test that users cannot logout other users' sessions.
@@ -173,7 +167,7 @@ class TestSessionHijackingSecurity:
3. User A tries to logout User B's session 3. User A tries to logout User B's session
4. System MUST reject (cross-user attack) 4. System MUST reject (cross-user attack)
""" """
test_engine, SessionLocal = async_test_db _test_engine, _SessionLocal = async_test_db
# Step 1: User A logs in # Step 1: User A logs in
response = await client.post( response = await client.post(
@@ -202,8 +196,10 @@ class TestSessionHijackingSecurity:
# Step 3: User A tries to logout User B's session using User B's refresh token # Step 3: User A tries to logout User B's session using User B's refresh token
response = await client.post( response = await client.post(
"/api/v1/auth/logout", "/api/v1/auth/logout",
headers={"Authorization": f"Bearer {user_a_access}"}, # User A's access token headers={
json={"refresh_token": user_b_refresh} # But User B's refresh token "Authorization": f"Bearer {user_a_access}"
}, # User A's access token
json={"refresh_token": user_b_refresh}, # But User B's refresh token
) )
# Step 4: System MUST reject (covers lines 509-513) # Step 4: System MUST reject (covers lines 509-513)
@@ -217,9 +213,7 @@ class TestSessionHijackingSecurity:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_users_can_logout_their_own_sessions( async def test_users_can_logout_their_own_sessions(
self, self, client: AsyncClient, async_test_user: User
client: AsyncClient,
async_test_user: User
): ):
""" """
Sanity check: Users CAN logout their own sessions. Sanity check: Users CAN logout their own sessions.
@@ -241,6 +235,8 @@ class TestSessionHijackingSecurity:
response = await client.post( response = await client.post(
"/api/v1/auth/logout", "/api/v1/auth/logout",
headers={"Authorization": f"Bearer {tokens['access_token']}"}, headers={"Authorization": f"Bearer {tokens['access_token']}"},
json={"refresh_token": tokens["refresh_token"]} json={"refresh_token": tokens["refresh_token"]},
)
assert response.status_code == 200, (
"Users should be able to logout their own sessions"
) )
assert response.status_code == 200, "Users should be able to logout their own sessions"

View File

@@ -5,16 +5,18 @@ Tests for organization routes (user endpoints).
These test the routes in app/api/routes/organizations.py which allow These test the routes in app/api/routes/organizations.py which allow
users to view and manage organizations they belong to. users to view and manage organizations they belong to.
""" """
from unittest.mock import patch
from uuid import uuid4
import pytest import pytest
import pytest_asyncio import pytest_asyncio
from fastapi import status from fastapi import status
from uuid import uuid4
from unittest.mock import patch, AsyncMock
from app.core.auth import get_password_hash
from app.models.organization import Organization from app.models.organization import Organization
from app.models.user import User from app.models.user import User
from app.models.user_organization import UserOrganization, OrganizationRole from app.models.user_organization import OrganizationRole, UserOrganization
from app.core.auth import get_password_hash
@pytest_asyncio.fixture @pytest_asyncio.fixture
@@ -22,10 +24,7 @@ async def user_token(client, async_test_user):
"""Get access token for regular user.""" """Get access token for regular user."""
response = await client.post( response = await client.post(
"/api/v1/auth/login", "/api/v1/auth/login",
json={ json={"email": "testuser@example.com", "password": "TestPassword123!"},
"email": "testuser@example.com",
"password": "TestPassword123!"
}
) )
assert response.status_code == 200 assert response.status_code == 200
return response.json()["access_token"] return response.json()["access_token"]
@@ -34,7 +33,7 @@ async def user_token(client, async_test_user):
@pytest_asyncio.fixture @pytest_asyncio.fixture
async def second_user(async_test_db): async def second_user(async_test_db):
"""Create a second test user.""" """Create a second test user."""
test_engine, AsyncTestingSessionLocal = async_test_db _test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session: async with AsyncTestingSessionLocal() as session:
user = User( user = User(
id=uuid4(), id=uuid4(),
@@ -56,12 +55,12 @@ async def second_user(async_test_db):
@pytest_asyncio.fixture @pytest_asyncio.fixture
async def test_org_with_user_member(async_test_db, async_test_user): async def test_org_with_user_member(async_test_db, async_test_user):
"""Create a test organization with async_test_user as a member.""" """Create a test organization with async_test_user as a member."""
test_engine, AsyncTestingSessionLocal = async_test_db _test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session: async with AsyncTestingSessionLocal() as session:
org = Organization( org = Organization(
name="Member Org", name="Member Org",
slug="member-org", slug="member-org",
description="Test organization where user is a member" description="Test organization where user is a member",
) )
session.add(org) session.add(org)
await session.commit() await session.commit()
@@ -72,7 +71,7 @@ async def test_org_with_user_member(async_test_db, async_test_user):
user_id=async_test_user.id, user_id=async_test_user.id,
organization_id=org.id, organization_id=org.id,
role=OrganizationRole.MEMBER, role=OrganizationRole.MEMBER,
is_active=True is_active=True,
) )
session.add(membership) session.add(membership)
await session.commit() await session.commit()
@@ -83,12 +82,12 @@ async def test_org_with_user_member(async_test_db, async_test_user):
@pytest_asyncio.fixture @pytest_asyncio.fixture
async def test_org_with_user_admin(async_test_db, async_test_user): async def test_org_with_user_admin(async_test_db, async_test_user):
"""Create a test organization with async_test_user as an admin.""" """Create a test organization with async_test_user as an admin."""
test_engine, AsyncTestingSessionLocal = async_test_db _test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session: async with AsyncTestingSessionLocal() as session:
org = Organization( org = Organization(
name="Admin Org", name="Admin Org",
slug="admin-org", slug="admin-org",
description="Test organization where user is an admin" description="Test organization where user is an admin",
) )
session.add(org) session.add(org)
await session.commit() await session.commit()
@@ -99,7 +98,7 @@ async def test_org_with_user_admin(async_test_db, async_test_user):
user_id=async_test_user.id, user_id=async_test_user.id,
organization_id=org.id, organization_id=org.id,
role=OrganizationRole.ADMIN, role=OrganizationRole.ADMIN,
is_active=True is_active=True,
) )
session.add(membership) session.add(membership)
await session.commit() await session.commit()
@@ -110,12 +109,12 @@ async def test_org_with_user_admin(async_test_db, async_test_user):
@pytest_asyncio.fixture @pytest_asyncio.fixture
async def test_org_with_user_owner(async_test_db, async_test_user): async def test_org_with_user_owner(async_test_db, async_test_user):
"""Create a test organization with async_test_user as owner.""" """Create a test organization with async_test_user as owner."""
test_engine, AsyncTestingSessionLocal = async_test_db _test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session: async with AsyncTestingSessionLocal() as session:
org = Organization( org = Organization(
name="Owner Org", name="Owner Org",
slug="owner-org", slug="owner-org",
description="Test organization where user is owner" description="Test organization where user is owner",
) )
session.add(org) session.add(org)
await session.commit() await session.commit()
@@ -126,7 +125,7 @@ async def test_org_with_user_owner(async_test_db, async_test_user):
user_id=async_test_user.id, user_id=async_test_user.id,
organization_id=org.id, organization_id=org.id,
role=OrganizationRole.OWNER, role=OrganizationRole.OWNER,
is_active=True is_active=True,
) )
session.add(membership) session.add(membership)
await session.commit() await session.commit()
@@ -136,21 +135,18 @@ async def test_org_with_user_owner(async_test_db, async_test_user):
# ===== GET /api/v1/organizations/me ===== # ===== GET /api/v1/organizations/me =====
class TestGetMyOrganizations: class TestGetMyOrganizations:
"""Tests for GET /api/v1/organizations/me endpoint.""" """Tests for GET /api/v1/organizations/me endpoint."""
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_get_my_organizations_success( async def test_get_my_organizations_success(
self, self, client, user_token, test_org_with_user_member, test_org_with_user_admin
client,
user_token,
test_org_with_user_member,
test_org_with_user_admin
): ):
"""Test successfully getting user's organizations (covers lines 54-79).""" """Test successfully getting user's organizations (covers lines 54-79)."""
response = await client.get( response = await client.get(
"/api/v1/organizations/me", "/api/v1/organizations/me",
headers={"Authorization": f"Bearer {user_token}"} headers={"Authorization": f"Bearer {user_token}"},
) )
assert response.status_code == status.HTTP_200_OK assert response.status_code == status.HTTP_200_OK
@@ -167,21 +163,15 @@ class TestGetMyOrganizations:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_get_my_organizations_filter_active( async def test_get_my_organizations_filter_active(
self, self, client, async_test_db, async_test_user, user_token
client,
async_test_db,
async_test_user,
user_token
): ):
"""Test filtering organizations by active status.""" """Test filtering organizations by active status."""
test_engine, AsyncTestingSessionLocal = async_test_db _test_engine, AsyncTestingSessionLocal = async_test_db
# Create active org # Create active org
async with AsyncTestingSessionLocal() as session: async with AsyncTestingSessionLocal() as session:
active_org = Organization( active_org = Organization(
name="Active Org", name="Active Org", slug="active-org-filter", is_active=True
slug="active-org-filter",
is_active=True
) )
session.add(active_org) session.add(active_org)
await session.commit() await session.commit()
@@ -192,14 +182,14 @@ class TestGetMyOrganizations:
user_id=async_test_user.id, user_id=async_test_user.id,
organization_id=active_org.id, organization_id=active_org.id,
role=OrganizationRole.MEMBER, role=OrganizationRole.MEMBER,
is_active=True is_active=True,
) )
session.add(membership) session.add(membership)
await session.commit() await session.commit()
response = await client.get( response = await client.get(
"/api/v1/organizations/me?is_active=true", "/api/v1/organizations/me?is_active=true",
headers={"Authorization": f"Bearer {user_token}"} headers={"Authorization": f"Bearer {user_token}"},
) )
assert response.status_code == status.HTTP_200_OK assert response.status_code == status.HTTP_200_OK
@@ -209,7 +199,7 @@ class TestGetMyOrganizations:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_get_my_organizations_empty(self, client, async_test_db): async def test_get_my_organizations_empty(self, client, async_test_db):
"""Test getting organizations when user has none.""" """Test getting organizations when user has none."""
test_engine, AsyncTestingSessionLocal = async_test_db _test_engine, AsyncTestingSessionLocal = async_test_db
# Create user with no org memberships # Create user with no org memberships
async with AsyncTestingSessionLocal() as session: async with AsyncTestingSessionLocal() as session:
@@ -219,7 +209,7 @@ class TestGetMyOrganizations:
password_hash=get_password_hash("TestPassword123!"), password_hash=get_password_hash("TestPassword123!"),
first_name="No", first_name="No",
last_name="Org", last_name="Org",
is_active=True is_active=True,
) )
session.add(user) session.add(user)
await session.commit() await session.commit()
@@ -227,13 +217,12 @@ class TestGetMyOrganizations:
# Login to get token # Login to get token
login_response = await client.post( login_response = await client.post(
"/api/v1/auth/login", "/api/v1/auth/login",
json={"email": "noorg@example.com", "password": "TestPassword123!"} json={"email": "noorg@example.com", "password": "TestPassword123!"},
) )
token = login_response.json()["access_token"] token = login_response.json()["access_token"]
response = await client.get( response = await client.get(
"/api/v1/organizations/me", "/api/v1/organizations/me", headers={"Authorization": f"Bearer {token}"}
headers={"Authorization": f"Bearer {token}"}
) )
assert response.status_code == status.HTTP_200_OK assert response.status_code == status.HTTP_200_OK
@@ -243,20 +232,18 @@ class TestGetMyOrganizations:
# ===== GET /api/v1/organizations/{organization_id} ===== # ===== GET /api/v1/organizations/{organization_id} =====
class TestGetOrganization: class TestGetOrganization:
"""Tests for GET /api/v1/organizations/{organization_id} endpoint.""" """Tests for GET /api/v1/organizations/{organization_id} endpoint."""
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_get_organization_success( async def test_get_organization_success(
self, self, client, user_token, test_org_with_user_member
client,
user_token,
test_org_with_user_member
): ):
"""Test successfully getting organization details (covers lines 103-122).""" """Test successfully getting organization details (covers lines 103-122)."""
response = await client.get( response = await client.get(
f"/api/v1/organizations/{test_org_with_user_member.id}", f"/api/v1/organizations/{test_org_with_user_member.id}",
headers={"Authorization": f"Bearer {user_token}"} headers={"Authorization": f"Bearer {user_token}"},
) )
assert response.status_code == status.HTTP_200_OK assert response.status_code == status.HTTP_200_OK
@@ -272,7 +259,7 @@ class TestGetOrganization:
fake_org_id = uuid4() fake_org_id = uuid4()
response = await client.get( response = await client.get(
f"/api/v1/organizations/{fake_org_id}", f"/api/v1/organizations/{fake_org_id}",
headers={"Authorization": f"Bearer {user_token}"} headers={"Authorization": f"Bearer {user_token}"},
) )
# Permission dependency checks membership before endpoint logic # Permission dependency checks membership before endpoint logic
@@ -283,20 +270,14 @@ class TestGetOrganization:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_get_organization_not_member( async def test_get_organization_not_member(
self, self, client, async_test_db, async_test_user
client,
async_test_db,
async_test_user
): ):
"""Test getting organization where user is not a member fails.""" """Test getting organization where user is not a member fails."""
test_engine, AsyncTestingSessionLocal = async_test_db _test_engine, AsyncTestingSessionLocal = async_test_db
# Create org without adding user # Create org without adding user
async with AsyncTestingSessionLocal() as session: async with AsyncTestingSessionLocal() as session:
org = Organization( org = Organization(name="Not Member Org", slug="not-member-org")
name="Not Member Org",
slug="not-member-org"
)
session.add(org) session.add(org)
await session.commit() await session.commit()
await session.refresh(org) await session.refresh(org)
@@ -305,13 +286,13 @@ class TestGetOrganization:
# Login as user # Login as user
login_response = await client.post( login_response = await client.post(
"/api/v1/auth/login", "/api/v1/auth/login",
json={"email": "testuser@example.com", "password": "TestPassword123!"} json={"email": "testuser@example.com", "password": "TestPassword123!"},
) )
token = login_response.json()["access_token"] token = login_response.json()["access_token"]
response = await client.get( response = await client.get(
f"/api/v1/organizations/{org_id}", f"/api/v1/organizations/{org_id}",
headers={"Authorization": f"Bearer {token}"} headers={"Authorization": f"Bearer {token}"},
) )
# Should fail permission check # Should fail permission check
@@ -320,6 +301,7 @@ class TestGetOrganization:
# ===== GET /api/v1/organizations/{organization_id}/members ===== # ===== GET /api/v1/organizations/{organization_id}/members =====
class TestGetOrganizationMembers: class TestGetOrganizationMembers:
"""Tests for GET /api/v1/organizations/{organization_id}/members endpoint.""" """Tests for GET /api/v1/organizations/{organization_id}/members endpoint."""
@@ -331,10 +313,10 @@ class TestGetOrganizationMembers:
async_test_user, async_test_user,
second_user, second_user,
user_token, user_token,
test_org_with_user_member test_org_with_user_member,
): ):
"""Test successfully getting organization members (covers lines 150-168).""" """Test successfully getting organization members (covers lines 150-168)."""
test_engine, AsyncTestingSessionLocal = async_test_db _test_engine, AsyncTestingSessionLocal = async_test_db
# Add second user to org # Add second user to org
async with AsyncTestingSessionLocal() as session: async with AsyncTestingSessionLocal() as session:
@@ -342,14 +324,14 @@ class TestGetOrganizationMembers:
user_id=second_user.id, user_id=second_user.id,
organization_id=test_org_with_user_member.id, organization_id=test_org_with_user_member.id,
role=OrganizationRole.MEMBER, role=OrganizationRole.MEMBER,
is_active=True is_active=True,
) )
session.add(membership) session.add(membership)
await session.commit() await session.commit()
response = await client.get( response = await client.get(
f"/api/v1/organizations/{test_org_with_user_member.id}/members", f"/api/v1/organizations/{test_org_with_user_member.id}/members",
headers={"Authorization": f"Bearer {user_token}"} headers={"Authorization": f"Bearer {user_token}"},
) )
assert response.status_code == status.HTTP_200_OK assert response.status_code == status.HTTP_200_OK
@@ -360,15 +342,12 @@ class TestGetOrganizationMembers:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_get_organization_members_with_pagination( async def test_get_organization_members_with_pagination(
self, self, client, user_token, test_org_with_user_member
client,
user_token,
test_org_with_user_member
): ):
"""Test pagination parameters.""" """Test pagination parameters."""
response = await client.get( response = await client.get(
f"/api/v1/organizations/{test_org_with_user_member.id}/members?page=1&limit=10", f"/api/v1/organizations/{test_org_with_user_member.id}/members?page=1&limit=10",
headers={"Authorization": f"Bearer {user_token}"} headers={"Authorization": f"Bearer {user_token}"},
) )
assert response.status_code == status.HTTP_200_OK assert response.status_code == status.HTTP_200_OK
@@ -385,10 +364,10 @@ class TestGetOrganizationMembers:
async_test_user, async_test_user,
second_user, second_user,
user_token, user_token,
test_org_with_user_member test_org_with_user_member,
): ):
"""Test filtering members by active status.""" """Test filtering members by active status."""
test_engine, AsyncTestingSessionLocal = async_test_db _test_engine, AsyncTestingSessionLocal = async_test_db
# Add second user as inactive member # Add second user as inactive member
async with AsyncTestingSessionLocal() as session: async with AsyncTestingSessionLocal() as session:
@@ -396,7 +375,7 @@ class TestGetOrganizationMembers:
user_id=second_user.id, user_id=second_user.id,
organization_id=test_org_with_user_member.id, organization_id=test_org_with_user_member.id,
role=OrganizationRole.MEMBER, role=OrganizationRole.MEMBER,
is_active=False is_active=False,
) )
session.add(membership) session.add(membership)
await session.commit() await session.commit()
@@ -404,7 +383,7 @@ class TestGetOrganizationMembers:
# Filter for active only # Filter for active only
response = await client.get( response = await client.get(
f"/api/v1/organizations/{test_org_with_user_member.id}/members?is_active=true", f"/api/v1/organizations/{test_org_with_user_member.id}/members?is_active=true",
headers={"Authorization": f"Bearer {user_token}"} headers={"Authorization": f"Bearer {user_token}"},
) )
assert response.status_code == status.HTTP_200_OK assert response.status_code == status.HTTP_200_OK
@@ -416,31 +395,26 @@ class TestGetOrganizationMembers:
# ===== PUT /api/v1/organizations/{organization_id} ===== # ===== PUT /api/v1/organizations/{organization_id} =====
class TestUpdateOrganization: class TestUpdateOrganization:
"""Tests for PUT /api/v1/organizations/{organization_id} endpoint.""" """Tests for PUT /api/v1/organizations/{organization_id} endpoint."""
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_update_organization_as_admin_success( async def test_update_organization_as_admin_success(
self, self, client, async_test_user, test_org_with_user_admin
client,
async_test_user,
test_org_with_user_admin
): ):
"""Test successfully updating organization as admin (covers lines 193-215).""" """Test successfully updating organization as admin (covers lines 193-215)."""
# Login as admin user # Login as admin user
login_response = await client.post( login_response = await client.post(
"/api/v1/auth/login", "/api/v1/auth/login",
json={"email": "testuser@example.com", "password": "TestPassword123!"} json={"email": "testuser@example.com", "password": "TestPassword123!"},
) )
admin_token = login_response.json()["access_token"] admin_token = login_response.json()["access_token"]
response = await client.put( response = await client.put(
f"/api/v1/organizations/{test_org_with_user_admin.id}", f"/api/v1/organizations/{test_org_with_user_admin.id}",
json={ json={"name": "Updated Admin Org", "description": "Updated description"},
"name": "Updated Admin Org", headers={"Authorization": f"Bearer {admin_token}"},
"description": "Updated description"
},
headers={"Authorization": f"Bearer {admin_token}"}
) )
assert response.status_code == status.HTTP_200_OK assert response.status_code == status.HTTP_200_OK
@@ -450,23 +424,20 @@ class TestUpdateOrganization:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_update_organization_as_owner_success( async def test_update_organization_as_owner_success(
self, self, client, async_test_user, test_org_with_user_owner
client,
async_test_user,
test_org_with_user_owner
): ):
"""Test successfully updating organization as owner.""" """Test successfully updating organization as owner."""
# Login as owner user # Login as owner user
login_response = await client.post( login_response = await client.post(
"/api/v1/auth/login", "/api/v1/auth/login",
json={"email": "testuser@example.com", "password": "TestPassword123!"} json={"email": "testuser@example.com", "password": "TestPassword123!"},
) )
owner_token = login_response.json()["access_token"] owner_token = login_response.json()["access_token"]
response = await client.put( response = await client.put(
f"/api/v1/organizations/{test_org_with_user_owner.id}", f"/api/v1/organizations/{test_org_with_user_owner.id}",
json={"name": "Updated Owner Org"}, json={"name": "Updated Owner Org"},
headers={"Authorization": f"Bearer {owner_token}"} headers={"Authorization": f"Bearer {owner_token}"},
) )
assert response.status_code == status.HTTP_200_OK assert response.status_code == status.HTTP_200_OK
@@ -475,16 +446,13 @@ class TestUpdateOrganization:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_update_organization_as_member_fails( async def test_update_organization_as_member_fails(
self, self, client, user_token, test_org_with_user_member
client,
user_token,
test_org_with_user_member
): ):
"""Test updating organization as regular member fails.""" """Test updating organization as regular member fails."""
response = await client.put( response = await client.put(
f"/api/v1/organizations/{test_org_with_user_member.id}", f"/api/v1/organizations/{test_org_with_user_member.id}",
json={"name": "Should Fail"}, json={"name": "Should Fail"},
headers={"Authorization": f"Bearer {user_token}"} headers={"Authorization": f"Bearer {user_token}"},
) )
# Should fail permission check (need admin or owner) # Should fail permission check (need admin or owner)
@@ -492,15 +460,13 @@ class TestUpdateOrganization:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_update_organization_not_found( async def test_update_organization_not_found(
self, self, client, test_org_with_user_admin
client,
test_org_with_user_admin
): ):
"""Test updating nonexistent organization returns 403 (permission check first).""" """Test updating nonexistent organization returns 403 (permission check first)."""
# Login as admin # Login as admin
login_response = await client.post( login_response = await client.post(
"/api/v1/auth/login", "/api/v1/auth/login",
json={"email": "testuser@example.com", "password": "TestPassword123!"} json={"email": "testuser@example.com", "password": "TestPassword123!"},
) )
admin_token = login_response.json()["access_token"] admin_token = login_response.json()["access_token"]
@@ -508,7 +474,7 @@ class TestUpdateOrganization:
response = await client.put( response = await client.put(
f"/api/v1/organizations/{fake_org_id}", f"/api/v1/organizations/{fake_org_id}",
json={"name": "Updated"}, json={"name": "Updated"},
headers={"Authorization": f"Bearer {admin_token}"} headers={"Authorization": f"Bearer {admin_token}"},
) )
# Permission dependency checks admin role before endpoint logic # Permission dependency checks admin role before endpoint logic
@@ -520,6 +486,7 @@ class TestUpdateOrganization:
# ===== Authentication Tests ===== # ===== Authentication Tests =====
class TestOrganizationAuthentication: class TestOrganizationAuthentication:
"""Test authentication requirements for organization endpoints.""" """Test authentication requirements for organization endpoints."""
@@ -548,14 +515,14 @@ class TestOrganizationAuthentication:
"""Test unauthenticated access to update fails.""" """Test unauthenticated access to update fails."""
fake_id = uuid4() fake_id = uuid4()
response = await client.put( response = await client.put(
f"/api/v1/organizations/{fake_id}", f"/api/v1/organizations/{fake_id}", json={"name": "Test"}
json={"name": "Test"}
) )
assert response.status_code == status.HTTP_401_UNAUTHORIZED assert response.status_code == status.HTTP_401_UNAUTHORIZED
# ===== Exception Handler Tests (Database Error Scenarios) ===== # ===== Exception Handler Tests (Database Error Scenarios) =====
class TestOrganizationExceptionHandlers: class TestOrganizationExceptionHandlers:
""" """
Test exception handlers in organization endpoints. Test exception handlers in organization endpoints.
@@ -566,86 +533,74 @@ class TestOrganizationExceptionHandlers:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_get_my_organizations_database_error( async def test_get_my_organizations_database_error(
self, self, client, user_token, test_org_with_user_member
client,
user_token,
test_org_with_user_member
): ):
"""Test generic exception handler in get_my_organizations (covers lines 81-83).""" """Test generic exception handler in get_my_organizations (covers lines 81-83)."""
with patch( with patch(
"app.crud.organization.organization.get_user_organizations_with_details", "app.crud.organization.organization.get_user_organizations_with_details",
side_effect=Exception("Database connection lost") side_effect=Exception("Database connection lost"),
): ):
# The exception handler logs and re-raises, so we expect the exception # The exception handler logs and re-raises, so we expect the exception
# to propagate (which proves the handler executed) # to propagate (which proves the handler executed)
with pytest.raises(Exception, match="Database connection lost"): with pytest.raises(Exception, match="Database connection lost"):
await client.get( await client.get(
"/api/v1/organizations/me", "/api/v1/organizations/me",
headers={"Authorization": f"Bearer {user_token}"} headers={"Authorization": f"Bearer {user_token}"},
) )
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_get_organization_database_error( async def test_get_organization_database_error(
self, self, client, user_token, test_org_with_user_member
client,
user_token,
test_org_with_user_member
): ):
"""Test generic exception handler in get_organization (covers lines 124-128).""" """Test generic exception handler in get_organization (covers lines 124-128)."""
with patch( with patch(
"app.crud.organization.organization.get", "app.crud.organization.organization.get",
side_effect=Exception("Database timeout") side_effect=Exception("Database timeout"),
): ):
with pytest.raises(Exception, match="Database timeout"): with pytest.raises(Exception, match="Database timeout"):
await client.get( await client.get(
f"/api/v1/organizations/{test_org_with_user_member.id}", f"/api/v1/organizations/{test_org_with_user_member.id}",
headers={"Authorization": f"Bearer {user_token}"} headers={"Authorization": f"Bearer {user_token}"},
) )
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_get_organization_members_database_error( async def test_get_organization_members_database_error(
self, self, client, user_token, test_org_with_user_member
client,
user_token,
test_org_with_user_member
): ):
"""Test generic exception handler in get_organization_members (covers lines 170-172).""" """Test generic exception handler in get_organization_members (covers lines 170-172)."""
with patch( with patch(
"app.crud.organization.organization.get_organization_members", "app.crud.organization.organization.get_organization_members",
side_effect=Exception("Connection pool exhausted") side_effect=Exception("Connection pool exhausted"),
): ):
with pytest.raises(Exception, match="Connection pool exhausted"): with pytest.raises(Exception, match="Connection pool exhausted"):
await client.get( await client.get(
f"/api/v1/organizations/{test_org_with_user_member.id}/members", f"/api/v1/organizations/{test_org_with_user_member.id}/members",
headers={"Authorization": f"Bearer {user_token}"} headers={"Authorization": f"Bearer {user_token}"},
) )
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_update_organization_database_error( async def test_update_organization_database_error(
self, self, client, async_test_user, test_org_with_user_admin
client,
async_test_user,
test_org_with_user_admin
): ):
"""Test generic exception handler in update_organization (covers lines 217-221).""" """Test generic exception handler in update_organization (covers lines 217-221)."""
# Login as admin user # Login as admin user
login_response = await client.post( login_response = await client.post(
"/api/v1/auth/login", "/api/v1/auth/login",
json={"email": "testuser@example.com", "password": "TestPassword123!"} json={"email": "testuser@example.com", "password": "TestPassword123!"},
) )
admin_token = login_response.json()["access_token"] admin_token = login_response.json()["access_token"]
with patch( with patch(
"app.crud.organization.organization.get", "app.crud.organization.organization.get",
return_value=test_org_with_user_admin return_value=test_org_with_user_admin,
): ):
with patch( with patch(
"app.crud.organization.organization.update", "app.crud.organization.organization.update",
side_effect=Exception("Write lock timeout") side_effect=Exception("Write lock timeout"),
): ):
with pytest.raises(Exception, match="Write lock timeout"): with pytest.raises(Exception, match="Write lock timeout"):
await client.put( await client.put(
f"/api/v1/organizations/{test_org_with_user_admin.id}", f"/api/v1/organizations/{test_org_with_user_admin.id}",
json={"name": "Should Fail"}, json={"name": "Should Fail"},
headers={"Authorization": f"Bearer {admin_token}"} headers={"Authorization": f"Bearer {admin_token}"},
) )

View File

@@ -5,15 +5,17 @@ Tests for permission dependencies - CRITICAL SECURITY PATHS.
These tests ensure superusers can bypass organization checks correctly, These tests ensure superusers can bypass organization checks correctly,
and that regular users are properly blocked. and that regular users are properly blocked.
""" """
from uuid import uuid4
import pytest import pytest
import pytest_asyncio import pytest_asyncio
from fastapi import status from fastapi import status
from uuid import uuid4
from app.core.auth import get_password_hash
from app.models.organization import Organization from app.models.organization import Organization
from app.models.user import User from app.models.user import User
from app.models.user_organization import UserOrganization, OrganizationRole from app.models.user_organization import OrganizationRole, UserOrganization
from app.core.auth import get_password_hash
@pytest_asyncio.fixture @pytest_asyncio.fixture
@@ -21,10 +23,7 @@ async def superuser_token(client, async_test_superuser):
"""Get access token for superuser.""" """Get access token for superuser."""
response = await client.post( response = await client.post(
"/api/v1/auth/login", "/api/v1/auth/login",
json={ json={"email": "superuser@example.com", "password": "SuperPassword123!"},
"email": "superuser@example.com",
"password": "SuperPassword123!"
}
) )
assert response.status_code == 200 assert response.status_code == 200
return response.json()["access_token"] return response.json()["access_token"]
@@ -35,10 +34,7 @@ async def regular_user_token(client, async_test_user):
"""Get access token for regular user.""" """Get access token for regular user."""
response = await client.post( response = await client.post(
"/api/v1/auth/login", "/api/v1/auth/login",
json={ json={"email": "testuser@example.com", "password": "TestPassword123!"},
"email": "testuser@example.com",
"password": "TestPassword123!"
}
) )
assert response.status_code == 200 assert response.status_code == 200
return response.json()["access_token"] return response.json()["access_token"]
@@ -47,12 +43,12 @@ async def regular_user_token(client, async_test_user):
@pytest_asyncio.fixture @pytest_asyncio.fixture
async def test_org_no_members(async_test_db): async def test_org_no_members(async_test_db):
"""Create a test organization with NO members.""" """Create a test organization with NO members."""
test_engine, AsyncTestingSessionLocal = async_test_db _test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session: async with AsyncTestingSessionLocal() as session:
org = Organization( org = Organization(
name="No Members Org", name="No Members Org",
slug="no-members-org", slug="no-members-org",
description="Test org with no members" description="Test org with no members",
) )
session.add(org) session.add(org)
await session.commit() await session.commit()
@@ -63,12 +59,12 @@ async def test_org_no_members(async_test_db):
@pytest_asyncio.fixture @pytest_asyncio.fixture
async def test_org_with_member(async_test_db, async_test_user): async def test_org_with_member(async_test_db, async_test_user):
"""Create a test organization with async_test_user as member (not admin).""" """Create a test organization with async_test_user as member (not admin)."""
test_engine, AsyncTestingSessionLocal = async_test_db _test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session: async with AsyncTestingSessionLocal() as session:
org = Organization( org = Organization(
name="Member Only Org", name="Member Only Org",
slug="member-only-org", slug="member-only-org",
description="Test org where user is just a member" description="Test org where user is just a member",
) )
session.add(org) session.add(org)
await session.commit() await session.commit()
@@ -79,7 +75,7 @@ async def test_org_with_member(async_test_db, async_test_user):
user_id=async_test_user.id, user_id=async_test_user.id,
organization_id=org.id, organization_id=org.id,
role=OrganizationRole.MEMBER, role=OrganizationRole.MEMBER,
is_active=True is_active=True,
) )
session.add(membership) session.add(membership)
await session.commit() await session.commit()
@@ -89,6 +85,7 @@ async def test_org_with_member(async_test_db, async_test_user):
# ===== CRITICAL SECURITY TESTS: Superuser Bypass ===== # ===== CRITICAL SECURITY TESTS: Superuser Bypass =====
class TestSuperuserBypass: class TestSuperuserBypass:
""" """
CRITICAL: Test that superusers can bypass organization checks. CRITICAL: Test that superusers can bypass organization checks.
@@ -99,10 +96,7 @@ class TestSuperuserBypass:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_superuser_can_access_org_not_member_of( async def test_superuser_can_access_org_not_member_of(
self, self, client, superuser_token, test_org_no_members
client,
superuser_token,
test_org_no_members
): ):
""" """
CRITICAL: Superuser should bypass membership check (covers line 175). CRITICAL: Superuser should bypass membership check (covers line 175).
@@ -111,7 +105,7 @@ class TestSuperuserBypass:
""" """
response = await client.get( response = await client.get(
f"/api/v1/organizations/{test_org_no_members.id}", f"/api/v1/organizations/{test_org_no_members.id}",
headers={"Authorization": f"Bearer {superuser_token}"} headers={"Authorization": f"Bearer {superuser_token}"},
) )
# Superuser should succeed even though they're not a member # Superuser should succeed even though they're not a member
@@ -121,15 +115,12 @@ class TestSuperuserBypass:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_regular_user_cannot_access_org_not_member_of( async def test_regular_user_cannot_access_org_not_member_of(
self, self, client, regular_user_token, test_org_no_members
client,
regular_user_token,
test_org_no_members
): ):
"""Regular user should be blocked from org they're not a member of.""" """Regular user should be blocked from org they're not a member of."""
response = await client.get( response = await client.get(
f"/api/v1/organizations/{test_org_no_members.id}", f"/api/v1/organizations/{test_org_no_members.id}",
headers={"Authorization": f"Bearer {regular_user_token}"} headers={"Authorization": f"Bearer {regular_user_token}"},
) )
# Regular user should fail permission check # Regular user should fail permission check
@@ -137,10 +128,7 @@ class TestSuperuserBypass:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_superuser_can_update_org_not_admin_of( async def test_superuser_can_update_org_not_admin_of(
self, self, client, superuser_token, test_org_no_members
client,
superuser_token,
test_org_no_members
): ):
""" """
CRITICAL: Superuser should bypass admin check (covers line 99). CRITICAL: Superuser should bypass admin check (covers line 99).
@@ -150,7 +138,7 @@ class TestSuperuserBypass:
response = await client.put( response = await client.put(
f"/api/v1/organizations/{test_org_no_members.id}", f"/api/v1/organizations/{test_org_no_members.id}",
json={"name": "Updated by Superuser"}, json={"name": "Updated by Superuser"},
headers={"Authorization": f"Bearer {superuser_token}"} headers={"Authorization": f"Bearer {superuser_token}"},
) )
# Superuser should succeed in updating org # Superuser should succeed in updating org
@@ -160,16 +148,13 @@ class TestSuperuserBypass:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_regular_member_cannot_update_org( async def test_regular_member_cannot_update_org(
self, self, client, regular_user_token, test_org_with_member
client,
regular_user_token,
test_org_with_member
): ):
"""Regular member (not admin) should NOT be able to update org.""" """Regular member (not admin) should NOT be able to update org."""
response = await client.put( response = await client.put(
f"/api/v1/organizations/{test_org_with_member.id}", f"/api/v1/organizations/{test_org_with_member.id}",
json={"name": "Should Fail"}, json={"name": "Should Fail"},
headers={"Authorization": f"Bearer {regular_user_token}"} headers={"Authorization": f"Bearer {regular_user_token}"},
) )
# Member should fail - need admin or owner role # Member should fail - need admin or owner role
@@ -177,15 +162,12 @@ class TestSuperuserBypass:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_superuser_can_list_org_members_not_member_of( async def test_superuser_can_list_org_members_not_member_of(
self, self, client, superuser_token, test_org_no_members
client,
superuser_token,
test_org_no_members
): ):
"""CRITICAL: Superuser should bypass membership check to list members.""" """CRITICAL: Superuser should bypass membership check to list members."""
response = await client.get( response = await client.get(
f"/api/v1/organizations/{test_org_no_members.id}/members", f"/api/v1/organizations/{test_org_no_members.id}/members",
headers={"Authorization": f"Bearer {superuser_token}"} headers={"Authorization": f"Bearer {superuser_token}"},
) )
# Superuser should succeed # Superuser should succeed
@@ -197,13 +179,14 @@ class TestSuperuserBypass:
# ===== Edge Cases and Security Tests ===== # ===== Edge Cases and Security Tests =====
class TestPermissionEdgeCases: class TestPermissionEdgeCases:
"""Test edge cases in permission system.""" """Test edge cases in permission system."""
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_inactive_user_blocked(self, client, async_test_db): async def test_inactive_user_blocked(self, client, async_test_db):
"""Test that inactive users are blocked.""" """Test that inactive users are blocked."""
test_engine, AsyncTestingSessionLocal = async_test_db _test_engine, AsyncTestingSessionLocal = async_test_db
# Create inactive user # Create inactive user
async with AsyncTestingSessionLocal() as session: async with AsyncTestingSessionLocal() as session:
@@ -213,7 +196,7 @@ class TestPermissionEdgeCases:
password_hash=get_password_hash("TestPassword123!"), password_hash=get_password_hash("TestPassword123!"),
first_name="Inactive", first_name="Inactive",
last_name="User", last_name="User",
is_active=False # INACTIVE is_active=False, # INACTIVE
) )
session.add(user) session.add(user)
await session.commit() await session.commit()
@@ -222,7 +205,7 @@ class TestPermissionEdgeCases:
# But accessing protected endpoints should fail # But accessing protected endpoints should fail
login_response = await client.post( login_response = await client.post(
"/api/v1/auth/login", "/api/v1/auth/login",
json={"email": "inactive@example.com", "password": "TestPassword123!"} json={"email": "inactive@example.com", "password": "TestPassword123!"},
) )
# Login might fail for inactive users depending on auth implementation # Login might fail for inactive users depending on auth implementation
@@ -231,18 +214,18 @@ class TestPermissionEdgeCases:
# Try to access protected endpoint # Try to access protected endpoint
response = await client.get( response = await client.get(
"/api/v1/users/me", "/api/v1/users/me", headers={"Authorization": f"Bearer {token}"}
headers={"Authorization": f"Bearer {token}"}
) )
# Should be blocked # Should be blocked
assert response.status_code in [status.HTTP_401_UNAUTHORIZED, status.HTTP_403_FORBIDDEN] assert response.status_code in [
status.HTTP_401_UNAUTHORIZED,
status.HTTP_403_FORBIDDEN,
]
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_nonexistent_organization_returns_403_not_404( async def test_nonexistent_organization_returns_403_not_404(
self, self, client, regular_user_token
client,
regular_user_token
): ):
""" """
Test that accessing nonexistent org returns 403, not 404. Test that accessing nonexistent org returns 403, not 404.
@@ -254,7 +237,7 @@ class TestPermissionEdgeCases:
fake_org_id = uuid4() fake_org_id = uuid4()
response = await client.get( response = await client.get(
f"/api/v1/organizations/{fake_org_id}", f"/api/v1/organizations/{fake_org_id}",
headers={"Authorization": f"Bearer {regular_user_token}"} headers={"Authorization": f"Bearer {regular_user_token}"},
) )
# Should get 403 (not a member), not 404 (doesn't exist) # Should get 403 (not a member), not 404 (doesn't exist)
@@ -264,18 +247,16 @@ class TestPermissionEdgeCases:
# ===== Admin Role Tests ===== # ===== Admin Role Tests =====
class TestAdminRolePermissions: class TestAdminRolePermissions:
"""Test admin role can perform admin actions.""" """Test admin role can perform admin actions."""
@pytest_asyncio.fixture @pytest_asyncio.fixture
async def test_org_with_admin(self, async_test_db, async_test_user): async def test_org_with_admin(self, async_test_db, async_test_user):
"""Create org where user is ADMIN.""" """Create org where user is ADMIN."""
test_engine, AsyncTestingSessionLocal = async_test_db _test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session: async with AsyncTestingSessionLocal() as session:
org = Organization( org = Organization(name="Admin Org", slug="admin-org")
name="Admin Org",
slug="admin-org"
)
session.add(org) session.add(org)
await session.commit() await session.commit()
await session.refresh(org) await session.refresh(org)
@@ -284,7 +265,7 @@ class TestAdminRolePermissions:
user_id=async_test_user.id, user_id=async_test_user.id,
organization_id=org.id, organization_id=org.id,
role=OrganizationRole.ADMIN, role=OrganizationRole.ADMIN,
is_active=True is_active=True,
) )
session.add(membership) session.add(membership)
await session.commit() await session.commit()
@@ -293,16 +274,13 @@ class TestAdminRolePermissions:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_admin_can_update_org( async def test_admin_can_update_org(
self, self, client, regular_user_token, test_org_with_admin
client,
regular_user_token,
test_org_with_admin
): ):
"""Admin should be able to update organization.""" """Admin should be able to update organization."""
response = await client.put( response = await client.put(
f"/api/v1/organizations/{test_org_with_admin.id}", f"/api/v1/organizations/{test_org_with_admin.id}",
json={"name": "Updated by Admin"}, json={"name": "Updated by Admin"},
headers={"Authorization": f"Bearer {regular_user_token}"} headers={"Authorization": f"Bearer {regular_user_token}"},
) )
assert response.status_code == status.HTTP_200_OK assert response.status_code == status.HTTP_200_OK

View File

@@ -7,13 +7,13 @@ Critical security tests covering:
These tests prevent unauthorized access and privilege escalation. These tests prevent unauthorized access and privilege escalation.
""" """
import pytest import pytest
from httpx import AsyncClient from httpx import AsyncClient
from sqlalchemy.ext.asyncio import AsyncSession
from app.models.user import User
from app.models.organization import Organization
from app.crud.user import user as user_crud from app.crud.user import user as user_crud
from app.models.organization import Organization
from app.models.user import User
class TestInactiveUserBlocking: class TestInactiveUserBlocking:
@@ -29,11 +29,7 @@ class TestInactiveUserBlocking:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_inactive_user_cannot_access_protected_endpoints( async def test_inactive_user_cannot_access_protected_endpoints(
self, self, client: AsyncClient, async_test_db, async_test_user: User, user_token: str
client: AsyncClient,
async_test_db,
async_test_user: User,
user_token: str
): ):
""" """
Test that inactive users are blocked from protected endpoints. Test that inactive users are blocked from protected endpoints.
@@ -44,12 +40,11 @@ class TestInactiveUserBlocking:
3. User tries to access protected endpoint with valid token 3. User tries to access protected endpoint with valid token
4. System MUST reject (account inactive) 4. System MUST reject (account inactive)
""" """
test_engine, SessionLocal = async_test_db _test_engine, SessionLocal = async_test_db
# Step 1: Verify user can access endpoint while active # Step 1: Verify user can access endpoint while active
response = await client.get( response = await client.get(
"/api/v1/users/me", "/api/v1/users/me", headers={"Authorization": f"Bearer {user_token}"}
headers={"Authorization": f"Bearer {user_token}"}
) )
assert response.status_code == 200, "Active user should have access" assert response.status_code == 200, "Active user should have access"
@@ -61,8 +56,7 @@ class TestInactiveUserBlocking:
# Step 3: User tries to access endpoint with same token # Step 3: User tries to access endpoint with same token
response = await client.get( response = await client.get(
"/api/v1/users/me", "/api/v1/users/me", headers={"Authorization": f"Bearer {user_token}"}
headers={"Authorization": f"Bearer {user_token}"}
) )
# Step 4: System MUST reject (covers lines 52-57) # Step 4: System MUST reject (covers lines 52-57)
@@ -75,18 +69,14 @@ class TestInactiveUserBlocking:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_inactive_user_blocked_from_organization_endpoints( async def test_inactive_user_blocked_from_organization_endpoints(
self, self, client: AsyncClient, async_test_db, async_test_user: User, user_token: str
client: AsyncClient,
async_test_db,
async_test_user: User,
user_token: str
): ):
""" """
Test that inactive users can't access organization endpoints. Test that inactive users can't access organization endpoints.
Ensures the inactive check applies to ALL protected endpoints. Ensures the inactive check applies to ALL protected endpoints.
""" """
test_engine, SessionLocal = async_test_db _test_engine, SessionLocal = async_test_db
# Deactivate user # Deactivate user
async with SessionLocal() as session: async with SessionLocal() as session:
@@ -97,7 +87,7 @@ class TestInactiveUserBlocking:
# Try to list organizations # Try to list organizations
response = await client.get( response = await client.get(
"/api/v1/organizations/me", "/api/v1/organizations/me",
headers={"Authorization": f"Bearer {user_token}"} headers={"Authorization": f"Bearer {user_token}"},
) )
# Must be blocked # Must be blocked
@@ -122,7 +112,7 @@ class TestSuperuserPrivilegeEscalation:
client: AsyncClient, client: AsyncClient,
async_test_db, async_test_db,
async_test_superuser: User, async_test_superuser: User,
superuser_token: str superuser_token: str,
): ):
""" """
Test that superusers automatically get OWNER role in organizations. Test that superusers automatically get OWNER role in organizations.
@@ -131,14 +121,11 @@ class TestSuperuserPrivilegeEscalation:
Superusers can manage any organization without being explicitly added. Superusers can manage any organization without being explicitly added.
This is for platform administration. This is for platform administration.
""" """
test_engine, SessionLocal = async_test_db _test_engine, SessionLocal = async_test_db
# Step 1: Create an organization (owned by someone else) # Step 1: Create an organization (owned by someone else)
async with SessionLocal() as session: async with SessionLocal() as session:
org = Organization( org = Organization(name="Test Organization", slug="test-org")
name="Test Organization",
slug="test-org"
)
session.add(org) session.add(org)
await session.commit() await session.commit()
await session.refresh(org) await session.refresh(org)
@@ -148,7 +135,7 @@ class TestSuperuserPrivilegeEscalation:
# (They're not a member, but should auto-get OWNER role) # (They're not a member, but should auto-get OWNER role)
response = await client.get( response = await client.get(
f"/api/v1/organizations/{org_id}", f"/api/v1/organizations/{org_id}",
headers={"Authorization": f"Bearer {superuser_token}"} headers={"Authorization": f"Bearer {superuser_token}"},
) )
# Step 3: Should have access (covers lines 154-157) # Step 3: Should have access (covers lines 154-157)
@@ -161,21 +148,18 @@ class TestSuperuserPrivilegeEscalation:
client: AsyncClient, client: AsyncClient,
async_test_db, async_test_db,
async_test_superuser: User, async_test_superuser: User,
superuser_token: str superuser_token: str,
): ):
""" """
Test that superusers have full management access to all organizations. Test that superusers have full management access to all organizations.
Ensures the OWNER role privilege escalation works end-to-end. Ensures the OWNER role privilege escalation works end-to-end.
""" """
test_engine, SessionLocal = async_test_db _test_engine, SessionLocal = async_test_db
# Create an organization # Create an organization
async with SessionLocal() as session: async with SessionLocal() as session:
org = Organization( org = Organization(name="Test Organization", slug="test-org")
name="Test Organization",
slug="test-org"
)
session.add(org) session.add(org)
await session.commit() await session.commit()
await session.refresh(org) await session.refresh(org)
@@ -185,34 +169,29 @@ class TestSuperuserPrivilegeEscalation:
response = await client.put( response = await client.put(
f"/api/v1/organizations/{org_id}", f"/api/v1/organizations/{org_id}",
headers={"Authorization": f"Bearer {superuser_token}"}, headers={"Authorization": f"Bearer {superuser_token}"},
json={"name": "Updated Name"} json={"name": "Updated Name"},
) )
# Should succeed (superuser has OWNER privileges) # Should succeed (superuser has OWNER privileges)
assert response.status_code in [200, 404], "Superuser should be able to manage any org" assert response.status_code in [200, 404], (
"Superuser should be able to manage any org"
)
# Note: Might be 404 if org endpoints require membership, but the role check passes # Note: Might be 404 if org endpoints require membership, but the role check passes
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_regular_user_does_not_get_owner_role( async def test_regular_user_does_not_get_owner_role(
self, self, client: AsyncClient, async_test_db, async_test_user: User, user_token: str
client: AsyncClient,
async_test_db,
async_test_user: User,
user_token: str
): ):
""" """
Sanity check: Regular users don't get automatic OWNER role. Sanity check: Regular users don't get automatic OWNER role.
Ensures the superuser check is working correctly (line 154). Ensures the superuser check is working correctly (line 154).
""" """
test_engine, SessionLocal = async_test_db _test_engine, SessionLocal = async_test_db
# Create an organization # Create an organization
async with SessionLocal() as session: async with SessionLocal() as session:
org = Organization( org = Organization(name="Test Organization", slug="test-org")
name="Test Organization",
slug="test-org"
)
session.add(org) session.add(org)
await session.commit() await session.commit()
await session.refresh(org) await session.refresh(org)
@@ -221,8 +200,10 @@ class TestSuperuserPrivilegeEscalation:
# Regular user tries to access it (not a member) # Regular user tries to access it (not a member)
response = await client.get( response = await client.get(
f"/api/v1/organizations/{org_id}", f"/api/v1/organizations/{org_id}",
headers={"Authorization": f"Bearer {user_token}"} headers={"Authorization": f"Bearer {user_token}"},
) )
# Should be denied (not a member, not a superuser) # Should be denied (not a member, not a superuser)
assert response.status_code in [403, 404], "Regular user shouldn't access non-member org" assert response.status_code in [403, 404], (
"Regular user shouldn't access non-member org"
)

View File

@@ -1,7 +1,8 @@
# tests/api/test_security_headers.py # tests/api/test_security_headers.py
from unittest.mock import patch
import pytest import pytest
from fastapi.testclient import TestClient from fastapi.testclient import TestClient
from unittest.mock import patch
from app.main import app from app.main import app
@@ -11,8 +12,10 @@ def client():
"""Create a FastAPI test client for the main app (module-scoped for speed).""" """Create a FastAPI test client for the main app (module-scoped for speed)."""
# Mock get_db to avoid database connection issues # Mock get_db to avoid database connection issues
with patch("app.core.database.get_db") as mock_get_db: with patch("app.core.database.get_db") as mock_get_db:
async def mock_session_generator(): async def mock_session_generator():
from unittest.mock import MagicMock, AsyncMock from unittest.mock import AsyncMock, MagicMock
mock_session = MagicMock() mock_session = MagicMock()
mock_session.execute = AsyncMock(return_value=None) mock_session.execute = AsyncMock(return_value=None)
mock_session.close = AsyncMock(return_value=None) mock_session.close = AsyncMock(return_value=None)
@@ -77,8 +80,10 @@ class TestSecurityHeaders:
"""Test that HSTS header is set in production (covers line 95)""" """Test that HSTS header is set in production (covers line 95)"""
with patch("app.core.config.settings.ENVIRONMENT", "production"): with patch("app.core.config.settings.ENVIRONMENT", "production"):
with patch("app.core.database.get_db") as mock_get_db: with patch("app.core.database.get_db") as mock_get_db:
async def mock_session_generator(): async def mock_session_generator():
from unittest.mock import MagicMock, AsyncMock from unittest.mock import AsyncMock, MagicMock
mock_session = MagicMock() mock_session = MagicMock()
mock_session.execute = AsyncMock(return_value=None) mock_session.execute = AsyncMock(return_value=None)
mock_session.close = AsyncMock(return_value=None) mock_session.close = AsyncMock(return_value=None)
@@ -88,20 +93,26 @@ class TestSecurityHeaders:
# Need to reimport app to pick up the new settings # Need to reimport app to pick up the new settings
from importlib import reload from importlib import reload
import app.main import app.main
reload(app.main) reload(app.main)
test_client = TestClient(app.main.app) test_client = TestClient(app.main.app)
response = test_client.get("/health") response = test_client.get("/health")
assert "Strict-Transport-Security" in response.headers assert "Strict-Transport-Security" in response.headers
assert "max-age=31536000" in response.headers["Strict-Transport-Security"] assert (
"max-age=31536000" in response.headers["Strict-Transport-Security"]
)
def test_csp_strict_mode(self): def test_csp_strict_mode(self):
"""Test CSP strict mode (covers line 121)""" """Test CSP strict mode (covers line 121)"""
with patch("app.core.config.settings.CSP_MODE", "strict"): with patch("app.core.config.settings.CSP_MODE", "strict"):
with patch("app.core.database.get_db") as mock_get_db: with patch("app.core.database.get_db") as mock_get_db:
async def mock_session_generator(): async def mock_session_generator():
from unittest.mock import MagicMock, AsyncMock from unittest.mock import AsyncMock, MagicMock
mock_session = MagicMock() mock_session = MagicMock()
mock_session.execute = AsyncMock(return_value=None) mock_session.execute = AsyncMock(return_value=None)
mock_session.close = AsyncMock(return_value=None) mock_session.close = AsyncMock(return_value=None)
@@ -110,7 +121,9 @@ class TestSecurityHeaders:
mock_get_db.side_effect = lambda: mock_session_generator() mock_get_db.side_effect = lambda: mock_session_generator()
from importlib import reload from importlib import reload
import app.main import app.main
reload(app.main) reload(app.main)
test_client = TestClient(app.main.app) test_client = TestClient(app.main.app)
@@ -136,8 +149,10 @@ class TestRootEndpoint:
def test_root_endpoint(self): def test_root_endpoint(self):
"""Test root endpoint returns HTML (covers line 174)""" """Test root endpoint returns HTML (covers line 174)"""
with patch("app.core.database.get_db") as mock_get_db: with patch("app.core.database.get_db") as mock_get_db:
async def mock_session_generator(): async def mock_session_generator():
from unittest.mock import MagicMock, AsyncMock from unittest.mock import AsyncMock, MagicMock
mock_session = MagicMock() mock_session = MagicMock()
mock_session.execute = AsyncMock(return_value=None) mock_session.execute = AsyncMock(return_value=None)
mock_session.close = AsyncMock(return_value=None) mock_session.close = AsyncMock(return_value=None)

View File

@@ -2,23 +2,23 @@
""" """
Comprehensive tests for session management API endpoints. Comprehensive tests for session management API endpoints.
""" """
from datetime import UTC, datetime, timedelta
from unittest.mock import patch
from uuid import uuid4
import pytest import pytest
import pytest_asyncio import pytest_asyncio
from datetime import datetime, timedelta, timezone
from uuid import uuid4
from unittest.mock import patch
from fastapi import status from fastapi import status
from app.models.user_session import UserSession from app.models.user_session import UserSession
from app.schemas.users import UserCreate
# Disable rate limiting for tests # Disable rate limiting for tests
@pytest.fixture(autouse=True) @pytest.fixture(autouse=True)
def disable_rate_limit(): def disable_rate_limit():
"""Disable rate limiting for all tests in this module.""" """Disable rate limiting for all tests in this module."""
with patch('app.api.routes.sessions.limiter.enabled', False): with patch("app.api.routes.sessions.limiter.enabled", False):
yield yield
@@ -27,10 +27,7 @@ async def user_token(client, async_test_user):
"""Create and return an access token for async_test_user.""" """Create and return an access token for async_test_user."""
response = await client.post( response = await client.post(
"/api/v1/auth/login", "/api/v1/auth/login",
json={ json={"email": "testuser@example.com", "password": "TestPassword123!"},
"email": "testuser@example.com",
"password": "TestPassword123!"
}
) )
assert response.status_code == 200 assert response.status_code == 200
return response.json()["access_token"] return response.json()["access_token"]
@@ -39,7 +36,7 @@ async def user_token(client, async_test_user):
@pytest_asyncio.fixture @pytest_asyncio.fixture
async def async_test_user2(async_test_db): async def async_test_user2(async_test_db):
"""Create a second test user.""" """Create a second test user."""
test_engine, SessionLocal = async_test_db _test_engine, SessionLocal = async_test_db
async with SessionLocal() as session: async with SessionLocal() as session:
from app.crud.user import user as user_crud from app.crud.user import user as user_crud
@@ -49,7 +46,7 @@ async def async_test_user2(async_test_db):
email="testuser2@example.com", email="testuser2@example.com",
password="TestPassword123!", password="TestPassword123!",
first_name="Test", first_name="Test",
last_name="User2" last_name="User2",
) )
user = await user_crud.create(session, obj_in=user_data) user = await user_crud.create(session, obj_in=user_data)
await session.commit() await session.commit()
@@ -61,9 +58,11 @@ class TestListMySessions:
"""Tests for GET /api/v1/sessions/me endpoint.""" """Tests for GET /api/v1/sessions/me endpoint."""
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_list_my_sessions_success(self, client, async_test_user, async_test_db, user_token): async def test_list_my_sessions_success(
self, client, async_test_user, async_test_db, user_token
):
"""Test successfully listing user's active sessions.""" """Test successfully listing user's active sessions."""
test_engine, SessionLocal = async_test_db _test_engine, SessionLocal = async_test_db
# Create some sessions for the user # Create some sessions for the user
async with SessionLocal() as session: async with SessionLocal() as session:
@@ -75,8 +74,8 @@ class TestListMySessions:
ip_address="192.168.1.100", ip_address="192.168.1.100",
user_agent="Mozilla/5.0 (iPhone)", user_agent="Mozilla/5.0 (iPhone)",
is_active=True, is_active=True,
expires_at=datetime.now(timezone.utc) + timedelta(days=7), expires_at=datetime.now(UTC) + timedelta(days=7),
last_used_at=datetime.now(timezone.utc) last_used_at=datetime.now(UTC),
) )
# Active session 2 # Active session 2
s2 = UserSession( s2 = UserSession(
@@ -86,8 +85,8 @@ class TestListMySessions:
ip_address="192.168.1.101", ip_address="192.168.1.101",
user_agent="Mozilla/5.0 (Macintosh)", user_agent="Mozilla/5.0 (Macintosh)",
is_active=True, is_active=True,
expires_at=datetime.now(timezone.utc) + timedelta(days=7), expires_at=datetime.now(UTC) + timedelta(days=7),
last_used_at=datetime.now(timezone.utc) - timedelta(hours=1) last_used_at=datetime.now(UTC) - timedelta(hours=1),
) )
# Inactive session (should not appear) # Inactive session (should not appear)
s3 = UserSession( s3 = UserSession(
@@ -97,16 +96,15 @@ class TestListMySessions:
ip_address="192.168.1.102", ip_address="192.168.1.102",
user_agent="Mozilla/5.0", user_agent="Mozilla/5.0",
is_active=False, is_active=False,
expires_at=datetime.now(timezone.utc) + timedelta(days=7), expires_at=datetime.now(UTC) + timedelta(days=7),
last_used_at=datetime.now(timezone.utc) - timedelta(days=1) last_used_at=datetime.now(UTC) - timedelta(days=1),
) )
session.add_all([s1, s2, s3]) session.add_all([s1, s2, s3])
await session.commit() await session.commit()
# Make request # Make request
response = await client.get( response = await client.get(
"/api/v1/sessions/me", "/api/v1/sessions/me", headers={"Authorization": f"Bearer {user_token}"}
headers={"Authorization": f"Bearer {user_token}"}
) )
assert response.status_code == status.HTTP_200_OK assert response.status_code == status.HTTP_200_OK
@@ -128,11 +126,12 @@ class TestListMySessions:
assert data["sessions"][0]["is_current"] is True assert data["sessions"][0]["is_current"] is True
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_list_my_sessions_with_login_session(self, client, async_test_user, user_token): async def test_list_my_sessions_with_login_session(
self, client, async_test_user, user_token
):
"""Test listing sessions shows the login session.""" """Test listing sessions shows the login session."""
response = await client.get( response = await client.get(
"/api/v1/sessions/me", "/api/v1/sessions/me", headers={"Authorization": f"Bearer {user_token}"}
headers={"Authorization": f"Bearer {user_token}"}
) )
assert response.status_code == status.HTTP_200_OK assert response.status_code == status.HTTP_200_OK
@@ -155,9 +154,11 @@ class TestRevokeSession:
"""Tests for DELETE /api/v1/sessions/{session_id} endpoint.""" """Tests for DELETE /api/v1/sessions/{session_id} endpoint."""
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_revoke_session_success(self, client, async_test_user, async_test_db, user_token): async def test_revoke_session_success(
self, client, async_test_user, async_test_db, user_token
):
"""Test successfully revoking a session.""" """Test successfully revoking a session."""
test_engine, SessionLocal = async_test_db _test_engine, SessionLocal = async_test_db
# Create a session to revoke # Create a session to revoke
async with SessionLocal() as session: async with SessionLocal() as session:
@@ -168,8 +169,8 @@ class TestRevokeSession:
ip_address="192.168.1.103", ip_address="192.168.1.103",
user_agent="Mozilla/5.0 (iPad)", user_agent="Mozilla/5.0 (iPad)",
is_active=True, is_active=True,
expires_at=datetime.now(timezone.utc) + timedelta(days=7), expires_at=datetime.now(UTC) + timedelta(days=7),
last_used_at=datetime.now(timezone.utc) last_used_at=datetime.now(UTC),
) )
session.add(user_session) session.add(user_session)
await session.commit() await session.commit()
@@ -179,7 +180,7 @@ class TestRevokeSession:
# Revoke the session # Revoke the session
response = await client.delete( response = await client.delete(
f"/api/v1/sessions/{session_id}", f"/api/v1/sessions/{session_id}",
headers={"Authorization": f"Bearer {user_token}"} headers={"Authorization": f"Bearer {user_token}"},
) )
assert response.status_code == status.HTTP_200_OK assert response.status_code == status.HTTP_200_OK
@@ -191,6 +192,7 @@ class TestRevokeSession:
# Verify session is deactivated # Verify session is deactivated
async with SessionLocal() as session: async with SessionLocal() as session:
from app.crud.session import session as session_crud from app.crud.session import session as session_crud
revoked_session = await session_crud.get(session, id=str(session_id)) revoked_session = await session_crud.get(session, id=str(session_id))
assert revoked_session.is_active is False assert revoked_session.is_active is False
@@ -200,7 +202,7 @@ class TestRevokeSession:
fake_id = uuid4() fake_id = uuid4()
response = await client.delete( response = await client.delete(
f"/api/v1/sessions/{fake_id}", f"/api/v1/sessions/{fake_id}",
headers={"Authorization": f"Bearer {user_token}"} headers={"Authorization": f"Bearer {user_token}"},
) )
assert response.status_code == status.HTTP_404_NOT_FOUND assert response.status_code == status.HTTP_404_NOT_FOUND
@@ -222,7 +224,7 @@ class TestRevokeSession:
self, client, async_test_user, async_test_user2, async_test_db, user_token self, client, async_test_user, async_test_user2, async_test_db, user_token
): ):
"""Test that users cannot revoke other users' sessions.""" """Test that users cannot revoke other users' sessions."""
test_engine, SessionLocal = async_test_db _test_engine, SessionLocal = async_test_db
# Create a session for user2 # Create a session for user2
async with SessionLocal() as session: async with SessionLocal() as session:
@@ -233,8 +235,8 @@ class TestRevokeSession:
ip_address="192.168.1.200", ip_address="192.168.1.200",
user_agent="Mozilla/5.0", user_agent="Mozilla/5.0",
is_active=True, is_active=True,
expires_at=datetime.now(timezone.utc) + timedelta(days=7), expires_at=datetime.now(UTC) + timedelta(days=7),
last_used_at=datetime.now(timezone.utc) last_used_at=datetime.now(UTC),
) )
session.add(other_user_session) session.add(other_user_session)
await session.commit() await session.commit()
@@ -244,7 +246,7 @@ class TestRevokeSession:
# Try to revoke it as user1 # Try to revoke it as user1
response = await client.delete( response = await client.delete(
f"/api/v1/sessions/{session_id}", f"/api/v1/sessions/{session_id}",
headers={"Authorization": f"Bearer {user_token}"} headers={"Authorization": f"Bearer {user_token}"},
) )
assert response.status_code == status.HTTP_403_FORBIDDEN assert response.status_code == status.HTTP_403_FORBIDDEN
@@ -263,7 +265,7 @@ class TestCleanupExpiredSessions:
self, client, async_test_user, async_test_db, user_token self, client, async_test_user, async_test_db, user_token
): ):
"""Test successfully cleaning up expired sessions.""" """Test successfully cleaning up expired sessions."""
test_engine, SessionLocal = async_test_db _test_engine, SessionLocal = async_test_db
# Create expired and active sessions using CRUD to avoid greenlet issues # Create expired and active sessions using CRUD to avoid greenlet issues
from app.crud.session import session as session_crud from app.crud.session import session as session_crud
@@ -277,8 +279,8 @@ class TestCleanupExpiredSessions:
device_name="Expired 1", device_name="Expired 1",
ip_address="192.168.1.201", ip_address="192.168.1.201",
user_agent="Mozilla/5.0", user_agent="Mozilla/5.0",
expires_at=datetime.now(timezone.utc) - timedelta(days=1), expires_at=datetime.now(UTC) - timedelta(days=1),
last_used_at=datetime.now(timezone.utc) - timedelta(days=2) last_used_at=datetime.now(UTC) - timedelta(days=2),
) )
e1 = await session_crud.create_session(db, obj_in=e1_data) e1 = await session_crud.create_session(db, obj_in=e1_data)
e1.is_active = False e1.is_active = False
@@ -291,8 +293,8 @@ class TestCleanupExpiredSessions:
device_name="Expired 2", device_name="Expired 2",
ip_address="192.168.1.202", ip_address="192.168.1.202",
user_agent="Mozilla/5.0", user_agent="Mozilla/5.0",
expires_at=datetime.now(timezone.utc) - timedelta(hours=1), expires_at=datetime.now(UTC) - timedelta(hours=1),
last_used_at=datetime.now(timezone.utc) - timedelta(hours=2) last_used_at=datetime.now(UTC) - timedelta(hours=2),
) )
e2 = await session_crud.create_session(db, obj_in=e2_data) e2 = await session_crud.create_session(db, obj_in=e2_data)
e2.is_active = False e2.is_active = False
@@ -305,8 +307,8 @@ class TestCleanupExpiredSessions:
device_name="Active", device_name="Active",
ip_address="192.168.1.203", ip_address="192.168.1.203",
user_agent="Mozilla/5.0", user_agent="Mozilla/5.0",
expires_at=datetime.now(timezone.utc) + timedelta(days=7), expires_at=datetime.now(UTC) + timedelta(days=7),
last_used_at=datetime.now(timezone.utc) last_used_at=datetime.now(UTC),
) )
await session_crud.create_session(db, obj_in=a1_data) await session_crud.create_session(db, obj_in=a1_data)
await db.commit() await db.commit()
@@ -314,7 +316,7 @@ class TestCleanupExpiredSessions:
# Cleanup expired sessions # Cleanup expired sessions
response = await client.delete( response = await client.delete(
"/api/v1/sessions/me/expired", "/api/v1/sessions/me/expired",
headers={"Authorization": f"Bearer {user_token}"} headers={"Authorization": f"Bearer {user_token}"},
) )
assert response.status_code == status.HTTP_200_OK assert response.status_code == status.HTTP_200_OK
@@ -329,7 +331,7 @@ class TestCleanupExpiredSessions:
self, client, async_test_user, async_test_db, user_token self, client, async_test_user, async_test_db, user_token
): ):
"""Test cleanup when no sessions are expired.""" """Test cleanup when no sessions are expired."""
test_engine, SessionLocal = async_test_db _test_engine, SessionLocal = async_test_db
# Create only active sessions using CRUD # Create only active sessions using CRUD
from app.crud.session import session as session_crud from app.crud.session import session as session_crud
@@ -342,15 +344,15 @@ class TestCleanupExpiredSessions:
device_name="Active Device", device_name="Active Device",
ip_address="192.168.1.210", ip_address="192.168.1.210",
user_agent="Mozilla/5.0", user_agent="Mozilla/5.0",
expires_at=datetime.now(timezone.utc) + timedelta(days=7), expires_at=datetime.now(UTC) + timedelta(days=7),
last_used_at=datetime.now(timezone.utc) last_used_at=datetime.now(UTC),
) )
await session_crud.create_session(db, obj_in=a1_data) await session_crud.create_session(db, obj_in=a1_data)
await db.commit() await db.commit()
response = await client.delete( response = await client.delete(
"/api/v1/sessions/me/expired", "/api/v1/sessions/me/expired",
headers={"Authorization": f"Bearer {user_token}"} headers={"Authorization": f"Bearer {user_token}"},
) )
assert response.status_code == status.HTTP_200_OK assert response.status_code == status.HTTP_200_OK
@@ -369,13 +371,16 @@ class TestCleanupExpiredSessions:
# Additional tests for better coverage # Additional tests for better coverage
class TestSessionsAdditionalCases: class TestSessionsAdditionalCases:
"""Additional tests to improve sessions endpoint coverage.""" """Additional tests to improve sessions endpoint coverage."""
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_list_sessions_pagination(self, client, async_test_user, async_test_db, user_token): async def test_list_sessions_pagination(
self, client, async_test_user, async_test_db, user_token
):
"""Test listing sessions with pagination.""" """Test listing sessions with pagination."""
test_engine, SessionLocal = async_test_db _test_engine, SessionLocal = async_test_db
# Create multiple sessions # Create multiple sessions
async with SessionLocal() as session: async with SessionLocal() as session:
@@ -389,15 +394,15 @@ class TestSessionsAdditionalCases:
device_name=f"Device {i}", device_name=f"Device {i}",
ip_address=f"192.168.1.{i}", ip_address=f"192.168.1.{i}",
user_agent="Mozilla/5.0", user_agent="Mozilla/5.0",
expires_at=datetime.now(timezone.utc) + timedelta(days=7), expires_at=datetime.now(UTC) + timedelta(days=7),
last_used_at=datetime.now(timezone.utc) last_used_at=datetime.now(UTC),
) )
await session_crud.create_session(session, obj_in=session_data) await session_crud.create_session(session, obj_in=session_data)
await session.commit() await session.commit()
response = await client.get( response = await client.get(
"/api/v1/sessions/me?page=1&limit=3", "/api/v1/sessions/me?page=1&limit=3",
headers={"Authorization": f"Bearer {user_token}"} headers={"Authorization": f"Bearer {user_token}"},
) )
assert response.status_code == status.HTTP_200_OK assert response.status_code == status.HTTP_200_OK
@@ -410,16 +415,21 @@ class TestSessionsAdditionalCases:
"""Test revoking session with invalid UUID.""" """Test revoking session with invalid UUID."""
response = await client.delete( response = await client.delete(
"/api/v1/sessions/not-a-uuid", "/api/v1/sessions/not-a-uuid",
headers={"Authorization": f"Bearer {user_token}"} headers={"Authorization": f"Bearer {user_token}"},
) )
# Should return 422 for invalid UUID format # Should return 422 for invalid UUID format
assert response.status_code in [status.HTTP_422_UNPROCESSABLE_ENTITY, status.HTTP_404_NOT_FOUND] assert response.status_code in [
status.HTTP_422_UNPROCESSABLE_ENTITY,
status.HTTP_404_NOT_FOUND,
]
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_cleanup_expired_sessions_with_mixed_states(self, client, async_test_user, async_test_db, user_token): async def test_cleanup_expired_sessions_with_mixed_states(
self, client, async_test_user, async_test_db, user_token
):
"""Test cleanup with mix of active/inactive and expired/not-expired sessions.""" """Test cleanup with mix of active/inactive and expired/not-expired sessions."""
test_engine, SessionLocal = async_test_db _test_engine, SessionLocal = async_test_db
from app.crud.session import session as session_crud from app.crud.session import session as session_crud
from app.schemas.sessions import SessionCreate from app.schemas.sessions import SessionCreate
@@ -432,8 +442,8 @@ class TestSessionsAdditionalCases:
device_name="Expired Inactive", device_name="Expired Inactive",
ip_address="192.168.1.100", ip_address="192.168.1.100",
user_agent="Mozilla/5.0", user_agent="Mozilla/5.0",
expires_at=datetime.now(timezone.utc) - timedelta(days=1), expires_at=datetime.now(UTC) - timedelta(days=1),
last_used_at=datetime.now(timezone.utc) - timedelta(days=2) last_used_at=datetime.now(UTC) - timedelta(days=2),
) )
e1 = await session_crud.create_session(db, obj_in=e1_data) e1 = await session_crud.create_session(db, obj_in=e1_data)
e1.is_active = False e1.is_active = False
@@ -446,8 +456,8 @@ class TestSessionsAdditionalCases:
device_name="Expired Active", device_name="Expired Active",
ip_address="192.168.1.101", ip_address="192.168.1.101",
user_agent="Mozilla/5.0", user_agent="Mozilla/5.0",
expires_at=datetime.now(timezone.utc) - timedelta(hours=1), expires_at=datetime.now(UTC) - timedelta(hours=1),
last_used_at=datetime.now(timezone.utc) - timedelta(hours=2) last_used_at=datetime.now(UTC) - timedelta(hours=2),
) )
await session_crud.create_session(db, obj_in=e2_data) await session_crud.create_session(db, obj_in=e2_data)
@@ -455,7 +465,7 @@ class TestSessionsAdditionalCases:
response = await client.delete( response = await client.delete(
"/api/v1/sessions/me/expired", "/api/v1/sessions/me/expired",
headers={"Authorization": f"Bearer {user_token}"} headers={"Authorization": f"Bearer {user_token}"},
) )
assert response.status_code == status.HTTP_200_OK assert response.status_code == status.HTTP_200_OK
@@ -476,10 +486,12 @@ class TestSessionExceptionHandlers:
from unittest.mock import patch from unittest.mock import patch
# Patch decode_token to raise an exception # Patch decode_token to raise an exception
with patch('app.api.routes.sessions.decode_token', side_effect=Exception("Token decode error")): with patch(
"app.api.routes.sessions.decode_token",
side_effect=Exception("Token decode error"),
):
response = await client.get( response = await client.get(
"/api/v1/sessions/me", "/api/v1/sessions/me", headers={"Authorization": f"Bearer {user_token}"}
headers={"Authorization": f"Bearer {user_token}"}
) )
# Should still succeed (exception is caught and ignored in try/except at line 77) # Should still succeed (exception is caught and ignored in try/except at line 77)
@@ -489,12 +501,16 @@ class TestSessionExceptionHandlers:
async def test_list_sessions_database_error(self, client, user_token): async def test_list_sessions_database_error(self, client, user_token):
"""Test list_sessions handles database errors (covers lines 104-106).""" """Test list_sessions handles database errors (covers lines 104-106)."""
from unittest.mock import patch from unittest.mock import patch
from app.crud import session as session_module from app.crud import session as session_module
with patch.object(session_module.session, 'get_user_sessions', side_effect=Exception("Database error")): with patch.object(
session_module.session,
"get_user_sessions",
side_effect=Exception("Database error"),
):
response = await client.get( response = await client.get(
"/api/v1/sessions/me", "/api/v1/sessions/me", headers={"Authorization": f"Bearer {user_token}"}
headers={"Authorization": f"Bearer {user_token}"}
) )
assert response.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR assert response.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR
@@ -503,18 +519,21 @@ class TestSessionExceptionHandlers:
assert data["errors"][0]["message"] == "Failed to retrieve sessions" assert data["errors"][0]["message"] == "Failed to retrieve sessions"
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_revoke_session_database_error(self, client, user_token, async_test_db, async_test_user): async def test_revoke_session_database_error(
self, client, user_token, async_test_db, async_test_user
):
"""Test revoke_session handles database errors (covers lines 181-183).""" """Test revoke_session handles database errors (covers lines 181-183)."""
from datetime import datetime, timedelta
from unittest.mock import patch from unittest.mock import patch
from uuid import uuid4 from uuid import uuid4
from app.crud import session as session_module from app.crud import session as session_module
# First create a session to revoke # First create a session to revoke
from app.crud.session import session as session_crud from app.crud.session import session as session_crud
from app.schemas.sessions import SessionCreate from app.schemas.sessions import SessionCreate
from datetime import datetime, timedelta, timezone
test_engine, AsyncTestingSessionLocal = async_test_db _test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as db: async with AsyncTestingSessionLocal() as db:
session_in = SessionCreate( session_in = SessionCreate(
@@ -523,17 +542,21 @@ class TestSessionExceptionHandlers:
device_name="Test Device", device_name="Test Device",
ip_address="192.168.1.1", ip_address="192.168.1.1",
user_agent="Mozilla/5.0", user_agent="Mozilla/5.0",
last_used_at=datetime.now(timezone.utc), last_used_at=datetime.now(UTC),
expires_at=datetime.now(timezone.utc) + timedelta(days=60) expires_at=datetime.now(UTC) + timedelta(days=60),
) )
user_session = await session_crud.create_session(db, obj_in=session_in) user_session = await session_crud.create_session(db, obj_in=session_in)
session_id = user_session.id session_id = user_session.id
# Mock the deactivate method to raise an exception # Mock the deactivate method to raise an exception
with patch.object(session_module.session, 'deactivate', side_effect=Exception("Database connection lost")): with patch.object(
session_module.session,
"deactivate",
side_effect=Exception("Database connection lost"),
):
response = await client.delete( response = await client.delete(
f"/api/v1/sessions/{session_id}", f"/api/v1/sessions/{session_id}",
headers={"Authorization": f"Bearer {user_token}"} headers={"Authorization": f"Bearer {user_token}"},
) )
assert response.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR assert response.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR
@@ -544,12 +567,17 @@ class TestSessionExceptionHandlers:
async def test_cleanup_expired_sessions_database_error(self, client, user_token): async def test_cleanup_expired_sessions_database_error(self, client, user_token):
"""Test cleanup_expired_sessions handles database errors (covers lines 233-236).""" """Test cleanup_expired_sessions handles database errors (covers lines 233-236)."""
from unittest.mock import patch from unittest.mock import patch
from app.crud import session as session_module from app.crud import session as session_module
with patch.object(session_module.session, 'cleanup_expired_for_user', side_effect=Exception("Cleanup failed")): with patch.object(
session_module.session,
"cleanup_expired_for_user",
side_effect=Exception("Cleanup failed"),
):
response = await client.delete( response = await client.delete(
"/api/v1/sessions/me/expired", "/api/v1/sessions/me/expired",
headers={"Authorization": f"Bearer {user_token}"} headers={"Authorization": f"Bearer {user_token}"},
) )
assert response.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR assert response.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR

View File

@@ -3,32 +3,29 @@
Comprehensive tests for user management endpoints. Comprehensive tests for user management endpoints.
These tests focus on finding potential bugs, not just coverage. These tests focus on finding potential bugs, not just coverage.
""" """
import pytest
import pytest_asyncio
from unittest.mock import patch
from fastapi import status
import uuid
from sqlalchemy import select import uuid
from unittest.mock import patch
import pytest
from fastapi import status
from app.models.user import User from app.models.user import User
from app.models.user import User
from app.schemas.users import UserUpdate
# Disable rate limiting for tests # Disable rate limiting for tests
@pytest.fixture(autouse=True) @pytest.fixture(autouse=True)
def disable_rate_limit(): def disable_rate_limit():
"""Disable rate limiting for all tests in this module.""" """Disable rate limiting for all tests in this module."""
with patch('app.api.routes.users.limiter.enabled', False): with patch("app.api.routes.users.limiter.enabled", False):
with patch('app.api.routes.auth.limiter.enabled', False): with patch("app.api.routes.auth.limiter.enabled", False):
yield yield
async def get_auth_headers(client, email, password): async def get_auth_headers(client, email, password):
"""Helper to get authentication headers.""" """Helper to get authentication headers."""
response = await client.post( response = await client.post(
"/api/v1/auth/login", "/api/v1/auth/login", json={"email": email, "password": password}
json={"email": email, "password": password}
) )
token = response.json()["access_token"] token = response.json()["access_token"]
return {"Authorization": f"Bearer {token}"} return {"Authorization": f"Bearer {token}"}
@@ -40,7 +37,9 @@ class TestListUsers:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_list_users_as_superuser(self, client, async_test_superuser): async def test_list_users_as_superuser(self, client, async_test_superuser):
"""Test listing users as superuser.""" """Test listing users as superuser."""
headers = await get_auth_headers(client, async_test_superuser.email, "SuperPassword123!") headers = await get_auth_headers(
client, async_test_superuser.email, "SuperPassword123!"
)
response = await client.get("/api/v1/users", headers=headers) response = await client.get("/api/v1/users", headers=headers)
@@ -53,16 +52,20 @@ class TestListUsers:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_list_users_as_regular_user(self, client, async_test_user): async def test_list_users_as_regular_user(self, client, async_test_user):
"""Test that regular users cannot list users.""" """Test that regular users cannot list users."""
headers = await get_auth_headers(client, async_test_user.email, "TestPassword123!") headers = await get_auth_headers(
client, async_test_user.email, "TestPassword123!"
)
response = await client.get("/api/v1/users", headers=headers) response = await client.get("/api/v1/users", headers=headers)
assert response.status_code == status.HTTP_403_FORBIDDEN assert response.status_code == status.HTTP_403_FORBIDDEN
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_list_users_pagination(self, client, async_test_superuser, async_test_db): async def test_list_users_pagination(
self, client, async_test_superuser, async_test_db
):
"""Test pagination works correctly.""" """Test pagination works correctly."""
test_engine, AsyncTestingSessionLocal = async_test_db _test_engine, AsyncTestingSessionLocal = async_test_db
# Create multiple users # Create multiple users
async with AsyncTestingSessionLocal() as session: async with AsyncTestingSessionLocal() as session:
@@ -72,12 +75,14 @@ class TestListUsers:
password_hash="hash", password_hash="hash",
first_name=f"PagUser{i}", first_name=f"PagUser{i}",
is_active=True, is_active=True,
is_superuser=False is_superuser=False,
) )
session.add(user) session.add(user)
await session.commit() await session.commit()
headers = await get_auth_headers(client, async_test_superuser.email, "SuperPassword123!") headers = await get_auth_headers(
client, async_test_superuser.email, "SuperPassword123!"
)
# Get first page # Get first page
response = await client.get("/api/v1/users?page=1&limit=5", headers=headers) response = await client.get("/api/v1/users?page=1&limit=5", headers=headers)
@@ -88,9 +93,11 @@ class TestListUsers:
assert data["pagination"]["total"] >= 15 assert data["pagination"]["total"] >= 15
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_list_users_filter_active(self, client, async_test_superuser, async_test_db): async def test_list_users_filter_active(
self, client, async_test_superuser, async_test_db
):
"""Test filtering by active status.""" """Test filtering by active status."""
test_engine, AsyncTestingSessionLocal = async_test_db _test_engine, AsyncTestingSessionLocal = async_test_db
# Create active and inactive users # Create active and inactive users
async with AsyncTestingSessionLocal() as session: async with AsyncTestingSessionLocal() as session:
@@ -99,19 +106,21 @@ class TestListUsers:
password_hash="hash", password_hash="hash",
first_name="Active", first_name="Active",
is_active=True, is_active=True,
is_superuser=False is_superuser=False,
) )
inactive_user = User( inactive_user = User(
email="inactivefilter@example.com", email="inactivefilter@example.com",
password_hash="hash", password_hash="hash",
first_name="Inactive", first_name="Inactive",
is_active=False, is_active=False,
is_superuser=False is_superuser=False,
) )
session.add_all([active_user, inactive_user]) session.add_all([active_user, inactive_user])
await session.commit() await session.commit()
headers = await get_auth_headers(client, async_test_superuser.email, "SuperPassword123!") headers = await get_auth_headers(
client, async_test_superuser.email, "SuperPassword123!"
)
# Filter for active users # Filter for active users
response = await client.get("/api/v1/users?is_active=true", headers=headers) response = await client.get("/api/v1/users?is_active=true", headers=headers)
@@ -130,9 +139,13 @@ class TestListUsers:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_list_users_sort_by_email(self, client, async_test_superuser): async def test_list_users_sort_by_email(self, client, async_test_superuser):
"""Test sorting users by email.""" """Test sorting users by email."""
headers = await get_auth_headers(client, async_test_superuser.email, "SuperPassword123!") headers = await get_auth_headers(
client, async_test_superuser.email, "SuperPassword123!"
)
response = await client.get("/api/v1/users?sort_by=email&sort_order=asc", headers=headers) response = await client.get(
"/api/v1/users?sort_by=email&sort_order=asc", headers=headers
)
assert response.status_code == status.HTTP_200_OK assert response.status_code == status.HTTP_200_OK
data = response.json() data = response.json()
emails = [u["email"] for u in data["data"]] emails = [u["email"] for u in data["data"]]
@@ -154,7 +167,9 @@ class TestGetCurrentUserProfile:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_get_own_profile(self, client, async_test_user): async def test_get_own_profile(self, client, async_test_user):
"""Test getting own profile.""" """Test getting own profile."""
headers = await get_auth_headers(client, async_test_user.email, "TestPassword123!") headers = await get_auth_headers(
client, async_test_user.email, "TestPassword123!"
)
response = await client.get("/api/v1/users/me", headers=headers) response = await client.get("/api/v1/users/me", headers=headers)
@@ -176,12 +191,14 @@ class TestUpdateCurrentUser:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_update_own_profile(self, client, async_test_user): async def test_update_own_profile(self, client, async_test_user):
"""Test updating own profile.""" """Test updating own profile."""
headers = await get_auth_headers(client, async_test_user.email, "TestPassword123!") headers = await get_auth_headers(
client, async_test_user.email, "TestPassword123!"
)
response = await client.patch( response = await client.patch(
"/api/v1/users/me", "/api/v1/users/me",
headers=headers, headers=headers,
json={"first_name": "Updated", "last_name": "Name"} json={"first_name": "Updated", "last_name": "Name"},
) )
assert response.status_code == status.HTTP_200_OK assert response.status_code == status.HTTP_200_OK
@@ -192,12 +209,12 @@ class TestUpdateCurrentUser:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_update_profile_phone_number(self, client, async_test_user, test_db): async def test_update_profile_phone_number(self, client, async_test_user, test_db):
"""Test updating phone number with validation.""" """Test updating phone number with validation."""
headers = await get_auth_headers(client, async_test_user.email, "TestPassword123!") headers = await get_auth_headers(
client, async_test_user.email, "TestPassword123!"
)
response = await client.patch( response = await client.patch(
"/api/v1/users/me", "/api/v1/users/me", headers=headers, json={"phone_number": "+19876543210"}
headers=headers,
json={"phone_number": "+19876543210"}
) )
assert response.status_code == status.HTTP_200_OK assert response.status_code == status.HTTP_200_OK
@@ -207,12 +224,12 @@ class TestUpdateCurrentUser:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_update_profile_invalid_phone(self, client, async_test_user): async def test_update_profile_invalid_phone(self, client, async_test_user):
"""Test that invalid phone numbers are rejected.""" """Test that invalid phone numbers are rejected."""
headers = await get_auth_headers(client, async_test_user.email, "TestPassword123!") headers = await get_auth_headers(
client, async_test_user.email, "TestPassword123!"
)
response = await client.patch( response = await client.patch(
"/api/v1/users/me", "/api/v1/users/me", headers=headers, json={"phone_number": "invalid"}
headers=headers,
json={"phone_number": "invalid"}
) )
assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY
@@ -220,14 +237,16 @@ class TestUpdateCurrentUser:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_cannot_elevate_to_superuser(self, client, async_test_user): async def test_cannot_elevate_to_superuser(self, client, async_test_user):
"""Test that users cannot make themselves superuser.""" """Test that users cannot make themselves superuser."""
headers = await get_auth_headers(client, async_test_user.email, "TestPassword123!") headers = await get_auth_headers(
client, async_test_user.email, "TestPassword123!"
)
# Note: is_superuser is now in UserUpdate schema with explicit validation # Note: is_superuser is now in UserUpdate schema with explicit validation
# This tests that Pydantic rejects the attempt at the schema level # This tests that Pydantic rejects the attempt at the schema level
response = await client.patch( response = await client.patch(
"/api/v1/users/me", "/api/v1/users/me",
headers=headers, headers=headers,
json={"first_name": "Test", "is_superuser": True} json={"first_name": "Test", "is_superuser": True},
) )
# Pydantic validation should reject this at the schema level # Pydantic validation should reject this at the schema level
@@ -242,10 +261,7 @@ class TestUpdateCurrentUser:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_update_profile_no_auth(self, client): async def test_update_profile_no_auth(self, client):
"""Test that unauthenticated requests are rejected.""" """Test that unauthenticated requests are rejected."""
response = await client.patch( response = await client.patch("/api/v1/users/me", json={"first_name": "Hacker"})
"/api/v1/users/me",
json={"first_name": "Hacker"}
)
assert response.status_code == status.HTTP_401_UNAUTHORIZED assert response.status_code == status.HTTP_401_UNAUTHORIZED
# Note: Removed test_update_profile_unexpected_error - see comment above # Note: Removed test_update_profile_unexpected_error - see comment above
@@ -257,16 +273,22 @@ class TestGetUserById:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_get_own_profile_by_id(self, client, async_test_user): async def test_get_own_profile_by_id(self, client, async_test_user):
"""Test getting own profile by ID.""" """Test getting own profile by ID."""
headers = await get_auth_headers(client, async_test_user.email, "TestPassword123!") headers = await get_auth_headers(
client, async_test_user.email, "TestPassword123!"
)
response = await client.get(f"/api/v1/users/{async_test_user.id}", headers=headers) response = await client.get(
f"/api/v1/users/{async_test_user.id}", headers=headers
)
assert response.status_code == status.HTTP_200_OK assert response.status_code == status.HTTP_200_OK
data = response.json() data = response.json()
assert data["email"] == async_test_user.email assert data["email"] == async_test_user.email
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_get_other_user_as_regular_user(self, client, async_test_user, test_db): async def test_get_other_user_as_regular_user(
self, client, async_test_user, test_db
):
"""Test that regular users cannot view other profiles.""" """Test that regular users cannot view other profiles."""
# Create another user # Create another user
other_user = User( other_user = User(
@@ -274,24 +296,32 @@ class TestGetUserById:
password_hash="hash", password_hash="hash",
first_name="Other", first_name="Other",
is_active=True, is_active=True,
is_superuser=False is_superuser=False,
) )
test_db.add(other_user) test_db.add(other_user)
test_db.commit() test_db.commit()
test_db.refresh(other_user) test_db.refresh(other_user)
headers = await get_auth_headers(client, async_test_user.email, "TestPassword123!") headers = await get_auth_headers(
client, async_test_user.email, "TestPassword123!"
)
response = await client.get(f"/api/v1/users/{other_user.id}", headers=headers) response = await client.get(f"/api/v1/users/{other_user.id}", headers=headers)
assert response.status_code == status.HTTP_403_FORBIDDEN assert response.status_code == status.HTTP_403_FORBIDDEN
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_get_other_user_as_superuser(self, client, async_test_superuser, async_test_user): async def test_get_other_user_as_superuser(
self, client, async_test_superuser, async_test_user
):
"""Test that superusers can view other profiles.""" """Test that superusers can view other profiles."""
headers = await get_auth_headers(client, async_test_superuser.email, "SuperPassword123!") headers = await get_auth_headers(
client, async_test_superuser.email, "SuperPassword123!"
)
response = await client.get(f"/api/v1/users/{async_test_user.id}", headers=headers) response = await client.get(
f"/api/v1/users/{async_test_user.id}", headers=headers
)
assert response.status_code == status.HTTP_200_OK assert response.status_code == status.HTTP_200_OK
data = response.json() data = response.json()
@@ -300,7 +330,9 @@ class TestGetUserById:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_get_nonexistent_user(self, client, async_test_superuser): async def test_get_nonexistent_user(self, client, async_test_superuser):
"""Test getting non-existent user.""" """Test getting non-existent user."""
headers = await get_auth_headers(client, async_test_superuser.email, "SuperPassword123!") headers = await get_auth_headers(
client, async_test_superuser.email, "SuperPassword123!"
)
fake_id = uuid.uuid4() fake_id = uuid.uuid4()
response = await client.get(f"/api/v1/users/{fake_id}", headers=headers) response = await client.get(f"/api/v1/users/{fake_id}", headers=headers)
@@ -310,7 +342,9 @@ class TestGetUserById:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_get_user_invalid_uuid(self, client, async_test_superuser): async def test_get_user_invalid_uuid(self, client, async_test_superuser):
"""Test getting user with invalid UUID format.""" """Test getting user with invalid UUID format."""
headers = await get_auth_headers(client, async_test_superuser.email, "SuperPassword123!") headers = await get_auth_headers(
client, async_test_superuser.email, "SuperPassword123!"
)
response = await client.get("/api/v1/users/not-a-uuid", headers=headers) response = await client.get("/api/v1/users/not-a-uuid", headers=headers)
@@ -323,12 +357,14 @@ class TestUpdateUserById:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_update_own_profile_by_id(self, client, async_test_user, test_db): async def test_update_own_profile_by_id(self, client, async_test_user, test_db):
"""Test updating own profile by ID.""" """Test updating own profile by ID."""
headers = await get_auth_headers(client, async_test_user.email, "TestPassword123!") headers = await get_auth_headers(
client, async_test_user.email, "TestPassword123!"
)
response = await client.patch( response = await client.patch(
f"/api/v1/users/{async_test_user.id}", f"/api/v1/users/{async_test_user.id}",
headers=headers, headers=headers,
json={"first_name": "SelfUpdated"} json={"first_name": "SelfUpdated"},
) )
assert response.status_code == status.HTTP_200_OK assert response.status_code == status.HTTP_200_OK
@@ -336,7 +372,9 @@ class TestUpdateUserById:
assert data["first_name"] == "SelfUpdated" assert data["first_name"] == "SelfUpdated"
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_update_other_user_as_regular_user(self, client, async_test_user, test_db): async def test_update_other_user_as_regular_user(
self, client, async_test_user, test_db
):
"""Test that regular users cannot update other profiles.""" """Test that regular users cannot update other profiles."""
# Create another user # Create another user
other_user = User( other_user = User(
@@ -344,18 +382,20 @@ class TestUpdateUserById:
password_hash="hash", password_hash="hash",
first_name="Other", first_name="Other",
is_active=True, is_active=True,
is_superuser=False is_superuser=False,
) )
test_db.add(other_user) test_db.add(other_user)
test_db.commit() test_db.commit()
test_db.refresh(other_user) test_db.refresh(other_user)
headers = await get_auth_headers(client, async_test_user.email, "TestPassword123!") headers = await get_auth_headers(
client, async_test_user.email, "TestPassword123!"
)
response = await client.patch( response = await client.patch(
f"/api/v1/users/{other_user.id}", f"/api/v1/users/{other_user.id}",
headers=headers, headers=headers,
json={"first_name": "Hacked"} json={"first_name": "Hacked"},
) )
assert response.status_code == status.HTTP_403_FORBIDDEN assert response.status_code == status.HTTP_403_FORBIDDEN
@@ -365,14 +405,18 @@ class TestUpdateUserById:
assert other_user.first_name == "Other" assert other_user.first_name == "Other"
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_update_other_user_as_superuser(self, client, async_test_superuser, async_test_user, test_db): async def test_update_other_user_as_superuser(
self, client, async_test_superuser, async_test_user, test_db
):
"""Test that superusers can update other profiles.""" """Test that superusers can update other profiles."""
headers = await get_auth_headers(client, async_test_superuser.email, "SuperPassword123!") headers = await get_auth_headers(
client, async_test_superuser.email, "SuperPassword123!"
)
response = await client.patch( response = await client.patch(
f"/api/v1/users/{async_test_user.id}", f"/api/v1/users/{async_test_user.id}",
headers=headers, headers=headers,
json={"first_name": "AdminUpdated"} json={"first_name": "AdminUpdated"},
) )
assert response.status_code == status.HTTP_200_OK assert response.status_code == status.HTTP_200_OK
@@ -380,16 +424,20 @@ class TestUpdateUserById:
assert data["first_name"] == "AdminUpdated" assert data["first_name"] == "AdminUpdated"
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_regular_user_cannot_modify_superuser_status(self, client, async_test_user): async def test_regular_user_cannot_modify_superuser_status(
self, client, async_test_user
):
"""Test that regular users cannot change superuser status even if they try.""" """Test that regular users cannot change superuser status even if they try."""
headers = await get_auth_headers(client, async_test_user.email, "TestPassword123!") headers = await get_auth_headers(
client, async_test_user.email, "TestPassword123!"
)
# is_superuser not in UserUpdate schema, so it gets ignored by Pydantic # is_superuser not in UserUpdate schema, so it gets ignored by Pydantic
# Just verify the user stays the same # Just verify the user stays the same
response = await client.patch( response = await client.patch(
f"/api/v1/users/{async_test_user.id}", f"/api/v1/users/{async_test_user.id}",
headers=headers, headers=headers,
json={"first_name": "Test"} json={"first_name": "Test"},
) )
assert response.status_code == status.HTTP_200_OK assert response.status_code == status.HTTP_200_OK
@@ -397,14 +445,18 @@ class TestUpdateUserById:
assert data["is_superuser"] is False assert data["is_superuser"] is False
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_superuser_can_update_users(self, client, async_test_superuser, async_test_user, test_db): async def test_superuser_can_update_users(
self, client, async_test_superuser, async_test_user, test_db
):
"""Test that superusers can update other users.""" """Test that superusers can update other users."""
headers = await get_auth_headers(client, async_test_superuser.email, "SuperPassword123!") headers = await get_auth_headers(
client, async_test_superuser.email, "SuperPassword123!"
)
response = await client.patch( response = await client.patch(
f"/api/v1/users/{async_test_user.id}", f"/api/v1/users/{async_test_user.id}",
headers=headers, headers=headers,
json={"first_name": "AdminChanged", "is_active": False} json={"first_name": "AdminChanged", "is_active": False},
) )
assert response.status_code == status.HTTP_200_OK assert response.status_code == status.HTTP_200_OK
@@ -415,13 +467,13 @@ class TestUpdateUserById:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_update_nonexistent_user(self, client, async_test_superuser): async def test_update_nonexistent_user(self, client, async_test_superuser):
"""Test updating non-existent user.""" """Test updating non-existent user."""
headers = await get_auth_headers(client, async_test_superuser.email, "SuperPassword123!") headers = await get_auth_headers(
client, async_test_superuser.email, "SuperPassword123!"
)
fake_id = uuid.uuid4() fake_id = uuid.uuid4()
response = await client.patch( response = await client.patch(
f"/api/v1/users/{fake_id}", f"/api/v1/users/{fake_id}", headers=headers, json={"first_name": "Ghost"}
headers=headers,
json={"first_name": "Ghost"}
) )
assert response.status_code == status.HTTP_404_NOT_FOUND assert response.status_code == status.HTTP_404_NOT_FOUND
@@ -435,15 +487,17 @@ class TestChangePassword:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_change_password_success(self, client, async_test_user, test_db): async def test_change_password_success(self, client, async_test_user, test_db):
"""Test successful password change.""" """Test successful password change."""
headers = await get_auth_headers(client, async_test_user.email, "TestPassword123!") headers = await get_auth_headers(
client, async_test_user.email, "TestPassword123!"
)
response = await client.patch( response = await client.patch(
"/api/v1/users/me/password", "/api/v1/users/me/password",
headers=headers, headers=headers,
json={ json={
"current_password": "TestPassword123!", "current_password": "TestPassword123!",
"new_password": "NewPassword123!" "new_password": "NewPassword123!",
} },
) )
assert response.status_code == status.HTTP_200_OK assert response.status_code == status.HTTP_200_OK
@@ -453,25 +507,24 @@ class TestChangePassword:
# Verify can login with new password # Verify can login with new password
login_response = await client.post( login_response = await client.post(
"/api/v1/auth/login", "/api/v1/auth/login",
json={ json={"email": async_test_user.email, "password": "NewPassword123!"},
"email": async_test_user.email,
"password": "NewPassword123!"
}
) )
assert login_response.status_code == status.HTTP_200_OK assert login_response.status_code == status.HTTP_200_OK
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_change_password_wrong_current(self, client, async_test_user): async def test_change_password_wrong_current(self, client, async_test_user):
"""Test that wrong current password is rejected.""" """Test that wrong current password is rejected."""
headers = await get_auth_headers(client, async_test_user.email, "TestPassword123!") headers = await get_auth_headers(
client, async_test_user.email, "TestPassword123!"
)
response = await client.patch( response = await client.patch(
"/api/v1/users/me/password", "/api/v1/users/me/password",
headers=headers, headers=headers,
json={ json={
"current_password": "WrongPassword123", "current_password": "WrongPassword123",
"new_password": "NewPassword123!" "new_password": "NewPassword123!",
} },
) )
assert response.status_code == status.HTTP_403_FORBIDDEN assert response.status_code == status.HTTP_403_FORBIDDEN
@@ -479,15 +532,14 @@ class TestChangePassword:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_change_password_weak_new_password(self, client, async_test_user): async def test_change_password_weak_new_password(self, client, async_test_user):
"""Test that weak new passwords are rejected.""" """Test that weak new passwords are rejected."""
headers = await get_auth_headers(client, async_test_user.email, "TestPassword123!") headers = await get_auth_headers(
client, async_test_user.email, "TestPassword123!"
)
response = await client.patch( response = await client.patch(
"/api/v1/users/me/password", "/api/v1/users/me/password",
headers=headers, headers=headers,
json={ json={"current_password": "TestPassword123!", "new_password": "weak"},
"current_password": "TestPassword123!",
"new_password": "weak"
}
) )
assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY
@@ -499,8 +551,8 @@ class TestChangePassword:
"/api/v1/users/me/password", "/api/v1/users/me/password",
json={ json={
"current_password": "TestPassword123!", "current_password": "TestPassword123!",
"new_password": "NewPassword123!" "new_password": "NewPassword123!",
} },
) )
assert response.status_code == status.HTTP_401_UNAUTHORIZED assert response.status_code == status.HTTP_401_UNAUTHORIZED
@@ -511,9 +563,11 @@ class TestDeleteUser:
"""Tests for DELETE /users/{user_id} endpoint.""" """Tests for DELETE /users/{user_id} endpoint."""
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_delete_user_as_superuser(self, client, async_test_superuser, async_test_db): async def test_delete_user_as_superuser(
self, client, async_test_superuser, async_test_db
):
"""Test deleting a user as superuser.""" """Test deleting a user as superuser."""
test_engine, AsyncTestingSessionLocal = async_test_db _test_engine, AsyncTestingSessionLocal = async_test_db
# Create a user to delete # Create a user to delete
async with AsyncTestingSessionLocal() as session: async with AsyncTestingSessionLocal() as session:
@@ -522,14 +576,16 @@ class TestDeleteUser:
password_hash="hash", password_hash="hash",
first_name="Delete", first_name="Delete",
is_active=True, is_active=True,
is_superuser=False is_superuser=False,
) )
session.add(user_to_delete) session.add(user_to_delete)
await session.commit() await session.commit()
await session.refresh(user_to_delete) await session.refresh(user_to_delete)
user_id = user_to_delete.id user_id = user_to_delete.id
headers = await get_auth_headers(client, async_test_superuser.email, "SuperPassword123!") headers = await get_auth_headers(
client, async_test_superuser.email, "SuperPassword123!"
)
response = await client.delete(f"/api/v1/users/{user_id}", headers=headers) response = await client.delete(f"/api/v1/users/{user_id}", headers=headers)
@@ -540,6 +596,7 @@ class TestDeleteUser:
# Verify user is soft-deleted (has deleted_at timestamp) # Verify user is soft-deleted (has deleted_at timestamp)
async with AsyncTestingSessionLocal() as session: async with AsyncTestingSessionLocal() as session:
from sqlalchemy import select from sqlalchemy import select
result = await session.execute(select(User).where(User.id == user_id)) result = await session.execute(select(User).where(User.id == user_id))
deleted_user = result.scalar_one_or_none() deleted_user = result.scalar_one_or_none()
assert deleted_user.deleted_at is not None assert deleted_user.deleted_at is not None
@@ -547,9 +604,13 @@ class TestDeleteUser:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_cannot_delete_self(self, client, async_test_superuser): async def test_cannot_delete_self(self, client, async_test_superuser):
"""Test that users cannot delete their own account.""" """Test that users cannot delete their own account."""
headers = await get_auth_headers(client, async_test_superuser.email, "SuperPassword123!") headers = await get_auth_headers(
client, async_test_superuser.email, "SuperPassword123!"
)
response = await client.delete(f"/api/v1/users/{async_test_superuser.id}", headers=headers) response = await client.delete(
f"/api/v1/users/{async_test_superuser.id}", headers=headers
)
assert response.status_code == status.HTTP_403_FORBIDDEN assert response.status_code == status.HTTP_403_FORBIDDEN
@@ -562,22 +623,28 @@ class TestDeleteUser:
password_hash="hash", password_hash="hash",
first_name="Protected", first_name="Protected",
is_active=True, is_active=True,
is_superuser=False is_superuser=False,
) )
test_db.add(other_user) test_db.add(other_user)
test_db.commit() test_db.commit()
test_db.refresh(other_user) test_db.refresh(other_user)
headers = await get_auth_headers(client, async_test_user.email, "TestPassword123!") headers = await get_auth_headers(
client, async_test_user.email, "TestPassword123!"
)
response = await client.delete(f"/api/v1/users/{other_user.id}", headers=headers) response = await client.delete(
f"/api/v1/users/{other_user.id}", headers=headers
)
assert response.status_code == status.HTTP_403_FORBIDDEN assert response.status_code == status.HTTP_403_FORBIDDEN
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_delete_nonexistent_user(self, client, async_test_superuser): async def test_delete_nonexistent_user(self, client, async_test_superuser):
"""Test deleting non-existent user.""" """Test deleting non-existent user."""
headers = await get_auth_headers(client, async_test_superuser.email, "SuperPassword123!") headers = await get_auth_headers(
client, async_test_superuser.email, "SuperPassword123!"
)
fake_id = uuid.uuid4() fake_id = uuid.uuid4()
response = await client.delete(f"/api/v1/users/{fake_id}", headers=headers) response = await client.delete(f"/api/v1/users/{fake_id}", headers=headers)

View File

@@ -2,10 +2,12 @@
""" """
Tests for user routes. Tests for user routes.
""" """
from uuid import uuid4
import pytest import pytest
import pytest_asyncio import pytest_asyncio
from fastapi import status from fastapi import status
from uuid import uuid4
@pytest_asyncio.fixture @pytest_asyncio.fixture
@@ -13,10 +15,7 @@ async def superuser_token(client, async_test_superuser):
"""Get access token for superuser.""" """Get access token for superuser."""
response = await client.post( response = await client.post(
"/api/v1/auth/login", "/api/v1/auth/login",
json={ json={"email": "superuser@example.com", "password": "SuperPassword123!"},
"email": "superuser@example.com",
"password": "SuperPassword123!"
}
) )
assert response.status_code == 200 assert response.status_code == 200
return response.json()["access_token"] return response.json()["access_token"]
@@ -27,10 +26,7 @@ async def user_token(client, async_test_user):
"""Get access token for regular user.""" """Get access token for regular user."""
response = await client.post( response = await client.post(
"/api/v1/auth/login", "/api/v1/auth/login",
json={ json={"email": "testuser@example.com", "password": "TestPassword123!"},
"email": "testuser@example.com",
"password": "TestPassword123!"
}
) )
assert response.status_code == 200 assert response.status_code == 200
return response.json()["access_token"] return response.json()["access_token"]
@@ -43,8 +39,7 @@ class TestListUsers:
async def test_list_users_success(self, client, superuser_token): async def test_list_users_success(self, client, superuser_token):
"""Test listing users successfully (covers lines 87-100).""" """Test listing users successfully (covers lines 87-100)."""
response = await client.get( response = await client.get(
"/api/v1/users", "/api/v1/users", headers={"Authorization": f"Bearer {superuser_token}"}
headers={"Authorization": f"Bearer {superuser_token}"}
) )
assert response.status_code == status.HTTP_200_OK assert response.status_code == status.HTTP_200_OK
@@ -58,7 +53,7 @@ class TestListUsers:
"""Test listing users with is_superuser filter (covers line 74).""" """Test listing users with is_superuser filter (covers line 74)."""
response = await client.get( response = await client.get(
"/api/v1/users?is_superuser=true", "/api/v1/users?is_superuser=true",
headers={"Authorization": f"Bearer {superuser_token}"} headers={"Authorization": f"Bearer {superuser_token}"},
) )
assert response.status_code == status.HTTP_200_OK assert response.status_code == status.HTTP_200_OK
@@ -73,8 +68,7 @@ class TestGetCurrentUser:
async def test_get_current_user_success(self, client, async_test_user, user_token): async def test_get_current_user_success(self, client, async_test_user, user_token):
"""Test getting current user profile.""" """Test getting current user profile."""
response = await client.get( response = await client.get(
"/api/v1/users/me", "/api/v1/users/me", headers={"Authorization": f"Bearer {user_token}"}
headers={"Authorization": f"Bearer {user_token}"}
) )
assert response.status_code == status.HTTP_200_OK assert response.status_code == status.HTTP_200_OK
@@ -92,7 +86,7 @@ class TestUpdateCurrentUser:
response = await client.patch( response = await client.patch(
"/api/v1/users/me", "/api/v1/users/me",
headers={"Authorization": f"Bearer {user_token}"}, headers={"Authorization": f"Bearer {user_token}"},
json={"first_name": "UpdatedName"} json={"first_name": "UpdatedName"},
) )
assert response.status_code == status.HTTP_200_OK assert response.status_code == status.HTTP_200_OK
@@ -104,12 +98,14 @@ class TestUpdateCurrentUser:
"""Test database error handling during update (covers lines 162-169).""" """Test database error handling during update (covers lines 162-169)."""
from unittest.mock import patch from unittest.mock import patch
with patch('app.api.routes.users.user_crud.update', side_effect=Exception("DB error")): with patch(
"app.api.routes.users.user_crud.update", side_effect=Exception("DB error")
):
with pytest.raises(Exception): with pytest.raises(Exception):
await client.patch( await client.patch(
"/api/v1/users/me", "/api/v1/users/me",
headers={"Authorization": f"Bearer {user_token}"}, headers={"Authorization": f"Bearer {user_token}"},
json={"first_name": "Updated"} json={"first_name": "Updated"},
) )
@pytest.mark.asyncio @pytest.mark.asyncio
@@ -118,7 +114,7 @@ class TestUpdateCurrentUser:
response = await client.patch( response = await client.patch(
"/api/v1/users/me", "/api/v1/users/me",
headers={"Authorization": f"Bearer {user_token}"}, headers={"Authorization": f"Bearer {user_token}"},
json={"is_superuser": True} json={"is_superuser": True},
) )
# Pydantic validation should reject this at the schema level # Pydantic validation should reject this at the schema level
@@ -137,12 +133,15 @@ class TestUpdateCurrentUser:
"""Test ValueError handling during update (covers lines 165-166).""" """Test ValueError handling during update (covers lines 165-166)."""
from unittest.mock import patch from unittest.mock import patch
with patch('app.api.routes.users.user_crud.update', side_effect=ValueError("Invalid value")): with patch(
"app.api.routes.users.user_crud.update",
side_effect=ValueError("Invalid value"),
):
with pytest.raises(ValueError): with pytest.raises(ValueError):
await client.patch( await client.patch(
"/api/v1/users/me", "/api/v1/users/me",
headers={"Authorization": f"Bearer {user_token}"}, headers={"Authorization": f"Bearer {user_token}"},
json={"first_name": "Updated"} json={"first_name": "Updated"},
) )
@@ -154,7 +153,7 @@ class TestGetUser:
"""Test getting user by ID.""" """Test getting user by ID."""
response = await client.get( response = await client.get(
f"/api/v1/users/{async_test_user.id}", f"/api/v1/users/{async_test_user.id}",
headers={"Authorization": f"Bearer {superuser_token}"} headers={"Authorization": f"Bearer {superuser_token}"},
) )
assert response.status_code == status.HTTP_200_OK assert response.status_code == status.HTTP_200_OK
@@ -167,7 +166,7 @@ class TestGetUser:
fake_id = uuid4() fake_id = uuid4()
response = await client.get( response = await client.get(
f"/api/v1/users/{fake_id}", f"/api/v1/users/{fake_id}",
headers={"Authorization": f"Bearer {superuser_token}"} headers={"Authorization": f"Bearer {superuser_token}"},
) )
assert response.status_code == status.HTTP_404_NOT_FOUND assert response.status_code == status.HTTP_404_NOT_FOUND
@@ -183,30 +182,34 @@ class TestUpdateUserById:
response = await client.patch( response = await client.patch(
f"/api/v1/users/{fake_id}", f"/api/v1/users/{fake_id}",
headers={"Authorization": f"Bearer {superuser_token}"}, headers={"Authorization": f"Bearer {superuser_token}"},
json={"first_name": "Updated"} json={"first_name": "Updated"},
) )
assert response.status_code == status.HTTP_404_NOT_FOUND assert response.status_code == status.HTTP_404_NOT_FOUND
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_update_user_by_id_non_superuser_cannot_change_superuser_status(self, client, async_test_user, user_token): async def test_update_user_by_id_non_superuser_cannot_change_superuser_status(
self, client, async_test_user, user_token
):
"""Test non-superuser cannot modify superuser status (Pydantic validation).""" """Test non-superuser cannot modify superuser status (Pydantic validation)."""
response = await client.patch( response = await client.patch(
f"/api/v1/users/{async_test_user.id}", f"/api/v1/users/{async_test_user.id}",
headers={"Authorization": f"Bearer {user_token}"}, headers={"Authorization": f"Bearer {user_token}"},
json={"is_superuser": True} json={"is_superuser": True},
) )
# Pydantic validation should reject this at the schema level # Pydantic validation should reject this at the schema level
assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_update_user_by_id_success(self, client, async_test_user, superuser_token): async def test_update_user_by_id_success(
self, client, async_test_user, superuser_token
):
"""Test updating user successfully (covers lines 276-278).""" """Test updating user successfully (covers lines 276-278)."""
response = await client.patch( response = await client.patch(
f"/api/v1/users/{async_test_user.id}", f"/api/v1/users/{async_test_user.id}",
headers={"Authorization": f"Bearer {superuser_token}"}, headers={"Authorization": f"Bearer {superuser_token}"},
json={"first_name": "SuperUpdated"} json={"first_name": "SuperUpdated"},
) )
assert response.status_code == status.HTTP_200_OK assert response.status_code == status.HTTP_200_OK
@@ -214,29 +217,37 @@ class TestUpdateUserById:
assert data["first_name"] == "SuperUpdated" assert data["first_name"] == "SuperUpdated"
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_update_user_by_id_value_error(self, client, async_test_user, superuser_token): async def test_update_user_by_id_value_error(
self, client, async_test_user, superuser_token
):
"""Test ValueError handling (covers lines 280-281).""" """Test ValueError handling (covers lines 280-281)."""
from unittest.mock import patch from unittest.mock import patch
with patch('app.api.routes.users.user_crud.update', side_effect=ValueError("Invalid")): with patch(
"app.api.routes.users.user_crud.update", side_effect=ValueError("Invalid")
):
with pytest.raises(ValueError): with pytest.raises(ValueError):
await client.patch( await client.patch(
f"/api/v1/users/{async_test_user.id}", f"/api/v1/users/{async_test_user.id}",
headers={"Authorization": f"Bearer {superuser_token}"}, headers={"Authorization": f"Bearer {superuser_token}"},
json={"first_name": "Updated"} json={"first_name": "Updated"},
) )
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_update_user_by_id_unexpected_error(self, client, async_test_user, superuser_token): async def test_update_user_by_id_unexpected_error(
self, client, async_test_user, superuser_token
):
"""Test unexpected error handling (covers lines 283-284).""" """Test unexpected error handling (covers lines 283-284)."""
from unittest.mock import patch from unittest.mock import patch
with patch('app.api.routes.users.user_crud.update', side_effect=Exception("Unexpected")): with patch(
"app.api.routes.users.user_crud.update", side_effect=Exception("Unexpected")
):
with pytest.raises(Exception): with pytest.raises(Exception):
await client.patch( await client.patch(
f"/api/v1/users/{async_test_user.id}", f"/api/v1/users/{async_test_user.id}",
headers={"Authorization": f"Bearer {superuser_token}"}, headers={"Authorization": f"Bearer {superuser_token}"},
json={"first_name": "Updated"} json={"first_name": "Updated"},
) )
@@ -246,18 +257,18 @@ class TestChangePassword:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_change_password_success(self, client, async_test_db): async def test_change_password_success(self, client, async_test_db):
"""Test changing password successfully.""" """Test changing password successfully."""
test_engine, AsyncTestingSessionLocal = async_test_db _test_engine, AsyncTestingSessionLocal = async_test_db
# Create a fresh user # Create a fresh user
async with AsyncTestingSessionLocal() as session: async with AsyncTestingSessionLocal() as session:
from app.models.user import User
from app.core.auth import get_password_hash from app.core.auth import get_password_hash
from app.models.user import User
new_user = User( new_user = User(
email="changepass@example.com", email="changepass@example.com",
password_hash=get_password_hash("OldPassword123!"), password_hash=get_password_hash("OldPassword123!"),
first_name="Change", first_name="Change",
last_name="Pass" last_name="Pass",
) )
session.add(new_user) session.add(new_user)
await session.commit() await session.commit()
@@ -265,10 +276,7 @@ class TestChangePassword:
# Login # Login
login_response = await client.post( login_response = await client.post(
"/api/v1/auth/login", "/api/v1/auth/login",
json={ json={"email": "changepass@example.com", "password": "OldPassword123!"},
"email": "changepass@example.com",
"password": "OldPassword123!"
}
) )
token = login_response.json()["access_token"] token = login_response.json()["access_token"]
@@ -278,8 +286,8 @@ class TestChangePassword:
headers={"Authorization": f"Bearer {token}"}, headers={"Authorization": f"Bearer {token}"},
json={ json={
"current_password": "OldPassword123!", "current_password": "OldPassword123!",
"new_password": "NewPassword456!" "new_password": "NewPassword456!",
} },
) )
assert response.status_code == status.HTTP_200_OK assert response.status_code == status.HTTP_200_OK
@@ -289,10 +297,7 @@ class TestChangePassword:
# Verify new password works # Verify new password works
login_response = await client.post( login_response = await client.post(
"/api/v1/auth/login", "/api/v1/auth/login",
json={ json={"email": "changepass@example.com", "password": "NewPassword456!"},
"email": "changepass@example.com",
"password": "NewPassword456!"
}
) )
assert login_response.status_code == status.HTTP_200_OK assert login_response.status_code == status.HTTP_200_OK
@@ -306,7 +311,7 @@ class TestDeleteUserById:
fake_id = uuid4() fake_id = uuid4()
response = await client.delete( response = await client.delete(
f"/api/v1/users/{fake_id}", f"/api/v1/users/{fake_id}",
headers={"Authorization": f"Bearer {superuser_token}"} headers={"Authorization": f"Bearer {superuser_token}"},
) )
assert response.status_code == status.HTTP_404_NOT_FOUND assert response.status_code == status.HTTP_404_NOT_FOUND
@@ -314,18 +319,18 @@ class TestDeleteUserById:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_delete_user_success(self, client, async_test_db, superuser_token): async def test_delete_user_success(self, client, async_test_db, superuser_token):
"""Test deleting user successfully (covers lines 383-388).""" """Test deleting user successfully (covers lines 383-388)."""
test_engine, AsyncTestingSessionLocal = async_test_db _test_engine, AsyncTestingSessionLocal = async_test_db
# Create a user to delete # Create a user to delete
async with AsyncTestingSessionLocal() as session: async with AsyncTestingSessionLocal() as session:
from app.models.user import User
from app.core.auth import get_password_hash from app.core.auth import get_password_hash
from app.models.user import User
user_to_delete = User( user_to_delete = User(
email=f"delete{uuid4().hex[:8]}@example.com", email=f"delete{uuid4().hex[:8]}@example.com",
password_hash=get_password_hash("Password123!"), password_hash=get_password_hash("Password123!"),
first_name="Delete", first_name="Delete",
last_name="Me" last_name="Me",
) )
session.add(user_to_delete) session.add(user_to_delete)
await session.commit() await session.commit()
@@ -334,7 +339,7 @@ class TestDeleteUserById:
response = await client.delete( response = await client.delete(
f"/api/v1/users/{user_id}", f"/api/v1/users/{user_id}",
headers={"Authorization": f"Bearer {superuser_token}"} headers={"Authorization": f"Bearer {superuser_token}"},
) )
assert response.status_code == status.HTTP_200_OK assert response.status_code == status.HTTP_200_OK
@@ -342,25 +347,35 @@ class TestDeleteUserById:
assert data["success"] is True assert data["success"] is True
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_delete_user_value_error(self, client, async_test_user, superuser_token): async def test_delete_user_value_error(
self, client, async_test_user, superuser_token
):
"""Test ValueError handling during delete (covers lines 390-391).""" """Test ValueError handling during delete (covers lines 390-391)."""
from unittest.mock import patch from unittest.mock import patch
with patch('app.api.routes.users.user_crud.soft_delete', side_effect=ValueError("Cannot delete")): with patch(
"app.api.routes.users.user_crud.soft_delete",
side_effect=ValueError("Cannot delete"),
):
with pytest.raises(ValueError): with pytest.raises(ValueError):
await client.delete( await client.delete(
f"/api/v1/users/{async_test_user.id}", f"/api/v1/users/{async_test_user.id}",
headers={"Authorization": f"Bearer {superuser_token}"} headers={"Authorization": f"Bearer {superuser_token}"},
) )
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_delete_user_unexpected_error(self, client, async_test_user, superuser_token): async def test_delete_user_unexpected_error(
self, client, async_test_user, superuser_token
):
"""Test unexpected error handling during delete (covers lines 393-394).""" """Test unexpected error handling during delete (covers lines 393-394)."""
from unittest.mock import patch from unittest.mock import patch
with patch('app.api.routes.users.user_crud.soft_delete', side_effect=Exception("Unexpected")): with patch(
"app.api.routes.users.user_crud.soft_delete",
side_effect=Exception("Unexpected"),
):
with pytest.raises(Exception): with pytest.raises(Exception):
await client.delete( await client.delete(
f"/api/v1/users/{async_test_user.id}", f"/api/v1/users/{async_test_user.id}",
headers={"Authorization": f"Bearer {superuser_token}"} headers={"Authorization": f"Bearer {superuser_token}"},
) )

View File

@@ -1,21 +1,25 @@
# tests/conftest.py # tests/conftest.py
import os import os
import uuid import uuid
from datetime import datetime, timezone
import pytest import pytest
import pytest_asyncio import pytest_asyncio
from httpx import AsyncClient, ASGITransport from httpx import ASGITransport, AsyncClient
# Set IS_TEST environment variable BEFORE importing app # Set IS_TEST environment variable BEFORE importing app
# This prevents the scheduler from starting during tests # This prevents the scheduler from starting during tests
os.environ["IS_TEST"] = "True" os.environ["IS_TEST"] = "True"
from app.main import app
from app.core.database import get_db
from app.models.user import User
from app.core.auth import get_password_hash from app.core.auth import get_password_hash
from app.utils.test_utils import setup_test_db, teardown_test_db, setup_async_test_db, teardown_async_test_db from app.core.database import get_db
from app.main import app
from app.models.user import User
from app.utils.test_utils import (
setup_async_test_db,
setup_test_db,
teardown_async_test_db,
teardown_test_db,
)
@pytest.fixture(scope="function") @pytest.fixture(scope="function")
@@ -46,6 +50,7 @@ async def async_test_db():
yield test_engine, AsyncTestingSessionLocal yield test_engine, AsyncTestingSessionLocal
await teardown_async_test_db(test_engine) await teardown_async_test_db(test_engine)
@pytest.fixture @pytest.fixture
def user_create_data(): def user_create_data():
return { return {
@@ -55,7 +60,7 @@ def user_create_data():
"last_name": "User", "last_name": "User",
"phone_number": "+1234567890", "phone_number": "+1234567890",
"is_superuser": False, "is_superuser": False,
"preferences": None "preferences": None,
} }
@@ -102,7 +107,7 @@ async def client(async_test_db):
This overrides the get_db dependency to use the test database. This overrides the get_db dependency to use the test database.
""" """
test_engine, AsyncTestingSessionLocal = async_test_db _test_engine, AsyncTestingSessionLocal = async_test_db
async def override_get_db(): async def override_get_db():
async with AsyncTestingSessionLocal() as session: async with AsyncTestingSessionLocal() as session:
@@ -176,7 +181,7 @@ async def async_test_user(async_test_db):
Password: TestPassword123 Password: TestPassword123
""" """
test_engine, AsyncTestingSessionLocal = async_test_db _test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session: async with AsyncTestingSessionLocal() as session:
user = User( user = User(
id=uuid.uuid4(), id=uuid.uuid4(),
@@ -202,7 +207,7 @@ async def async_test_superuser(async_test_db):
Password: SuperPassword123 Password: SuperPassword123
""" """
test_engine, AsyncTestingSessionLocal = async_test_db _test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session: async with AsyncTestingSessionLocal() as session:
user = User( user = User(
id=uuid.uuid4(), id=uuid.uuid4(),

View File

@@ -1,20 +1,20 @@
# tests/core/test_auth.py # tests/core/test_auth.py
import uuid import uuid
from datetime import UTC, datetime, timedelta
import pytest import pytest
from datetime import datetime, timedelta, timezone
from jose import jwt from jose import jwt
from pydantic import ValidationError
from app.core.auth import ( from app.core.auth import (
verify_password, TokenExpiredError,
get_password_hash, TokenInvalidError,
TokenMissingClaimError,
create_access_token, create_access_token,
create_refresh_token, create_refresh_token,
decode_token, decode_token,
get_password_hash,
get_token_data, get_token_data,
TokenExpiredError, verify_password,
TokenInvalidError,
TokenMissingClaimError
) )
from app.core.config import settings from app.core.config import settings
@@ -58,15 +58,13 @@ class TestTokenCreation:
custom_claims = { custom_claims = {
"email": "test@example.com", "email": "test@example.com",
"first_name": "Test", "first_name": "Test",
"is_superuser": True "is_superuser": True,
} }
token = create_access_token(subject=user_id, claims=custom_claims) token = create_access_token(subject=user_id, claims=custom_claims)
# Decode token to verify claims # Decode token to verify claims
payload = jwt.decode( payload = jwt.decode(
token, token, settings.SECRET_KEY, algorithms=[settings.ALGORITHM]
settings.SECRET_KEY,
algorithms=[settings.ALGORITHM]
) )
# Check standard claims # Check standard claims
@@ -87,9 +85,7 @@ class TestTokenCreation:
# Decode token to verify claims # Decode token to verify claims
payload = jwt.decode( payload = jwt.decode(
token, token, settings.SECRET_KEY, algorithms=[settings.ALGORITHM]
settings.SECRET_KEY,
algorithms=[settings.ALGORITHM]
) )
# Check standard claims # Check standard claims
@@ -105,23 +101,18 @@ class TestTokenCreation:
expires = timedelta(minutes=5) expires = timedelta(minutes=5)
# Create token with specific expiration # Create token with specific expiration
token = create_access_token( token = create_access_token(subject=user_id, expires_delta=expires)
subject=user_id,
expires_delta=expires
)
# Decode token # Decode token
payload = jwt.decode( payload = jwt.decode(
token, token, settings.SECRET_KEY, algorithms=[settings.ALGORITHM]
settings.SECRET_KEY,
algorithms=[settings.ALGORITHM]
) )
# Get actual expiration time from token # Get actual expiration time from token
expiration = datetime.fromtimestamp(payload["exp"], tz=timezone.utc) expiration = datetime.fromtimestamp(payload["exp"], tz=UTC)
# Calculate expected expiration (approximately) # Calculate expected expiration (approximately)
now = datetime.now(timezone.utc) now = datetime.now(UTC)
expected_expiration = now + expires expected_expiration = now + expires
# Difference should be small (less than 1 second) # Difference should be small (less than 1 second)
@@ -148,7 +139,7 @@ class TestTokenDecoding:
user_id = str(uuid.uuid4()) user_id = str(uuid.uuid4())
# Create a token that's already expired by directly manipulating the payload # Create a token that's already expired by directly manipulating the payload
now = datetime.now(timezone.utc) now = datetime.now(UTC)
expired_time = now - timedelta(hours=1) # 1 hour in the past expired_time = now - timedelta(hours=1) # 1 hour in the past
# Create the expired token manually # Create the expired token manually
@@ -157,13 +148,11 @@ class TestTokenDecoding:
"exp": int(expired_time.timestamp()), # Set expiration in the past "exp": int(expired_time.timestamp()), # Set expiration in the past
"iat": int(now.timestamp()), "iat": int(now.timestamp()),
"jti": str(uuid.uuid4()), "jti": str(uuid.uuid4()),
"type": "access" "type": "access",
} }
expired_token = jwt.encode( expired_token = jwt.encode(
payload, payload, settings.SECRET_KEY, algorithm=settings.ALGORITHM
settings.SECRET_KEY,
algorithm=settings.ALGORITHM
) )
# Attempting to decode should raise TokenExpiredError # Attempting to decode should raise TokenExpiredError
@@ -180,20 +169,16 @@ class TestTokenDecoding:
def test_decode_token_with_missing_sub(self): def test_decode_token_with_missing_sub(self):
"""Test that a token without 'sub' claim raises TokenMissingClaimError""" """Test that a token without 'sub' claim raises TokenMissingClaimError"""
# Create a token without a subject # Create a token without a subject
now = datetime.now(timezone.utc) now = datetime.now(UTC)
payload = { payload = {
"exp": int((now + timedelta(minutes=30)).timestamp()), "exp": int((now + timedelta(minutes=30)).timestamp()),
"iat": int(now.timestamp()), "iat": int(now.timestamp()),
"jti": str(uuid.uuid4()), "jti": str(uuid.uuid4()),
"type": "access" "type": "access",
# No 'sub' claim # No 'sub' claim
} }
token = jwt.encode( token = jwt.encode(payload, settings.SECRET_KEY, algorithm=settings.ALGORITHM)
payload,
settings.SECRET_KEY,
algorithm=settings.ALGORITHM
)
with pytest.raises(TokenMissingClaimError): with pytest.raises(TokenMissingClaimError):
decode_token(token) decode_token(token)
@@ -211,20 +196,16 @@ class TestTokenDecoding:
"""Test that a token with invalid payload structure raises TokenInvalidError""" """Test that a token with invalid payload structure raises TokenInvalidError"""
# Create a token with an invalid payload structure - missing 'sub' which is required # Create a token with an invalid payload structure - missing 'sub' which is required
# but including 'exp' to avoid the expiration check # but including 'exp' to avoid the expiration check
now = datetime.now(timezone.utc) now = datetime.now(UTC)
payload = { payload = {
# Missing "sub" field which is required # Missing "sub" field which is required
"exp": int((now + timedelta(minutes=30)).timestamp()), "exp": int((now + timedelta(minutes=30)).timestamp()),
"iat": int(now.timestamp()), "iat": int(now.timestamp()),
"jti": str(uuid.uuid4()), "jti": str(uuid.uuid4()),
"invalid_field": "test" "invalid_field": "test",
} }
token = jwt.encode( token = jwt.encode(payload, settings.SECRET_KEY, algorithm=settings.ALGORITHM)
payload,
settings.SECRET_KEY,
algorithm=settings.ALGORITHM
)
# Should raise TokenMissingClaimError due to missing 'sub' # Should raise TokenMissingClaimError due to missing 'sub'
with pytest.raises(TokenMissingClaimError): with pytest.raises(TokenMissingClaimError):
@@ -236,11 +217,7 @@ class TestTokenDecoding:
"exp": int((now + timedelta(minutes=30)).timestamp()), "exp": int((now + timedelta(minutes=30)).timestamp()),
} }
token = jwt.encode( token = jwt.encode(payload, settings.SECRET_KEY, algorithm=settings.ALGORITHM)
payload,
settings.SECRET_KEY,
algorithm=settings.ALGORITHM
)
# Should raise TokenInvalidError due to ValidationError # Should raise TokenInvalidError due to ValidationError
with pytest.raises(TokenInvalidError): with pytest.raises(TokenInvalidError):
@@ -249,10 +226,7 @@ class TestTokenDecoding:
def test_get_token_data(self): def test_get_token_data(self):
"""Test extracting TokenData from a token""" """Test extracting TokenData from a token"""
user_id = uuid.uuid4() user_id = uuid.uuid4()
token = create_access_token( token = create_access_token(subject=str(user_id), claims={"is_superuser": True})
subject=str(user_id),
claims={"is_superuser": True}
)
token_data = get_token_data(token) token_data = get_token_data(token)

View File

@@ -8,11 +8,11 @@ Critical security tests covering:
These tests cover critical security vulnerabilities that could be exploited. These tests cover critical security vulnerabilities that could be exploited.
""" """
import pytest import pytest
from jose import jwt from jose import jwt
from datetime import datetime, timedelta, timezone
from app.core.auth import decode_token, create_access_token, TokenInvalidError from app.core.auth import TokenInvalidError, create_access_token, decode_token
from app.core.config import settings from app.core.config import settings
@@ -46,13 +46,14 @@ class TestJWTAlgorithmSecurityAttacks:
""" """
# Create a payload that would normally be valid (using timestamps) # Create a payload that would normally be valid (using timestamps)
import time import time
now = int(time.time()) now = int(time.time())
payload = { payload = {
"sub": "user123", "sub": "user123",
"exp": now + 3600, # 1 hour from now "exp": now + 3600, # 1 hour from now
"iat": now, "iat": now,
"type": "access" "type": "access",
} }
# Craft a malicious token with "alg: none" # Craft a malicious token with "alg: none"
@@ -61,13 +62,13 @@ class TestJWTAlgorithmSecurityAttacks:
import json import json
header = {"alg": "none", "typ": "JWT"} header = {"alg": "none", "typ": "JWT"}
header_encoded = base64.urlsafe_b64encode( header_encoded = (
json.dumps(header).encode() base64.urlsafe_b64encode(json.dumps(header).encode()).decode().rstrip("=")
).decode().rstrip("=") )
payload_encoded = base64.urlsafe_b64encode( payload_encoded = (
json.dumps(payload).encode() base64.urlsafe_b64encode(json.dumps(payload).encode()).decode().rstrip("=")
).decode().rstrip("=") )
# Token with no signature (algorithm "none") # Token with no signature (algorithm "none")
malicious_token = f"{header_encoded}.{payload_encoded}." malicious_token = f"{header_encoded}.{payload_encoded}."
@@ -85,22 +86,17 @@ class TestJWTAlgorithmSecurityAttacks:
import time import time
now = int(time.time()) now = int(time.time())
payload = { payload = {"sub": "user123", "exp": now + 3600, "iat": now, "type": "access"}
"sub": "user123",
"exp": now + 3600,
"iat": now,
"type": "access"
}
# Try uppercase "NONE" # Try uppercase "NONE"
header = {"alg": "NONE", "typ": "JWT"} header = {"alg": "NONE", "typ": "JWT"}
header_encoded = base64.urlsafe_b64encode( header_encoded = (
json.dumps(header).encode() base64.urlsafe_b64encode(json.dumps(header).encode()).decode().rstrip("=")
).decode().rstrip("=") )
payload_encoded = base64.urlsafe_b64encode( payload_encoded = (
json.dumps(payload).encode() base64.urlsafe_b64encode(json.dumps(payload).encode()).decode().rstrip("=")
).decode().rstrip("=") )
malicious_token = f"{header_encoded}.{payload_encoded}." malicious_token = f"{header_encoded}.{payload_encoded}."
@@ -121,15 +117,11 @@ class TestJWTAlgorithmSecurityAttacks:
before our defensive checks at line 212. This is good for security! before our defensive checks at line 212. This is good for security!
""" """
import time import time
now = int(time.time()) now = int(time.time())
# Create a valid payload # Create a valid payload
payload = { payload = {"sub": "user123", "exp": now + 3600, "iat": now, "type": "access"}
"sub": "user123",
"exp": now + 3600,
"iat": now,
"type": "access"
}
# Encode with wrong algorithm (RS256 instead of HS256) # Encode with wrong algorithm (RS256 instead of HS256)
# This simulates an attacker trying algorithm substitution # This simulates an attacker trying algorithm substitution
@@ -137,9 +129,7 @@ class TestJWTAlgorithmSecurityAttacks:
try: try:
malicious_token = jwt.encode( malicious_token = jwt.encode(
payload, payload, settings.SECRET_KEY, algorithm=wrong_algorithm
settings.SECRET_KEY,
algorithm=wrong_algorithm
) )
# Should reject the token (library catches mismatch) # Should reject the token (library catches mismatch)
@@ -156,21 +146,15 @@ class TestJWTAlgorithmSecurityAttacks:
Prevents algorithm downgrade/upgrade attacks. Prevents algorithm downgrade/upgrade attacks.
""" """
import time import time
now = int(time.time()) now = int(time.time())
payload = { payload = {"sub": "user123", "exp": now + 3600, "iat": now, "type": "access"}
"sub": "user123",
"exp": now + 3600,
"iat": now,
"type": "access"
}
# Create token with HS384 instead of HS256 # Create token with HS384 instead of HS256
try: try:
malicious_token = jwt.encode( malicious_token = jwt.encode(
payload, payload, settings.SECRET_KEY, algorithm="HS384"
settings.SECRET_KEY,
algorithm="HS384"
) )
with pytest.raises(TokenInvalidError): with pytest.raises(TokenInvalidError):
@@ -223,20 +207,15 @@ class TestJWTSecurityEdgeCases:
# Create token without "alg" in header # Create token without "alg" in header
header = {"typ": "JWT"} # Missing "alg" header = {"typ": "JWT"} # Missing "alg"
payload = { payload = {"sub": "user123", "exp": now + 3600, "iat": now, "type": "access"}
"sub": "user123",
"exp": now + 3600,
"iat": now,
"type": "access"
}
header_encoded = base64.urlsafe_b64encode( header_encoded = (
json.dumps(header).encode() base64.urlsafe_b64encode(json.dumps(header).encode()).decode().rstrip("=")
).decode().rstrip("=") )
payload_encoded = base64.urlsafe_b64encode( payload_encoded = (
json.dumps(payload).encode() base64.urlsafe_b64encode(json.dumps(payload).encode()).decode().rstrip("=")
).decode().rstrip("=") )
malicious_token = f"{header_encoded}.{payload_encoded}.fake_signature" malicious_token = f"{header_encoded}.{payload_encoded}.fake_signature"
@@ -253,15 +232,20 @@ class TestJWTSecurityEdgeCases:
"""Test token with malformed JSON in payload.""" """Test token with malformed JSON in payload."""
import base64 import base64
header = {"alg": "HS256", "typ": "JWT"} header_encoded = (
header_encoded = base64.urlsafe_b64encode( base64.urlsafe_b64encode(b'{"alg":"HS256","typ":"JWT"}')
b'{"alg":"HS256","typ":"JWT"}' .decode()
).decode().rstrip("=") .rstrip("=")
)
# Invalid JSON (missing closing brace) # Invalid JSON (missing closing brace)
invalid_payload_encoded = base64.urlsafe_b64encode( invalid_payload_encoded = (
b'{"sub":"user123"' # Invalid JSON base64.urlsafe_b64encode(
).decode().rstrip("=") b'{"sub":"user123"' # Invalid JSON
)
.decode()
.rstrip("=")
)
malicious_token = f"{header_encoded}.{invalid_payload_encoded}.fake_sig" malicious_token = f"{header_encoded}.{invalid_payload_encoded}.fake_sig"

View File

@@ -1,6 +1,7 @@
# tests/core/test_config.py # tests/core/test_config.py
import pytest import pytest
from pydantic import ValidationError from pydantic import ValidationError
from app.core.config import Settings from app.core.config import Settings
@@ -22,11 +23,15 @@ class TestSecretKeyValidation:
with pytest.raises(ValidationError) as exc_info: with pytest.raises(ValidationError) as exc_info:
Settings(SECRET_KEY=default_key, ENVIRONMENT="production") Settings(SECRET_KEY=default_key, ENVIRONMENT="production")
assert "must be set to a secure random value in production" in str(exc_info.value) assert "must be set to a secure random value in production" in str(
exc_info.value
)
def test_default_secret_key_in_development_allows_with_warning(self, caplog): def test_default_secret_key_in_development_allows_with_warning(self, caplog):
"""Test that default SECRET_KEY in development is allowed but warns""" """Test that default SECRET_KEY in development is allowed but warns"""
settings = Settings(SECRET_KEY="your_secret_key_here" + "x" * 14, ENVIRONMENT="development") settings = Settings(
SECRET_KEY="your_secret_key_here" + "x" * 14, ENVIRONMENT="development"
)
assert settings.SECRET_KEY == "your_secret_key_here" + "x" * 14 assert settings.SECRET_KEY == "your_secret_key_here" + "x" * 14
# Note: The warning happens during validation, which we've seen works # Note: The warning happens during validation, which we've seen works
@@ -44,19 +49,13 @@ class TestSuperuserPasswordValidation:
def test_none_password_accepted(self): def test_none_password_accepted(self):
"""Test that None password is accepted (optional field)""" """Test that None password is accepted (optional field)"""
settings = Settings( settings = Settings(SECRET_KEY="a" * 32, FIRST_SUPERUSER_PASSWORD=None)
SECRET_KEY="a" * 32,
FIRST_SUPERUSER_PASSWORD=None
)
assert settings.FIRST_SUPERUSER_PASSWORD is None assert settings.FIRST_SUPERUSER_PASSWORD is None
def test_password_too_short_raises_error(self): def test_password_too_short_raises_error(self):
"""Test that password shorter than 12 characters raises error""" """Test that password shorter than 12 characters raises error"""
with pytest.raises(ValidationError) as exc_info: with pytest.raises(ValidationError) as exc_info:
Settings( Settings(SECRET_KEY="a" * 32, FIRST_SUPERUSER_PASSWORD="Short1")
SECRET_KEY="a" * 32,
FIRST_SUPERUSER_PASSWORD="Short1"
)
assert "must be at least 12 characters" in str(exc_info.value) assert "must be at least 12 characters" in str(exc_info.value)
@@ -64,14 +63,11 @@ class TestSuperuserPasswordValidation:
"""Test that common weak passwords are rejected""" """Test that common weak passwords are rejected"""
# Test with the exact weak passwords from the validator # Test with the exact weak passwords from the validator
# These are in the weak_passwords set and should be rejected # These are in the weak_passwords set and should be rejected
weak_passwords = ['123456789012'] # Exactly 12 chars, in the weak set weak_passwords = ["123456789012"] # Exactly 12 chars, in the weak set
for weak_pwd in weak_passwords: for weak_pwd in weak_passwords:
with pytest.raises(ValidationError) as exc_info: with pytest.raises(ValidationError) as exc_info:
Settings( Settings(SECRET_KEY="a" * 32, FIRST_SUPERUSER_PASSWORD=weak_pwd)
SECRET_KEY="a" * 32,
FIRST_SUPERUSER_PASSWORD=weak_pwd
)
# Should get "too weak" message # Should get "too weak" message
error_str = str(exc_info.value) error_str = str(exc_info.value)
assert "too weak" in error_str assert "too weak" in error_str
@@ -79,30 +75,21 @@ class TestSuperuserPasswordValidation:
def test_password_without_lowercase_rejected(self): def test_password_without_lowercase_rejected(self):
"""Test that password without lowercase is rejected""" """Test that password without lowercase is rejected"""
with pytest.raises(ValidationError) as exc_info: with pytest.raises(ValidationError) as exc_info:
Settings( Settings(SECRET_KEY="a" * 32, FIRST_SUPERUSER_PASSWORD="ALLUPPERCASE123")
SECRET_KEY="a" * 32,
FIRST_SUPERUSER_PASSWORD="ALLUPPERCASE123"
)
assert "must contain lowercase, uppercase, and digits" in str(exc_info.value) assert "must contain lowercase, uppercase, and digits" in str(exc_info.value)
def test_password_without_uppercase_rejected(self): def test_password_without_uppercase_rejected(self):
"""Test that password without uppercase is rejected""" """Test that password without uppercase is rejected"""
with pytest.raises(ValidationError) as exc_info: with pytest.raises(ValidationError) as exc_info:
Settings( Settings(SECRET_KEY="a" * 32, FIRST_SUPERUSER_PASSWORD="alllowercase123")
SECRET_KEY="a" * 32,
FIRST_SUPERUSER_PASSWORD="alllowercase123"
)
assert "must contain lowercase, uppercase, and digits" in str(exc_info.value) assert "must contain lowercase, uppercase, and digits" in str(exc_info.value)
def test_password_without_digit_rejected(self): def test_password_without_digit_rejected(self):
"""Test that password without digit is rejected""" """Test that password without digit is rejected"""
with pytest.raises(ValidationError) as exc_info: with pytest.raises(ValidationError) as exc_info:
Settings( Settings(SECRET_KEY="a" * 32, FIRST_SUPERUSER_PASSWORD="NoDigitsHere")
SECRET_KEY="a" * 32,
FIRST_SUPERUSER_PASSWORD="NoDigitsHere"
)
assert "must contain lowercase, uppercase, and digits" in str(exc_info.value) assert "must contain lowercase, uppercase, and digits" in str(exc_info.value)
@@ -110,8 +97,7 @@ class TestSuperuserPasswordValidation:
"""Test that strong password is accepted""" """Test that strong password is accepted"""
strong_password = "StrongPassword123!" strong_password = "StrongPassword123!"
settings = Settings( settings = Settings(
SECRET_KEY="a" * 32, SECRET_KEY="a" * 32, FIRST_SUPERUSER_PASSWORD=strong_password
FIRST_SUPERUSER_PASSWORD=strong_password
) )
assert settings.FIRST_SUPERUSER_PASSWORD == strong_password assert settings.FIRST_SUPERUSER_PASSWORD == strong_password
@@ -150,7 +136,7 @@ class TestDatabaseConfiguration:
POSTGRES_HOST="testhost", POSTGRES_HOST="testhost",
POSTGRES_PORT="5432", POSTGRES_PORT="5432",
POSTGRES_DB="testdb", POSTGRES_DB="testdb",
DATABASE_URL=None # Don't use explicit URL DATABASE_URL=None, # Don't use explicit URL
) )
expected_url = "postgresql://testuser:testpass@testhost:5432/testdb" expected_url = "postgresql://testuser:testpass@testhost:5432/testdb"
@@ -159,10 +145,7 @@ class TestDatabaseConfiguration:
def test_explicit_database_url_used_when_set(self): def test_explicit_database_url_used_when_set(self):
"""Test that explicit DATABASE_URL is used when provided""" """Test that explicit DATABASE_URL is used when provided"""
explicit_url = "postgresql://explicit:pass@host:5432/db" explicit_url = "postgresql://explicit:pass@host:5432/db"
settings = Settings( settings = Settings(SECRET_KEY="a" * 32, DATABASE_URL=explicit_url)
SECRET_KEY="a" * 32,
DATABASE_URL=explicit_url
)
assert settings.database_url == explicit_url assert settings.database_url == explicit_url

View File

@@ -6,8 +6,10 @@ Critical security tests covering:
These tests prevent security misconfigurations. These tests prevent security misconfigurations.
""" """
import pytest
import os import os
import pytest
from pydantic import ValidationError from pydantic import ValidationError
@@ -43,6 +45,7 @@ class TestSecretKeySecurityValidation:
# Import Settings class fresh (to pick up new env var) # Import Settings class fresh (to pick up new env var)
# The ValidationError should be raised during reload when Settings() is instantiated # The ValidationError should be raised during reload when Settings() is instantiated
import importlib import importlib
from app.core import config from app.core import config
# Reload will raise ValidationError because Settings() is instantiated at module level # Reload will raise ValidationError because Settings() is instantiated at module level
@@ -58,7 +61,9 @@ class TestSecretKeySecurityValidation:
# Reload config to restore original settings # Reload config to restore original settings
import importlib import importlib
from app.core import config from app.core import config
importlib.reload(config) importlib.reload(config)
def test_secret_key_exactly_32_characters_accepted(self): def test_secret_key_exactly_32_characters_accepted(self):
@@ -75,7 +80,9 @@ class TestSecretKeySecurityValidation:
os.environ["SECRET_KEY"] = key_32 os.environ["SECRET_KEY"] = key_32
import importlib import importlib
from app.core import config from app.core import config
importlib.reload(config) importlib.reload(config)
# Should work # Should work
@@ -89,7 +96,9 @@ class TestSecretKeySecurityValidation:
os.environ.pop("SECRET_KEY", None) os.environ.pop("SECRET_KEY", None)
import importlib import importlib
from app.core import config from app.core import config
importlib.reload(config) importlib.reload(config)
def test_secret_key_long_enough_accepted(self): def test_secret_key_long_enough_accepted(self):
@@ -106,7 +115,9 @@ class TestSecretKeySecurityValidation:
os.environ["SECRET_KEY"] = key_64 os.environ["SECRET_KEY"] = key_64
import importlib import importlib
from app.core import config from app.core import config
importlib.reload(config) importlib.reload(config)
# Should work # Should work
@@ -120,7 +131,9 @@ class TestSecretKeySecurityValidation:
os.environ.pop("SECRET_KEY", None) os.environ.pop("SECRET_KEY", None)
import importlib import importlib
from app.core import config from app.core import config
importlib.reload(config) importlib.reload(config)
def test_default_secret_key_meets_requirements(self): def test_default_secret_key_meets_requirements(self):
@@ -132,4 +145,6 @@ class TestSecretKeySecurityValidation:
from app.core.config import settings from app.core.config import settings
# Current settings should have valid SECRET_KEY # Current settings should have valid SECRET_KEY
assert len(settings.SECRET_KEY) >= 32, "Default SECRET_KEY must be at least 32 chars" assert len(settings.SECRET_KEY) >= 32, (
"Default SECRET_KEY must be at least 32 chars"
)

View File

@@ -9,18 +9,19 @@ Covers:
- init_async_db - init_async_db
- close_async_db - close_async_db
""" """
from unittest.mock import patch
import pytest import pytest
import pytest_asyncio
from unittest.mock import patch, MagicMock, AsyncMock
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from app.core.database import ( from app.core.database import (
get_async_database_url,
get_db,
async_transaction_scope, async_transaction_scope,
check_async_database_health, check_async_database_health,
init_async_db,
close_async_db, close_async_db,
get_async_database_url,
get_db,
init_async_db,
) )
@@ -88,12 +89,13 @@ class TestAsyncTransactionScope:
async def test_transaction_scope_commits_on_success(self, async_test_db): async def test_transaction_scope_commits_on_success(self, async_test_db):
"""Test that successful operations are committed (covers line 138).""" """Test that successful operations are committed (covers line 138)."""
# Mock the transaction scope to use test database # Mock the transaction scope to use test database
test_engine, SessionLocal = async_test_db _test_engine, SessionLocal = async_test_db
with patch('app.core.database.SessionLocal', SessionLocal): with patch("app.core.database.SessionLocal", SessionLocal):
async with async_transaction_scope() as db: async with async_transaction_scope() as db:
# Execute a simple query to verify transaction works # Execute a simple query to verify transaction works
from sqlalchemy import text from sqlalchemy import text
result = await db.execute(text("SELECT 1")) result = await db.execute(text("SELECT 1"))
assert result is not None assert result is not None
# Transaction should be committed (covers line 138 debug log) # Transaction should be committed (covers line 138 debug log)
@@ -101,12 +103,13 @@ class TestAsyncTransactionScope:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_transaction_scope_rollback_on_error(self, async_test_db): async def test_transaction_scope_rollback_on_error(self, async_test_db):
"""Test that transaction rolls back on exception.""" """Test that transaction rolls back on exception."""
test_engine, SessionLocal = async_test_db _test_engine, SessionLocal = async_test_db
with patch('app.core.database.SessionLocal', SessionLocal): with patch("app.core.database.SessionLocal", SessionLocal):
with pytest.raises(RuntimeError, match="Test error"): with pytest.raises(RuntimeError, match="Test error"):
async with async_transaction_scope() as db: async with async_transaction_scope() as db:
from sqlalchemy import text from sqlalchemy import text
await db.execute(text("SELECT 1")) await db.execute(text("SELECT 1"))
raise RuntimeError("Test error") raise RuntimeError("Test error")
@@ -117,9 +120,9 @@ class TestCheckAsyncDatabaseHealth:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_database_health_check_success(self, async_test_db): async def test_database_health_check_success(self, async_test_db):
"""Test health check returns True on success (covers line 156).""" """Test health check returns True on success (covers line 156)."""
test_engine, SessionLocal = async_test_db _test_engine, SessionLocal = async_test_db
with patch('app.core.database.SessionLocal', SessionLocal): with patch("app.core.database.SessionLocal", SessionLocal):
result = await check_async_database_health() result = await check_async_database_health()
assert result is True assert result is True
@@ -127,7 +130,7 @@ class TestCheckAsyncDatabaseHealth:
async def test_database_health_check_failure(self): async def test_database_health_check_failure(self):
"""Test health check returns False on database error.""" """Test health check returns False on database error."""
# Mock async_transaction_scope to raise an error # Mock async_transaction_scope to raise an error
with patch('app.core.database.async_transaction_scope') as mock_scope: with patch("app.core.database.async_transaction_scope") as mock_scope:
mock_scope.side_effect = Exception("Database connection failed") mock_scope.side_effect = Exception("Database connection failed")
result = await check_async_database_health() result = await check_async_database_health()
@@ -140,10 +143,10 @@ class TestInitAsyncDb:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_init_async_db_creates_tables(self, async_test_db): async def test_init_async_db_creates_tables(self, async_test_db):
"""Test init_async_db creates tables (covers lines 174-176).""" """Test init_async_db creates tables (covers lines 174-176)."""
test_engine, SessionLocal = async_test_db test_engine, _SessionLocal = async_test_db
# Mock the engine to use test engine # Mock the engine to use test engine
with patch('app.core.database.engine', test_engine): with patch("app.core.database.engine", test_engine):
await init_async_db() await init_async_db()
# If no exception, tables were created successfully # If no exception, tables were created successfully
@@ -155,7 +158,6 @@ class TestCloseAsyncDb:
async def test_close_async_db_disposes_engine(self): async def test_close_async_db_disposes_engine(self):
"""Test close_async_db disposes engine (covers lines 185-186).""" """Test close_async_db disposes engine (covers lines 185-186)."""
# Create a fresh engine to test closing # Create a fresh engine to test closing
from app.core.database import engine
# Close connections # Close connections
await close_async_db() await close_async_db()

View File

@@ -2,14 +2,16 @@
""" """
Comprehensive tests for CRUDBase class covering all error paths and edge cases. Comprehensive tests for CRUDBase class covering all error paths and edge cases.
""" """
from datetime import UTC
from unittest.mock import patch
from uuid import uuid4
import pytest import pytest
from uuid import uuid4, UUID from sqlalchemy.exc import DataError, IntegrityError, OperationalError
from sqlalchemy.exc import IntegrityError, OperationalError, DataError
from sqlalchemy.orm import joinedload from sqlalchemy.orm import joinedload
from unittest.mock import AsyncMock, patch, MagicMock
from app.crud.user import user as user_crud from app.crud.user import user as user_crud
from app.models.user import User
from app.schemas.users import UserCreate, UserUpdate from app.schemas.users import UserCreate, UserUpdate
@@ -19,7 +21,7 @@ class TestCRUDBaseGet:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_get_with_invalid_uuid_string(self, async_test_db): async def test_get_with_invalid_uuid_string(self, async_test_db):
"""Test get with invalid UUID string returns None.""" """Test get with invalid UUID string returns None."""
test_engine, SessionLocal = async_test_db _test_engine, SessionLocal = async_test_db
async with SessionLocal() as session: async with SessionLocal() as session:
result = await user_crud.get(session, id="invalid-uuid") result = await user_crud.get(session, id="invalid-uuid")
@@ -28,7 +30,7 @@ class TestCRUDBaseGet:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_get_with_invalid_uuid_type(self, async_test_db): async def test_get_with_invalid_uuid_type(self, async_test_db):
"""Test get with invalid UUID type returns None.""" """Test get with invalid UUID type returns None."""
test_engine, SessionLocal = async_test_db _test_engine, SessionLocal = async_test_db
async with SessionLocal() as session: async with SessionLocal() as session:
result = await user_crud.get(session, id=12345) # int instead of UUID result = await user_crud.get(session, id=12345) # int instead of UUID
@@ -37,7 +39,7 @@ class TestCRUDBaseGet:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_get_with_uuid_object(self, async_test_db, async_test_user): async def test_get_with_uuid_object(self, async_test_db, async_test_user):
"""Test get with UUID object instead of string.""" """Test get with UUID object instead of string."""
test_engine, SessionLocal = async_test_db _test_engine, SessionLocal = async_test_db
async with SessionLocal() as session: async with SessionLocal() as session:
# Pass UUID object directly # Pass UUID object directly
@@ -48,26 +50,24 @@ class TestCRUDBaseGet:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_get_with_options(self, async_test_db, async_test_user): async def test_get_with_options(self, async_test_db, async_test_user):
"""Test get with eager loading options (tests lines 76-78).""" """Test get with eager loading options (tests lines 76-78)."""
test_engine, SessionLocal = async_test_db _test_engine, SessionLocal = async_test_db
async with SessionLocal() as session: async with SessionLocal() as session:
# Test that options parameter is accepted and doesn't error # Test that options parameter is accepted and doesn't error
# We pass an empty list which still tests the code path # We pass an empty list which still tests the code path
result = await user_crud.get( result = await user_crud.get(
session, session, id=str(async_test_user.id), options=[]
id=str(async_test_user.id),
options=[]
) )
assert result is not None assert result is not None
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_get_database_error(self, async_test_db): async def test_get_database_error(self, async_test_db):
"""Test get handles database errors properly.""" """Test get handles database errors properly."""
test_engine, SessionLocal = async_test_db _test_engine, SessionLocal = async_test_db
async with SessionLocal() as session: async with SessionLocal() as session:
# Mock execute to raise an exception # Mock execute to raise an exception
with patch.object(session, 'execute', side_effect=Exception("DB error")): with patch.object(session, "execute", side_effect=Exception("DB error")):
with pytest.raises(Exception, match="DB error"): with pytest.raises(Exception, match="DB error"):
await user_crud.get(session, id=str(uuid4())) await user_crud.get(session, id=str(uuid4()))
@@ -78,7 +78,7 @@ class TestCRUDBaseGetMulti:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_get_multi_negative_skip(self, async_test_db): async def test_get_multi_negative_skip(self, async_test_db):
"""Test get_multi with negative skip raises ValueError.""" """Test get_multi with negative skip raises ValueError."""
test_engine, SessionLocal = async_test_db _test_engine, SessionLocal = async_test_db
async with SessionLocal() as session: async with SessionLocal() as session:
with pytest.raises(ValueError, match="skip must be non-negative"): with pytest.raises(ValueError, match="skip must be non-negative"):
@@ -87,7 +87,7 @@ class TestCRUDBaseGetMulti:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_get_multi_negative_limit(self, async_test_db): async def test_get_multi_negative_limit(self, async_test_db):
"""Test get_multi with negative limit raises ValueError.""" """Test get_multi with negative limit raises ValueError."""
test_engine, SessionLocal = async_test_db _test_engine, SessionLocal = async_test_db
async with SessionLocal() as session: async with SessionLocal() as session:
with pytest.raises(ValueError, match="limit must be non-negative"): with pytest.raises(ValueError, match="limit must be non-negative"):
@@ -96,7 +96,7 @@ class TestCRUDBaseGetMulti:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_get_multi_limit_too_large(self, async_test_db): async def test_get_multi_limit_too_large(self, async_test_db):
"""Test get_multi with limit > 1000 raises ValueError.""" """Test get_multi with limit > 1000 raises ValueError."""
test_engine, SessionLocal = async_test_db _test_engine, SessionLocal = async_test_db
async with SessionLocal() as session: async with SessionLocal() as session:
with pytest.raises(ValueError, match="Maximum limit is 1000"): with pytest.raises(ValueError, match="Maximum limit is 1000"):
@@ -105,25 +105,20 @@ class TestCRUDBaseGetMulti:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_get_multi_with_options(self, async_test_db, async_test_user): async def test_get_multi_with_options(self, async_test_db, async_test_user):
"""Test get_multi with eager loading options (tests lines 118-120).""" """Test get_multi with eager loading options (tests lines 118-120)."""
test_engine, SessionLocal = async_test_db _test_engine, SessionLocal = async_test_db
async with SessionLocal() as session: async with SessionLocal() as session:
# Test that options parameter is accepted # Test that options parameter is accepted
results = await user_crud.get_multi( results = await user_crud.get_multi(session, skip=0, limit=10, options=[])
session,
skip=0,
limit=10,
options=[]
)
assert isinstance(results, list) assert isinstance(results, list)
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_get_multi_database_error(self, async_test_db): async def test_get_multi_database_error(self, async_test_db):
"""Test get_multi handles database errors.""" """Test get_multi handles database errors."""
test_engine, SessionLocal = async_test_db _test_engine, SessionLocal = async_test_db
async with SessionLocal() as session: async with SessionLocal() as session:
with patch.object(session, 'execute', side_effect=Exception("DB error")): with patch.object(session, "execute", side_effect=Exception("DB error")):
with pytest.raises(Exception, match="DB error"): with pytest.raises(Exception, match="DB error"):
await user_crud.get_multi(session) await user_crud.get_multi(session)
@@ -134,7 +129,7 @@ class TestCRUDBaseCreate:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_create_duplicate_unique_field(self, async_test_db, async_test_user): async def test_create_duplicate_unique_field(self, async_test_db, async_test_user):
"""Test create with duplicate unique field raises ValueError.""" """Test create with duplicate unique field raises ValueError."""
test_engine, SessionLocal = async_test_db _test_engine, SessionLocal = async_test_db
async with SessionLocal() as session: async with SessionLocal() as session:
# Try to create user with duplicate email # Try to create user with duplicate email
@@ -142,7 +137,7 @@ class TestCRUDBaseCreate:
email=async_test_user.email, # Duplicate! email=async_test_user.email, # Duplicate!
password="TestPassword123!", password="TestPassword123!",
first_name="Test", first_name="Test",
last_name="Duplicate" last_name="Duplicate",
) )
with pytest.raises(ValueError, match="already exists"): with pytest.raises(ValueError, match="already exists"):
@@ -151,22 +146,23 @@ class TestCRUDBaseCreate:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_create_integrity_error_non_duplicate(self, async_test_db): async def test_create_integrity_error_non_duplicate(self, async_test_db):
"""Test create with non-duplicate IntegrityError.""" """Test create with non-duplicate IntegrityError."""
test_engine, SessionLocal = async_test_db _test_engine, SessionLocal = async_test_db
async with SessionLocal() as session: async with SessionLocal() as session:
# Mock commit to raise IntegrityError without "unique" in message # Mock commit to raise IntegrityError without "unique" in message
original_commit = session.commit
async def mock_commit(): async def mock_commit():
error = IntegrityError("statement", {}, Exception("foreign key violation")) error = IntegrityError(
"statement", {}, Exception("foreign key violation")
)
raise error raise error
with patch.object(session, 'commit', side_effect=mock_commit): with patch.object(session, "commit", side_effect=mock_commit):
user_data = UserCreate( user_data = UserCreate(
email="test@example.com", email="test@example.com",
password="TestPassword123!", password="TestPassword123!",
first_name="Test", first_name="Test",
last_name="User" last_name="User",
) )
with pytest.raises(ValueError, match="Database integrity error"): with pytest.raises(ValueError, match="Database integrity error"):
@@ -175,15 +171,21 @@ class TestCRUDBaseCreate:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_create_operational_error(self, async_test_db): async def test_create_operational_error(self, async_test_db):
"""Test create with OperationalError (user CRUD catches as generic Exception).""" """Test create with OperationalError (user CRUD catches as generic Exception)."""
test_engine, SessionLocal = async_test_db _test_engine, SessionLocal = async_test_db
async with SessionLocal() as session: async with SessionLocal() as session:
with patch.object(session, 'commit', side_effect=OperationalError("statement", {}, Exception("connection lost"))): with patch.object(
session,
"commit",
side_effect=OperationalError(
"statement", {}, Exception("connection lost")
),
):
user_data = UserCreate( user_data = UserCreate(
email="test@example.com", email="test@example.com",
password="TestPassword123!", password="TestPassword123!",
first_name="Test", first_name="Test",
last_name="User" last_name="User",
) )
# User CRUD catches this as generic Exception and re-raises # User CRUD catches this as generic Exception and re-raises
@@ -193,15 +195,19 @@ class TestCRUDBaseCreate:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_create_data_error(self, async_test_db): async def test_create_data_error(self, async_test_db):
"""Test create with DataError (user CRUD catches as generic Exception).""" """Test create with DataError (user CRUD catches as generic Exception)."""
test_engine, SessionLocal = async_test_db _test_engine, SessionLocal = async_test_db
async with SessionLocal() as session: async with SessionLocal() as session:
with patch.object(session, 'commit', side_effect=DataError("statement", {}, Exception("invalid data"))): with patch.object(
session,
"commit",
side_effect=DataError("statement", {}, Exception("invalid data")),
):
user_data = UserCreate( user_data = UserCreate(
email="test@example.com", email="test@example.com",
password="TestPassword123!", password="TestPassword123!",
first_name="Test", first_name="Test",
last_name="User" last_name="User",
) )
# User CRUD catches this as generic Exception and re-raises # User CRUD catches this as generic Exception and re-raises
@@ -211,15 +217,17 @@ class TestCRUDBaseCreate:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_create_unexpected_error(self, async_test_db): async def test_create_unexpected_error(self, async_test_db):
"""Test create with unexpected exception.""" """Test create with unexpected exception."""
test_engine, SessionLocal = async_test_db _test_engine, SessionLocal = async_test_db
async with SessionLocal() as session: async with SessionLocal() as session:
with patch.object(session, 'commit', side_effect=RuntimeError("Unexpected error")): with patch.object(
session, "commit", side_effect=RuntimeError("Unexpected error")
):
user_data = UserCreate( user_data = UserCreate(
email="test@example.com", email="test@example.com",
password="TestPassword123!", password="TestPassword123!",
first_name="Test", first_name="Test",
last_name="User" last_name="User",
) )
with pytest.raises(RuntimeError, match="Unexpected error"): with pytest.raises(RuntimeError, match="Unexpected error"):
@@ -232,16 +240,17 @@ class TestCRUDBaseUpdate:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_update_duplicate_unique_field(self, async_test_db, async_test_user): async def test_update_duplicate_unique_field(self, async_test_db, async_test_user):
"""Test update with duplicate unique field raises ValueError.""" """Test update with duplicate unique field raises ValueError."""
test_engine, SessionLocal = async_test_db _test_engine, SessionLocal = async_test_db
# Create another user # Create another user
async with SessionLocal() as session: async with SessionLocal() as session:
from app.crud.user import user as user_crud from app.crud.user import user as user_crud
user2_data = UserCreate( user2_data = UserCreate(
email="user2@example.com", email="user2@example.com",
password="TestPassword123!", password="TestPassword123!",
first_name="User", first_name="User",
last_name="Two" last_name="Two",
) )
user2 = await user_crud.create(session, obj_in=user2_data) user2 = await user_crud.create(session, obj_in=user2_data)
await session.commit() await session.commit()
@@ -250,63 +259,89 @@ class TestCRUDBaseUpdate:
async with SessionLocal() as session: async with SessionLocal() as session:
user2_obj = await user_crud.get(session, id=str(user2.id)) user2_obj = await user_crud.get(session, id=str(user2.id))
with patch.object(session, 'commit', side_effect=IntegrityError("statement", {}, Exception("UNIQUE constraint failed"))): with patch.object(
session,
"commit",
side_effect=IntegrityError(
"statement", {}, Exception("UNIQUE constraint failed")
),
):
update_data = UserUpdate(email=async_test_user.email) update_data = UserUpdate(email=async_test_user.email)
with pytest.raises(ValueError, match="already exists"): with pytest.raises(ValueError, match="already exists"):
await user_crud.update(session, db_obj=user2_obj, obj_in=update_data) await user_crud.update(
session, db_obj=user2_obj, obj_in=update_data
)
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_update_with_dict(self, async_test_db, async_test_user): async def test_update_with_dict(self, async_test_db, async_test_user):
"""Test update with dict instead of schema.""" """Test update with dict instead of schema."""
test_engine, SessionLocal = async_test_db _test_engine, SessionLocal = async_test_db
async with SessionLocal() as session: async with SessionLocal() as session:
user = await user_crud.get(session, id=str(async_test_user.id)) user = await user_crud.get(session, id=str(async_test_user.id))
# Update with dict (tests lines 164-165) # Update with dict (tests lines 164-165)
updated = await user_crud.update( updated = await user_crud.update(
session, session, db_obj=user, obj_in={"first_name": "UpdatedName"}
db_obj=user,
obj_in={"first_name": "UpdatedName"}
) )
assert updated.first_name == "UpdatedName" assert updated.first_name == "UpdatedName"
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_update_integrity_error(self, async_test_db, async_test_user): async def test_update_integrity_error(self, async_test_db, async_test_user):
"""Test update with IntegrityError.""" """Test update with IntegrityError."""
test_engine, SessionLocal = async_test_db _test_engine, SessionLocal = async_test_db
async with SessionLocal() as session: async with SessionLocal() as session:
user = await user_crud.get(session, id=str(async_test_user.id)) user = await user_crud.get(session, id=str(async_test_user.id))
with patch.object(session, 'commit', side_effect=IntegrityError("statement", {}, Exception("constraint failed"))): with patch.object(
session,
"commit",
side_effect=IntegrityError(
"statement", {}, Exception("constraint failed")
),
):
with pytest.raises(ValueError, match="Database integrity error"): with pytest.raises(ValueError, match="Database integrity error"):
await user_crud.update(session, db_obj=user, obj_in={"first_name": "Test"}) await user_crud.update(
session, db_obj=user, obj_in={"first_name": "Test"}
)
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_update_operational_error(self, async_test_db, async_test_user): async def test_update_operational_error(self, async_test_db, async_test_user):
"""Test update with OperationalError.""" """Test update with OperationalError."""
test_engine, SessionLocal = async_test_db _test_engine, SessionLocal = async_test_db
async with SessionLocal() as session: async with SessionLocal() as session:
user = await user_crud.get(session, id=str(async_test_user.id)) user = await user_crud.get(session, id=str(async_test_user.id))
with patch.object(session, 'commit', side_effect=OperationalError("statement", {}, Exception("connection error"))): with patch.object(
session,
"commit",
side_effect=OperationalError(
"statement", {}, Exception("connection error")
),
):
with pytest.raises(ValueError, match="Database operation failed"): with pytest.raises(ValueError, match="Database operation failed"):
await user_crud.update(session, db_obj=user, obj_in={"first_name": "Test"}) await user_crud.update(
session, db_obj=user, obj_in={"first_name": "Test"}
)
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_update_unexpected_error(self, async_test_db, async_test_user): async def test_update_unexpected_error(self, async_test_db, async_test_user):
"""Test update with unexpected error.""" """Test update with unexpected error."""
test_engine, SessionLocal = async_test_db _test_engine, SessionLocal = async_test_db
async with SessionLocal() as session: async with SessionLocal() as session:
user = await user_crud.get(session, id=str(async_test_user.id)) user = await user_crud.get(session, id=str(async_test_user.id))
with patch.object(session, 'commit', side_effect=RuntimeError("Unexpected")): with patch.object(
session, "commit", side_effect=RuntimeError("Unexpected")
):
with pytest.raises(RuntimeError): with pytest.raises(RuntimeError):
await user_crud.update(session, db_obj=user, obj_in={"first_name": "Test"}) await user_crud.update(
session, db_obj=user, obj_in={"first_name": "Test"}
)
class TestCRUDBaseRemove: class TestCRUDBaseRemove:
@@ -315,7 +350,7 @@ class TestCRUDBaseRemove:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_remove_invalid_uuid(self, async_test_db): async def test_remove_invalid_uuid(self, async_test_db):
"""Test remove with invalid UUID returns None.""" """Test remove with invalid UUID returns None."""
test_engine, SessionLocal = async_test_db _test_engine, SessionLocal = async_test_db
async with SessionLocal() as session: async with SessionLocal() as session:
result = await user_crud.remove(session, id="invalid-uuid") result = await user_crud.remove(session, id="invalid-uuid")
@@ -324,7 +359,7 @@ class TestCRUDBaseRemove:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_remove_with_uuid_object(self, async_test_db, async_test_user): async def test_remove_with_uuid_object(self, async_test_db, async_test_user):
"""Test remove with UUID object.""" """Test remove with UUID object."""
test_engine, SessionLocal = async_test_db _test_engine, SessionLocal = async_test_db
# Create a user to delete # Create a user to delete
async with SessionLocal() as session: async with SessionLocal() as session:
@@ -332,7 +367,7 @@ class TestCRUDBaseRemove:
email="todelete@example.com", email="todelete@example.com",
password="TestPassword123!", password="TestPassword123!",
first_name="To", first_name="To",
last_name="Delete" last_name="Delete",
) )
user = await user_crud.create(session, obj_in=user_data) user = await user_crud.create(session, obj_in=user_data)
user_id = user.id user_id = user.id
@@ -347,7 +382,7 @@ class TestCRUDBaseRemove:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_remove_nonexistent(self, async_test_db): async def test_remove_nonexistent(self, async_test_db):
"""Test remove of nonexistent record returns None.""" """Test remove of nonexistent record returns None."""
test_engine, SessionLocal = async_test_db _test_engine, SessionLocal = async_test_db
async with SessionLocal() as session: async with SessionLocal() as session:
result = await user_crud.remove(session, id=str(uuid4())) result = await user_crud.remove(session, id=str(uuid4()))
@@ -356,21 +391,31 @@ class TestCRUDBaseRemove:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_remove_integrity_error(self, async_test_db, async_test_user): async def test_remove_integrity_error(self, async_test_db, async_test_user):
"""Test remove with IntegrityError (foreign key constraint).""" """Test remove with IntegrityError (foreign key constraint)."""
test_engine, SessionLocal = async_test_db _test_engine, SessionLocal = async_test_db
async with SessionLocal() as session: async with SessionLocal() as session:
# Mock delete to raise IntegrityError # Mock delete to raise IntegrityError
with patch.object(session, 'commit', side_effect=IntegrityError("statement", {}, Exception("FOREIGN KEY constraint"))): with patch.object(
with pytest.raises(ValueError, match="Cannot delete.*referenced by other records"): session,
"commit",
side_effect=IntegrityError(
"statement", {}, Exception("FOREIGN KEY constraint")
),
):
with pytest.raises(
ValueError, match="Cannot delete.*referenced by other records"
):
await user_crud.remove(session, id=str(async_test_user.id)) await user_crud.remove(session, id=str(async_test_user.id))
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_remove_unexpected_error(self, async_test_db, async_test_user): async def test_remove_unexpected_error(self, async_test_db, async_test_user):
"""Test remove with unexpected error.""" """Test remove with unexpected error."""
test_engine, SessionLocal = async_test_db _test_engine, SessionLocal = async_test_db
async with SessionLocal() as session: async with SessionLocal() as session:
with patch.object(session, 'commit', side_effect=RuntimeError("Unexpected")): with patch.object(
session, "commit", side_effect=RuntimeError("Unexpected")
):
with pytest.raises(RuntimeError): with pytest.raises(RuntimeError):
await user_crud.remove(session, id=str(async_test_user.id)) await user_crud.remove(session, id=str(async_test_user.id))
@@ -381,10 +426,12 @@ class TestCRUDBaseGetMultiWithTotal:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_get_multi_with_total_basic(self, async_test_db, async_test_user): async def test_get_multi_with_total_basic(self, async_test_db, async_test_user):
"""Test get_multi_with_total basic functionality.""" """Test get_multi_with_total basic functionality."""
test_engine, SessionLocal = async_test_db _test_engine, SessionLocal = async_test_db
async with SessionLocal() as session: async with SessionLocal() as session:
items, total = await user_crud.get_multi_with_total(session, skip=0, limit=10) items, total = await user_crud.get_multi_with_total(
session, skip=0, limit=10
)
assert isinstance(items, list) assert isinstance(items, list)
assert isinstance(total, int) assert isinstance(total, int)
assert total >= 1 # At least the test user assert total >= 1 # At least the test user
@@ -392,7 +439,7 @@ class TestCRUDBaseGetMultiWithTotal:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_get_multi_with_total_negative_skip(self, async_test_db): async def test_get_multi_with_total_negative_skip(self, async_test_db):
"""Test get_multi_with_total with negative skip raises ValueError.""" """Test get_multi_with_total with negative skip raises ValueError."""
test_engine, SessionLocal = async_test_db _test_engine, SessionLocal = async_test_db
async with SessionLocal() as session: async with SessionLocal() as session:
with pytest.raises(ValueError, match="skip must be non-negative"): with pytest.raises(ValueError, match="skip must be non-negative"):
@@ -401,7 +448,7 @@ class TestCRUDBaseGetMultiWithTotal:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_get_multi_with_total_negative_limit(self, async_test_db): async def test_get_multi_with_total_negative_limit(self, async_test_db):
"""Test get_multi_with_total with negative limit raises ValueError.""" """Test get_multi_with_total with negative limit raises ValueError."""
test_engine, SessionLocal = async_test_db _test_engine, SessionLocal = async_test_db
async with SessionLocal() as session: async with SessionLocal() as session:
with pytest.raises(ValueError, match="limit must be non-negative"): with pytest.raises(ValueError, match="limit must be non-negative"):
@@ -410,28 +457,34 @@ class TestCRUDBaseGetMultiWithTotal:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_get_multi_with_total_limit_too_large(self, async_test_db): async def test_get_multi_with_total_limit_too_large(self, async_test_db):
"""Test get_multi_with_total with limit > 1000 raises ValueError.""" """Test get_multi_with_total with limit > 1000 raises ValueError."""
test_engine, SessionLocal = async_test_db _test_engine, SessionLocal = async_test_db
async with SessionLocal() as session: async with SessionLocal() as session:
with pytest.raises(ValueError, match="Maximum limit is 1000"): with pytest.raises(ValueError, match="Maximum limit is 1000"):
await user_crud.get_multi_with_total(session, limit=1001) await user_crud.get_multi_with_total(session, limit=1001)
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_get_multi_with_total_with_filters(self, async_test_db, async_test_user): async def test_get_multi_with_total_with_filters(
self, async_test_db, async_test_user
):
"""Test get_multi_with_total with filters.""" """Test get_multi_with_total with filters."""
test_engine, SessionLocal = async_test_db _test_engine, SessionLocal = async_test_db
async with SessionLocal() as session: async with SessionLocal() as session:
filters = {"email": async_test_user.email} filters = {"email": async_test_user.email}
items, total = await user_crud.get_multi_with_total(session, filters=filters) items, total = await user_crud.get_multi_with_total(
session, filters=filters
)
assert total == 1 assert total == 1
assert len(items) == 1 assert len(items) == 1
assert items[0].email == async_test_user.email assert items[0].email == async_test_user.email
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_get_multi_with_total_with_sorting_asc(self, async_test_db, async_test_user): async def test_get_multi_with_total_with_sorting_asc(
self, async_test_db, async_test_user
):
"""Test get_multi_with_total with ascending sort.""" """Test get_multi_with_total with ascending sort."""
test_engine, SessionLocal = async_test_db _test_engine, SessionLocal = async_test_db
# Create additional users # Create additional users
async with SessionLocal() as session: async with SessionLocal() as session:
@@ -439,13 +492,13 @@ class TestCRUDBaseGetMultiWithTotal:
email="aaa@example.com", email="aaa@example.com",
password="TestPassword123!", password="TestPassword123!",
first_name="AAA", first_name="AAA",
last_name="User" last_name="User",
) )
user_data2 = UserCreate( user_data2 = UserCreate(
email="zzz@example.com", email="zzz@example.com",
password="TestPassword123!", password="TestPassword123!",
first_name="ZZZ", first_name="ZZZ",
last_name="User" last_name="User",
) )
await user_crud.create(session, obj_in=user_data1) await user_crud.create(session, obj_in=user_data1)
await user_crud.create(session, obj_in=user_data2) await user_crud.create(session, obj_in=user_data2)
@@ -460,9 +513,11 @@ class TestCRUDBaseGetMultiWithTotal:
assert items[0].email == "aaa@example.com" assert items[0].email == "aaa@example.com"
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_get_multi_with_total_with_sorting_desc(self, async_test_db, async_test_user): async def test_get_multi_with_total_with_sorting_desc(
self, async_test_db, async_test_user
):
"""Test get_multi_with_total with descending sort.""" """Test get_multi_with_total with descending sort."""
test_engine, SessionLocal = async_test_db _test_engine, SessionLocal = async_test_db
# Create additional users # Create additional users
async with SessionLocal() as session: async with SessionLocal() as session:
@@ -470,20 +525,20 @@ class TestCRUDBaseGetMultiWithTotal:
email="bbb@example.com", email="bbb@example.com",
password="TestPassword123!", password="TestPassword123!",
first_name="BBB", first_name="BBB",
last_name="User" last_name="User",
) )
user_data2 = UserCreate( user_data2 = UserCreate(
email="ccc@example.com", email="ccc@example.com",
password="TestPassword123!", password="TestPassword123!",
first_name="CCC", first_name="CCC",
last_name="User" last_name="User",
) )
await user_crud.create(session, obj_in=user_data1) await user_crud.create(session, obj_in=user_data1)
await user_crud.create(session, obj_in=user_data2) await user_crud.create(session, obj_in=user_data2)
await session.commit() await session.commit()
async with SessionLocal() as session: async with SessionLocal() as session:
items, total = await user_crud.get_multi_with_total( items, _total = await user_crud.get_multi_with_total(
session, sort_by="email", sort_order="desc", limit=1 session, sort_by="email", sort_order="desc", limit=1
) )
assert len(items) == 1 assert len(items) == 1
@@ -492,7 +547,7 @@ class TestCRUDBaseGetMultiWithTotal:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_get_multi_with_total_with_pagination(self, async_test_db): async def test_get_multi_with_total_with_pagination(self, async_test_db):
"""Test get_multi_with_total pagination works correctly.""" """Test get_multi_with_total pagination works correctly."""
test_engine, SessionLocal = async_test_db _test_engine, SessionLocal = async_test_db
# Create minimal users for pagination test (3 instead of 5) # Create minimal users for pagination test (3 instead of 5)
async with SessionLocal() as session: async with SessionLocal() as session:
@@ -501,19 +556,23 @@ class TestCRUDBaseGetMultiWithTotal:
email=f"user{i}@example.com", email=f"user{i}@example.com",
password="TestPassword123!", password="TestPassword123!",
first_name=f"User{i}", first_name=f"User{i}",
last_name="Test" last_name="Test",
) )
await user_crud.create(session, obj_in=user_data) await user_crud.create(session, obj_in=user_data)
await session.commit() await session.commit()
async with SessionLocal() as session: async with SessionLocal() as session:
# Get first page # Get first page
items1, total = await user_crud.get_multi_with_total(session, skip=0, limit=2) items1, total = await user_crud.get_multi_with_total(
session, skip=0, limit=2
)
assert len(items1) == 2 assert len(items1) == 2
assert total >= 3 assert total >= 3
# Get second page # Get second page
items2, total2 = await user_crud.get_multi_with_total(session, skip=2, limit=2) items2, total2 = await user_crud.get_multi_with_total(
session, skip=2, limit=2
)
assert len(items2) >= 1 assert len(items2) >= 1
assert total2 == total assert total2 == total
@@ -529,7 +588,7 @@ class TestCRUDBaseCount:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_count_basic(self, async_test_db, async_test_user): async def test_count_basic(self, async_test_db, async_test_user):
"""Test count returns correct number.""" """Test count returns correct number."""
test_engine, SessionLocal = async_test_db _test_engine, SessionLocal = async_test_db
async with SessionLocal() as session: async with SessionLocal() as session:
count = await user_crud.count(session) count = await user_crud.count(session)
@@ -539,7 +598,7 @@ class TestCRUDBaseCount:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_count_multiple_users(self, async_test_db, async_test_user): async def test_count_multiple_users(self, async_test_db, async_test_user):
"""Test count with multiple users.""" """Test count with multiple users."""
test_engine, SessionLocal = async_test_db _test_engine, SessionLocal = async_test_db
# Create additional users # Create additional users
async with SessionLocal() as session: async with SessionLocal() as session:
@@ -549,13 +608,13 @@ class TestCRUDBaseCount:
email="count1@example.com", email="count1@example.com",
password="TestPassword123!", password="TestPassword123!",
first_name="Count", first_name="Count",
last_name="One" last_name="One",
) )
user_data2 = UserCreate( user_data2 = UserCreate(
email="count2@example.com", email="count2@example.com",
password="TestPassword123!", password="TestPassword123!",
first_name="Count", first_name="Count",
last_name="Two" last_name="Two",
) )
await user_crud.create(session, obj_in=user_data1) await user_crud.create(session, obj_in=user_data1)
await user_crud.create(session, obj_in=user_data2) await user_crud.create(session, obj_in=user_data2)
@@ -568,10 +627,10 @@ class TestCRUDBaseCount:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_count_database_error(self, async_test_db): async def test_count_database_error(self, async_test_db):
"""Test count handles database errors.""" """Test count handles database errors."""
test_engine, SessionLocal = async_test_db _test_engine, SessionLocal = async_test_db
async with SessionLocal() as session: async with SessionLocal() as session:
with patch.object(session, 'execute', side_effect=Exception("DB error")): with patch.object(session, "execute", side_effect=Exception("DB error")):
with pytest.raises(Exception, match="DB error"): with pytest.raises(Exception, match="DB error"):
await user_crud.count(session) await user_crud.count(session)
@@ -582,7 +641,7 @@ class TestCRUDBaseExists:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_exists_true(self, async_test_db, async_test_user): async def test_exists_true(self, async_test_db, async_test_user):
"""Test exists returns True for existing record.""" """Test exists returns True for existing record."""
test_engine, SessionLocal = async_test_db _test_engine, SessionLocal = async_test_db
async with SessionLocal() as session: async with SessionLocal() as session:
result = await user_crud.exists(session, id=str(async_test_user.id)) result = await user_crud.exists(session, id=str(async_test_user.id))
@@ -591,7 +650,7 @@ class TestCRUDBaseExists:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_exists_false(self, async_test_db): async def test_exists_false(self, async_test_db):
"""Test exists returns False for non-existent record.""" """Test exists returns False for non-existent record."""
test_engine, SessionLocal = async_test_db _test_engine, SessionLocal = async_test_db
async with SessionLocal() as session: async with SessionLocal() as session:
result = await user_crud.exists(session, id=str(uuid4())) result = await user_crud.exists(session, id=str(uuid4()))
@@ -600,7 +659,7 @@ class TestCRUDBaseExists:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_exists_invalid_uuid(self, async_test_db): async def test_exists_invalid_uuid(self, async_test_db):
"""Test exists returns False for invalid UUID.""" """Test exists returns False for invalid UUID."""
test_engine, SessionLocal = async_test_db _test_engine, SessionLocal = async_test_db
async with SessionLocal() as session: async with SessionLocal() as session:
result = await user_crud.exists(session, id="invalid-uuid") result = await user_crud.exists(session, id="invalid-uuid")
@@ -613,7 +672,7 @@ class TestCRUDBaseSoftDelete:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_soft_delete_success(self, async_test_db): async def test_soft_delete_success(self, async_test_db):
"""Test soft delete sets deleted_at timestamp.""" """Test soft delete sets deleted_at timestamp."""
test_engine, SessionLocal = async_test_db _test_engine, SessionLocal = async_test_db
# Create a user to soft delete # Create a user to soft delete
async with SessionLocal() as session: async with SessionLocal() as session:
@@ -621,7 +680,7 @@ class TestCRUDBaseSoftDelete:
email="softdelete@example.com", email="softdelete@example.com",
password="TestPassword123!", password="TestPassword123!",
first_name="Soft", first_name="Soft",
last_name="Delete" last_name="Delete",
) )
user = await user_crud.create(session, obj_in=user_data) user = await user_crud.create(session, obj_in=user_data)
user_id = user.id user_id = user.id
@@ -636,7 +695,7 @@ class TestCRUDBaseSoftDelete:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_soft_delete_invalid_uuid(self, async_test_db): async def test_soft_delete_invalid_uuid(self, async_test_db):
"""Test soft delete with invalid UUID returns None.""" """Test soft delete with invalid UUID returns None."""
test_engine, SessionLocal = async_test_db _test_engine, SessionLocal = async_test_db
async with SessionLocal() as session: async with SessionLocal() as session:
result = await user_crud.soft_delete(session, id="invalid-uuid") result = await user_crud.soft_delete(session, id="invalid-uuid")
@@ -645,7 +704,7 @@ class TestCRUDBaseSoftDelete:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_soft_delete_nonexistent(self, async_test_db): async def test_soft_delete_nonexistent(self, async_test_db):
"""Test soft delete of nonexistent record returns None.""" """Test soft delete of nonexistent record returns None."""
test_engine, SessionLocal = async_test_db _test_engine, SessionLocal = async_test_db
async with SessionLocal() as session: async with SessionLocal() as session:
result = await user_crud.soft_delete(session, id=str(uuid4())) result = await user_crud.soft_delete(session, id=str(uuid4()))
@@ -654,7 +713,7 @@ class TestCRUDBaseSoftDelete:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_soft_delete_with_uuid_object(self, async_test_db): async def test_soft_delete_with_uuid_object(self, async_test_db):
"""Test soft delete with UUID object.""" """Test soft delete with UUID object."""
test_engine, SessionLocal = async_test_db _test_engine, SessionLocal = async_test_db
# Create a user to soft delete # Create a user to soft delete
async with SessionLocal() as session: async with SessionLocal() as session:
@@ -662,7 +721,7 @@ class TestCRUDBaseSoftDelete:
email="softdelete2@example.com", email="softdelete2@example.com",
password="TestPassword123!", password="TestPassword123!",
first_name="Soft", first_name="Soft",
last_name="Delete2" last_name="Delete2",
) )
user = await user_crud.create(session, obj_in=user_data) user = await user_crud.create(session, obj_in=user_data)
user_id = user.id user_id = user.id
@@ -681,7 +740,7 @@ class TestCRUDBaseRestore:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_restore_success(self, async_test_db): async def test_restore_success(self, async_test_db):
"""Test restore clears deleted_at timestamp.""" """Test restore clears deleted_at timestamp."""
test_engine, SessionLocal = async_test_db _test_engine, SessionLocal = async_test_db
# Create and soft delete a user # Create and soft delete a user
async with SessionLocal() as session: async with SessionLocal() as session:
@@ -689,7 +748,7 @@ class TestCRUDBaseRestore:
email="restore@example.com", email="restore@example.com",
password="TestPassword123!", password="TestPassword123!",
first_name="Restore", first_name="Restore",
last_name="Test" last_name="Test",
) )
user = await user_crud.create(session, obj_in=user_data) user = await user_crud.create(session, obj_in=user_data)
user_id = user.id user_id = user.id
@@ -707,7 +766,7 @@ class TestCRUDBaseRestore:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_restore_invalid_uuid(self, async_test_db): async def test_restore_invalid_uuid(self, async_test_db):
"""Test restore with invalid UUID returns None.""" """Test restore with invalid UUID returns None."""
test_engine, SessionLocal = async_test_db _test_engine, SessionLocal = async_test_db
async with SessionLocal() as session: async with SessionLocal() as session:
result = await user_crud.restore(session, id="invalid-uuid") result = await user_crud.restore(session, id="invalid-uuid")
@@ -716,7 +775,7 @@ class TestCRUDBaseRestore:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_restore_nonexistent(self, async_test_db): async def test_restore_nonexistent(self, async_test_db):
"""Test restore of nonexistent record returns None.""" """Test restore of nonexistent record returns None."""
test_engine, SessionLocal = async_test_db _test_engine, SessionLocal = async_test_db
async with SessionLocal() as session: async with SessionLocal() as session:
result = await user_crud.restore(session, id=str(uuid4())) result = await user_crud.restore(session, id=str(uuid4()))
@@ -725,7 +784,7 @@ class TestCRUDBaseRestore:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_restore_not_deleted(self, async_test_db, async_test_user): async def test_restore_not_deleted(self, async_test_db, async_test_user):
"""Test restore of non-deleted record returns None.""" """Test restore of non-deleted record returns None."""
test_engine, SessionLocal = async_test_db _test_engine, SessionLocal = async_test_db
async with SessionLocal() as session: async with SessionLocal() as session:
# Try to restore a user that's not deleted # Try to restore a user that's not deleted
@@ -735,7 +794,7 @@ class TestCRUDBaseRestore:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_restore_with_uuid_object(self, async_test_db): async def test_restore_with_uuid_object(self, async_test_db):
"""Test restore with UUID object.""" """Test restore with UUID object."""
test_engine, SessionLocal = async_test_db _test_engine, SessionLocal = async_test_db
# Create and soft delete a user # Create and soft delete a user
async with SessionLocal() as session: async with SessionLocal() as session:
@@ -743,7 +802,7 @@ class TestCRUDBaseRestore:
email="restore2@example.com", email="restore2@example.com",
password="TestPassword123!", password="TestPassword123!",
first_name="Restore", first_name="Restore",
last_name="Test2" last_name="Test2",
) )
user = await user_crud.create(session, obj_in=user_data) user = await user_crud.create(session, obj_in=user_data)
user_id = user.id user_id = user.id
@@ -765,7 +824,7 @@ class TestCRUDBasePaginationValidation:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_get_multi_with_total_negative_skip(self, async_test_db): async def test_get_multi_with_total_negative_skip(self, async_test_db):
"""Test that negative skip raises ValueError.""" """Test that negative skip raises ValueError."""
test_engine, SessionLocal = async_test_db _test_engine, SessionLocal = async_test_db
async with SessionLocal() as session: async with SessionLocal() as session:
with pytest.raises(ValueError, match="skip must be non-negative"): with pytest.raises(ValueError, match="skip must be non-negative"):
@@ -774,7 +833,7 @@ class TestCRUDBasePaginationValidation:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_get_multi_with_total_negative_limit(self, async_test_db): async def test_get_multi_with_total_negative_limit(self, async_test_db):
"""Test that negative limit raises ValueError.""" """Test that negative limit raises ValueError."""
test_engine, SessionLocal = async_test_db _test_engine, SessionLocal = async_test_db
async with SessionLocal() as session: async with SessionLocal() as session:
with pytest.raises(ValueError, match="limit must be non-negative"): with pytest.raises(ValueError, match="limit must be non-negative"):
@@ -783,23 +842,22 @@ class TestCRUDBasePaginationValidation:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_get_multi_with_total_limit_too_large(self, async_test_db): async def test_get_multi_with_total_limit_too_large(self, async_test_db):
"""Test that limit > 1000 raises ValueError.""" """Test that limit > 1000 raises ValueError."""
test_engine, SessionLocal = async_test_db _test_engine, SessionLocal = async_test_db
async with SessionLocal() as session: async with SessionLocal() as session:
with pytest.raises(ValueError, match="Maximum limit is 1000"): with pytest.raises(ValueError, match="Maximum limit is 1000"):
await user_crud.get_multi_with_total(session, skip=0, limit=1001) await user_crud.get_multi_with_total(session, skip=0, limit=1001)
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_get_multi_with_total_with_filters(self, async_test_db, async_test_user): async def test_get_multi_with_total_with_filters(
self, async_test_db, async_test_user
):
"""Test pagination with filters (covers lines 270-273).""" """Test pagination with filters (covers lines 270-273)."""
test_engine, SessionLocal = async_test_db _test_engine, SessionLocal = async_test_db
async with SessionLocal() as session: async with SessionLocal() as session:
users, total = await user_crud.get_multi_with_total( users, total = await user_crud.get_multi_with_total(
session, session, skip=0, limit=10, filters={"is_active": True}
skip=0,
limit=10,
filters={"is_active": True}
) )
assert isinstance(users, list) assert isinstance(users, list)
assert total >= 0 assert total >= 0
@@ -807,30 +865,22 @@ class TestCRUDBasePaginationValidation:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_get_multi_with_total_with_sorting_desc(self, async_test_db): async def test_get_multi_with_total_with_sorting_desc(self, async_test_db):
"""Test pagination with descending sort (covers lines 283-284).""" """Test pagination with descending sort (covers lines 283-284)."""
test_engine, SessionLocal = async_test_db _test_engine, SessionLocal = async_test_db
async with SessionLocal() as session: async with SessionLocal() as session:
users, total = await user_crud.get_multi_with_total( users, _total = await user_crud.get_multi_with_total(
session, session, skip=0, limit=10, sort_by="created_at", sort_order="desc"
skip=0,
limit=10,
sort_by="created_at",
sort_order="desc"
) )
assert isinstance(users, list) assert isinstance(users, list)
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_get_multi_with_total_with_sorting_asc(self, async_test_db): async def test_get_multi_with_total_with_sorting_asc(self, async_test_db):
"""Test pagination with ascending sort (covers lines 285-286).""" """Test pagination with ascending sort (covers lines 285-286)."""
test_engine, SessionLocal = async_test_db _test_engine, SessionLocal = async_test_db
async with SessionLocal() as session: async with SessionLocal() as session:
users, total = await user_crud.get_multi_with_total( users, _total = await user_crud.get_multi_with_total(
session, session, skip=0, limit=10, sort_by="created_at", sort_order="asc"
skip=0,
limit=10,
sort_by="created_at",
sort_order="asc"
) )
assert isinstance(users, list) assert isinstance(users, list)
@@ -842,13 +892,15 @@ class TestCRUDBaseModelsWithoutSoftDelete:
""" """
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_soft_delete_model_without_deleted_at(self, async_test_db, async_test_user): async def test_soft_delete_model_without_deleted_at(
self, async_test_db, async_test_user
):
"""Test soft_delete on Organization model (no deleted_at) raises ValueError (covers lines 342-343).""" """Test soft_delete on Organization model (no deleted_at) raises ValueError (covers lines 342-343)."""
test_engine, SessionLocal = async_test_db _test_engine, SessionLocal = async_test_db
# Create an organization (which doesn't have deleted_at) # Create an organization (which doesn't have deleted_at)
from app.models.organization import Organization
from app.crud.organization import organization as org_crud from app.crud.organization import organization as org_crud
from app.models.organization import Organization
async with SessionLocal() as session: async with SessionLocal() as session:
org = Organization(name="Test Org", slug="test-org") org = Organization(name="Test Org", slug="test-org")
@@ -864,11 +916,11 @@ class TestCRUDBaseModelsWithoutSoftDelete:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_restore_model_without_deleted_at(self, async_test_db): async def test_restore_model_without_deleted_at(self, async_test_db):
"""Test restore on Organization model (no deleted_at) raises ValueError (covers lines 383-384).""" """Test restore on Organization model (no deleted_at) raises ValueError (covers lines 383-384)."""
test_engine, SessionLocal = async_test_db _test_engine, SessionLocal = async_test_db
# Create an organization (which doesn't have deleted_at) # Create an organization (which doesn't have deleted_at)
from app.models.organization import Organization
from app.crud.organization import organization as org_crud from app.crud.organization import organization as org_crud
from app.models.organization import Organization
async with SessionLocal() as session: async with SessionLocal() as session:
org = Organization(name="Restore Test", slug="restore-test") org = Organization(name="Restore Test", slug="restore-test")
@@ -889,14 +941,17 @@ class TestCRUDBaseEagerLoadingWithRealOptions:
""" """
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_get_with_real_eager_loading_options(self, async_test_db, async_test_user): async def test_get_with_real_eager_loading_options(
self, async_test_db, async_test_user
):
"""Test get() with actual eager loading options (covers lines 77-78).""" """Test get() with actual eager loading options (covers lines 77-78)."""
from datetime import datetime, timedelta, timezone from datetime import datetime, timedelta
test_engine, SessionLocal = async_test_db
_test_engine, SessionLocal = async_test_db
# Create a session for the user # Create a session for the user
from app.models.user_session import UserSession
from app.crud.session import session as session_crud from app.crud.session import session as session_crud
from app.models.user_session import UserSession
async with SessionLocal() as session: async with SessionLocal() as session:
user_session = UserSession( user_session = UserSession(
@@ -905,8 +960,8 @@ class TestCRUDBaseEagerLoadingWithRealOptions:
device_id="test-device", device_id="test-device",
ip_address="192.168.1.1", ip_address="192.168.1.1",
user_agent="Test Agent", user_agent="Test Agent",
last_used_at=datetime.now(timezone.utc), last_used_at=datetime.now(UTC),
expires_at=datetime.now(timezone.utc) + timedelta(days=60) expires_at=datetime.now(UTC) + timedelta(days=60),
) )
session.add(user_session) session.add(user_session)
await session.commit() await session.commit()
@@ -917,7 +972,7 @@ class TestCRUDBaseEagerLoadingWithRealOptions:
result = await session_crud.get( result = await session_crud.get(
session, session,
id=str(session_id), id=str(session_id),
options=[joinedload(UserSession.user)] # Real option, not empty list options=[joinedload(UserSession.user)], # Real option, not empty list
) )
assert result is not None assert result is not None
assert result.id == session_id assert result.id == session_id
@@ -925,14 +980,17 @@ class TestCRUDBaseEagerLoadingWithRealOptions:
assert result.user.email == async_test_user.email assert result.user.email == async_test_user.email
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_get_multi_with_real_eager_loading_options(self, async_test_db, async_test_user): async def test_get_multi_with_real_eager_loading_options(
self, async_test_db, async_test_user
):
"""Test get_multi() with actual eager loading options (covers lines 119-120).""" """Test get_multi() with actual eager loading options (covers lines 119-120)."""
from datetime import datetime, timedelta, timezone from datetime import datetime, timedelta
test_engine, SessionLocal = async_test_db
_test_engine, SessionLocal = async_test_db
# Create multiple sessions for the user # Create multiple sessions for the user
from app.models.user_session import UserSession
from app.crud.session import session as session_crud from app.crud.session import session as session_crud
from app.models.user_session import UserSession
async with SessionLocal() as session: async with SessionLocal() as session:
for i in range(3): for i in range(3):
@@ -942,8 +1000,8 @@ class TestCRUDBaseEagerLoadingWithRealOptions:
device_id=f"device-{i}", device_id=f"device-{i}",
ip_address=f"192.168.1.{i}", ip_address=f"192.168.1.{i}",
user_agent=f"Agent {i}", user_agent=f"Agent {i}",
last_used_at=datetime.now(timezone.utc), last_used_at=datetime.now(UTC),
expires_at=datetime.now(timezone.utc) + timedelta(days=60) expires_at=datetime.now(UTC) + timedelta(days=60),
) )
session.add(user_session) session.add(user_session)
await session.commit() await session.commit()
@@ -954,7 +1012,7 @@ class TestCRUDBaseEagerLoadingWithRealOptions:
session, session,
skip=0, skip=0,
limit=10, limit=10,
options=[joinedload(UserSession.user)] # Real option, not empty list options=[joinedload(UserSession.user)], # Real option, not empty list
) )
assert len(results) >= 3 assert len(results) >= 3
# Verify we can access user without additional queries # Verify we can access user without additional queries

View File

@@ -3,13 +3,15 @@
Comprehensive tests for base CRUD database failure scenarios. Comprehensive tests for base CRUD database failure scenarios.
Tests exception handling, rollbacks, and error messages. Tests exception handling, rollbacks, and error messages.
""" """
import pytest
from unittest.mock import AsyncMock, MagicMock, patch from unittest.mock import AsyncMock, patch
from sqlalchemy.exc import IntegrityError, OperationalError, DataError
from uuid import uuid4 from uuid import uuid4
import pytest
from sqlalchemy.exc import DataError, OperationalError
from app.crud.user import user as user_crud from app.crud.user import user as user_crud
from app.schemas.users import UserCreate, UserUpdate from app.schemas.users import UserCreate
class TestBaseCRUDCreateFailures: class TestBaseCRUDCreateFailures:
@@ -18,19 +20,24 @@ class TestBaseCRUDCreateFailures:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_create_operational_error_triggers_rollback(self, async_test_db): async def test_create_operational_error_triggers_rollback(self, async_test_db):
"""Test that OperationalError triggers rollback (User CRUD catches as Exception).""" """Test that OperationalError triggers rollback (User CRUD catches as Exception)."""
test_engine, SessionLocal = async_test_db _test_engine, SessionLocal = async_test_db
async with SessionLocal() as session: async with SessionLocal() as session:
async def mock_commit():
raise OperationalError("Connection lost", {}, Exception("DB connection failed"))
with patch.object(session, 'commit', side_effect=mock_commit): async def mock_commit():
with patch.object(session, 'rollback', new_callable=AsyncMock) as mock_rollback: raise OperationalError(
"Connection lost", {}, Exception("DB connection failed")
)
with patch.object(session, "commit", side_effect=mock_commit):
with patch.object(
session, "rollback", new_callable=AsyncMock
) as mock_rollback:
user_data = UserCreate( user_data = UserCreate(
email="operror@example.com", email="operror@example.com",
password="TestPassword123!", password="TestPassword123!",
first_name="Test", first_name="Test",
last_name="User" last_name="User",
) )
# User CRUD catches this as generic Exception and re-raises # User CRUD catches this as generic Exception and re-raises
@@ -43,19 +50,22 @@ class TestBaseCRUDCreateFailures:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_create_data_error_triggers_rollback(self, async_test_db): async def test_create_data_error_triggers_rollback(self, async_test_db):
"""Test that DataError triggers rollback (User CRUD catches as Exception).""" """Test that DataError triggers rollback (User CRUD catches as Exception)."""
test_engine, SessionLocal = async_test_db _test_engine, SessionLocal = async_test_db
async with SessionLocal() as session: async with SessionLocal() as session:
async def mock_commit(): async def mock_commit():
raise DataError("Invalid data type", {}, Exception("Data overflow")) raise DataError("Invalid data type", {}, Exception("Data overflow"))
with patch.object(session, 'commit', side_effect=mock_commit): with patch.object(session, "commit", side_effect=mock_commit):
with patch.object(session, 'rollback', new_callable=AsyncMock) as mock_rollback: with patch.object(
session, "rollback", new_callable=AsyncMock
) as mock_rollback:
user_data = UserCreate( user_data = UserCreate(
email="dataerror@example.com", email="dataerror@example.com",
password="TestPassword123!", password="TestPassword123!",
first_name="Test", first_name="Test",
last_name="User" last_name="User",
) )
# User CRUD catches this as generic Exception and re-raises # User CRUD catches this as generic Exception and re-raises
@@ -67,19 +77,22 @@ class TestBaseCRUDCreateFailures:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_create_unexpected_exception_triggers_rollback(self, async_test_db): async def test_create_unexpected_exception_triggers_rollback(self, async_test_db):
"""Test that unexpected exceptions trigger rollback and re-raise.""" """Test that unexpected exceptions trigger rollback and re-raise."""
test_engine, SessionLocal = async_test_db _test_engine, SessionLocal = async_test_db
async with SessionLocal() as session: async with SessionLocal() as session:
async def mock_commit(): async def mock_commit():
raise RuntimeError("Unexpected database error") raise RuntimeError("Unexpected database error")
with patch.object(session, 'commit', side_effect=mock_commit): with patch.object(session, "commit", side_effect=mock_commit):
with patch.object(session, 'rollback', new_callable=AsyncMock) as mock_rollback: with patch.object(
session, "rollback", new_callable=AsyncMock
) as mock_rollback:
user_data = UserCreate( user_data = UserCreate(
email="unexpected@example.com", email="unexpected@example.com",
password="TestPassword123!", password="TestPassword123!",
first_name="Test", first_name="Test",
last_name="User" last_name="User",
) )
with pytest.raises(RuntimeError, match="Unexpected database error"): with pytest.raises(RuntimeError, match="Unexpected database error"):
@@ -94,7 +107,7 @@ class TestBaseCRUDUpdateFailures:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_update_operational_error(self, async_test_db, async_test_user): async def test_update_operational_error(self, async_test_db, async_test_user):
"""Test update with OperationalError.""" """Test update with OperationalError."""
test_engine, SessionLocal = async_test_db _test_engine, SessionLocal = async_test_db
async with SessionLocal() as session: async with SessionLocal() as session:
user = await user_crud.get(session, id=str(async_test_user.id)) user = await user_crud.get(session, id=str(async_test_user.id))
@@ -102,17 +115,21 @@ class TestBaseCRUDUpdateFailures:
async def mock_commit(): async def mock_commit():
raise OperationalError("Connection timeout", {}, Exception("Timeout")) raise OperationalError("Connection timeout", {}, Exception("Timeout"))
with patch.object(session, 'commit', side_effect=mock_commit): with patch.object(session, "commit", side_effect=mock_commit):
with patch.object(session, 'rollback', new_callable=AsyncMock) as mock_rollback: with patch.object(
session, "rollback", new_callable=AsyncMock
) as mock_rollback:
with pytest.raises(ValueError, match="Database operation failed"): with pytest.raises(ValueError, match="Database operation failed"):
await user_crud.update(session, db_obj=user, obj_in={"first_name": "Updated"}) await user_crud.update(
session, db_obj=user, obj_in={"first_name": "Updated"}
)
mock_rollback.assert_called_once() mock_rollback.assert_called_once()
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_update_data_error(self, async_test_db, async_test_user): async def test_update_data_error(self, async_test_db, async_test_user):
"""Test update with DataError.""" """Test update with DataError."""
test_engine, SessionLocal = async_test_db _test_engine, SessionLocal = async_test_db
async with SessionLocal() as session: async with SessionLocal() as session:
user = await user_crud.get(session, id=str(async_test_user.id)) user = await user_crud.get(session, id=str(async_test_user.id))
@@ -120,17 +137,21 @@ class TestBaseCRUDUpdateFailures:
async def mock_commit(): async def mock_commit():
raise DataError("Invalid data", {}, Exception("Data type mismatch")) raise DataError("Invalid data", {}, Exception("Data type mismatch"))
with patch.object(session, 'commit', side_effect=mock_commit): with patch.object(session, "commit", side_effect=mock_commit):
with patch.object(session, 'rollback', new_callable=AsyncMock) as mock_rollback: with patch.object(
session, "rollback", new_callable=AsyncMock
) as mock_rollback:
with pytest.raises(ValueError, match="Database operation failed"): with pytest.raises(ValueError, match="Database operation failed"):
await user_crud.update(session, db_obj=user, obj_in={"first_name": "Updated"}) await user_crud.update(
session, db_obj=user, obj_in={"first_name": "Updated"}
)
mock_rollback.assert_called_once() mock_rollback.assert_called_once()
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_update_unexpected_error(self, async_test_db, async_test_user): async def test_update_unexpected_error(self, async_test_db, async_test_user):
"""Test update with unexpected error.""" """Test update with unexpected error."""
test_engine, SessionLocal = async_test_db _test_engine, SessionLocal = async_test_db
async with SessionLocal() as session: async with SessionLocal() as session:
user = await user_crud.get(session, id=str(async_test_user.id)) user = await user_crud.get(session, id=str(async_test_user.id))
@@ -138,10 +159,14 @@ class TestBaseCRUDUpdateFailures:
async def mock_commit(): async def mock_commit():
raise KeyError("Unexpected error") raise KeyError("Unexpected error")
with patch.object(session, 'commit', side_effect=mock_commit): with patch.object(session, "commit", side_effect=mock_commit):
with patch.object(session, 'rollback', new_callable=AsyncMock) as mock_rollback: with patch.object(
session, "rollback", new_callable=AsyncMock
) as mock_rollback:
with pytest.raises(KeyError): with pytest.raises(KeyError):
await user_crud.update(session, db_obj=user, obj_in={"first_name": "Updated"}) await user_crud.update(
session, db_obj=user, obj_in={"first_name": "Updated"}
)
mock_rollback.assert_called_once() mock_rollback.assert_called_once()
@@ -150,16 +175,21 @@ class TestBaseCRUDRemoveFailures:
"""Test base CRUD remove method exception handling.""" """Test base CRUD remove method exception handling."""
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_remove_unexpected_error_triggers_rollback(self, async_test_db, async_test_user): async def test_remove_unexpected_error_triggers_rollback(
self, async_test_db, async_test_user
):
"""Test that unexpected errors in remove trigger rollback.""" """Test that unexpected errors in remove trigger rollback."""
test_engine, SessionLocal = async_test_db _test_engine, SessionLocal = async_test_db
async with SessionLocal() as session: async with SessionLocal() as session:
async def mock_commit(): async def mock_commit():
raise RuntimeError("Database write failed") raise RuntimeError("Database write failed")
with patch.object(session, 'commit', side_effect=mock_commit): with patch.object(session, "commit", side_effect=mock_commit):
with patch.object(session, 'rollback', new_callable=AsyncMock) as mock_rollback: with patch.object(
session, "rollback", new_callable=AsyncMock
) as mock_rollback:
with pytest.raises(RuntimeError, match="Database write failed"): with pytest.raises(RuntimeError, match="Database write failed"):
await user_crud.remove(session, id=str(async_test_user.id)) await user_crud.remove(session, id=str(async_test_user.id))
@@ -172,16 +202,15 @@ class TestBaseCRUDGetMultiWithTotalFailures:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_get_multi_with_total_database_error(self, async_test_db): async def test_get_multi_with_total_database_error(self, async_test_db):
"""Test get_multi_with_total handles database errors.""" """Test get_multi_with_total handles database errors."""
test_engine, SessionLocal = async_test_db _test_engine, SessionLocal = async_test_db
async with SessionLocal() as session: async with SessionLocal() as session:
# Mock execute to raise an error # Mock execute to raise an error
original_execute = session.execute
async def mock_execute(*args, **kwargs): async def mock_execute(*args, **kwargs):
raise OperationalError("Query failed", {}, Exception("Database error")) raise OperationalError("Query failed", {}, Exception("Database error"))
with patch.object(session, 'execute', side_effect=mock_execute): with patch.object(session, "execute", side_effect=mock_execute):
with pytest.raises(OperationalError): with pytest.raises(OperationalError):
await user_crud.get_multi_with_total(session, skip=0, limit=10) await user_crud.get_multi_with_total(session, skip=0, limit=10)
@@ -192,13 +221,14 @@ class TestBaseCRUDCountFailures:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_count_database_error_propagates(self, async_test_db): async def test_count_database_error_propagates(self, async_test_db):
"""Test count propagates database errors.""" """Test count propagates database errors."""
test_engine, SessionLocal = async_test_db _test_engine, SessionLocal = async_test_db
async with SessionLocal() as session: async with SessionLocal() as session:
async def mock_execute(*args, **kwargs): async def mock_execute(*args, **kwargs):
raise OperationalError("Count failed", {}, Exception("DB error")) raise OperationalError("Count failed", {}, Exception("DB error"))
with patch.object(session, 'execute', side_effect=mock_execute): with patch.object(session, "execute", side_effect=mock_execute):
with pytest.raises(OperationalError): with pytest.raises(OperationalError):
await user_crud.count(session) await user_crud.count(session)
@@ -207,16 +237,21 @@ class TestBaseCRUDSoftDeleteFailures:
"""Test soft_delete method exception handling.""" """Test soft_delete method exception handling."""
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_soft_delete_unexpected_error_triggers_rollback(self, async_test_db, async_test_user): async def test_soft_delete_unexpected_error_triggers_rollback(
self, async_test_db, async_test_user
):
"""Test soft_delete handles unexpected errors with rollback.""" """Test soft_delete handles unexpected errors with rollback."""
test_engine, SessionLocal = async_test_db _test_engine, SessionLocal = async_test_db
async with SessionLocal() as session: async with SessionLocal() as session:
async def mock_commit(): async def mock_commit():
raise RuntimeError("Soft delete failed") raise RuntimeError("Soft delete failed")
with patch.object(session, 'commit', side_effect=mock_commit): with patch.object(session, "commit", side_effect=mock_commit):
with patch.object(session, 'rollback', new_callable=AsyncMock) as mock_rollback: with patch.object(
session, "rollback", new_callable=AsyncMock
) as mock_rollback:
with pytest.raises(RuntimeError, match="Soft delete failed"): with pytest.raises(RuntimeError, match="Soft delete failed"):
await user_crud.soft_delete(session, id=str(async_test_user.id)) await user_crud.soft_delete(session, id=str(async_test_user.id))
@@ -229,7 +264,7 @@ class TestBaseCRUDRestoreFailures:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_restore_unexpected_error_triggers_rollback(self, async_test_db): async def test_restore_unexpected_error_triggers_rollback(self, async_test_db):
"""Test restore handles unexpected errors with rollback.""" """Test restore handles unexpected errors with rollback."""
test_engine, SessionLocal = async_test_db _test_engine, SessionLocal = async_test_db
# First create and soft delete a user # First create and soft delete a user
async with SessionLocal() as session: async with SessionLocal() as session:
@@ -237,7 +272,7 @@ class TestBaseCRUDRestoreFailures:
email="restore_test@example.com", email="restore_test@example.com",
password="TestPassword123!", password="TestPassword123!",
first_name="Restore", first_name="Restore",
last_name="Test" last_name="Test",
) )
user = await user_crud.create(session, obj_in=user_data) user = await user_crud.create(session, obj_in=user_data)
user_id = user.id user_id = user.id
@@ -248,11 +283,14 @@ class TestBaseCRUDRestoreFailures:
# Now test restore failure # Now test restore failure
async with SessionLocal() as session: async with SessionLocal() as session:
async def mock_commit(): async def mock_commit():
raise RuntimeError("Restore failed") raise RuntimeError("Restore failed")
with patch.object(session, 'commit', side_effect=mock_commit): with patch.object(session, "commit", side_effect=mock_commit):
with patch.object(session, 'rollback', new_callable=AsyncMock) as mock_rollback: with patch.object(
session, "rollback", new_callable=AsyncMock
) as mock_rollback:
with pytest.raises(RuntimeError, match="Restore failed"): with pytest.raises(RuntimeError, match="Restore failed"):
await user_crud.restore(session, id=str(user_id)) await user_crud.restore(session, id=str(user_id))
@@ -265,13 +303,14 @@ class TestBaseCRUDGetFailures:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_get_database_error_propagates(self, async_test_db): async def test_get_database_error_propagates(self, async_test_db):
"""Test get propagates database errors.""" """Test get propagates database errors."""
test_engine, SessionLocal = async_test_db _test_engine, SessionLocal = async_test_db
async with SessionLocal() as session: async with SessionLocal() as session:
async def mock_execute(*args, **kwargs): async def mock_execute(*args, **kwargs):
raise OperationalError("Get failed", {}, Exception("DB error")) raise OperationalError("Get failed", {}, Exception("DB error"))
with patch.object(session, 'execute', side_effect=mock_execute): with patch.object(session, "execute", side_effect=mock_execute):
with pytest.raises(OperationalError): with pytest.raises(OperationalError):
await user_crud.get(session, id=str(uuid4())) await user_crud.get(session, id=str(uuid4()))
@@ -282,12 +321,13 @@ class TestBaseCRUDGetMultiFailures:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_get_multi_database_error_propagates(self, async_test_db): async def test_get_multi_database_error_propagates(self, async_test_db):
"""Test get_multi propagates database errors.""" """Test get_multi propagates database errors."""
test_engine, SessionLocal = async_test_db _test_engine, SessionLocal = async_test_db
async with SessionLocal() as session: async with SessionLocal() as session:
async def mock_execute(*args, **kwargs): async def mock_execute(*args, **kwargs):
raise OperationalError("Query failed", {}, Exception("DB error")) raise OperationalError("Query failed", {}, Exception("DB error"))
with patch.object(session, 'execute', side_effect=mock_execute): with patch.object(session, "execute", side_effect=mock_execute):
with pytest.raises(OperationalError): with pytest.raises(OperationalError):
await user_crud.get_multi(session, skip=0, limit=10) await user_crud.get_multi(session, skip=0, limit=10)

File diff suppressed because it is too large Load Diff

View File

@@ -2,10 +2,12 @@
""" """
Comprehensive tests for async session CRUD operations. Comprehensive tests for async session CRUD operations.
""" """
import pytest
from datetime import datetime, timedelta, timezone from datetime import UTC, datetime, timedelta
from uuid import uuid4 from uuid import uuid4
import pytest
from app.crud.session import session as session_crud from app.crud.session import session as session_crud
from app.models.user_session import UserSession from app.models.user_session import UserSession
from app.schemas.sessions import SessionCreate from app.schemas.sessions import SessionCreate
@@ -17,7 +19,7 @@ class TestGetByJti:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_get_by_jti_success(self, async_test_db, async_test_user): async def test_get_by_jti_success(self, async_test_db, async_test_user):
"""Test getting session by JTI.""" """Test getting session by JTI."""
test_engine, AsyncTestingSessionLocal = async_test_db _test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session: async with AsyncTestingSessionLocal() as session:
user_session = UserSession( user_session = UserSession(
@@ -27,8 +29,8 @@ class TestGetByJti:
ip_address="192.168.1.1", ip_address="192.168.1.1",
user_agent="Mozilla/5.0", user_agent="Mozilla/5.0",
is_active=True, is_active=True,
expires_at=datetime.now(timezone.utc) + timedelta(days=7), expires_at=datetime.now(UTC) + timedelta(days=7),
last_used_at=datetime.now(timezone.utc) last_used_at=datetime.now(UTC),
) )
session.add(user_session) session.add(user_session)
await session.commit() await session.commit()
@@ -41,7 +43,7 @@ class TestGetByJti:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_get_by_jti_not_found(self, async_test_db): async def test_get_by_jti_not_found(self, async_test_db):
"""Test getting non-existent JTI returns None.""" """Test getting non-existent JTI returns None."""
test_engine, AsyncTestingSessionLocal = async_test_db _test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session: async with AsyncTestingSessionLocal() as session:
result = await session_crud.get_by_jti(session, jti="nonexistent") result = await session_crud.get_by_jti(session, jti="nonexistent")
@@ -54,7 +56,7 @@ class TestGetActiveByJti:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_get_active_by_jti_success(self, async_test_db, async_test_user): async def test_get_active_by_jti_success(self, async_test_db, async_test_user):
"""Test getting active session by JTI.""" """Test getting active session by JTI."""
test_engine, AsyncTestingSessionLocal = async_test_db _test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session: async with AsyncTestingSessionLocal() as session:
user_session = UserSession( user_session = UserSession(
@@ -64,8 +66,8 @@ class TestGetActiveByJti:
ip_address="192.168.1.1", ip_address="192.168.1.1",
user_agent="Mozilla/5.0", user_agent="Mozilla/5.0",
is_active=True, is_active=True,
expires_at=datetime.now(timezone.utc) + timedelta(days=7), expires_at=datetime.now(UTC) + timedelta(days=7),
last_used_at=datetime.now(timezone.utc) last_used_at=datetime.now(UTC),
) )
session.add(user_session) session.add(user_session)
await session.commit() await session.commit()
@@ -78,7 +80,7 @@ class TestGetActiveByJti:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_get_active_by_jti_inactive(self, async_test_db, async_test_user): async def test_get_active_by_jti_inactive(self, async_test_db, async_test_user):
"""Test getting inactive session by JTI returns None.""" """Test getting inactive session by JTI returns None."""
test_engine, AsyncTestingSessionLocal = async_test_db _test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session: async with AsyncTestingSessionLocal() as session:
user_session = UserSession( user_session = UserSession(
@@ -88,8 +90,8 @@ class TestGetActiveByJti:
ip_address="192.168.1.1", ip_address="192.168.1.1",
user_agent="Mozilla/5.0", user_agent="Mozilla/5.0",
is_active=False, is_active=False,
expires_at=datetime.now(timezone.utc) + timedelta(days=7), expires_at=datetime.now(UTC) + timedelta(days=7),
last_used_at=datetime.now(timezone.utc) last_used_at=datetime.now(UTC),
) )
session.add(user_session) session.add(user_session)
await session.commit() await session.commit()
@@ -105,7 +107,7 @@ class TestGetUserSessions:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_get_user_sessions_active_only(self, async_test_db, async_test_user): async def test_get_user_sessions_active_only(self, async_test_db, async_test_user):
"""Test getting only active user sessions.""" """Test getting only active user sessions."""
test_engine, AsyncTestingSessionLocal = async_test_db _test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session: async with AsyncTestingSessionLocal() as session:
active = UserSession( active = UserSession(
@@ -115,8 +117,8 @@ class TestGetUserSessions:
ip_address="192.168.1.1", ip_address="192.168.1.1",
user_agent="Mozilla/5.0", user_agent="Mozilla/5.0",
is_active=True, is_active=True,
expires_at=datetime.now(timezone.utc) + timedelta(days=7), expires_at=datetime.now(UTC) + timedelta(days=7),
last_used_at=datetime.now(timezone.utc) last_used_at=datetime.now(UTC),
) )
inactive = UserSession( inactive = UserSession(
user_id=async_test_user.id, user_id=async_test_user.id,
@@ -125,17 +127,15 @@ class TestGetUserSessions:
ip_address="192.168.1.2", ip_address="192.168.1.2",
user_agent="Mozilla/5.0", user_agent="Mozilla/5.0",
is_active=False, is_active=False,
expires_at=datetime.now(timezone.utc) + timedelta(days=7), expires_at=datetime.now(UTC) + timedelta(days=7),
last_used_at=datetime.now(timezone.utc) last_used_at=datetime.now(UTC),
) )
session.add_all([active, inactive]) session.add_all([active, inactive])
await session.commit() await session.commit()
async with AsyncTestingSessionLocal() as session: async with AsyncTestingSessionLocal() as session:
results = await session_crud.get_user_sessions( results = await session_crud.get_user_sessions(
session, session, user_id=str(async_test_user.id), active_only=True
user_id=str(async_test_user.id),
active_only=True
) )
assert len(results) == 1 assert len(results) == 1
assert results[0].is_active is True assert results[0].is_active is True
@@ -143,7 +143,7 @@ class TestGetUserSessions:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_get_user_sessions_all(self, async_test_db, async_test_user): async def test_get_user_sessions_all(self, async_test_db, async_test_user):
"""Test getting all user sessions.""" """Test getting all user sessions."""
test_engine, AsyncTestingSessionLocal = async_test_db _test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session: async with AsyncTestingSessionLocal() as session:
for i in range(3): for i in range(3):
@@ -154,17 +154,15 @@ class TestGetUserSessions:
ip_address="192.168.1.1", ip_address="192.168.1.1",
user_agent="Mozilla/5.0", user_agent="Mozilla/5.0",
is_active=i % 2 == 0, is_active=i % 2 == 0,
expires_at=datetime.now(timezone.utc) + timedelta(days=7), expires_at=datetime.now(UTC) + timedelta(days=7),
last_used_at=datetime.now(timezone.utc) last_used_at=datetime.now(UTC),
) )
session.add(sess) session.add(sess)
await session.commit() await session.commit()
async with AsyncTestingSessionLocal() as session: async with AsyncTestingSessionLocal() as session:
results = await session_crud.get_user_sessions( results = await session_crud.get_user_sessions(
session, session, user_id=str(async_test_user.id), active_only=False
user_id=str(async_test_user.id),
active_only=False
) )
assert len(results) == 3 assert len(results) == 3
@@ -175,7 +173,7 @@ class TestCreateSession:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_create_session_success(self, async_test_db, async_test_user): async def test_create_session_success(self, async_test_db, async_test_user):
"""Test successfully creating a session_crud.""" """Test successfully creating a session_crud."""
test_engine, AsyncTestingSessionLocal = async_test_db _test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session: async with AsyncTestingSessionLocal() as session:
session_data = SessionCreate( session_data = SessionCreate(
@@ -185,10 +183,10 @@ class TestCreateSession:
device_id="device_123", device_id="device_123",
ip_address="192.168.1.100", ip_address="192.168.1.100",
user_agent="Mozilla/5.0", user_agent="Mozilla/5.0",
last_used_at=datetime.now(timezone.utc), last_used_at=datetime.now(UTC),
expires_at=datetime.now(timezone.utc) + timedelta(days=7), expires_at=datetime.now(UTC) + timedelta(days=7),
location_city="San Francisco", location_city="San Francisco",
location_country="USA" location_country="USA",
) )
result = await session_crud.create_session(session, obj_in=session_data) result = await session_crud.create_session(session, obj_in=session_data)
@@ -204,7 +202,7 @@ class TestDeactivate:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_deactivate_success(self, async_test_db, async_test_user): async def test_deactivate_success(self, async_test_db, async_test_user):
"""Test successfully deactivating a session_crud.""" """Test successfully deactivating a session_crud."""
test_engine, AsyncTestingSessionLocal = async_test_db _test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session: async with AsyncTestingSessionLocal() as session:
user_session = UserSession( user_session = UserSession(
@@ -214,8 +212,8 @@ class TestDeactivate:
ip_address="192.168.1.1", ip_address="192.168.1.1",
user_agent="Mozilla/5.0", user_agent="Mozilla/5.0",
is_active=True, is_active=True,
expires_at=datetime.now(timezone.utc) + timedelta(days=7), expires_at=datetime.now(UTC) + timedelta(days=7),
last_used_at=datetime.now(timezone.utc) last_used_at=datetime.now(UTC),
) )
session.add(user_session) session.add(user_session)
await session.commit() await session.commit()
@@ -229,7 +227,7 @@ class TestDeactivate:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_deactivate_not_found(self, async_test_db): async def test_deactivate_not_found(self, async_test_db):
"""Test deactivating non-existent session returns None.""" """Test deactivating non-existent session returns None."""
test_engine, AsyncTestingSessionLocal = async_test_db _test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session: async with AsyncTestingSessionLocal() as session:
result = await session_crud.deactivate(session, session_id=str(uuid4())) result = await session_crud.deactivate(session, session_id=str(uuid4()))
@@ -240,9 +238,11 @@ class TestDeactivateAllUserSessions:
"""Tests for deactivate_all_user_sessions method.""" """Tests for deactivate_all_user_sessions method."""
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_deactivate_all_user_sessions_success(self, async_test_db, async_test_user): async def test_deactivate_all_user_sessions_success(
self, async_test_db, async_test_user
):
"""Test deactivating all user sessions.""" """Test deactivating all user sessions."""
test_engine, AsyncTestingSessionLocal = async_test_db _test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session: async with AsyncTestingSessionLocal() as session:
# Create minimal sessions for test (2 instead of 5) # Create minimal sessions for test (2 instead of 5)
@@ -254,16 +254,15 @@ class TestDeactivateAllUserSessions:
ip_address="192.168.1.1", ip_address="192.168.1.1",
user_agent="Mozilla/5.0", user_agent="Mozilla/5.0",
is_active=True, is_active=True,
expires_at=datetime.now(timezone.utc) + timedelta(days=7), expires_at=datetime.now(UTC) + timedelta(days=7),
last_used_at=datetime.now(timezone.utc) last_used_at=datetime.now(UTC),
) )
session.add(sess) session.add(sess)
await session.commit() await session.commit()
async with AsyncTestingSessionLocal() as session: async with AsyncTestingSessionLocal() as session:
count = await session_crud.deactivate_all_user_sessions( count = await session_crud.deactivate_all_user_sessions(
session, session, user_id=str(async_test_user.id)
user_id=str(async_test_user.id)
) )
assert count == 2 assert count == 2
@@ -274,7 +273,7 @@ class TestUpdateLastUsed:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_update_last_used_success(self, async_test_db, async_test_user): async def test_update_last_used_success(self, async_test_db, async_test_user):
"""Test updating last_used_at timestamp.""" """Test updating last_used_at timestamp."""
test_engine, AsyncTestingSessionLocal = async_test_db _test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session: async with AsyncTestingSessionLocal() as session:
user_session = UserSession( user_session = UserSession(
@@ -284,8 +283,8 @@ class TestUpdateLastUsed:
ip_address="192.168.1.1", ip_address="192.168.1.1",
user_agent="Mozilla/5.0", user_agent="Mozilla/5.0",
is_active=True, is_active=True,
expires_at=datetime.now(timezone.utc) + timedelta(days=7), expires_at=datetime.now(UTC) + timedelta(days=7),
last_used_at=datetime.now(timezone.utc) - timedelta(hours=1) last_used_at=datetime.now(UTC) - timedelta(hours=1),
) )
session.add(user_session) session.add(user_session)
await session.commit() await session.commit()
@@ -303,7 +302,7 @@ class TestGetUserSessionCount:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_get_user_session_count_success(self, async_test_db, async_test_user): async def test_get_user_session_count_success(self, async_test_db, async_test_user):
"""Test getting user session count.""" """Test getting user session count."""
test_engine, AsyncTestingSessionLocal = async_test_db _test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session: async with AsyncTestingSessionLocal() as session:
for i in range(3): for i in range(3):
@@ -314,28 +313,26 @@ class TestGetUserSessionCount:
ip_address="192.168.1.1", ip_address="192.168.1.1",
user_agent="Mozilla/5.0", user_agent="Mozilla/5.0",
is_active=True, is_active=True,
expires_at=datetime.now(timezone.utc) + timedelta(days=7), expires_at=datetime.now(UTC) + timedelta(days=7),
last_used_at=datetime.now(timezone.utc) last_used_at=datetime.now(UTC),
) )
session.add(sess) session.add(sess)
await session.commit() await session.commit()
async with AsyncTestingSessionLocal() as session: async with AsyncTestingSessionLocal() as session:
count = await session_crud.get_user_session_count( count = await session_crud.get_user_session_count(
session, session, user_id=str(async_test_user.id)
user_id=str(async_test_user.id)
) )
assert count == 3 assert count == 3
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_get_user_session_count_empty(self, async_test_db): async def test_get_user_session_count_empty(self, async_test_db):
"""Test getting session count for user with no sessions.""" """Test getting session count for user with no sessions."""
test_engine, AsyncTestingSessionLocal = async_test_db _test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session: async with AsyncTestingSessionLocal() as session:
count = await session_crud.get_user_session_count( count = await session_crud.get_user_session_count(
session, session, user_id=str(uuid4())
user_id=str(uuid4())
) )
assert count == 0 assert count == 0
@@ -346,7 +343,7 @@ class TestUpdateRefreshToken:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_update_refresh_token_success(self, async_test_db, async_test_user): async def test_update_refresh_token_success(self, async_test_db, async_test_user):
"""Test updating refresh token JTI and expiration.""" """Test updating refresh token JTI and expiration."""
test_engine, AsyncTestingSessionLocal = async_test_db _test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session: async with AsyncTestingSessionLocal() as session:
user_session = UserSession( user_session = UserSession(
@@ -356,26 +353,34 @@ class TestUpdateRefreshToken:
ip_address="192.168.1.1", ip_address="192.168.1.1",
user_agent="Mozilla/5.0", user_agent="Mozilla/5.0",
is_active=True, is_active=True,
expires_at=datetime.now(timezone.utc) + timedelta(days=7), expires_at=datetime.now(UTC) + timedelta(days=7),
last_used_at=datetime.now(timezone.utc) - timedelta(hours=1) last_used_at=datetime.now(UTC) - timedelta(hours=1),
) )
session.add(user_session) session.add(user_session)
await session.commit() await session.commit()
await session.refresh(user_session) await session.refresh(user_session)
new_jti = "new_jti_123" new_jti = "new_jti_123"
new_expires = datetime.now(timezone.utc) + timedelta(days=14) new_expires = datetime.now(UTC) + timedelta(days=14)
result = await session_crud.update_refresh_token( result = await session_crud.update_refresh_token(
session, session,
session=user_session, session=user_session,
new_jti=new_jti, new_jti=new_jti,
new_expires_at=new_expires new_expires_at=new_expires,
) )
assert result.refresh_token_jti == new_jti assert result.refresh_token_jti == new_jti
# Compare timestamps ignoring timezone info # Compare timestamps ignoring timezone info
assert abs((result.expires_at.replace(tzinfo=None) - new_expires.replace(tzinfo=None)).total_seconds()) < 1 assert (
abs(
(
result.expires_at.replace(tzinfo=None)
- new_expires.replace(tzinfo=None)
).total_seconds()
)
< 1
)
class TestCleanupExpired: class TestCleanupExpired:
@@ -384,7 +389,7 @@ class TestCleanupExpired:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_cleanup_expired_success(self, async_test_db, async_test_user): async def test_cleanup_expired_success(self, async_test_db, async_test_user):
"""Test cleaning up old expired inactive sessions.""" """Test cleaning up old expired inactive sessions."""
test_engine, AsyncTestingSessionLocal = async_test_db _test_engine, AsyncTestingSessionLocal = async_test_db
# Create old expired inactive session # Create old expired inactive session
async with AsyncTestingSessionLocal() as session: async with AsyncTestingSessionLocal() as session:
@@ -395,9 +400,9 @@ class TestCleanupExpired:
ip_address="192.168.1.1", ip_address="192.168.1.1",
user_agent="Mozilla/5.0", user_agent="Mozilla/5.0",
is_active=False, is_active=False,
expires_at=datetime.now(timezone.utc) - timedelta(days=5), expires_at=datetime.now(UTC) - timedelta(days=5),
last_used_at=datetime.now(timezone.utc) - timedelta(days=35), last_used_at=datetime.now(UTC) - timedelta(days=35),
created_at=datetime.now(timezone.utc) - timedelta(days=35) created_at=datetime.now(UTC) - timedelta(days=35),
) )
session.add(old_session) session.add(old_session)
await session.commit() await session.commit()
@@ -410,7 +415,7 @@ class TestCleanupExpired:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_cleanup_expired_keeps_recent(self, async_test_db, async_test_user): async def test_cleanup_expired_keeps_recent(self, async_test_db, async_test_user):
"""Test that cleanup keeps recent expired sessions.""" """Test that cleanup keeps recent expired sessions."""
test_engine, AsyncTestingSessionLocal = async_test_db _test_engine, AsyncTestingSessionLocal = async_test_db
# Create recent expired inactive session (less than keep_days old) # Create recent expired inactive session (less than keep_days old)
async with AsyncTestingSessionLocal() as session: async with AsyncTestingSessionLocal() as session:
@@ -421,9 +426,9 @@ class TestCleanupExpired:
ip_address="192.168.1.1", ip_address="192.168.1.1",
user_agent="Mozilla/5.0", user_agent="Mozilla/5.0",
is_active=False, is_active=False,
expires_at=datetime.now(timezone.utc) - timedelta(hours=1), expires_at=datetime.now(UTC) - timedelta(hours=1),
last_used_at=datetime.now(timezone.utc) - timedelta(hours=2), last_used_at=datetime.now(UTC) - timedelta(hours=2),
created_at=datetime.now(timezone.utc) - timedelta(days=1) created_at=datetime.now(UTC) - timedelta(days=1),
) )
session.add(recent_session) session.add(recent_session)
await session.commit() await session.commit()
@@ -436,7 +441,7 @@ class TestCleanupExpired:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_cleanup_expired_keeps_active(self, async_test_db, async_test_user): async def test_cleanup_expired_keeps_active(self, async_test_db, async_test_user):
"""Test that cleanup does not delete active sessions.""" """Test that cleanup does not delete active sessions."""
test_engine, AsyncTestingSessionLocal = async_test_db _test_engine, AsyncTestingSessionLocal = async_test_db
# Create old expired but ACTIVE session # Create old expired but ACTIVE session
async with AsyncTestingSessionLocal() as session: async with AsyncTestingSessionLocal() as session:
@@ -447,9 +452,9 @@ class TestCleanupExpired:
ip_address="192.168.1.1", ip_address="192.168.1.1",
user_agent="Mozilla/5.0", user_agent="Mozilla/5.0",
is_active=True, # Active is_active=True, # Active
expires_at=datetime.now(timezone.utc) - timedelta(days=5), expires_at=datetime.now(UTC) - timedelta(days=5),
last_used_at=datetime.now(timezone.utc) - timedelta(days=35), last_used_at=datetime.now(UTC) - timedelta(days=35),
created_at=datetime.now(timezone.utc) - timedelta(days=35) created_at=datetime.now(UTC) - timedelta(days=35),
) )
session.add(active_session) session.add(active_session)
await session.commit() await session.commit()
@@ -464,9 +469,11 @@ class TestCleanupExpiredForUser:
"""Tests for cleanup_expired_for_user method.""" """Tests for cleanup_expired_for_user method."""
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_cleanup_expired_for_user_success(self, async_test_db, async_test_user): async def test_cleanup_expired_for_user_success(
self, async_test_db, async_test_user
):
"""Test cleaning up expired sessions for specific user.""" """Test cleaning up expired sessions for specific user."""
test_engine, AsyncTestingSessionLocal = async_test_db _test_engine, AsyncTestingSessionLocal = async_test_db
# Create expired inactive session for user # Create expired inactive session for user
async with AsyncTestingSessionLocal() as session: async with AsyncTestingSessionLocal() as session:
@@ -477,8 +484,8 @@ class TestCleanupExpiredForUser:
ip_address="192.168.1.1", ip_address="192.168.1.1",
user_agent="Mozilla/5.0", user_agent="Mozilla/5.0",
is_active=False, is_active=False,
expires_at=datetime.now(timezone.utc) - timedelta(days=1), expires_at=datetime.now(UTC) - timedelta(days=1),
last_used_at=datetime.now(timezone.utc) - timedelta(days=2) last_used_at=datetime.now(UTC) - timedelta(days=2),
) )
session.add(expired_session) session.add(expired_session)
await session.commit() await session.commit()
@@ -486,27 +493,27 @@ class TestCleanupExpiredForUser:
# Cleanup for user # Cleanup for user
async with AsyncTestingSessionLocal() as session: async with AsyncTestingSessionLocal() as session:
count = await session_crud.cleanup_expired_for_user( count = await session_crud.cleanup_expired_for_user(
session, session, user_id=str(async_test_user.id)
user_id=str(async_test_user.id)
) )
assert count == 1 assert count == 1
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_cleanup_expired_for_user_invalid_uuid(self, async_test_db): async def test_cleanup_expired_for_user_invalid_uuid(self, async_test_db):
"""Test cleanup with invalid user UUID.""" """Test cleanup with invalid user UUID."""
test_engine, AsyncTestingSessionLocal = async_test_db _test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session: async with AsyncTestingSessionLocal() as session:
with pytest.raises(ValueError, match="Invalid user ID format"): with pytest.raises(ValueError, match="Invalid user ID format"):
await session_crud.cleanup_expired_for_user( await session_crud.cleanup_expired_for_user(
session, session, user_id="not-a-valid-uuid"
user_id="not-a-valid-uuid"
) )
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_cleanup_expired_for_user_keeps_active(self, async_test_db, async_test_user): async def test_cleanup_expired_for_user_keeps_active(
self, async_test_db, async_test_user
):
"""Test that cleanup for user keeps active sessions.""" """Test that cleanup for user keeps active sessions."""
test_engine, AsyncTestingSessionLocal = async_test_db _test_engine, AsyncTestingSessionLocal = async_test_db
# Create expired but active session # Create expired but active session
async with AsyncTestingSessionLocal() as session: async with AsyncTestingSessionLocal() as session:
@@ -517,8 +524,8 @@ class TestCleanupExpiredForUser:
ip_address="192.168.1.1", ip_address="192.168.1.1",
user_agent="Mozilla/5.0", user_agent="Mozilla/5.0",
is_active=True, # Active is_active=True, # Active
expires_at=datetime.now(timezone.utc) - timedelta(days=1), expires_at=datetime.now(UTC) - timedelta(days=1),
last_used_at=datetime.now(timezone.utc) - timedelta(days=2) last_used_at=datetime.now(UTC) - timedelta(days=2),
) )
session.add(active_session) session.add(active_session)
await session.commit() await session.commit()
@@ -526,8 +533,7 @@ class TestCleanupExpiredForUser:
# Cleanup # Cleanup
async with AsyncTestingSessionLocal() as session: async with AsyncTestingSessionLocal() as session:
count = await session_crud.cleanup_expired_for_user( count = await session_crud.cleanup_expired_for_user(
session, session, user_id=str(async_test_user.id)
user_id=str(async_test_user.id)
) )
assert count == 0 # Should not delete active sessions assert count == 0 # Should not delete active sessions
@@ -536,9 +542,11 @@ class TestGetUserSessionsWithUser:
"""Tests for get_user_sessions with eager loading.""" """Tests for get_user_sessions with eager loading."""
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_get_user_sessions_with_user_relationship(self, async_test_db, async_test_user): async def test_get_user_sessions_with_user_relationship(
self, async_test_db, async_test_user
):
"""Test getting sessions with user relationship loaded.""" """Test getting sessions with user relationship loaded."""
test_engine, AsyncTestingSessionLocal = async_test_db _test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session: async with AsyncTestingSessionLocal() as session:
user_session = UserSession( user_session = UserSession(
@@ -548,8 +556,8 @@ class TestGetUserSessionsWithUser:
ip_address="192.168.1.1", ip_address="192.168.1.1",
user_agent="Mozilla/5.0", user_agent="Mozilla/5.0",
is_active=True, is_active=True,
expires_at=datetime.now(timezone.utc) + timedelta(days=7), expires_at=datetime.now(UTC) + timedelta(days=7),
last_used_at=datetime.now(timezone.utc) last_used_at=datetime.now(UTC),
) )
session.add(user_session) session.add(user_session)
await session.commit() await session.commit()
@@ -557,8 +565,6 @@ class TestGetUserSessionsWithUser:
# Get with user relationship # Get with user relationship
async with AsyncTestingSessionLocal() as session: async with AsyncTestingSessionLocal() as session:
results = await session_crud.get_user_sessions( results = await session_crud.get_user_sessions(
session, session, user_id=str(async_test_user.id), with_user=True
user_id=str(async_test_user.id),
with_user=True
) )
assert len(results) >= 1 assert len(results) >= 1

View File

@@ -2,12 +2,14 @@
""" """
Comprehensive tests for session CRUD database failure scenarios. Comprehensive tests for session CRUD database failure scenarios.
""" """
import pytest
from datetime import UTC, datetime, timedelta
from unittest.mock import AsyncMock, patch from unittest.mock import AsyncMock, patch
from sqlalchemy.exc import OperationalError, IntegrityError
from datetime import datetime, timedelta, timezone
from uuid import uuid4 from uuid import uuid4
import pytest
from sqlalchemy.exc import OperationalError
from app.crud.session import session as session_crud from app.crud.session import session as session_crud
from app.models.user_session import UserSession from app.models.user_session import UserSession
from app.schemas.sessions import SessionCreate from app.schemas.sessions import SessionCreate
@@ -19,13 +21,14 @@ class TestSessionCRUDGetByJtiFailures:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_get_by_jti_database_error(self, async_test_db): async def test_get_by_jti_database_error(self, async_test_db):
"""Test get_by_jti handles database errors.""" """Test get_by_jti handles database errors."""
test_engine, SessionLocal = async_test_db _test_engine, SessionLocal = async_test_db
async with SessionLocal() as session: async with SessionLocal() as session:
async def mock_execute(*args, **kwargs): async def mock_execute(*args, **kwargs):
raise OperationalError("DB connection lost", {}, Exception()) raise OperationalError("DB connection lost", {}, Exception())
with patch.object(session, 'execute', side_effect=mock_execute): with patch.object(session, "execute", side_effect=mock_execute):
with pytest.raises(OperationalError): with pytest.raises(OperationalError):
await session_crud.get_by_jti(session, jti="test_jti") await session_crud.get_by_jti(session, jti="test_jti")
@@ -36,13 +39,14 @@ class TestSessionCRUDGetActiveByJtiFailures:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_get_active_by_jti_database_error(self, async_test_db): async def test_get_active_by_jti_database_error(self, async_test_db):
"""Test get_active_by_jti handles database errors.""" """Test get_active_by_jti handles database errors."""
test_engine, SessionLocal = async_test_db _test_engine, SessionLocal = async_test_db
async with SessionLocal() as session: async with SessionLocal() as session:
async def mock_execute(*args, **kwargs): async def mock_execute(*args, **kwargs):
raise OperationalError("Query timeout", {}, Exception()) raise OperationalError("Query timeout", {}, Exception())
with patch.object(session, 'execute', side_effect=mock_execute): with patch.object(session, "execute", side_effect=mock_execute):
with pytest.raises(OperationalError): with pytest.raises(OperationalError):
await session_crud.get_active_by_jti(session, jti="test_jti") await session_crud.get_active_by_jti(session, jti="test_jti")
@@ -51,19 +55,21 @@ class TestSessionCRUDGetUserSessionsFailures:
"""Test get_user_sessions exception handling.""" """Test get_user_sessions exception handling."""
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_get_user_sessions_database_error(self, async_test_db, async_test_user): async def test_get_user_sessions_database_error(
self, async_test_db, async_test_user
):
"""Test get_user_sessions handles database errors.""" """Test get_user_sessions handles database errors."""
test_engine, SessionLocal = async_test_db _test_engine, SessionLocal = async_test_db
async with SessionLocal() as session: async with SessionLocal() as session:
async def mock_execute(*args, **kwargs): async def mock_execute(*args, **kwargs):
raise OperationalError("Database error", {}, Exception()) raise OperationalError("Database error", {}, Exception())
with patch.object(session, 'execute', side_effect=mock_execute): with patch.object(session, "execute", side_effect=mock_execute):
with pytest.raises(OperationalError): with pytest.raises(OperationalError):
await session_crud.get_user_sessions( await session_crud.get_user_sessions(
session, session, user_id=str(async_test_user.id)
user_id=str(async_test_user.id)
) )
@@ -71,24 +77,29 @@ class TestSessionCRUDCreateSessionFailures:
"""Test create_session exception handling.""" """Test create_session exception handling."""
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_create_session_commit_failure_triggers_rollback(self, async_test_db, async_test_user): async def test_create_session_commit_failure_triggers_rollback(
self, async_test_db, async_test_user
):
"""Test create_session handles commit failures with rollback.""" """Test create_session handles commit failures with rollback."""
test_engine, SessionLocal = async_test_db _test_engine, SessionLocal = async_test_db
async with SessionLocal() as session: async with SessionLocal() as session:
async def mock_commit(): async def mock_commit():
raise OperationalError("Commit failed", {}, Exception()) raise OperationalError("Commit failed", {}, Exception())
with patch.object(session, 'commit', side_effect=mock_commit): with patch.object(session, "commit", side_effect=mock_commit):
with patch.object(session, 'rollback', new_callable=AsyncMock) as mock_rollback: with patch.object(
session, "rollback", new_callable=AsyncMock
) as mock_rollback:
session_data = SessionCreate( session_data = SessionCreate(
user_id=async_test_user.id, user_id=async_test_user.id,
refresh_token_jti=str(uuid4()), refresh_token_jti=str(uuid4()),
device_name="Test Device", device_name="Test Device",
ip_address="127.0.0.1", ip_address="127.0.0.1",
user_agent="Test Agent", user_agent="Test Agent",
expires_at=datetime.now(timezone.utc) + timedelta(days=7), expires_at=datetime.now(UTC) + timedelta(days=7),
last_used_at=datetime.now(timezone.utc) last_used_at=datetime.now(UTC),
) )
with pytest.raises(ValueError, match="Failed to create session"): with pytest.raises(ValueError, match="Failed to create session"):
@@ -97,24 +108,29 @@ class TestSessionCRUDCreateSessionFailures:
mock_rollback.assert_called_once() mock_rollback.assert_called_once()
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_create_session_unexpected_error_triggers_rollback(self, async_test_db, async_test_user): async def test_create_session_unexpected_error_triggers_rollback(
self, async_test_db, async_test_user
):
"""Test create_session handles unexpected errors.""" """Test create_session handles unexpected errors."""
test_engine, SessionLocal = async_test_db _test_engine, SessionLocal = async_test_db
async with SessionLocal() as session: async with SessionLocal() as session:
async def mock_commit(): async def mock_commit():
raise RuntimeError("Unexpected error") raise RuntimeError("Unexpected error")
with patch.object(session, 'commit', side_effect=mock_commit): with patch.object(session, "commit", side_effect=mock_commit):
with patch.object(session, 'rollback', new_callable=AsyncMock) as mock_rollback: with patch.object(
session, "rollback", new_callable=AsyncMock
) as mock_rollback:
session_data = SessionCreate( session_data = SessionCreate(
user_id=async_test_user.id, user_id=async_test_user.id,
refresh_token_jti=str(uuid4()), refresh_token_jti=str(uuid4()),
device_name="Test Device", device_name="Test Device",
ip_address="127.0.0.1", ip_address="127.0.0.1",
user_agent="Test Agent", user_agent="Test Agent",
expires_at=datetime.now(timezone.utc) + timedelta(days=7), expires_at=datetime.now(UTC) + timedelta(days=7),
last_used_at=datetime.now(timezone.utc) last_used_at=datetime.now(UTC),
) )
with pytest.raises(ValueError, match="Failed to create session"): with pytest.raises(ValueError, match="Failed to create session"):
@@ -127,9 +143,11 @@ class TestSessionCRUDDeactivateFailures:
"""Test deactivate exception handling.""" """Test deactivate exception handling."""
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_deactivate_commit_failure_triggers_rollback(self, async_test_db, async_test_user): async def test_deactivate_commit_failure_triggers_rollback(
self, async_test_db, async_test_user
):
"""Test deactivate handles commit failures.""" """Test deactivate handles commit failures."""
test_engine, SessionLocal = async_test_db _test_engine, SessionLocal = async_test_db
# Create a session first # Create a session first
async with SessionLocal() as session: async with SessionLocal() as session:
@@ -140,8 +158,8 @@ class TestSessionCRUDDeactivateFailures:
ip_address="127.0.0.1", ip_address="127.0.0.1",
user_agent="Test Agent", user_agent="Test Agent",
is_active=True, is_active=True,
expires_at=datetime.now(timezone.utc) + timedelta(days=7), expires_at=datetime.now(UTC) + timedelta(days=7),
last_used_at=datetime.now(timezone.utc) last_used_at=datetime.now(UTC),
) )
session.add(user_session) session.add(user_session)
await session.commit() await session.commit()
@@ -150,13 +168,18 @@ class TestSessionCRUDDeactivateFailures:
# Test deactivate failure # Test deactivate failure
async with SessionLocal() as session: async with SessionLocal() as session:
async def mock_commit(): async def mock_commit():
raise OperationalError("Deactivate failed", {}, Exception()) raise OperationalError("Deactivate failed", {}, Exception())
with patch.object(session, 'commit', side_effect=mock_commit): with patch.object(session, "commit", side_effect=mock_commit):
with patch.object(session, 'rollback', new_callable=AsyncMock) as mock_rollback: with patch.object(
session, "rollback", new_callable=AsyncMock
) as mock_rollback:
with pytest.raises(OperationalError): with pytest.raises(OperationalError):
await session_crud.deactivate(session, session_id=str(session_id)) await session_crud.deactivate(
session, session_id=str(session_id)
)
mock_rollback.assert_called_once() mock_rollback.assert_called_once()
@@ -165,20 +188,24 @@ class TestSessionCRUDDeactivateAllFailures:
"""Test deactivate_all_user_sessions exception handling.""" """Test deactivate_all_user_sessions exception handling."""
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_deactivate_all_commit_failure_triggers_rollback(self, async_test_db, async_test_user): async def test_deactivate_all_commit_failure_triggers_rollback(
self, async_test_db, async_test_user
):
"""Test deactivate_all handles commit failures.""" """Test deactivate_all handles commit failures."""
test_engine, SessionLocal = async_test_db _test_engine, SessionLocal = async_test_db
async with SessionLocal() as session: async with SessionLocal() as session:
async def mock_commit(): async def mock_commit():
raise OperationalError("Bulk deactivate failed", {}, Exception()) raise OperationalError("Bulk deactivate failed", {}, Exception())
with patch.object(session, 'commit', side_effect=mock_commit): with patch.object(session, "commit", side_effect=mock_commit):
with patch.object(session, 'rollback', new_callable=AsyncMock) as mock_rollback: with patch.object(
session, "rollback", new_callable=AsyncMock
) as mock_rollback:
with pytest.raises(OperationalError): with pytest.raises(OperationalError):
await session_crud.deactivate_all_user_sessions( await session_crud.deactivate_all_user_sessions(
session, session, user_id=str(async_test_user.id)
user_id=str(async_test_user.id)
) )
mock_rollback.assert_called_once() mock_rollback.assert_called_once()
@@ -188,9 +215,11 @@ class TestSessionCRUDUpdateLastUsedFailures:
"""Test update_last_used exception handling.""" """Test update_last_used exception handling."""
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_update_last_used_commit_failure_triggers_rollback(self, async_test_db, async_test_user): async def test_update_last_used_commit_failure_triggers_rollback(
self, async_test_db, async_test_user
):
"""Test update_last_used handles commit failures.""" """Test update_last_used handles commit failures."""
test_engine, SessionLocal = async_test_db _test_engine, SessionLocal = async_test_db
# Create a session # Create a session
async with SessionLocal() as session: async with SessionLocal() as session:
@@ -201,8 +230,8 @@ class TestSessionCRUDUpdateLastUsedFailures:
ip_address="127.0.0.1", ip_address="127.0.0.1",
user_agent="Test Agent", user_agent="Test Agent",
is_active=True, is_active=True,
expires_at=datetime.now(timezone.utc) + timedelta(days=7), expires_at=datetime.now(UTC) + timedelta(days=7),
last_used_at=datetime.now(timezone.utc) - timedelta(hours=1) last_used_at=datetime.now(UTC) - timedelta(hours=1),
) )
session.add(user_session) session.add(user_session)
await session.commit() await session.commit()
@@ -211,15 +240,19 @@ class TestSessionCRUDUpdateLastUsedFailures:
# Test update failure # Test update failure
async with SessionLocal() as session: async with SessionLocal() as session:
from sqlalchemy import select from sqlalchemy import select
from app.models.user_session import UserSession as US from app.models.user_session import UserSession as US
result = await session.execute(select(US).where(US.id == user_session.id)) result = await session.execute(select(US).where(US.id == user_session.id))
sess = result.scalar_one() sess = result.scalar_one()
async def mock_commit(): async def mock_commit():
raise OperationalError("Update failed", {}, Exception()) raise OperationalError("Update failed", {}, Exception())
with patch.object(session, 'commit', side_effect=mock_commit): with patch.object(session, "commit", side_effect=mock_commit):
with patch.object(session, 'rollback', new_callable=AsyncMock) as mock_rollback: with patch.object(
session, "rollback", new_callable=AsyncMock
) as mock_rollback:
with pytest.raises(OperationalError): with pytest.raises(OperationalError):
await session_crud.update_last_used(session, session=sess) await session_crud.update_last_used(session, session=sess)
@@ -230,9 +263,11 @@ class TestSessionCRUDUpdateRefreshTokenFailures:
"""Test update_refresh_token exception handling.""" """Test update_refresh_token exception handling."""
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_update_refresh_token_commit_failure_triggers_rollback(self, async_test_db, async_test_user): async def test_update_refresh_token_commit_failure_triggers_rollback(
self, async_test_db, async_test_user
):
"""Test update_refresh_token handles commit failures.""" """Test update_refresh_token handles commit failures."""
test_engine, SessionLocal = async_test_db _test_engine, SessionLocal = async_test_db
# Create a session # Create a session
async with SessionLocal() as session: async with SessionLocal() as session:
@@ -243,8 +278,8 @@ class TestSessionCRUDUpdateRefreshTokenFailures:
ip_address="127.0.0.1", ip_address="127.0.0.1",
user_agent="Test Agent", user_agent="Test Agent",
is_active=True, is_active=True,
expires_at=datetime.now(timezone.utc) + timedelta(days=7), expires_at=datetime.now(UTC) + timedelta(days=7),
last_used_at=datetime.now(timezone.utc) last_used_at=datetime.now(UTC),
) )
session.add(user_session) session.add(user_session)
await session.commit() await session.commit()
@@ -253,21 +288,25 @@ class TestSessionCRUDUpdateRefreshTokenFailures:
# Test update failure # Test update failure
async with SessionLocal() as session: async with SessionLocal() as session:
from sqlalchemy import select from sqlalchemy import select
from app.models.user_session import UserSession as US from app.models.user_session import UserSession as US
result = await session.execute(select(US).where(US.id == user_session.id)) result = await session.execute(select(US).where(US.id == user_session.id))
sess = result.scalar_one() sess = result.scalar_one()
async def mock_commit(): async def mock_commit():
raise OperationalError("Token update failed", {}, Exception()) raise OperationalError("Token update failed", {}, Exception())
with patch.object(session, 'commit', side_effect=mock_commit): with patch.object(session, "commit", side_effect=mock_commit):
with patch.object(session, 'rollback', new_callable=AsyncMock) as mock_rollback: with patch.object(
session, "rollback", new_callable=AsyncMock
) as mock_rollback:
with pytest.raises(OperationalError): with pytest.raises(OperationalError):
await session_crud.update_refresh_token( await session_crud.update_refresh_token(
session, session,
session=sess, session=sess,
new_jti=str(uuid4()), new_jti=str(uuid4()),
new_expires_at=datetime.now(timezone.utc) + timedelta(days=14) new_expires_at=datetime.now(UTC) + timedelta(days=14),
) )
mock_rollback.assert_called_once() mock_rollback.assert_called_once()
@@ -277,16 +316,21 @@ class TestSessionCRUDCleanupExpiredFailures:
"""Test cleanup_expired exception handling.""" """Test cleanup_expired exception handling."""
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_cleanup_expired_commit_failure_triggers_rollback(self, async_test_db): async def test_cleanup_expired_commit_failure_triggers_rollback(
self, async_test_db
):
"""Test cleanup_expired handles commit failures.""" """Test cleanup_expired handles commit failures."""
test_engine, SessionLocal = async_test_db _test_engine, SessionLocal = async_test_db
async with SessionLocal() as session: async with SessionLocal() as session:
async def mock_commit(): async def mock_commit():
raise OperationalError("Cleanup failed", {}, Exception()) raise OperationalError("Cleanup failed", {}, Exception())
with patch.object(session, 'commit', side_effect=mock_commit): with patch.object(session, "commit", side_effect=mock_commit):
with patch.object(session, 'rollback', new_callable=AsyncMock) as mock_rollback: with patch.object(
session, "rollback", new_callable=AsyncMock
) as mock_rollback:
with pytest.raises(OperationalError): with pytest.raises(OperationalError):
await session_crud.cleanup_expired(session, keep_days=30) await session_crud.cleanup_expired(session, keep_days=30)
@@ -297,20 +341,24 @@ class TestSessionCRUDCleanupExpiredForUserFailures:
"""Test cleanup_expired_for_user exception handling.""" """Test cleanup_expired_for_user exception handling."""
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_cleanup_expired_for_user_commit_failure_triggers_rollback(self, async_test_db, async_test_user): async def test_cleanup_expired_for_user_commit_failure_triggers_rollback(
self, async_test_db, async_test_user
):
"""Test cleanup_expired_for_user handles commit failures.""" """Test cleanup_expired_for_user handles commit failures."""
test_engine, SessionLocal = async_test_db _test_engine, SessionLocal = async_test_db
async with SessionLocal() as session: async with SessionLocal() as session:
async def mock_commit(): async def mock_commit():
raise OperationalError("User cleanup failed", {}, Exception()) raise OperationalError("User cleanup failed", {}, Exception())
with patch.object(session, 'commit', side_effect=mock_commit): with patch.object(session, "commit", side_effect=mock_commit):
with patch.object(session, 'rollback', new_callable=AsyncMock) as mock_rollback: with patch.object(
session, "rollback", new_callable=AsyncMock
) as mock_rollback:
with pytest.raises(OperationalError): with pytest.raises(OperationalError):
await session_crud.cleanup_expired_for_user( await session_crud.cleanup_expired_for_user(
session, session, user_id=str(async_test_user.id)
user_id=str(async_test_user.id)
) )
mock_rollback.assert_called_once() mock_rollback.assert_called_once()
@@ -320,17 +368,19 @@ class TestSessionCRUDGetUserSessionCountFailures:
"""Test get_user_session_count exception handling.""" """Test get_user_session_count exception handling."""
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_get_user_session_count_database_error(self, async_test_db, async_test_user): async def test_get_user_session_count_database_error(
self, async_test_db, async_test_user
):
"""Test get_user_session_count handles database errors.""" """Test get_user_session_count handles database errors."""
test_engine, SessionLocal = async_test_db _test_engine, SessionLocal = async_test_db
async with SessionLocal() as session: async with SessionLocal() as session:
async def mock_execute(*args, **kwargs): async def mock_execute(*args, **kwargs):
raise OperationalError("Count query failed", {}, Exception()) raise OperationalError("Count query failed", {}, Exception())
with patch.object(session, 'execute', side_effect=mock_execute): with patch.object(session, "execute", side_effect=mock_execute):
with pytest.raises(OperationalError): with pytest.raises(OperationalError):
await session_crud.get_user_session_count( await session_crud.get_user_session_count(
session, session, user_id=str(async_test_user.id)
user_id=str(async_test_user.id)
) )

View File

@@ -2,12 +2,10 @@
""" """
Comprehensive tests for async user CRUD operations. Comprehensive tests for async user CRUD operations.
""" """
import pytest import pytest
from datetime import datetime, timezone
from uuid import uuid4
from app.crud.user import user as user_crud from app.crud.user import user as user_crud
from app.models.user import User
from app.schemas.users import UserCreate, UserUpdate from app.schemas.users import UserCreate, UserUpdate
@@ -17,7 +15,7 @@ class TestGetByEmail:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_get_by_email_success(self, async_test_db, async_test_user): async def test_get_by_email_success(self, async_test_db, async_test_user):
"""Test getting user by email.""" """Test getting user by email."""
test_engine, AsyncTestingSessionLocal = async_test_db _test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session: async with AsyncTestingSessionLocal() as session:
result = await user_crud.get_by_email(session, email=async_test_user.email) result = await user_crud.get_by_email(session, email=async_test_user.email)
@@ -28,10 +26,12 @@ class TestGetByEmail:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_get_by_email_not_found(self, async_test_db): async def test_get_by_email_not_found(self, async_test_db):
"""Test getting non-existent email returns None.""" """Test getting non-existent email returns None."""
test_engine, AsyncTestingSessionLocal = async_test_db _test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session: async with AsyncTestingSessionLocal() as session:
result = await user_crud.get_by_email(session, email="nonexistent@example.com") result = await user_crud.get_by_email(
session, email="nonexistent@example.com"
)
assert result is None assert result is None
@@ -41,7 +41,7 @@ class TestCreate:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_create_user_success(self, async_test_db): async def test_create_user_success(self, async_test_db):
"""Test successfully creating a user_crud.""" """Test successfully creating a user_crud."""
test_engine, AsyncTestingSessionLocal = async_test_db _test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session: async with AsyncTestingSessionLocal() as session:
user_data = UserCreate( user_data = UserCreate(
@@ -49,7 +49,7 @@ class TestCreate:
password="SecurePass123!", password="SecurePass123!",
first_name="New", first_name="New",
last_name="User", last_name="User",
phone_number="+1234567890" phone_number="+1234567890",
) )
result = await user_crud.create(session, obj_in=user_data) result = await user_crud.create(session, obj_in=user_data)
@@ -65,7 +65,7 @@ class TestCreate:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_create_superuser_success(self, async_test_db): async def test_create_superuser_success(self, async_test_db):
"""Test creating a superuser.""" """Test creating a superuser."""
test_engine, AsyncTestingSessionLocal = async_test_db _test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session: async with AsyncTestingSessionLocal() as session:
user_data = UserCreate( user_data = UserCreate(
@@ -73,7 +73,7 @@ class TestCreate:
password="SuperPass123!", password="SuperPass123!",
first_name="Super", first_name="Super",
last_name="User", last_name="User",
is_superuser=True is_superuser=True,
) )
result = await user_crud.create(session, obj_in=user_data) result = await user_crud.create(session, obj_in=user_data)
@@ -83,14 +83,14 @@ class TestCreate:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_create_duplicate_email_fails(self, async_test_db, async_test_user): async def test_create_duplicate_email_fails(self, async_test_db, async_test_user):
"""Test creating user with duplicate email raises ValueError.""" """Test creating user with duplicate email raises ValueError."""
test_engine, AsyncTestingSessionLocal = async_test_db _test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session: async with AsyncTestingSessionLocal() as session:
user_data = UserCreate( user_data = UserCreate(
email=async_test_user.email, # Duplicate email email=async_test_user.email, # Duplicate email
password="AnotherPass123!", password="AnotherPass123!",
first_name="Duplicate", first_name="Duplicate",
last_name="User" last_name="User",
) )
with pytest.raises(ValueError) as exc_info: with pytest.raises(ValueError) as exc_info:
@@ -105,16 +105,14 @@ class TestUpdate:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_update_user_basic_fields(self, async_test_db, async_test_user): async def test_update_user_basic_fields(self, async_test_db, async_test_user):
"""Test updating basic user fields.""" """Test updating basic user fields."""
test_engine, AsyncTestingSessionLocal = async_test_db _test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session: async with AsyncTestingSessionLocal() as session:
# Get fresh copy of user # Get fresh copy of user
user = await user_crud.get(session, id=str(async_test_user.id)) user = await user_crud.get(session, id=str(async_test_user.id))
update_data = UserUpdate( update_data = UserUpdate(
first_name="Updated", first_name="Updated", last_name="Name", phone_number="+9876543210"
last_name="Name",
phone_number="+9876543210"
) )
result = await user_crud.update(session, db_obj=user, obj_in=update_data) result = await user_crud.update(session, db_obj=user, obj_in=update_data)
@@ -125,7 +123,7 @@ class TestUpdate:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_update_user_password(self, async_test_db): async def test_update_user_password(self, async_test_db):
"""Test updating user password.""" """Test updating user password."""
test_engine, AsyncTestingSessionLocal = async_test_db _test_engine, AsyncTestingSessionLocal = async_test_db
# Create a fresh user for this test # Create a fresh user for this test
async with AsyncTestingSessionLocal() as session: async with AsyncTestingSessionLocal() as session:
@@ -133,7 +131,7 @@ class TestUpdate:
email="passwordtest@example.com", email="passwordtest@example.com",
password="OldPassword123!", password="OldPassword123!",
first_name="Pass", first_name="Pass",
last_name="Test" last_name="Test",
) )
user = await user_crud.create(session, obj_in=user_data) user = await user_crud.create(session, obj_in=user_data)
user_id = user.id user_id = user.id
@@ -149,12 +147,14 @@ class TestUpdate:
await session.refresh(result) await session.refresh(result)
assert result.password_hash != old_password_hash assert result.password_hash != old_password_hash
assert result.password_hash is not None assert result.password_hash is not None
assert "NewDifferentPassword123!" not in result.password_hash # Should be hashed assert (
"NewDifferentPassword123!" not in result.password_hash
) # Should be hashed
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_update_user_with_dict(self, async_test_db, async_test_user): async def test_update_user_with_dict(self, async_test_db, async_test_user):
"""Test updating user with dictionary.""" """Test updating user with dictionary."""
test_engine, AsyncTestingSessionLocal = async_test_db _test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session: async with AsyncTestingSessionLocal() as session:
user = await user_crud.get(session, id=str(async_test_user.id)) user = await user_crud.get(session, id=str(async_test_user.id))
@@ -171,13 +171,11 @@ class TestGetMultiWithTotal:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_get_multi_with_total_basic(self, async_test_db, async_test_user): async def test_get_multi_with_total_basic(self, async_test_db, async_test_user):
"""Test basic pagination.""" """Test basic pagination."""
test_engine, AsyncTestingSessionLocal = async_test_db _test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session: async with AsyncTestingSessionLocal() as session:
users, total = await user_crud.get_multi_with_total( users, total = await user_crud.get_multi_with_total(
session, session, skip=0, limit=10
skip=0,
limit=10
) )
assert total >= 1 assert total >= 1
assert len(users) >= 1 assert len(users) >= 1
@@ -186,7 +184,7 @@ class TestGetMultiWithTotal:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_get_multi_with_total_sorting_asc(self, async_test_db): async def test_get_multi_with_total_sorting_asc(self, async_test_db):
"""Test sorting in ascending order.""" """Test sorting in ascending order."""
test_engine, AsyncTestingSessionLocal = async_test_db _test_engine, AsyncTestingSessionLocal = async_test_db
# Create multiple users # Create multiple users
async with AsyncTestingSessionLocal() as session: async with AsyncTestingSessionLocal() as session:
@@ -195,17 +193,13 @@ class TestGetMultiWithTotal:
email=f"sort{i}@example.com", email=f"sort{i}@example.com",
password="SecurePass123!", password="SecurePass123!",
first_name=f"User{i}", first_name=f"User{i}",
last_name="Test" last_name="Test",
) )
await user_crud.create(session, obj_in=user_data) await user_crud.create(session, obj_in=user_data)
async with AsyncTestingSessionLocal() as session: async with AsyncTestingSessionLocal() as session:
users, total = await user_crud.get_multi_with_total( users, _total = await user_crud.get_multi_with_total(
session, session, skip=0, limit=10, sort_by="email", sort_order="asc"
skip=0,
limit=10,
sort_by="email",
sort_order="asc"
) )
# Check if sorted (at least the test users) # Check if sorted (at least the test users)
@@ -216,7 +210,7 @@ class TestGetMultiWithTotal:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_get_multi_with_total_sorting_desc(self, async_test_db): async def test_get_multi_with_total_sorting_desc(self, async_test_db):
"""Test sorting in descending order.""" """Test sorting in descending order."""
test_engine, AsyncTestingSessionLocal = async_test_db _test_engine, AsyncTestingSessionLocal = async_test_db
# Create multiple users # Create multiple users
async with AsyncTestingSessionLocal() as session: async with AsyncTestingSessionLocal() as session:
@@ -225,17 +219,13 @@ class TestGetMultiWithTotal:
email=f"desc{i}@example.com", email=f"desc{i}@example.com",
password="SecurePass123!", password="SecurePass123!",
first_name=f"User{i}", first_name=f"User{i}",
last_name="Test" last_name="Test",
) )
await user_crud.create(session, obj_in=user_data) await user_crud.create(session, obj_in=user_data)
async with AsyncTestingSessionLocal() as session: async with AsyncTestingSessionLocal() as session:
users, total = await user_crud.get_multi_with_total( users, _total = await user_crud.get_multi_with_total(
session, session, skip=0, limit=10, sort_by="email", sort_order="desc"
skip=0,
limit=10,
sort_by="email",
sort_order="desc"
) )
# Check if sorted descending (at least the test users) # Check if sorted descending (at least the test users)
@@ -246,7 +236,7 @@ class TestGetMultiWithTotal:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_get_multi_with_total_filtering(self, async_test_db): async def test_get_multi_with_total_filtering(self, async_test_db):
"""Test filtering by field.""" """Test filtering by field."""
test_engine, AsyncTestingSessionLocal = async_test_db _test_engine, AsyncTestingSessionLocal = async_test_db
# Create active and inactive users # Create active and inactive users
async with AsyncTestingSessionLocal() as session: async with AsyncTestingSessionLocal() as session:
@@ -254,7 +244,7 @@ class TestGetMultiWithTotal:
email="active@example.com", email="active@example.com",
password="SecurePass123!", password="SecurePass123!",
first_name="Active", first_name="Active",
last_name="User" last_name="User",
) )
await user_crud.create(session, obj_in=active_user) await user_crud.create(session, obj_in=active_user)
@@ -262,23 +252,18 @@ class TestGetMultiWithTotal:
email="inactive@example.com", email="inactive@example.com",
password="SecurePass123!", password="SecurePass123!",
first_name="Inactive", first_name="Inactive",
last_name="User" last_name="User",
) )
created_inactive = await user_crud.create(session, obj_in=inactive_user) created_inactive = await user_crud.create(session, obj_in=inactive_user)
# Deactivate the user # Deactivate the user
await user_crud.update( await user_crud.update(
session, session, db_obj=created_inactive, obj_in={"is_active": False}
db_obj=created_inactive,
obj_in={"is_active": False}
) )
async with AsyncTestingSessionLocal() as session: async with AsyncTestingSessionLocal() as session:
users, total = await user_crud.get_multi_with_total( users, _total = await user_crud.get_multi_with_total(
session, session, skip=0, limit=100, filters={"is_active": True}
skip=0,
limit=100,
filters={"is_active": True}
) )
# All returned users should be active # All returned users should be active
@@ -287,7 +272,7 @@ class TestGetMultiWithTotal:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_get_multi_with_total_search(self, async_test_db): async def test_get_multi_with_total_search(self, async_test_db):
"""Test search functionality.""" """Test search functionality."""
test_engine, AsyncTestingSessionLocal = async_test_db _test_engine, AsyncTestingSessionLocal = async_test_db
# Create user with unique name # Create user with unique name
async with AsyncTestingSessionLocal() as session: async with AsyncTestingSessionLocal() as session:
@@ -295,16 +280,13 @@ class TestGetMultiWithTotal:
email="searchable@example.com", email="searchable@example.com",
password="SecurePass123!", password="SecurePass123!",
first_name="Searchable", first_name="Searchable",
last_name="UserName" last_name="UserName",
) )
await user_crud.create(session, obj_in=user_data) await user_crud.create(session, obj_in=user_data)
async with AsyncTestingSessionLocal() as session: async with AsyncTestingSessionLocal() as session:
users, total = await user_crud.get_multi_with_total( users, total = await user_crud.get_multi_with_total(
session, session, skip=0, limit=100, search="Searchable"
skip=0,
limit=100,
search="Searchable"
) )
assert total >= 1 assert total >= 1
@@ -313,7 +295,7 @@ class TestGetMultiWithTotal:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_get_multi_with_total_pagination(self, async_test_db): async def test_get_multi_with_total_pagination(self, async_test_db):
"""Test pagination with skip and limit.""" """Test pagination with skip and limit."""
test_engine, AsyncTestingSessionLocal = async_test_db _test_engine, AsyncTestingSessionLocal = async_test_db
# Create multiple users # Create multiple users
async with AsyncTestingSessionLocal() as session: async with AsyncTestingSessionLocal() as session:
@@ -322,23 +304,19 @@ class TestGetMultiWithTotal:
email=f"page{i}@example.com", email=f"page{i}@example.com",
password="SecurePass123!", password="SecurePass123!",
first_name=f"Page{i}", first_name=f"Page{i}",
last_name="User" last_name="User",
) )
await user_crud.create(session, obj_in=user_data) await user_crud.create(session, obj_in=user_data)
async with AsyncTestingSessionLocal() as session: async with AsyncTestingSessionLocal() as session:
# Get first page # Get first page
users_page1, total = await user_crud.get_multi_with_total( users_page1, total = await user_crud.get_multi_with_total(
session, session, skip=0, limit=2
skip=0,
limit=2
) )
# Get second page # Get second page
users_page2, total2 = await user_crud.get_multi_with_total( users_page2, total2 = await user_crud.get_multi_with_total(
session, session, skip=2, limit=2
skip=2,
limit=2
) )
# Total should be same # Total should be same
@@ -349,7 +327,7 @@ class TestGetMultiWithTotal:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_get_multi_with_total_validation_negative_skip(self, async_test_db): async def test_get_multi_with_total_validation_negative_skip(self, async_test_db):
"""Test validation fails for negative skip.""" """Test validation fails for negative skip."""
test_engine, AsyncTestingSessionLocal = async_test_db _test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session: async with AsyncTestingSessionLocal() as session:
with pytest.raises(ValueError) as exc_info: with pytest.raises(ValueError) as exc_info:
@@ -360,7 +338,7 @@ class TestGetMultiWithTotal:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_get_multi_with_total_validation_negative_limit(self, async_test_db): async def test_get_multi_with_total_validation_negative_limit(self, async_test_db):
"""Test validation fails for negative limit.""" """Test validation fails for negative limit."""
test_engine, AsyncTestingSessionLocal = async_test_db _test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session: async with AsyncTestingSessionLocal() as session:
with pytest.raises(ValueError) as exc_info: with pytest.raises(ValueError) as exc_info:
@@ -371,7 +349,7 @@ class TestGetMultiWithTotal:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_get_multi_with_total_validation_max_limit(self, async_test_db): async def test_get_multi_with_total_validation_max_limit(self, async_test_db):
"""Test validation fails for limit > 1000.""" """Test validation fails for limit > 1000."""
test_engine, AsyncTestingSessionLocal = async_test_db _test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session: async with AsyncTestingSessionLocal() as session:
with pytest.raises(ValueError) as exc_info: with pytest.raises(ValueError) as exc_info:
@@ -386,7 +364,7 @@ class TestBulkUpdateStatus:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_bulk_update_status_success(self, async_test_db): async def test_bulk_update_status_success(self, async_test_db):
"""Test bulk updating user status.""" """Test bulk updating user status."""
test_engine, AsyncTestingSessionLocal = async_test_db _test_engine, AsyncTestingSessionLocal = async_test_db
# Create multiple users # Create multiple users
user_ids = [] user_ids = []
@@ -396,7 +374,7 @@ class TestBulkUpdateStatus:
email=f"bulk{i}@example.com", email=f"bulk{i}@example.com",
password="SecurePass123!", password="SecurePass123!",
first_name=f"Bulk{i}", first_name=f"Bulk{i}",
last_name="User" last_name="User",
) )
user = await user_crud.create(session, obj_in=user_data) user = await user_crud.create(session, obj_in=user_data)
user_ids.append(user.id) user_ids.append(user.id)
@@ -404,9 +382,7 @@ class TestBulkUpdateStatus:
# Bulk deactivate # Bulk deactivate
async with AsyncTestingSessionLocal() as session: async with AsyncTestingSessionLocal() as session:
count = await user_crud.bulk_update_status( count = await user_crud.bulk_update_status(
session, session, user_ids=user_ids, is_active=False
user_ids=user_ids,
is_active=False
) )
assert count == 3 assert count == 3
@@ -419,20 +395,18 @@ class TestBulkUpdateStatus:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_bulk_update_status_empty_list(self, async_test_db): async def test_bulk_update_status_empty_list(self, async_test_db):
"""Test bulk update with empty list returns 0.""" """Test bulk update with empty list returns 0."""
test_engine, AsyncTestingSessionLocal = async_test_db _test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session: async with AsyncTestingSessionLocal() as session:
count = await user_crud.bulk_update_status( count = await user_crud.bulk_update_status(
session, session, user_ids=[], is_active=False
user_ids=[],
is_active=False
) )
assert count == 0 assert count == 0
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_bulk_update_status_reactivate(self, async_test_db): async def test_bulk_update_status_reactivate(self, async_test_db):
"""Test bulk reactivating users.""" """Test bulk reactivating users."""
test_engine, AsyncTestingSessionLocal = async_test_db _test_engine, AsyncTestingSessionLocal = async_test_db
# Create inactive user # Create inactive user
async with AsyncTestingSessionLocal() as session: async with AsyncTestingSessionLocal() as session:
@@ -440,7 +414,7 @@ class TestBulkUpdateStatus:
email="reactivate@example.com", email="reactivate@example.com",
password="SecurePass123!", password="SecurePass123!",
first_name="Reactivate", first_name="Reactivate",
last_name="User" last_name="User",
) )
user = await user_crud.create(session, obj_in=user_data) user = await user_crud.create(session, obj_in=user_data)
# Deactivate # Deactivate
@@ -450,9 +424,7 @@ class TestBulkUpdateStatus:
# Reactivate # Reactivate
async with AsyncTestingSessionLocal() as session: async with AsyncTestingSessionLocal() as session:
count = await user_crud.bulk_update_status( count = await user_crud.bulk_update_status(
session, session, user_ids=[user_id], is_active=True
user_ids=[user_id],
is_active=True
) )
assert count == 1 assert count == 1
@@ -468,7 +440,7 @@ class TestBulkSoftDelete:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_bulk_soft_delete_success(self, async_test_db): async def test_bulk_soft_delete_success(self, async_test_db):
"""Test bulk soft deleting users.""" """Test bulk soft deleting users."""
test_engine, AsyncTestingSessionLocal = async_test_db _test_engine, AsyncTestingSessionLocal = async_test_db
# Create multiple users # Create multiple users
user_ids = [] user_ids = []
@@ -478,17 +450,14 @@ class TestBulkSoftDelete:
email=f"delete{i}@example.com", email=f"delete{i}@example.com",
password="SecurePass123!", password="SecurePass123!",
first_name=f"Delete{i}", first_name=f"Delete{i}",
last_name="User" last_name="User",
) )
user = await user_crud.create(session, obj_in=user_data) user = await user_crud.create(session, obj_in=user_data)
user_ids.append(user.id) user_ids.append(user.id)
# Bulk delete # Bulk delete
async with AsyncTestingSessionLocal() as session: async with AsyncTestingSessionLocal() as session:
count = await user_crud.bulk_soft_delete( count = await user_crud.bulk_soft_delete(session, user_ids=user_ids)
session,
user_ids=user_ids
)
assert count == 3 assert count == 3
# Verify all are soft deleted # Verify all are soft deleted
@@ -501,7 +470,7 @@ class TestBulkSoftDelete:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_bulk_soft_delete_with_exclusion(self, async_test_db): async def test_bulk_soft_delete_with_exclusion(self, async_test_db):
"""Test bulk soft delete with excluded user_crud.""" """Test bulk soft delete with excluded user_crud."""
test_engine, AsyncTestingSessionLocal = async_test_db _test_engine, AsyncTestingSessionLocal = async_test_db
# Create multiple users # Create multiple users
user_ids = [] user_ids = []
@@ -511,7 +480,7 @@ class TestBulkSoftDelete:
email=f"exclude{i}@example.com", email=f"exclude{i}@example.com",
password="SecurePass123!", password="SecurePass123!",
first_name=f"Exclude{i}", first_name=f"Exclude{i}",
last_name="User" last_name="User",
) )
user = await user_crud.create(session, obj_in=user_data) user = await user_crud.create(session, obj_in=user_data)
user_ids.append(user.id) user_ids.append(user.id)
@@ -520,9 +489,7 @@ class TestBulkSoftDelete:
exclude_id = user_ids[0] exclude_id = user_ids[0]
async with AsyncTestingSessionLocal() as session: async with AsyncTestingSessionLocal() as session:
count = await user_crud.bulk_soft_delete( count = await user_crud.bulk_soft_delete(
session, session, user_ids=user_ids, exclude_user_id=exclude_id
user_ids=user_ids,
exclude_user_id=exclude_id
) )
assert count == 2 # Only 2 deleted assert count == 2 # Only 2 deleted
@@ -534,19 +501,16 @@ class TestBulkSoftDelete:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_bulk_soft_delete_empty_list(self, async_test_db): async def test_bulk_soft_delete_empty_list(self, async_test_db):
"""Test bulk delete with empty list returns 0.""" """Test bulk delete with empty list returns 0."""
test_engine, AsyncTestingSessionLocal = async_test_db _test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session: async with AsyncTestingSessionLocal() as session:
count = await user_crud.bulk_soft_delete( count = await user_crud.bulk_soft_delete(session, user_ids=[])
session,
user_ids=[]
)
assert count == 0 assert count == 0
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_bulk_soft_delete_all_excluded(self, async_test_db): async def test_bulk_soft_delete_all_excluded(self, async_test_db):
"""Test bulk delete where all users are excluded.""" """Test bulk delete where all users are excluded."""
test_engine, AsyncTestingSessionLocal = async_test_db _test_engine, AsyncTestingSessionLocal = async_test_db
# Create user # Create user
async with AsyncTestingSessionLocal() as session: async with AsyncTestingSessionLocal() as session:
@@ -554,7 +518,7 @@ class TestBulkSoftDelete:
email="onlyuser@example.com", email="onlyuser@example.com",
password="SecurePass123!", password="SecurePass123!",
first_name="Only", first_name="Only",
last_name="User" last_name="User",
) )
user = await user_crud.create(session, obj_in=user_data) user = await user_crud.create(session, obj_in=user_data)
user_id = user.id user_id = user.id
@@ -562,16 +526,14 @@ class TestBulkSoftDelete:
# Try to delete but exclude # Try to delete but exclude
async with AsyncTestingSessionLocal() as session: async with AsyncTestingSessionLocal() as session:
count = await user_crud.bulk_soft_delete( count = await user_crud.bulk_soft_delete(
session, session, user_ids=[user_id], exclude_user_id=user_id
user_ids=[user_id],
exclude_user_id=user_id
) )
assert count == 0 assert count == 0
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_bulk_soft_delete_already_deleted(self, async_test_db): async def test_bulk_soft_delete_already_deleted(self, async_test_db):
"""Test bulk delete doesn't re-delete already deleted users.""" """Test bulk delete doesn't re-delete already deleted users."""
test_engine, AsyncTestingSessionLocal = async_test_db _test_engine, AsyncTestingSessionLocal = async_test_db
# Create and delete user # Create and delete user
async with AsyncTestingSessionLocal() as session: async with AsyncTestingSessionLocal() as session:
@@ -579,7 +541,7 @@ class TestBulkSoftDelete:
email="predeleted@example.com", email="predeleted@example.com",
password="SecurePass123!", password="SecurePass123!",
first_name="PreDeleted", first_name="PreDeleted",
last_name="User" last_name="User",
) )
user = await user_crud.create(session, obj_in=user_data) user = await user_crud.create(session, obj_in=user_data)
user_id = user.id user_id = user.id
@@ -589,10 +551,7 @@ class TestBulkSoftDelete:
# Try to delete again # Try to delete again
async with AsyncTestingSessionLocal() as session: async with AsyncTestingSessionLocal() as session:
count = await user_crud.bulk_soft_delete( count = await user_crud.bulk_soft_delete(session, user_ids=[user_id])
session,
user_ids=[user_id]
)
assert count == 0 # Already deleted assert count == 0 # Already deleted
@@ -602,7 +561,7 @@ class TestUtilityMethods:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_is_active_true(self, async_test_db, async_test_user): async def test_is_active_true(self, async_test_db, async_test_user):
"""Test is_active returns True for active user_crud.""" """Test is_active returns True for active user_crud."""
test_engine, AsyncTestingSessionLocal = async_test_db _test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session: async with AsyncTestingSessionLocal() as session:
user = await user_crud.get(session, id=str(async_test_user.id)) user = await user_crud.get(session, id=str(async_test_user.id))
@@ -611,14 +570,14 @@ class TestUtilityMethods:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_is_active_false(self, async_test_db): async def test_is_active_false(self, async_test_db):
"""Test is_active returns False for inactive user_crud.""" """Test is_active returns False for inactive user_crud."""
test_engine, AsyncTestingSessionLocal = async_test_db _test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session: async with AsyncTestingSessionLocal() as session:
user_data = UserCreate( user_data = UserCreate(
email="inactive2@example.com", email="inactive2@example.com",
password="SecurePass123!", password="SecurePass123!",
first_name="Inactive", first_name="Inactive",
last_name="User" last_name="User",
) )
user = await user_crud.create(session, obj_in=user_data) user = await user_crud.create(session, obj_in=user_data)
await user_crud.update(session, db_obj=user, obj_in={"is_active": False}) await user_crud.update(session, db_obj=user, obj_in={"is_active": False})
@@ -628,7 +587,7 @@ class TestUtilityMethods:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_is_superuser_true(self, async_test_db, async_test_superuser): async def test_is_superuser_true(self, async_test_db, async_test_superuser):
"""Test is_superuser returns True for superuser.""" """Test is_superuser returns True for superuser."""
test_engine, AsyncTestingSessionLocal = async_test_db _test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session: async with AsyncTestingSessionLocal() as session:
user = await user_crud.get(session, id=str(async_test_superuser.id)) user = await user_crud.get(session, id=str(async_test_superuser.id))
@@ -637,7 +596,7 @@ class TestUtilityMethods:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_is_superuser_false(self, async_test_db, async_test_user): async def test_is_superuser_false(self, async_test_db, async_test_user):
"""Test is_superuser returns False for regular user_crud.""" """Test is_superuser returns False for regular user_crud."""
test_engine, AsyncTestingSessionLocal = async_test_db _test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session: async with AsyncTestingSessionLocal() as session:
user = await user_crud.get(session, id=str(async_test_user.id)) user = await user_crud.get(session, id=str(async_test_user.id))
@@ -654,42 +613,52 @@ class TestUserExceptionHandlers:
async def test_get_by_email_database_error(self, async_test_db): async def test_get_by_email_database_error(self, async_test_db):
"""Test get_by_email handles database errors (covers lines 30-32).""" """Test get_by_email handles database errors (covers lines 30-32)."""
from unittest.mock import patch from unittest.mock import patch
test_engine, AsyncTestingSessionLocal = async_test_db
_test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session: async with AsyncTestingSessionLocal() as session:
with patch.object(session, 'execute', side_effect=Exception("Database query failed")): with patch.object(
session, "execute", side_effect=Exception("Database query failed")
):
with pytest.raises(Exception, match="Database query failed"): with pytest.raises(Exception, match="Database query failed"):
await user_crud.get_by_email(session, email="test@example.com") await user_crud.get_by_email(session, email="test@example.com")
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_bulk_update_status_database_error(self, async_test_db, async_test_user): async def test_bulk_update_status_database_error(
self, async_test_db, async_test_user
):
"""Test bulk_update_status handles database errors (covers lines 205-208).""" """Test bulk_update_status handles database errors (covers lines 205-208)."""
from unittest.mock import patch, AsyncMock from unittest.mock import AsyncMock, patch
test_engine, AsyncTestingSessionLocal = async_test_db
_test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session: async with AsyncTestingSessionLocal() as session:
# Mock execute to fail # Mock execute to fail
with patch.object(session, 'execute', side_effect=Exception("Bulk update failed")): with patch.object(
with patch.object(session, 'rollback', new_callable=AsyncMock): session, "execute", side_effect=Exception("Bulk update failed")
):
with patch.object(session, "rollback", new_callable=AsyncMock):
with pytest.raises(Exception, match="Bulk update failed"): with pytest.raises(Exception, match="Bulk update failed"):
await user_crud.bulk_update_status( await user_crud.bulk_update_status(
session, session, user_ids=[async_test_user.id], is_active=False
user_ids=[async_test_user.id],
is_active=False
) )
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_bulk_soft_delete_database_error(self, async_test_db, async_test_user): async def test_bulk_soft_delete_database_error(
self, async_test_db, async_test_user
):
"""Test bulk_soft_delete handles database errors (covers lines 257-260).""" """Test bulk_soft_delete handles database errors (covers lines 257-260)."""
from unittest.mock import patch, AsyncMock from unittest.mock import AsyncMock, patch
test_engine, AsyncTestingSessionLocal = async_test_db
_test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session: async with AsyncTestingSessionLocal() as session:
# Mock execute to fail # Mock execute to fail
with patch.object(session, 'execute', side_effect=Exception("Bulk delete failed")): with patch.object(
with patch.object(session, 'rollback', new_callable=AsyncMock): session, "execute", side_effect=Exception("Bulk delete failed")
):
with patch.object(session, "rollback", new_callable=AsyncMock):
with pytest.raises(Exception, match="Bulk delete failed"): with pytest.raises(Exception, match="Bulk delete failed"):
await user_crud.bulk_soft_delete( await user_crud.bulk_soft_delete(
session, session, user_ids=[async_test_user.id]
user_ids=[async_test_user.id]
) )

View File

@@ -1,8 +1,10 @@
# tests/models/test_user.py # tests/models/test_user.py
import uuid import uuid
import pytest
from datetime import datetime from datetime import datetime
import pytest
from sqlalchemy.exc import IntegrityError from sqlalchemy.exc import IntegrityError
from app.models.user import User from app.models.user import User
@@ -166,7 +168,6 @@ def test_user_required_fields(db_session):
db_session.rollback() db_session.rollback()
def test_user_defaults(db_session): def test_user_defaults(db_session):
"""Test that default values are correctly set.""" """Test that default values are correctly set."""
# Arrange - Create a minimal user with only required fields # Arrange - Create a minimal user with only required fields
@@ -210,22 +211,13 @@ def test_user_with_complex_json_preferences(db_session):
"""Test storing and retrieving complex JSON preferences.""" """Test storing and retrieving complex JSON preferences."""
# Arrange - Create a user with nested JSON preferences # Arrange - Create a user with nested JSON preferences
complex_preferences = { complex_preferences = {
"theme": { "theme": {"mode": "dark", "colors": {"primary": "#333", "secondary": "#666"}},
"mode": "dark",
"colors": {
"primary": "#333",
"secondary": "#666"
}
},
"notifications": { "notifications": {
"email": True, "email": True,
"sms": False, "sms": False,
"push": { "push": {"enabled": True, "quiet_hours": [22, 7]},
"enabled": True,
"quiet_hours": [22, 7]
}
}, },
"tags": ["important", "family", "events"] "tags": ["important", "family", "events"],
} }
user = User( user = User(
@@ -234,13 +226,15 @@ def test_user_with_complex_json_preferences(db_session):
password_hash="hashedpassword", password_hash="hashedpassword",
first_name="Complex", first_name="Complex",
last_name="JSON", last_name="JSON",
preferences=complex_preferences preferences=complex_preferences,
) )
db_session.add(user) db_session.add(user)
db_session.commit() db_session.commit()
# Act - Retrieve the user # Act - Retrieve the user
retrieved_user = db_session.query(User).filter_by(email="complex@example.com").first() retrieved_user = (
db_session.query(User).filter_by(email="complex@example.com").first()
)
# Assert - The complex JSON should be preserved # Assert - The complex JSON should be preserved
assert retrieved_user.preferences == complex_preferences assert retrieved_user.preferences == complex_preferences

View File

@@ -5,6 +5,7 @@ Covers Pydantic validators for:
- Slug validation (lines 26, 28, 30, 32, 62-70) - Slug validation (lines 26, 28, 30, 32, 62-70)
- Name validation (lines 40, 77) - Name validation (lines 40, 77)
""" """
import pytest import pytest
from pydantic import ValidationError from pydantic import ValidationError
@@ -20,19 +21,13 @@ class TestOrganizationBaseValidators:
def test_valid_organization_base(self): def test_valid_organization_base(self):
"""Test that valid data passes validation.""" """Test that valid data passes validation."""
org = OrganizationBase( org = OrganizationBase(name="Test Organization", slug="test-org")
name="Test Organization",
slug="test-org"
)
assert org.name == "Test Organization" assert org.name == "Test Organization"
assert org.slug == "test-org" assert org.slug == "test-org"
def test_slug_none_returns_none(self): def test_slug_none_returns_none(self):
"""Test that None slug is allowed (covers line 26).""" """Test that None slug is allowed (covers line 26)."""
org = OrganizationBase( org = OrganizationBase(name="Test Organization", slug=None)
name="Test Organization",
slug=None
)
assert org.slug is None assert org.slug is None
def test_slug_invalid_characters_rejected(self): def test_slug_invalid_characters_rejected(self):
@@ -40,57 +35,46 @@ class TestOrganizationBaseValidators:
with pytest.raises(ValidationError) as exc_info: with pytest.raises(ValidationError) as exc_info:
OrganizationBase( OrganizationBase(
name="Test Organization", name="Test Organization",
slug="Test_Org!" # Uppercase and special chars slug="Test_Org!", # Uppercase and special chars
) )
errors = exc_info.value.errors() errors = exc_info.value.errors()
assert any("lowercase letters, numbers, and hyphens" in str(e['msg']) for e in errors) assert any(
"lowercase letters, numbers, and hyphens" in str(e["msg"]) for e in errors
)
def test_slug_starts_with_hyphen_rejected(self): def test_slug_starts_with_hyphen_rejected(self):
"""Test slug starting with hyphen is rejected (covers line 30).""" """Test slug starting with hyphen is rejected (covers line 30)."""
with pytest.raises(ValidationError) as exc_info: with pytest.raises(ValidationError) as exc_info:
OrganizationBase( OrganizationBase(name="Test Organization", slug="-test-org")
name="Test Organization",
slug="-test-org"
)
errors = exc_info.value.errors() errors = exc_info.value.errors()
assert any("cannot start or end with a hyphen" in str(e['msg']) for e in errors) assert any("cannot start or end with a hyphen" in str(e["msg"]) for e in errors)
def test_slug_ends_with_hyphen_rejected(self): def test_slug_ends_with_hyphen_rejected(self):
"""Test slug ending with hyphen is rejected (covers line 30).""" """Test slug ending with hyphen is rejected (covers line 30)."""
with pytest.raises(ValidationError) as exc_info: with pytest.raises(ValidationError) as exc_info:
OrganizationBase( OrganizationBase(name="Test Organization", slug="test-org-")
name="Test Organization",
slug="test-org-"
)
errors = exc_info.value.errors() errors = exc_info.value.errors()
assert any("cannot start or end with a hyphen" in str(e['msg']) for e in errors) assert any("cannot start or end with a hyphen" in str(e["msg"]) for e in errors)
def test_slug_consecutive_hyphens_rejected(self): def test_slug_consecutive_hyphens_rejected(self):
"""Test slug with consecutive hyphens is rejected (covers line 32).""" """Test slug with consecutive hyphens is rejected (covers line 32)."""
with pytest.raises(ValidationError) as exc_info: with pytest.raises(ValidationError) as exc_info:
OrganizationBase( OrganizationBase(name="Test Organization", slug="test--org")
name="Test Organization",
slug="test--org"
)
errors = exc_info.value.errors() errors = exc_info.value.errors()
assert any("cannot contain consecutive hyphens" in str(e['msg']) for e in errors) assert any(
"cannot contain consecutive hyphens" in str(e["msg"]) for e in errors
)
def test_name_whitespace_only_rejected(self): def test_name_whitespace_only_rejected(self):
"""Test whitespace-only name is rejected (covers line 40).""" """Test whitespace-only name is rejected (covers line 40)."""
with pytest.raises(ValidationError) as exc_info: with pytest.raises(ValidationError) as exc_info:
OrganizationBase( OrganizationBase(name=" ", slug="test-org")
name=" ",
slug="test-org"
)
errors = exc_info.value.errors() errors = exc_info.value.errors()
assert any("name cannot be empty" in str(e['msg']) for e in errors) assert any("name cannot be empty" in str(e["msg"]) for e in errors)
def test_name_trimmed(self): def test_name_trimmed(self):
"""Test that name is trimmed.""" """Test that name is trimmed."""
org = OrganizationBase( org = OrganizationBase(name=" Test Organization ", slug="test-org")
name=" Test Organization ",
slug="test-org"
)
assert org.name == "Test Organization" assert org.name == "Test Organization"
@@ -99,22 +83,18 @@ class TestOrganizationCreateValidators:
def test_valid_organization_create(self): def test_valid_organization_create(self):
"""Test that valid data passes validation.""" """Test that valid data passes validation."""
org = OrganizationCreate( org = OrganizationCreate(name="Test Organization", slug="test-org")
name="Test Organization",
slug="test-org"
)
assert org.name == "Test Organization" assert org.name == "Test Organization"
assert org.slug == "test-org" assert org.slug == "test-org"
def test_slug_validation_inherited(self): def test_slug_validation_inherited(self):
"""Test that slug validation is inherited from base.""" """Test that slug validation is inherited from base."""
with pytest.raises(ValidationError) as exc_info: with pytest.raises(ValidationError) as exc_info:
OrganizationCreate( OrganizationCreate(name="Test", slug="Invalid_Slug!")
name="Test",
slug="Invalid_Slug!"
)
errors = exc_info.value.errors() errors = exc_info.value.errors()
assert any("lowercase letters, numbers, and hyphens" in str(e['msg']) for e in errors) assert any(
"lowercase letters, numbers, and hyphens" in str(e["msg"]) for e in errors
)
class TestOrganizationUpdateValidators: class TestOrganizationUpdateValidators:
@@ -122,10 +102,7 @@ class TestOrganizationUpdateValidators:
def test_valid_organization_update(self): def test_valid_organization_update(self):
"""Test that valid update data passes validation.""" """Test that valid update data passes validation."""
org = OrganizationUpdate( org = OrganizationUpdate(name="Updated Name", slug="updated-slug")
name="Updated Name",
slug="updated-slug"
)
assert org.name == "Updated Name" assert org.name == "Updated Name"
assert org.slug == "updated-slug" assert org.slug == "updated-slug"
@@ -139,35 +116,39 @@ class TestOrganizationUpdateValidators:
with pytest.raises(ValidationError) as exc_info: with pytest.raises(ValidationError) as exc_info:
OrganizationUpdate(slug="Test_Org!") OrganizationUpdate(slug="Test_Org!")
errors = exc_info.value.errors() errors = exc_info.value.errors()
assert any("lowercase letters, numbers, and hyphens" in str(e['msg']) for e in errors) assert any(
"lowercase letters, numbers, and hyphens" in str(e["msg"]) for e in errors
)
def test_update_slug_starts_with_hyphen_rejected(self): def test_update_slug_starts_with_hyphen_rejected(self):
"""Test update slug starting with hyphen is rejected (covers line 66).""" """Test update slug starting with hyphen is rejected (covers line 66)."""
with pytest.raises(ValidationError) as exc_info: with pytest.raises(ValidationError) as exc_info:
OrganizationUpdate(slug="-test-org") OrganizationUpdate(slug="-test-org")
errors = exc_info.value.errors() errors = exc_info.value.errors()
assert any("cannot start or end with a hyphen" in str(e['msg']) for e in errors) assert any("cannot start or end with a hyphen" in str(e["msg"]) for e in errors)
def test_update_slug_ends_with_hyphen_rejected(self): def test_update_slug_ends_with_hyphen_rejected(self):
"""Test update slug ending with hyphen is rejected (covers line 66).""" """Test update slug ending with hyphen is rejected (covers line 66)."""
with pytest.raises(ValidationError) as exc_info: with pytest.raises(ValidationError) as exc_info:
OrganizationUpdate(slug="test-org-") OrganizationUpdate(slug="test-org-")
errors = exc_info.value.errors() errors = exc_info.value.errors()
assert any("cannot start or end with a hyphen" in str(e['msg']) for e in errors) assert any("cannot start or end with a hyphen" in str(e["msg"]) for e in errors)
def test_update_slug_consecutive_hyphens_rejected(self): def test_update_slug_consecutive_hyphens_rejected(self):
"""Test update slug with consecutive hyphens is rejected (covers line 68).""" """Test update slug with consecutive hyphens is rejected (covers line 68)."""
with pytest.raises(ValidationError) as exc_info: with pytest.raises(ValidationError) as exc_info:
OrganizationUpdate(slug="test--org") OrganizationUpdate(slug="test--org")
errors = exc_info.value.errors() errors = exc_info.value.errors()
assert any("cannot contain consecutive hyphens" in str(e['msg']) for e in errors) assert any(
"cannot contain consecutive hyphens" in str(e["msg"]) for e in errors
)
def test_update_name_whitespace_only_rejected(self): def test_update_name_whitespace_only_rejected(self):
"""Test whitespace-only name in update is rejected (covers line 77).""" """Test whitespace-only name in update is rejected (covers line 77)."""
with pytest.raises(ValidationError) as exc_info: with pytest.raises(ValidationError) as exc_info:
OrganizationUpdate(name=" ") OrganizationUpdate(name=" ")
errors = exc_info.value.errors() errors = exc_info.value.errors()
assert any("name cannot be empty" in str(e['msg']) for e in errors) assert any("name cannot be empty" in str(e["msg"]) for e in errors)
def test_update_name_none_allowed(self): def test_update_name_none_allowed(self):
"""Test that None name is allowed in update.""" """Test that None name is allowed in update."""

View File

@@ -1,80 +1,177 @@
# tests/schemas/test_user_schemas.py # tests/schemas/test_user_schemas.py
import pytest
import re import re
import pytest
from pydantic import ValidationError from pydantic import ValidationError
from app.schemas.users import UserBase, UserCreate from app.schemas.users import UserBase, UserCreate
class TestPhoneNumberValidation: class TestPhoneNumberValidation:
"""Tests for phone number validation in user schemas""" """Tests for phone number validation in user schemas"""
def test_valid_swiss_numbers(self): def test_valid_swiss_numbers(self):
"""Test valid Swiss phone numbers are accepted""" """Test valid Swiss phone numbers are accepted"""
# International format # International format
user = UserBase(email="test@example.com", first_name="Test", last_name="User", phone_number="+41791234567") user = UserBase(
email="test@example.com",
first_name="Test",
last_name="User",
phone_number="+41791234567",
)
assert user.phone_number == "+41791234567" assert user.phone_number == "+41791234567"
# Local format # Local format
user = UserBase(email="test@example.com", first_name="Test", last_name="User", phone_number="0791234567") user = UserBase(
email="test@example.com",
first_name="Test",
last_name="User",
phone_number="0791234567",
)
assert user.phone_number == "0791234567" assert user.phone_number == "0791234567"
# With formatting characters # With formatting characters
user = UserBase(email="test@example.com", first_name="Test", last_name="User", phone_number="+41 79 123 45 67") user = UserBase(
assert re.sub(r'[\s\-\(\)]', '', user.phone_number) == "+41791234567" email="test@example.com",
first_name="Test",
last_name="User",
phone_number="+41 79 123 45 67",
)
assert re.sub(r"[\s\-\(\)]", "", user.phone_number) == "+41791234567"
user = UserBase(email="test@example.com", first_name="Test", last_name="User", phone_number="079 123 45 67") user = UserBase(
assert re.sub(r'[\s\-\(\)]', '', user.phone_number) == "0791234567" email="test@example.com",
first_name="Test",
last_name="User",
phone_number="079 123 45 67",
)
assert re.sub(r"[\s\-\(\)]", "", user.phone_number) == "0791234567"
user = UserBase(email="test@example.com", first_name="Test", last_name="User", phone_number="+41-79-123-45-67") user = UserBase(
assert re.sub(r'[\s\-\(\)]', '', user.phone_number) == "+41791234567" email="test@example.com",
first_name="Test",
last_name="User",
phone_number="+41-79-123-45-67",
)
assert re.sub(r"[\s\-\(\)]", "", user.phone_number) == "+41791234567"
user = UserBase(email="test@example.com", first_name="Test", last_name="User", phone_number="079-123-45-67") user = UserBase(
assert re.sub(r'[\s\-\(\)]', '', user.phone_number) == "0791234567" email="test@example.com",
first_name="Test",
last_name="User",
phone_number="079-123-45-67",
)
assert re.sub(r"[\s\-\(\)]", "", user.phone_number) == "0791234567"
user = UserBase(email="test@example.com", first_name="Test", last_name="User", phone_number="+41 (79) 123 45 67") user = UserBase(
assert re.sub(r'[\s\-\(\)]', '', user.phone_number) == "+41791234567" email="test@example.com",
first_name="Test",
last_name="User",
phone_number="+41 (79) 123 45 67",
)
assert re.sub(r"[\s\-\(\)]", "", user.phone_number) == "+41791234567"
user = UserBase(email="test@example.com", first_name="Test", last_name="User", phone_number="079 (123) 45 67") user = UserBase(
assert re.sub(r'[\s\-\(\)]', '', user.phone_number) == "0791234567" email="test@example.com",
first_name="Test",
last_name="User",
phone_number="079 (123) 45 67",
)
assert re.sub(r"[\s\-\(\)]", "", user.phone_number) == "0791234567"
def test_valid_italian_numbers(self): def test_valid_italian_numbers(self):
"""Test valid Italian phone numbers are accepted""" """Test valid Italian phone numbers are accepted"""
# International format # International format
user = UserBase(email="test@example.com", first_name="Test", last_name="User", phone_number="+393451234567") user = UserBase(
email="test@example.com",
first_name="Test",
last_name="User",
phone_number="+393451234567",
)
assert user.phone_number == "+393451234567" assert user.phone_number == "+393451234567"
user = UserBase(email="test@example.com", first_name="Test", last_name="User", phone_number="+39345123456") user = UserBase(
email="test@example.com",
first_name="Test",
last_name="User",
phone_number="+39345123456",
)
assert user.phone_number == "+39345123456" assert user.phone_number == "+39345123456"
# Local format # Local format
user = UserBase(email="test@example.com", first_name="Test", last_name="User", phone_number="03451234567") user = UserBase(
email="test@example.com",
first_name="Test",
last_name="User",
phone_number="03451234567",
)
assert user.phone_number == "03451234567" assert user.phone_number == "03451234567"
user = UserBase(email="test@example.com", first_name="Test", last_name="User", phone_number="0345123456789") user = UserBase(
email="test@example.com",
first_name="Test",
last_name="User",
phone_number="0345123456789",
)
assert user.phone_number == "0345123456789" assert user.phone_number == "0345123456789"
# With formatting characters # With formatting characters
user = UserBase(email="test@example.com", first_name="Test", last_name="User", phone_number="+39 345 123 4567") user = UserBase(
assert re.sub(r'[\s\-\(\)]', '', user.phone_number) == "+393451234567" email="test@example.com",
first_name="Test",
last_name="User",
phone_number="+39 345 123 4567",
)
assert re.sub(r"[\s\-\(\)]", "", user.phone_number) == "+393451234567"
user = UserBase(email="test@example.com", first_name="Test", last_name="User", phone_number="0345 123 4567") user = UserBase(
assert re.sub(r'[\s\-\(\)]', '', user.phone_number) == "03451234567" email="test@example.com",
first_name="Test",
last_name="User",
phone_number="0345 123 4567",
)
assert re.sub(r"[\s\-\(\)]", "", user.phone_number) == "03451234567"
user = UserBase(email="test@example.com", first_name="Test", last_name="User", phone_number="+39-345-123-4567") user = UserBase(
assert re.sub(r'[\s\-\(\)]', '', user.phone_number) == "+393451234567" email="test@example.com",
first_name="Test",
last_name="User",
phone_number="+39-345-123-4567",
)
assert re.sub(r"[\s\-\(\)]", "", user.phone_number) == "+393451234567"
user = UserBase(email="test@example.com", first_name="Test", last_name="User", phone_number="0345-123-4567") user = UserBase(
assert re.sub(r'[\s\-\(\)]', '', user.phone_number) == "03451234567" email="test@example.com",
first_name="Test",
last_name="User",
phone_number="0345-123-4567",
)
assert re.sub(r"[\s\-\(\)]", "", user.phone_number) == "03451234567"
user = UserBase(email="test@example.com", first_name="Test", last_name="User", phone_number="+39 (345) 123 4567") user = UserBase(
assert re.sub(r'[\s\-\(\)]', '', user.phone_number) == "+393451234567" email="test@example.com",
first_name="Test",
last_name="User",
phone_number="+39 (345) 123 4567",
)
assert re.sub(r"[\s\-\(\)]", "", user.phone_number) == "+393451234567"
user = UserBase(email="test@example.com", first_name="Test", last_name="User", phone_number="0345 (123) 4567") user = UserBase(
assert re.sub(r'[\s\-\(\)]', '', user.phone_number) == "03451234567" email="test@example.com",
first_name="Test",
last_name="User",
phone_number="0345 (123) 4567",
)
assert re.sub(r"[\s\-\(\)]", "", user.phone_number) == "03451234567"
def test_none_phone_number(self): def test_none_phone_number(self):
"""Test that None is accepted as a valid value (optional phone number)""" """Test that None is accepted as a valid value (optional phone number)"""
user = UserBase(email="test@example.com", first_name="Test", last_name="User", phone_number=None) user = UserBase(
email="test@example.com",
first_name="Test",
last_name="User",
phone_number=None,
)
assert user.phone_number is None assert user.phone_number is None
def test_invalid_phone_numbers(self): def test_invalid_phone_numbers(self):
@@ -83,17 +180,14 @@ class TestPhoneNumberValidation:
# Too short # Too short
"+12", "+12",
"012", "012",
# Invalid characters # Invalid characters
"+41xyz123456", "+41xyz123456",
"079abc4567", "079abc4567",
"123-abc-7890", "123-abc-7890",
"+1(800)CALL-NOW", "+1(800)CALL-NOW",
# Completely invalid formats # Completely invalid formats
"++4412345678", # Double plus "++4412345678", # Double plus
# Note: "()+41123456" becomes "+41123456" after cleaning, which is valid # Note: "()+41123456" becomes "+41123456" after cleaning, which is valid
# Empty string # Empty string
"", "",
# Spaces only # Spaces only
@@ -102,7 +196,12 @@ class TestPhoneNumberValidation:
for number in invalid_numbers: for number in invalid_numbers:
with pytest.raises(ValidationError): with pytest.raises(ValidationError):
UserBase(email="test@example.com", first_name="Test", last_name="User", phone_number=number) UserBase(
email="test@example.com",
first_name="Test",
last_name="User",
phone_number=number,
)
def test_phone_validation_in_user_create(self): def test_phone_validation_in_user_create(self):
"""Test that phone validation also works in UserCreate schema""" """Test that phone validation also works in UserCreate schema"""
@@ -112,7 +211,7 @@ class TestPhoneNumberValidation:
first_name="Test", first_name="Test",
last_name="User", last_name="User",
password="Password123!", password="Password123!",
phone_number="+41791234567" phone_number="+41791234567",
) )
assert user.phone_number == "+41791234567" assert user.phone_number == "+41791234567"
@@ -123,5 +222,5 @@ class TestPhoneNumberValidation:
first_name="Test", first_name="Test",
last_name="User", last_name="User",
password="Password123!", password="Password123!",
phone_number="invalid-number" phone_number="invalid-number",
) )

View File

@@ -7,12 +7,13 @@ Covers all edge cases in validation functions:
- validate_email_format (line 148) - validate_email_format (line 148)
- validate_slug (lines 170-183) - validate_slug (lines 170-183)
""" """
import pytest import pytest
from app.schemas.validators import ( from app.schemas.validators import (
validate_email_format,
validate_password_strength, validate_password_strength,
validate_phone_number, validate_phone_number,
validate_email_format,
validate_slug, validate_slug,
) )
@@ -108,12 +109,14 @@ class TestPhoneNumberValidator:
validate_phone_number("+123456789012345") # 15 digits after + validate_phone_number("+123456789012345") # 15 digits after +
def test_multiple_plus_symbols_rejected(self): def test_multiple_plus_symbols_rejected(self):
"""Test phone number with multiple + symbols. r"""Test phone number with multiple + symbols.
Note: Line 115 is defensive code - the regex check at line 110 catches this first. Note: Line 115 is defensive code - the regex check at line 110 catches this first.
The regex ^(?:\+[0-9]{8,14}|0[0-9]{8,14})$ only allows + at the start. The regex ^(?:\+[0-9]{8,14}|0[0-9]{8,14})$ only allows + at the start.
""" """
with pytest.raises(ValueError, match="must start with \\+ or 0 followed by 8-14 digits"): with pytest.raises(
ValueError, match="must start with \\+ or 0 followed by 8-14 digits"
):
validate_phone_number("+1234+5678901") validate_phone_number("+1234+5678901")
def test_non_digit_after_prefix_rejected(self): def test_non_digit_after_prefix_rejected(self):

View File

@@ -1,14 +1,18 @@
# tests/services/test_auth_service.py # tests/services/test_auth_service.py
import uuid import uuid
import pytest
import pytest_asyncio
from unittest.mock import patch from unittest.mock import patch
import pytest
from sqlalchemy import select from sqlalchemy import select
from app.core.auth import get_password_hash, verify_password, TokenExpiredError, TokenInvalidError from app.core.auth import (
TokenInvalidError,
get_password_hash,
verify_password,
)
from app.models.user import User from app.models.user import User
from app.schemas.users import UserCreate, Token from app.schemas.users import Token, UserCreate
from app.services.auth_service import AuthService, AuthenticationError from app.services.auth_service import AuthenticationError, AuthService
class TestAuthServiceAuthentication: class TestAuthServiceAuthentication:
@@ -17,12 +21,14 @@ class TestAuthServiceAuthentication:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_authenticate_valid_user(self, async_test_db, async_test_user): async def test_authenticate_valid_user(self, async_test_db, async_test_user):
"""Test authenticating a user with valid credentials""" """Test authenticating a user with valid credentials"""
test_engine, AsyncTestingSessionLocal = async_test_db _test_engine, AsyncTestingSessionLocal = async_test_db
# Set a known password for the mock user # Set a known password for the mock user
password = "TestPassword123!" password = "TestPassword123!"
async with AsyncTestingSessionLocal() as session: async with AsyncTestingSessionLocal() as session:
result = await session.execute(select(User).where(User.id == async_test_user.id)) result = await session.execute(
select(User).where(User.id == async_test_user.id)
)
user = result.scalar_one_or_none() user = result.scalar_one_or_none()
user.password_hash = get_password_hash(password) user.password_hash = get_password_hash(password)
await session.commit() await session.commit()
@@ -30,9 +36,7 @@ class TestAuthServiceAuthentication:
# Authenticate with correct credentials # Authenticate with correct credentials
async with AsyncTestingSessionLocal() as session: async with AsyncTestingSessionLocal() as session:
auth_user = await AuthService.authenticate_user( auth_user = await AuthService.authenticate_user(
db=session, db=session, email=async_test_user.email, password=password
email=async_test_user.email,
password=password
) )
assert auth_user is not None assert auth_user is not None
@@ -42,26 +46,28 @@ class TestAuthServiceAuthentication:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_authenticate_nonexistent_user(self, async_test_db): async def test_authenticate_nonexistent_user(self, async_test_db):
"""Test authenticating with an email that doesn't exist""" """Test authenticating with an email that doesn't exist"""
test_engine, AsyncTestingSessionLocal = async_test_db _test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session: async with AsyncTestingSessionLocal() as session:
user = await AuthService.authenticate_user( user = await AuthService.authenticate_user(
db=session, db=session, email="nonexistent@example.com", password="password"
email="nonexistent@example.com",
password="password"
) )
assert user is None assert user is None
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_authenticate_with_wrong_password(self, async_test_db, async_test_user): async def test_authenticate_with_wrong_password(
self, async_test_db, async_test_user
):
"""Test authenticating with the wrong password""" """Test authenticating with the wrong password"""
test_engine, AsyncTestingSessionLocal = async_test_db _test_engine, AsyncTestingSessionLocal = async_test_db
# Set a known password for the mock user # Set a known password for the mock user
password = "TestPassword123!" password = "TestPassword123!"
async with AsyncTestingSessionLocal() as session: async with AsyncTestingSessionLocal() as session:
result = await session.execute(select(User).where(User.id == async_test_user.id)) result = await session.execute(
select(User).where(User.id == async_test_user.id)
)
user = result.scalar_one_or_none() user = result.scalar_one_or_none()
user.password_hash = get_password_hash(password) user.password_hash = get_password_hash(password)
await session.commit() await session.commit()
@@ -69,9 +75,7 @@ class TestAuthServiceAuthentication:
# Authenticate with wrong password # Authenticate with wrong password
async with AsyncTestingSessionLocal() as session: async with AsyncTestingSessionLocal() as session:
auth_user = await AuthService.authenticate_user( auth_user = await AuthService.authenticate_user(
db=session, db=session, email=async_test_user.email, password="WrongPassword123"
email=async_test_user.email,
password="WrongPassword123"
) )
assert auth_user is None assert auth_user is None
@@ -79,12 +83,14 @@ class TestAuthServiceAuthentication:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_authenticate_inactive_user(self, async_test_db, async_test_user): async def test_authenticate_inactive_user(self, async_test_db, async_test_user):
"""Test authenticating an inactive user""" """Test authenticating an inactive user"""
test_engine, AsyncTestingSessionLocal = async_test_db _test_engine, AsyncTestingSessionLocal = async_test_db
# Set a known password and make user inactive # Set a known password and make user inactive
password = "TestPassword123!" password = "TestPassword123!"
async with AsyncTestingSessionLocal() as session: async with AsyncTestingSessionLocal() as session:
result = await session.execute(select(User).where(User.id == async_test_user.id)) result = await session.execute(
select(User).where(User.id == async_test_user.id)
)
user = result.scalar_one_or_none() user = result.scalar_one_or_none()
user.password_hash = get_password_hash(password) user.password_hash = get_password_hash(password)
user.is_active = False user.is_active = False
@@ -94,9 +100,7 @@ class TestAuthServiceAuthentication:
async with AsyncTestingSessionLocal() as session: async with AsyncTestingSessionLocal() as session:
with pytest.raises(AuthenticationError): with pytest.raises(AuthenticationError):
await AuthService.authenticate_user( await AuthService.authenticate_user(
db=session, db=session, email=async_test_user.email, password=password
email=async_test_user.email,
password=password
) )
@@ -106,14 +110,14 @@ class TestAuthServiceUserCreation:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_create_new_user(self, async_test_db): async def test_create_new_user(self, async_test_db):
"""Test creating a new user""" """Test creating a new user"""
test_engine, AsyncTestingSessionLocal = async_test_db _test_engine, AsyncTestingSessionLocal = async_test_db
user_data = UserCreate( user_data = UserCreate(
email="newuser@example.com", email="newuser@example.com",
password="TestPassword123!", password="TestPassword123!",
first_name="New", first_name="New",
last_name="User", last_name="User",
phone_number="+1234567890" phone_number="+1234567890",
) )
async with AsyncTestingSessionLocal() as session: async with AsyncTestingSessionLocal() as session:
@@ -135,15 +139,17 @@ class TestAuthServiceUserCreation:
assert user.is_superuser is False assert user.is_superuser is False
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_create_user_with_existing_email(self, async_test_db, async_test_user): async def test_create_user_with_existing_email(
self, async_test_db, async_test_user
):
"""Test creating a user with an email that already exists""" """Test creating a user with an email that already exists"""
test_engine, AsyncTestingSessionLocal = async_test_db _test_engine, AsyncTestingSessionLocal = async_test_db
user_data = UserCreate( user_data = UserCreate(
email=async_test_user.email, # Use existing email email=async_test_user.email, # Use existing email
password="TestPassword123!", password="TestPassword123!",
first_name="Duplicate", first_name="Duplicate",
last_name="User" last_name="User",
) )
# Should raise AuthenticationError # Should raise AuthenticationError
@@ -169,7 +175,7 @@ class TestAuthServiceTokens:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_refresh_tokens(self, async_test_db, async_test_user): async def test_refresh_tokens(self, async_test_db, async_test_user):
"""Test refreshing tokens with a valid refresh token""" """Test refreshing tokens with a valid refresh token"""
test_engine, AsyncTestingSessionLocal = async_test_db _test_engine, AsyncTestingSessionLocal = async_test_db
# Create initial tokens # Create initial tokens
initial_tokens = AuthService.create_tokens(async_test_user) initial_tokens = AuthService.create_tokens(async_test_user)
@@ -177,8 +183,7 @@ class TestAuthServiceTokens:
# Refresh tokens # Refresh tokens
async with AsyncTestingSessionLocal() as session: async with AsyncTestingSessionLocal() as session:
new_tokens = await AuthService.refresh_tokens( new_tokens = await AuthService.refresh_tokens(
db=session, db=session, refresh_token=initial_tokens.refresh_token
refresh_token=initial_tokens.refresh_token
) )
# Verify new tokens are different from old ones # Verify new tokens are different from old ones
@@ -188,7 +193,7 @@ class TestAuthServiceTokens:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_refresh_tokens_with_invalid_token(self, async_test_db): async def test_refresh_tokens_with_invalid_token(self, async_test_db):
"""Test refreshing tokens with an invalid token""" """Test refreshing tokens with an invalid token"""
test_engine, AsyncTestingSessionLocal = async_test_db _test_engine, AsyncTestingSessionLocal = async_test_db
# Create an invalid token # Create an invalid token
invalid_token = "invalid.token.string" invalid_token = "invalid.token.string"
@@ -197,14 +202,15 @@ class TestAuthServiceTokens:
async with AsyncTestingSessionLocal() as session: async with AsyncTestingSessionLocal() as session:
with pytest.raises(TokenInvalidError): with pytest.raises(TokenInvalidError):
await AuthService.refresh_tokens( await AuthService.refresh_tokens(
db=session, db=session, refresh_token=invalid_token
refresh_token=invalid_token
) )
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_refresh_tokens_with_access_token(self, async_test_db, async_test_user): async def test_refresh_tokens_with_access_token(
self, async_test_db, async_test_user
):
"""Test refreshing tokens with an access token instead of refresh token""" """Test refreshing tokens with an access token instead of refresh token"""
test_engine, AsyncTestingSessionLocal = async_test_db _test_engine, AsyncTestingSessionLocal = async_test_db
# Create tokens # Create tokens
tokens = AuthService.create_tokens(async_test_user) tokens = AuthService.create_tokens(async_test_user)
@@ -213,18 +219,20 @@ class TestAuthServiceTokens:
async with AsyncTestingSessionLocal() as session: async with AsyncTestingSessionLocal() as session:
with pytest.raises(TokenInvalidError): with pytest.raises(TokenInvalidError):
await AuthService.refresh_tokens( await AuthService.refresh_tokens(
db=session, db=session, refresh_token=tokens.access_token
refresh_token=tokens.access_token
) )
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_refresh_tokens_with_nonexistent_user(self, async_test_db): async def test_refresh_tokens_with_nonexistent_user(self, async_test_db):
"""Test refreshing tokens for a user that doesn't exist in the database""" """Test refreshing tokens for a user that doesn't exist in the database"""
test_engine, AsyncTestingSessionLocal = async_test_db _test_engine, AsyncTestingSessionLocal = async_test_db
# Create a token for a non-existent user # Create a token for a non-existent user
non_existent_id = str(uuid.uuid4()) non_existent_id = str(uuid.uuid4())
with patch('app.core.auth.decode_token'), patch('app.core.auth.get_token_data') as mock_get_data: with (
patch("app.core.auth.decode_token"),
patch("app.core.auth.get_token_data") as mock_get_data,
):
# Mock the token data to return a non-existent user ID # Mock the token data to return a non-existent user ID
mock_get_data.return_value.user_id = uuid.UUID(non_existent_id) mock_get_data.return_value.user_id = uuid.UUID(non_existent_id)
@@ -232,8 +240,7 @@ class TestAuthServiceTokens:
async with AsyncTestingSessionLocal() as session: async with AsyncTestingSessionLocal() as session:
with pytest.raises(TokenInvalidError): with pytest.raises(TokenInvalidError):
await AuthService.refresh_tokens( await AuthService.refresh_tokens(
db=session, db=session, refresh_token="some.refresh.token"
refresh_token="some.refresh.token"
) )
@@ -243,12 +250,14 @@ class TestAuthServicePasswordChange:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_change_password(self, async_test_db, async_test_user): async def test_change_password(self, async_test_db, async_test_user):
"""Test changing a user's password""" """Test changing a user's password"""
test_engine, AsyncTestingSessionLocal = async_test_db _test_engine, AsyncTestingSessionLocal = async_test_db
# Set a known password for the mock user # Set a known password for the mock user
current_password = "CurrentPassword123" current_password = "CurrentPassword123"
async with AsyncTestingSessionLocal() as session: async with AsyncTestingSessionLocal() as session:
result = await session.execute(select(User).where(User.id == async_test_user.id)) result = await session.execute(
select(User).where(User.id == async_test_user.id)
)
user = result.scalar_one_or_none() user = result.scalar_one_or_none()
user.password_hash = get_password_hash(current_password) user.password_hash = get_password_hash(current_password)
await session.commit() await session.commit()
@@ -260,7 +269,7 @@ class TestAuthServicePasswordChange:
db=session, db=session,
user_id=async_test_user.id, user_id=async_test_user.id,
current_password=current_password, current_password=current_password,
new_password=new_password new_password=new_password,
) )
# Verify operation was successful # Verify operation was successful
@@ -268,7 +277,9 @@ class TestAuthServicePasswordChange:
# Verify password was changed # Verify password was changed
async with AsyncTestingSessionLocal() as session: async with AsyncTestingSessionLocal() as session:
result = await session.execute(select(User).where(User.id == async_test_user.id)) result = await session.execute(
select(User).where(User.id == async_test_user.id)
)
updated_user = result.scalar_one_or_none() updated_user = result.scalar_one_or_none()
# Verify old password no longer works # Verify old password no longer works
@@ -278,14 +289,18 @@ class TestAuthServicePasswordChange:
assert verify_password(new_password, updated_user.password_hash) assert verify_password(new_password, updated_user.password_hash)
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_change_password_wrong_current_password(self, async_test_db, async_test_user): async def test_change_password_wrong_current_password(
self, async_test_db, async_test_user
):
"""Test changing password with incorrect current password""" """Test changing password with incorrect current password"""
test_engine, AsyncTestingSessionLocal = async_test_db _test_engine, AsyncTestingSessionLocal = async_test_db
# Set a known password for the mock user # Set a known password for the mock user
current_password = "CurrentPassword123" current_password = "CurrentPassword123"
async with AsyncTestingSessionLocal() as session: async with AsyncTestingSessionLocal() as session:
result = await session.execute(select(User).where(User.id == async_test_user.id)) result = await session.execute(
select(User).where(User.id == async_test_user.id)
)
user = result.scalar_one_or_none() user = result.scalar_one_or_none()
user.password_hash = get_password_hash(current_password) user.password_hash = get_password_hash(current_password)
await session.commit() await session.commit()
@@ -298,19 +313,21 @@ class TestAuthServicePasswordChange:
db=session, db=session,
user_id=async_test_user.id, user_id=async_test_user.id,
current_password=wrong_password, current_password=wrong_password,
new_password="NewPassword456" new_password="NewPassword456",
) )
# Verify password was not changed # Verify password was not changed
async with AsyncTestingSessionLocal() as session: async with AsyncTestingSessionLocal() as session:
result = await session.execute(select(User).where(User.id == async_test_user.id)) result = await session.execute(
select(User).where(User.id == async_test_user.id)
)
user = result.scalar_one_or_none() user = result.scalar_one_or_none()
assert verify_password(current_password, user.password_hash) assert verify_password(current_password, user.password_hash)
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_change_password_nonexistent_user(self, async_test_db): async def test_change_password_nonexistent_user(self, async_test_db):
"""Test changing password for a user that doesn't exist""" """Test changing password for a user that doesn't exist"""
test_engine, AsyncTestingSessionLocal = async_test_db _test_engine, AsyncTestingSessionLocal = async_test_db
non_existent_id = uuid.uuid4() non_existent_id = uuid.uuid4()
@@ -320,5 +337,5 @@ class TestAuthServicePasswordChange:
db=session, db=session,
user_id=non_existent_id, user_id=non_existent_id,
current_password="CurrentPassword123", current_password="CurrentPassword123",
new_password="NewPassword456" new_password="NewPassword456",
) )

View File

@@ -2,13 +2,15 @@
""" """
Tests for email service functionality. Tests for email service functionality.
""" """
from unittest.mock import AsyncMock
import pytest import pytest
from unittest.mock import patch, AsyncMock, MagicMock
from app.services.email_service import ( from app.services.email_service import (
EmailService,
ConsoleEmailBackend, ConsoleEmailBackend,
SMTPEmailBackend EmailService,
SMTPEmailBackend,
) )
@@ -24,7 +26,7 @@ class TestConsoleEmailBackend:
to=["user@example.com"], to=["user@example.com"],
subject="Test Subject", subject="Test Subject",
html_content="<p>Test HTML</p>", html_content="<p>Test HTML</p>",
text_content="Test Text" text_content="Test Text",
) )
assert result is True assert result is True
@@ -37,7 +39,7 @@ class TestConsoleEmailBackend:
result = await backend.send_email( result = await backend.send_email(
to=["user@example.com"], to=["user@example.com"],
subject="Test Subject", subject="Test Subject",
html_content="<p>Test HTML</p>" html_content="<p>Test HTML</p>",
) )
assert result is True assert result is True
@@ -50,7 +52,7 @@ class TestConsoleEmailBackend:
result = await backend.send_email( result = await backend.send_email(
to=["user1@example.com", "user2@example.com"], to=["user1@example.com", "user2@example.com"],
subject="Test Subject", subject="Test Subject",
html_content="<p>Test HTML</p>" html_content="<p>Test HTML</p>",
) )
assert result is True assert result is True
@@ -66,7 +68,7 @@ class TestSMTPEmailBackend:
host="smtp.example.com", host="smtp.example.com",
port=587, port=587,
username="test@example.com", username="test@example.com",
password="password" password="password",
) )
assert backend.host == "smtp.example.com" assert backend.host == "smtp.example.com"
@@ -81,14 +83,14 @@ class TestSMTPEmailBackend:
host="smtp.example.com", host="smtp.example.com",
port=587, port=587,
username="test@example.com", username="test@example.com",
password="password" password="password",
) )
# Should fall back to console backend since SMTP is not implemented # Should fall back to console backend since SMTP is not implemented
result = await backend.send_email( result = await backend.send_email(
to=["user@example.com"], to=["user@example.com"],
subject="Test Subject", subject="Test Subject",
html_content="<p>Test HTML</p>" html_content="<p>Test HTML</p>",
) )
assert result is True assert result is True
@@ -114,9 +116,7 @@ class TestEmailService:
service = EmailService() service = EmailService()
result = await service.send_password_reset_email( result = await service.send_password_reset_email(
to_email="user@example.com", to_email="user@example.com", reset_token="test_token_123", user_name="John"
reset_token="test_token_123",
user_name="John"
) )
assert result is True assert result is True
@@ -127,8 +127,7 @@ class TestEmailService:
service = EmailService() service = EmailService()
result = await service.send_password_reset_email( result = await service.send_password_reset_email(
to_email="user@example.com", to_email="user@example.com", reset_token="test_token_123"
reset_token="test_token_123"
) )
assert result is True assert result is True
@@ -142,8 +141,7 @@ class TestEmailService:
token = "test_reset_token_xyz" token = "test_reset_token_xyz"
await service.send_password_reset_email( await service.send_password_reset_email(
to_email="user@example.com", to_email="user@example.com", reset_token=token
reset_token=token
) )
# Verify send_email was called # Verify send_email was called
@@ -151,7 +149,7 @@ class TestEmailService:
call_args = backend_mock.send_email.call_args call_args = backend_mock.send_email.call_args
# Check that token is in the HTML content # Check that token is in the HTML content
html_content = call_args.kwargs['html_content'] html_content = call_args.kwargs["html_content"]
assert token in html_content assert token in html_content
@pytest.mark.asyncio @pytest.mark.asyncio
@@ -162,8 +160,7 @@ class TestEmailService:
service = EmailService(backend=backend_mock) service = EmailService(backend=backend_mock)
result = await service.send_password_reset_email( result = await service.send_password_reset_email(
to_email="user@example.com", to_email="user@example.com", reset_token="test_token"
reset_token="test_token"
) )
assert result is False assert result is False
@@ -176,7 +173,7 @@ class TestEmailService:
result = await service.send_email_verification( result = await service.send_email_verification(
to_email="user@example.com", to_email="user@example.com",
verification_token="verification_token_123", verification_token="verification_token_123",
user_name="Jane" user_name="Jane",
) )
assert result is True assert result is True
@@ -187,8 +184,7 @@ class TestEmailService:
service = EmailService() service = EmailService()
result = await service.send_email_verification( result = await service.send_email_verification(
to_email="user@example.com", to_email="user@example.com", verification_token="verification_token_123"
verification_token="verification_token_123"
) )
assert result is True assert result is True
@@ -202,8 +198,7 @@ class TestEmailService:
token = "test_verification_token_xyz" token = "test_verification_token_xyz"
await service.send_email_verification( await service.send_email_verification(
to_email="user@example.com", to_email="user@example.com", verification_token=token
verification_token=token
) )
# Verify send_email was called # Verify send_email was called
@@ -211,7 +206,7 @@ class TestEmailService:
call_args = backend_mock.send_email.call_args call_args = backend_mock.send_email.call_args
# Check that token is in the HTML content # Check that token is in the HTML content
html_content = call_args.kwargs['html_content'] html_content = call_args.kwargs["html_content"]
assert token in html_content assert token in html_content
@pytest.mark.asyncio @pytest.mark.asyncio
@@ -222,8 +217,7 @@ class TestEmailService:
service = EmailService(backend=backend_mock) service = EmailService(backend=backend_mock)
result = await service.send_email_verification( result = await service.send_email_verification(
to_email="user@example.com", to_email="user@example.com", verification_token="test_token"
verification_token="test_token"
) )
assert result is False assert result is False
@@ -236,14 +230,12 @@ class TestEmailService:
service = EmailService(backend=backend_mock) service = EmailService(backend=backend_mock)
await service.send_password_reset_email( await service.send_password_reset_email(
to_email="user@example.com", to_email="user@example.com", reset_token="token123", user_name="Test User"
reset_token="token123",
user_name="Test User"
) )
call_args = backend_mock.send_email.call_args call_args = backend_mock.send_email.call_args
html_content = call_args.kwargs['html_content'] html_content = call_args.kwargs["html_content"]
text_content = call_args.kwargs['text_content'] text_content = call_args.kwargs["text_content"]
# Check HTML content # Check HTML content
assert "Password Reset" in html_content assert "Password Reset" in html_content
@@ -251,7 +243,9 @@ class TestEmailService:
assert "Test User" in html_content assert "Test User" in html_content
# Check text content # Check text content
assert "Password Reset" in text_content or "password reset" in text_content.lower() assert (
"Password Reset" in text_content or "password reset" in text_content.lower()
)
assert "token123" in text_content assert "token123" in text_content
@pytest.mark.asyncio @pytest.mark.asyncio
@@ -264,12 +258,12 @@ class TestEmailService:
await service.send_email_verification( await service.send_email_verification(
to_email="user@example.com", to_email="user@example.com",
verification_token="verify123", verification_token="verify123",
user_name="Test User" user_name="Test User",
) )
call_args = backend_mock.send_email.call_args call_args = backend_mock.send_email.call_args
html_content = call_args.kwargs['html_content'] html_content = call_args.kwargs["html_content"]
text_content = call_args.kwargs['text_content'] text_content = call_args.kwargs["text_content"]
# Check HTML content # Check HTML content
assert "Verify" in html_content assert "Verify" in html_content

View File

@@ -2,23 +2,27 @@
""" """
Comprehensive tests for session cleanup service. Comprehensive tests for session cleanup service.
""" """
import pytest
import asyncio import asyncio
from datetime import datetime, timedelta, timezone
from unittest.mock import patch, MagicMock, AsyncMock
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
from datetime import UTC, datetime, timedelta
from unittest.mock import AsyncMock, patch
import pytest
from sqlalchemy import select
from app.models.user_session import UserSession from app.models.user_session import UserSession
from sqlalchemy import select
class TestCleanupExpiredSessions: class TestCleanupExpiredSessions:
"""Tests for cleanup_expired_sessions function.""" """Tests for cleanup_expired_sessions function."""
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_cleanup_expired_sessions_success(self, async_test_db, async_test_user): async def test_cleanup_expired_sessions_success(
self, async_test_db, async_test_user
):
"""Test successful cleanup of expired sessions.""" """Test successful cleanup of expired sessions."""
test_engine, AsyncTestingSessionLocal = async_test_db _test_engine, AsyncTestingSessionLocal = async_test_db
# Create mix of sessions # Create mix of sessions
async with AsyncTestingSessionLocal() as session: async with AsyncTestingSessionLocal() as session:
@@ -30,9 +34,9 @@ class TestCleanupExpiredSessions:
ip_address="192.168.1.1", ip_address="192.168.1.1",
user_agent="Mozilla/5.0", user_agent="Mozilla/5.0",
is_active=True, is_active=True,
expires_at=datetime.now(timezone.utc) + timedelta(days=7), expires_at=datetime.now(UTC) + timedelta(days=7),
created_at=datetime.now(timezone.utc) - timedelta(days=1), created_at=datetime.now(UTC) - timedelta(days=1),
last_used_at=datetime.now(timezone.utc) last_used_at=datetime.now(UTC),
) )
# 2. Inactive, expired, old (SHOULD be deleted) # 2. Inactive, expired, old (SHOULD be deleted)
@@ -43,9 +47,9 @@ class TestCleanupExpiredSessions:
ip_address="192.168.1.2", ip_address="192.168.1.2",
user_agent="Mozilla/5.0", user_agent="Mozilla/5.0",
is_active=False, is_active=False,
expires_at=datetime.now(timezone.utc) - timedelta(days=10), expires_at=datetime.now(UTC) - timedelta(days=10),
created_at=datetime.now(timezone.utc) - timedelta(days=40), created_at=datetime.now(UTC) - timedelta(days=40),
last_used_at=datetime.now(timezone.utc) last_used_at=datetime.now(UTC),
) )
# 3. Inactive, expired, recent (should NOT be deleted - within keep_days) # 3. Inactive, expired, recent (should NOT be deleted - within keep_days)
@@ -56,17 +60,23 @@ class TestCleanupExpiredSessions:
ip_address="192.168.1.3", ip_address="192.168.1.3",
user_agent="Mozilla/5.0", user_agent="Mozilla/5.0",
is_active=False, is_active=False,
expires_at=datetime.now(timezone.utc) - timedelta(days=1), expires_at=datetime.now(UTC) - timedelta(days=1),
created_at=datetime.now(timezone.utc) - timedelta(days=5), created_at=datetime.now(UTC) - timedelta(days=5),
last_used_at=datetime.now(timezone.utc) last_used_at=datetime.now(UTC),
) )
session.add_all([active_session, old_expired_session, recent_expired_session]) session.add_all(
[active_session, old_expired_session, recent_expired_session]
)
await session.commit() await session.commit()
# Mock SessionLocal to return our test session # Mock SessionLocal to return our test session
with patch('app.services.session_cleanup.SessionLocal', return_value=AsyncTestingSessionLocal()): with patch(
"app.services.session_cleanup.SessionLocal",
return_value=AsyncTestingSessionLocal(),
):
from app.services.session_cleanup import cleanup_expired_sessions from app.services.session_cleanup import cleanup_expired_sessions
deleted_count = await cleanup_expired_sessions(keep_days=30) deleted_count = await cleanup_expired_sessions(keep_days=30)
# Should only delete old_expired_session # Should only delete old_expired_session
@@ -85,7 +95,7 @@ class TestCleanupExpiredSessions:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_cleanup_no_sessions_to_delete(self, async_test_db, async_test_user): async def test_cleanup_no_sessions_to_delete(self, async_test_db, async_test_user):
"""Test cleanup when no sessions meet deletion criteria.""" """Test cleanup when no sessions meet deletion criteria."""
test_engine, AsyncTestingSessionLocal = async_test_db _test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session: async with AsyncTestingSessionLocal() as session:
active = UserSession( active = UserSession(
@@ -95,15 +105,19 @@ class TestCleanupExpiredSessions:
ip_address="192.168.1.1", ip_address="192.168.1.1",
user_agent="Mozilla/5.0", user_agent="Mozilla/5.0",
is_active=True, is_active=True,
expires_at=datetime.now(timezone.utc) + timedelta(days=7), expires_at=datetime.now(UTC) + timedelta(days=7),
created_at=datetime.now(timezone.utc), created_at=datetime.now(UTC),
last_used_at=datetime.now(timezone.utc) last_used_at=datetime.now(UTC),
) )
session.add(active) session.add(active)
await session.commit() await session.commit()
with patch('app.services.session_cleanup.SessionLocal', return_value=AsyncTestingSessionLocal()): with patch(
"app.services.session_cleanup.SessionLocal",
return_value=AsyncTestingSessionLocal(),
):
from app.services.session_cleanup import cleanup_expired_sessions from app.services.session_cleanup import cleanup_expired_sessions
deleted_count = await cleanup_expired_sessions(keep_days=30) deleted_count = await cleanup_expired_sessions(keep_days=30)
assert deleted_count == 0 assert deleted_count == 0
@@ -111,10 +125,14 @@ class TestCleanupExpiredSessions:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_cleanup_empty_database(self, async_test_db): async def test_cleanup_empty_database(self, async_test_db):
"""Test cleanup with no sessions in database.""" """Test cleanup with no sessions in database."""
test_engine, AsyncTestingSessionLocal = async_test_db _test_engine, AsyncTestingSessionLocal = async_test_db
with patch('app.services.session_cleanup.SessionLocal', return_value=AsyncTestingSessionLocal()): with patch(
"app.services.session_cleanup.SessionLocal",
return_value=AsyncTestingSessionLocal(),
):
from app.services.session_cleanup import cleanup_expired_sessions from app.services.session_cleanup import cleanup_expired_sessions
deleted_count = await cleanup_expired_sessions(keep_days=30) deleted_count = await cleanup_expired_sessions(keep_days=30)
assert deleted_count == 0 assert deleted_count == 0
@@ -122,7 +140,7 @@ class TestCleanupExpiredSessions:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_cleanup_with_keep_days_0(self, async_test_db, async_test_user): async def test_cleanup_with_keep_days_0(self, async_test_db, async_test_user):
"""Test cleanup with keep_days=0 deletes all inactive expired sessions.""" """Test cleanup with keep_days=0 deletes all inactive expired sessions."""
test_engine, AsyncTestingSessionLocal = async_test_db _test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session: async with AsyncTestingSessionLocal() as session:
today_expired = UserSession( today_expired = UserSession(
@@ -132,15 +150,19 @@ class TestCleanupExpiredSessions:
ip_address="192.168.1.1", ip_address="192.168.1.1",
user_agent="Mozilla/5.0", user_agent="Mozilla/5.0",
is_active=False, is_active=False,
expires_at=datetime.now(timezone.utc) - timedelta(hours=1), expires_at=datetime.now(UTC) - timedelta(hours=1),
created_at=datetime.now(timezone.utc) - timedelta(hours=2), created_at=datetime.now(UTC) - timedelta(hours=2),
last_used_at=datetime.now(timezone.utc) last_used_at=datetime.now(UTC),
) )
session.add(today_expired) session.add(today_expired)
await session.commit() await session.commit()
with patch('app.services.session_cleanup.SessionLocal', return_value=AsyncTestingSessionLocal()): with patch(
"app.services.session_cleanup.SessionLocal",
return_value=AsyncTestingSessionLocal(),
):
from app.services.session_cleanup import cleanup_expired_sessions from app.services.session_cleanup import cleanup_expired_sessions
deleted_count = await cleanup_expired_sessions(keep_days=0) deleted_count = await cleanup_expired_sessions(keep_days=0)
assert deleted_count == 1 assert deleted_count == 1
@@ -148,7 +170,7 @@ class TestCleanupExpiredSessions:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_cleanup_bulk_delete_efficiency(self, async_test_db, async_test_user): async def test_cleanup_bulk_delete_efficiency(self, async_test_db, async_test_user):
"""Test that cleanup uses bulk DELETE for many sessions.""" """Test that cleanup uses bulk DELETE for many sessions."""
test_engine, AsyncTestingSessionLocal = async_test_db _test_engine, AsyncTestingSessionLocal = async_test_db
# Create 50 expired sessions # Create 50 expired sessions
async with AsyncTestingSessionLocal() as session: async with AsyncTestingSessionLocal() as session:
@@ -161,16 +183,20 @@ class TestCleanupExpiredSessions:
ip_address="192.168.1.1", ip_address="192.168.1.1",
user_agent="Mozilla/5.0", user_agent="Mozilla/5.0",
is_active=False, is_active=False,
expires_at=datetime.now(timezone.utc) - timedelta(days=10), expires_at=datetime.now(UTC) - timedelta(days=10),
created_at=datetime.now(timezone.utc) - timedelta(days=40), created_at=datetime.now(UTC) - timedelta(days=40),
last_used_at=datetime.now(timezone.utc) last_used_at=datetime.now(UTC),
) )
sessions_to_add.append(expired) sessions_to_add.append(expired)
session.add_all(sessions_to_add) session.add_all(sessions_to_add)
await session.commit() await session.commit()
with patch('app.services.session_cleanup.SessionLocal', return_value=AsyncTestingSessionLocal()): with patch(
"app.services.session_cleanup.SessionLocal",
return_value=AsyncTestingSessionLocal(),
):
from app.services.session_cleanup import cleanup_expired_sessions from app.services.session_cleanup import cleanup_expired_sessions
deleted_count = await cleanup_expired_sessions(keep_days=30) deleted_count = await cleanup_expired_sessions(keep_days=30)
assert deleted_count == 50 assert deleted_count == 50
@@ -178,14 +204,20 @@ class TestCleanupExpiredSessions:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_cleanup_database_error_returns_zero(self, async_test_db): async def test_cleanup_database_error_returns_zero(self, async_test_db):
"""Test cleanup returns 0 on database errors (doesn't crash).""" """Test cleanup returns 0 on database errors (doesn't crash)."""
test_engine, AsyncTestingSessionLocal = async_test_db _test_engine, AsyncTestingSessionLocal = async_test_db
# Mock session_crud.cleanup_expired to raise error # Mock session_crud.cleanup_expired to raise error
with patch('app.services.session_cleanup.SessionLocal', return_value=AsyncTestingSessionLocal()): with patch(
with patch('app.services.session_cleanup.session_crud.cleanup_expired') as mock_cleanup: "app.services.session_cleanup.SessionLocal",
return_value=AsyncTestingSessionLocal(),
):
with patch(
"app.services.session_cleanup.session_crud.cleanup_expired"
) as mock_cleanup:
mock_cleanup.side_effect = Exception("Database connection lost") mock_cleanup.side_effect = Exception("Database connection lost")
from app.services.session_cleanup import cleanup_expired_sessions from app.services.session_cleanup import cleanup_expired_sessions
# Should not crash, should return 0 # Should not crash, should return 0
deleted_count = await cleanup_expired_sessions(keep_days=30) deleted_count = await cleanup_expired_sessions(keep_days=30)
@@ -198,7 +230,7 @@ class TestGetSessionStatistics:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_get_statistics_with_sessions(self, async_test_db, async_test_user): async def test_get_statistics_with_sessions(self, async_test_db, async_test_user):
"""Test getting session statistics with various session types.""" """Test getting session statistics with various session types."""
test_engine, AsyncTestingSessionLocal = async_test_db _test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session: async with AsyncTestingSessionLocal() as session:
# 2 active, not expired # 2 active, not expired
@@ -210,9 +242,9 @@ class TestGetSessionStatistics:
ip_address="192.168.1.1", ip_address="192.168.1.1",
user_agent="Mozilla/5.0", user_agent="Mozilla/5.0",
is_active=True, is_active=True,
expires_at=datetime.now(timezone.utc) + timedelta(days=7), expires_at=datetime.now(UTC) + timedelta(days=7),
created_at=datetime.now(timezone.utc), created_at=datetime.now(UTC),
last_used_at=datetime.now(timezone.utc) last_used_at=datetime.now(UTC),
) )
session.add(active) session.add(active)
@@ -225,9 +257,9 @@ class TestGetSessionStatistics:
ip_address="192.168.1.2", ip_address="192.168.1.2",
user_agent="Mozilla/5.0", user_agent="Mozilla/5.0",
is_active=False, is_active=False,
expires_at=datetime.now(timezone.utc) - timedelta(days=1), expires_at=datetime.now(UTC) - timedelta(days=1),
created_at=datetime.now(timezone.utc) - timedelta(days=2), created_at=datetime.now(UTC) - timedelta(days=2),
last_used_at=datetime.now(timezone.utc) last_used_at=datetime.now(UTC),
) )
session.add(inactive) session.add(inactive)
@@ -239,16 +271,20 @@ class TestGetSessionStatistics:
ip_address="192.168.1.3", ip_address="192.168.1.3",
user_agent="Mozilla/5.0", user_agent="Mozilla/5.0",
is_active=True, is_active=True,
expires_at=datetime.now(timezone.utc) - timedelta(hours=1), expires_at=datetime.now(UTC) - timedelta(hours=1),
created_at=datetime.now(timezone.utc) - timedelta(days=1), created_at=datetime.now(UTC) - timedelta(days=1),
last_used_at=datetime.now(timezone.utc) last_used_at=datetime.now(UTC),
) )
session.add(expired_active) session.add(expired_active)
await session.commit() await session.commit()
with patch('app.services.session_cleanup.SessionLocal', return_value=AsyncTestingSessionLocal()): with patch(
"app.services.session_cleanup.SessionLocal",
return_value=AsyncTestingSessionLocal(),
):
from app.services.session_cleanup import get_session_statistics from app.services.session_cleanup import get_session_statistics
stats = await get_session_statistics() stats = await get_session_statistics()
assert stats["total"] == 6 assert stats["total"] == 6
@@ -259,10 +295,14 @@ class TestGetSessionStatistics:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_get_statistics_empty_database(self, async_test_db): async def test_get_statistics_empty_database(self, async_test_db):
"""Test getting statistics with no sessions.""" """Test getting statistics with no sessions."""
test_engine, AsyncTestingSessionLocal = async_test_db _test_engine, AsyncTestingSessionLocal = async_test_db
with patch('app.services.session_cleanup.SessionLocal', return_value=AsyncTestingSessionLocal()): with patch(
"app.services.session_cleanup.SessionLocal",
return_value=AsyncTestingSessionLocal(),
):
from app.services.session_cleanup import get_session_statistics from app.services.session_cleanup import get_session_statistics
stats = await get_session_statistics() stats = await get_session_statistics()
assert stats["total"] == 0 assert stats["total"] == 0
@@ -271,9 +311,11 @@ class TestGetSessionStatistics:
assert stats["expired"] == 0 assert stats["expired"] == 0
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_get_statistics_database_error_returns_empty_dict(self, async_test_db): async def test_get_statistics_database_error_returns_empty_dict(
self, async_test_db
):
"""Test statistics returns empty dict on database errors.""" """Test statistics returns empty dict on database errors."""
test_engine, AsyncTestingSessionLocal = async_test_db _test_engine, _AsyncTestingSessionLocal = async_test_db
# Create a mock that raises on execute # Create a mock that raises on execute
mock_session = AsyncMock() mock_session = AsyncMock()
@@ -283,8 +325,12 @@ class TestGetSessionStatistics:
async def mock_session_local(): async def mock_session_local():
yield mock_session yield mock_session
with patch('app.services.session_cleanup.SessionLocal', return_value=mock_session_local()): with patch(
"app.services.session_cleanup.SessionLocal",
return_value=mock_session_local(),
):
from app.services.session_cleanup import get_session_statistics from app.services.session_cleanup import get_session_statistics
stats = await get_session_statistics() stats = await get_session_statistics()
assert stats == {} assert stats == {}
@@ -294,9 +340,11 @@ class TestConcurrentCleanup:
"""Tests for concurrent cleanup scenarios.""" """Tests for concurrent cleanup scenarios."""
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_concurrent_cleanup_no_duplicate_deletes(self, async_test_db, async_test_user): async def test_concurrent_cleanup_no_duplicate_deletes(
self, async_test_db, async_test_user
):
"""Test concurrent cleanups don't cause race conditions.""" """Test concurrent cleanups don't cause race conditions."""
test_engine, AsyncTestingSessionLocal = async_test_db _test_engine, AsyncTestingSessionLocal = async_test_db
# Create 10 expired sessions # Create 10 expired sessions
async with AsyncTestingSessionLocal() as session: async with AsyncTestingSessionLocal() as session:
@@ -308,20 +356,24 @@ class TestConcurrentCleanup:
ip_address="192.168.1.1", ip_address="192.168.1.1",
user_agent="Mozilla/5.0", user_agent="Mozilla/5.0",
is_active=False, is_active=False,
expires_at=datetime.now(timezone.utc) - timedelta(days=10), expires_at=datetime.now(UTC) - timedelta(days=10),
created_at=datetime.now(timezone.utc) - timedelta(days=40), created_at=datetime.now(UTC) - timedelta(days=40),
last_used_at=datetime.now(timezone.utc) last_used_at=datetime.now(UTC),
) )
session.add(expired) session.add(expired)
await session.commit() await session.commit()
# Run two cleanups concurrently # Run two cleanups concurrently
# Use side_effect to return fresh session instances for each call # Use side_effect to return fresh session instances for each call
with patch('app.services.session_cleanup.SessionLocal', side_effect=lambda: AsyncTestingSessionLocal()): with patch(
"app.services.session_cleanup.SessionLocal",
side_effect=lambda: AsyncTestingSessionLocal(),
):
from app.services.session_cleanup import cleanup_expired_sessions from app.services.session_cleanup import cleanup_expired_sessions
results = await asyncio.gather( results = await asyncio.gather(
cleanup_expired_sessions(keep_days=30), cleanup_expired_sessions(keep_days=30),
cleanup_expired_sessions(keep_days=30) cleanup_expired_sessions(keep_days=30),
) )
# Both should report deleting sessions (may overlap due to transaction timing) # Both should report deleting sessions (may overlap due to transaction timing)

View File

@@ -2,12 +2,13 @@
""" """
Tests for database initialization script. Tests for database initialization script.
""" """
import pytest
import pytest_asyncio
from unittest.mock import AsyncMock, patch
from app.init_db import init_db from unittest.mock import patch
import pytest
from app.core.config import settings from app.core.config import settings
from app.init_db import init_db
class TestInitDb: class TestInitDb:
@@ -16,69 +17,86 @@ class TestInitDb:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_init_db_creates_superuser_when_not_exists(self, async_test_db): async def test_init_db_creates_superuser_when_not_exists(self, async_test_db):
"""Test that init_db creates a superuser when one doesn't exist.""" """Test that init_db creates a superuser when one doesn't exist."""
test_engine, SessionLocal = async_test_db _test_engine, SessionLocal = async_test_db
# Mock the SessionLocal to use our test database # Mock the SessionLocal to use our test database
with patch('app.init_db.SessionLocal', SessionLocal): with patch("app.init_db.SessionLocal", SessionLocal):
# Mock settings to provide test credentials # Mock settings to provide test credentials
with patch.object(settings, 'FIRST_SUPERUSER_EMAIL', 'test_admin@example.com'): with patch.object(
with patch.object(settings, 'FIRST_SUPERUSER_PASSWORD', 'TestAdmin123!'): settings, "FIRST_SUPERUSER_EMAIL", "test_admin@example.com"
):
with patch.object(
settings, "FIRST_SUPERUSER_PASSWORD", "TestAdmin123!"
):
# Run init_db # Run init_db
user = await init_db() user = await init_db()
# Verify superuser was created # Verify superuser was created
assert user is not None assert user is not None
assert user.email == 'test_admin@example.com' assert user.email == "test_admin@example.com"
assert user.is_superuser is True assert user.is_superuser is True
assert user.first_name == 'Admin' assert user.first_name == "Admin"
assert user.last_name == 'User' assert user.last_name == "User"
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_init_db_returns_existing_superuser(self, async_test_db, async_test_user): async def test_init_db_returns_existing_superuser(
self, async_test_db, async_test_user
):
"""Test that init_db returns existing superuser instead of creating duplicate.""" """Test that init_db returns existing superuser instead of creating duplicate."""
test_engine, SessionLocal = async_test_db _test_engine, SessionLocal = async_test_db
# Mock the SessionLocal to use our test database # Mock the SessionLocal to use our test database
with patch('app.init_db.SessionLocal', SessionLocal): with patch("app.init_db.SessionLocal", SessionLocal):
# Mock settings to match async_test_user's email # Mock settings to match async_test_user's email
with patch.object(settings, 'FIRST_SUPERUSER_EMAIL', 'testuser@example.com'): with patch.object(
with patch.object(settings, 'FIRST_SUPERUSER_PASSWORD', 'TestPassword123!'): settings, "FIRST_SUPERUSER_EMAIL", "testuser@example.com"
):
with patch.object(
settings, "FIRST_SUPERUSER_PASSWORD", "TestPassword123!"
):
# Run init_db # Run init_db
user = await init_db() user = await init_db()
# Verify it returns the existing user # Verify it returns the existing user
assert user is not None assert user is not None
assert user.id == async_test_user.id assert user.id == async_test_user.id
assert user.email == 'testuser@example.com' assert user.email == "testuser@example.com"
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_init_db_uses_default_credentials(self, async_test_db): async def test_init_db_uses_default_credentials(self, async_test_db):
"""Test that init_db uses default credentials when env vars not set.""" """Test that init_db uses default credentials when env vars not set."""
test_engine, SessionLocal = async_test_db _test_engine, SessionLocal = async_test_db
# Mock the SessionLocal to use our test database # Mock the SessionLocal to use our test database
with patch('app.init_db.SessionLocal', SessionLocal): with patch("app.init_db.SessionLocal", SessionLocal):
# Mock settings to have None values (not configured) # Mock settings to have None values (not configured)
with patch.object(settings, 'FIRST_SUPERUSER_EMAIL', None): with patch.object(settings, "FIRST_SUPERUSER_EMAIL", None):
with patch.object(settings, 'FIRST_SUPERUSER_PASSWORD', None): with patch.object(settings, "FIRST_SUPERUSER_PASSWORD", None):
# Run init_db # Run init_db
user = await init_db() user = await init_db()
# Verify superuser was created with defaults # Verify superuser was created with defaults
assert user is not None assert user is not None
assert user.email == 'admin@example.com' assert user.email == "admin@example.com"
assert user.is_superuser is True assert user.is_superuser is True
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_init_db_handles_database_errors(self, async_test_db): async def test_init_db_handles_database_errors(self, async_test_db):
"""Test that init_db handles database errors gracefully.""" """Test that init_db handles database errors gracefully."""
test_engine, SessionLocal = async_test_db _test_engine, SessionLocal = async_test_db
# Mock user_crud.get_by_email to raise an exception # Mock user_crud.get_by_email to raise an exception
with patch('app.init_db.user_crud.get_by_email', side_effect=Exception("Database error")): with patch(
with patch('app.init_db.SessionLocal', SessionLocal): "app.init_db.user_crud.get_by_email",
with patch.object(settings, 'FIRST_SUPERUSER_EMAIL', 'test@example.com'): side_effect=Exception("Database error"),
with patch.object(settings, 'FIRST_SUPERUSER_PASSWORD', 'TestPassword123!'): ):
with patch("app.init_db.SessionLocal", SessionLocal):
with patch.object(
settings, "FIRST_SUPERUSER_EMAIL", "test@example.com"
):
with patch.object(
settings, "FIRST_SUPERUSER_PASSWORD", "TestPassword123!"
):
# Run init_db and expect it to raise # Run init_db and expect it to raise
with pytest.raises(Exception, match="Database error"): with pytest.raises(Exception, match="Database error"):
await init_db() await init_db()

View File

@@ -2,18 +2,18 @@
""" """
Comprehensive tests for device utility functions. Comprehensive tests for device utility functions.
""" """
import pytest
from unittest.mock import Mock from unittest.mock import Mock
from fastapi import Request from fastapi import Request
from app.utils.device import ( from app.utils.device import (
extract_device_info,
parse_device_name,
extract_browser, extract_browser,
extract_device_info,
get_client_ip, get_client_ip,
get_device_type,
is_mobile_device, is_mobile_device,
get_device_type parse_device_name,
) )
@@ -138,7 +138,9 @@ class TestExtractBrowser:
def test_extract_browser_edge_legacy(self): def test_extract_browser_edge_legacy(self):
"""Test extracting legacy Edge browser.""" """Test extracting legacy Edge browser."""
ua = "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 Edge/18.19582" ua = (
"Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 Edge/18.19582"
)
result = extract_browser(ua) result = extract_browser(ua)
assert result == "Edge" assert result == "Edge"
@@ -249,7 +251,7 @@ class TestGetClientIp:
request = Mock(spec=Request) request = Mock(spec=Request)
request.headers = { request.headers = {
"x-forwarded-for": "192.168.1.100", "x-forwarded-for": "192.168.1.100",
"x-real-ip": "192.168.1.200" "x-real-ip": "192.168.1.200",
} }
request.client = Mock() request.client = Mock()
request.client.host = "192.168.1.50" request.client.host = "192.168.1.50"
@@ -385,7 +387,7 @@ class TestExtractDeviceInfo:
request.headers = { request.headers = {
"user-agent": "Mozilla/5.0 (iPhone; CPU iPhone OS 15_0 like Mac OS X)", "user-agent": "Mozilla/5.0 (iPhone; CPU iPhone OS 15_0 like Mac OS X)",
"x-device-id": "device-123-456", "x-device-id": "device-123-456",
"x-forwarded-for": "192.168.1.100" "x-forwarded-for": "192.168.1.100",
} }
request.client = None request.client = None

View File

@@ -2,19 +2,21 @@
""" """
Tests for security utility functions. Tests for security utility functions.
""" """
import time
import base64 import base64
import json import json
import time
from unittest.mock import MagicMock, patch
import pytest import pytest
from unittest.mock import patch, MagicMock
from app.utils.security import ( from app.utils.security import (
create_upload_token,
verify_upload_token,
create_password_reset_token,
verify_password_reset_token,
create_email_verification_token, create_email_verification_token,
verify_email_verification_token create_password_reset_token,
create_upload_token,
verify_email_verification_token,
verify_password_reset_token,
verify_upload_token,
) )
@@ -31,7 +33,7 @@ class TestCreateUploadToken:
# Token should be base64 encoded # Token should be base64 encoded
try: try:
decoded = base64.urlsafe_b64decode(token.encode('utf-8')) decoded = base64.urlsafe_b64decode(token.encode("utf-8"))
token_data = json.loads(decoded) token_data = json.loads(decoded)
assert "payload" in token_data assert "payload" in token_data
assert "signature" in token_data assert "signature" in token_data
@@ -46,7 +48,7 @@ class TestCreateUploadToken:
token = create_upload_token(file_path, content_type) token = create_upload_token(file_path, content_type)
# Decode and verify payload # Decode and verify payload
decoded = base64.urlsafe_b64decode(token.encode('utf-8')) decoded = base64.urlsafe_b64decode(token.encode("utf-8"))
token_data = json.loads(decoded) token_data = json.loads(decoded)
payload = token_data["payload"] payload = token_data["payload"]
@@ -62,7 +64,7 @@ class TestCreateUploadToken:
after = int(time.time()) after = int(time.time())
# Decode token # Decode token
decoded = base64.urlsafe_b64decode(token.encode('utf-8')) decoded = base64.urlsafe_b64decode(token.encode("utf-8"))
token_data = json.loads(decoded) token_data = json.loads(decoded)
payload = token_data["payload"] payload = token_data["payload"]
@@ -74,11 +76,13 @@ class TestCreateUploadToken:
"""Test token creation with custom expiration time.""" """Test token creation with custom expiration time."""
custom_exp = 600 # 10 minutes custom_exp = 600 # 10 minutes
before = int(time.time()) before = int(time.time())
token = create_upload_token("/uploads/test.jpg", "image/jpeg", expires_in=custom_exp) token = create_upload_token(
"/uploads/test.jpg", "image/jpeg", expires_in=custom_exp
)
after = int(time.time()) after = int(time.time())
# Decode token # Decode token
decoded = base64.urlsafe_b64decode(token.encode('utf-8')) decoded = base64.urlsafe_b64decode(token.encode("utf-8"))
token_data = json.loads(decoded) token_data = json.loads(decoded)
payload = token_data["payload"] payload = token_data["payload"]
@@ -92,11 +96,11 @@ class TestCreateUploadToken:
token2 = create_upload_token("/uploads/test.jpg", "image/jpeg") token2 = create_upload_token("/uploads/test.jpg", "image/jpeg")
# Decode both tokens # Decode both tokens
decoded1 = base64.urlsafe_b64decode(token1.encode('utf-8')) decoded1 = base64.urlsafe_b64decode(token1.encode("utf-8"))
token_data1 = json.loads(decoded1) token_data1 = json.loads(decoded1)
nonce1 = token_data1["payload"]["nonce"] nonce1 = token_data1["payload"]["nonce"]
decoded2 = base64.urlsafe_b64decode(token2.encode('utf-8')) decoded2 = base64.urlsafe_b64decode(token2.encode("utf-8"))
token_data2 = json.loads(decoded2) token_data2 = json.loads(decoded2)
nonce2 = token_data2["payload"]["nonce"] nonce2 = token_data2["payload"]["nonce"]
@@ -133,7 +137,7 @@ class TestVerifyUploadToken:
current_time = 1000000 current_time = 1000000
mock_time.time = MagicMock(return_value=current_time) mock_time.time = MagicMock(return_value=current_time)
with patch('app.utils.security.time', mock_time): with patch("app.utils.security.time", mock_time):
# Create token that "expires" at current_time + 1 # Create token that "expires" at current_time + 1
token = create_upload_token("/uploads/test.jpg", "image/jpeg", expires_in=1) token = create_upload_token("/uploads/test.jpg", "image/jpeg", expires_in=1)
@@ -149,13 +153,15 @@ class TestVerifyUploadToken:
token = create_upload_token("/uploads/test.jpg", "image/jpeg") token = create_upload_token("/uploads/test.jpg", "image/jpeg")
# Decode, modify, and re-encode # Decode, modify, and re-encode
decoded = base64.urlsafe_b64decode(token.encode('utf-8')) decoded = base64.urlsafe_b64decode(token.encode("utf-8"))
token_data = json.loads(decoded) token_data = json.loads(decoded)
token_data["signature"] = "invalid_signature" token_data["signature"] = "invalid_signature"
# Re-encode the tampered token # Re-encode the tampered token
tampered_json = json.dumps(token_data) tampered_json = json.dumps(token_data)
tampered_token = base64.urlsafe_b64encode(tampered_json.encode('utf-8')).decode('utf-8') tampered_token = base64.urlsafe_b64encode(tampered_json.encode("utf-8")).decode(
"utf-8"
)
payload = verify_upload_token(tampered_token) payload = verify_upload_token(tampered_token)
assert payload is None assert payload is None
@@ -165,13 +171,15 @@ class TestVerifyUploadToken:
token = create_upload_token("/uploads/test.jpg", "image/jpeg") token = create_upload_token("/uploads/test.jpg", "image/jpeg")
# Decode, modify payload, and re-encode # Decode, modify payload, and re-encode
decoded = base64.urlsafe_b64decode(token.encode('utf-8')) decoded = base64.urlsafe_b64decode(token.encode("utf-8"))
token_data = json.loads(decoded) token_data = json.loads(decoded)
token_data["payload"]["path"] = "/uploads/hacked.exe" token_data["payload"]["path"] = "/uploads/hacked.exe"
# Re-encode the tampered token (signature won't match) # Re-encode the tampered token (signature won't match)
tampered_json = json.dumps(token_data) tampered_json = json.dumps(token_data)
tampered_token = base64.urlsafe_b64encode(tampered_json.encode('utf-8')).decode('utf-8') tampered_token = base64.urlsafe_b64encode(tampered_json.encode("utf-8")).decode(
"utf-8"
)
payload = verify_upload_token(tampered_token) payload = verify_upload_token(tampered_token)
assert payload is None assert payload is None
@@ -194,7 +202,9 @@ class TestVerifyUploadToken:
"""Test that tokens with invalid JSON are rejected.""" """Test that tokens with invalid JSON are rejected."""
# Create a base64 string that decodes to invalid JSON # Create a base64 string that decodes to invalid JSON
invalid_json = "not valid json" invalid_json = "not valid json"
invalid_token = base64.urlsafe_b64encode(invalid_json.encode('utf-8')).decode('utf-8') invalid_token = base64.urlsafe_b64encode(invalid_json.encode("utf-8")).decode(
"utf-8"
)
payload = verify_upload_token(invalid_token) payload = verify_upload_token(invalid_token)
assert payload is None assert payload is None
@@ -207,11 +217,13 @@ class TestVerifyUploadToken:
"path": "/uploads/test.jpg" "path": "/uploads/test.jpg"
# Missing content_type, exp, nonce # Missing content_type, exp, nonce
}, },
"signature": "some_signature" "signature": "some_signature",
} }
incomplete_json = json.dumps(incomplete_data) incomplete_json = json.dumps(incomplete_data)
incomplete_token = base64.urlsafe_b64encode(incomplete_json.encode('utf-8')).decode('utf-8') incomplete_token = base64.urlsafe_b64encode(
incomplete_json.encode("utf-8")
).decode("utf-8")
payload = verify_upload_token(incomplete_token) payload = verify_upload_token(incomplete_token)
assert payload is None assert payload is None
@@ -266,7 +278,7 @@ class TestPasswordResetTokens:
email = "user@example.com" email = "user@example.com"
# Create token that expires in 1 second # Create token that expires in 1 second
with patch('app.utils.security.time') as mock_time: with patch("app.utils.security.time") as mock_time:
mock_time.time = MagicMock(return_value=1000000) mock_time.time = MagicMock(return_value=1000000)
token = create_password_reset_token(email, expires_in=1) token = create_password_reset_token(email, expires_in=1)
@@ -287,12 +299,14 @@ class TestPasswordResetTokens:
token = create_password_reset_token(email) token = create_password_reset_token(email)
# Decode and tamper # Decode and tamper
decoded = base64.urlsafe_b64decode(token.encode('utf-8')).decode('utf-8') decoded = base64.urlsafe_b64decode(token.encode("utf-8")).decode("utf-8")
token_data = json.loads(decoded) token_data = json.loads(decoded)
token_data["payload"]["email"] = "hacker@example.com" token_data["payload"]["email"] = "hacker@example.com"
# Re-encode # Re-encode
tampered = base64.urlsafe_b64encode(json.dumps(token_data).encode('utf-8')).decode('utf-8') tampered = base64.urlsafe_b64encode(
json.dumps(token_data).encode("utf-8")
).decode("utf-8")
verified_email = verify_password_reset_token(tampered) verified_email = verify_password_reset_token(tampered)
assert verified_email is None assert verified_email is None
@@ -312,14 +326,14 @@ class TestPasswordResetTokens:
email = "user@example.com" email = "user@example.com"
custom_exp = 7200 # 2 hours custom_exp = 7200 # 2 hours
with patch('app.utils.security.time') as mock_time: with patch("app.utils.security.time") as mock_time:
current_time = 1000000 current_time = 1000000
mock_time.time = MagicMock(return_value=current_time) mock_time.time = MagicMock(return_value=current_time)
token = create_password_reset_token(email, expires_in=custom_exp) token = create_password_reset_token(email, expires_in=custom_exp)
# Decode to check expiration # Decode to check expiration
decoded = base64.urlsafe_b64decode(token.encode('utf-8')).decode('utf-8') decoded = base64.urlsafe_b64decode(token.encode("utf-8")).decode("utf-8")
token_data = json.loads(decoded) token_data = json.loads(decoded)
assert token_data["payload"]["exp"] == current_time + custom_exp assert token_data["payload"]["exp"] == current_time + custom_exp
@@ -350,7 +364,7 @@ class TestEmailVerificationTokens:
"""Test that expired verification tokens are rejected.""" """Test that expired verification tokens are rejected."""
email = "user@example.com" email = "user@example.com"
with patch('app.utils.security.time') as mock_time: with patch("app.utils.security.time") as mock_time:
mock_time.time = MagicMock(return_value=1000000) mock_time.time = MagicMock(return_value=1000000)
token = create_email_verification_token(email, expires_in=1) token = create_email_verification_token(email, expires_in=1)
@@ -371,12 +385,14 @@ class TestEmailVerificationTokens:
token = create_email_verification_token(email) token = create_email_verification_token(email)
# Decode and tamper # Decode and tamper
decoded = base64.urlsafe_b64decode(token.encode('utf-8')).decode('utf-8') decoded = base64.urlsafe_b64decode(token.encode("utf-8")).decode("utf-8")
token_data = json.loads(decoded) token_data = json.loads(decoded)
token_data["payload"]["email"] = "hacker@example.com" token_data["payload"]["email"] = "hacker@example.com"
# Re-encode # Re-encode
tampered = base64.urlsafe_b64encode(json.dumps(token_data).encode('utf-8')).decode('utf-8') tampered = base64.urlsafe_b64encode(
json.dumps(token_data).encode("utf-8")
).decode("utf-8")
verified_email = verify_email_verification_token(tampered) verified_email = verify_email_verification_token(tampered)
assert verified_email is None assert verified_email is None
@@ -395,14 +411,14 @@ class TestEmailVerificationTokens:
"""Test email verification token with default 24-hour expiration.""" """Test email verification token with default 24-hour expiration."""
email = "user@example.com" email = "user@example.com"
with patch('app.utils.security.time') as mock_time: with patch("app.utils.security.time") as mock_time:
current_time = 1000000 current_time = 1000000
mock_time.time = MagicMock(return_value=current_time) mock_time.time = MagicMock(return_value=current_time)
token = create_email_verification_token(email) token = create_email_verification_token(email)
# Decode to check expiration (should be 86400 seconds = 24 hours) # Decode to check expiration (should be 86400 seconds = 24 hours)
decoded = base64.urlsafe_b64decode(token.encode('utf-8')).decode('utf-8') decoded = base64.urlsafe_b64decode(token.encode("utf-8")).decode("utf-8")
token_data = json.loads(decoded) token_data = json.loads(decoded)
assert token_data["payload"]["exp"] == current_time + 86400 assert token_data["payload"]["exp"] == current_time + 86400