Add pyproject.toml for consolidated project configuration and replace Black, isort, and Flake8 with Ruff

- Introduced `pyproject.toml` to centralize backend tool configurations (e.g., Ruff, mypy, coverage, pytest).
- Replaced Black, isort, and Flake8 with Ruff for linting, formatting, and import sorting.
- Updated `requirements.txt` to include Ruff and remove replaced tools.
- Added `Makefile` to streamline development workflows with commands for linting, formatting, type-checking, testing, and cleanup.
This commit is contained in:
2025-11-10 11:55:15 +01:00
parent a5c671c133
commit c589b565f0
86 changed files with 4572 additions and 3956 deletions

View File

@@ -2,12 +2,11 @@ import sys
from logging.config import fileConfig
from pathlib import Path
from sqlalchemy import engine_from_config, pool, text, create_engine
from alembic import context
from sqlalchemy import create_engine, engine_from_config, pool, text
from sqlalchemy.engine.url import make_url
from sqlalchemy.exc import OperationalError
from alembic import context
# Get the path to the app directory (parent of 'alembic')
app_dir = Path(__file__).resolve().parent.parent
# Add the app directory to Python path
@@ -66,7 +65,9 @@ def ensure_database_exists(db_url: str) -> None:
admin_url = url.set(database="postgres")
# CREATE DATABASE cannot run inside a transaction
admin_engine = create_engine(str(admin_url), isolation_level="AUTOCOMMIT", poolclass=pool.NullPool)
admin_engine = create_engine(
str(admin_url), isolation_level="AUTOCOMMIT", poolclass=pool.NullPool
)
try:
with admin_engine.connect() as conn:
exists = conn.execute(
@@ -122,9 +123,7 @@ def run_migrations_online() -> None:
)
with connectable.connect() as connection:
context.configure(
connection=connection, target_metadata=target_metadata
)
context.configure(connection=connection, target_metadata=target_metadata)
with context.begin_transaction():
context.run_migrations()
@@ -133,4 +132,4 @@ def run_migrations_online() -> None:
if context.is_offline_mode():
run_migrations_offline()
else:
run_migrations_online()
run_migrations_online()

View File

@@ -5,17 +5,17 @@ Revises: fbf6318a8a36
Create Date: 2025-11-01 04:15:25.367010
"""
from typing import Sequence, Union
from collections.abc import Sequence
import sqlalchemy as sa
from alembic import op
# revision identifiers, used by Alembic.
revision: str = '1174fffbe3e4'
down_revision: Union[str, None] = 'fbf6318a8a36'
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
revision: str = "1174fffbe3e4"
down_revision: str | None = "fbf6318a8a36"
branch_labels: str | Sequence[str] | None = None
depends_on: str | Sequence[str] | None = None
def upgrade() -> None:
@@ -24,46 +24,46 @@ def upgrade() -> None:
# Index for session cleanup queries
# Optimizes: DELETE WHERE is_active = FALSE AND expires_at < now AND created_at < cutoff
op.create_index(
'ix_user_sessions_cleanup',
'user_sessions',
['is_active', 'expires_at', 'created_at'],
"ix_user_sessions_cleanup",
"user_sessions",
["is_active", "expires_at", "created_at"],
unique=False,
postgresql_where=sa.text('is_active = false')
postgresql_where=sa.text("is_active = false"),
)
# Index for user search queries (basic trigram support without pg_trgm extension)
# Optimizes: WHERE email ILIKE '%search%' OR first_name ILIKE '%search%'
# Note: For better performance, consider enabling pg_trgm extension
op.create_index(
'ix_users_email_lower',
'users',
[sa.text('LOWER(email)')],
"ix_users_email_lower",
"users",
[sa.text("LOWER(email)")],
unique=False,
postgresql_where=sa.text('deleted_at IS NULL')
postgresql_where=sa.text("deleted_at IS NULL"),
)
op.create_index(
'ix_users_first_name_lower',
'users',
[sa.text('LOWER(first_name)')],
"ix_users_first_name_lower",
"users",
[sa.text("LOWER(first_name)")],
unique=False,
postgresql_where=sa.text('deleted_at IS NULL')
postgresql_where=sa.text("deleted_at IS NULL"),
)
op.create_index(
'ix_users_last_name_lower',
'users',
[sa.text('LOWER(last_name)')],
"ix_users_last_name_lower",
"users",
[sa.text("LOWER(last_name)")],
unique=False,
postgresql_where=sa.text('deleted_at IS NULL')
postgresql_where=sa.text("deleted_at IS NULL"),
)
# Index for organization search
op.create_index(
'ix_organizations_name_lower',
'organizations',
[sa.text('LOWER(name)')],
unique=False
"ix_organizations_name_lower",
"organizations",
[sa.text("LOWER(name)")],
unique=False,
)
@@ -71,8 +71,8 @@ def downgrade() -> None:
"""Remove performance indexes."""
# Drop indexes in reverse order
op.drop_index('ix_organizations_name_lower', table_name='organizations')
op.drop_index('ix_users_last_name_lower', table_name='users')
op.drop_index('ix_users_first_name_lower', table_name='users')
op.drop_index('ix_users_email_lower', table_name='users')
op.drop_index('ix_user_sessions_cleanup', table_name='user_sessions')
op.drop_index("ix_organizations_name_lower", table_name="organizations")
op.drop_index("ix_users_last_name_lower", table_name="users")
op.drop_index("ix_users_first_name_lower", table_name="users")
op.drop_index("ix_users_email_lower", table_name="users")
op.drop_index("ix_user_sessions_cleanup", table_name="user_sessions")

View File

@@ -5,30 +5,32 @@ Revises: 9e4f2a1b8c7d
Create Date: 2025-10-30 16:40:21.000021
"""
from typing import Sequence, Union
from collections.abc import Sequence
import sqlalchemy as sa
from alembic import op
# revision identifiers, used by Alembic.
revision: str = '2d0fcec3b06d'
down_revision: Union[str, None] = '9e4f2a1b8c7d'
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
revision: str = "2d0fcec3b06d"
down_revision: str | None = "9e4f2a1b8c7d"
branch_labels: str | Sequence[str] | None = None
depends_on: str | Sequence[str] | None = None
def upgrade() -> None:
# Add deleted_at column for soft deletes
op.add_column('users', sa.Column('deleted_at', sa.DateTime(timezone=True), nullable=True))
op.add_column(
"users", sa.Column("deleted_at", sa.DateTime(timezone=True), nullable=True)
)
# Add index on deleted_at for efficient queries
op.create_index('ix_users_deleted_at', 'users', ['deleted_at'])
op.create_index("ix_users_deleted_at", "users", ["deleted_at"])
def downgrade() -> None:
# Remove index
op.drop_index('ix_users_deleted_at', table_name='users')
op.drop_index("ix_users_deleted_at", table_name="users")
# Remove column
op.drop_column('users', 'deleted_at')
op.drop_column("users", "deleted_at")

View File

@@ -5,42 +5,42 @@ Revises: 7396957cbe80
Create Date: 2025-02-28 09:19:33.212278
"""
from typing import Sequence, Union
from collections.abc import Sequence
import sqlalchemy as sa
from alembic import op
# revision identifiers, used by Alembic.
revision: str = '38bf9e7e74b3'
down_revision: Union[str, None] = '7396957cbe80'
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
revision: str = "38bf9e7e74b3"
down_revision: str | None = "7396957cbe80"
branch_labels: str | Sequence[str] | None = None
depends_on: str | Sequence[str] | None = None
def upgrade() -> None:
op.create_table('users',
sa.Column('email', sa.String(), nullable=False),
sa.Column('password_hash', sa.String(), nullable=False),
sa.Column('first_name', sa.String(), nullable=False),
sa.Column('last_name', sa.String(), nullable=True),
sa.Column('phone_number', sa.String(), nullable=True),
sa.Column('is_active', sa.Boolean(), nullable=False),
sa.Column('is_superuser', sa.Boolean(), nullable=False),
sa.Column('preferences', sa.JSON(), nullable=True),
sa.Column('id', sa.UUID(), nullable=False),
sa.Column('created_at', sa.DateTime(timezone=True), nullable=False),
sa.Column('updated_at', sa.DateTime(timezone=True), nullable=False),
sa.PrimaryKeyConstraint('id')
op.create_table(
"users",
sa.Column("email", sa.String(), nullable=False),
sa.Column("password_hash", sa.String(), nullable=False),
sa.Column("first_name", sa.String(), nullable=False),
sa.Column("last_name", sa.String(), nullable=True),
sa.Column("phone_number", sa.String(), nullable=True),
sa.Column("is_active", sa.Boolean(), nullable=False),
sa.Column("is_superuser", sa.Boolean(), nullable=False),
sa.Column("preferences", sa.JSON(), nullable=True),
sa.Column("id", sa.UUID(), nullable=False),
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False),
sa.Column("updated_at", sa.DateTime(timezone=True), nullable=False),
sa.PrimaryKeyConstraint("id"),
)
op.create_index(op.f('ix_users_email'), 'users', ['email'], unique=True)
op.create_index(op.f("ix_users_email"), "users", ["email"], unique=True)
# ### end Alembic commands ###
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_index(op.f('ix_users_email'), table_name='users')
op.drop_table('users')
op.drop_index(op.f("ix_users_email"), table_name="users")
op.drop_table("users")
# ### end Alembic commands ###

View File

@@ -5,98 +5,85 @@ Revises: b76c725fc3cf
Create Date: 2025-10-31 07:41:18.729544
"""
from typing import Sequence, Union
from collections.abc import Sequence
import sqlalchemy as sa
from alembic import op
# revision identifiers, used by Alembic.
revision: str = '549b50ea888d'
down_revision: Union[str, None] = 'b76c725fc3cf'
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
revision: str = "549b50ea888d"
down_revision: str | None = "b76c725fc3cf"
branch_labels: str | Sequence[str] | None = None
depends_on: str | Sequence[str] | None = None
def upgrade() -> None:
# Create user_sessions table for per-device session management
op.create_table(
'user_sessions',
sa.Column('id', sa.UUID(), nullable=False),
sa.Column('user_id', sa.UUID(), nullable=False),
sa.Column('refresh_token_jti', sa.String(length=255), nullable=False),
sa.Column('device_name', sa.String(length=255), nullable=True),
sa.Column('device_id', sa.String(length=255), nullable=True),
sa.Column('ip_address', sa.String(length=45), nullable=True),
sa.Column('user_agent', sa.String(length=500), nullable=True),
sa.Column('last_used_at', sa.DateTime(timezone=True), nullable=False),
sa.Column('expires_at', sa.DateTime(timezone=True), nullable=False),
sa.Column('is_active', sa.Boolean(), nullable=False, server_default='true'),
sa.Column('location_city', sa.String(length=100), nullable=True),
sa.Column('location_country', sa.String(length=100), nullable=True),
sa.Column('created_at', sa.DateTime(timezone=True), nullable=False),
sa.Column('updated_at', sa.DateTime(timezone=True), nullable=False),
sa.PrimaryKeyConstraint('id')
"user_sessions",
sa.Column("id", sa.UUID(), nullable=False),
sa.Column("user_id", sa.UUID(), nullable=False),
sa.Column("refresh_token_jti", sa.String(length=255), nullable=False),
sa.Column("device_name", sa.String(length=255), nullable=True),
sa.Column("device_id", sa.String(length=255), nullable=True),
sa.Column("ip_address", sa.String(length=45), nullable=True),
sa.Column("user_agent", sa.String(length=500), nullable=True),
sa.Column("last_used_at", sa.DateTime(timezone=True), nullable=False),
sa.Column("expires_at", sa.DateTime(timezone=True), nullable=False),
sa.Column("is_active", sa.Boolean(), nullable=False, server_default="true"),
sa.Column("location_city", sa.String(length=100), nullable=True),
sa.Column("location_country", sa.String(length=100), nullable=True),
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False),
sa.Column("updated_at", sa.DateTime(timezone=True), nullable=False),
sa.PrimaryKeyConstraint("id"),
)
# Create foreign key to users table
op.create_foreign_key(
'fk_user_sessions_user_id',
'user_sessions',
'users',
['user_id'],
['id'],
ondelete='CASCADE'
"fk_user_sessions_user_id",
"user_sessions",
"users",
["user_id"],
["id"],
ondelete="CASCADE",
)
# Create indexes for performance
# 1. Lookup session by refresh token JTI (most common query)
op.create_index(
'ix_user_sessions_jti',
'user_sessions',
['refresh_token_jti'],
unique=True
"ix_user_sessions_jti", "user_sessions", ["refresh_token_jti"], unique=True
)
# 2. Lookup sessions by user ID
op.create_index(
'ix_user_sessions_user_id',
'user_sessions',
['user_id']
)
op.create_index("ix_user_sessions_user_id", "user_sessions", ["user_id"])
# 3. Composite index for active sessions by user
op.create_index(
'ix_user_sessions_user_active',
'user_sessions',
['user_id', 'is_active']
"ix_user_sessions_user_active", "user_sessions", ["user_id", "is_active"]
)
# 4. Index on expires_at for cleanup job
op.create_index(
'ix_user_sessions_expires_at',
'user_sessions',
['expires_at']
)
op.create_index("ix_user_sessions_expires_at", "user_sessions", ["expires_at"])
# 5. Composite index for active session lookup by JTI
op.create_index(
'ix_user_sessions_jti_active',
'user_sessions',
['refresh_token_jti', 'is_active']
"ix_user_sessions_jti_active",
"user_sessions",
["refresh_token_jti", "is_active"],
)
def downgrade() -> None:
# Drop indexes first
op.drop_index('ix_user_sessions_jti_active', table_name='user_sessions')
op.drop_index('ix_user_sessions_expires_at', table_name='user_sessions')
op.drop_index('ix_user_sessions_user_active', table_name='user_sessions')
op.drop_index('ix_user_sessions_user_id', table_name='user_sessions')
op.drop_index('ix_user_sessions_jti', table_name='user_sessions')
op.drop_index("ix_user_sessions_jti_active", table_name="user_sessions")
op.drop_index("ix_user_sessions_expires_at", table_name="user_sessions")
op.drop_index("ix_user_sessions_user_active", table_name="user_sessions")
op.drop_index("ix_user_sessions_user_id", table_name="user_sessions")
op.drop_index("ix_user_sessions_jti", table_name="user_sessions")
# Drop foreign key
op.drop_constraint('fk_user_sessions_user_id', 'user_sessions', type_='foreignkey')
op.drop_constraint("fk_user_sessions_user_id", "user_sessions", type_="foreignkey")
# Drop table
op.drop_table('user_sessions')
op.drop_table("user_sessions")

View File

@@ -1,19 +1,18 @@
"""Initial empty migration
Revision ID: 7396957cbe80
Revises:
Revises:
Create Date: 2025-02-27 12:47:46.445313
"""
from typing import Sequence, Union
from alembic import op
from collections.abc import Sequence
# revision identifiers, used by Alembic.
revision: str = '7396957cbe80'
down_revision: Union[str, None] = None
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
revision: str = "7396957cbe80"
down_revision: str | None = None
branch_labels: str | Sequence[str] | None = None
depends_on: str | Sequence[str] | None = None
def upgrade() -> None:

View File

@@ -5,80 +5,112 @@ Revises: 38bf9e7e74b3
Create Date: 2025-10-30 10:00:00.000000
"""
from typing import Sequence, Union
from collections.abc import Sequence
import sqlalchemy as sa
from alembic import op
# revision identifiers, used by Alembic.
revision: str = '9e4f2a1b8c7d'
down_revision: Union[str, None] = '38bf9e7e74b3'
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
revision: str = "9e4f2a1b8c7d"
down_revision: str | None = "38bf9e7e74b3"
branch_labels: str | Sequence[str] | None = None
depends_on: str | Sequence[str] | None = None
def upgrade() -> None:
# Add missing indexes for is_active and is_superuser
op.create_index(op.f('ix_users_is_active'), 'users', ['is_active'], unique=False)
op.create_index(op.f('ix_users_is_superuser'), 'users', ['is_superuser'], unique=False)
op.create_index(op.f("ix_users_is_active"), "users", ["is_active"], unique=False)
op.create_index(
op.f("ix_users_is_superuser"), "users", ["is_superuser"], unique=False
)
# Fix column types to match model definitions with explicit lengths
op.alter_column('users', 'email',
existing_type=sa.String(),
type_=sa.String(length=255),
nullable=False)
op.alter_column(
"users",
"email",
existing_type=sa.String(),
type_=sa.String(length=255),
nullable=False,
)
op.alter_column('users', 'password_hash',
existing_type=sa.String(),
type_=sa.String(length=255),
nullable=False)
op.alter_column(
"users",
"password_hash",
existing_type=sa.String(),
type_=sa.String(length=255),
nullable=False,
)
op.alter_column('users', 'first_name',
existing_type=sa.String(),
type_=sa.String(length=100),
nullable=False,
server_default='user') # Add server default
op.alter_column(
"users",
"first_name",
existing_type=sa.String(),
type_=sa.String(length=100),
nullable=False,
server_default="user",
) # Add server default
op.alter_column('users', 'last_name',
existing_type=sa.String(),
type_=sa.String(length=100),
nullable=True)
op.alter_column(
"users",
"last_name",
existing_type=sa.String(),
type_=sa.String(length=100),
nullable=True,
)
op.alter_column('users', 'phone_number',
existing_type=sa.String(),
type_=sa.String(length=20),
nullable=True)
op.alter_column(
"users",
"phone_number",
existing_type=sa.String(),
type_=sa.String(length=20),
nullable=True,
)
def downgrade() -> None:
# Revert column types
op.alter_column('users', 'phone_number',
existing_type=sa.String(length=20),
type_=sa.String(),
nullable=True)
op.alter_column(
"users",
"phone_number",
existing_type=sa.String(length=20),
type_=sa.String(),
nullable=True,
)
op.alter_column('users', 'last_name',
existing_type=sa.String(length=100),
type_=sa.String(),
nullable=True)
op.alter_column(
"users",
"last_name",
existing_type=sa.String(length=100),
type_=sa.String(),
nullable=True,
)
op.alter_column('users', 'first_name',
existing_type=sa.String(length=100),
type_=sa.String(),
nullable=False,
server_default=None) # Remove server default
op.alter_column(
"users",
"first_name",
existing_type=sa.String(length=100),
type_=sa.String(),
nullable=False,
server_default=None,
) # Remove server default
op.alter_column('users', 'password_hash',
existing_type=sa.String(length=255),
type_=sa.String(),
nullable=False)
op.alter_column(
"users",
"password_hash",
existing_type=sa.String(length=255),
type_=sa.String(),
nullable=False,
)
op.alter_column('users', 'email',
existing_type=sa.String(length=255),
type_=sa.String(),
nullable=False)
op.alter_column(
"users",
"email",
existing_type=sa.String(length=255),
type_=sa.String(),
nullable=False,
)
# Drop indexes
op.drop_index(op.f('ix_users_is_superuser'), table_name='users')
op.drop_index(op.f('ix_users_is_active'), table_name='users')
op.drop_index(op.f("ix_users_is_superuser"), table_name="users")
op.drop_index(op.f("ix_users_is_active"), table_name="users")

View File

@@ -5,17 +5,17 @@ Revises: 2d0fcec3b06d
Create Date: 2025-10-30 16:41:33.273135
"""
from typing import Sequence, Union
from collections.abc import Sequence
import sqlalchemy as sa
from alembic import op
# revision identifiers, used by Alembic.
revision: str = 'b76c725fc3cf'
down_revision: Union[str, None] = '2d0fcec3b06d'
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
revision: str = "b76c725fc3cf"
down_revision: str | None = "2d0fcec3b06d"
branch_labels: str | Sequence[str] | None = None
depends_on: str | Sequence[str] | None = None
def upgrade() -> None:
@@ -23,30 +23,26 @@ def upgrade() -> None:
# Composite index for filtering active users by role
op.create_index(
'ix_users_active_superuser',
'users',
['is_active', 'is_superuser'],
postgresql_where=sa.text('deleted_at IS NULL')
"ix_users_active_superuser",
"users",
["is_active", "is_superuser"],
postgresql_where=sa.text("deleted_at IS NULL"),
)
# Composite index for sorting active users by creation date
op.create_index(
'ix_users_active_created',
'users',
['is_active', 'created_at'],
postgresql_where=sa.text('deleted_at IS NULL')
"ix_users_active_created",
"users",
["is_active", "created_at"],
postgresql_where=sa.text("deleted_at IS NULL"),
)
# Composite index for email lookup of non-deleted users
op.create_index(
'ix_users_email_not_deleted',
'users',
['email', 'deleted_at']
)
op.create_index("ix_users_email_not_deleted", "users", ["email", "deleted_at"])
def downgrade() -> None:
# Remove composite indexes
op.drop_index('ix_users_email_not_deleted', table_name='users')
op.drop_index('ix_users_active_created', table_name='users')
op.drop_index('ix_users_active_superuser', table_name='users')
op.drop_index("ix_users_email_not_deleted", table_name="users")
op.drop_index("ix_users_active_created", table_name="users")
op.drop_index("ix_users_active_superuser", table_name="users")

View File

@@ -5,102 +5,123 @@ Revises: 549b50ea888d
Create Date: 2025-10-31 12:08:05.141353
"""
from typing import Sequence, Union
from collections.abc import Sequence
import sqlalchemy as sa
from alembic import op
# revision identifiers, used by Alembic.
revision: str = 'fbf6318a8a36'
down_revision: Union[str, None] = '549b50ea888d'
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
revision: str = "fbf6318a8a36"
down_revision: str | None = "549b50ea888d"
branch_labels: str | Sequence[str] | None = None
depends_on: str | Sequence[str] | None = None
def upgrade() -> None:
# Create organizations table
op.create_table(
'organizations',
sa.Column('id', sa.UUID(), nullable=False),
sa.Column('name', sa.String(length=255), nullable=False),
sa.Column('slug', sa.String(length=255), nullable=False),
sa.Column('description', sa.Text(), nullable=True),
sa.Column('is_active', sa.Boolean(), nullable=False, server_default='true'),
sa.Column('settings', sa.JSON(), nullable=True),
sa.Column('created_at', sa.DateTime(timezone=True), nullable=False),
sa.Column('updated_at', sa.DateTime(timezone=True), nullable=False),
sa.PrimaryKeyConstraint('id')
"organizations",
sa.Column("id", sa.UUID(), nullable=False),
sa.Column("name", sa.String(length=255), nullable=False),
sa.Column("slug", sa.String(length=255), nullable=False),
sa.Column("description", sa.Text(), nullable=True),
sa.Column("is_active", sa.Boolean(), nullable=False, server_default="true"),
sa.Column("settings", sa.JSON(), nullable=True),
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False),
sa.Column("updated_at", sa.DateTime(timezone=True), nullable=False),
sa.PrimaryKeyConstraint("id"),
)
# Create indexes for organizations
op.create_index('ix_organizations_name', 'organizations', ['name'])
op.create_index('ix_organizations_slug', 'organizations', ['slug'], unique=True)
op.create_index('ix_organizations_is_active', 'organizations', ['is_active'])
op.create_index('ix_organizations_name_active', 'organizations', ['name', 'is_active'])
op.create_index('ix_organizations_slug_active', 'organizations', ['slug', 'is_active'])
op.create_index("ix_organizations_name", "organizations", ["name"])
op.create_index("ix_organizations_slug", "organizations", ["slug"], unique=True)
op.create_index("ix_organizations_is_active", "organizations", ["is_active"])
op.create_index(
"ix_organizations_name_active", "organizations", ["name", "is_active"]
)
op.create_index(
"ix_organizations_slug_active", "organizations", ["slug", "is_active"]
)
# Create user_organizations junction table
op.create_table(
'user_organizations',
sa.Column('user_id', sa.UUID(), nullable=False),
sa.Column('organization_id', sa.UUID(), nullable=False),
sa.Column('role', sa.Enum('OWNER', 'ADMIN', 'MEMBER', 'GUEST', name='organizationrole'), nullable=False, server_default='MEMBER'),
sa.Column('is_active', sa.Boolean(), nullable=False, server_default='true'),
sa.Column('custom_permissions', sa.String(length=500), nullable=True),
sa.Column('created_at', sa.DateTime(timezone=True), nullable=False),
sa.Column('updated_at', sa.DateTime(timezone=True), nullable=False),
sa.PrimaryKeyConstraint('user_id', 'organization_id')
"user_organizations",
sa.Column("user_id", sa.UUID(), nullable=False),
sa.Column("organization_id", sa.UUID(), nullable=False),
sa.Column(
"role",
sa.Enum("OWNER", "ADMIN", "MEMBER", "GUEST", name="organizationrole"),
nullable=False,
server_default="MEMBER",
),
sa.Column("is_active", sa.Boolean(), nullable=False, server_default="true"),
sa.Column("custom_permissions", sa.String(length=500), nullable=True),
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False),
sa.Column("updated_at", sa.DateTime(timezone=True), nullable=False),
sa.PrimaryKeyConstraint("user_id", "organization_id"),
)
# Create foreign keys
op.create_foreign_key(
'fk_user_organizations_user_id',
'user_organizations',
'users',
['user_id'],
['id'],
ondelete='CASCADE'
"fk_user_organizations_user_id",
"user_organizations",
"users",
["user_id"],
["id"],
ondelete="CASCADE",
)
op.create_foreign_key(
'fk_user_organizations_organization_id',
'user_organizations',
'organizations',
['organization_id'],
['id'],
ondelete='CASCADE'
"fk_user_organizations_organization_id",
"user_organizations",
"organizations",
["organization_id"],
["id"],
ondelete="CASCADE",
)
# Create indexes for user_organizations
op.create_index('ix_user_organizations_role', 'user_organizations', ['role'])
op.create_index('ix_user_organizations_is_active', 'user_organizations', ['is_active'])
op.create_index('ix_user_org_user_active', 'user_organizations', ['user_id', 'is_active'])
op.create_index('ix_user_org_org_active', 'user_organizations', ['organization_id', 'is_active'])
op.create_index("ix_user_organizations_role", "user_organizations", ["role"])
op.create_index(
"ix_user_organizations_is_active", "user_organizations", ["is_active"]
)
op.create_index(
"ix_user_org_user_active", "user_organizations", ["user_id", "is_active"]
)
op.create_index(
"ix_user_org_org_active", "user_organizations", ["organization_id", "is_active"]
)
def downgrade() -> None:
# Drop indexes for user_organizations
op.drop_index('ix_user_org_org_active', table_name='user_organizations')
op.drop_index('ix_user_org_user_active', table_name='user_organizations')
op.drop_index('ix_user_organizations_is_active', table_name='user_organizations')
op.drop_index('ix_user_organizations_role', table_name='user_organizations')
op.drop_index("ix_user_org_org_active", table_name="user_organizations")
op.drop_index("ix_user_org_user_active", table_name="user_organizations")
op.drop_index("ix_user_organizations_is_active", table_name="user_organizations")
op.drop_index("ix_user_organizations_role", table_name="user_organizations")
# Drop foreign keys
op.drop_constraint('fk_user_organizations_organization_id', 'user_organizations', type_='foreignkey')
op.drop_constraint('fk_user_organizations_user_id', 'user_organizations', type_='foreignkey')
op.drop_constraint(
"fk_user_organizations_organization_id",
"user_organizations",
type_="foreignkey",
)
op.drop_constraint(
"fk_user_organizations_user_id", "user_organizations", type_="foreignkey"
)
# Drop user_organizations table
op.drop_table('user_organizations')
op.drop_table("user_organizations")
# Drop indexes for organizations
op.drop_index('ix_organizations_slug_active', table_name='organizations')
op.drop_index('ix_organizations_name_active', table_name='organizations')
op.drop_index('ix_organizations_is_active', table_name='organizations')
op.drop_index('ix_organizations_slug', table_name='organizations')
op.drop_index('ix_organizations_name', table_name='organizations')
op.drop_index("ix_organizations_slug_active", table_name="organizations")
op.drop_index("ix_organizations_name_active", table_name="organizations")
op.drop_index("ix_organizations_is_active", table_name="organizations")
op.drop_index("ix_organizations_slug", table_name="organizations")
op.drop_index("ix_organizations_name", table_name="organizations")
# Drop organizations table
op.drop_table('organizations')
op.drop_table("organizations")
# Drop enum type
op.execute('DROP TYPE IF EXISTS organizationrole')
op.execute("DROP TYPE IF EXISTS organizationrole")

View File

@@ -1,12 +1,10 @@
from typing import Optional
from fastapi import Depends, HTTPException, status, Header
from fastapi import Depends, Header, HTTPException, status
from fastapi.security import OAuth2PasswordBearer
from fastapi.security.utils import get_authorization_scheme_param
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.core.auth import get_token_data, TokenExpiredError, TokenInvalidError
from app.core.auth import TokenExpiredError, TokenInvalidError, get_token_data
from app.core.database import get_db
from app.models.user import User
@@ -15,8 +13,7 @@ oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/api/v1/auth/login")
async def get_current_user(
db: AsyncSession = Depends(get_db),
token: str = Depends(oauth2_scheme)
db: AsyncSession = Depends(get_db), token: str = Depends(oauth2_scheme)
) -> User:
"""
Get the current authenticated user.
@@ -36,21 +33,17 @@ async def get_current_user(
token_data = get_token_data(token)
# Get user from database
result = await db.execute(
select(User).where(User.id == token_data.user_id)
)
result = await db.execute(select(User).where(User.id == token_data.user_id))
user = result.scalar_one_or_none()
if not user:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="User not found"
status_code=status.HTTP_404_NOT_FOUND, detail="User not found"
)
if not user.is_active:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Inactive user"
status_code=status.HTTP_403_FORBIDDEN, detail="Inactive user"
)
return user
@@ -59,19 +52,17 @@ async def get_current_user(
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Token expired",
headers={"WWW-Authenticate": "Bearer"}
headers={"WWW-Authenticate": "Bearer"},
)
except TokenInvalidError:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Could not validate credentials",
headers={"WWW-Authenticate": "Bearer"}
headers={"WWW-Authenticate": "Bearer"},
)
def get_current_active_user(
current_user: User = Depends(get_current_user)
) -> User:
def get_current_active_user(current_user: User = Depends(get_current_user)) -> User:
"""
Check if the current user is active.
@@ -86,15 +77,12 @@ def get_current_active_user(
"""
if not current_user.is_active:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Inactive user"
status_code=status.HTTP_403_FORBIDDEN, detail="Inactive user"
)
return current_user
def get_current_superuser(
current_user: User = Depends(get_current_user)
) -> User:
def get_current_superuser(current_user: User = Depends(get_current_user)) -> User:
"""
Check if the current user is a superuser.
@@ -109,13 +97,12 @@ def get_current_superuser(
"""
if not current_user.is_superuser:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Not enough permissions"
status_code=status.HTTP_403_FORBIDDEN, detail="Not enough permissions"
)
return current_user
async def get_optional_token(authorization: str = Header(None)) -> Optional[str]:
async def get_optional_token(authorization: str = Header(None)) -> str | None:
"""
Get the token from the Authorization header without requiring it.
@@ -139,9 +126,8 @@ async def get_optional_token(authorization: str = Header(None)) -> Optional[str]
async def get_optional_current_user(
db: AsyncSession = Depends(get_db),
token: Optional[str] = Depends(get_optional_token)
) -> Optional[User]:
db: AsyncSession = Depends(get_db), token: str | None = Depends(get_optional_token)
) -> User | None:
"""
Get the current user if authenticated, otherwise return None.
Useful for endpoints that work with both authenticated and unauthenticated users.
@@ -158,12 +144,10 @@ async def get_optional_current_user(
try:
token_data = get_token_data(token)
result = await db.execute(
select(User).where(User.id == token_data.user_id)
)
result = await db.execute(select(User).where(User.id == token_data.user_id))
user = result.scalar_one_or_none()
if not user or not user.is_active:
return None
return user
except (TokenExpiredError, TokenInvalidError):
return None
return None

View File

@@ -7,7 +7,7 @@ These dependencies are optional and flexible:
- Use require_org_role for organization-specific access control
- Projects can choose to use these or implement their own permission system
"""
from typing import Optional
from uuid import UUID
from fastapi import Depends, HTTPException, status
@@ -20,9 +20,7 @@ from app.models.user import User
from app.models.user_organization import OrganizationRole
def require_superuser(
current_user: User = Depends(get_current_user)
) -> User:
def require_superuser(current_user: User = Depends(get_current_user)) -> User:
"""
Dependency to ensure the current user is a superuser.
@@ -36,7 +34,7 @@ def require_superuser(
if not current_user.is_superuser:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Superuser privileges required"
detail="Superuser privileges required",
)
return current_user
@@ -62,7 +60,7 @@ class OrganizationPermission:
self,
organization_id: UUID,
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db)
db: AsyncSession = Depends(get_db),
) -> User:
"""
Check if user has required role in the organization.
@@ -84,21 +82,19 @@ class OrganizationPermission:
# Get user's role in organization
user_role = await organization_crud.get_user_role_in_org(
db,
user_id=current_user.id,
organization_id=organization_id
db, user_id=current_user.id, organization_id=organization_id
)
if not user_role:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Not a member of this organization"
detail="Not a member of this organization",
)
if user_role not in self.allowed_roles:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail=f"Role {user_role} not authorized. Required: {self.allowed_roles}"
detail=f"Role {user_role} not authorized. Required: {self.allowed_roles}",
)
return current_user
@@ -106,18 +102,18 @@ class OrganizationPermission:
# Common permission presets for convenience
require_org_owner = OrganizationPermission([OrganizationRole.OWNER])
require_org_admin = OrganizationPermission([OrganizationRole.OWNER, OrganizationRole.ADMIN])
require_org_member = OrganizationPermission([
OrganizationRole.OWNER,
OrganizationRole.ADMIN,
OrganizationRole.MEMBER
])
require_org_admin = OrganizationPermission(
[OrganizationRole.OWNER, OrganizationRole.ADMIN]
)
require_org_member = OrganizationPermission(
[OrganizationRole.OWNER, OrganizationRole.ADMIN, OrganizationRole.MEMBER]
)
async def require_org_membership(
organization_id: UUID,
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db)
db: AsyncSession = Depends(get_db),
) -> User:
"""
Ensure user is a member of the organization (any role).
@@ -128,15 +124,13 @@ async def require_org_membership(
return current_user
user_role = await organization_crud.get_user_role_in_org(
db,
user_id=current_user.id,
organization_id=organization_id
db, user_id=current_user.id, organization_id=organization_id
)
if not user_role:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Not a member of this organization"
detail="Not a member of this organization",
)
return current_user

View File

@@ -1,10 +1,12 @@
from fastapi import APIRouter
from app.api.routes import auth, users, sessions, admin, organizations
from app.api.routes import admin, auth, organizations, sessions, users
api_router = APIRouter()
api_router.include_router(auth.router, prefix="/auth", tags=["Authentication"])
api_router.include_router(users.router, prefix="/users", tags=["Users"])
api_router.include_router(sessions.router, prefix="/sessions", tags=["Sessions"])
api_router.include_router(admin.router, prefix="/admin", tags=["Admin"])
api_router.include_router(organizations.router, prefix="/organizations", tags=["Organizations"])
api_router.include_router(
organizations.router, prefix="/organizations", tags=["Organizations"]
)

View File

@@ -5,9 +5,10 @@ Admin-specific endpoints for managing users and organizations.
These endpoints require superuser privileges and provide CMS-like functionality
for managing the application.
"""
import logging
from enum import Enum
from typing import Any, List, Optional
from typing import Any
from uuid import UUID
from fastapi import APIRouter, Depends, Query, status
@@ -16,27 +17,32 @@ from sqlalchemy.ext.asyncio import AsyncSession
from app.api.dependencies.permissions import require_superuser
from app.core.database import get_db
from app.core.exceptions import NotFoundError, DuplicateError, AuthorizationError, ErrorCode
from app.core.exceptions import (
AuthorizationError,
DuplicateError,
ErrorCode,
NotFoundError,
)
from app.crud.organization import organization as organization_crud
from app.crud.user import user as user_crud
from app.crud.session import session as session_crud
from app.crud.user import user as user_crud
from app.models.user import User
from app.models.user_organization import OrganizationRole
from app.schemas.common import (
PaginationParams,
PaginatedResponse,
MessageResponse,
PaginatedResponse,
PaginationParams,
SortParams,
create_pagination_meta
create_pagination_meta,
)
from app.schemas.organizations import (
OrganizationResponse,
OrganizationCreate,
OrganizationMemberResponse,
OrganizationResponse,
OrganizationUpdate,
OrganizationMemberResponse
)
from app.schemas.users import UserResponse, UserCreate, UserUpdate
from app.schemas.sessions import AdminSessionResponse
from app.schemas.users import UserCreate, UserResponse, UserUpdate
logger = logging.getLogger(__name__)
@@ -46,6 +52,7 @@ router = APIRouter()
# Schemas for bulk operations
class BulkAction(str, Enum):
"""Supported bulk actions."""
ACTIVATE = "activate"
DEACTIVATE = "deactivate"
DELETE = "delete"
@@ -53,36 +60,41 @@ class BulkAction(str, Enum):
class BulkUserAction(BaseModel):
"""Schema for bulk user actions."""
action: BulkAction = Field(..., description="Action to perform on selected users")
user_ids: List[UUID] = Field(..., min_items=1, max_items=100, description="List of user IDs (max 100)")
user_ids: list[UUID] = Field(
..., min_items=1, max_items=100, description="List of user IDs (max 100)"
)
class BulkActionResult(BaseModel):
"""Result of a bulk action."""
success: bool
affected_count: int
failed_count: int
message: str
failed_ids: Optional[List[UUID]] = []
failed_ids: list[UUID] | None = []
# ===== User Management Endpoints =====
@router.get(
"/users",
response_model=PaginatedResponse[UserResponse],
summary="Admin: List All Users",
description="Get paginated list of all users with filtering and search (admin only)",
operation_id="admin_list_users"
operation_id="admin_list_users",
)
async def admin_list_users(
pagination: PaginationParams = Depends(),
sort: SortParams = Depends(),
is_active: Optional[bool] = Query(None, description="Filter by active status"),
is_superuser: Optional[bool] = Query(None, description="Filter by superuser status"),
search: Optional[str] = Query(None, description="Search by email, name"),
is_active: bool | None = Query(None, description="Filter by active status"),
is_superuser: bool | None = Query(None, description="Filter by superuser status"),
search: str | None = Query(None, description="Search by email, name"),
admin: User = Depends(require_superuser),
db: AsyncSession = Depends(get_db)
db: AsyncSession = Depends(get_db),
) -> Any:
"""
List all users with comprehensive filtering and search.
@@ -105,20 +117,20 @@ async def admin_list_users(
sort_by=sort.sort_by or "created_at",
sort_order=sort.sort_order.value if sort.sort_order else "desc",
filters=filters if filters else None,
search=search
search=search,
)
pagination_meta = create_pagination_meta(
total=total,
page=pagination.page,
limit=pagination.limit,
items_count=len(users)
items_count=len(users),
)
return PaginatedResponse(data=users, pagination=pagination_meta)
except Exception as e:
logger.error(f"Error listing users (admin): {str(e)}", exc_info=True)
logger.error(f"Error listing users (admin): {e!s}", exc_info=True)
raise
@@ -128,12 +140,12 @@ async def admin_list_users(
status_code=status.HTTP_201_CREATED,
summary="Admin: Create User",
description="Create a new user (admin only)",
operation_id="admin_create_user"
operation_id="admin_create_user",
)
async def admin_create_user(
user_in: UserCreate,
admin: User = Depends(require_superuser),
db: AsyncSession = Depends(get_db)
db: AsyncSession = Depends(get_db),
) -> Any:
"""
Create a new user with admin privileges.
@@ -145,13 +157,10 @@ async def admin_create_user(
logger.info(f"Admin {admin.email} created user {user.email}")
return user
except ValueError as e:
logger.warning(f"Failed to create user: {str(e)}")
raise NotFoundError(
message=str(e),
error_code=ErrorCode.USER_ALREADY_EXISTS
)
logger.warning(f"Failed to create user: {e!s}")
raise NotFoundError(message=str(e), error_code=ErrorCode.USER_ALREADY_EXISTS)
except Exception as e:
logger.error(f"Error creating user (admin): {str(e)}", exc_info=True)
logger.error(f"Error creating user (admin): {e!s}", exc_info=True)
raise
@@ -160,19 +169,18 @@ async def admin_create_user(
response_model=UserResponse,
summary="Admin: Get User Details",
description="Get detailed user information (admin only)",
operation_id="admin_get_user"
operation_id="admin_get_user",
)
async def admin_get_user(
user_id: UUID,
admin: User = Depends(require_superuser),
db: AsyncSession = Depends(get_db)
db: AsyncSession = Depends(get_db),
) -> Any:
"""Get detailed information about a specific user."""
user = await user_crud.get(db, id=user_id)
if not user:
raise NotFoundError(
message=f"User {user_id} not found",
error_code=ErrorCode.USER_NOT_FOUND
message=f"User {user_id} not found", error_code=ErrorCode.USER_NOT_FOUND
)
return user
@@ -182,21 +190,20 @@ async def admin_get_user(
response_model=UserResponse,
summary="Admin: Update User",
description="Update user information (admin only)",
operation_id="admin_update_user"
operation_id="admin_update_user",
)
async def admin_update_user(
user_id: UUID,
user_in: UserUpdate,
admin: User = Depends(require_superuser),
db: AsyncSession = Depends(get_db)
db: AsyncSession = Depends(get_db),
) -> Any:
"""Update user information with admin privileges."""
try:
user = await user_crud.get(db, id=user_id)
if not user:
raise NotFoundError(
message=f"User {user_id} not found",
error_code=ErrorCode.USER_NOT_FOUND
message=f"User {user_id} not found", error_code=ErrorCode.USER_NOT_FOUND
)
updated_user = await user_crud.update(db, db_obj=user, obj_in=user_in)
@@ -206,7 +213,7 @@ async def admin_update_user(
except NotFoundError:
raise
except Exception as e:
logger.error(f"Error updating user (admin): {str(e)}", exc_info=True)
logger.error(f"Error updating user (admin): {e!s}", exc_info=True)
raise
@@ -215,20 +222,19 @@ async def admin_update_user(
response_model=MessageResponse,
summary="Admin: Delete User",
description="Soft delete a user (admin only)",
operation_id="admin_delete_user"
operation_id="admin_delete_user",
)
async def admin_delete_user(
user_id: UUID,
admin: User = Depends(require_superuser),
db: AsyncSession = Depends(get_db)
db: AsyncSession = Depends(get_db),
) -> Any:
"""Soft delete a user (sets deleted_at timestamp)."""
try:
user = await user_crud.get(db, id=user_id)
if not user:
raise NotFoundError(
message=f"User {user_id} not found",
error_code=ErrorCode.USER_NOT_FOUND
message=f"User {user_id} not found", error_code=ErrorCode.USER_NOT_FOUND
)
# Prevent deleting yourself
@@ -236,21 +242,20 @@ async def admin_delete_user(
# Use AuthorizationError for permission/operation restrictions
raise AuthorizationError(
message="Cannot delete your own account",
error_code=ErrorCode.OPERATION_FORBIDDEN
error_code=ErrorCode.OPERATION_FORBIDDEN,
)
await user_crud.soft_delete(db, id=user_id)
logger.info(f"Admin {admin.email} deleted user {user.email}")
return MessageResponse(
success=True,
message=f"User {user.email} has been deleted"
success=True, message=f"User {user.email} has been deleted"
)
except NotFoundError:
raise
except Exception as e:
logger.error(f"Error deleting user (admin): {str(e)}", exc_info=True)
logger.error(f"Error deleting user (admin): {e!s}", exc_info=True)
raise
@@ -259,34 +264,32 @@ async def admin_delete_user(
response_model=MessageResponse,
summary="Admin: Activate User",
description="Activate a user account (admin only)",
operation_id="admin_activate_user"
operation_id="admin_activate_user",
)
async def admin_activate_user(
user_id: UUID,
admin: User = Depends(require_superuser),
db: AsyncSession = Depends(get_db)
db: AsyncSession = Depends(get_db),
) -> Any:
"""Activate a user account."""
try:
user = await user_crud.get(db, id=user_id)
if not user:
raise NotFoundError(
message=f"User {user_id} not found",
error_code=ErrorCode.USER_NOT_FOUND
message=f"User {user_id} not found", error_code=ErrorCode.USER_NOT_FOUND
)
await user_crud.update(db, db_obj=user, obj_in={"is_active": True})
logger.info(f"Admin {admin.email} activated user {user.email}")
return MessageResponse(
success=True,
message=f"User {user.email} has been activated"
success=True, message=f"User {user.email} has been activated"
)
except NotFoundError:
raise
except Exception as e:
logger.error(f"Error activating user (admin): {str(e)}", exc_info=True)
logger.error(f"Error activating user (admin): {e!s}", exc_info=True)
raise
@@ -295,20 +298,19 @@ async def admin_activate_user(
response_model=MessageResponse,
summary="Admin: Deactivate User",
description="Deactivate a user account (admin only)",
operation_id="admin_deactivate_user"
operation_id="admin_deactivate_user",
)
async def admin_deactivate_user(
user_id: UUID,
admin: User = Depends(require_superuser),
db: AsyncSession = Depends(get_db)
db: AsyncSession = Depends(get_db),
) -> Any:
"""Deactivate a user account."""
try:
user = await user_crud.get(db, id=user_id)
if not user:
raise NotFoundError(
message=f"User {user_id} not found",
error_code=ErrorCode.USER_NOT_FOUND
message=f"User {user_id} not found", error_code=ErrorCode.USER_NOT_FOUND
)
# Prevent deactivating yourself
@@ -316,21 +318,20 @@ async def admin_deactivate_user(
# Use AuthorizationError for permission/operation restrictions
raise AuthorizationError(
message="Cannot deactivate your own account",
error_code=ErrorCode.OPERATION_FORBIDDEN
error_code=ErrorCode.OPERATION_FORBIDDEN,
)
await user_crud.update(db, db_obj=user, obj_in={"is_active": False})
logger.info(f"Admin {admin.email} deactivated user {user.email}")
return MessageResponse(
success=True,
message=f"User {user.email} has been deactivated"
success=True, message=f"User {user.email} has been deactivated"
)
except NotFoundError:
raise
except Exception as e:
logger.error(f"Error deactivating user (admin): {str(e)}", exc_info=True)
logger.error(f"Error deactivating user (admin): {e!s}", exc_info=True)
raise
@@ -339,12 +340,12 @@ async def admin_deactivate_user(
response_model=BulkActionResult,
summary="Admin: Bulk User Action",
description="Perform bulk actions on multiple users (admin only)",
operation_id="admin_bulk_user_action"
operation_id="admin_bulk_user_action",
)
async def admin_bulk_user_action(
bulk_action: BulkUserAction,
admin: User = Depends(require_superuser),
db: AsyncSession = Depends(get_db)
db: AsyncSession = Depends(get_db),
) -> Any:
"""
Perform bulk actions on multiple users using optimized bulk operations.
@@ -356,22 +357,16 @@ async def admin_bulk_user_action(
# Use efficient bulk operations instead of loop
if bulk_action.action == BulkAction.ACTIVATE:
affected_count = await user_crud.bulk_update_status(
db,
user_ids=bulk_action.user_ids,
is_active=True
db, user_ids=bulk_action.user_ids, is_active=True
)
elif bulk_action.action == BulkAction.DEACTIVATE:
affected_count = await user_crud.bulk_update_status(
db,
user_ids=bulk_action.user_ids,
is_active=False
db, user_ids=bulk_action.user_ids, is_active=False
)
elif bulk_action.action == BulkAction.DELETE:
# bulk_soft_delete automatically excludes the admin user
affected_count = await user_crud.bulk_soft_delete(
db,
user_ids=bulk_action.user_ids,
exclude_user_id=admin.id
db, user_ids=bulk_action.user_ids, exclude_user_id=admin.id
)
else:
raise ValueError(f"Unsupported bulk action: {bulk_action.action}")
@@ -390,29 +385,30 @@ async def admin_bulk_user_action(
affected_count=affected_count,
failed_count=failed_count,
message=f"Bulk {bulk_action.action.value}: {affected_count} users affected, {failed_count} skipped",
failed_ids=None # Bulk operations don't track individual failures
failed_ids=None, # Bulk operations don't track individual failures
)
except Exception as e:
logger.error(f"Error in bulk user action: {str(e)}", exc_info=True)
logger.error(f"Error in bulk user action: {e!s}", exc_info=True)
raise
# ===== Organization Management Endpoints =====
@router.get(
"/organizations",
response_model=PaginatedResponse[OrganizationResponse],
summary="Admin: List Organizations",
description="Get paginated list of all organizations (admin only)",
operation_id="admin_list_organizations"
operation_id="admin_list_organizations",
)
async def admin_list_organizations(
pagination: PaginationParams = Depends(),
is_active: Optional[bool] = Query(None, description="Filter by active status"),
search: Optional[str] = Query(None, description="Search by name, slug, description"),
is_active: bool | None = Query(None, description="Filter by active status"),
search: str | None = Query(None, description="Search by name, slug, description"),
admin: User = Depends(require_superuser),
db: AsyncSession = Depends(get_db)
db: AsyncSession = Depends(get_db),
) -> Any:
"""List all organizations with filtering and search."""
try:
@@ -422,14 +418,14 @@ async def admin_list_organizations(
skip=pagination.offset,
limit=pagination.limit,
is_active=is_active,
search=search
search=search,
)
# Build response objects from optimized query results
orgs_with_count = []
for item in orgs_with_data:
org = item['organization']
member_count = item['member_count']
org = item["organization"]
member_count = item["member_count"]
org_dict = {
"id": org.id,
@@ -440,7 +436,7 @@ async def admin_list_organizations(
"settings": org.settings,
"created_at": org.created_at,
"updated_at": org.updated_at,
"member_count": member_count
"member_count": member_count,
}
orgs_with_count.append(OrganizationResponse(**org_dict))
@@ -448,13 +444,13 @@ async def admin_list_organizations(
total=total,
page=pagination.page,
limit=pagination.limit,
items_count=len(orgs_with_count)
items_count=len(orgs_with_count),
)
return PaginatedResponse(data=orgs_with_count, pagination=pagination_meta)
except Exception as e:
logger.error(f"Error listing organizations (admin): {str(e)}", exc_info=True)
logger.error(f"Error listing organizations (admin): {e!s}", exc_info=True)
raise
@@ -464,12 +460,12 @@ async def admin_list_organizations(
status_code=status.HTTP_201_CREATED,
summary="Admin: Create Organization",
description="Create a new organization (admin only)",
operation_id="admin_create_organization"
operation_id="admin_create_organization",
)
async def admin_create_organization(
org_in: OrganizationCreate,
admin: User = Depends(require_superuser),
db: AsyncSession = Depends(get_db)
db: AsyncSession = Depends(get_db),
) -> Any:
"""Create a new organization."""
try:
@@ -486,18 +482,15 @@ async def admin_create_organization(
"settings": org.settings,
"created_at": org.created_at,
"updated_at": org.updated_at,
"member_count": 0
"member_count": 0,
}
return OrganizationResponse(**org_dict)
except ValueError as e:
logger.warning(f"Failed to create organization: {str(e)}")
raise NotFoundError(
message=str(e),
error_code=ErrorCode.ALREADY_EXISTS
)
logger.warning(f"Failed to create organization: {e!s}")
raise NotFoundError(message=str(e), error_code=ErrorCode.ALREADY_EXISTS)
except Exception as e:
logger.error(f"Error creating organization (admin): {str(e)}", exc_info=True)
logger.error(f"Error creating organization (admin): {e!s}", exc_info=True)
raise
@@ -506,19 +499,18 @@ async def admin_create_organization(
response_model=OrganizationResponse,
summary="Admin: Get Organization Details",
description="Get detailed organization information (admin only)",
operation_id="admin_get_organization"
operation_id="admin_get_organization",
)
async def admin_get_organization(
org_id: UUID,
admin: User = Depends(require_superuser),
db: AsyncSession = Depends(get_db)
db: AsyncSession = Depends(get_db),
) -> Any:
"""Get detailed information about a specific organization."""
org = await organization_crud.get(db, id=org_id)
if not org:
raise NotFoundError(
message=f"Organization {org_id} not found",
error_code=ErrorCode.NOT_FOUND
message=f"Organization {org_id} not found", error_code=ErrorCode.NOT_FOUND
)
org_dict = {
@@ -530,7 +522,9 @@ async def admin_get_organization(
"settings": org.settings,
"created_at": org.created_at,
"updated_at": org.updated_at,
"member_count": await organization_crud.get_member_count(db, organization_id=org.id)
"member_count": await organization_crud.get_member_count(
db, organization_id=org.id
),
}
return OrganizationResponse(**org_dict)
@@ -540,13 +534,13 @@ async def admin_get_organization(
response_model=OrganizationResponse,
summary="Admin: Update Organization",
description="Update organization information (admin only)",
operation_id="admin_update_organization"
operation_id="admin_update_organization",
)
async def admin_update_organization(
org_id: UUID,
org_in: OrganizationUpdate,
admin: User = Depends(require_superuser),
db: AsyncSession = Depends(get_db)
db: AsyncSession = Depends(get_db),
) -> Any:
"""Update organization information."""
try:
@@ -554,7 +548,7 @@ async def admin_update_organization(
if not org:
raise NotFoundError(
message=f"Organization {org_id} not found",
error_code=ErrorCode.NOT_FOUND
error_code=ErrorCode.NOT_FOUND,
)
updated_org = await organization_crud.update(db, db_obj=org, obj_in=org_in)
@@ -569,14 +563,16 @@ async def admin_update_organization(
"settings": updated_org.settings,
"created_at": updated_org.created_at,
"updated_at": updated_org.updated_at,
"member_count": await organization_crud.get_member_count(db, organization_id=updated_org.id)
"member_count": await organization_crud.get_member_count(
db, organization_id=updated_org.id
),
}
return OrganizationResponse(**org_dict)
except NotFoundError:
raise
except Exception as e:
logger.error(f"Error updating organization (admin): {str(e)}", exc_info=True)
logger.error(f"Error updating organization (admin): {e!s}", exc_info=True)
raise
@@ -585,12 +581,12 @@ async def admin_update_organization(
response_model=MessageResponse,
summary="Admin: Delete Organization",
description="Delete an organization (admin only)",
operation_id="admin_delete_organization"
operation_id="admin_delete_organization",
)
async def admin_delete_organization(
org_id: UUID,
admin: User = Depends(require_superuser),
db: AsyncSession = Depends(get_db)
db: AsyncSession = Depends(get_db),
) -> Any:
"""Delete an organization and all its relationships."""
try:
@@ -598,21 +594,20 @@ async def admin_delete_organization(
if not org:
raise NotFoundError(
message=f"Organization {org_id} not found",
error_code=ErrorCode.NOT_FOUND
error_code=ErrorCode.NOT_FOUND,
)
await organization_crud.remove(db, id=org_id)
logger.info(f"Admin {admin.email} deleted organization {org.name}")
return MessageResponse(
success=True,
message=f"Organization {org.name} has been deleted"
success=True, message=f"Organization {org.name} has been deleted"
)
except NotFoundError:
raise
except Exception as e:
logger.error(f"Error deleting organization (admin): {str(e)}", exc_info=True)
logger.error(f"Error deleting organization (admin): {e!s}", exc_info=True)
raise
@@ -621,14 +616,14 @@ async def admin_delete_organization(
response_model=PaginatedResponse[OrganizationMemberResponse],
summary="Admin: List Organization Members",
description="Get all members of an organization (admin only)",
operation_id="admin_list_organization_members"
operation_id="admin_list_organization_members",
)
async def admin_list_organization_members(
org_id: UUID,
pagination: PaginationParams = Depends(),
is_active: Optional[bool] = Query(True, description="Filter by active status"),
is_active: bool | None = Query(True, description="Filter by active status"),
admin: User = Depends(require_superuser),
db: AsyncSession = Depends(get_db)
db: AsyncSession = Depends(get_db),
) -> Any:
"""List all members of an organization."""
try:
@@ -636,7 +631,7 @@ async def admin_list_organization_members(
if not org:
raise NotFoundError(
message=f"Organization {org_id} not found",
error_code=ErrorCode.NOT_FOUND
error_code=ErrorCode.NOT_FOUND,
)
members, total = await organization_crud.get_organization_members(
@@ -644,7 +639,7 @@ async def admin_list_organization_members(
organization_id=org_id,
skip=pagination.offset,
limit=pagination.limit,
is_active=is_active
is_active=is_active,
)
# Convert to response models
@@ -654,7 +649,7 @@ async def admin_list_organization_members(
total=total,
page=pagination.page,
limit=pagination.limit,
items_count=len(member_responses)
items_count=len(member_responses),
)
return PaginatedResponse(data=member_responses, pagination=pagination_meta)
@@ -662,14 +657,19 @@ async def admin_list_organization_members(
except NotFoundError:
raise
except Exception as e:
logger.error(f"Error listing organization members (admin): {str(e)}", exc_info=True)
logger.error(
f"Error listing organization members (admin): {e!s}", exc_info=True
)
raise
class AddMemberRequest(BaseModel):
"""Request to add a member to an organization."""
user_id: UUID = Field(..., description="User ID to add")
role: OrganizationRole = Field(OrganizationRole.MEMBER, description="Role in organization")
role: OrganizationRole = Field(
OrganizationRole.MEMBER, description="Role in organization"
)
@router.post(
@@ -677,13 +677,13 @@ class AddMemberRequest(BaseModel):
response_model=MessageResponse,
summary="Admin: Add Member to Organization",
description="Add a user to an organization (admin only)",
operation_id="admin_add_organization_member"
operation_id="admin_add_organization_member",
)
async def admin_add_organization_member(
org_id: UUID,
request: AddMemberRequest,
admin: User = Depends(require_superuser),
db: AsyncSession = Depends(get_db)
db: AsyncSession = Depends(get_db),
) -> Any:
"""Add a user to an organization."""
try:
@@ -691,21 +691,18 @@ async def admin_add_organization_member(
if not org:
raise NotFoundError(
message=f"Organization {org_id} not found",
error_code=ErrorCode.NOT_FOUND
error_code=ErrorCode.NOT_FOUND,
)
user = await user_crud.get(db, id=request.user_id)
if not user:
raise NotFoundError(
message=f"User {request.user_id} not found",
error_code=ErrorCode.USER_NOT_FOUND
error_code=ErrorCode.USER_NOT_FOUND,
)
await organization_crud.add_user(
db,
organization_id=org_id,
user_id=request.user_id,
role=request.role
db, organization_id=org_id, user_id=request.user_id, role=request.role
)
logger.info(
@@ -714,22 +711,21 @@ async def admin_add_organization_member(
)
return MessageResponse(
success=True,
message=f"User {user.email} added to organization {org.name}"
success=True, message=f"User {user.email} added to organization {org.name}"
)
except ValueError as e:
logger.warning(f"Failed to add user to organization: {str(e)}")
logger.warning(f"Failed to add user to organization: {e!s}")
# Use DuplicateError for "already exists" scenarios
raise DuplicateError(
message=str(e),
error_code=ErrorCode.USER_ALREADY_EXISTS,
field="user_id"
message=str(e), error_code=ErrorCode.USER_ALREADY_EXISTS, field="user_id"
)
except NotFoundError:
raise
except Exception as e:
logger.error(f"Error adding member to organization (admin): {str(e)}", exc_info=True)
logger.error(
f"Error adding member to organization (admin): {e!s}", exc_info=True
)
raise
@@ -738,13 +734,13 @@ async def admin_add_organization_member(
response_model=MessageResponse,
summary="Admin: Remove Member from Organization",
description="Remove a user from an organization (admin only)",
operation_id="admin_remove_organization_member"
operation_id="admin_remove_organization_member",
)
async def admin_remove_organization_member(
org_id: UUID,
user_id: UUID,
admin: User = Depends(require_superuser),
db: AsyncSession = Depends(get_db)
db: AsyncSession = Depends(get_db),
) -> Any:
"""Remove a user from an organization."""
try:
@@ -752,39 +748,40 @@ async def admin_remove_organization_member(
if not org:
raise NotFoundError(
message=f"Organization {org_id} not found",
error_code=ErrorCode.NOT_FOUND
error_code=ErrorCode.NOT_FOUND,
)
user = await user_crud.get(db, id=user_id)
if not user:
raise NotFoundError(
message=f"User {user_id} not found",
error_code=ErrorCode.USER_NOT_FOUND
message=f"User {user_id} not found", error_code=ErrorCode.USER_NOT_FOUND
)
success = await organization_crud.remove_user(
db,
organization_id=org_id,
user_id=user_id
db, organization_id=org_id, user_id=user_id
)
if not success:
raise NotFoundError(
message="User is not a member of this organization",
error_code=ErrorCode.NOT_FOUND
error_code=ErrorCode.NOT_FOUND,
)
logger.info(f"Admin {admin.email} removed user {user.email} from organization {org.name}")
logger.info(
f"Admin {admin.email} removed user {user.email} from organization {org.name}"
)
return MessageResponse(
success=True,
message=f"User {user.email} removed from organization {org.name}"
message=f"User {user.email} removed from organization {org.name}",
)
except NotFoundError:
raise
except Exception as e:
logger.error(f"Error removing member from organization (admin): {str(e)}", exc_info=True)
logger.error(
f"Error removing member from organization (admin): {e!s}", exc_info=True
)
raise
@@ -792,6 +789,7 @@ async def admin_remove_organization_member(
# Session Management Endpoints
# ============================================================================
@router.get(
"/sessions",
response_model=PaginatedResponse[AdminSessionResponse],
@@ -802,13 +800,13 @@ async def admin_remove_organization_member(
Returns paginated list of sessions with user information.
Useful for admin dashboard statistics and session monitoring.
""",
operation_id="admin_list_sessions"
operation_id="admin_list_sessions",
)
async def admin_list_sessions(
pagination: PaginationParams = Depends(),
is_active: Optional[bool] = Query(None, description="Filter by active status"),
is_active: bool | None = Query(None, description="Filter by active status"),
admin: User = Depends(require_superuser),
db: AsyncSession = Depends(get_db)
db: AsyncSession = Depends(get_db),
) -> Any:
"""List all sessions across all users with filtering and pagination."""
try:
@@ -818,7 +816,7 @@ async def admin_list_sessions(
skip=pagination.offset,
limit=pagination.limit,
active_only=is_active if is_active is not None else True,
with_user=True
with_user=True,
)
# Build response objects with user information
@@ -847,21 +845,23 @@ async def admin_list_sessions(
last_used_at=session.last_used_at,
created_at=session.created_at,
expires_at=session.expires_at,
is_active=session.is_active
is_active=session.is_active,
)
session_responses.append(session_response)
logger.info(f"Admin {admin.email} listed {len(session_responses)} sessions (total: {total})")
logger.info(
f"Admin {admin.email} listed {len(session_responses)} sessions (total: {total})"
)
pagination_meta = create_pagination_meta(
total=total,
page=pagination.page,
limit=pagination.limit,
items_count=len(session_responses)
items_count=len(session_responses),
)
return PaginatedResponse(data=session_responses, pagination=pagination_meta)
except Exception as e:
logger.error(f"Error listing sessions (admin): {str(e)}", exc_info=True)
logger.error(f"Error listing sessions (admin): {e!s}", exc_info=True)
raise

View File

@@ -1,39 +1,43 @@
# app/api/routes/auth.py
import logging
import os
from datetime import datetime, timezone
from datetime import UTC, datetime
from typing import Any
from fastapi import APIRouter, Depends, HTTPException, status, Request
from fastapi import APIRouter, Depends, HTTPException, Request, status
from fastapi.security import OAuth2PasswordRequestForm
from slowapi import Limiter
from slowapi.util import get_remote_address
from sqlalchemy.ext.asyncio import AsyncSession
from app.api.dependencies.auth import get_current_user
from app.core.auth import TokenExpiredError, TokenInvalidError, decode_token
from app.core.auth import get_password_hash
from app.core.auth import (
TokenExpiredError,
TokenInvalidError,
decode_token,
get_password_hash,
)
from app.core.database import get_db
from app.core.exceptions import (
AuthenticationError as AuthError,
DatabaseError,
ErrorCode
ErrorCode,
)
from app.crud.session import session as session_crud
from app.crud.user import user as user_crud
from app.models.user import User
from app.schemas.common import MessageResponse
from app.schemas.sessions import SessionCreate, LogoutRequest
from app.schemas.sessions import LogoutRequest, SessionCreate
from app.schemas.users import (
LoginRequest,
PasswordResetConfirm,
PasswordResetRequest,
RefreshTokenRequest,
Token,
UserCreate,
UserResponse,
Token,
LoginRequest,
RefreshTokenRequest,
PasswordResetRequest,
PasswordResetConfirm
)
from app.services.auth_service import AuthService, AuthenticationError
from app.services.auth_service import AuthenticationError, AuthService
from app.services.email_service import email_service
from app.utils.device import extract_device_info
from app.utils.security import create_password_reset_token, verify_password_reset_token
@@ -54,7 +58,7 @@ async def _create_login_session(
request: Request,
user: User,
tokens: Token,
login_type: str = "login"
login_type: str = "login",
) -> None:
"""
Create a session record for successful login.
@@ -81,8 +85,8 @@ async def _create_login_session(
device_id=device_info.device_id,
ip_address=device_info.ip_address,
user_agent=device_info.user_agent,
last_used_at=datetime.now(timezone.utc),
expires_at=datetime.fromtimestamp(refresh_payload.exp, tz=timezone.utc),
last_used_at=datetime.now(UTC),
expires_at=datetime.fromtimestamp(refresh_payload.exp, tz=UTC),
location_city=device_info.location_city,
location_country=device_info.location_country,
)
@@ -95,15 +99,20 @@ async def _create_login_session(
)
except Exception as session_err:
# Log but don't fail login if session creation fails
logger.error(f"Failed to create session for {user.email}: {str(session_err)}", exc_info=True)
logger.error(
f"Failed to create session for {user.email}: {session_err!s}", exc_info=True
)
@router.post("/register", response_model=UserResponse, status_code=status.HTTP_201_CREATED, operation_id="register")
@router.post(
"/register",
response_model=UserResponse,
status_code=status.HTTP_201_CREATED,
operation_id="register",
)
@limiter.limit(f"{5 * RATE_MULTIPLIER}/minute")
async def register_user(
request: Request,
user_data: UserCreate,
db: AsyncSession = Depends(get_db)
request: Request, user_data: UserCreate, db: AsyncSession = Depends(get_db)
) -> Any:
"""
Register a new user.
@@ -116,25 +125,23 @@ async def register_user(
return user
except AuthenticationError as e:
# SECURITY: Don't reveal if email exists - generic error message
logger.warning(f"Registration failed: {str(e)}")
logger.warning(f"Registration failed: {e!s}")
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Registration failed. Please check your information and try again."
detail="Registration failed. Please check your information and try again.",
)
except Exception as e:
logger.error(f"Unexpected error during registration: {str(e)}", exc_info=True)
logger.error(f"Unexpected error during registration: {e!s}", exc_info=True)
raise DatabaseError(
message="An unexpected error occurred. Please try again later.",
error_code=ErrorCode.INTERNAL_ERROR
error_code=ErrorCode.INTERNAL_ERROR,
)
@router.post("/login", response_model=Token, operation_id="login")
@limiter.limit(f"{10 * RATE_MULTIPLIER}/minute")
async def login(
request: Request,
login_data: LoginRequest,
db: AsyncSession = Depends(get_db)
request: Request, login_data: LoginRequest, db: AsyncSession = Depends(get_db)
) -> Any:
"""
Login with username and password.
@@ -146,14 +153,16 @@ async def login(
"""
try:
# Attempt to authenticate the user
user = await AuthService.authenticate_user(db, login_data.email, login_data.password)
user = await AuthService.authenticate_user(
db, login_data.email, login_data.password
)
# Explicitly check for None result and raise correct exception
if user is None:
logger.warning(f"Invalid login attempt for: {login_data.email}")
raise AuthError(
message="Invalid email or password",
error_code=ErrorCode.INVALID_CREDENTIALS
error_code=ErrorCode.INVALID_CREDENTIALS,
)
# User is authenticated, generate tokens
@@ -166,29 +175,26 @@ async def login(
except AuthenticationError as e:
# Handle specific authentication errors like inactive accounts
logger.warning(f"Authentication failed: {str(e)}")
raise AuthError(
message=str(e),
error_code=ErrorCode.INVALID_CREDENTIALS
)
logger.warning(f"Authentication failed: {e!s}")
raise AuthError(message=str(e), error_code=ErrorCode.INVALID_CREDENTIALS)
except AuthError:
# Re-raise custom auth exceptions without modification
raise
except Exception as e:
# Handle unexpected errors
logger.error(f"Unexpected error during login: {str(e)}", exc_info=True)
logger.error(f"Unexpected error during login: {e!s}", exc_info=True)
raise DatabaseError(
message="An unexpected error occurred. Please try again later.",
error_code=ErrorCode.INTERNAL_ERROR
error_code=ErrorCode.INTERNAL_ERROR,
)
@router.post("/login/oauth", response_model=Token, operation_id='login_oauth')
@router.post("/login/oauth", response_model=Token, operation_id="login_oauth")
@limiter.limit("10/minute")
async def login_oauth(
request: Request,
form_data: OAuth2PasswordRequestForm = Depends(),
db: AsyncSession = Depends(get_db)
request: Request,
form_data: OAuth2PasswordRequestForm = Depends(),
db: AsyncSession = Depends(get_db),
) -> Any:
"""
OAuth2-compatible login endpoint, used by the OpenAPI UI.
@@ -199,12 +205,14 @@ async def login_oauth(
Access and refresh tokens.
"""
try:
user = await AuthService.authenticate_user(db, form_data.username, form_data.password)
user = await AuthService.authenticate_user(
db, form_data.username, form_data.password
)
if user is None:
raise AuthError(
message="Invalid email or password",
error_code=ErrorCode.INVALID_CREDENTIALS
error_code=ErrorCode.INVALID_CREDENTIALS,
)
# Generate tokens
@@ -216,28 +224,25 @@ async def login_oauth(
# Return full token response with user data
return tokens
except AuthenticationError as e:
logger.warning(f"OAuth authentication failed: {str(e)}")
raise AuthError(
message=str(e),
error_code=ErrorCode.INVALID_CREDENTIALS
)
logger.warning(f"OAuth authentication failed: {e!s}")
raise AuthError(message=str(e), error_code=ErrorCode.INVALID_CREDENTIALS)
except AuthError:
# Re-raise custom auth exceptions without modification
raise
except Exception as e:
logger.error(f"Unexpected error during OAuth login: {str(e)}", exc_info=True)
logger.error(f"Unexpected error during OAuth login: {e!s}", exc_info=True)
raise DatabaseError(
message="An unexpected error occurred. Please try again later.",
error_code=ErrorCode.INTERNAL_ERROR
error_code=ErrorCode.INTERNAL_ERROR,
)
@router.post("/refresh", response_model=Token, operation_id="refresh_token")
@limiter.limit("30/minute")
async def refresh_token(
request: Request,
refresh_data: RefreshTokenRequest,
db: AsyncSession = Depends(get_db)
request: Request,
refresh_data: RefreshTokenRequest,
db: AsyncSession = Depends(get_db),
) -> Any:
"""
Refresh access token using a refresh token.
@@ -249,13 +254,17 @@ async def refresh_token(
"""
try:
# Decode the refresh token to get the JTI
refresh_payload = decode_token(refresh_data.refresh_token, verify_type="refresh")
refresh_payload = decode_token(
refresh_data.refresh_token, verify_type="refresh"
)
# Check if session exists and is active
session = await session_crud.get_active_by_jti(db, jti=refresh_payload.jti)
if not session:
logger.warning(f"Refresh token used for inactive or non-existent session: {refresh_payload.jti}")
logger.warning(
f"Refresh token used for inactive or non-existent session: {refresh_payload.jti}"
)
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Session has been revoked. Please log in again.",
@@ -274,10 +283,12 @@ async def refresh_token(
db,
session=session,
new_jti=new_refresh_payload.jti,
new_expires_at=datetime.fromtimestamp(new_refresh_payload.exp, tz=timezone.utc)
new_expires_at=datetime.fromtimestamp(new_refresh_payload.exp, tz=UTC),
)
except Exception as session_err:
logger.error(f"Failed to update session {session.id}: {str(session_err)}", exc_info=True)
logger.error(
f"Failed to update session {session.id}: {session_err!s}", exc_info=True
)
# Continue anyway - tokens are already issued
return tokens
@@ -300,10 +311,10 @@ async def refresh_token(
# Re-raise HTTP exceptions (like session revoked)
raise
except Exception as e:
logger.error(f"Unexpected error during token refresh: {str(e)}")
logger.error(f"Unexpected error during token refresh: {e!s}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="An unexpected error occurred. Please try again later."
detail="An unexpected error occurred. Please try again later.",
)
@@ -320,13 +331,13 @@ async def refresh_token(
**Rate Limit**: 3 requests/minute
""",
operation_id="request_password_reset"
operation_id="request_password_reset",
)
@limiter.limit("3/minute")
async def request_password_reset(
request: Request,
reset_request: PasswordResetRequest,
db: AsyncSession = Depends(get_db)
db: AsyncSession = Depends(get_db),
) -> Any:
"""
Request a password reset.
@@ -345,26 +356,26 @@ async def request_password_reset(
# Send password reset email
await email_service.send_password_reset_email(
to_email=user.email,
reset_token=reset_token,
user_name=user.first_name
to_email=user.email, reset_token=reset_token, user_name=user.first_name
)
logger.info(f"Password reset requested for {user.email}")
else:
# Log attempt but don't reveal if email exists
logger.warning(f"Password reset requested for non-existent or inactive email: {reset_request.email}")
logger.warning(
f"Password reset requested for non-existent or inactive email: {reset_request.email}"
)
# Always return success to prevent email enumeration
return MessageResponse(
success=True,
message="If your email is registered, you will receive a password reset link shortly"
message="If your email is registered, you will receive a password reset link shortly",
)
except Exception as e:
logger.error(f"Error processing password reset request: {str(e)}", exc_info=True)
logger.error(f"Error processing password reset request: {e!s}", exc_info=True)
# Still return success to prevent information leakage
return MessageResponse(
success=True,
message="If your email is registered, you will receive a password reset link shortly"
message="If your email is registered, you will receive a password reset link shortly",
)
@@ -378,13 +389,13 @@ async def request_password_reset(
**Rate Limit**: 5 requests/minute
""",
operation_id="confirm_password_reset"
operation_id="confirm_password_reset",
)
@limiter.limit("5/minute")
async def confirm_password_reset(
request: Request,
reset_confirm: PasswordResetConfirm,
db: AsyncSession = Depends(get_db)
db: AsyncSession = Depends(get_db),
) -> Any:
"""
Confirm password reset with token.
@@ -398,7 +409,7 @@ async def confirm_password_reset(
if not email:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Invalid or expired password reset token"
detail="Invalid or expired password reset token",
)
# Look up user
@@ -406,14 +417,13 @@ async def confirm_password_reset(
if not user:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="User not found"
status_code=status.HTTP_404_NOT_FOUND, detail="User not found"
)
if not user.is_active:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="User account is inactive"
detail="User account is inactive",
)
# Update password
@@ -424,29 +434,33 @@ async def confirm_password_reset(
# SECURITY: Invalidate all existing sessions after password reset
# This prevents stolen sessions from being used after password change
from app.crud.session import session as session_crud
try:
deactivated_count = await session_crud.deactivate_all_user_sessions(
db,
user_id=str(user.id)
db, user_id=str(user.id)
)
logger.info(
f"Password reset successful for {user.email}, invalidated {deactivated_count} sessions"
)
logger.info(f"Password reset successful for {user.email}, invalidated {deactivated_count} sessions")
except Exception as session_error:
# Log but don't fail password reset if session invalidation fails
logger.error(f"Failed to invalidate sessions after password reset: {str(session_error)}")
logger.error(
f"Failed to invalidate sessions after password reset: {session_error!s}"
)
return MessageResponse(
success=True,
message="Password has been reset successfully. All devices have been logged out for security. You can now log in with your new password."
message="Password has been reset successfully. All devices have been logged out for security. You can now log in with your new password.",
)
except HTTPException:
raise
except Exception as e:
logger.error(f"Error confirming password reset: {str(e)}", exc_info=True)
logger.error(f"Error confirming password reset: {e!s}", exc_info=True)
await db.rollback()
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="An error occurred while resetting your password"
detail="An error occurred while resetting your password",
)
@@ -464,14 +478,14 @@ async def confirm_password_reset(
**Rate Limit**: 10 requests/minute
""",
operation_id="logout"
operation_id="logout",
)
@limiter.limit("10/minute")
async def logout(
request: Request,
logout_request: LogoutRequest,
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db)
db: AsyncSession = Depends(get_db),
) -> Any:
"""
Logout from current device by deactivating the session.
@@ -487,15 +501,14 @@ async def logout(
try:
# Decode refresh token to get JTI
try:
refresh_payload = decode_token(logout_request.refresh_token, verify_type="refresh")
refresh_payload = decode_token(
logout_request.refresh_token, verify_type="refresh"
)
except (TokenExpiredError, TokenInvalidError) as e:
# Even if token is expired/invalid, try to deactivate session
logger.warning(f"Logout with invalid/expired token: {str(e)}")
logger.warning(f"Logout with invalid/expired token: {e!s}")
# Don't fail - return success anyway
return MessageResponse(
success=True,
message="Logged out successfully"
)
return MessageResponse(success=True, message="Logged out successfully")
# Find the session by JTI
session = await session_crud.get_by_jti(db, jti=refresh_payload.jti)
@@ -509,7 +522,7 @@ async def logout(
)
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="You can only logout your own sessions"
detail="You can only logout your own sessions",
)
# Deactivate the session
@@ -522,22 +535,20 @@ async def logout(
else:
# Session not found - maybe already deleted or never existed
# Return success anyway (idempotent)
logger.info(f"Logout requested for non-existent session (JTI: {refresh_payload.jti})")
logger.info(
f"Logout requested for non-existent session (JTI: {refresh_payload.jti})"
)
return MessageResponse(
success=True,
message="Logged out successfully"
)
return MessageResponse(success=True, message="Logged out successfully")
except HTTPException:
raise
except Exception as e:
logger.error(f"Error during logout for user {current_user.id}: {str(e)}", exc_info=True)
# Don't expose error details
return MessageResponse(
success=True,
message="Logged out successfully"
logger.error(
f"Error during logout for user {current_user.id}: {e!s}", exc_info=True
)
# Don't expose error details
return MessageResponse(success=True, message="Logged out successfully")
@router.post(
@@ -553,13 +564,13 @@ async def logout(
**Rate Limit**: 5 requests/minute
""",
operation_id="logout_all"
operation_id="logout_all",
)
@limiter.limit("5/minute")
async def logout_all(
request: Request,
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db)
db: AsyncSession = Depends(get_db),
) -> Any:
"""
Logout from all devices by deactivating all user sessions.
@@ -573,19 +584,25 @@ async def logout_all(
"""
try:
# Deactivate all sessions for this user
count = await session_crud.deactivate_all_user_sessions(db, user_id=str(current_user.id))
count = await session_crud.deactivate_all_user_sessions(
db, user_id=str(current_user.id)
)
logger.info(f"User {current_user.id} logged out from all devices ({count} sessions)")
logger.info(
f"User {current_user.id} logged out from all devices ({count} sessions)"
)
return MessageResponse(
success=True,
message=f"Successfully logged out from all devices ({count} sessions terminated)"
message=f"Successfully logged out from all devices ({count} sessions terminated)",
)
except Exception as e:
logger.error(f"Error during logout-all for user {current_user.id}: {str(e)}", exc_info=True)
logger.error(
f"Error during logout-all for user {current_user.id}: {e!s}", exc_info=True
)
await db.rollback()
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="An error occurred while logging out"
detail="An error occurred while logging out",
)

View File

@@ -4,8 +4,9 @@ Organization endpoints for regular users.
These endpoints allow users to view and manage organizations they belong to.
"""
import logging
from typing import Any, List
from typing import Any
from uuid import UUID
from fastapi import APIRouter, Depends, Query
@@ -14,18 +15,18 @@ from sqlalchemy.ext.asyncio import AsyncSession
from app.api.dependencies.auth import get_current_user
from app.api.dependencies.permissions import require_org_admin, require_org_membership
from app.core.database import get_db
from app.core.exceptions import NotFoundError, ErrorCode
from app.core.exceptions import ErrorCode, NotFoundError
from app.crud.organization import organization as organization_crud
from app.models.user import User
from app.schemas.common import (
PaginationParams,
PaginatedResponse,
create_pagination_meta
PaginationParams,
create_pagination_meta,
)
from app.schemas.organizations import (
OrganizationResponse,
OrganizationMemberResponse,
OrganizationUpdate
OrganizationResponse,
OrganizationUpdate,
)
logger = logging.getLogger(__name__)
@@ -35,15 +36,15 @@ router = APIRouter()
@router.get(
"/me",
response_model=List[OrganizationResponse],
response_model=list[OrganizationResponse],
summary="Get My Organizations",
description="Get all organizations the current user belongs to",
operation_id="get_my_organizations"
operation_id="get_my_organizations",
)
async def get_my_organizations(
is_active: bool = Query(True, description="Filter by active membership"),
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db)
db: AsyncSession = Depends(get_db),
) -> Any:
"""
Get all organizations the current user belongs to.
@@ -54,15 +55,13 @@ async def get_my_organizations(
try:
# Get all org data in single query with JOIN and subquery
orgs_data = await organization_crud.get_user_organizations_with_details(
db,
user_id=current_user.id,
is_active=is_active
db, user_id=current_user.id, is_active=is_active
)
# Transform to response objects
orgs_with_data = []
for item in orgs_data:
org = item['organization']
org = item["organization"]
org_dict = {
"id": org.id,
"name": org.name,
@@ -72,14 +71,14 @@ async def get_my_organizations(
"settings": org.settings,
"created_at": org.created_at,
"updated_at": org.updated_at,
"member_count": item['member_count']
"member_count": item["member_count"],
}
orgs_with_data.append(OrganizationResponse(**org_dict))
return orgs_with_data
except Exception as e:
logger.error(f"Error getting user organizations: {str(e)}", exc_info=True)
logger.error(f"Error getting user organizations: {e!s}", exc_info=True)
raise
@@ -88,12 +87,12 @@ async def get_my_organizations(
response_model=OrganizationResponse,
summary="Get Organization Details",
description="Get details of an organization the user belongs to",
operation_id="get_organization"
operation_id="get_organization",
)
async def get_organization(
organization_id: UUID,
current_user: User = Depends(require_org_membership),
db: AsyncSession = Depends(get_db)
db: AsyncSession = Depends(get_db),
) -> Any:
"""
Get details of a specific organization.
@@ -105,7 +104,7 @@ async def get_organization(
if not org: # pragma: no cover - Permission check prevents this (see docs/UNREACHABLE_DEFENSIVE_CODE_ANALYSIS.md)
raise NotFoundError(
detail=f"Organization {organization_id} not found",
error_code=ErrorCode.NOT_FOUND
error_code=ErrorCode.NOT_FOUND,
)
org_dict = {
@@ -117,14 +116,16 @@ async def get_organization(
"settings": org.settings,
"created_at": org.created_at,
"updated_at": org.updated_at,
"member_count": await organization_crud.get_member_count(db, organization_id=org.id)
"member_count": await organization_crud.get_member_count(
db, organization_id=org.id
),
}
return OrganizationResponse(**org_dict)
except NotFoundError: # pragma: no cover - See above
raise
except Exception as e:
logger.error(f"Error getting organization: {str(e)}", exc_info=True)
logger.error(f"Error getting organization: {e!s}", exc_info=True)
raise
@@ -133,14 +134,14 @@ async def get_organization(
response_model=PaginatedResponse[OrganizationMemberResponse],
summary="Get Organization Members",
description="Get all members of an organization (members can view)",
operation_id="get_organization_members"
operation_id="get_organization_members",
)
async def get_organization_members(
organization_id: UUID,
pagination: PaginationParams = Depends(),
is_active: bool = Query(True, description="Filter by active status"),
current_user: User = Depends(require_org_membership),
db: AsyncSession = Depends(get_db)
db: AsyncSession = Depends(get_db),
) -> Any:
"""
Get all members of an organization.
@@ -153,7 +154,7 @@ async def get_organization_members(
organization_id=organization_id,
skip=pagination.offset,
limit=pagination.limit,
is_active=is_active
is_active=is_active,
)
member_responses = [OrganizationMemberResponse(**member) for member in members]
@@ -162,13 +163,13 @@ async def get_organization_members(
total=total,
page=pagination.page,
limit=pagination.limit,
items_count=len(member_responses)
items_count=len(member_responses),
)
return PaginatedResponse(data=member_responses, pagination=pagination_meta)
except Exception as e:
logger.error(f"Error getting organization members: {str(e)}", exc_info=True)
logger.error(f"Error getting organization members: {e!s}", exc_info=True)
raise
@@ -177,13 +178,13 @@ async def get_organization_members(
response_model=OrganizationResponse,
summary="Update Organization",
description="Update organization details (admin/owner only)",
operation_id="update_organization"
operation_id="update_organization",
)
async def update_organization(
organization_id: UUID,
org_in: OrganizationUpdate,
current_user: User = Depends(require_org_admin),
db: AsyncSession = Depends(get_db)
db: AsyncSession = Depends(get_db),
) -> Any:
"""
Update organization details.
@@ -195,11 +196,13 @@ async def update_organization(
if not org: # pragma: no cover - Permission check prevents this (see docs/UNREACHABLE_DEFENSIVE_CODE_ANALYSIS.md)
raise NotFoundError(
detail=f"Organization {organization_id} not found",
error_code=ErrorCode.NOT_FOUND
error_code=ErrorCode.NOT_FOUND,
)
updated_org = await organization_crud.update(db, db_obj=org, obj_in=org_in)
logger.info(f"User {current_user.email} updated organization {updated_org.name}")
logger.info(
f"User {current_user.email} updated organization {updated_org.name}"
)
org_dict = {
"id": updated_org.id,
@@ -210,12 +213,14 @@ async def update_organization(
"settings": updated_org.settings,
"created_at": updated_org.created_at,
"updated_at": updated_org.updated_at,
"member_count": await organization_crud.get_member_count(db, organization_id=updated_org.id)
"member_count": await organization_crud.get_member_count(
db, organization_id=updated_org.id
),
}
return OrganizationResponse(**org_dict)
except NotFoundError: # pragma: no cover - See above
raise
except Exception as e:
logger.error(f"Error updating organization: {str(e)}", exc_info=True)
logger.error(f"Error updating organization: {e!s}", exc_info=True)
raise

View File

@@ -3,11 +3,12 @@ Session management endpoints.
Allows users to view and manage their active sessions across devices.
"""
import logging
from typing import Any
from uuid import UUID
from fastapi import APIRouter, Depends, HTTPException, status, Request
from fastapi import APIRouter, Depends, HTTPException, Request, status
from slowapi import Limiter
from slowapi.util import get_remote_address
from sqlalchemy.ext.asyncio import AsyncSession
@@ -15,11 +16,11 @@ from sqlalchemy.ext.asyncio import AsyncSession
from app.api.dependencies.auth import get_current_user
from app.core.auth import decode_token
from app.core.database import get_db
from app.core.exceptions import NotFoundError, AuthorizationError, ErrorCode
from app.core.exceptions import AuthorizationError, ErrorCode, NotFoundError
from app.crud.session import session as session_crud
from app.models.user import User
from app.schemas.common import MessageResponse
from app.schemas.sessions import SessionResponse, SessionListResponse
from app.schemas.sessions import SessionListResponse, SessionResponse
router = APIRouter()
logger = logging.getLogger(__name__)
@@ -39,13 +40,13 @@ limiter = Limiter(key_func=get_remote_address)
**Rate Limit**: 30 requests/minute
""",
operation_id="list_my_sessions"
operation_id="list_my_sessions",
)
@limiter.limit("30/minute")
async def list_my_sessions(
request: Request,
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db)
db: AsyncSession = Depends(get_db),
) -> Any:
"""
List all active sessions for the current user.
@@ -60,18 +61,15 @@ async def list_my_sessions(
try:
# Get all active sessions for user
sessions = await session_crud.get_user_sessions(
db,
user_id=str(current_user.id),
active_only=True
db, user_id=str(current_user.id), active_only=True
)
# Try to identify current session from Authorization header
current_session_jti = None
auth_header = request.headers.get("authorization")
if auth_header and auth_header.startswith("Bearer "):
try:
access_token = auth_header.split(" ")[1]
token_payload = decode_token(access_token)
decode_token(access_token)
# Note: Access tokens don't have JTI by default, but we can try
# For now, we'll mark current based on most recent activity
except Exception:
@@ -90,22 +88,27 @@ async def list_my_sessions(
last_used_at=s.last_used_at,
created_at=s.created_at,
expires_at=s.expires_at,
is_current=(s == sessions[0] if sessions else False) # Most recent = current
is_current=(
s == sessions[0] if sessions else False
), # Most recent = current
)
session_responses.append(session_response)
logger.info(f"User {current_user.id} listed {len(session_responses)} active sessions")
logger.info(
f"User {current_user.id} listed {len(session_responses)} active sessions"
)
return SessionListResponse(
sessions=session_responses,
total=len(session_responses)
sessions=session_responses, total=len(session_responses)
)
except Exception as e:
logger.error(f"Error listing sessions for user {current_user.id}: {str(e)}", exc_info=True)
logger.error(
f"Error listing sessions for user {current_user.id}: {e!s}", exc_info=True
)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to retrieve sessions"
detail="Failed to retrieve sessions",
)
@@ -122,14 +125,14 @@ async def list_my_sessions(
**Rate Limit**: 10 requests/minute
""",
operation_id="revoke_session"
operation_id="revoke_session",
)
@limiter.limit("10/minute")
async def revoke_session(
request: Request,
session_id: UUID,
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db)
db: AsyncSession = Depends(get_db),
) -> Any:
"""
Revoke a specific session by ID.
@@ -149,7 +152,7 @@ async def revoke_session(
if not session:
raise NotFoundError(
message=f"Session {session_id} not found",
error_code=ErrorCode.NOT_FOUND
error_code=ErrorCode.NOT_FOUND,
)
# Verify session belongs to current user
@@ -160,7 +163,7 @@ async def revoke_session(
)
raise AuthorizationError(
message="You can only revoke your own sessions",
error_code=ErrorCode.INSUFFICIENT_PERMISSIONS
error_code=ErrorCode.INSUFFICIENT_PERMISSIONS,
)
# Deactivate the session
@@ -173,16 +176,16 @@ async def revoke_session(
return MessageResponse(
success=True,
message=f"Session revoked: {session.device_name or 'Unknown device'}"
message=f"Session revoked: {session.device_name or 'Unknown device'}",
)
except (NotFoundError, AuthorizationError):
raise
except Exception as e:
logger.error(f"Error revoking session {session_id}: {str(e)}", exc_info=True)
logger.error(f"Error revoking session {session_id}: {e!s}", exc_info=True)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to revoke session"
detail="Failed to revoke session",
)
@@ -198,13 +201,13 @@ async def revoke_session(
**Rate Limit**: 5 requests/minute
""",
operation_id="cleanup_expired_sessions"
operation_id="cleanup_expired_sessions",
)
@limiter.limit("5/minute")
async def cleanup_expired_sessions(
request: Request,
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db)
db: AsyncSession = Depends(get_db),
) -> Any:
"""
Cleanup expired sessions for the current user.
@@ -219,21 +222,24 @@ async def cleanup_expired_sessions(
try:
# Use optimized bulk DELETE instead of N individual deletes
deleted_count = await session_crud.cleanup_expired_for_user(
db,
user_id=str(current_user.id)
db, user_id=str(current_user.id)
)
logger.info(f"User {current_user.id} cleaned up {deleted_count} expired sessions")
logger.info(
f"User {current_user.id} cleaned up {deleted_count} expired sessions"
)
return MessageResponse(
success=True,
message=f"Cleaned up {deleted_count} expired sessions"
success=True, message=f"Cleaned up {deleted_count} expired sessions"
)
except Exception as e:
logger.error(f"Error cleaning up sessions for user {current_user.id}: {str(e)}", exc_info=True)
logger.error(
f"Error cleaning up sessions for user {current_user.id}: {e!s}",
exc_info=True,
)
await db.rollback()
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to cleanup sessions"
detail="Failed to cleanup sessions",
)

View File

@@ -1,33 +1,30 @@
"""
User management endpoints for CRUD operations.
"""
import logging
from typing import Any, Optional
from typing import Any
from uuid import UUID
from fastapi import APIRouter, Depends, Query, status, Request
from fastapi import APIRouter, Depends, Query, Request, status
from slowapi import Limiter
from slowapi.util import get_remote_address
from sqlalchemy.ext.asyncio import AsyncSession
from app.api.dependencies.auth import get_current_user, get_current_superuser
from app.api.dependencies.auth import get_current_superuser, get_current_user
from app.core.database import get_db
from app.core.exceptions import (
NotFoundError,
AuthorizationError,
ErrorCode
)
from app.core.exceptions import AuthorizationError, ErrorCode, NotFoundError
from app.crud.user import user as user_crud
from app.models.user import User
from app.schemas.common import (
PaginationParams,
PaginatedResponse,
MessageResponse,
PaginatedResponse,
PaginationParams,
SortParams,
create_pagination_meta
create_pagination_meta,
)
from app.schemas.users import UserResponse, UserUpdate, PasswordChange
from app.services.auth_service import AuthService, AuthenticationError
from app.schemas.users import PasswordChange, UserResponse, UserUpdate
from app.services.auth_service import AuthenticationError, AuthService
logger = logging.getLogger(__name__)
@@ -50,15 +47,15 @@ limiter = Limiter(key_func=get_remote_address)
**Rate Limit**: 60 requests/minute
""",
operation_id="list_users"
operation_id="list_users",
)
async def list_users(
pagination: PaginationParams = Depends(),
sort: SortParams = Depends(),
is_active: Optional[bool] = Query(None, description="Filter by active status"),
is_superuser: Optional[bool] = Query(None, description="Filter by superuser status"),
is_active: bool | None = Query(None, description="Filter by active status"),
is_superuser: bool | None = Query(None, description="Filter by superuser status"),
current_user: User = Depends(get_current_superuser),
db: AsyncSession = Depends(get_db)
db: AsyncSession = Depends(get_db),
) -> Any:
"""
List all users with pagination, filtering, and sorting.
@@ -80,7 +77,7 @@ async def list_users(
limit=pagination.limit,
sort_by=sort.sort_by,
sort_order=sort.sort_order.value if sort.sort_order else "asc",
filters=filters if filters else None
filters=filters if filters else None,
)
# Create pagination metadata
@@ -88,15 +85,12 @@ async def list_users(
total=total,
page=pagination.page,
limit=pagination.limit,
items_count=len(users)
items_count=len(users),
)
return PaginatedResponse(
data=users,
pagination=pagination_meta
)
return PaginatedResponse(data=users, pagination=pagination_meta)
except Exception as e:
logger.error(f"Error listing users: {str(e)}", exc_info=True)
logger.error(f"Error listing users: {e!s}", exc_info=True)
raise
@@ -111,11 +105,9 @@ async def list_users(
**Rate Limit**: 60 requests/minute
""",
operation_id="get_current_user_profile"
operation_id="get_current_user_profile",
)
def get_current_user_profile(
current_user: User = Depends(get_current_user)
) -> Any:
def get_current_user_profile(current_user: User = Depends(get_current_user)) -> Any:
"""Get current user's profile."""
return current_user
@@ -133,12 +125,12 @@ def get_current_user_profile(
**Rate Limit**: 30 requests/minute
""",
operation_id="update_current_user"
operation_id="update_current_user",
)
async def update_current_user(
user_update: UserUpdate,
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db)
db: AsyncSession = Depends(get_db),
) -> Any:
"""
Update current user's profile.
@@ -147,17 +139,17 @@ async def update_current_user(
"""
try:
updated_user = await user_crud.update(
db,
db_obj=current_user,
obj_in=user_update
db, db_obj=current_user, obj_in=user_update
)
logger.info(f"User {current_user.id} updated their profile")
return updated_user
except ValueError as e:
logger.error(f"Error updating user {current_user.id}: {str(e)}")
logger.error(f"Error updating user {current_user.id}: {e!s}")
raise
except Exception as e:
logger.error(f"Unexpected error updating user {current_user.id}: {str(e)}", exc_info=True)
logger.error(
f"Unexpected error updating user {current_user.id}: {e!s}", exc_info=True
)
raise
@@ -175,12 +167,12 @@ async def update_current_user(
**Rate Limit**: 60 requests/minute
""",
operation_id="get_user_by_id"
operation_id="get_user_by_id",
)
async def get_user_by_id(
user_id: UUID,
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db)
db: AsyncSession = Depends(get_db),
) -> Any:
"""
Get user by ID.
@@ -194,7 +186,7 @@ async def get_user_by_id(
)
raise AuthorizationError(
message="Not enough permissions to view this user",
error_code=ErrorCode.INSUFFICIENT_PERMISSIONS
error_code=ErrorCode.INSUFFICIENT_PERMISSIONS,
)
# Get user
@@ -202,7 +194,7 @@ async def get_user_by_id(
if not user:
raise NotFoundError(
message=f"User with id {user_id} not found",
error_code=ErrorCode.USER_NOT_FOUND
error_code=ErrorCode.USER_NOT_FOUND,
)
return user
@@ -222,13 +214,13 @@ async def get_user_by_id(
**Rate Limit**: 30 requests/minute
""",
operation_id="update_user"
operation_id="update_user",
)
async def update_user(
user_id: UUID,
user_update: UserUpdate,
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db)
db: AsyncSession = Depends(get_db),
) -> Any:
"""
Update user by ID.
@@ -245,7 +237,7 @@ async def update_user(
)
raise AuthorizationError(
message="Not enough permissions to update this user",
error_code=ErrorCode.INSUFFICIENT_PERMISSIONS
error_code=ErrorCode.INSUFFICIENT_PERMISSIONS,
)
# Get user
@@ -253,7 +245,7 @@ async def update_user(
if not user:
raise NotFoundError(
message=f"User with id {user_id} not found",
error_code=ErrorCode.USER_NOT_FOUND
error_code=ErrorCode.USER_NOT_FOUND,
)
try:
@@ -261,10 +253,10 @@ async def update_user(
logger.info(f"User {user_id} updated by {current_user.id}")
return updated_user
except ValueError as e:
logger.error(f"Error updating user {user_id}: {str(e)}")
logger.error(f"Error updating user {user_id}: {e!s}")
raise
except Exception as e:
logger.error(f"Unexpected error updating user {user_id}: {str(e)}", exc_info=True)
logger.error(f"Unexpected error updating user {user_id}: {e!s}", exc_info=True)
raise
@@ -281,14 +273,14 @@ async def update_user(
**Rate Limit**: 5 requests/minute
""",
operation_id="change_current_user_password"
operation_id="change_current_user_password",
)
@limiter.limit("5/minute")
async def change_current_user_password(
request: Request,
password_change: PasswordChange,
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db)
db: AsyncSession = Depends(get_db),
) -> Any:
"""
Change current user's password.
@@ -300,23 +292,23 @@ async def change_current_user_password(
db=db,
user_id=current_user.id,
current_password=password_change.current_password,
new_password=password_change.new_password
new_password=password_change.new_password,
)
if success:
logger.info(f"User {current_user.id} changed their password")
return MessageResponse(
success=True,
message="Password changed successfully"
success=True, message="Password changed successfully"
)
except AuthenticationError as e:
logger.warning(f"Failed password change attempt for user {current_user.id}: {str(e)}")
logger.warning(
f"Failed password change attempt for user {current_user.id}: {e!s}"
)
raise AuthorizationError(
message=str(e),
error_code=ErrorCode.INVALID_CREDENTIALS
message=str(e), error_code=ErrorCode.INVALID_CREDENTIALS
)
except Exception as e:
logger.error(f"Error changing password for user {current_user.id}: {str(e)}")
logger.error(f"Error changing password for user {current_user.id}: {e!s}")
raise
@@ -335,12 +327,12 @@ async def change_current_user_password(
**Note**: This performs a hard delete. Consider implementing soft deletes for production.
""",
operation_id="delete_user"
operation_id="delete_user",
)
async def delete_user(
user_id: UUID,
current_user: User = Depends(get_current_superuser),
db: AsyncSession = Depends(get_db)
db: AsyncSession = Depends(get_db),
) -> Any:
"""
Delete user by ID (superuser only).
@@ -351,7 +343,7 @@ async def delete_user(
if str(user_id) == str(current_user.id):
raise AuthorizationError(
message="Cannot delete your own account",
error_code=ErrorCode.INSUFFICIENT_PERMISSIONS
error_code=ErrorCode.INSUFFICIENT_PERMISSIONS,
)
# Get user
@@ -359,7 +351,7 @@ async def delete_user(
if not user:
raise NotFoundError(
message=f"User with id {user_id} not found",
error_code=ErrorCode.USER_NOT_FOUND
error_code=ErrorCode.USER_NOT_FOUND,
)
try:
@@ -367,12 +359,11 @@ async def delete_user(
await user_crud.soft_delete(db, id=str(user_id))
logger.info(f"User {user_id} soft-deleted by {current_user.id}")
return MessageResponse(
success=True,
message=f"User {user_id} deleted successfully"
success=True, message=f"User {user_id} deleted successfully"
)
except ValueError as e:
logger.error(f"Error deleting user {user_id}: {str(e)}")
logger.error(f"Error deleting user {user_id}: {e!s}")
raise
except Exception as e:
logger.error(f"Unexpected error deleting user {user_id}: {str(e)}", exc_info=True)
logger.error(f"Unexpected error deleting user {user_id}: {e!s}", exc_info=True)
raise

View File

@@ -1,39 +1,39 @@
import logging
logging.getLogger('passlib').setLevel(logging.ERROR)
from datetime import datetime, timedelta, timezone
from typing import Any, Dict, Optional, Union
import uuid
logging.getLogger("passlib").setLevel(logging.ERROR)
import asyncio
import uuid
from datetime import UTC, datetime, timedelta
from functools import partial
from typing import Any
from jose import jwt, JWTError
from jose import JWTError, jwt
from passlib.context import CryptContext
from pydantic import ValidationError
from app.core.config import settings
from app.schemas.users import TokenData, TokenPayload
# Password hashing context
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
# Custom exceptions for auth
class AuthError(Exception):
"""Base authentication error"""
pass
class TokenExpiredError(AuthError):
"""Token has expired"""
pass
class TokenInvalidError(AuthError):
"""Token is invalid"""
pass
class TokenMissingClaimError(AuthError):
"""Token is missing a required claim"""
pass
def verify_password(plain_password: str, hashed_password: str) -> bool:
@@ -62,8 +62,7 @@ async def verify_password_async(plain_password: str, hashed_password: str) -> bo
"""
loop = asyncio.get_event_loop()
return await loop.run_in_executor(
None,
partial(pwd_context.verify, plain_password, hashed_password)
None, partial(pwd_context.verify, plain_password, hashed_password)
)
@@ -82,17 +81,13 @@ async def get_password_hash_async(password: str) -> str:
Hashed password string
"""
loop = asyncio.get_event_loop()
return await loop.run_in_executor(
None,
pwd_context.hash,
password
)
return await loop.run_in_executor(None, pwd_context.hash, password)
def create_access_token(
subject: Union[str, Any],
expires_delta: Optional[timedelta] = None,
claims: Optional[Dict[str, Any]] = None
subject: str | Any,
expires_delta: timedelta | None = None,
claims: dict[str, Any] | None = None,
) -> str:
"""
Create a JWT access token.
@@ -106,17 +101,19 @@ def create_access_token(
Encoded JWT token
"""
if expires_delta:
expire = datetime.now(timezone.utc) + expires_delta
expire = datetime.now(UTC) + expires_delta
else:
expire = datetime.now(timezone.utc) + timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES)
expire = datetime.now(UTC) + timedelta(
minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES
)
# Base token data
to_encode = {
"sub": str(subject),
"exp": expire,
"iat": datetime.now(tz=timezone.utc),
"iat": datetime.now(tz=UTC),
"jti": str(uuid.uuid4()),
"type": "access"
"type": "access",
}
# Add custom claims
@@ -125,17 +122,14 @@ def create_access_token(
# Create the JWT
encoded_jwt = jwt.encode(
to_encode,
settings.SECRET_KEY,
algorithm=settings.ALGORITHM
to_encode, settings.SECRET_KEY, algorithm=settings.ALGORITHM
)
return encoded_jwt
def create_refresh_token(
subject: Union[str, Any],
expires_delta: Optional[timedelta] = None
subject: str | Any, expires_delta: timedelta | None = None
) -> str:
"""
Create a JWT refresh token.
@@ -148,28 +142,26 @@ def create_refresh_token(
Encoded JWT refresh token
"""
if expires_delta:
expire = datetime.now(timezone.utc) + expires_delta
expire = datetime.now(UTC) + expires_delta
else:
expire = datetime.now(timezone.utc) + timedelta(days=settings.REFRESH_TOKEN_EXPIRE_DAYS)
expire = datetime.now(UTC) + timedelta(days=settings.REFRESH_TOKEN_EXPIRE_DAYS)
to_encode = {
"sub": str(subject),
"exp": expire,
"iat": datetime.now(timezone.utc),
"iat": datetime.now(UTC),
"jti": str(uuid.uuid4()),
"type": "refresh"
"type": "refresh",
}
encoded_jwt = jwt.encode(
to_encode,
settings.SECRET_KEY,
algorithm=settings.ALGORITHM
to_encode, settings.SECRET_KEY, algorithm=settings.ALGORITHM
)
return encoded_jwt
def decode_token(token: str, verify_type: Optional[str] = None) -> TokenPayload:
def decode_token(token: str, verify_type: str | None = None) -> TokenPayload:
"""
Decode and verify a JWT token.
@@ -195,8 +187,8 @@ def decode_token(token: str, verify_type: Optional[str] = None) -> TokenPayload:
"verify_signature": True,
"verify_exp": True,
"verify_iat": True,
"require": ["exp", "sub", "iat"]
}
"require": ["exp", "sub", "iat"],
},
)
# SECURITY: Explicitly verify the algorithm to prevent algorithm confusion attacks
@@ -250,4 +242,4 @@ def get_token_data(token: str) -> TokenData:
user_id = payload.sub
is_superuser = payload.is_superuser or False
return TokenData(user_id=uuid.UUID(user_id), is_superuser=is_superuser)
return TokenData(user_id=uuid.UUID(user_id), is_superuser=is_superuser)

View File

@@ -1,5 +1,4 @@
import logging
from typing import Optional, List
from pydantic import Field, field_validator
from pydantic_settings import BaseSettings
@@ -13,7 +12,7 @@ class Settings(BaseSettings):
# Environment (must be before SECRET_KEY for validation)
ENVIRONMENT: str = Field(
default="development",
description="Environment: development, staging, or production"
description="Environment: development, staging, or production",
)
# Security: Content Security Policy
@@ -21,8 +20,7 @@ class Settings(BaseSettings):
# Set to True for strict CSP (blocks most external resources)
# Set to "relaxed" for modern frontend development
CSP_MODE: str = Field(
default="relaxed",
description="CSP mode: 'strict', 'relaxed', or 'disabled'"
default="relaxed", description="CSP mode: 'strict', 'relaxed', or 'disabled'"
)
# Database configuration
@@ -31,7 +29,7 @@ class Settings(BaseSettings):
POSTGRES_HOST: str = "localhost"
POSTGRES_PORT: str = "5432"
POSTGRES_DB: str = "app"
DATABASE_URL: Optional[str] = None
DATABASE_URL: str | None = None
db_pool_size: int = 20 # Default connection pool size
db_max_overflow: int = 50 # Maximum overflow connections
db_pool_timeout: int = 30 # Seconds to wait for a connection
@@ -59,38 +57,36 @@ class Settings(BaseSettings):
SECRET_KEY: str = Field(
default="dev_only_insecure_key_change_in_production_32chars_min",
min_length=32,
description="JWT signing key. MUST be changed in production. Generate with: python -c 'import secrets; print(secrets.token_urlsafe(32))'"
description="JWT signing key. MUST be changed in production. Generate with: python -c 'import secrets; print(secrets.token_urlsafe(32))'",
)
ALGORITHM: str = "HS256"
ACCESS_TOKEN_EXPIRE_MINUTES: int = 15 # 15 minutes (production standard)
REFRESH_TOKEN_EXPIRE_DAYS: int = 7 # 7 days
# CORS configuration
BACKEND_CORS_ORIGINS: List[str] = ["http://localhost:3000"]
BACKEND_CORS_ORIGINS: list[str] = ["http://localhost:3000"]
# Frontend URL for email links
FRONTEND_URL: str = Field(
default="http://localhost:3000",
description="Frontend application URL for email links"
description="Frontend application URL for email links",
)
# Admin user
FIRST_SUPERUSER_EMAIL: Optional[str] = Field(
default=None,
description="Email for first superuser account"
FIRST_SUPERUSER_EMAIL: str | None = Field(
default=None, description="Email for first superuser account"
)
FIRST_SUPERUSER_PASSWORD: Optional[str] = Field(
default=None,
description="Password for first superuser (min 12 characters)"
FIRST_SUPERUSER_PASSWORD: str | None = Field(
default=None, description="Password for first superuser (min 12 characters)"
)
@field_validator('SECRET_KEY')
@field_validator("SECRET_KEY")
@classmethod
def validate_secret_key(cls, v: str, info) -> str:
"""Validate SECRET_KEY is secure, especially in production."""
# Get environment from values if available
values_data = info.data if info.data else {}
env = values_data.get('ENVIRONMENT', 'development')
env = values_data.get("ENVIRONMENT", "development")
if v.startswith("your_secret_key_here"):
if env == "production":
@@ -106,13 +102,15 @@ class Settings(BaseSettings):
)
if len(v) < 32:
raise ValueError("SECRET_KEY must be at least 32 characters long for security")
raise ValueError(
"SECRET_KEY must be at least 32 characters long for security"
)
return v
@field_validator('FIRST_SUPERUSER_PASSWORD')
@field_validator("FIRST_SUPERUSER_PASSWORD")
@classmethod
def validate_superuser_password(cls, v: Optional[str]) -> Optional[str]:
def validate_superuser_password(cls, v: str | None) -> str | None:
"""Validate superuser password strength."""
if v is None:
return v
@@ -121,7 +119,13 @@ class Settings(BaseSettings):
raise ValueError("FIRST_SUPERUSER_PASSWORD must be at least 12 characters")
# Check for common weak passwords
weak_passwords = {'admin123', 'Admin123', 'password123', 'Password123', '123456789012'}
weak_passwords = {
"admin123",
"Admin123",
"password123",
"Password123",
"123456789012",
}
if v in weak_passwords:
raise ValueError(
"FIRST_SUPERUSER_PASSWORD is too weak. "
@@ -144,8 +148,8 @@ class Settings(BaseSettings):
"env_file": "../.env",
"env_file_encoding": "utf-8",
"case_sensitive": True,
"extra": "ignore" # Ignore extra fields from .env (e.g., frontend-specific vars)
"extra": "ignore", # Ignore extra fields from .env (e.g., frontend-specific vars)
}
settings = Settings()
settings = Settings()

View File

@@ -5,17 +5,18 @@ Database configuration using SQLAlchemy 2.0 and asyncpg.
This module provides async database connectivity with proper connection pooling
and session management for FastAPI endpoints.
"""
import logging
from collections.abc import AsyncGenerator
from contextlib import asynccontextmanager
from typing import AsyncGenerator
from sqlalchemy import text
from sqlalchemy.dialects.postgresql import JSONB, UUID
from sqlalchemy.ext.asyncio import (
AsyncSession,
AsyncEngine,
create_async_engine,
AsyncSession,
async_sessionmaker,
create_async_engine,
)
from sqlalchemy.ext.compiler import compiles
from sqlalchemy.orm import DeclarativeBase
@@ -27,12 +28,12 @@ logger = logging.getLogger(__name__)
# SQLite compatibility for testing
@compiles(JSONB, 'sqlite')
@compiles(JSONB, "sqlite")
def compile_jsonb_sqlite(type_, compiler, **kw):
return "TEXT"
@compiles(UUID, 'sqlite')
@compiles(UUID, "sqlite")
def compile_uuid_sqlite(type_, compiler, **kw):
return "TEXT"
@@ -40,7 +41,6 @@ def compile_uuid_sqlite(type_, compiler, **kw):
# Declarative base for models (SQLAlchemy 2.0 style)
class Base(DeclarativeBase):
"""Base class for all database models."""
pass
def get_async_database_url(url: str) -> str:
@@ -139,7 +139,7 @@ async def async_transaction_scope() -> AsyncGenerator[AsyncSession, None]:
logger.debug("Async transaction committed successfully")
except Exception as e:
await session.rollback()
logger.error(f"Async transaction failed, rolling back: {str(e)}")
logger.error(f"Async transaction failed, rolling back: {e!s}")
raise
finally:
await session.close()
@@ -155,7 +155,7 @@ async def check_async_database_health() -> bool:
await db.execute(text("SELECT 1"))
return True
except Exception as e:
logger.error(f"Async database health check failed: {str(e)}")
logger.error(f"Async database health check failed: {e!s}")
return False

View File

@@ -1,8 +1,8 @@
"""
Custom exceptions and global exception handlers for the API.
"""
import logging
from typing import Optional, Union
from fastapi import HTTPException, Request, status
from fastapi.exceptions import RequestValidationError
@@ -27,17 +27,13 @@ class APIException(HTTPException):
status_code: int,
error_code: ErrorCode,
message: str,
field: Optional[str] = None,
headers: Optional[dict] = None
field: str | None = None,
headers: dict | None = None,
):
self.error_code = error_code
self.field = field
self.message = message
super().__init__(
status_code=status_code,
detail=message,
headers=headers
)
super().__init__(status_code=status_code, detail=message, headers=headers)
class AuthenticationError(APIException):
@@ -47,14 +43,14 @@ class AuthenticationError(APIException):
self,
message: str = "Authentication failed",
error_code: ErrorCode = ErrorCode.INVALID_CREDENTIALS,
field: Optional[str] = None
field: str | None = None,
):
super().__init__(
status_code=status.HTTP_401_UNAUTHORIZED,
error_code=error_code,
message=message,
field=field,
headers={"WWW-Authenticate": "Bearer"}
headers={"WWW-Authenticate": "Bearer"},
)
@@ -64,12 +60,12 @@ class AuthorizationError(APIException):
def __init__(
self,
message: str = "Insufficient permissions",
error_code: ErrorCode = ErrorCode.INSUFFICIENT_PERMISSIONS
error_code: ErrorCode = ErrorCode.INSUFFICIENT_PERMISSIONS,
):
super().__init__(
status_code=status.HTTP_403_FORBIDDEN,
error_code=error_code,
message=message
message=message,
)
@@ -79,12 +75,12 @@ class NotFoundError(APIException):
def __init__(
self,
message: str = "Resource not found",
error_code: ErrorCode = ErrorCode.NOT_FOUND
error_code: ErrorCode = ErrorCode.NOT_FOUND,
):
super().__init__(
status_code=status.HTTP_404_NOT_FOUND,
error_code=error_code,
message=message
message=message,
)
@@ -95,13 +91,13 @@ class DuplicateError(APIException):
self,
message: str = "Resource already exists",
error_code: ErrorCode = ErrorCode.DUPLICATE_ENTRY,
field: Optional[str] = None
field: str | None = None,
):
super().__init__(
status_code=status.HTTP_409_CONFLICT,
error_code=error_code,
message=message,
field=field
field=field,
)
@@ -112,13 +108,13 @@ class ValidationException(APIException):
self,
message: str = "Validation error",
error_code: ErrorCode = ErrorCode.VALIDATION_ERROR,
field: Optional[str] = None
field: str | None = None,
):
super().__init__(
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
error_code=error_code,
message=message,
field=field
field=field,
)
@@ -128,12 +124,12 @@ class DatabaseError(APIException):
def __init__(
self,
message: str = "Database operation failed",
error_code: ErrorCode = ErrorCode.DATABASE_ERROR
error_code: ErrorCode = ErrorCode.DATABASE_ERROR,
):
super().__init__(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
error_code=error_code,
message=message
message=message,
)
@@ -152,23 +148,18 @@ async def api_exception_handler(request: Request, exc: APIException) -> JSONResp
)
error_response = ErrorResponse(
errors=[ErrorDetail(
code=exc.error_code,
message=exc.message,
field=exc.field
)]
errors=[ErrorDetail(code=exc.error_code, message=exc.message, field=exc.field)]
)
return JSONResponse(
status_code=exc.status_code,
content=error_response.model_dump(),
headers=exc.headers
headers=exc.headers,
)
async def validation_exception_handler(
request: Request,
exc: Union[RequestValidationError, ValidationError]
request: Request, exc: RequestValidationError | ValidationError
) -> JSONResponse:
"""
Handler for Pydantic validation errors.
@@ -189,22 +180,19 @@ async def validation_exception_handler(
# Skip 'body' or 'query' prefix in location
field = ".".join(str(x) for x in error["loc"][1:])
errors.append(ErrorDetail(
code=ErrorCode.VALIDATION_ERROR,
message=error["msg"],
field=field
))
errors.append(
ErrorDetail(
code=ErrorCode.VALIDATION_ERROR, message=error["msg"], field=field
)
)
logger.warning(
f"Validation error: {len(errors)} errors "
f"(path: {request.url.path})"
)
logger.warning(f"Validation error: {len(errors)} errors (path: {request.url.path})")
error_response = ErrorResponse(errors=errors)
return JSONResponse(
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
content=error_response.model_dump()
content=error_response.model_dump(),
)
@@ -226,26 +214,21 @@ async def http_exception_handler(request: Request, exc: HTTPException) -> JSONRe
}
error_code = status_code_to_error_code.get(
exc.status_code,
ErrorCode.INTERNAL_ERROR
exc.status_code, ErrorCode.INTERNAL_ERROR
)
logger.warning(
f"HTTP exception: {exc.status_code} - {exc.detail} "
f"(path: {request.url.path})"
f"HTTP exception: {exc.status_code} - {exc.detail} (path: {request.url.path})"
)
error_response = ErrorResponse(
errors=[ErrorDetail(
code=error_code,
message=str(exc.detail)
)]
errors=[ErrorDetail(code=error_code, message=str(exc.detail))]
)
return JSONResponse(
status_code=exc.status_code,
content=error_response.model_dump(),
headers=exc.headers
headers=exc.headers,
)
@@ -257,26 +240,24 @@ async def unhandled_exception_handler(request: Request, exc: Exception) -> JSONR
leaking sensitive information in production.
"""
logger.error(
f"Unhandled exception: {type(exc).__name__} - {str(exc)} "
f"Unhandled exception: {type(exc).__name__} - {exc!s} "
f"(path: {request.url.path})",
exc_info=True
exc_info=True,
)
# In production, don't expose internal error details
from app.core.config import settings
if settings.ENVIRONMENT == "production":
message = "An internal error occurred. Please try again later."
else:
message = f"{type(exc).__name__}: {str(exc)}"
message = f"{type(exc).__name__}: {exc!s}"
error_response = ErrorResponse(
errors=[ErrorDetail(
code=ErrorCode.INTERNAL_ERROR,
message=message
)]
errors=[ErrorDetail(code=ErrorCode.INTERNAL_ERROR, message=message)]
)
return JSONResponse(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
content=error_response.model_dump()
content=error_response.model_dump(),
)

View File

@@ -3,4 +3,4 @@ from .organization import organization
from .session import session as session_crud
from .user import user
__all__ = ["user", "session_crud", "organization"]
__all__ = ["organization", "session_crud", "user"]

View File

@@ -4,14 +4,16 @@ Async CRUD operations base class using SQLAlchemy 2.0 async patterns.
Provides reusable create, read, update, and delete operations for all models.
"""
import logging
import uuid
from typing import Any, Dict, Generic, List, Optional, Type, TypeVar, Union, Tuple
from datetime import UTC
from typing import Any, TypeVar
from fastapi.encoders import jsonable_encoder
from pydantic import BaseModel
from sqlalchemy import func, select
from sqlalchemy.exc import IntegrityError, OperationalError, DataError
from sqlalchemy.exc import DataError, IntegrityError, OperationalError
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import Load
@@ -24,10 +26,14 @@ CreateSchemaType = TypeVar("CreateSchemaType", bound=BaseModel)
UpdateSchemaType = TypeVar("UpdateSchemaType", bound=BaseModel)
class CRUDBase(Generic[ModelType, CreateSchemaType, UpdateSchemaType]):
class CRUDBase[
ModelType: Base,
CreateSchemaType: BaseModel,
UpdateSchemaType: BaseModel,
]:
"""Async CRUD operations for a model."""
def __init__(self, model: Type[ModelType]):
def __init__(self, model: type[ModelType]):
"""
CRUD object with default async methods to Create, Read, Update, Delete.
@@ -37,11 +43,8 @@ class CRUDBase(Generic[ModelType, CreateSchemaType, UpdateSchemaType]):
self.model = model
async def get(
self,
db: AsyncSession,
id: str,
options: Optional[List[Load]] = None
) -> Optional[ModelType]:
self, db: AsyncSession, id: str, options: list[Load] | None = None
) -> ModelType | None:
"""
Get a single record by ID with UUID validation and optional eager loading.
@@ -66,7 +69,7 @@ class CRUDBase(Generic[ModelType, CreateSchemaType, UpdateSchemaType]):
else:
uuid_obj = uuid.UUID(str(id))
except (ValueError, AttributeError, TypeError) as e:
logger.warning(f"Invalid UUID format: {id} - {str(e)}")
logger.warning(f"Invalid UUID format: {id} - {e!s}")
return None
try:
@@ -80,7 +83,7 @@ class CRUDBase(Generic[ModelType, CreateSchemaType, UpdateSchemaType]):
result = await db.execute(query)
return result.scalar_one_or_none()
except Exception as e:
logger.error(f"Error retrieving {self.model.__name__} with id {id}: {str(e)}")
logger.error(f"Error retrieving {self.model.__name__} with id {id}: {e!s}")
raise
async def get_multi(
@@ -89,8 +92,8 @@ class CRUDBase(Generic[ModelType, CreateSchemaType, UpdateSchemaType]):
*,
skip: int = 0,
limit: int = 100,
options: Optional[List[Load]] = None
) -> List[ModelType]:
options: list[Load] | None = None,
) -> list[ModelType]:
"""
Get multiple records with pagination validation and optional eager loading.
@@ -122,10 +125,14 @@ class CRUDBase(Generic[ModelType, CreateSchemaType, UpdateSchemaType]):
result = await db.execute(query)
return list(result.scalars().all())
except Exception as e:
logger.error(f"Error retrieving multiple {self.model.__name__} records: {str(e)}")
logger.error(
f"Error retrieving multiple {self.model.__name__} records: {e!s}"
)
raise
async def create(self, db: AsyncSession, *, obj_in: CreateSchemaType) -> ModelType: # pragma: no cover
async def create(
self, db: AsyncSession, *, obj_in: CreateSchemaType
) -> ModelType: # pragma: no cover
"""Create a new record with error handling.
NOTE: This method is defensive code that's never called in practice.
@@ -142,19 +149,25 @@ class CRUDBase(Generic[ModelType, CreateSchemaType, UpdateSchemaType]):
return db_obj
except IntegrityError as e: # pragma: no cover
await db.rollback()
error_msg = str(e.orig) if hasattr(e, 'orig') else str(e)
error_msg = str(e.orig) if hasattr(e, "orig") else str(e)
if "unique" in error_msg.lower() or "duplicate" in error_msg.lower():
logger.warning(f"Duplicate entry attempted for {self.model.__name__}: {error_msg}")
raise ValueError(f"A {self.model.__name__} with this data already exists")
logger.warning(
f"Duplicate entry attempted for {self.model.__name__}: {error_msg}"
)
raise ValueError(
f"A {self.model.__name__} with this data already exists"
)
logger.error(f"Integrity error creating {self.model.__name__}: {error_msg}")
raise ValueError(f"Database integrity error: {error_msg}")
except (OperationalError, DataError) as e: # pragma: no cover
await db.rollback()
logger.error(f"Database error creating {self.model.__name__}: {str(e)}")
raise ValueError(f"Database operation failed: {str(e)}")
logger.error(f"Database error creating {self.model.__name__}: {e!s}")
raise ValueError(f"Database operation failed: {e!s}")
except Exception as e: # pragma: no cover
await db.rollback()
logger.error(f"Unexpected error creating {self.model.__name__}: {str(e)}", exc_info=True)
logger.error(
f"Unexpected error creating {self.model.__name__}: {e!s}", exc_info=True
)
raise
async def update(
@@ -162,7 +175,7 @@ class CRUDBase(Generic[ModelType, CreateSchemaType, UpdateSchemaType]):
db: AsyncSession,
*,
db_obj: ModelType,
obj_in: Union[UpdateSchemaType, Dict[str, Any]]
obj_in: UpdateSchemaType | dict[str, Any],
) -> ModelType:
"""Update a record with error handling."""
try:
@@ -182,22 +195,28 @@ class CRUDBase(Generic[ModelType, CreateSchemaType, UpdateSchemaType]):
return db_obj
except IntegrityError as e:
await db.rollback()
error_msg = str(e.orig) if hasattr(e, 'orig') else str(e)
error_msg = str(e.orig) if hasattr(e, "orig") else str(e)
if "unique" in error_msg.lower() or "duplicate" in error_msg.lower():
logger.warning(f"Duplicate entry attempted for {self.model.__name__}: {error_msg}")
raise ValueError(f"A {self.model.__name__} with this data already exists")
logger.warning(
f"Duplicate entry attempted for {self.model.__name__}: {error_msg}"
)
raise ValueError(
f"A {self.model.__name__} with this data already exists"
)
logger.error(f"Integrity error updating {self.model.__name__}: {error_msg}")
raise ValueError(f"Database integrity error: {error_msg}")
except (OperationalError, DataError) as e:
await db.rollback()
logger.error(f"Database error updating {self.model.__name__}: {str(e)}")
raise ValueError(f"Database operation failed: {str(e)}")
logger.error(f"Database error updating {self.model.__name__}: {e!s}")
raise ValueError(f"Database operation failed: {e!s}")
except Exception as e:
await db.rollback()
logger.error(f"Unexpected error updating {self.model.__name__}: {str(e)}", exc_info=True)
logger.error(
f"Unexpected error updating {self.model.__name__}: {e!s}", exc_info=True
)
raise
async def remove(self, db: AsyncSession, *, id: str) -> Optional[ModelType]:
async def remove(self, db: AsyncSession, *, id: str) -> ModelType | None:
"""Delete a record with error handling and null check."""
# Validate UUID format and convert to UUID object if string
try:
@@ -206,7 +225,7 @@ class CRUDBase(Generic[ModelType, CreateSchemaType, UpdateSchemaType]):
else:
uuid_obj = uuid.UUID(str(id))
except (ValueError, AttributeError, TypeError) as e:
logger.warning(f"Invalid UUID format for deletion: {id} - {str(e)}")
logger.warning(f"Invalid UUID format for deletion: {id} - {e!s}")
return None
try:
@@ -216,7 +235,9 @@ class CRUDBase(Generic[ModelType, CreateSchemaType, UpdateSchemaType]):
obj = result.scalar_one_or_none()
if obj is None:
logger.warning(f"{self.model.__name__} with id {id} not found for deletion")
logger.warning(
f"{self.model.__name__} with id {id} not found for deletion"
)
return None
await db.delete(obj)
@@ -224,12 +245,17 @@ class CRUDBase(Generic[ModelType, CreateSchemaType, UpdateSchemaType]):
return obj
except IntegrityError as e:
await db.rollback()
error_msg = str(e.orig) if hasattr(e, 'orig') else str(e)
error_msg = str(e.orig) if hasattr(e, "orig") else str(e)
logger.error(f"Integrity error deleting {self.model.__name__}: {error_msg}")
raise ValueError(f"Cannot delete {self.model.__name__}: referenced by other records")
raise ValueError(
f"Cannot delete {self.model.__name__}: referenced by other records"
)
except Exception as e:
await db.rollback()
logger.error(f"Error deleting {self.model.__name__} with id {id}: {str(e)}", exc_info=True)
logger.error(
f"Error deleting {self.model.__name__} with id {id}: {e!s}",
exc_info=True,
)
raise
async def get_multi_with_total(
@@ -238,10 +264,10 @@ class CRUDBase(Generic[ModelType, CreateSchemaType, UpdateSchemaType]):
*,
skip: int = 0,
limit: int = 100,
sort_by: Optional[str] = None,
sort_by: str | None = None,
sort_order: str = "asc",
filters: Optional[Dict[str, Any]] = None
) -> Tuple[List[ModelType], int]:
filters: dict[str, Any] | None = None,
) -> tuple[list[ModelType], int]:
"""
Get multiple records with total count, filtering, and sorting.
@@ -269,7 +295,7 @@ class CRUDBase(Generic[ModelType, CreateSchemaType, UpdateSchemaType]):
query = select(self.model)
# Exclude soft-deleted records by default
if hasattr(self.model, 'deleted_at'):
if hasattr(self.model, "deleted_at"):
query = query.where(self.model.deleted_at.is_(None))
# Apply filters
@@ -298,7 +324,9 @@ class CRUDBase(Generic[ModelType, CreateSchemaType, UpdateSchemaType]):
return items, total
except Exception as e:
logger.error(f"Error retrieving paginated {self.model.__name__} records: {str(e)}")
logger.error(
f"Error retrieving paginated {self.model.__name__} records: {e!s}"
)
raise
async def count(self, db: AsyncSession) -> int:
@@ -307,7 +335,7 @@ class CRUDBase(Generic[ModelType, CreateSchemaType, UpdateSchemaType]):
result = await db.execute(select(func.count(self.model.id)))
return result.scalar_one()
except Exception as e:
logger.error(f"Error counting {self.model.__name__} records: {str(e)}")
logger.error(f"Error counting {self.model.__name__} records: {e!s}")
raise
async def exists(self, db: AsyncSession, id: str) -> bool:
@@ -315,13 +343,13 @@ class CRUDBase(Generic[ModelType, CreateSchemaType, UpdateSchemaType]):
obj = await self.get(db, id=id)
return obj is not None
async def soft_delete(self, db: AsyncSession, *, id: str) -> Optional[ModelType]:
async def soft_delete(self, db: AsyncSession, *, id: str) -> ModelType | None:
"""
Soft delete a record by setting deleted_at timestamp.
Only works if the model has a 'deleted_at' column.
"""
from datetime import datetime, timezone
from datetime import datetime
# Validate UUID format and convert to UUID object if string
try:
@@ -330,7 +358,7 @@ class CRUDBase(Generic[ModelType, CreateSchemaType, UpdateSchemaType]):
else:
uuid_obj = uuid.UUID(str(id))
except (ValueError, AttributeError, TypeError) as e:
logger.warning(f"Invalid UUID format for soft deletion: {id} - {str(e)}")
logger.warning(f"Invalid UUID format for soft deletion: {id} - {e!s}")
return None
try:
@@ -340,26 +368,33 @@ class CRUDBase(Generic[ModelType, CreateSchemaType, UpdateSchemaType]):
obj = result.scalar_one_or_none()
if obj is None:
logger.warning(f"{self.model.__name__} with id {id} not found for soft deletion")
logger.warning(
f"{self.model.__name__} with id {id} not found for soft deletion"
)
return None
# Check if model supports soft deletes
if not hasattr(self.model, 'deleted_at'):
if not hasattr(self.model, "deleted_at"):
logger.error(f"{self.model.__name__} does not support soft deletes")
raise ValueError(f"{self.model.__name__} does not have a deleted_at column")
raise ValueError(
f"{self.model.__name__} does not have a deleted_at column"
)
# Set deleted_at timestamp
obj.deleted_at = datetime.now(timezone.utc)
obj.deleted_at = datetime.now(UTC)
db.add(obj)
await db.commit()
await db.refresh(obj)
return obj
except Exception as e:
await db.rollback()
logger.error(f"Error soft deleting {self.model.__name__} with id {id}: {str(e)}", exc_info=True)
logger.error(
f"Error soft deleting {self.model.__name__} with id {id}: {e!s}",
exc_info=True,
)
raise
async def restore(self, db: AsyncSession, *, id: str) -> Optional[ModelType]:
async def restore(self, db: AsyncSession, *, id: str) -> ModelType | None:
"""
Restore a soft-deleted record by clearing the deleted_at timestamp.
@@ -372,25 +407,28 @@ class CRUDBase(Generic[ModelType, CreateSchemaType, UpdateSchemaType]):
else:
uuid_obj = uuid.UUID(str(id))
except (ValueError, AttributeError, TypeError) as e:
logger.warning(f"Invalid UUID format for restoration: {id} - {str(e)}")
logger.warning(f"Invalid UUID format for restoration: {id} - {e!s}")
return None
try:
# Find the soft-deleted record
if hasattr(self.model, 'deleted_at'):
if hasattr(self.model, "deleted_at"):
result = await db.execute(
select(self.model).where(
self.model.id == uuid_obj,
self.model.deleted_at.isnot(None)
self.model.id == uuid_obj, self.model.deleted_at.isnot(None)
)
)
obj = result.scalar_one_or_none()
else:
logger.error(f"{self.model.__name__} does not support soft deletes")
raise ValueError(f"{self.model.__name__} does not have a deleted_at column")
raise ValueError(
f"{self.model.__name__} does not have a deleted_at column"
)
if obj is None:
logger.warning(f"Soft-deleted {self.model.__name__} with id {id} not found for restoration")
logger.warning(
f"Soft-deleted {self.model.__name__} with id {id} not found for restoration"
)
return None
# Clear deleted_at timestamp
@@ -401,5 +439,8 @@ class CRUDBase(Generic[ModelType, CreateSchemaType, UpdateSchemaType]):
return obj
except Exception as e:
await db.rollback()
logger.error(f"Error restoring {self.model.__name__} with id {id}: {str(e)}", exc_info=True)
logger.error(
f"Error restoring {self.model.__name__} with id {id}: {e!s}",
exc_info=True,
)
raise

View File

@@ -1,17 +1,18 @@
# app/crud/organization_async.py
"""Async CRUD operations for Organization model using SQLAlchemy 2.0 patterns."""
import logging
from typing import Optional, List, Dict, Any
from typing import Any
from uuid import UUID
from sqlalchemy import func, or_, and_, select, case
from sqlalchemy import and_, case, func, or_, select
from sqlalchemy.exc import IntegrityError
from sqlalchemy.ext.asyncio import AsyncSession
from app.crud.base import CRUDBase
from app.models.organization import Organization
from app.models.user import User
from app.models.user_organization import UserOrganization, OrganizationRole
from app.models.user_organization import OrganizationRole, UserOrganization
from app.schemas.organizations import (
OrganizationCreate,
OrganizationUpdate,
@@ -23,7 +24,7 @@ logger = logging.getLogger(__name__)
class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUpdate]):
"""Async CRUD operations for Organization model."""
async def get_by_slug(self, db: AsyncSession, *, slug: str) -> Optional[Organization]:
async def get_by_slug(self, db: AsyncSession, *, slug: str) -> Organization | None:
"""Get organization by slug."""
try:
result = await db.execute(
@@ -31,10 +32,12 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp
)
return result.scalar_one_or_none()
except Exception as e:
logger.error(f"Error getting organization by slug {slug}: {str(e)}")
logger.error(f"Error getting organization by slug {slug}: {e!s}")
raise
async def create(self, db: AsyncSession, *, obj_in: OrganizationCreate) -> Organization:
async def create(
self, db: AsyncSession, *, obj_in: OrganizationCreate
) -> Organization:
"""Create a new organization with error handling."""
try:
db_obj = Organization(
@@ -42,7 +45,7 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp
slug=obj_in.slug,
description=obj_in.description,
is_active=obj_in.is_active,
settings=obj_in.settings or {}
settings=obj_in.settings or {},
)
db.add(db_obj)
await db.commit()
@@ -50,15 +53,19 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp
return db_obj
except IntegrityError as e:
await db.rollback()
error_msg = str(e.orig) if hasattr(e, 'orig') else str(e)
error_msg = str(e.orig) if hasattr(e, "orig") else str(e)
if "slug" in error_msg.lower():
logger.warning(f"Duplicate slug attempted: {obj_in.slug}")
raise ValueError(f"Organization with slug '{obj_in.slug}' already exists")
raise ValueError(
f"Organization with slug '{obj_in.slug}' already exists"
)
logger.error(f"Integrity error creating organization: {error_msg}")
raise ValueError(f"Database integrity error: {error_msg}")
except Exception as e:
await db.rollback()
logger.error(f"Unexpected error creating organization: {str(e)}", exc_info=True)
logger.error(
f"Unexpected error creating organization: {e!s}", exc_info=True
)
raise
async def get_multi_with_filters(
@@ -67,11 +74,11 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp
*,
skip: int = 0,
limit: int = 100,
is_active: Optional[bool] = None,
search: Optional[str] = None,
is_active: bool | None = None,
search: str | None = None,
sort_by: str = "created_at",
sort_order: str = "desc"
) -> tuple[List[Organization], int]:
sort_order: str = "desc",
) -> tuple[list[Organization], int]:
"""
Get multiple organizations with filtering, searching, and sorting.
@@ -89,7 +96,7 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp
search_filter = or_(
Organization.name.ilike(f"%{search}%"),
Organization.slug.ilike(f"%{search}%"),
Organization.description.ilike(f"%{search}%")
Organization.description.ilike(f"%{search}%"),
)
query = query.where(search_filter)
@@ -112,7 +119,7 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp
return organizations, total
except Exception as e:
logger.error(f"Error getting organizations with filters: {str(e)}")
logger.error(f"Error getting organizations with filters: {e!s}")
raise
async def get_member_count(self, db: AsyncSession, *, organization_id: UUID) -> int:
@@ -122,13 +129,15 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp
select(func.count(UserOrganization.user_id)).where(
and_(
UserOrganization.organization_id == organization_id,
UserOrganization.is_active == True
UserOrganization.is_active,
)
)
)
return result.scalar_one() or 0
except Exception as e:
logger.error(f"Error getting member count for organization {organization_id}: {str(e)}")
logger.error(
f"Error getting member count for organization {organization_id}: {e!s}"
)
raise
async def get_multi_with_member_counts(
@@ -137,9 +146,9 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp
*,
skip: int = 0,
limit: int = 100,
is_active: Optional[bool] = None,
search: Optional[str] = None
) -> tuple[List[Dict[str, Any]], int]:
is_active: bool | None = None,
search: str | None = None,
) -> tuple[list[dict[str, Any]], int]:
"""
Get organizations with member counts in a SINGLE QUERY using JOIN and GROUP BY.
This eliminates the N+1 query problem.
@@ -156,13 +165,19 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp
func.count(
func.distinct(
case(
(UserOrganization.is_active == True, UserOrganization.user_id),
else_=None
(
UserOrganization.is_active,
UserOrganization.user_id,
),
else_=None,
)
)
).label('member_count')
).label("member_count"),
)
.outerjoin(
UserOrganization,
Organization.id == UserOrganization.organization_id,
)
.outerjoin(UserOrganization, Organization.id == UserOrganization.organization_id)
.group_by(Organization.id)
)
@@ -174,7 +189,7 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp
search_filter = or_(
Organization.name.ilike(f"%{search}%"),
Organization.slug.ilike(f"%{search}%"),
Organization.description.ilike(f"%{search}%")
Organization.description.ilike(f"%{search}%"),
)
query = query.where(search_filter)
@@ -189,24 +204,25 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp
total = count_result.scalar_one()
# Apply pagination and ordering
query = query.order_by(Organization.created_at.desc()).offset(skip).limit(limit)
query = (
query.order_by(Organization.created_at.desc()).offset(skip).limit(limit)
)
result = await db.execute(query)
rows = result.all()
# Convert to list of dicts
orgs_with_counts = [
{
'organization': org,
'member_count': member_count
}
{"organization": org, "member_count": member_count}
for org, member_count in rows
]
return orgs_with_counts, total
except Exception as e:
logger.error(f"Error getting organizations with member counts: {str(e)}", exc_info=True)
logger.error(
f"Error getting organizations with member counts: {e!s}", exc_info=True
)
raise
async def add_user(
@@ -216,7 +232,7 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp
organization_id: UUID,
user_id: UUID,
role: OrganizationRole = OrganizationRole.MEMBER,
custom_permissions: Optional[str] = None
custom_permissions: str | None = None,
) -> UserOrganization:
"""Add a user to an organization with a specific role."""
try:
@@ -225,7 +241,7 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp
select(UserOrganization).where(
and_(
UserOrganization.user_id == user_id,
UserOrganization.organization_id == organization_id
UserOrganization.organization_id == organization_id,
)
)
)
@@ -249,7 +265,7 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp
organization_id=organization_id,
role=role,
is_active=True,
custom_permissions=custom_permissions
custom_permissions=custom_permissions,
)
db.add(user_org)
await db.commit()
@@ -257,19 +273,15 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp
return user_org
except IntegrityError as e:
await db.rollback()
logger.error(f"Integrity error adding user to organization: {str(e)}")
logger.error(f"Integrity error adding user to organization: {e!s}")
raise ValueError("Failed to add user to organization")
except Exception as e:
await db.rollback()
logger.error(f"Error adding user to organization: {str(e)}", exc_info=True)
logger.error(f"Error adding user to organization: {e!s}", exc_info=True)
raise
async def remove_user(
self,
db: AsyncSession,
*,
organization_id: UUID,
user_id: UUID
self, db: AsyncSession, *, organization_id: UUID, user_id: UUID
) -> bool:
"""Remove a user from an organization (soft delete)."""
try:
@@ -277,7 +289,7 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp
select(UserOrganization).where(
and_(
UserOrganization.user_id == user_id,
UserOrganization.organization_id == organization_id
UserOrganization.organization_id == organization_id,
)
)
)
@@ -291,7 +303,7 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp
return True
except Exception as e:
await db.rollback()
logger.error(f"Error removing user from organization: {str(e)}", exc_info=True)
logger.error(f"Error removing user from organization: {e!s}", exc_info=True)
raise
async def update_user_role(
@@ -301,15 +313,15 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp
organization_id: UUID,
user_id: UUID,
role: OrganizationRole,
custom_permissions: Optional[str] = None
) -> Optional[UserOrganization]:
custom_permissions: str | None = None,
) -> UserOrganization | None:
"""Update a user's role in an organization."""
try:
result = await db.execute(
select(UserOrganization).where(
and_(
UserOrganization.user_id == user_id,
UserOrganization.organization_id == organization_id
UserOrganization.organization_id == organization_id,
)
)
)
@@ -326,7 +338,7 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp
return user_org
except Exception as e:
await db.rollback()
logger.error(f"Error updating user role: {str(e)}", exc_info=True)
logger.error(f"Error updating user role: {e!s}", exc_info=True)
raise
async def get_organization_members(
@@ -336,8 +348,8 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp
organization_id: UUID,
skip: int = 0,
limit: int = 100,
is_active: bool = True
) -> tuple[List[Dict[str, Any]], int]:
is_active: bool = True,
) -> tuple[list[dict[str, Any]], int]:
"""
Get members of an organization with user details.
@@ -359,46 +371,55 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp
count_query = select(func.count()).select_from(
select(UserOrganization)
.where(UserOrganization.organization_id == organization_id)
.where(UserOrganization.is_active == is_active if is_active is not None else True)
.where(
UserOrganization.is_active == is_active
if is_active is not None
else True
)
.alias()
)
count_result = await db.execute(count_query)
total = count_result.scalar_one()
# Apply ordering and pagination
query = query.order_by(UserOrganization.created_at.desc()).offset(skip).limit(limit)
query = (
query.order_by(UserOrganization.created_at.desc())
.offset(skip)
.limit(limit)
)
result = await db.execute(query)
results = result.all()
members = []
for user_org, user in results:
members.append({
"user_id": user.id,
"email": user.email,
"first_name": user.first_name,
"last_name": user.last_name,
"role": user_org.role,
"is_active": user_org.is_active,
"joined_at": user_org.created_at
})
members.append(
{
"user_id": user.id,
"email": user.email,
"first_name": user.first_name,
"last_name": user.last_name,
"role": user_org.role,
"is_active": user_org.is_active,
"joined_at": user_org.created_at,
}
)
return members, total
except Exception as e:
logger.error(f"Error getting organization members: {str(e)}")
logger.error(f"Error getting organization members: {e!s}")
raise
async def get_user_organizations(
self,
db: AsyncSession,
*,
user_id: UUID,
is_active: bool = True
) -> List[Organization]:
self, db: AsyncSession, *, user_id: UUID, is_active: bool = True
) -> list[Organization]:
"""Get all organizations a user belongs to."""
try:
query = (
select(Organization)
.join(UserOrganization, Organization.id == UserOrganization.organization_id)
.join(
UserOrganization,
Organization.id == UserOrganization.organization_id,
)
.where(UserOrganization.user_id == user_id)
)
@@ -408,16 +429,12 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp
result = await db.execute(query)
return list(result.scalars().all())
except Exception as e:
logger.error(f"Error getting user organizations: {str(e)}")
logger.error(f"Error getting user organizations: {e!s}")
raise
async def get_user_organizations_with_details(
self,
db: AsyncSession,
*,
user_id: UUID,
is_active: bool = True
) -> List[Dict[str, Any]]:
self, db: AsyncSession, *, user_id: UUID, is_active: bool = True
) -> list[dict[str, Any]]:
"""
Get user's organizations with role and member count in SINGLE QUERY.
Eliminates N+1 problem by using subquery for member counts.
@@ -430,9 +447,9 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp
member_count_subq = (
select(
UserOrganization.organization_id,
func.count(UserOrganization.user_id).label('member_count')
func.count(UserOrganization.user_id).label("member_count"),
)
.where(UserOrganization.is_active == True)
.where(UserOrganization.is_active)
.group_by(UserOrganization.organization_id)
.subquery()
)
@@ -442,10 +459,18 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp
select(
Organization,
UserOrganization.role,
func.coalesce(member_count_subq.c.member_count, 0).label('member_count')
func.coalesce(member_count_subq.c.member_count, 0).label(
"member_count"
),
)
.join(
UserOrganization,
Organization.id == UserOrganization.organization_id,
)
.outerjoin(
member_count_subq,
Organization.id == member_count_subq.c.organization_id,
)
.join(UserOrganization, Organization.id == UserOrganization.organization_id)
.outerjoin(member_count_subq, Organization.id == member_count_subq.c.organization_id)
.where(UserOrganization.user_id == user_id)
)
@@ -456,25 +481,19 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp
rows = result.all()
return [
{
'organization': org,
'role': role,
'member_count': member_count
}
{"organization": org, "role": role, "member_count": member_count}
for org, role, member_count in rows
]
except Exception as e:
logger.error(f"Error getting user organizations with details: {str(e)}", exc_info=True)
logger.error(
f"Error getting user organizations with details: {e!s}", exc_info=True
)
raise
async def get_user_role_in_org(
self,
db: AsyncSession,
*,
user_id: UUID,
organization_id: UUID
) -> Optional[OrganizationRole]:
self, db: AsyncSession, *, user_id: UUID, organization_id: UUID
) -> OrganizationRole | None:
"""Get a user's role in a specific organization."""
try:
result = await db.execute(
@@ -482,7 +501,7 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp
and_(
UserOrganization.user_id == user_id,
UserOrganization.organization_id == organization_id,
UserOrganization.is_active == True
UserOrganization.is_active,
)
)
)
@@ -490,29 +509,25 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp
return user_org.role if user_org else None
except Exception as e:
logger.error(f"Error getting user role in org: {str(e)}")
logger.error(f"Error getting user role in org: {e!s}")
raise
async def is_user_org_owner(
self,
db: AsyncSession,
*,
user_id: UUID,
organization_id: UUID
self, db: AsyncSession, *, user_id: UUID, organization_id: UUID
) -> bool:
"""Check if a user is an owner of an organization."""
role = await self.get_user_role_in_org(db, user_id=user_id, organization_id=organization_id)
role = await self.get_user_role_in_org(
db, user_id=user_id, organization_id=organization_id
)
return role == OrganizationRole.OWNER
async def is_user_org_admin(
self,
db: AsyncSession,
*,
user_id: UUID,
organization_id: UUID
self, db: AsyncSession, *, user_id: UUID, organization_id: UUID
) -> bool:
"""Check if a user is an owner or admin of an organization."""
role = await self.get_user_role_in_org(db, user_id=user_id, organization_id=organization_id)
role = await self.get_user_role_in_org(
db, user_id=user_id, organization_id=organization_id
)
return role in [OrganizationRole.OWNER, OrganizationRole.ADMIN]

View File

@@ -1,13 +1,13 @@
"""
Async CRUD operations for user sessions using SQLAlchemy 2.0 patterns.
"""
import logging
import uuid
from datetime import datetime, timezone, timedelta
from typing import List, Optional
from datetime import UTC, datetime, timedelta
from uuid import UUID
from sqlalchemy import and_, select, update, delete, func
from sqlalchemy import and_, delete, func, select, update
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import joinedload
@@ -21,7 +21,7 @@ logger = logging.getLogger(__name__)
class CRUDSession(CRUDBase[UserSession, SessionCreate, SessionUpdate]):
"""Async CRUD operations for user sessions."""
async def get_by_jti(self, db: AsyncSession, *, jti: str) -> Optional[UserSession]:
async def get_by_jti(self, db: AsyncSession, *, jti: str) -> UserSession | None:
"""
Get session by refresh token JTI.
@@ -38,10 +38,12 @@ class CRUDSession(CRUDBase[UserSession, SessionCreate, SessionUpdate]):
)
return result.scalar_one_or_none()
except Exception as e:
logger.error(f"Error getting session by JTI {jti}: {str(e)}")
logger.error(f"Error getting session by JTI {jti}: {e!s}")
raise
async def get_active_by_jti(self, db: AsyncSession, *, jti: str) -> Optional[UserSession]:
async def get_active_by_jti(
self, db: AsyncSession, *, jti: str
) -> UserSession | None:
"""
Get active session by refresh token JTI.
@@ -57,13 +59,13 @@ class CRUDSession(CRUDBase[UserSession, SessionCreate, SessionUpdate]):
select(UserSession).where(
and_(
UserSession.refresh_token_jti == jti,
UserSession.is_active == True
UserSession.is_active,
)
)
)
return result.scalar_one_or_none()
except Exception as e:
logger.error(f"Error getting active session by JTI {jti}: {str(e)}")
logger.error(f"Error getting active session by JTI {jti}: {e!s}")
raise
async def get_user_sessions(
@@ -72,8 +74,8 @@ class CRUDSession(CRUDBase[UserSession, SessionCreate, SessionUpdate]):
*,
user_id: str,
active_only: bool = True,
with_user: bool = False
) -> List[UserSession]:
with_user: bool = False,
) -> list[UserSession]:
"""
Get all sessions for a user with optional eager loading.
@@ -97,20 +99,17 @@ class CRUDSession(CRUDBase[UserSession, SessionCreate, SessionUpdate]):
query = query.options(joinedload(UserSession.user))
if active_only:
query = query.where(UserSession.is_active == True)
query = query.where(UserSession.is_active)
query = query.order_by(UserSession.last_used_at.desc())
result = await db.execute(query)
return list(result.scalars().all())
except Exception as e:
logger.error(f"Error getting sessions for user {user_id}: {str(e)}")
logger.error(f"Error getting sessions for user {user_id}: {e!s}")
raise
async def create_session(
self,
db: AsyncSession,
*,
obj_in: SessionCreate
self, db: AsyncSession, *, obj_in: SessionCreate
) -> UserSession:
"""
Create a new user session.
@@ -151,10 +150,12 @@ class CRUDSession(CRUDBase[UserSession, SessionCreate, SessionUpdate]):
return db_obj
except Exception as e:
await db.rollback()
logger.error(f"Error creating session: {str(e)}", exc_info=True)
raise ValueError(f"Failed to create session: {str(e)}")
logger.error(f"Error creating session: {e!s}", exc_info=True)
raise ValueError(f"Failed to create session: {e!s}")
async def deactivate(self, db: AsyncSession, *, session_id: str) -> Optional[UserSession]:
async def deactivate(
self, db: AsyncSession, *, session_id: str
) -> UserSession | None:
"""
Deactivate a session (logout from device).
@@ -184,14 +185,11 @@ class CRUDSession(CRUDBase[UserSession, SessionCreate, SessionUpdate]):
return session
except Exception as e:
await db.rollback()
logger.error(f"Error deactivating session {session_id}: {str(e)}")
logger.error(f"Error deactivating session {session_id}: {e!s}")
raise
async def deactivate_all_user_sessions(
self,
db: AsyncSession,
*,
user_id: str
self, db: AsyncSession, *, user_id: str
) -> int:
"""
Deactivate all active sessions for a user (logout from all devices).
@@ -209,12 +207,7 @@ class CRUDSession(CRUDBase[UserSession, SessionCreate, SessionUpdate]):
stmt = (
update(UserSession)
.where(
and_(
UserSession.user_id == user_uuid,
UserSession.is_active == True
)
)
.where(and_(UserSession.user_id == user_uuid, UserSession.is_active))
.values(is_active=False)
)
@@ -228,14 +221,11 @@ class CRUDSession(CRUDBase[UserSession, SessionCreate, SessionUpdate]):
return count
except Exception as e:
await db.rollback()
logger.error(f"Error deactivating all sessions for user {user_id}: {str(e)}")
logger.error(f"Error deactivating all sessions for user {user_id}: {e!s}")
raise
async def update_last_used(
self,
db: AsyncSession,
*,
session: UserSession
self, db: AsyncSession, *, session: UserSession
) -> UserSession:
"""
Update the last_used_at timestamp for a session.
@@ -248,14 +238,14 @@ class CRUDSession(CRUDBase[UserSession, SessionCreate, SessionUpdate]):
Updated UserSession
"""
try:
session.last_used_at = datetime.now(timezone.utc)
session.last_used_at = datetime.now(UTC)
db.add(session)
await db.commit()
await db.refresh(session)
return session
except Exception as e:
await db.rollback()
logger.error(f"Error updating last_used for session {session.id}: {str(e)}")
logger.error(f"Error updating last_used for session {session.id}: {e!s}")
raise
async def update_refresh_token(
@@ -264,7 +254,7 @@ class CRUDSession(CRUDBase[UserSession, SessionCreate, SessionUpdate]):
*,
session: UserSession,
new_jti: str,
new_expires_at: datetime
new_expires_at: datetime,
) -> UserSession:
"""
Update session with new refresh token JTI and expiration.
@@ -283,14 +273,16 @@ class CRUDSession(CRUDBase[UserSession, SessionCreate, SessionUpdate]):
try:
session.refresh_token_jti = new_jti
session.expires_at = new_expires_at
session.last_used_at = datetime.now(timezone.utc)
session.last_used_at = datetime.now(UTC)
db.add(session)
await db.commit()
await db.refresh(session)
return session
except Exception as e:
await db.rollback()
logger.error(f"Error updating refresh token for session {session.id}: {str(e)}")
logger.error(
f"Error updating refresh token for session {session.id}: {e!s}"
)
raise
async def cleanup_expired(self, db: AsyncSession, *, keep_days: int = 30) -> int:
@@ -311,15 +303,15 @@ class CRUDSession(CRUDBase[UserSession, SessionCreate, SessionUpdate]):
Number of sessions deleted
"""
try:
cutoff_date = datetime.now(timezone.utc) - timedelta(days=keep_days)
now = datetime.now(timezone.utc)
cutoff_date = datetime.now(UTC) - timedelta(days=keep_days)
now = datetime.now(UTC)
# Use bulk DELETE with WHERE clause - single query
stmt = delete(UserSession).where(
and_(
UserSession.is_active == False,
not UserSession.is_active,
UserSession.expires_at < now,
UserSession.created_at < cutoff_date
UserSession.created_at < cutoff_date,
)
)
@@ -334,15 +326,10 @@ class CRUDSession(CRUDBase[UserSession, SessionCreate, SessionUpdate]):
return count
except Exception as e:
await db.rollback()
logger.error(f"Error cleaning up expired sessions: {str(e)}")
logger.error(f"Error cleaning up expired sessions: {e!s}")
raise
async def cleanup_expired_for_user(
self,
db: AsyncSession,
*,
user_id: str
) -> int:
async def cleanup_expired_for_user(self, db: AsyncSession, *, user_id: str) -> int:
"""
Clean up expired and inactive sessions for a specific user.
@@ -363,14 +350,14 @@ class CRUDSession(CRUDBase[UserSession, SessionCreate, SessionUpdate]):
logger.error(f"Invalid UUID format: {user_id}")
raise ValueError(f"Invalid user ID format: {user_id}")
now = datetime.now(timezone.utc)
now = datetime.now(UTC)
# Use bulk DELETE with WHERE clause - single query
stmt = delete(UserSession).where(
and_(
UserSession.user_id == uuid_obj,
UserSession.is_active == False,
UserSession.expires_at < now
not UserSession.is_active,
UserSession.expires_at < now,
)
)
@@ -388,7 +375,7 @@ class CRUDSession(CRUDBase[UserSession, SessionCreate, SessionUpdate]):
except Exception as e:
await db.rollback()
logger.error(
f"Error cleaning up expired sessions for user {user_id}: {str(e)}"
f"Error cleaning up expired sessions for user {user_id}: {e!s}"
)
raise
@@ -409,15 +396,12 @@ class CRUDSession(CRUDBase[UserSession, SessionCreate, SessionUpdate]):
result = await db.execute(
select(func.count(UserSession.id)).where(
and_(
UserSession.user_id == user_uuid,
UserSession.is_active == True
)
and_(UserSession.user_id == user_uuid, UserSession.is_active)
)
)
return result.scalar_one()
except Exception as e:
logger.error(f"Error counting sessions for user {user_id}: {str(e)}")
logger.error(f"Error counting sessions for user {user_id}: {e!s}")
raise
async def get_all_sessions(
@@ -427,8 +411,8 @@ class CRUDSession(CRUDBase[UserSession, SessionCreate, SessionUpdate]):
skip: int = 0,
limit: int = 100,
active_only: bool = True,
with_user: bool = True
) -> tuple[List[UserSession], int]:
with_user: bool = True,
) -> tuple[list[UserSession], int]:
"""
Get all sessions across all users with pagination (admin only).
@@ -451,18 +435,22 @@ class CRUDSession(CRUDBase[UserSession, SessionCreate, SessionUpdate]):
query = query.options(joinedload(UserSession.user))
if active_only:
query = query.where(UserSession.is_active == True)
query = query.where(UserSession.is_active)
# Get total count
count_query = select(func.count(UserSession.id))
if active_only:
count_query = count_query.where(UserSession.is_active == True)
count_query = count_query.where(UserSession.is_active)
count_result = await db.execute(count_query)
total = count_result.scalar_one()
# Apply pagination and ordering
query = query.order_by(UserSession.last_used_at.desc()).offset(skip).limit(limit)
query = (
query.order_by(UserSession.last_used_at.desc())
.offset(skip)
.limit(limit)
)
result = await db.execute(query)
sessions = list(result.scalars().all())
@@ -470,7 +458,7 @@ class CRUDSession(CRUDBase[UserSession, SessionCreate, SessionUpdate]):
return sessions, total
except Exception as e:
logger.error(f"Error getting all sessions: {str(e)}", exc_info=True)
logger.error(f"Error getting all sessions: {e!s}", exc_info=True)
raise

View File

@@ -1,8 +1,9 @@
# app/crud/user_async.py
"""Async CRUD operations for User model using SQLAlchemy 2.0 patterns."""
import logging
from datetime import datetime, timezone
from typing import Optional, Union, Dict, Any, List, Tuple
from datetime import UTC, datetime
from typing import Any
from uuid import UUID
from sqlalchemy import or_, select, update
@@ -20,15 +21,13 @@ logger = logging.getLogger(__name__)
class CRUDUser(CRUDBase[User, UserCreate, UserUpdate]):
"""Async CRUD operations for User model."""
async def get_by_email(self, db: AsyncSession, *, email: str) -> Optional[User]:
async def get_by_email(self, db: AsyncSession, *, email: str) -> User | None:
"""Get user by email address."""
try:
result = await db.execute(
select(User).where(User.email == email)
)
result = await db.execute(select(User).where(User.email == email))
return result.scalar_one_or_none()
except Exception as e:
logger.error(f"Error getting user by email {email}: {str(e)}")
logger.error(f"Error getting user by email {email}: {e!s}")
raise
async def create(self, db: AsyncSession, *, obj_in: UserCreate) -> User:
@@ -42,9 +41,13 @@ class CRUDUser(CRUDBase[User, UserCreate, UserUpdate]):
password_hash=password_hash,
first_name=obj_in.first_name,
last_name=obj_in.last_name,
phone_number=obj_in.phone_number if hasattr(obj_in, 'phone_number') else None,
is_superuser=obj_in.is_superuser if hasattr(obj_in, 'is_superuser') else False,
preferences={}
phone_number=obj_in.phone_number
if hasattr(obj_in, "phone_number")
else None,
is_superuser=obj_in.is_superuser
if hasattr(obj_in, "is_superuser")
else False,
preferences={},
)
db.add(db_obj)
await db.commit()
@@ -52,7 +55,7 @@ class CRUDUser(CRUDBase[User, UserCreate, UserUpdate]):
return db_obj
except IntegrityError as e:
await db.rollback()
error_msg = str(e.orig) if hasattr(e, 'orig') else str(e)
error_msg = str(e.orig) if hasattr(e, "orig") else str(e)
if "email" in error_msg.lower():
logger.warning(f"Duplicate email attempted: {obj_in.email}")
raise ValueError(f"User with email {obj_in.email} already exists")
@@ -60,15 +63,11 @@ class CRUDUser(CRUDBase[User, UserCreate, UserUpdate]):
raise ValueError(f"Database integrity error: {error_msg}")
except Exception as e:
await db.rollback()
logger.error(f"Unexpected error creating user: {str(e)}", exc_info=True)
logger.error(f"Unexpected error creating user: {e!s}", exc_info=True)
raise
async def update(
self,
db: AsyncSession,
*,
db_obj: User,
obj_in: Union[UserUpdate, Dict[str, Any]]
self, db: AsyncSession, *, db_obj: User, obj_in: UserUpdate | dict[str, Any]
) -> User:
"""Update user with async password hashing if password is updated."""
if isinstance(obj_in, dict):
@@ -79,7 +78,9 @@ class CRUDUser(CRUDBase[User, UserCreate, UserUpdate]):
# Handle password separately if it exists in update data
# Hash password asynchronously to avoid blocking event loop
if "password" in update_data:
update_data["password_hash"] = await get_password_hash_async(update_data["password"])
update_data["password_hash"] = await get_password_hash_async(
update_data["password"]
)
del update_data["password"]
return await super().update(db, db_obj=db_obj, obj_in=update_data)
@@ -90,11 +91,11 @@ class CRUDUser(CRUDBase[User, UserCreate, UserUpdate]):
*,
skip: int = 0,
limit: int = 100,
sort_by: Optional[str] = None,
sort_by: str | None = None,
sort_order: str = "asc",
filters: Optional[Dict[str, Any]] = None,
search: Optional[str] = None
) -> Tuple[List[User], int]:
filters: dict[str, Any] | None = None,
search: str | None = None,
) -> tuple[list[User], int]:
"""
Get multiple users with total count, filtering, sorting, and search.
@@ -136,12 +137,13 @@ class CRUDUser(CRUDBase[User, UserCreate, UserUpdate]):
search_filter = or_(
User.email.ilike(f"%{search}%"),
User.first_name.ilike(f"%{search}%"),
User.last_name.ilike(f"%{search}%")
User.last_name.ilike(f"%{search}%"),
)
query = query.where(search_filter)
# Get total count
from sqlalchemy import func
count_query = select(func.count()).select_from(query.alias())
count_result = await db.execute(count_query)
total = count_result.scalar_one()
@@ -162,15 +164,11 @@ class CRUDUser(CRUDBase[User, UserCreate, UserUpdate]):
return users, total
except Exception as e:
logger.error(f"Error retrieving paginated users: {str(e)}")
logger.error(f"Error retrieving paginated users: {e!s}")
raise
async def bulk_update_status(
self,
db: AsyncSession,
*,
user_ids: List[UUID],
is_active: bool
self, db: AsyncSession, *, user_ids: list[UUID], is_active: bool
) -> int:
"""
Bulk update is_active status for multiple users.
@@ -192,7 +190,7 @@ class CRUDUser(CRUDBase[User, UserCreate, UserUpdate]):
update(User)
.where(User.id.in_(user_ids))
.where(User.deleted_at.is_(None)) # Don't update deleted users
.values(is_active=is_active, updated_at=datetime.now(timezone.utc))
.values(is_active=is_active, updated_at=datetime.now(UTC))
)
result = await db.execute(stmt)
@@ -204,15 +202,15 @@ class CRUDUser(CRUDBase[User, UserCreate, UserUpdate]):
except Exception as e:
await db.rollback()
logger.error(f"Error bulk updating user status: {str(e)}", exc_info=True)
logger.error(f"Error bulk updating user status: {e!s}", exc_info=True)
raise
async def bulk_soft_delete(
self,
db: AsyncSession,
*,
user_ids: List[UUID],
exclude_user_id: Optional[UUID] = None
user_ids: list[UUID],
exclude_user_id: UUID | None = None,
) -> int:
"""
Bulk soft delete multiple users.
@@ -239,11 +237,13 @@ class CRUDUser(CRUDBase[User, UserCreate, UserUpdate]):
stmt = (
update(User)
.where(User.id.in_(filtered_ids))
.where(User.deleted_at.is_(None)) # Don't re-delete already deleted users
.where(
User.deleted_at.is_(None)
) # Don't re-delete already deleted users
.values(
deleted_at=datetime.now(timezone.utc),
deleted_at=datetime.now(UTC),
is_active=False,
updated_at=datetime.now(timezone.utc)
updated_at=datetime.now(UTC),
)
)
@@ -256,7 +256,7 @@ class CRUDUser(CRUDBase[User, UserCreate, UserUpdate]):
except Exception as e:
await db.rollback()
logger.error(f"Error bulk deleting users: {str(e)}", exc_info=True)
logger.error(f"Error bulk deleting users: {e!s}", exc_info=True)
raise
def is_active(self, user: User) -> bool:

View File

@@ -4,9 +4,9 @@ Async database initialization script.
Creates the first superuser if configured and doesn't already exist.
"""
import asyncio
import logging
from typing import Optional
from app.core.config import settings
from app.core.database import SessionLocal, engine
@@ -17,7 +17,7 @@ from app.schemas.users import UserCreate
logger = logging.getLogger(__name__)
async def init_db() -> Optional[User]:
async def init_db() -> User | None:
"""
Initialize database with first superuser if settings are configured and user doesn't exist.
@@ -49,7 +49,7 @@ async def init_db() -> Optional[User]:
password=superuser_password,
first_name="Admin",
last_name="User",
is_superuser=True
is_superuser=True,
)
user = await user_crud.create(session, obj_in=user_in)
@@ -70,13 +70,13 @@ async def main():
# Configure logging to show info logs
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
)
try:
user = await init_db()
if user:
print(f"✓ Database initialized successfully")
print("✓ Database initialized successfully")
print(f"✓ Superuser: {user.email}")
else:
print("✗ Failed to initialize database")

View File

@@ -2,10 +2,10 @@ import logging
import os
from contextlib import asynccontextmanager
from datetime import datetime
from typing import Dict, Any
from typing import Any
from apscheduler.schedulers.asyncio import AsyncIOScheduler
from fastapi import FastAPI, status, Request, HTTPException
from fastapi import FastAPI, HTTPException, Request, status
from fastapi.exceptions import RequestValidationError
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import HTMLResponse, JSONResponse
@@ -19,9 +19,9 @@ from app.core.database import check_database_health
from app.core.exceptions import (
APIException,
api_exception_handler,
validation_exception_handler,
http_exception_handler,
unhandled_exception_handler
unhandled_exception_handler,
validation_exception_handler,
)
scheduler = AsyncIOScheduler()
@@ -52,11 +52,11 @@ async def lifespan(app: FastAPI):
# Runs daily at 2:00 AM server time
scheduler.add_job(
cleanup_expired_sessions,
'cron',
"cron",
hour=2,
minute=0,
id='cleanup_expired_sessions',
replace_existing=True
id="cleanup_expired_sessions",
replace_existing=True,
)
scheduler.start()
@@ -73,12 +73,12 @@ async def lifespan(app: FastAPI):
logger.info("Scheduled jobs stopped")
logger.info(f"Starting app!!!")
logger.info("Starting app!!!")
app = FastAPI(
title=settings.PROJECT_NAME,
version=settings.VERSION,
openapi_url=f"{settings.API_V1_STR}/openapi.json",
lifespan=lifespan
lifespan=lifespan,
)
# Add rate limiter state to app
@@ -96,7 +96,14 @@ app.add_middleware(
CORSMiddleware,
allow_origins=settings.BACKEND_CORS_ORIGINS,
allow_credentials=True,
allow_methods=["GET", "POST", "PUT", "PATCH", "DELETE", "OPTIONS"], # Explicit methods only
allow_methods=[
"GET",
"POST",
"PUT",
"PATCH",
"DELETE",
"OPTIONS",
], # Explicit methods only
allow_headers=[
"Content-Type",
"Authorization",
@@ -129,12 +136,14 @@ async def limit_request_size(request: Request, call_next):
status_code=status.HTTP_413_REQUEST_ENTITY_TOO_LARGE,
content={
"success": False,
"errors": [{
"code": "REQUEST_TOO_LARGE",
"message": f"Request body too large. Maximum size is {MAX_REQUEST_SIZE // (1024 * 1024)}MB",
"field": None
}]
}
"errors": [
{
"code": "REQUEST_TOO_LARGE",
"message": f"Request body too large. Maximum size is {MAX_REQUEST_SIZE // (1024 * 1024)}MB",
"field": None,
}
],
},
)
response = await call_next(request)
@@ -165,15 +174,19 @@ async def add_security_headers(request: Request, call_next):
# Enforce HTTPS in production
if settings.ENVIRONMENT == "production":
response.headers["Strict-Transport-Security"] = "max-age=31536000; includeSubDomains"
response.headers["Strict-Transport-Security"] = (
"max-age=31536000; includeSubDomains"
)
# Content Security Policy
csp_mode = settings.CSP_MODE.lower()
# Special handling for API docs
is_docs = request.url.path in ["/docs", "/redoc"] or \
request.url.path.startswith("/docs/") or \
request.url.path.startswith("/redoc/")
is_docs = (
request.url.path in ["/docs", "/redoc"]
or request.url.path.startswith("/docs/")
or request.url.path.startswith("/redoc/")
)
if csp_mode == "disabled":
# No CSP (only for local development/debugging)
@@ -264,7 +277,7 @@ async def root():
description="Check the health status of the API and its dependencies",
response_description="Health status information",
tags=["Health"],
operation_id="health_check"
operation_id="health_check",
)
async def health_check() -> JSONResponse:
"""
@@ -278,12 +291,12 @@ async def health_check() -> JSONResponse:
- environment: Current environment (development, staging, production)
- database: Database connectivity status
"""
health_status: Dict[str, Any] = {
health_status: dict[str, Any] = {
"status": "healthy",
"timestamp": datetime.utcnow().isoformat() + "Z",
"version": settings.VERSION,
"environment": settings.ENVIRONMENT,
"checks": {}
"checks": {},
}
response_status = status.HTTP_200_OK
@@ -294,7 +307,7 @@ async def health_check() -> JSONResponse:
if db_healthy:
health_status["checks"]["database"] = {
"status": "healthy",
"message": "Database connection successful"
"message": "Database connection successful",
}
else:
raise Exception("Database health check returned unhealthy status")
@@ -302,15 +315,12 @@ async def health_check() -> JSONResponse:
health_status["status"] = "unhealthy"
health_status["checks"]["database"] = {
"status": "unhealthy",
"message": f"Database connection failed: {str(e)}"
"message": f"Database connection failed: {e!s}",
}
response_status = status.HTTP_503_SERVICE_UNAVAILABLE
logger.error(f"Health check failed - database error: {e}")
return JSONResponse(
status_code=response_status,
content=health_status
)
return JSONResponse(status_code=response_status, content=health_status)
app.include_router(api_router, prefix=settings.API_V1_STR)

View File

@@ -2,17 +2,25 @@
Models package initialization.
Imports all models to ensure they're registered with SQLAlchemy.
"""
# First import Base to avoid circular imports
from app.core.database import Base
from .base import TimestampMixin, UUIDMixin
from .organization import Organization
# Import models
from .user import User
from .user_organization import UserOrganization, OrganizationRole
from .user_organization import OrganizationRole, UserOrganization
from .user_session import UserSession
__all__ = [
'Base', 'TimestampMixin', 'UUIDMixin',
'User', 'UserSession',
'Organization', 'UserOrganization', 'OrganizationRole',
]
"Base",
"Organization",
"OrganizationRole",
"TimestampMixin",
"UUIDMixin",
"User",
"UserOrganization",
"UserSession",
]

View File

@@ -1,20 +1,27 @@
import uuid
from datetime import datetime, timezone
from datetime import UTC, datetime
from sqlalchemy import Column, DateTime
from sqlalchemy.dialects.postgresql import UUID
# noinspection PyUnresolvedReferences
from app.core.database import Base
class TimestampMixin:
"""Mixin to add created_at and updated_at timestamps to models"""
created_at = Column(DateTime(timezone=True), default=lambda: datetime.now(timezone.utc), nullable=False)
updated_at = Column(DateTime(timezone=True), default=lambda: datetime.now(timezone.utc),
onupdate=lambda: datetime.now(timezone.utc), nullable=False)
created_at = Column(
DateTime(timezone=True), default=lambda: datetime.now(UTC), nullable=False
)
updated_at = Column(
DateTime(timezone=True),
default=lambda: datetime.now(UTC),
onupdate=lambda: datetime.now(UTC),
nullable=False,
)
class UUIDMixin:
"""Mixin to add UUID primary keys to models"""
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)

View File

@@ -1,5 +1,5 @@
# app/models/organization.py
from sqlalchemy import Column, String, Boolean, Text, Index
from sqlalchemy import Boolean, Column, Index, String, Text
from sqlalchemy.dialects.postgresql import JSONB
from sqlalchemy.orm import relationship
@@ -11,7 +11,8 @@ class Organization(Base, UUIDMixin, TimestampMixin):
Organization model for multi-tenant support.
Users can belong to multiple organizations with different roles.
"""
__tablename__ = 'organizations'
__tablename__ = "organizations"
name = Column(String(255), nullable=False, index=True)
slug = Column(String(255), unique=True, nullable=False, index=True)
@@ -20,11 +21,13 @@ class Organization(Base, UUIDMixin, TimestampMixin):
settings = Column(JSONB, default={})
# Relationships
user_organizations = relationship("UserOrganization", back_populates="organization", cascade="all, delete-orphan")
user_organizations = relationship(
"UserOrganization", back_populates="organization", cascade="all, delete-orphan"
)
__table_args__ = (
Index('ix_organizations_name_active', 'name', 'is_active'),
Index('ix_organizations_slug_active', 'slug', 'is_active'),
Index("ix_organizations_name_active", "name", "is_active"),
Index("ix_organizations_slug_active", "slug", "is_active"),
)
def __repr__(self):

View File

@@ -1,4 +1,4 @@
from sqlalchemy import Column, String, Boolean, DateTime
from sqlalchemy import Boolean, Column, DateTime, String
from sqlalchemy.dialects.postgresql import JSONB
from sqlalchemy.orm import relationship
@@ -6,7 +6,7 @@ from .base import Base, TimestampMixin, UUIDMixin
class User(Base, UUIDMixin, TimestampMixin):
__tablename__ = 'users'
__tablename__ = "users"
email = Column(String(255), unique=True, nullable=False, index=True)
password_hash = Column(String(255), nullable=False)
@@ -19,7 +19,9 @@ class User(Base, UUIDMixin, TimestampMixin):
deleted_at = Column(DateTime(timezone=True), nullable=True, index=True)
# Relationships
user_organizations = relationship("UserOrganization", back_populates="user", cascade="all, delete-orphan")
user_organizations = relationship(
"UserOrganization", back_populates="user", cascade="all, delete-orphan"
)
def __repr__(self):
return f"<User {self.email}>"
return f"<User {self.email}>"

View File

@@ -1,7 +1,7 @@
# app/models/user_organization.py
from enum import Enum as PyEnum
from sqlalchemy import Column, ForeignKey, Boolean, String, Index, Enum
from sqlalchemy import Boolean, Column, Enum, ForeignKey, Index, String
from sqlalchemy.dialects.postgresql import UUID as PGUUID
from sqlalchemy.orm import relationship
@@ -14,6 +14,7 @@ class OrganizationRole(str, PyEnum):
These provide a baseline role system that can be optionally used.
Projects can extend this or implement their own permission system.
"""
OWNER = "owner" # Full control over organization
ADMIN = "admin" # Can manage users and settings
MEMBER = "member" # Regular member with standard access
@@ -25,25 +26,41 @@ class UserOrganization(Base, TimestampMixin):
Junction table for many-to-many relationship between Users and Organizations.
Includes role information for flexible RBAC.
"""
__tablename__ = 'user_organizations'
user_id = Column(PGUUID(as_uuid=True), ForeignKey('users.id', ondelete='CASCADE'), primary_key=True)
organization_id = Column(PGUUID(as_uuid=True), ForeignKey('organizations.id', ondelete='CASCADE'), primary_key=True)
__tablename__ = "user_organizations"
role = Column(Enum(OrganizationRole), default=OrganizationRole.MEMBER, nullable=False, index=True)
user_id = Column(
PGUUID(as_uuid=True),
ForeignKey("users.id", ondelete="CASCADE"),
primary_key=True,
)
organization_id = Column(
PGUUID(as_uuid=True),
ForeignKey("organizations.id", ondelete="CASCADE"),
primary_key=True,
)
role = Column(
Enum(OrganizationRole),
default=OrganizationRole.MEMBER,
nullable=False,
index=True,
)
is_active = Column(Boolean, default=True, nullable=False, index=True)
# Optional: Custom permissions override for specific users
custom_permissions = Column(String(500), nullable=True) # JSON array of permission strings
custom_permissions = Column(
String(500), nullable=True
) # JSON array of permission strings
# Relationships
user = relationship("User", back_populates="user_organizations")
organization = relationship("Organization", back_populates="user_organizations")
__table_args__ = (
Index('ix_user_org_user_active', 'user_id', 'is_active'),
Index('ix_user_org_org_active', 'organization_id', 'is_active'),
Index('ix_user_org_role', 'role'),
Index("ix_user_org_user_active", "user_id", "is_active"),
Index("ix_user_org_org_active", "organization_id", "is_active"),
Index("ix_user_org_role", "role"),
)
def __repr__(self):

View File

@@ -6,7 +6,10 @@ This allows users to:
- Logout from specific devices
- Manage their active sessions
"""
from sqlalchemy import Column, String, Boolean, DateTime, ForeignKey, Index
from datetime import UTC
from sqlalchemy import Boolean, Column, DateTime, ForeignKey, Index, String
from sqlalchemy.dialects.postgresql import UUID
from sqlalchemy.orm import relationship
@@ -20,19 +23,27 @@ class UserSession(Base, UUIDMixin, TimestampMixin):
Each time a user logs in from a device, a new session is created.
Sessions are identified by the refresh token JTI (JWT ID).
"""
__tablename__ = 'user_sessions'
__tablename__ = "user_sessions"
# Foreign key to user
user_id = Column(UUID(as_uuid=True), ForeignKey('users.id', ondelete='CASCADE'), nullable=False, index=True)
user_id = Column(
UUID(as_uuid=True),
ForeignKey("users.id", ondelete="CASCADE"),
nullable=False,
index=True,
)
# Refresh token identifier (JWT ID from the refresh token)
refresh_token_jti = Column(String(255), unique=True, nullable=False, index=True)
# Device information
device_name = Column(String(255), nullable=True) # "iPhone 14", "Chrome on MacBook"
device_id = Column(String(255), nullable=True) # Persistent device identifier (from client)
ip_address = Column(String(45), nullable=True) # IPv4 (15 chars) or IPv6 (45 chars)
user_agent = Column(String(500), nullable=True) # Browser/app user agent
device_id = Column(
String(255), nullable=True
) # Persistent device identifier (from client)
ip_address = Column(String(45), nullable=True) # IPv4 (15 chars) or IPv6 (45 chars)
user_agent = Column(String(500), nullable=True) # Browser/app user agent
# Session timing
last_used_at = Column(DateTime(timezone=True), nullable=False)
@@ -50,8 +61,8 @@ class UserSession(Base, UUIDMixin, TimestampMixin):
# Composite indexes for performance (defined in migration)
__table_args__ = (
Index('ix_user_sessions_user_active', 'user_id', 'is_active'),
Index('ix_user_sessions_jti_active', 'refresh_token_jti', 'is_active'),
Index("ix_user_sessions_user_active", "user_id", "is_active"),
Index("ix_user_sessions_jti_active", "refresh_token_jti", "is_active"),
)
def __repr__(self):
@@ -60,21 +71,24 @@ class UserSession(Base, UUIDMixin, TimestampMixin):
@property
def is_expired(self) -> bool:
"""Check if session has expired."""
from datetime import datetime, timezone
return self.expires_at < datetime.now(timezone.utc)
from datetime import datetime
return self.expires_at < datetime.now(UTC)
def to_dict(self):
"""Convert session to dictionary for serialization."""
return {
'id': str(self.id),
'user_id': str(self.user_id),
'device_name': self.device_name,
'device_id': self.device_id,
'ip_address': self.ip_address,
'last_used_at': self.last_used_at.isoformat() if self.last_used_at else None,
'expires_at': self.expires_at.isoformat() if self.expires_at else None,
'is_active': self.is_active,
'location_city': self.location_city,
'location_country': self.location_country,
'created_at': self.created_at.isoformat() if self.created_at else None,
"id": str(self.id),
"user_id": str(self.user_id),
"device_name": self.device_name,
"device_id": self.device_id,
"ip_address": self.ip_address,
"last_used_at": self.last_used_at.isoformat()
if self.last_used_at
else None,
"expires_at": self.expires_at.isoformat() if self.expires_at else None,
"is_active": self.is_active,
"location_city": self.location_city,
"location_country": self.location_country,
"created_at": self.created_at.isoformat() if self.created_at else None,
}

View File

@@ -1,18 +1,20 @@
"""
Common schemas used across the API for pagination, responses, filtering, and sorting.
"""
from enum import Enum
from math import ceil
from typing import Generic, TypeVar, List, Optional
from typing import TypeVar
from uuid import UUID
from pydantic import BaseModel, Field
T = TypeVar('T')
T = TypeVar("T")
class SortOrder(str, Enum):
"""Sort order options."""
ASC = "asc"
DESC = "desc"
@@ -20,16 +22,9 @@ class SortOrder(str, Enum):
class PaginationParams(BaseModel):
"""Parameters for pagination."""
page: int = Field(
default=1,
ge=1,
description="Page number (1-indexed)"
)
page: int = Field(default=1, ge=1, description="Page number (1-indexed)")
limit: int = Field(
default=20,
ge=1,
le=100,
description="Number of items per page (max 100)"
default=20, ge=1, le=100, description="Number of items per page (max 100)"
)
@property
@@ -42,34 +37,20 @@ class PaginationParams(BaseModel):
"""Alias for offset (compatibility with existing code)."""
return self.offset
model_config = {
"json_schema_extra": {
"example": {
"page": 1,
"limit": 20
}
}
}
model_config = {"json_schema_extra": {"example": {"page": 1, "limit": 20}}}
class SortParams(BaseModel):
"""Parameters for sorting."""
sort_by: Optional[str] = Field(
default=None,
description="Field name to sort by"
)
sort_by: str | None = Field(default=None, description="Field name to sort by")
sort_order: SortOrder = Field(
default=SortOrder.ASC,
description="Sort order (asc or desc)"
default=SortOrder.ASC, description="Sort order (asc or desc)"
)
model_config = {
"json_schema_extra": {
"example": {
"sort_by": "created_at",
"sort_order": "desc"
}
"example": {"sort_by": "created_at", "sort_order": "desc"}
}
}
@@ -92,32 +73,30 @@ class PaginationMeta(BaseModel):
"page_size": 20,
"total_pages": 8,
"has_next": True,
"has_prev": False
"has_prev": False,
}
}
}
class PaginatedResponse(BaseModel, Generic[T]):
class PaginatedResponse[T](BaseModel):
"""Generic paginated response wrapper."""
data: List[T] = Field(..., description="List of items")
data: list[T] = Field(..., description="List of items")
pagination: PaginationMeta = Field(..., description="Pagination metadata")
model_config = {
"json_schema_extra": {
"example": {
"data": [
{"id": "123", "name": "Example Item"}
],
"data": [{"id": "123", "name": "Example Item"}],
"pagination": {
"total": 150,
"page": 1,
"page_size": 20,
"total_pages": 8,
"has_next": True,
"has_prev": False
}
"has_prev": False,
},
}
}
}
@@ -131,10 +110,7 @@ class MessageResponse(BaseModel):
model_config = {
"json_schema_extra": {
"example": {
"success": True,
"message": "Operation completed successfully"
}
"example": {"success": True, "message": "Operation completed successfully"}
}
}
@@ -142,11 +118,11 @@ class MessageResponse(BaseModel):
class BulkActionRequest(BaseModel):
"""Request schema for bulk operations on multiple items."""
ids: List[UUID] = Field(
ids: list[UUID] = Field(
...,
min_length=1,
max_length=100,
description="List of item IDs to perform action on (max 100)"
description="List of item IDs to perform action on (max 100)",
)
model_config = {
@@ -154,7 +130,7 @@ class BulkActionRequest(BaseModel):
"example": {
"ids": [
"550e8400-e29b-41d4-a716-446655440000",
"6ba7b810-9dad-11d1-80b4-00c04fd430c8"
"6ba7b810-9dad-11d1-80b4-00c04fd430c8",
]
}
}
@@ -166,24 +142,23 @@ class BulkActionResponse(BaseModel):
success: bool = Field(default=True, description="Operation success status")
message: str = Field(..., description="Human-readable message")
affected_count: int = Field(..., description="Number of items affected by the operation")
affected_count: int = Field(
..., description="Number of items affected by the operation"
)
model_config = {
"json_schema_extra": {
"example": {
"success": True,
"message": "Successfully deactivated 5 users",
"affected_count": 5
"affected_count": 5,
}
}
}
def create_pagination_meta(
total: int,
page: int,
limit: int,
items_count: int
total: int, page: int, limit: int, items_count: int
) -> PaginationMeta:
"""
Helper function to create pagination metadata.
@@ -205,5 +180,5 @@ def create_pagination_meta(
page_size=items_count,
total_pages=total_pages,
has_next=page < total_pages,
has_prev=page > 1
has_prev=page > 1,
)

View File

@@ -1,8 +1,8 @@
"""
Error schemas for standardized API error responses.
"""
from enum import Enum
from typing import List, Optional
from pydantic import BaseModel, Field
@@ -53,14 +53,14 @@ class ErrorDetail(BaseModel):
code: ErrorCode = Field(..., description="Machine-readable error code")
message: str = Field(..., description="Human-readable error message")
field: Optional[str] = Field(None, description="Field name if error is field-specific")
field: str | None = Field(None, description="Field name if error is field-specific")
model_config = {
"json_schema_extra": {
"example": {
"code": "VAL_002",
"message": "Password must be at least 8 characters long",
"field": "password"
"field": "password",
}
}
}
@@ -70,7 +70,7 @@ class ErrorResponse(BaseModel):
"""Standardized error response format."""
success: bool = Field(default=False, description="Always false for error responses")
errors: List[ErrorDetail] = Field(..., description="List of errors that occurred")
errors: list[ErrorDetail] = Field(..., description="List of errors that occurred")
model_config = {
"json_schema_extra": {
@@ -80,9 +80,9 @@ class ErrorResponse(BaseModel):
{
"code": "AUTH_001",
"message": "Invalid email or password",
"field": None
"field": None,
}
]
],
}
}
}

View File

@@ -1,10 +1,10 @@
# app/schemas/organizations.py
import re
from datetime import datetime
from typing import Optional, Dict, Any, List
from typing import Any
from uuid import UUID
from pydantic import BaseModel, field_validator, ConfigDict, Field
from pydantic import BaseModel, ConfigDict, Field, field_validator
from app.models.user_organization import OrganizationRole
@@ -12,85 +12,94 @@ from app.models.user_organization import OrganizationRole
# Organization Schemas
class OrganizationBase(BaseModel):
"""Base organization schema with common fields."""
name: str = Field(..., min_length=1, max_length=255)
slug: Optional[str] = Field(None, min_length=1, max_length=255)
description: Optional[str] = None
is_active: bool = True
settings: Optional[Dict[str, Any]] = {}
@field_validator('slug')
name: str = Field(..., min_length=1, max_length=255)
slug: str | None = Field(None, min_length=1, max_length=255)
description: str | None = None
is_active: bool = True
settings: dict[str, Any] | None = {}
@field_validator("slug")
@classmethod
def validate_slug(cls, v: Optional[str]) -> Optional[str]:
def validate_slug(cls, v: str | None) -> str | None:
"""Validate slug format: lowercase, alphanumeric, hyphens only."""
if v is None:
return v
if not re.match(r'^[a-z0-9-]+$', v):
raise ValueError('Slug must contain only lowercase letters, numbers, and hyphens')
if v.startswith('-') or v.endswith('-'):
raise ValueError('Slug cannot start or end with a hyphen')
if '--' in v:
raise ValueError('Slug cannot contain consecutive hyphens')
if not re.match(r"^[a-z0-9-]+$", v):
raise ValueError(
"Slug must contain only lowercase letters, numbers, and hyphens"
)
if v.startswith("-") or v.endswith("-"):
raise ValueError("Slug cannot start or end with a hyphen")
if "--" in v:
raise ValueError("Slug cannot contain consecutive hyphens")
return v
@field_validator('name')
@field_validator("name")
@classmethod
def validate_name(cls, v: str) -> str:
"""Validate organization name."""
if not v or v.strip() == "":
raise ValueError('Organization name cannot be empty')
raise ValueError("Organization name cannot be empty")
return v.strip()
class OrganizationCreate(OrganizationBase):
"""Schema for creating a new organization."""
name: str = Field(..., min_length=1, max_length=255)
slug: str = Field(..., min_length=1, max_length=255)
class OrganizationUpdate(BaseModel):
"""Schema for updating an organization."""
name: Optional[str] = Field(None, min_length=1, max_length=255)
slug: Optional[str] = Field(None, min_length=1, max_length=255)
description: Optional[str] = None
is_active: Optional[bool] = None
settings: Optional[Dict[str, Any]] = None
@field_validator('slug')
name: str | None = Field(None, min_length=1, max_length=255)
slug: str | None = Field(None, min_length=1, max_length=255)
description: str | None = None
is_active: bool | None = None
settings: dict[str, Any] | None = None
@field_validator("slug")
@classmethod
def validate_slug(cls, v: Optional[str]) -> Optional[str]:
def validate_slug(cls, v: str | None) -> str | None:
"""Validate slug format."""
if v is None:
return v
if not re.match(r'^[a-z0-9-]+$', v):
raise ValueError('Slug must contain only lowercase letters, numbers, and hyphens')
if v.startswith('-') or v.endswith('-'):
raise ValueError('Slug cannot start or end with a hyphen')
if '--' in v:
raise ValueError('Slug cannot contain consecutive hyphens')
if not re.match(r"^[a-z0-9-]+$", v):
raise ValueError(
"Slug must contain only lowercase letters, numbers, and hyphens"
)
if v.startswith("-") or v.endswith("-"):
raise ValueError("Slug cannot start or end with a hyphen")
if "--" in v:
raise ValueError("Slug cannot contain consecutive hyphens")
return v
@field_validator('name')
@field_validator("name")
@classmethod
def validate_name(cls, v: Optional[str]) -> Optional[str]:
def validate_name(cls, v: str | None) -> str | None:
"""Validate organization name."""
if v is not None and (not v or v.strip() == ""):
raise ValueError('Organization name cannot be empty')
raise ValueError("Organization name cannot be empty")
return v.strip() if v else v
class OrganizationResponse(OrganizationBase):
"""Schema for organization API responses."""
id: UUID
created_at: datetime
updated_at: Optional[datetime] = None
member_count: Optional[int] = 0
updated_at: datetime | None = None
member_count: int | None = 0
model_config = ConfigDict(from_attributes=True)
class OrganizationListResponse(BaseModel):
"""Schema for paginated organization list responses."""
organizations: List[OrganizationResponse]
organizations: list[OrganizationResponse]
total: int
page: int
page_size: int
@@ -100,44 +109,49 @@ class OrganizationListResponse(BaseModel):
# User-Organization Relationship Schemas
class UserOrganizationBase(BaseModel):
"""Base schema for user-organization relationship."""
role: OrganizationRole = OrganizationRole.MEMBER
is_active: bool = True
custom_permissions: Optional[str] = None
custom_permissions: str | None = None
class UserOrganizationCreate(BaseModel):
"""Schema for adding a user to an organization."""
user_id: UUID
role: OrganizationRole = OrganizationRole.MEMBER
custom_permissions: Optional[str] = None
custom_permissions: str | None = None
class UserOrganizationUpdate(BaseModel):
"""Schema for updating user's role in an organization."""
role: Optional[OrganizationRole] = None
is_active: Optional[bool] = None
custom_permissions: Optional[str] = None
role: OrganizationRole | None = None
is_active: bool | None = None
custom_permissions: str | None = None
class UserOrganizationResponse(BaseModel):
"""Schema for user-organization relationship responses."""
user_id: UUID
organization_id: UUID
role: OrganizationRole
is_active: bool
custom_permissions: Optional[str] = None
custom_permissions: str | None = None
created_at: datetime
updated_at: Optional[datetime] = None
updated_at: datetime | None = None
model_config = ConfigDict(from_attributes=True)
class OrganizationMemberResponse(BaseModel):
"""Schema for organization member information."""
user_id: UUID
email: str
first_name: str
last_name: Optional[str] = None
last_name: str | None = None
role: OrganizationRole
is_active: bool
joined_at: datetime
@@ -147,7 +161,8 @@ class OrganizationMemberResponse(BaseModel):
class OrganizationMemberListResponse(BaseModel):
"""Schema for paginated organization member list."""
members: List[OrganizationMemberResponse]
members: list[OrganizationMemberResponse]
total: int
page: int
page_size: int

View File

@@ -1,37 +1,44 @@
"""
Pydantic schemas for user session management.
"""
from datetime import datetime
from typing import Optional
from uuid import UUID
from pydantic import BaseModel, Field, ConfigDict
from pydantic import BaseModel, ConfigDict, Field
class SessionBase(BaseModel):
"""Base schema for user sessions."""
device_name: Optional[str] = Field(None, max_length=255, description="Friendly device name")
device_id: Optional[str] = Field(None, max_length=255, description="Persistent device identifier")
device_name: str | None = Field(
None, max_length=255, description="Friendly device name"
)
device_id: str | None = Field(
None, max_length=255, description="Persistent device identifier"
)
class SessionCreate(SessionBase):
"""Schema for creating a new session (internal use)."""
user_id: UUID
refresh_token_jti: str = Field(..., max_length=255)
ip_address: Optional[str] = Field(None, max_length=45)
user_agent: Optional[str] = Field(None, max_length=500)
ip_address: str | None = Field(None, max_length=45)
user_agent: str | None = Field(None, max_length=500)
last_used_at: datetime
expires_at: datetime
location_city: Optional[str] = Field(None, max_length=100)
location_country: Optional[str] = Field(None, max_length=100)
location_city: str | None = Field(None, max_length=100)
location_country: str | None = Field(None, max_length=100)
class SessionUpdate(BaseModel):
"""Schema for updating a session (internal use)."""
last_used_at: Optional[datetime] = None
is_active: Optional[bool] = None
refresh_token_jti: Optional[str] = None
expires_at: Optional[datetime] = None
last_used_at: datetime | None = None
is_active: bool | None = None
refresh_token_jti: str | None = None
expires_at: datetime | None = None
class SessionResponse(SessionBase):
@@ -40,14 +47,17 @@ class SessionResponse(SessionBase):
This is what users see when they list their active sessions.
"""
id: UUID
ip_address: Optional[str] = None
location_city: Optional[str] = None
location_country: Optional[str] = None
ip_address: str | None = None
location_city: str | None = None
location_country: str | None = None
last_used_at: datetime
created_at: datetime
expires_at: datetime
is_current: bool = Field(default=False, description="Whether this is the current session")
is_current: bool = Field(
default=False, description="Whether this is the current session"
)
model_config = ConfigDict(
from_attributes=True,
@@ -62,14 +72,15 @@ class SessionResponse(SessionBase):
"last_used_at": "2025-10-31T12:00:00Z",
"created_at": "2025-10-30T09:00:00Z",
"expires_at": "2025-11-06T09:00:00Z",
"is_current": True
"is_current": True,
}
}
},
)
class SessionListResponse(BaseModel):
"""Response containing list of sessions."""
sessions: list[SessionResponse]
total: int = Field(..., description="Total number of active sessions")
@@ -84,10 +95,10 @@ class SessionListResponse(BaseModel):
"last_used_at": "2025-10-31T12:00:00Z",
"created_at": "2025-10-30T09:00:00Z",
"expires_at": "2025-11-06T09:00:00Z",
"is_current": True
"is_current": True,
}
],
"total": 1
"total": 1,
}
}
)
@@ -95,17 +106,14 @@ class SessionListResponse(BaseModel):
class LogoutRequest(BaseModel):
"""Request schema for logout endpoint."""
refresh_token: str = Field(
...,
description="Refresh token for the session to logout from",
min_length=10
..., description="Refresh token for the session to logout from", min_length=10
)
model_config = ConfigDict(
json_schema_extra={
"example": {
"refresh_token": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9..."
}
"example": {"refresh_token": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9..."}
}
)
@@ -116,13 +124,14 @@ class AdminSessionResponse(SessionBase):
Includes user information for admin to see who owns each session.
"""
id: UUID
user_id: UUID
user_email: str = Field(..., description="Email of the user who owns this session")
user_full_name: Optional[str] = Field(None, description="Full name of the user")
ip_address: Optional[str] = None
location_city: Optional[str] = None
location_country: Optional[str] = None
user_full_name: str | None = Field(None, description="Full name of the user")
ip_address: str | None = None
location_city: str | None = None
location_country: str | None = None
last_used_at: datetime
created_at: datetime
expires_at: datetime
@@ -144,20 +153,21 @@ class AdminSessionResponse(SessionBase):
"last_used_at": "2025-10-31T12:00:00Z",
"created_at": "2025-10-30T09:00:00Z",
"expires_at": "2025-11-06T09:00:00Z",
"is_active": True
"is_active": True,
}
}
},
)
class DeviceInfo(BaseModel):
"""Device information extracted from request."""
device_name: Optional[str] = None
device_id: Optional[str] = None
ip_address: Optional[str] = None
user_agent: Optional[str] = None
location_city: Optional[str] = None
location_country: Optional[str] = None
device_name: str | None = None
device_id: str | None = None
ip_address: str | None = None
user_agent: str | None = None
location_city: str | None = None
location_country: str | None = None
model_config = ConfigDict(
json_schema_extra={
@@ -167,7 +177,7 @@ class DeviceInfo(BaseModel):
"ip_address": "192.168.1.50",
"user_agent": "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7)...",
"location_city": "San Francisco",
"location_country": "United States"
"location_country": "United States",
}
}
)

View File

@@ -1,9 +1,9 @@
# app/schemas/users.py
from datetime import datetime
from typing import Optional, Dict, Any
from typing import Any
from uuid import UUID
from pydantic import BaseModel, EmailStr, field_validator, ConfigDict, Field
from pydantic import BaseModel, ConfigDict, EmailStr, Field, field_validator
from app.schemas.validators import validate_password_strength, validate_phone_number
@@ -11,12 +11,12 @@ from app.schemas.validators import validate_password_strength, validate_phone_nu
class UserBase(BaseModel):
email: EmailStr
first_name: str
last_name: Optional[str] = None
phone_number: Optional[str] = None
last_name: str | None = None
phone_number: str | None = None
@field_validator('phone_number')
@field_validator("phone_number")
@classmethod
def validate_phone(cls, v: Optional[str]) -> Optional[str]:
def validate_phone(cls, v: str | None) -> str | None:
return validate_phone_number(v)
@@ -24,7 +24,7 @@ class UserCreate(UserBase):
password: str
is_superuser: bool = False
@field_validator('password')
@field_validator("password")
@classmethod
def password_strength(cls, v: str) -> str:
"""Enterprise-grade password strength validation"""
@@ -32,30 +32,32 @@ class UserCreate(UserBase):
class UserUpdate(BaseModel):
first_name: Optional[str] = None
last_name: Optional[str] = None
phone_number: Optional[str] = None
password: Optional[str] = None
preferences: Optional[Dict[str, Any]] = None
is_active: Optional[bool] = None # Changed default from True to None to avoid unintended updates
is_superuser: Optional[bool] = None # Explicitly reject privilege escalation attempts
first_name: str | None = None
last_name: str | None = None
phone_number: str | None = None
password: str | None = None
preferences: dict[str, Any] | None = None
is_active: bool | None = (
None # Changed default from True to None to avoid unintended updates
)
is_superuser: bool | None = None # Explicitly reject privilege escalation attempts
@field_validator('phone_number')
@field_validator("phone_number")
@classmethod
def validate_phone(cls, v: Optional[str]) -> Optional[str]:
def validate_phone(cls, v: str | None) -> str | None:
return validate_phone_number(v)
@field_validator('password')
@field_validator("password")
@classmethod
def password_strength(cls, v: Optional[str]) -> Optional[str]:
def password_strength(cls, v: str | None) -> str | None:
"""Enterprise-grade password strength validation"""
if v is None:
return v
return validate_password_strength(v)
@field_validator('is_superuser')
@field_validator("is_superuser")
@classmethod
def prevent_superuser_modification(cls, v: Optional[bool]) -> Optional[bool]:
def prevent_superuser_modification(cls, v: bool | None) -> bool | None:
"""Prevent users from modifying their superuser status via this schema."""
if v is not None:
raise ValueError("Cannot modify superuser status through user update")
@@ -67,7 +69,7 @@ class UserInDB(UserBase):
is_active: bool
is_superuser: bool
created_at: datetime
updated_at: Optional[datetime] = None
updated_at: datetime | None = None
model_config = ConfigDict(from_attributes=True)
@@ -77,28 +79,28 @@ class UserResponse(UserBase):
is_active: bool
is_superuser: bool
created_at: datetime
updated_at: Optional[datetime] = None
updated_at: datetime | None = None
model_config = ConfigDict(from_attributes=True)
class Token(BaseModel):
access_token: str
refresh_token: Optional[str] = None
refresh_token: str | None = None
token_type: str = "bearer"
user: "UserResponse" # Forward reference since UserResponse is defined above
expires_in: Optional[int] = None # Token expiration in seconds
expires_in: int | None = None # Token expiration in seconds
class TokenPayload(BaseModel):
sub: str # User ID
exp: int # Expiration time
iat: Optional[int] = None # Issued at
jti: Optional[str] = None # JWT ID
is_superuser: Optional[bool] = False
first_name: Optional[str] = None
email: Optional[str] = None
type: Optional[str] = None # Token type (access/refresh)
iat: int | None = None # Issued at
jti: str | None = None # JWT ID
is_superuser: bool | None = False
first_name: str | None = None
email: str | None = None
type: str | None = None # Token type (access/refresh)
class TokenData(BaseModel):
@@ -108,10 +110,11 @@ class TokenData(BaseModel):
class PasswordChange(BaseModel):
"""Schema for changing password (requires current password)."""
current_password: str
new_password: str
@field_validator('new_password')
@field_validator("new_password")
@classmethod
def password_strength(cls, v: str) -> str:
"""Enterprise-grade password strength validation"""
@@ -120,10 +123,11 @@ class PasswordChange(BaseModel):
class PasswordReset(BaseModel):
"""Schema for resetting password (via email token)."""
token: str
new_password: str
@field_validator('new_password')
@field_validator("new_password")
@classmethod
def password_strength(cls, v: str) -> str:
"""Enterprise-grade password strength validation"""
@@ -141,23 +145,19 @@ class RefreshTokenRequest(BaseModel):
class PasswordResetRequest(BaseModel):
"""Schema for requesting a password reset."""
email: EmailStr = Field(..., description="Email address of the account")
model_config = {
"json_schema_extra": {
"example": {
"email": "user@example.com"
}
}
}
model_config = {"json_schema_extra": {"example": {"email": "user@example.com"}}}
class PasswordResetConfirm(BaseModel):
"""Schema for confirming a password reset with token."""
token: str = Field(..., description="Password reset token from email")
new_password: str = Field(..., min_length=8, description="New password")
@field_validator('new_password')
@field_validator("new_password")
@classmethod
def password_strength(cls, v: str) -> str:
"""Enterprise-grade password strength validation"""
@@ -167,7 +167,7 @@ class PasswordResetConfirm(BaseModel):
"json_schema_extra": {
"example": {
"token": "eyJwYXlsb2FkIjp7ImVtYWlsIjoidXNlckBleGFtcGxlLmNvbSIsImV4cCI6MTcxMjM0NTY3OH19",
"new_password": "NewSecurePassword123"
"new_password": "NewSecurePassword123",
}
}
}

View File

@@ -4,19 +4,34 @@ Shared validators for Pydantic schemas.
This module provides reusable validation functions to ensure consistency
across all schemas and avoid code duplication.
"""
import re
from typing import Set
# Common weak passwords that should be rejected
COMMON_PASSWORDS: Set[str] = {
'password', 'password1', 'password123', 'password1234',
'admin', 'admin123', 'admin1234',
'welcome', 'welcome1', 'welcome123',
'qwerty', 'qwerty123',
'12345678', '123456789', '1234567890',
'letmein', 'letmein1', 'letmein123',
'monkey123', 'dragon123',
'passw0rd', 'p@ssw0rd', 'p@ssword',
COMMON_PASSWORDS: set[str] = {
"password",
"password1",
"password123",
"password1234",
"admin",
"admin123",
"admin1234",
"welcome",
"welcome1",
"welcome123",
"qwerty",
"qwerty123",
"12345678",
"123456789",
"1234567890",
"letmein",
"letmein1",
"letmein123",
"monkey123",
"dragon123",
"passw0rd",
"p@ssw0rd",
"p@ssword",
}
@@ -47,18 +62,21 @@ def validate_password_strength(password: str) -> str:
"""
# Check minimum length
if len(password) < 12:
raise ValueError('Password must be at least 12 characters long')
raise ValueError("Password must be at least 12 characters long")
# Check against common passwords (case-insensitive)
if password.lower() in COMMON_PASSWORDS:
raise ValueError('Password is too common. Please choose a stronger password')
raise ValueError("Password is too common. Please choose a stronger password")
# Check for required character types
checks = [
(any(c.islower() for c in password), 'at least one lowercase letter'),
(any(c.isupper() for c in password), 'at least one uppercase letter'),
(any(c.isdigit() for c in password), 'at least one digit'),
(any(c in '!@#$%^&*()_+-=[]{}|;:,.<>?~`' for c in password), 'at least one special character (!@#$%^&*()_+-=[]{}|;:,.<>?~`)')
(any(c.islower() for c in password), "at least one lowercase letter"),
(any(c.isupper() for c in password), "at least one uppercase letter"),
(any(c.isdigit() for c in password), "at least one digit"),
(
any(c in "!@#$%^&*()_+-=[]{}|;:,.<>?~`" for c in password),
"at least one special character (!@#$%^&*()_+-=[]{}|;:,.<>?~`)",
),
]
failed = [msg for check, msg in checks if not check]
@@ -94,10 +112,10 @@ def validate_phone_number(phone: str | None) -> str | None:
# Check for empty strings
if not phone or phone.strip() == "":
raise ValueError('Phone number cannot be empty')
raise ValueError("Phone number cannot be empty")
# Remove all spaces and formatting characters
cleaned = re.sub(r'[\s\-\(\)]', '', phone)
cleaned = re.sub(r"[\s\-\(\)]", "", phone)
# Basic pattern:
# Must start with + or 0
@@ -105,19 +123,19 @@ def validate_phone_number(phone: str | None) -> str | None:
# After 0 must have at least 8 digits
# Maximum total length of 15 digits (international standard)
# Only allowed characters are + at start and digits
pattern = r'^(?:\+[0-9]{8,14}|0[0-9]{8,14})$'
pattern = r"^(?:\+[0-9]{8,14}|0[0-9]{8,14})$"
if not re.match(pattern, cleaned):
raise ValueError('Phone number must start with + or 0 followed by 8-14 digits')
raise ValueError("Phone number must start with + or 0 followed by 8-14 digits")
# Additional validation to catch specific invalid cases
# NOTE: These checks are defensive code - the regex pattern above already catches these cases
if cleaned.count('+') > 1: # pragma: no cover
raise ValueError('Phone number can only contain one + symbol at the start')
if cleaned.count("+") > 1: # pragma: no cover
raise ValueError("Phone number can only contain one + symbol at the start")
# Check for any non-digit characters (except the leading +)
if not all(c.isdigit() for c in cleaned[1:]): # pragma: no cover
raise ValueError('Phone number can only contain digits after the prefix')
raise ValueError("Phone number can only contain digits after the prefix")
return cleaned
@@ -169,16 +187,16 @@ def validate_slug(slug: str) -> str:
ValueError: If slug format is invalid
"""
if not slug or len(slug) < 2:
raise ValueError('Slug must be at least 2 characters long')
raise ValueError("Slug must be at least 2 characters long")
if len(slug) > 50:
raise ValueError('Slug must be at most 50 characters long')
raise ValueError("Slug must be at most 50 characters long")
# Check format
if not re.match(r'^[a-z0-9]+(?:-[a-z0-9]+)*$', slug):
if not re.match(r"^[a-z0-9]+(?:-[a-z0-9]+)*$", slug):
raise ValueError(
'Slug can only contain lowercase letters, numbers, and hyphens. '
'It cannot start or end with a hyphen, and cannot contain consecutive hyphens'
"Slug can only contain lowercase letters, numbers, and hyphens. "
"It cannot start or end with a hyphen, and cannot contain consecutive hyphens"
)
return slug

View File

@@ -1,18 +1,17 @@
# app/services/auth_service.py
import logging
from typing import Optional
from uuid import UUID
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.core.auth import (
verify_password_async,
get_password_hash_async,
TokenExpiredError,
TokenInvalidError,
create_access_token,
create_refresh_token,
TokenExpiredError,
TokenInvalidError
get_password_hash_async,
verify_password_async,
)
from app.core.config import settings
from app.core.exceptions import AuthenticationError
@@ -26,7 +25,9 @@ class AuthService:
"""Service for handling authentication operations"""
@staticmethod
async def authenticate_user(db: AsyncSession, email: str, password: str) -> Optional[User]:
async def authenticate_user(
db: AsyncSession, email: str, password: str
) -> User | None:
"""
Authenticate a user with email and password using async password verification.
@@ -87,7 +88,7 @@ class AuthService:
last_name=user_data.last_name,
phone_number=user_data.phone_number,
is_active=True,
is_superuser=False
is_superuser=False,
)
db.add(user)
@@ -103,8 +104,8 @@ class AuthService:
except Exception as e:
# Rollback on any database errors
await db.rollback()
logger.error(f"Error creating user: {str(e)}", exc_info=True)
raise AuthenticationError(f"Failed to create user: {str(e)}")
logger.error(f"Error creating user: {e!s}", exc_info=True)
raise AuthenticationError(f"Failed to create user: {e!s}")
@staticmethod
def create_tokens(user: User) -> Token:
@@ -121,18 +122,13 @@ class AuthService:
claims = {
"is_superuser": user.is_superuser,
"email": user.email,
"first_name": user.first_name
"first_name": user.first_name,
}
# Create tokens
access_token = create_access_token(
subject=str(user.id),
claims=claims
)
access_token = create_access_token(subject=str(user.id), claims=claims)
refresh_token = create_refresh_token(
subject=str(user.id)
)
refresh_token = create_refresh_token(subject=str(user.id))
# Convert User model to UserResponse schema
user_response = UserResponse.model_validate(user)
@@ -141,7 +137,8 @@ class AuthService:
access_token=access_token,
refresh_token=refresh_token,
user=user_response,
expires_in=settings.ACCESS_TOKEN_EXPIRE_MINUTES * 60 # Convert minutes to seconds
expires_in=settings.ACCESS_TOKEN_EXPIRE_MINUTES
* 60, # Convert minutes to seconds
)
@staticmethod
@@ -180,11 +177,13 @@ class AuthService:
return AuthService.create_tokens(user)
except (TokenExpiredError, TokenInvalidError) as e:
logger.warning(f"Token refresh failed: {str(e)}")
logger.warning(f"Token refresh failed: {e!s}")
raise
@staticmethod
async def change_password(db: AsyncSession, user_id: UUID, current_password: str, new_password: str) -> bool:
async def change_password(
db: AsyncSession, user_id: UUID, current_password: str, new_password: str
) -> bool:
"""
Change a user's password.
@@ -223,5 +222,7 @@ class AuthService:
except Exception as e:
# Rollback on any database errors
await db.rollback()
logger.error(f"Error changing password for user {user_id}: {str(e)}", exc_info=True)
raise AuthenticationError(f"Failed to change password: {str(e)}")
logger.error(
f"Error changing password for user {user_id}: {e!s}", exc_info=True
)
raise AuthenticationError(f"Failed to change password: {e!s}")

View File

@@ -5,9 +5,9 @@ Email service with placeholder implementation.
This service provides email sending functionality with a simple console/log-based
placeholder that can be easily replaced with a real email provider (SendGrid, SES, etc.)
"""
import logging
from abc import ABC, abstractmethod
from typing import List, Optional
from app.core.config import settings
@@ -20,13 +20,12 @@ class EmailBackend(ABC):
@abstractmethod
async def send_email(
self,
to: List[str],
to: list[str],
subject: str,
html_content: str,
text_content: Optional[str] = None
text_content: str | None = None,
) -> bool:
"""Send an email."""
pass
class ConsoleEmailBackend(EmailBackend):
@@ -39,10 +38,10 @@ class ConsoleEmailBackend(EmailBackend):
async def send_email(
self,
to: List[str],
to: list[str],
subject: str,
html_content: str,
text_content: Optional[str] = None
text_content: str | None = None,
) -> bool:
"""
Log email content to console/logs.
@@ -88,10 +87,10 @@ class SMTPEmailBackend(EmailBackend):
async def send_email(
self,
to: List[str],
to: list[str],
subject: str,
html_content: str,
text_content: Optional[str] = None
text_content: str | None = None,
) -> bool:
"""Send email via SMTP."""
# TODO: Implement SMTP sending
@@ -108,7 +107,7 @@ class EmailService:
and can be configured to use different backends (console, SMTP, SendGrid, etc.)
"""
def __init__(self, backend: Optional[EmailBackend] = None):
def __init__(self, backend: EmailBackend | None = None):
"""
Initialize email service with a backend.
@@ -118,10 +117,7 @@ class EmailService:
self.backend = backend or ConsoleEmailBackend()
async def send_password_reset_email(
self,
to_email: str,
reset_token: str,
user_name: Optional[str] = None
self, to_email: str, reset_token: str, user_name: str | None = None
) -> bool:
"""
Send password reset email.
@@ -142,7 +138,7 @@ class EmailService:
# Plain text version
text_content = f"""
Hello{' ' + user_name if user_name else ''},
Hello{" " + user_name if user_name else ""},
You requested a password reset for your account. Click the link below to reset your password:
@@ -177,7 +173,7 @@ The {settings.PROJECT_NAME} Team
<h1>Password Reset</h1>
</div>
<div class="content">
<p>Hello{' ' + user_name if user_name else ''},</p>
<p>Hello{" " + user_name if user_name else ""},</p>
<p>You requested a password reset for your account. Click the button below to reset your password:</p>
<p style="text-align: center;">
<a href="{reset_url}" class="button">Reset Password</a>
@@ -200,17 +196,14 @@ The {settings.PROJECT_NAME} Team
to=[to_email],
subject=subject,
html_content=html_content,
text_content=text_content
text_content=text_content,
)
except Exception as e:
logger.error(f"Failed to send password reset email to {to_email}: {str(e)}")
logger.error(f"Failed to send password reset email to {to_email}: {e!s}")
return False
async def send_email_verification(
self,
to_email: str,
verification_token: str,
user_name: Optional[str] = None
self, to_email: str, verification_token: str, user_name: str | None = None
) -> bool:
"""
Send email verification email.
@@ -224,14 +217,16 @@ The {settings.PROJECT_NAME} Team
True if email sent successfully
"""
# Generate verification URL
verification_url = f"{settings.FRONTEND_URL}/verify-email?token={verification_token}"
verification_url = (
f"{settings.FRONTEND_URL}/verify-email?token={verification_token}"
)
# Prepare email content
subject = "Verify Your Email Address"
# Plain text version
text_content = f"""
Hello{' ' + user_name if user_name else ''},
Hello{" " + user_name if user_name else ""},
Thank you for signing up! Please verify your email address by clicking the link below:
@@ -266,7 +261,7 @@ The {settings.PROJECT_NAME} Team
<h1>Verify Your Email</h1>
</div>
<div class="content">
<p>Hello{' ' + user_name if user_name else ''},</p>
<p>Hello{" " + user_name if user_name else ""},</p>
<p>Thank you for signing up! Please verify your email address by clicking the button below:</p>
<p style="text-align: center;">
<a href="{verification_url}" class="button">Verify Email</a>
@@ -289,10 +284,10 @@ The {settings.PROJECT_NAME} Team
to=[to_email],
subject=subject,
html_content=html_content,
text_content=text_content
text_content=text_content,
)
except Exception as e:
logger.error(f"Failed to send verification email to {to_email}: {str(e)}")
logger.error(f"Failed to send verification email to {to_email}: {e!s}")
return False

View File

@@ -3,8 +3,9 @@ Background job for cleaning up expired sessions.
This service runs periodically to remove old session records from the database.
"""
import logging
from datetime import datetime, timezone
from datetime import UTC, datetime
from app.core.database import SessionLocal
from app.crud.session import session as session_crud
@@ -39,7 +40,7 @@ async def cleanup_expired_sessions(keep_days: int = 30) -> int:
return count
except Exception as e:
logger.error(f"Error during session cleanup: {str(e)}", exc_info=True)
logger.error(f"Error during session cleanup: {e!s}", exc_info=True)
return 0
@@ -52,20 +53,21 @@ async def get_session_statistics() -> dict:
"""
async with SessionLocal() as db:
try:
from sqlalchemy import func, select
from app.models.user_session import UserSession
from sqlalchemy import select, func
total_result = await db.execute(select(func.count(UserSession.id)))
total_sessions = total_result.scalar_one()
active_result = await db.execute(
select(func.count(UserSession.id)).where(UserSession.is_active == True)
select(func.count(UserSession.id)).where(UserSession.is_active)
)
active_sessions = active_result.scalar_one()
expired_result = await db.execute(
select(func.count(UserSession.id)).where(
UserSession.expires_at < datetime.now(timezone.utc)
UserSession.expires_at < datetime.now(UTC)
)
)
expired_sessions = expired_result.scalar_one()
@@ -82,5 +84,5 @@ async def get_session_statistics() -> dict:
return stats
except Exception as e:
logger.error(f"Error getting session statistics: {str(e)}", exc_info=True)
logger.error(f"Error getting session statistics: {e!s}", exc_info=True)
return {}

View File

@@ -2,7 +2,8 @@
Authentication utilities for testing.
This module provides tools to bypass FastAPI's authentication in tests.
"""
from typing import Callable, Dict, Optional
from collections.abc import Callable
from fastapi import FastAPI
from fastapi.security import OAuth2PasswordBearer
@@ -13,9 +14,9 @@ from app.models.user import User
def create_test_auth_client(
app: FastAPI,
test_user: User,
extra_overrides: Optional[Dict[Callable, Callable]] = None
app: FastAPI,
test_user: User,
extra_overrides: dict[Callable, Callable] | None = None,
) -> TestClient:
"""
Create a test client with authentication pre-configured.
@@ -47,10 +48,7 @@ def create_test_auth_client(
return TestClient(app)
def create_test_optional_auth_client(
app: FastAPI,
test_user: User
) -> TestClient:
def create_test_optional_auth_client(app: FastAPI, test_user: User) -> TestClient:
"""
Create a test client with optional authentication pre-configured.
@@ -70,10 +68,7 @@ def create_test_optional_auth_client(
return TestClient(app)
def create_test_superuser_client(
app: FastAPI,
test_user: User
) -> TestClient:
def create_test_superuser_client(app: FastAPI, test_user: User) -> TestClient:
"""
Create a test client with superuser authentication pre-configured.
@@ -120,7 +115,7 @@ def cleanup_test_client_auth(app: FastAPI) -> None:
auth_deps = [
get_current_user,
get_optional_current_user,
OAuth2PasswordBearer(tokenUrl="/api/v1/auth/login")
OAuth2PasswordBearer(tokenUrl="/api/v1/auth/login"),
]
# Remove overrides

View File

@@ -1,8 +1,8 @@
"""
Utility functions for extracting and parsing device information from HTTP requests.
"""
import re
from typing import Optional
from fastapi import Request
@@ -19,11 +19,11 @@ def extract_device_info(request: Request) -> DeviceInfo:
Returns:
DeviceInfo object with parsed device information
"""
user_agent = request.headers.get('user-agent', '')
user_agent = request.headers.get("user-agent", "")
device_info = DeviceInfo(
device_name=parse_device_name(user_agent),
device_id=request.headers.get('x-device-id'), # Client must send this header
device_id=request.headers.get("x-device-id"), # Client must send this header
ip_address=get_client_ip(request),
user_agent=user_agent[:500] if user_agent else None, # Truncate to max length
location_city=None, # Can be populated via IP geolocation service
@@ -33,7 +33,7 @@ def extract_device_info(request: Request) -> DeviceInfo:
return device_info
def parse_device_name(user_agent: str) -> Optional[str]:
def parse_device_name(user_agent: str) -> str | None:
"""
Parse user agent string to extract a friendly device name.
@@ -54,48 +54,48 @@ def parse_device_name(user_agent: str) -> Optional[str]:
user_agent_lower = user_agent.lower()
# Mobile devices (check first, as they can contain desktop patterns too)
if 'iphone' in user_agent_lower:
if "iphone" in user_agent_lower:
return "iPhone"
elif 'ipad' in user_agent_lower:
elif "ipad" in user_agent_lower:
return "iPad"
elif 'android' in user_agent_lower:
elif "android" in user_agent_lower:
# Try to extract device model
android_match = re.search(r'android.*;\s*([^)]+)\s*build', user_agent_lower)
android_match = re.search(r"android.*;\s*([^)]+)\s*build", user_agent_lower)
if android_match:
device_model = android_match.group(1).strip()
return f"Android ({device_model.title()})"
return "Android device"
elif 'windows phone' in user_agent_lower:
elif "windows phone" in user_agent_lower:
return "Windows Phone"
# Tablets (check before desktop, as some tablets contain "android")
elif 'tablet' in user_agent_lower:
elif "tablet" in user_agent_lower:
return "Tablet"
# Smart TVs (check before desktop OS patterns)
elif any(tv in user_agent_lower for tv in ['smart-tv', 'smarttv']):
elif any(tv in user_agent_lower for tv in ["smart-tv", "smarttv"]):
return "Smart TV"
# Game consoles (check before desktop OS patterns, as Xbox contains "Windows")
elif 'playstation' in user_agent_lower:
elif "playstation" in user_agent_lower:
return "PlayStation"
elif 'xbox' in user_agent_lower:
elif "xbox" in user_agent_lower:
return "Xbox"
elif 'nintendo' in user_agent_lower:
elif "nintendo" in user_agent_lower:
return "Nintendo"
# Desktop operating systems
elif 'macintosh' in user_agent_lower or 'mac os x' in user_agent_lower:
elif "macintosh" in user_agent_lower or "mac os x" in user_agent_lower:
# Try to extract browser
browser = extract_browser(user_agent)
return f"{browser} on Mac" if browser else "Mac"
elif 'windows' in user_agent_lower:
elif "windows" in user_agent_lower:
browser = extract_browser(user_agent)
return f"{browser} on Windows" if browser else "Windows PC"
elif 'linux' in user_agent_lower and 'android' not in user_agent_lower:
elif "linux" in user_agent_lower and "android" not in user_agent_lower:
browser = extract_browser(user_agent)
return f"{browser} on Linux" if browser else "Linux"
elif 'cros' in user_agent_lower:
elif "cros" in user_agent_lower:
return "Chromebook"
# Fallback: just return browser name if detected
@@ -106,7 +106,7 @@ def parse_device_name(user_agent: str) -> Optional[str]:
return "Unknown device"
def extract_browser(user_agent: str) -> Optional[str]:
def extract_browser(user_agent: str) -> str | None:
"""
Extract browser name from user agent string.
@@ -126,26 +126,26 @@ def extract_browser(user_agent: str) -> Optional[str]:
user_agent_lower = user_agent.lower()
# Check specific browsers (order matters - check Edge before Chrome!)
if 'edg/' in user_agent_lower or 'edge/' in user_agent_lower:
if "edg/" in user_agent_lower or "edge/" in user_agent_lower:
return "Edge"
elif 'opr/' in user_agent_lower or 'opera' in user_agent_lower:
elif "opr/" in user_agent_lower or "opera" in user_agent_lower:
return "Opera"
elif 'chrome/' in user_agent_lower:
elif "chrome/" in user_agent_lower:
return "Chrome"
elif 'safari/' in user_agent_lower:
elif "safari/" in user_agent_lower:
# Make sure it's actually Safari, not Chrome (which also contains "Safari")
if 'chrome' not in user_agent_lower:
if "chrome" not in user_agent_lower:
return "Safari"
return None
elif 'firefox/' in user_agent_lower:
elif "firefox/" in user_agent_lower:
return "Firefox"
elif 'msie' in user_agent_lower or 'trident/' in user_agent_lower:
elif "msie" in user_agent_lower or "trident/" in user_agent_lower:
return "Internet Explorer"
return None
def get_client_ip(request: Request) -> Optional[str]:
def get_client_ip(request: Request) -> str | None:
"""
Extract client IP address from request, considering proxy headers.
@@ -163,14 +163,14 @@ def get_client_ip(request: Request) -> Optional[str]:
- request.client.host is fallback for direct connections
"""
# Check X-Forwarded-For (common in proxied environments)
x_forwarded_for = request.headers.get('x-forwarded-for')
x_forwarded_for = request.headers.get("x-forwarded-for")
if x_forwarded_for:
# Get the first IP (original client)
client_ip = x_forwarded_for.split(',')[0].strip()
client_ip = x_forwarded_for.split(",")[0].strip()
return client_ip
# Check X-Real-IP (used by some proxies like nginx)
x_real_ip = request.headers.get('x-real-ip')
x_real_ip = request.headers.get("x-real-ip")
if x_real_ip:
return x_real_ip.strip()
@@ -195,9 +195,17 @@ def is_mobile_device(user_agent: str) -> bool:
return False
mobile_patterns = [
'mobile', 'android', 'iphone', 'ipad', 'ipod',
'blackberry', 'windows phone', 'webos', 'opera mini',
'iemobile', 'mobile safari'
"mobile",
"android",
"iphone",
"ipad",
"ipod",
"blackberry",
"windows phone",
"webos",
"opera mini",
"iemobile",
"mobile safari",
]
user_agent_lower = user_agent.lower()
@@ -220,7 +228,7 @@ def get_device_type(user_agent: str) -> str:
user_agent_lower = user_agent.lower()
# Check for tablets first (they can contain "mobile" too)
if 'ipad' in user_agent_lower or 'tablet' in user_agent_lower:
if "ipad" in user_agent_lower or "tablet" in user_agent_lower:
return "tablet"
# Check for mobile
@@ -228,7 +236,7 @@ def get_device_type(user_agent: str) -> str:
return "mobile"
# Check for desktop OS patterns
if any(os in user_agent_lower for os in ['windows', 'macintosh', 'linux', 'cros']):
if any(os in user_agent_lower for os in ["windows", "macintosh", "linux", "cros"]):
return "desktop"
return "other"

View File

@@ -5,18 +5,21 @@ This module provides utilities for creating and verifying signed tokens,
useful for operations like file uploads, password resets, or any other
time-limited, single-use operations.
"""
import base64
import hashlib
import hmac
import json
import secrets
import time
from typing import Dict, Any, Optional
from typing import Any
from app.core.config import settings
def create_upload_token(file_path: str, content_type: str, expires_in: int = 300) -> str:
def create_upload_token(
file_path: str, content_type: str, expires_in: int = 300
) -> str:
"""
Create a signed token for secure file uploads.
@@ -40,34 +43,29 @@ def create_upload_token(file_path: str, content_type: str, expires_in: int = 300
"path": file_path,
"content_type": content_type,
"exp": int(time.time()) + expires_in,
"nonce": secrets.token_hex(8) # Add randomness to prevent token reuse
"nonce": secrets.token_hex(8), # Add randomness to prevent token reuse
}
# Convert to JSON and encode
payload_bytes = json.dumps(payload).encode('utf-8')
payload_bytes = json.dumps(payload).encode("utf-8")
# Create a signature using HMAC-SHA256 for security
# This prevents length extension attacks that plain SHA-256 is vulnerable to
signature = hmac.new(
settings.SECRET_KEY.encode('utf-8'),
payload_bytes,
hashlib.sha256
settings.SECRET_KEY.encode("utf-8"), payload_bytes, hashlib.sha256
).hexdigest()
# Combine payload and signature
token_data = {
"payload": payload,
"signature": signature
}
token_data = {"payload": payload, "signature": signature}
# Encode the final token
token_json = json.dumps(token_data)
token = base64.urlsafe_b64encode(token_json.encode('utf-8')).decode('utf-8')
token = base64.urlsafe_b64encode(token_json.encode("utf-8")).decode("utf-8")
return token
def verify_upload_token(token: str) -> Optional[Dict[str, Any]]:
def verify_upload_token(token: str) -> dict[str, Any] | None:
"""
Verify an upload token and return the payload if valid.
@@ -88,7 +86,7 @@ def verify_upload_token(token: str) -> Optional[Dict[str, Any]]:
"""
try:
# Decode the token
token_json = base64.urlsafe_b64decode(token.encode('utf-8')).decode('utf-8')
token_json = base64.urlsafe_b64decode(token.encode("utf-8")).decode("utf-8")
token_data = json.loads(token_json)
# Extract payload and signature
@@ -96,11 +94,9 @@ def verify_upload_token(token: str) -> Optional[Dict[str, Any]]:
signature = token_data["signature"]
# Verify signature using HMAC and constant-time comparison
payload_bytes = json.dumps(payload).encode('utf-8')
payload_bytes = json.dumps(payload).encode("utf-8")
expected_signature = hmac.new(
settings.SECRET_KEY.encode('utf-8'),
payload_bytes,
hashlib.sha256
settings.SECRET_KEY.encode("utf-8"), payload_bytes, hashlib.sha256
).hexdigest()
if not hmac.compare_digest(signature, expected_signature):
@@ -136,34 +132,29 @@ def create_password_reset_token(email: str, expires_in: int = 3600) -> str:
"email": email,
"exp": int(time.time()) + expires_in,
"nonce": secrets.token_hex(16), # Extra randomness
"purpose": "password_reset"
"purpose": "password_reset",
}
# Convert to JSON and encode
payload_bytes = json.dumps(payload).encode('utf-8')
payload_bytes = json.dumps(payload).encode("utf-8")
# Create a signature using HMAC-SHA256 for security
# This prevents length extension attacks that plain SHA-256 is vulnerable to
signature = hmac.new(
settings.SECRET_KEY.encode('utf-8'),
payload_bytes,
hashlib.sha256
settings.SECRET_KEY.encode("utf-8"), payload_bytes, hashlib.sha256
).hexdigest()
# Combine payload and signature
token_data = {
"payload": payload,
"signature": signature
}
token_data = {"payload": payload, "signature": signature}
# Encode the final token
token_json = json.dumps(token_data)
token = base64.urlsafe_b64encode(token_json.encode('utf-8')).decode('utf-8')
token = base64.urlsafe_b64encode(token_json.encode("utf-8")).decode("utf-8")
return token
def verify_password_reset_token(token: str) -> Optional[str]:
def verify_password_reset_token(token: str) -> str | None:
"""
Verify a password reset token and return the email if valid.
@@ -182,7 +173,7 @@ def verify_password_reset_token(token: str) -> Optional[str]:
"""
try:
# Decode the token
token_json = base64.urlsafe_b64decode(token.encode('utf-8')).decode('utf-8')
token_json = base64.urlsafe_b64decode(token.encode("utf-8")).decode("utf-8")
token_data = json.loads(token_json)
# Extract payload and signature
@@ -194,11 +185,9 @@ def verify_password_reset_token(token: str) -> Optional[str]:
return None
# Verify signature using HMAC and constant-time comparison
payload_bytes = json.dumps(payload).encode('utf-8')
payload_bytes = json.dumps(payload).encode("utf-8")
expected_signature = hmac.new(
settings.SECRET_KEY.encode('utf-8'),
payload_bytes,
hashlib.sha256
settings.SECRET_KEY.encode("utf-8"), payload_bytes, hashlib.sha256
).hexdigest()
if not hmac.compare_digest(signature, expected_signature):
@@ -234,34 +223,29 @@ def create_email_verification_token(email: str, expires_in: int = 86400) -> str:
"email": email,
"exp": int(time.time()) + expires_in,
"nonce": secrets.token_hex(16),
"purpose": "email_verification"
"purpose": "email_verification",
}
# Convert to JSON and encode
payload_bytes = json.dumps(payload).encode('utf-8')
payload_bytes = json.dumps(payload).encode("utf-8")
# Create a signature using HMAC-SHA256 for security
# This prevents length extension attacks that plain SHA-256 is vulnerable to
signature = hmac.new(
settings.SECRET_KEY.encode('utf-8'),
payload_bytes,
hashlib.sha256
settings.SECRET_KEY.encode("utf-8"), payload_bytes, hashlib.sha256
).hexdigest()
# Combine payload and signature
token_data = {
"payload": payload,
"signature": signature
}
token_data = {"payload": payload, "signature": signature}
# Encode the final token
token_json = json.dumps(token_data)
token = base64.urlsafe_b64encode(token_json.encode('utf-8')).decode('utf-8')
token = base64.urlsafe_b64encode(token_json.encode("utf-8")).decode("utf-8")
return token
def verify_email_verification_token(token: str) -> Optional[str]:
def verify_email_verification_token(token: str) -> str | None:
"""
Verify an email verification token and return the email if valid.
@@ -280,7 +264,7 @@ def verify_email_verification_token(token: str) -> Optional[str]:
"""
try:
# Decode the token
token_json = base64.urlsafe_b64decode(token.encode('utf-8')).decode('utf-8')
token_json = base64.urlsafe_b64decode(token.encode("utf-8")).decode("utf-8")
token_data = json.loads(token_json)
# Extract payload and signature
@@ -292,11 +276,9 @@ def verify_email_verification_token(token: str) -> Optional[str]:
return None
# Verify signature using HMAC and constant-time comparison
payload_bytes = json.dumps(payload).encode('utf-8')
payload_bytes = json.dumps(payload).encode("utf-8")
expected_signature = hmac.new(
settings.SECRET_KEY.encode('utf-8'),
payload_bytes,
hashlib.sha256
settings.SECRET_KEY.encode("utf-8"), payload_bytes, hashlib.sha256
).hexdigest()
if not hmac.compare_digest(signature, expected_signature):

View File

@@ -9,17 +9,19 @@ from app.core.database import Base
logger = logging.getLogger(__name__)
def get_test_engine():
"""Create an SQLite in-memory engine specifically for testing"""
test_engine = create_engine(
"sqlite:///:memory:",
connect_args={"check_same_thread": False},
poolclass=StaticPool, # Use static pool for in-memory testing
echo=False
echo=False,
)
return test_engine
def setup_test_db():
"""Create a test database and session factory"""
# Create a new engine for this test run
@@ -30,14 +32,12 @@ def setup_test_db():
# Create session factory
TestingSessionLocal = sessionmaker(
autocommit=False,
autoflush=False,
bind=test_engine,
expire_on_commit=False
autocommit=False, autoflush=False, bind=test_engine, expire_on_commit=False
)
return test_engine, TestingSessionLocal
def teardown_test_db(engine):
"""Clean up after tests"""
# Drop all tables
@@ -46,13 +46,14 @@ def teardown_test_db(engine):
# Dispose of engine
engine.dispose()
async def get_async_test_engine():
"""Create an async SQLite in-memory engine specifically for testing"""
test_engine = create_async_engine(
"sqlite+aiosqlite:///:memory:",
connect_args={"check_same_thread": False},
poolclass=StaticPool, # Use static pool for in-memory testing
echo=False
echo=False,
)
return test_engine
@@ -69,7 +70,7 @@ async def setup_async_test_db():
autoflush=False,
bind=test_engine,
expire_on_commit=False,
class_=AsyncSession
class_=AsyncSession,
)
return test_engine, AsyncTestingSessionLocal

View File

@@ -1,15 +1,16 @@
# tests/api/dependencies/test_auth_dependencies.py
import pytest
import pytest_asyncio
import uuid
from unittest.mock import patch
import pytest
import pytest_asyncio
from fastapi import HTTPException
from app.api.dependencies.auth import (
get_current_user,
get_current_active_user,
get_current_superuser,
get_optional_current_user
get_current_user,
get_optional_current_user,
)
from app.core.auth import TokenExpiredError, TokenInvalidError, get_password_hash
from app.models.user import User
@@ -24,7 +25,7 @@ def mock_token():
@pytest_asyncio.fixture
async def async_mock_user(async_test_db):
"""Async fixture to create and return a mock User instance."""
test_engine, AsyncTestingSessionLocal = async_test_db
_test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session:
mock_user = User(
id=uuid.uuid4(),
@@ -47,12 +48,14 @@ class TestGetCurrentUser:
"""Tests for get_current_user dependency"""
@pytest.mark.asyncio
async def test_get_current_user_success(self, async_test_db, async_mock_user, mock_token):
async def test_get_current_user_success(
self, async_test_db, async_mock_user, mock_token
):
"""Test successfully getting the current user"""
test_engine, AsyncTestingSessionLocal = async_test_db
_test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session:
# Mock get_token_data to return user_id that matches our mock_user
with patch('app.api.dependencies.auth.get_token_data') as mock_get_data:
with patch("app.api.dependencies.auth.get_token_data") as mock_get_data:
mock_get_data.return_value.user_id = async_mock_user.id
# Call the dependency
@@ -65,12 +68,12 @@ class TestGetCurrentUser:
@pytest.mark.asyncio
async def test_get_current_user_nonexistent(self, async_test_db, mock_token):
"""Test when the token contains a user ID that doesn't exist"""
test_engine, AsyncTestingSessionLocal = async_test_db
_test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session:
# Mock get_token_data to return a non-existent user ID
nonexistent_id = uuid.UUID("11111111-1111-1111-1111-111111111111")
with patch('app.api.dependencies.auth.get_token_data') as mock_get_data:
with patch("app.api.dependencies.auth.get_token_data") as mock_get_data:
mock_get_data.return_value.user_id = nonexistent_id
# Should raise HTTPException with 404 status
@@ -81,19 +84,24 @@ class TestGetCurrentUser:
assert "User not found" in exc_info.value.detail
@pytest.mark.asyncio
async def test_get_current_user_inactive(self, async_test_db, async_mock_user, mock_token):
async def test_get_current_user_inactive(
self, async_test_db, async_mock_user, mock_token
):
"""Test when the user is inactive"""
test_engine, AsyncTestingSessionLocal = async_test_db
_test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session:
# Get the user in this session and make it inactive
from sqlalchemy import select
result = await session.execute(select(User).where(User.id == async_mock_user.id))
result = await session.execute(
select(User).where(User.id == async_mock_user.id)
)
user_in_session = result.scalar_one_or_none()
user_in_session.is_active = False
await session.commit()
# Mock get_token_data
with patch('app.api.dependencies.auth.get_token_data') as mock_get_data:
with patch("app.api.dependencies.auth.get_token_data") as mock_get_data:
mock_get_data.return_value.user_id = async_mock_user.id
# Should raise HTTPException with 403 status
@@ -106,10 +114,10 @@ class TestGetCurrentUser:
@pytest.mark.asyncio
async def test_get_current_user_expired_token(self, async_test_db, mock_token):
"""Test with an expired token"""
test_engine, AsyncTestingSessionLocal = async_test_db
_test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session:
# Mock get_token_data to raise TokenExpiredError
with patch('app.api.dependencies.auth.get_token_data') as mock_get_data:
with patch("app.api.dependencies.auth.get_token_data") as mock_get_data:
mock_get_data.side_effect = TokenExpiredError("Token expired")
# Should raise HTTPException with 401 status
@@ -122,10 +130,10 @@ class TestGetCurrentUser:
@pytest.mark.asyncio
async def test_get_current_user_invalid_token(self, async_test_db, mock_token):
"""Test with an invalid token"""
test_engine, AsyncTestingSessionLocal = async_test_db
_test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session:
# Mock get_token_data to raise TokenInvalidError
with patch('app.api.dependencies.auth.get_token_data') as mock_get_data:
with patch("app.api.dependencies.auth.get_token_data") as mock_get_data:
mock_get_data.side_effect = TokenInvalidError("Invalid token")
# Should raise HTTPException with 401 status
@@ -194,12 +202,14 @@ class TestGetOptionalCurrentUser:
"""Tests for get_optional_current_user dependency"""
@pytest.mark.asyncio
async def test_get_optional_current_user_with_token(self, async_test_db, async_mock_user, mock_token):
async def test_get_optional_current_user_with_token(
self, async_test_db, async_mock_user, mock_token
):
"""Test getting optional user with a valid token"""
test_engine, AsyncTestingSessionLocal = async_test_db
_test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session:
# Mock get_token_data
with patch('app.api.dependencies.auth.get_token_data') as mock_get_data:
with patch("app.api.dependencies.auth.get_token_data") as mock_get_data:
mock_get_data.return_value.user_id = async_mock_user.id
# Call the dependency
@@ -212,7 +222,7 @@ class TestGetOptionalCurrentUser:
@pytest.mark.asyncio
async def test_get_optional_current_user_no_token(self, async_test_db):
"""Test getting optional user with no token"""
test_engine, AsyncTestingSessionLocal = async_test_db
_test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session:
# Call the dependency with no token
user = await get_optional_current_user(db=session, token=None)
@@ -221,12 +231,14 @@ class TestGetOptionalCurrentUser:
assert user is None
@pytest.mark.asyncio
async def test_get_optional_current_user_invalid_token(self, async_test_db, mock_token):
async def test_get_optional_current_user_invalid_token(
self, async_test_db, mock_token
):
"""Test getting optional user with an invalid token"""
test_engine, AsyncTestingSessionLocal = async_test_db
_test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session:
# Mock get_token_data to raise TokenInvalidError
with patch('app.api.dependencies.auth.get_token_data') as mock_get_data:
with patch("app.api.dependencies.auth.get_token_data") as mock_get_data:
mock_get_data.side_effect = TokenInvalidError("Invalid token")
# Call the dependency
@@ -236,12 +248,14 @@ class TestGetOptionalCurrentUser:
assert user is None
@pytest.mark.asyncio
async def test_get_optional_current_user_expired_token(self, async_test_db, mock_token):
async def test_get_optional_current_user_expired_token(
self, async_test_db, mock_token
):
"""Test getting optional user with an expired token"""
test_engine, AsyncTestingSessionLocal = async_test_db
_test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session:
# Mock get_token_data to raise TokenExpiredError
with patch('app.api.dependencies.auth.get_token_data') as mock_get_data:
with patch("app.api.dependencies.auth.get_token_data") as mock_get_data:
mock_get_data.side_effect = TokenExpiredError("Token expired")
# Call the dependency
@@ -251,19 +265,24 @@ class TestGetOptionalCurrentUser:
assert user is None
@pytest.mark.asyncio
async def test_get_optional_current_user_inactive(self, async_test_db, async_mock_user, mock_token):
async def test_get_optional_current_user_inactive(
self, async_test_db, async_mock_user, mock_token
):
"""Test getting optional user when user is inactive"""
test_engine, AsyncTestingSessionLocal = async_test_db
_test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session:
# Get the user in this session and make it inactive
from sqlalchemy import select
result = await session.execute(select(User).where(User.id == async_mock_user.id))
result = await session.execute(
select(User).where(User.id == async_mock_user.id)
)
user_in_session = result.scalar_one_or_none()
user_in_session.is_active = False
await session.commit()
# Mock get_token_data
with patch('app.api.dependencies.auth.get_token_data') as mock_get_data:
with patch("app.api.dependencies.auth.get_token_data") as mock_get_data:
mock_get_data.return_value.user_id = async_mock_user.id
# Call the dependency

View File

@@ -1,13 +1,12 @@
# tests/api/routes/test_health.py
from datetime import datetime
from unittest.mock import patch
import pytest
from unittest.mock import AsyncMock, patch, MagicMock
from fastapi import status
from fastapi.testclient import TestClient
from datetime import datetime
from sqlalchemy.exc import OperationalError
from app.main import app
from app.core.database import get_db
@pytest.fixture
@@ -121,7 +120,10 @@ class TestHealthEndpoint:
response = client.get("/health")
# Should succeed without authentication
assert response.status_code in [status.HTTP_200_OK, status.HTTP_503_SERVICE_UNAVAILABLE]
assert response.status_code in [
status.HTTP_200_OK,
status.HTTP_503_SERVICE_UNAVAILABLE,
]
def test_health_check_idempotent(self, client):
"""Test that multiple health checks return consistent results"""
@@ -142,7 +144,10 @@ class TestHealthEndpoint:
assert data1["environment"] == data2["environment"]
# Same database check status
assert data1["checks"]["database"]["status"] == data2["checks"]["database"]["status"]
assert (
data1["checks"]["database"]["status"]
== data2["checks"]["database"]["status"]
)
def test_health_check_content_type(self, client):
"""Test that health check returns JSON content type"""

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -2,6 +2,7 @@
"""
Tests for authentication endpoints.
"""
import pytest
import pytest_asyncio
from fastapi import status
@@ -19,8 +20,8 @@ class TestRegisterEndpoint:
"email": "newuser@example.com",
"password": "NewPassword123!",
"first_name": "New",
"last_name": "User"
}
"last_name": "User",
},
)
assert response.status_code == status.HTTP_201_CREATED
@@ -36,8 +37,8 @@ class TestRegisterEndpoint:
"email": async_test_user.email,
"password": "TestPassword123!",
"first_name": "Test",
"last_name": "User"
}
"last_name": "User",
},
)
assert response.status_code == status.HTTP_400_BAD_REQUEST
@@ -51,8 +52,8 @@ class TestRegisterEndpoint:
"email": "test@example.com",
"password": "weak",
"first_name": "Test",
"last_name": "User"
}
"last_name": "User",
},
)
assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY
@@ -66,10 +67,7 @@ class TestLoginEndpoint:
"""Test successful login."""
response = await client.post(
"/api/v1/auth/login",
json={
"email": "testuser@example.com",
"password": "TestPassword123!"
}
json={"email": "testuser@example.com", "password": "TestPassword123!"},
)
assert response.status_code == status.HTTP_200_OK
@@ -82,10 +80,7 @@ class TestLoginEndpoint:
"""Test login with invalid password."""
response = await client.post(
"/api/v1/auth/login",
json={
"email": "testuser@example.com",
"password": "WrongPassword123!"
}
json={"email": "testuser@example.com", "password": "WrongPassword123!"},
)
assert response.status_code == status.HTTP_401_UNAUTHORIZED
@@ -95,10 +90,7 @@ class TestLoginEndpoint:
"""Test login with non-existent user."""
response = await client.post(
"/api/v1/auth/login",
json={
"email": "nonexistent@example.com",
"password": "TestPassword123!"
}
json={"email": "nonexistent@example.com", "password": "TestPassword123!"},
)
assert response.status_code == status.HTTP_401_UNAUTHORIZED
@@ -106,27 +98,25 @@ class TestLoginEndpoint:
@pytest.mark.asyncio
async def test_login_inactive_user(self, client, async_test_db):
"""Test login with inactive user."""
test_engine, SessionLocal = async_test_db
_test_engine, SessionLocal = async_test_db
async with SessionLocal() as session:
from app.models.user import User
from app.core.auth import get_password_hash
from app.models.user import User
inactive_user = User(
email="inactive@example.com",
password_hash=get_password_hash("TestPassword123!"),
first_name="Inactive",
last_name="User",
is_active=False
is_active=False,
)
session.add(inactive_user)
await session.commit()
response = await client.post(
"/api/v1/auth/login",
json={
"email": "inactive@example.com",
"password": "TestPassword123!"
}
json={"email": "inactive@example.com", "password": "TestPassword123!"},
)
assert response.status_code == status.HTTP_401_UNAUTHORIZED
@@ -140,10 +130,7 @@ class TestRefreshTokenEndpoint:
"""Get a refresh token for testing."""
response = await client.post(
"/api/v1/auth/login",
json={
"email": "testuser@example.com",
"password": "TestPassword123!"
}
json={"email": "testuser@example.com", "password": "TestPassword123!"},
)
return response.json()["refresh_token"]
@@ -151,8 +138,7 @@ class TestRefreshTokenEndpoint:
async def test_refresh_token_success(self, client, refresh_token):
"""Test successful token refresh."""
response = await client.post(
"/api/v1/auth/refresh",
json={"refresh_token": refresh_token}
"/api/v1/auth/refresh", json={"refresh_token": refresh_token}
)
assert response.status_code == status.HTTP_200_OK
@@ -164,8 +150,7 @@ class TestRefreshTokenEndpoint:
async def test_refresh_token_invalid(self, client):
"""Test refresh with invalid token."""
response = await client.post(
"/api/v1/auth/refresh",
json={"refresh_token": "invalid.token.here"}
"/api/v1/auth/refresh", json={"refresh_token": "invalid.token.here"}
)
assert response.status_code == status.HTTP_401_UNAUTHORIZED
@@ -179,13 +164,13 @@ class TestLogoutEndpoint:
"""Get tokens for testing."""
response = await client.post(
"/api/v1/auth/login",
json={
"email": "testuser@example.com",
"password": "TestPassword123!"
}
json={"email": "testuser@example.com", "password": "TestPassword123!"},
)
data = response.json()
return {"access_token": data["access_token"], "refresh_token": data["refresh_token"]}
return {
"access_token": data["access_token"],
"refresh_token": data["refresh_token"],
}
@pytest.mark.asyncio
async def test_logout_success(self, client, tokens):
@@ -193,7 +178,7 @@ class TestLogoutEndpoint:
response = await client.post(
"/api/v1/auth/logout",
headers={"Authorization": f"Bearer {tokens['access_token']}"},
json={"refresh_token": tokens["refresh_token"]}
json={"refresh_token": tokens["refresh_token"]},
)
assert response.status_code == status.HTTP_200_OK
@@ -202,8 +187,7 @@ class TestLogoutEndpoint:
async def test_logout_without_auth(self, client):
"""Test logout without authentication."""
response = await client.post(
"/api/v1/auth/logout",
json={"refresh_token": "some.token"}
"/api/v1/auth/logout", json={"refresh_token": "some.token"}
)
assert response.status_code == status.HTTP_401_UNAUTHORIZED
@@ -215,8 +199,7 @@ class TestPasswordResetRequest:
async def test_password_reset_request_success(self, client, async_test_user):
"""Test password reset request with existing user."""
response = await client.post(
"/api/v1/auth/password-reset/request",
json={"email": async_test_user.email}
"/api/v1/auth/password-reset/request", json={"email": async_test_user.email}
)
assert response.status_code == status.HTTP_200_OK
@@ -228,7 +211,7 @@ class TestPasswordResetRequest:
"""Test password reset request with non-existent email."""
response = await client.post(
"/api/v1/auth/password-reset/request",
json={"email": "nonexistent@example.com"}
json={"email": "nonexistent@example.com"},
)
assert response.status_code == status.HTTP_200_OK
@@ -244,10 +227,7 @@ class TestPasswordResetConfirm:
"""Test password reset with invalid token."""
response = await client.post(
"/api/v1/auth/password-reset/confirm",
json={
"token": "invalid.token.here",
"new_password": "NewPassword123!"
}
json={"token": "invalid.token.here", "new_password": "NewPassword123!"},
)
assert response.status_code == status.HTTP_400_BAD_REQUEST
@@ -261,20 +241,20 @@ class TestLogoutAll:
"""Get tokens for testing."""
response = await client.post(
"/api/v1/auth/login",
json={
"email": "testuser@example.com",
"password": "TestPassword123!"
}
json={"email": "testuser@example.com", "password": "TestPassword123!"},
)
data = response.json()
return {"access_token": data["access_token"], "refresh_token": data["refresh_token"]}
return {
"access_token": data["access_token"],
"refresh_token": data["refresh_token"],
}
@pytest.mark.asyncio
async def test_logout_all_success(self, client, tokens):
"""Test logout from all devices."""
response = await client.post(
"/api/v1/auth/logout-all",
headers={"Authorization": f"Bearer {tokens['access_token']}"}
headers={"Authorization": f"Bearer {tokens['access_token']}"},
)
assert response.status_code == status.HTTP_200_OK
@@ -298,10 +278,7 @@ class TestOAuthLogin:
"""Test successful OAuth login."""
response = await client.post(
"/api/v1/auth/login/oauth",
data={
"username": "testuser@example.com",
"password": "TestPassword123!"
}
data={"username": "testuser@example.com", "password": "TestPassword123!"},
)
assert response.status_code == status.HTTP_200_OK
@@ -315,10 +292,7 @@ class TestOAuthLogin:
"""Test OAuth login with invalid credentials."""
response = await client.post(
"/api/v1/auth/login/oauth",
data={
"username": "testuser@example.com",
"password": "WrongPassword"
}
data={"username": "testuser@example.com", "password": "WrongPassword"},
)
assert response.status_code == status.HTTP_401_UNAUTHORIZED

View File

@@ -1,15 +1,16 @@
# tests/api/dependencies/test_auth_dependencies.py
import pytest
import pytest_asyncio
import uuid
from unittest.mock import patch
import pytest
import pytest_asyncio
from fastapi import HTTPException
from app.api.dependencies.auth import (
get_current_user,
get_current_active_user,
get_current_superuser,
get_optional_current_user
get_current_user,
get_optional_current_user,
)
from app.core.auth import TokenExpiredError, TokenInvalidError, get_password_hash
from app.models.user import User
@@ -24,7 +25,7 @@ def mock_token():
@pytest_asyncio.fixture
async def async_mock_user(async_test_db):
"""Async fixture to create and return a mock User instance."""
test_engine, AsyncTestingSessionLocal = async_test_db
_test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session:
mock_user = User(
id=uuid.uuid4(),
@@ -47,12 +48,14 @@ class TestGetCurrentUser:
"""Tests for get_current_user dependency"""
@pytest.mark.asyncio
async def test_get_current_user_success(self, async_test_db, async_mock_user, mock_token):
async def test_get_current_user_success(
self, async_test_db, async_mock_user, mock_token
):
"""Test successfully getting the current user"""
test_engine, AsyncTestingSessionLocal = async_test_db
_test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session:
# Mock get_token_data to return user_id that matches our mock_user
with patch('app.api.dependencies.auth.get_token_data') as mock_get_data:
with patch("app.api.dependencies.auth.get_token_data") as mock_get_data:
mock_get_data.return_value.user_id = async_mock_user.id
# Call the dependency
@@ -65,12 +68,12 @@ class TestGetCurrentUser:
@pytest.mark.asyncio
async def test_get_current_user_nonexistent(self, async_test_db, mock_token):
"""Test when the token contains a user ID that doesn't exist"""
test_engine, AsyncTestingSessionLocal = async_test_db
_test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session:
# Mock get_token_data to return a non-existent user ID
nonexistent_id = uuid.UUID("11111111-1111-1111-1111-111111111111")
with patch('app.api.dependencies.auth.get_token_data') as mock_get_data:
with patch("app.api.dependencies.auth.get_token_data") as mock_get_data:
mock_get_data.return_value.user_id = nonexistent_id
# Should raise HTTPException with 404 status
@@ -81,19 +84,24 @@ class TestGetCurrentUser:
assert "User not found" in exc_info.value.detail
@pytest.mark.asyncio
async def test_get_current_user_inactive(self, async_test_db, async_mock_user, mock_token):
async def test_get_current_user_inactive(
self, async_test_db, async_mock_user, mock_token
):
"""Test when the user is inactive"""
test_engine, AsyncTestingSessionLocal = async_test_db
_test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session:
# Get the user in this session and make it inactive
from sqlalchemy import select
result = await session.execute(select(User).where(User.id == async_mock_user.id))
result = await session.execute(
select(User).where(User.id == async_mock_user.id)
)
user_in_session = result.scalar_one_or_none()
user_in_session.is_active = False
await session.commit()
# Mock get_token_data
with patch('app.api.dependencies.auth.get_token_data') as mock_get_data:
with patch("app.api.dependencies.auth.get_token_data") as mock_get_data:
mock_get_data.return_value.user_id = async_mock_user.id
# Should raise HTTPException with 403 status
@@ -106,10 +114,10 @@ class TestGetCurrentUser:
@pytest.mark.asyncio
async def test_get_current_user_expired_token(self, async_test_db, mock_token):
"""Test with an expired token"""
test_engine, AsyncTestingSessionLocal = async_test_db
_test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session:
# Mock get_token_data to raise TokenExpiredError
with patch('app.api.dependencies.auth.get_token_data') as mock_get_data:
with patch("app.api.dependencies.auth.get_token_data") as mock_get_data:
mock_get_data.side_effect = TokenExpiredError("Token expired")
# Should raise HTTPException with 401 status
@@ -122,10 +130,10 @@ class TestGetCurrentUser:
@pytest.mark.asyncio
async def test_get_current_user_invalid_token(self, async_test_db, mock_token):
"""Test with an invalid token"""
test_engine, AsyncTestingSessionLocal = async_test_db
_test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session:
# Mock get_token_data to raise TokenInvalidError
with patch('app.api.dependencies.auth.get_token_data') as mock_get_data:
with patch("app.api.dependencies.auth.get_token_data") as mock_get_data:
mock_get_data.side_effect = TokenInvalidError("Invalid token")
# Should raise HTTPException with 401 status
@@ -194,12 +202,14 @@ class TestGetOptionalCurrentUser:
"""Tests for get_optional_current_user dependency"""
@pytest.mark.asyncio
async def test_get_optional_current_user_with_token(self, async_test_db, async_mock_user, mock_token):
async def test_get_optional_current_user_with_token(
self, async_test_db, async_mock_user, mock_token
):
"""Test getting optional user with a valid token"""
test_engine, AsyncTestingSessionLocal = async_test_db
_test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session:
# Mock get_token_data
with patch('app.api.dependencies.auth.get_token_data') as mock_get_data:
with patch("app.api.dependencies.auth.get_token_data") as mock_get_data:
mock_get_data.return_value.user_id = async_mock_user.id
# Call the dependency
@@ -212,7 +222,7 @@ class TestGetOptionalCurrentUser:
@pytest.mark.asyncio
async def test_get_optional_current_user_no_token(self, async_test_db):
"""Test getting optional user with no token"""
test_engine, AsyncTestingSessionLocal = async_test_db
_test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session:
# Call the dependency with no token
user = await get_optional_current_user(db=session, token=None)
@@ -221,12 +231,14 @@ class TestGetOptionalCurrentUser:
assert user is None
@pytest.mark.asyncio
async def test_get_optional_current_user_invalid_token(self, async_test_db, mock_token):
async def test_get_optional_current_user_invalid_token(
self, async_test_db, mock_token
):
"""Test getting optional user with an invalid token"""
test_engine, AsyncTestingSessionLocal = async_test_db
_test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session:
# Mock get_token_data to raise TokenInvalidError
with patch('app.api.dependencies.auth.get_token_data') as mock_get_data:
with patch("app.api.dependencies.auth.get_token_data") as mock_get_data:
mock_get_data.side_effect = TokenInvalidError("Invalid token")
# Call the dependency
@@ -236,12 +248,14 @@ class TestGetOptionalCurrentUser:
assert user is None
@pytest.mark.asyncio
async def test_get_optional_current_user_expired_token(self, async_test_db, mock_token):
async def test_get_optional_current_user_expired_token(
self, async_test_db, mock_token
):
"""Test getting optional user with an expired token"""
test_engine, AsyncTestingSessionLocal = async_test_db
_test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session:
# Mock get_token_data to raise TokenExpiredError
with patch('app.api.dependencies.auth.get_token_data') as mock_get_data:
with patch("app.api.dependencies.auth.get_token_data") as mock_get_data:
mock_get_data.side_effect = TokenExpiredError("Token expired")
# Call the dependency
@@ -251,19 +265,24 @@ class TestGetOptionalCurrentUser:
assert user is None
@pytest.mark.asyncio
async def test_get_optional_current_user_inactive(self, async_test_db, async_mock_user, mock_token):
async def test_get_optional_current_user_inactive(
self, async_test_db, async_mock_user, mock_token
):
"""Test getting optional user when user is inactive"""
test_engine, AsyncTestingSessionLocal = async_test_db
_test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session:
# Get the user in this session and make it inactive
from sqlalchemy import select
result = await session.execute(select(User).where(User.id == async_mock_user.id))
result = await session.execute(
select(User).where(User.id == async_mock_user.id)
)
user_in_session = result.scalar_one_or_none()
user_in_session.is_active = False
await session.commit()
# Mock get_token_data
with patch('app.api.dependencies.auth.get_token_data') as mock_get_data:
with patch("app.api.dependencies.auth.get_token_data") as mock_get_data:
mock_get_data.return_value.user_id = async_mock_user.id
# Call the dependency

View File

@@ -2,21 +2,21 @@
"""
Tests for authentication endpoints.
"""
from unittest.mock import patch
import pytest
import pytest_asyncio
from unittest.mock import patch, MagicMock
from fastapi import status
from sqlalchemy import select
from app.models.user import User
from app.schemas.users import UserCreate
# Disable rate limiting for tests
@pytest.fixture(autouse=True)
def disable_rate_limit():
"""Disable rate limiting for all tests in this module."""
with patch('app.api.routes.auth.limiter.enabled', False):
with patch("app.api.routes.auth.limiter.enabled", False):
yield
@@ -32,8 +32,8 @@ class TestRegisterEndpoint:
"email": "newuser@example.com",
"password": "SecurePassword123!",
"first_name": "New",
"last_name": "User"
}
"last_name": "User",
},
)
assert response.status_code == status.HTTP_201_CREATED
@@ -54,8 +54,8 @@ class TestRegisterEndpoint:
"email": async_test_user.email,
"password": "SecurePassword123!",
"first_name": "Duplicate",
"last_name": "User"
}
"last_name": "User",
},
)
# Security: Returns 400 with generic message to prevent email enumeration
@@ -73,8 +73,8 @@ class TestRegisterEndpoint:
"email": "weakpass@example.com",
"password": "weak",
"first_name": "Weak",
"last_name": "Pass"
}
"last_name": "Pass",
},
)
assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY
@@ -82,7 +82,7 @@ class TestRegisterEndpoint:
@pytest.mark.asyncio
async def test_register_unexpected_error(self, client):
"""Test registration with unexpected error."""
with patch('app.services.auth_service.AuthService.create_user') as mock_create:
with patch("app.services.auth_service.AuthService.create_user") as mock_create:
mock_create.side_effect = Exception("Unexpected error")
response = await client.post(
@@ -91,8 +91,8 @@ class TestRegisterEndpoint:
"email": "error@example.com",
"password": "SecurePassword123!",
"first_name": "Error",
"last_name": "User"
}
"last_name": "User",
},
)
assert response.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR
@@ -106,10 +106,7 @@ class TestLoginEndpoint:
"""Test successful login."""
response = await client.post(
"/api/v1/auth/login",
json={
"email": async_test_user.email,
"password": "TestPassword123!"
}
json={"email": async_test_user.email, "password": "TestPassword123!"},
)
assert response.status_code == status.HTTP_200_OK
@@ -123,10 +120,7 @@ class TestLoginEndpoint:
"""Test login with wrong password."""
response = await client.post(
"/api/v1/auth/login",
json={
"email": async_test_user.email,
"password": "WrongPassword123"
}
json={"email": async_test_user.email, "password": "WrongPassword123"},
)
assert response.status_code == status.HTTP_401_UNAUTHORIZED
@@ -136,10 +130,7 @@ class TestLoginEndpoint:
"""Test login with non-existent email."""
response = await client.post(
"/api/v1/auth/login",
json={
"email": "nonexistent@example.com",
"password": "Password123!"
}
json={"email": "nonexistent@example.com", "password": "Password123!"},
)
assert response.status_code == status.HTTP_401_UNAUTHORIZED
@@ -147,20 +138,19 @@ class TestLoginEndpoint:
@pytest.mark.asyncio
async def test_login_inactive_user(self, client, async_test_user, async_test_db):
"""Test login with inactive user."""
test_engine, AsyncTestingSessionLocal = async_test_db
_test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session:
# Get the user in this session and make it inactive
result = await session.execute(select(User).where(User.id == async_test_user.id))
result = await session.execute(
select(User).where(User.id == async_test_user.id)
)
user_in_session = result.scalar_one_or_none()
user_in_session.is_active = False
await session.commit()
response = await client.post(
"/api/v1/auth/login",
json={
"email": async_test_user.email,
"password": "TestPassword123!"
}
json={"email": async_test_user.email, "password": "TestPassword123!"},
)
assert response.status_code == status.HTTP_401_UNAUTHORIZED
@@ -168,15 +158,14 @@ class TestLoginEndpoint:
@pytest.mark.asyncio
async def test_login_unexpected_error(self, client, async_test_user):
"""Test login with unexpected error."""
with patch('app.services.auth_service.AuthService.authenticate_user') as mock_auth:
with patch(
"app.services.auth_service.AuthService.authenticate_user"
) as mock_auth:
mock_auth.side_effect = Exception("Database error")
response = await client.post(
"/api/v1/auth/login",
json={
"email": async_test_user.email,
"password": "TestPassword123!"
}
json={"email": async_test_user.email, "password": "TestPassword123!"},
)
assert response.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR
@@ -190,10 +179,7 @@ class TestOAuthLoginEndpoint:
"""Test successful OAuth login."""
response = await client.post(
"/api/v1/auth/login/oauth",
data={
"username": async_test_user.email,
"password": "TestPassword123!"
}
data={"username": async_test_user.email, "password": "TestPassword123!"},
)
assert response.status_code == status.HTTP_200_OK
@@ -206,31 +192,29 @@ class TestOAuthLoginEndpoint:
"""Test OAuth login with wrong credentials."""
response = await client.post(
"/api/v1/auth/login/oauth",
data={
"username": async_test_user.email,
"password": "WrongPassword"
}
data={"username": async_test_user.email, "password": "WrongPassword"},
)
assert response.status_code == status.HTTP_401_UNAUTHORIZED
@pytest.mark.asyncio
async def test_oauth_login_inactive_user(self, client, async_test_user, async_test_db):
async def test_oauth_login_inactive_user(
self, client, async_test_user, async_test_db
):
"""Test OAuth login with inactive user."""
test_engine, AsyncTestingSessionLocal = async_test_db
_test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session:
# Get the user in this session and make it inactive
result = await session.execute(select(User).where(User.id == async_test_user.id))
result = await session.execute(
select(User).where(User.id == async_test_user.id)
)
user_in_session = result.scalar_one_or_none()
user_in_session.is_active = False
await session.commit()
response = await client.post(
"/api/v1/auth/login/oauth",
data={
"username": async_test_user.email,
"password": "TestPassword123!"
}
data={"username": async_test_user.email, "password": "TestPassword123!"},
)
assert response.status_code == status.HTTP_401_UNAUTHORIZED
@@ -238,15 +222,17 @@ class TestOAuthLoginEndpoint:
@pytest.mark.asyncio
async def test_oauth_login_unexpected_error(self, client, async_test_user):
"""Test OAuth login with unexpected error."""
with patch('app.services.auth_service.AuthService.authenticate_user') as mock_auth:
with patch(
"app.services.auth_service.AuthService.authenticate_user"
) as mock_auth:
mock_auth.side_effect = Exception("Unexpected error")
response = await client.post(
"/api/v1/auth/login/oauth",
data={
"username": async_test_user.email,
"password": "TestPassword123!"
}
"password": "TestPassword123!",
},
)
assert response.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR
@@ -261,17 +247,13 @@ class TestRefreshTokenEndpoint:
# First, login to get a refresh token
login_response = await client.post(
"/api/v1/auth/login",
json={
"email": async_test_user.email,
"password": "TestPassword123!"
}
json={"email": async_test_user.email, "password": "TestPassword123!"},
)
refresh_token = login_response.json()["refresh_token"]
# Now refresh the token
response = await client.post(
"/api/v1/auth/refresh",
json={"refresh_token": refresh_token}
"/api/v1/auth/refresh", json={"refresh_token": refresh_token}
)
assert response.status_code == status.HTTP_200_OK
@@ -284,12 +266,13 @@ class TestRefreshTokenEndpoint:
"""Test refresh with expired token."""
from app.core.auth import TokenExpiredError
with patch('app.services.auth_service.AuthService.refresh_tokens') as mock_refresh:
with patch(
"app.services.auth_service.AuthService.refresh_tokens"
) as mock_refresh:
mock_refresh.side_effect = TokenExpiredError("Token expired")
response = await client.post(
"/api/v1/auth/refresh",
json={"refresh_token": "some_token"}
"/api/v1/auth/refresh", json={"refresh_token": "some_token"}
)
assert response.status_code == status.HTTP_401_UNAUTHORIZED
@@ -298,8 +281,7 @@ class TestRefreshTokenEndpoint:
async def test_refresh_token_invalid(self, client):
"""Test refresh with invalid token."""
response = await client.post(
"/api/v1/auth/refresh",
json={"refresh_token": "invalid_token"}
"/api/v1/auth/refresh", json={"refresh_token": "invalid_token"}
)
assert response.status_code == status.HTTP_401_UNAUTHORIZED
@@ -310,19 +292,17 @@ class TestRefreshTokenEndpoint:
# Get a valid refresh token first
login_response = await client.post(
"/api/v1/auth/login",
json={
"email": async_test_user.email,
"password": "TestPassword123!"
}
json={"email": async_test_user.email, "password": "TestPassword123!"},
)
refresh_token = login_response.json()["refresh_token"]
with patch('app.services.auth_service.AuthService.refresh_tokens') as mock_refresh:
with patch(
"app.services.auth_service.AuthService.refresh_tokens"
) as mock_refresh:
mock_refresh.side_effect = Exception("Unexpected error")
response = await client.post(
"/api/v1/auth/refresh",
json={"refresh_token": refresh_token}
"/api/v1/auth/refresh", json={"refresh_token": refresh_token}
)
assert response.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR

View File

@@ -2,8 +2,10 @@
"""
Tests for auth route exception handlers and error paths.
"""
from unittest.mock import patch
import pytest
from unittest.mock import patch, AsyncMock
from fastapi import status
@@ -11,16 +13,18 @@ class TestLoginSessionCreationFailure:
"""Test login when session creation fails."""
@pytest.mark.asyncio
async def test_login_succeeds_despite_session_creation_failure(self, client, async_test_user):
async def test_login_succeeds_despite_session_creation_failure(
self, client, async_test_user
):
"""Test that login succeeds even if session creation fails."""
# Mock session creation to fail
with patch('app.api.routes.auth.session_crud.create_session', side_effect=Exception("Session creation failed")):
with patch(
"app.api.routes.auth.session_crud.create_session",
side_effect=Exception("Session creation failed"),
):
response = await client.post(
"/api/v1/auth/login",
json={
"email": "testuser@example.com",
"password": "TestPassword123!"
}
json={"email": "testuser@example.com", "password": "TestPassword123!"},
)
# Login should still succeed, just without session record
@@ -34,15 +38,20 @@ class TestOAuthLoginSessionCreationFailure:
"""Test OAuth login when session creation fails."""
@pytest.mark.asyncio
async def test_oauth_login_succeeds_despite_session_failure(self, client, async_test_user):
async def test_oauth_login_succeeds_despite_session_failure(
self, client, async_test_user
):
"""Test OAuth login succeeds even if session creation fails."""
with patch('app.api.routes.auth.session_crud.create_session', side_effect=Exception("Session failed")):
with patch(
"app.api.routes.auth.session_crud.create_session",
side_effect=Exception("Session failed"),
):
response = await client.post(
"/api/v1/auth/login/oauth",
data={
"username": "testuser@example.com",
"password": "TestPassword123!"
}
"password": "TestPassword123!",
},
)
assert response.status_code == status.HTTP_200_OK
@@ -54,23 +63,24 @@ class TestRefreshTokenSessionUpdateFailure:
"""Test refresh token when session update fails."""
@pytest.mark.asyncio
async def test_refresh_token_succeeds_despite_session_update_failure(self, client, async_test_user):
async def test_refresh_token_succeeds_despite_session_update_failure(
self, client, async_test_user
):
"""Test that token refresh succeeds even if session update fails."""
# First login to get tokens
response = await client.post(
"/api/v1/auth/login",
json={
"email": "testuser@example.com",
"password": "TestPassword123!"
}
json={"email": "testuser@example.com", "password": "TestPassword123!"},
)
tokens = response.json()
# Mock session update to fail
with patch('app.api.routes.auth.session_crud.update_refresh_token', side_effect=Exception("Update failed")):
with patch(
"app.api.routes.auth.session_crud.update_refresh_token",
side_effect=Exception("Update failed"),
):
response = await client.post(
"/api/v1/auth/refresh",
json={"refresh_token": tokens["refresh_token"]}
"/api/v1/auth/refresh", json={"refresh_token": tokens["refresh_token"]}
)
# Should still succeed - tokens are issued before update
@@ -83,15 +93,14 @@ class TestLogoutWithExpiredToken:
"""Test logout with expired/invalid token."""
@pytest.mark.asyncio
async def test_logout_with_invalid_token_still_succeeds(self, client, async_test_user):
async def test_logout_with_invalid_token_still_succeeds(
self, client, async_test_user
):
"""Test logout succeeds even with invalid refresh token."""
# Login first
response = await client.post(
"/api/v1/auth/login",
json={
"email": "testuser@example.com",
"password": "TestPassword123!"
}
json={"email": "testuser@example.com", "password": "TestPassword123!"},
)
access_token = response.json()["access_token"]
@@ -99,7 +108,7 @@ class TestLogoutWithExpiredToken:
response = await client.post(
"/api/v1/auth/logout",
headers={"Authorization": f"Bearer {access_token}"},
json={"refresh_token": "invalid.token.here"}
json={"refresh_token": "invalid.token.here"},
)
# Should succeed (idempotent)
@@ -116,19 +125,16 @@ class TestLogoutWithNonExistentSession:
"""Test logout succeeds even if session not found."""
response = await client.post(
"/api/v1/auth/login",
json={
"email": "testuser@example.com",
"password": "TestPassword123!"
}
json={"email": "testuser@example.com", "password": "TestPassword123!"},
)
tokens = response.json()
# Mock session lookup to return None
with patch('app.api.routes.auth.session_crud.get_by_jti', return_value=None):
with patch("app.api.routes.auth.session_crud.get_by_jti", return_value=None):
response = await client.post(
"/api/v1/auth/logout",
headers={"Authorization": f"Bearer {tokens['access_token']}"},
json={"refresh_token": tokens["refresh_token"]}
json={"refresh_token": tokens["refresh_token"]},
)
# Should succeed (idempotent)
@@ -139,23 +145,25 @@ class TestLogoutUnexpectedError:
"""Test logout with unexpected errors."""
@pytest.mark.asyncio
async def test_logout_with_unexpected_error_returns_success(self, client, async_test_user):
async def test_logout_with_unexpected_error_returns_success(
self, client, async_test_user
):
"""Test logout returns success even on unexpected errors."""
response = await client.post(
"/api/v1/auth/login",
json={
"email": "testuser@example.com",
"password": "TestPassword123!"
}
json={"email": "testuser@example.com", "password": "TestPassword123!"},
)
tokens = response.json()
# Mock to raise unexpected error
with patch('app.api.routes.auth.session_crud.get_by_jti', side_effect=Exception("Unexpected error")):
with patch(
"app.api.routes.auth.session_crud.get_by_jti",
side_effect=Exception("Unexpected error"),
):
response = await client.post(
"/api/v1/auth/logout",
headers={"Authorization": f"Bearer {tokens['access_token']}"},
json={"refresh_token": tokens["refresh_token"]}
json={"refresh_token": tokens["refresh_token"]},
)
# Should still return success (don't expose errors)
@@ -172,18 +180,18 @@ class TestLogoutAllUnexpectedError:
"""Test logout-all handles database errors."""
response = await client.post(
"/api/v1/auth/login",
json={
"email": "testuser@example.com",
"password": "TestPassword123!"
}
json={"email": "testuser@example.com", "password": "TestPassword123!"},
)
access_token = response.json()["access_token"]
# Mock to raise database error
with patch('app.api.routes.auth.session_crud.deactivate_all_user_sessions', side_effect=Exception("DB error")):
with patch(
"app.api.routes.auth.session_crud.deactivate_all_user_sessions",
side_effect=Exception("DB error"),
):
response = await client.post(
"/api/v1/auth/logout-all",
headers={"Authorization": f"Bearer {access_token}"}
headers={"Authorization": f"Bearer {access_token}"},
)
assert response.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR
@@ -193,7 +201,9 @@ class TestPasswordResetConfirmSessionInvalidation:
"""Test password reset invalidates sessions."""
@pytest.mark.asyncio
async def test_password_reset_continues_despite_session_invalidation_failure(self, client, async_test_user):
async def test_password_reset_continues_despite_session_invalidation_failure(
self, client, async_test_user
):
"""Test password reset succeeds even if session invalidation fails."""
# Create a valid password reset token
from app.utils.security import create_password_reset_token
@@ -201,13 +211,13 @@ class TestPasswordResetConfirmSessionInvalidation:
token = create_password_reset_token(async_test_user.email)
# Mock session invalidation to fail
with patch('app.api.routes.auth.session_crud.deactivate_all_user_sessions', side_effect=Exception("Invalidation failed")):
with patch(
"app.api.routes.auth.session_crud.deactivate_all_user_sessions",
side_effect=Exception("Invalidation failed"),
):
response = await client.post(
"/api/v1/auth/password-reset/confirm",
json={
"token": token,
"new_password": "NewPassword123!"
}
json={"token": token, "new_password": "NewPassword123!"},
)
# Should still succeed - password was reset

View File

@@ -2,22 +2,22 @@
"""
Tests for password reset endpoints.
"""
from unittest.mock import patch
import pytest
import pytest_asyncio
from unittest.mock import patch, AsyncMock, MagicMock
from fastapi import status
from sqlalchemy import select
from app.schemas.users import PasswordResetRequest, PasswordResetConfirm
from app.utils.security import create_password_reset_token
from app.models.user import User
from app.utils.security import create_password_reset_token
# Disable rate limiting for tests
@pytest.fixture(autouse=True)
def disable_rate_limit():
"""Disable rate limiting for all tests in this module."""
with patch('app.api.routes.auth.limiter.enabled', False):
with patch("app.api.routes.auth.limiter.enabled", False):
yield
@@ -27,12 +27,14 @@ class TestPasswordResetRequest:
@pytest.mark.asyncio
async def test_password_reset_request_valid_email(self, client, async_test_user):
"""Test password reset request with valid email."""
with patch('app.api.routes.auth.email_service.send_password_reset_email') as mock_send:
with patch(
"app.api.routes.auth.email_service.send_password_reset_email"
) as mock_send:
mock_send.return_value = True
response = await client.post(
"/api/v1/auth/password-reset/request",
json={"email": async_test_user.email}
json={"email": async_test_user.email},
)
assert response.status_code == status.HTTP_200_OK
@@ -50,10 +52,12 @@ class TestPasswordResetRequest:
@pytest.mark.asyncio
async def test_password_reset_request_nonexistent_email(self, client):
"""Test password reset request with non-existent email."""
with patch('app.api.routes.auth.email_service.send_password_reset_email') as mock_send:
with patch(
"app.api.routes.auth.email_service.send_password_reset_email"
) as mock_send:
response = await client.post(
"/api/v1/auth/password-reset/request",
json={"email": "nonexistent@example.com"}
json={"email": "nonexistent@example.com"},
)
# Should still return success to prevent email enumeration
@@ -65,20 +69,26 @@ class TestPasswordResetRequest:
mock_send.assert_not_called()
@pytest.mark.asyncio
async def test_password_reset_request_inactive_user(self, client, async_test_db, async_test_user):
async def test_password_reset_request_inactive_user(
self, client, async_test_db, async_test_user
):
"""Test password reset request with inactive user."""
# Deactivate user
test_engine, AsyncTestingSessionLocal = async_test_db
_test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session:
result = await session.execute(select(User).where(User.id == async_test_user.id))
result = await session.execute(
select(User).where(User.id == async_test_user.id)
)
user_in_session = result.scalar_one_or_none()
user_in_session.is_active = False
await session.commit()
with patch('app.api.routes.auth.email_service.send_password_reset_email') as mock_send:
with patch(
"app.api.routes.auth.email_service.send_password_reset_email"
) as mock_send:
response = await client.post(
"/api/v1/auth/password-reset/request",
json={"email": async_test_user.email}
json={"email": async_test_user.email},
)
# Should still return success to prevent email enumeration
@@ -93,8 +103,7 @@ class TestPasswordResetRequest:
async def test_password_reset_request_invalid_email_format(self, client):
"""Test password reset request with invalid email format."""
response = await client.post(
"/api/v1/auth/password-reset/request",
json={"email": "not-an-email"}
"/api/v1/auth/password-reset/request", json={"email": "not-an-email"}
)
assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY
@@ -102,22 +111,23 @@ class TestPasswordResetRequest:
@pytest.mark.asyncio
async def test_password_reset_request_missing_email(self, client):
"""Test password reset request without email."""
response = await client.post(
"/api/v1/auth/password-reset/request",
json={}
)
response = await client.post("/api/v1/auth/password-reset/request", json={})
assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY
@pytest.mark.asyncio
async def test_password_reset_request_email_service_error(self, client, async_test_user):
async def test_password_reset_request_email_service_error(
self, client, async_test_user
):
"""Test password reset when email service fails."""
with patch('app.api.routes.auth.email_service.send_password_reset_email') as mock_send:
with patch(
"app.api.routes.auth.email_service.send_password_reset_email"
) as mock_send:
mock_send.side_effect = Exception("SMTP Error")
response = await client.post(
"/api/v1/auth/password-reset/request",
json={"email": async_test_user.email}
json={"email": async_test_user.email},
)
# Should still return success even if email fails
@@ -128,14 +138,16 @@ class TestPasswordResetRequest:
@pytest.mark.asyncio
async def test_password_reset_request_rate_limiting(self, client, async_test_user):
"""Test that password reset requests are rate limited."""
with patch('app.api.routes.auth.email_service.send_password_reset_email') as mock_send:
with patch(
"app.api.routes.auth.email_service.send_password_reset_email"
) as mock_send:
mock_send.return_value = True
# Make multiple requests quickly (3/minute limit)
for _ in range(3):
response = await client.post(
"/api/v1/auth/password-reset/request",
json={"email": async_test_user.email}
json={"email": async_test_user.email},
)
assert response.status_code == status.HTTP_200_OK
@@ -144,7 +156,9 @@ class TestPasswordResetConfirm:
"""Tests for POST /auth/password-reset/confirm endpoint."""
@pytest.mark.asyncio
async def test_password_reset_confirm_valid_token(self, client, async_test_user, async_test_db):
async def test_password_reset_confirm_valid_token(
self, client, async_test_user, async_test_db
):
"""Test password reset confirmation with valid token."""
# Generate valid token
token = create_password_reset_token(async_test_user.email)
@@ -152,10 +166,7 @@ class TestPasswordResetConfirm:
response = await client.post(
"/api/v1/auth/password-reset/confirm",
json={
"token": token,
"new_password": new_password
}
json={"token": token, "new_password": new_password},
)
assert response.status_code == status.HTTP_200_OK
@@ -164,11 +175,14 @@ class TestPasswordResetConfirm:
assert "successfully" in data["message"].lower()
# Verify user can login with new password
test_engine, AsyncTestingSessionLocal = async_test_db
_test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session:
result = await session.execute(select(User).where(User.id == async_test_user.id))
result = await session.execute(
select(User).where(User.id == async_test_user.id)
)
updated_user = result.scalar_one_or_none()
from app.core.auth import verify_password
assert verify_password(new_password, updated_user.password_hash) is True
@pytest.mark.asyncio
@@ -184,10 +198,7 @@ class TestPasswordResetConfirm:
response = await client.post(
"/api/v1/auth/password-reset/confirm",
json={
"token": token,
"new_password": "NewSecure123!"
}
json={"token": token, "new_password": "NewSecure123!"},
)
assert response.status_code == status.HTTP_400_BAD_REQUEST
@@ -202,10 +213,7 @@ class TestPasswordResetConfirm:
"""Test password reset confirmation with invalid token."""
response = await client.post(
"/api/v1/auth/password-reset/confirm",
json={
"token": "invalid_token_xyz",
"new_password": "NewSecure123!"
}
json={"token": "invalid_token_xyz", "new_password": "NewSecure123!"},
)
assert response.status_code == status.HTTP_400_BAD_REQUEST
@@ -222,19 +230,18 @@ class TestPasswordResetConfirm:
# Create valid token and tamper with it
token = create_password_reset_token(async_test_user.email)
decoded = base64.urlsafe_b64decode(token.encode('utf-8')).decode('utf-8')
decoded = base64.urlsafe_b64decode(token.encode("utf-8")).decode("utf-8")
token_data = json.loads(decoded)
token_data["payload"]["email"] = "hacker@example.com"
# Re-encode tampered token
tampered = base64.urlsafe_b64encode(json.dumps(token_data).encode('utf-8')).decode('utf-8')
tampered = base64.urlsafe_b64encode(
json.dumps(token_data).encode("utf-8")
).decode("utf-8")
response = await client.post(
"/api/v1/auth/password-reset/confirm",
json={
"token": tampered,
"new_password": "NewSecure123!"
}
json={"token": tampered, "new_password": "NewSecure123!"},
)
assert response.status_code == status.HTTP_400_BAD_REQUEST
@@ -247,10 +254,7 @@ class TestPasswordResetConfirm:
response = await client.post(
"/api/v1/auth/password-reset/confirm",
json={
"token": token,
"new_password": "NewSecure123!"
}
json={"token": token, "new_password": "NewSecure123!"},
)
assert response.status_code == status.HTTP_404_NOT_FOUND
@@ -260,12 +264,16 @@ class TestPasswordResetConfirm:
assert "not found" in error_msg
@pytest.mark.asyncio
async def test_password_reset_confirm_inactive_user(self, client, async_test_user, async_test_db):
async def test_password_reset_confirm_inactive_user(
self, client, async_test_user, async_test_db
):
"""Test password reset confirmation for inactive user."""
# Deactivate user
test_engine, AsyncTestingSessionLocal = async_test_db
_test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session:
result = await session.execute(select(User).where(User.id == async_test_user.id))
result = await session.execute(
select(User).where(User.id == async_test_user.id)
)
user_in_session = result.scalar_one_or_none()
user_in_session.is_active = False
await session.commit()
@@ -274,10 +282,7 @@ class TestPasswordResetConfirm:
response = await client.post(
"/api/v1/auth/password-reset/confirm",
json={
"token": token,
"new_password": "NewSecure123!"
}
json={"token": token, "new_password": "NewSecure123!"},
)
assert response.status_code == status.HTTP_400_BAD_REQUEST
@@ -301,10 +306,7 @@ class TestPasswordResetConfirm:
for weak_password in weak_passwords:
response = await client.post(
"/api/v1/auth/password-reset/confirm",
json={
"token": token,
"new_password": weak_password
}
json={"token": token, "new_password": weak_password},
)
assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY
@@ -315,15 +317,14 @@ class TestPasswordResetConfirm:
# Missing token
response = await client.post(
"/api/v1/auth/password-reset/confirm",
json={"new_password": "NewSecure123!"}
json={"new_password": "NewSecure123!"},
)
assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY
# Missing password
token = create_password_reset_token("test@example.com")
response = await client.post(
"/api/v1/auth/password-reset/confirm",
json={"token": token}
"/api/v1/auth/password-reset/confirm", json={"token": token}
)
assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY
@@ -333,15 +334,12 @@ class TestPasswordResetConfirm:
token = create_password_reset_token(async_test_user.email)
# Mock the database commit to raise an exception
with patch('app.api.routes.auth.user_crud.get_by_email') as mock_get:
with patch("app.api.routes.auth.user_crud.get_by_email") as mock_get:
mock_get.side_effect = Exception("Database error")
response = await client.post(
"/api/v1/auth/password-reset/confirm",
json={
"token": token,
"new_password": "NewSecure123!"
}
json={"token": token, "new_password": "NewSecure123!"},
)
assert response.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR
@@ -351,18 +349,22 @@ class TestPasswordResetConfirm:
assert "error" in error_msg or "resetting" in error_msg
@pytest.mark.asyncio
async def test_password_reset_full_flow(self, client, async_test_user, async_test_db):
async def test_password_reset_full_flow(
self, client, async_test_user, async_test_db
):
"""Test complete password reset flow."""
original_password = async_test_user.password_hash
new_password = "BrandNew123!"
# Step 1: Request password reset
with patch('app.api.routes.auth.email_service.send_password_reset_email') as mock_send:
with patch(
"app.api.routes.auth.email_service.send_password_reset_email"
) as mock_send:
mock_send.return_value = True
response = await client.post(
"/api/v1/auth/password-reset/request",
json={"email": async_test_user.email}
json={"email": async_test_user.email},
)
assert response.status_code == status.HTTP_200_OK
@@ -374,29 +376,24 @@ class TestPasswordResetConfirm:
# Step 2: Confirm password reset
response = await client.post(
"/api/v1/auth/password-reset/confirm",
json={
"token": reset_token,
"new_password": new_password
}
json={"token": reset_token, "new_password": new_password},
)
assert response.status_code == status.HTTP_200_OK
# Step 3: Verify old password doesn't work
test_engine, AsyncTestingSessionLocal = async_test_db
_test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session:
result = await session.execute(select(User).where(User.id == async_test_user.id))
result = await session.execute(
select(User).where(User.id == async_test_user.id)
)
updated_user = result.scalar_one_or_none()
from app.core.auth import verify_password
assert updated_user.password_hash != original_password
# Step 4: Verify new password works
response = await client.post(
"/api/v1/auth/login",
json={
"email": async_test_user.email,
"password": new_password
}
json={"email": async_test_user.email, "password": new_password},
)
assert response.status_code == status.HTTP_200_OK

View File

@@ -8,11 +8,10 @@ Critical security tests covering:
These tests prevent real-world attack scenarios.
"""
import pytest
from httpx import AsyncClient
from sqlalchemy.ext.asyncio import AsyncSession
from app.core.auth import create_refresh_token
from app.crud.session import session as session_crud
from app.models.user import User
@@ -30,10 +29,7 @@ class TestRevokedSessionSecurity:
@pytest.mark.asyncio
async def test_refresh_token_rejected_after_logout(
self,
client: AsyncClient,
async_test_db,
async_test_user: User
self, client: AsyncClient, async_test_db, async_test_user: User
):
"""
Test that refresh tokens are rejected after session is deactivated.
@@ -45,10 +41,10 @@ class TestRevokedSessionSecurity:
4. Attacker tries to use stolen refresh token
5. System MUST reject it (session revoked)
"""
test_engine, SessionLocal = async_test_db
_test_engine, SessionLocal = async_test_db
# Step 1: Create a session and refresh token for the user
async with SessionLocal() as session:
async with SessionLocal():
# Login to get tokens
response = await client.post(
"/api/v1/auth/login",
@@ -64,8 +60,7 @@ class TestRevokedSessionSecurity:
# Step 2: Verify refresh token works before logout
response = await client.post(
"/api/v1/auth/refresh",
json={"refresh_token": refresh_token}
"/api/v1/auth/refresh", json={"refresh_token": refresh_token}
)
assert response.status_code == 200, "Refresh should work before logout"
@@ -73,14 +68,13 @@ class TestRevokedSessionSecurity:
response = await client.post(
"/api/v1/auth/logout",
headers={"Authorization": f"Bearer {access_token}"},
json={"refresh_token": refresh_token}
json={"refresh_token": refresh_token},
)
assert response.status_code == 200, "Logout should succeed"
# Step 4: Attacker tries to use stolen refresh token
response = await client.post(
"/api/v1/auth/refresh",
json={"refresh_token": refresh_token}
"/api/v1/auth/refresh", json={"refresh_token": refresh_token}
)
# Step 5: System MUST reject (covers lines 261-262)
@@ -93,10 +87,7 @@ class TestRevokedSessionSecurity:
@pytest.mark.asyncio
async def test_refresh_token_rejected_for_deleted_session(
self,
client: AsyncClient,
async_test_db,
async_test_user: User
self, client: AsyncClient, async_test_db, async_test_user: User
):
"""
Test that tokens for deleted sessions are rejected.
@@ -104,7 +95,7 @@ class TestRevokedSessionSecurity:
Attack Scenario:
Admin deletes a session from database, but attacker has the token.
"""
test_engine, SessionLocal = async_test_db
_test_engine, SessionLocal = async_test_db
# Step 1: Login to create a session
response = await client.post(
@@ -120,6 +111,7 @@ class TestRevokedSessionSecurity:
# Step 2: Manually delete the session from database (simulating admin action)
from app.core.auth import decode_token
token_data = decode_token(refresh_token, verify_type="refresh")
jti = token_data.jti
@@ -132,15 +124,17 @@ class TestRevokedSessionSecurity:
# Step 3: Try to use the refresh token
response = await client.post(
"/api/v1/auth/refresh",
json={"refresh_token": refresh_token}
"/api/v1/auth/refresh", json={"refresh_token": refresh_token}
)
# Should reject (session doesn't exist)
assert response.status_code == 401
data = response.json()
if "errors" in data:
assert "revoked" in data["errors"][0]["message"].lower() or "session" in data["errors"][0]["message"].lower()
assert (
"revoked" in data["errors"][0]["message"].lower()
or "session" in data["errors"][0]["message"].lower()
)
else:
assert "revoked" in data.get("detail", "").lower()
@@ -162,7 +156,7 @@ class TestSessionHijackingSecurity:
client: AsyncClient,
async_test_db,
async_test_user: User,
async_test_superuser: User
async_test_superuser: User,
):
"""
Test that users cannot logout other users' sessions.
@@ -173,7 +167,7 @@ class TestSessionHijackingSecurity:
3. User A tries to logout User B's session
4. System MUST reject (cross-user attack)
"""
test_engine, SessionLocal = async_test_db
_test_engine, _SessionLocal = async_test_db
# Step 1: User A logs in
response = await client.post(
@@ -202,8 +196,10 @@ class TestSessionHijackingSecurity:
# Step 3: User A tries to logout User B's session using User B's refresh token
response = await client.post(
"/api/v1/auth/logout",
headers={"Authorization": f"Bearer {user_a_access}"}, # User A's access token
json={"refresh_token": user_b_refresh} # But User B's refresh token
headers={
"Authorization": f"Bearer {user_a_access}"
}, # User A's access token
json={"refresh_token": user_b_refresh}, # But User B's refresh token
)
# Step 4: System MUST reject (covers lines 509-513)
@@ -217,9 +213,7 @@ class TestSessionHijackingSecurity:
@pytest.mark.asyncio
async def test_users_can_logout_their_own_sessions(
self,
client: AsyncClient,
async_test_user: User
self, client: AsyncClient, async_test_user: User
):
"""
Sanity check: Users CAN logout their own sessions.
@@ -241,6 +235,8 @@ class TestSessionHijackingSecurity:
response = await client.post(
"/api/v1/auth/logout",
headers={"Authorization": f"Bearer {tokens['access_token']}"},
json={"refresh_token": tokens["refresh_token"]}
json={"refresh_token": tokens["refresh_token"]},
)
assert response.status_code == 200, (
"Users should be able to logout their own sessions"
)
assert response.status_code == 200, "Users should be able to logout their own sessions"

View File

@@ -5,16 +5,18 @@ Tests for organization routes (user endpoints).
These test the routes in app/api/routes/organizations.py which allow
users to view and manage organizations they belong to.
"""
from unittest.mock import patch
from uuid import uuid4
import pytest
import pytest_asyncio
from fastapi import status
from uuid import uuid4
from unittest.mock import patch, AsyncMock
from app.core.auth import get_password_hash
from app.models.organization import Organization
from app.models.user import User
from app.models.user_organization import UserOrganization, OrganizationRole
from app.core.auth import get_password_hash
from app.models.user_organization import OrganizationRole, UserOrganization
@pytest_asyncio.fixture
@@ -22,10 +24,7 @@ async def user_token(client, async_test_user):
"""Get access token for regular user."""
response = await client.post(
"/api/v1/auth/login",
json={
"email": "testuser@example.com",
"password": "TestPassword123!"
}
json={"email": "testuser@example.com", "password": "TestPassword123!"},
)
assert response.status_code == 200
return response.json()["access_token"]
@@ -34,7 +33,7 @@ async def user_token(client, async_test_user):
@pytest_asyncio.fixture
async def second_user(async_test_db):
"""Create a second test user."""
test_engine, AsyncTestingSessionLocal = async_test_db
_test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session:
user = User(
id=uuid4(),
@@ -56,12 +55,12 @@ async def second_user(async_test_db):
@pytest_asyncio.fixture
async def test_org_with_user_member(async_test_db, async_test_user):
"""Create a test organization with async_test_user as a member."""
test_engine, AsyncTestingSessionLocal = async_test_db
_test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session:
org = Organization(
name="Member Org",
slug="member-org",
description="Test organization where user is a member"
description="Test organization where user is a member",
)
session.add(org)
await session.commit()
@@ -72,7 +71,7 @@ async def test_org_with_user_member(async_test_db, async_test_user):
user_id=async_test_user.id,
organization_id=org.id,
role=OrganizationRole.MEMBER,
is_active=True
is_active=True,
)
session.add(membership)
await session.commit()
@@ -83,12 +82,12 @@ async def test_org_with_user_member(async_test_db, async_test_user):
@pytest_asyncio.fixture
async def test_org_with_user_admin(async_test_db, async_test_user):
"""Create a test organization with async_test_user as an admin."""
test_engine, AsyncTestingSessionLocal = async_test_db
_test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session:
org = Organization(
name="Admin Org",
slug="admin-org",
description="Test organization where user is an admin"
description="Test organization where user is an admin",
)
session.add(org)
await session.commit()
@@ -99,7 +98,7 @@ async def test_org_with_user_admin(async_test_db, async_test_user):
user_id=async_test_user.id,
organization_id=org.id,
role=OrganizationRole.ADMIN,
is_active=True
is_active=True,
)
session.add(membership)
await session.commit()
@@ -110,12 +109,12 @@ async def test_org_with_user_admin(async_test_db, async_test_user):
@pytest_asyncio.fixture
async def test_org_with_user_owner(async_test_db, async_test_user):
"""Create a test organization with async_test_user as owner."""
test_engine, AsyncTestingSessionLocal = async_test_db
_test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session:
org = Organization(
name="Owner Org",
slug="owner-org",
description="Test organization where user is owner"
description="Test organization where user is owner",
)
session.add(org)
await session.commit()
@@ -126,7 +125,7 @@ async def test_org_with_user_owner(async_test_db, async_test_user):
user_id=async_test_user.id,
organization_id=org.id,
role=OrganizationRole.OWNER,
is_active=True
is_active=True,
)
session.add(membership)
await session.commit()
@@ -136,21 +135,18 @@ async def test_org_with_user_owner(async_test_db, async_test_user):
# ===== GET /api/v1/organizations/me =====
class TestGetMyOrganizations:
"""Tests for GET /api/v1/organizations/me endpoint."""
@pytest.mark.asyncio
async def test_get_my_organizations_success(
self,
client,
user_token,
test_org_with_user_member,
test_org_with_user_admin
self, client, user_token, test_org_with_user_member, test_org_with_user_admin
):
"""Test successfully getting user's organizations (covers lines 54-79)."""
response = await client.get(
"/api/v1/organizations/me",
headers={"Authorization": f"Bearer {user_token}"}
headers={"Authorization": f"Bearer {user_token}"},
)
assert response.status_code == status.HTTP_200_OK
@@ -167,21 +163,15 @@ class TestGetMyOrganizations:
@pytest.mark.asyncio
async def test_get_my_organizations_filter_active(
self,
client,
async_test_db,
async_test_user,
user_token
self, client, async_test_db, async_test_user, user_token
):
"""Test filtering organizations by active status."""
test_engine, AsyncTestingSessionLocal = async_test_db
_test_engine, AsyncTestingSessionLocal = async_test_db
# Create active org
async with AsyncTestingSessionLocal() as session:
active_org = Organization(
name="Active Org",
slug="active-org-filter",
is_active=True
name="Active Org", slug="active-org-filter", is_active=True
)
session.add(active_org)
await session.commit()
@@ -192,14 +182,14 @@ class TestGetMyOrganizations:
user_id=async_test_user.id,
organization_id=active_org.id,
role=OrganizationRole.MEMBER,
is_active=True
is_active=True,
)
session.add(membership)
await session.commit()
response = await client.get(
"/api/v1/organizations/me?is_active=true",
headers={"Authorization": f"Bearer {user_token}"}
headers={"Authorization": f"Bearer {user_token}"},
)
assert response.status_code == status.HTTP_200_OK
@@ -209,7 +199,7 @@ class TestGetMyOrganizations:
@pytest.mark.asyncio
async def test_get_my_organizations_empty(self, client, async_test_db):
"""Test getting organizations when user has none."""
test_engine, AsyncTestingSessionLocal = async_test_db
_test_engine, AsyncTestingSessionLocal = async_test_db
# Create user with no org memberships
async with AsyncTestingSessionLocal() as session:
@@ -219,7 +209,7 @@ class TestGetMyOrganizations:
password_hash=get_password_hash("TestPassword123!"),
first_name="No",
last_name="Org",
is_active=True
is_active=True,
)
session.add(user)
await session.commit()
@@ -227,13 +217,12 @@ class TestGetMyOrganizations:
# Login to get token
login_response = await client.post(
"/api/v1/auth/login",
json={"email": "noorg@example.com", "password": "TestPassword123!"}
json={"email": "noorg@example.com", "password": "TestPassword123!"},
)
token = login_response.json()["access_token"]
response = await client.get(
"/api/v1/organizations/me",
headers={"Authorization": f"Bearer {token}"}
"/api/v1/organizations/me", headers={"Authorization": f"Bearer {token}"}
)
assert response.status_code == status.HTTP_200_OK
@@ -243,20 +232,18 @@ class TestGetMyOrganizations:
# ===== GET /api/v1/organizations/{organization_id} =====
class TestGetOrganization:
"""Tests for GET /api/v1/organizations/{organization_id} endpoint."""
@pytest.mark.asyncio
async def test_get_organization_success(
self,
client,
user_token,
test_org_with_user_member
self, client, user_token, test_org_with_user_member
):
"""Test successfully getting organization details (covers lines 103-122)."""
response = await client.get(
f"/api/v1/organizations/{test_org_with_user_member.id}",
headers={"Authorization": f"Bearer {user_token}"}
headers={"Authorization": f"Bearer {user_token}"},
)
assert response.status_code == status.HTTP_200_OK
@@ -272,7 +259,7 @@ class TestGetOrganization:
fake_org_id = uuid4()
response = await client.get(
f"/api/v1/organizations/{fake_org_id}",
headers={"Authorization": f"Bearer {user_token}"}
headers={"Authorization": f"Bearer {user_token}"},
)
# Permission dependency checks membership before endpoint logic
@@ -283,20 +270,14 @@ class TestGetOrganization:
@pytest.mark.asyncio
async def test_get_organization_not_member(
self,
client,
async_test_db,
async_test_user
self, client, async_test_db, async_test_user
):
"""Test getting organization where user is not a member fails."""
test_engine, AsyncTestingSessionLocal = async_test_db
_test_engine, AsyncTestingSessionLocal = async_test_db
# Create org without adding user
async with AsyncTestingSessionLocal() as session:
org = Organization(
name="Not Member Org",
slug="not-member-org"
)
org = Organization(name="Not Member Org", slug="not-member-org")
session.add(org)
await session.commit()
await session.refresh(org)
@@ -305,13 +286,13 @@ class TestGetOrganization:
# Login as user
login_response = await client.post(
"/api/v1/auth/login",
json={"email": "testuser@example.com", "password": "TestPassword123!"}
json={"email": "testuser@example.com", "password": "TestPassword123!"},
)
token = login_response.json()["access_token"]
response = await client.get(
f"/api/v1/organizations/{org_id}",
headers={"Authorization": f"Bearer {token}"}
headers={"Authorization": f"Bearer {token}"},
)
# Should fail permission check
@@ -320,6 +301,7 @@ class TestGetOrganization:
# ===== GET /api/v1/organizations/{organization_id}/members =====
class TestGetOrganizationMembers:
"""Tests for GET /api/v1/organizations/{organization_id}/members endpoint."""
@@ -331,10 +313,10 @@ class TestGetOrganizationMembers:
async_test_user,
second_user,
user_token,
test_org_with_user_member
test_org_with_user_member,
):
"""Test successfully getting organization members (covers lines 150-168)."""
test_engine, AsyncTestingSessionLocal = async_test_db
_test_engine, AsyncTestingSessionLocal = async_test_db
# Add second user to org
async with AsyncTestingSessionLocal() as session:
@@ -342,14 +324,14 @@ class TestGetOrganizationMembers:
user_id=second_user.id,
organization_id=test_org_with_user_member.id,
role=OrganizationRole.MEMBER,
is_active=True
is_active=True,
)
session.add(membership)
await session.commit()
response = await client.get(
f"/api/v1/organizations/{test_org_with_user_member.id}/members",
headers={"Authorization": f"Bearer {user_token}"}
headers={"Authorization": f"Bearer {user_token}"},
)
assert response.status_code == status.HTTP_200_OK
@@ -360,15 +342,12 @@ class TestGetOrganizationMembers:
@pytest.mark.asyncio
async def test_get_organization_members_with_pagination(
self,
client,
user_token,
test_org_with_user_member
self, client, user_token, test_org_with_user_member
):
"""Test pagination parameters."""
response = await client.get(
f"/api/v1/organizations/{test_org_with_user_member.id}/members?page=1&limit=10",
headers={"Authorization": f"Bearer {user_token}"}
headers={"Authorization": f"Bearer {user_token}"},
)
assert response.status_code == status.HTTP_200_OK
@@ -385,10 +364,10 @@ class TestGetOrganizationMembers:
async_test_user,
second_user,
user_token,
test_org_with_user_member
test_org_with_user_member,
):
"""Test filtering members by active status."""
test_engine, AsyncTestingSessionLocal = async_test_db
_test_engine, AsyncTestingSessionLocal = async_test_db
# Add second user as inactive member
async with AsyncTestingSessionLocal() as session:
@@ -396,7 +375,7 @@ class TestGetOrganizationMembers:
user_id=second_user.id,
organization_id=test_org_with_user_member.id,
role=OrganizationRole.MEMBER,
is_active=False
is_active=False,
)
session.add(membership)
await session.commit()
@@ -404,7 +383,7 @@ class TestGetOrganizationMembers:
# Filter for active only
response = await client.get(
f"/api/v1/organizations/{test_org_with_user_member.id}/members?is_active=true",
headers={"Authorization": f"Bearer {user_token}"}
headers={"Authorization": f"Bearer {user_token}"},
)
assert response.status_code == status.HTTP_200_OK
@@ -416,31 +395,26 @@ class TestGetOrganizationMembers:
# ===== PUT /api/v1/organizations/{organization_id} =====
class TestUpdateOrganization:
"""Tests for PUT /api/v1/organizations/{organization_id} endpoint."""
@pytest.mark.asyncio
async def test_update_organization_as_admin_success(
self,
client,
async_test_user,
test_org_with_user_admin
self, client, async_test_user, test_org_with_user_admin
):
"""Test successfully updating organization as admin (covers lines 193-215)."""
# Login as admin user
login_response = await client.post(
"/api/v1/auth/login",
json={"email": "testuser@example.com", "password": "TestPassword123!"}
json={"email": "testuser@example.com", "password": "TestPassword123!"},
)
admin_token = login_response.json()["access_token"]
response = await client.put(
f"/api/v1/organizations/{test_org_with_user_admin.id}",
json={
"name": "Updated Admin Org",
"description": "Updated description"
},
headers={"Authorization": f"Bearer {admin_token}"}
json={"name": "Updated Admin Org", "description": "Updated description"},
headers={"Authorization": f"Bearer {admin_token}"},
)
assert response.status_code == status.HTTP_200_OK
@@ -450,23 +424,20 @@ class TestUpdateOrganization:
@pytest.mark.asyncio
async def test_update_organization_as_owner_success(
self,
client,
async_test_user,
test_org_with_user_owner
self, client, async_test_user, test_org_with_user_owner
):
"""Test successfully updating organization as owner."""
# Login as owner user
login_response = await client.post(
"/api/v1/auth/login",
json={"email": "testuser@example.com", "password": "TestPassword123!"}
json={"email": "testuser@example.com", "password": "TestPassword123!"},
)
owner_token = login_response.json()["access_token"]
response = await client.put(
f"/api/v1/organizations/{test_org_with_user_owner.id}",
json={"name": "Updated Owner Org"},
headers={"Authorization": f"Bearer {owner_token}"}
headers={"Authorization": f"Bearer {owner_token}"},
)
assert response.status_code == status.HTTP_200_OK
@@ -475,16 +446,13 @@ class TestUpdateOrganization:
@pytest.mark.asyncio
async def test_update_organization_as_member_fails(
self,
client,
user_token,
test_org_with_user_member
self, client, user_token, test_org_with_user_member
):
"""Test updating organization as regular member fails."""
response = await client.put(
f"/api/v1/organizations/{test_org_with_user_member.id}",
json={"name": "Should Fail"},
headers={"Authorization": f"Bearer {user_token}"}
headers={"Authorization": f"Bearer {user_token}"},
)
# Should fail permission check (need admin or owner)
@@ -492,15 +460,13 @@ class TestUpdateOrganization:
@pytest.mark.asyncio
async def test_update_organization_not_found(
self,
client,
test_org_with_user_admin
self, client, test_org_with_user_admin
):
"""Test updating nonexistent organization returns 403 (permission check first)."""
# Login as admin
login_response = await client.post(
"/api/v1/auth/login",
json={"email": "testuser@example.com", "password": "TestPassword123!"}
json={"email": "testuser@example.com", "password": "TestPassword123!"},
)
admin_token = login_response.json()["access_token"]
@@ -508,7 +474,7 @@ class TestUpdateOrganization:
response = await client.put(
f"/api/v1/organizations/{fake_org_id}",
json={"name": "Updated"},
headers={"Authorization": f"Bearer {admin_token}"}
headers={"Authorization": f"Bearer {admin_token}"},
)
# Permission dependency checks admin role before endpoint logic
@@ -520,6 +486,7 @@ class TestUpdateOrganization:
# ===== Authentication Tests =====
class TestOrganizationAuthentication:
"""Test authentication requirements for organization endpoints."""
@@ -548,14 +515,14 @@ class TestOrganizationAuthentication:
"""Test unauthenticated access to update fails."""
fake_id = uuid4()
response = await client.put(
f"/api/v1/organizations/{fake_id}",
json={"name": "Test"}
f"/api/v1/organizations/{fake_id}", json={"name": "Test"}
)
assert response.status_code == status.HTTP_401_UNAUTHORIZED
# ===== Exception Handler Tests (Database Error Scenarios) =====
class TestOrganizationExceptionHandlers:
"""
Test exception handlers in organization endpoints.
@@ -566,86 +533,74 @@ class TestOrganizationExceptionHandlers:
@pytest.mark.asyncio
async def test_get_my_organizations_database_error(
self,
client,
user_token,
test_org_with_user_member
self, client, user_token, test_org_with_user_member
):
"""Test generic exception handler in get_my_organizations (covers lines 81-83)."""
with patch(
"app.crud.organization.organization.get_user_organizations_with_details",
side_effect=Exception("Database connection lost")
side_effect=Exception("Database connection lost"),
):
# The exception handler logs and re-raises, so we expect the exception
# to propagate (which proves the handler executed)
with pytest.raises(Exception, match="Database connection lost"):
await client.get(
"/api/v1/organizations/me",
headers={"Authorization": f"Bearer {user_token}"}
headers={"Authorization": f"Bearer {user_token}"},
)
@pytest.mark.asyncio
async def test_get_organization_database_error(
self,
client,
user_token,
test_org_with_user_member
self, client, user_token, test_org_with_user_member
):
"""Test generic exception handler in get_organization (covers lines 124-128)."""
with patch(
"app.crud.organization.organization.get",
side_effect=Exception("Database timeout")
side_effect=Exception("Database timeout"),
):
with pytest.raises(Exception, match="Database timeout"):
await client.get(
f"/api/v1/organizations/{test_org_with_user_member.id}",
headers={"Authorization": f"Bearer {user_token}"}
headers={"Authorization": f"Bearer {user_token}"},
)
@pytest.mark.asyncio
async def test_get_organization_members_database_error(
self,
client,
user_token,
test_org_with_user_member
self, client, user_token, test_org_with_user_member
):
"""Test generic exception handler in get_organization_members (covers lines 170-172)."""
with patch(
"app.crud.organization.organization.get_organization_members",
side_effect=Exception("Connection pool exhausted")
side_effect=Exception("Connection pool exhausted"),
):
with pytest.raises(Exception, match="Connection pool exhausted"):
await client.get(
f"/api/v1/organizations/{test_org_with_user_member.id}/members",
headers={"Authorization": f"Bearer {user_token}"}
headers={"Authorization": f"Bearer {user_token}"},
)
@pytest.mark.asyncio
async def test_update_organization_database_error(
self,
client,
async_test_user,
test_org_with_user_admin
self, client, async_test_user, test_org_with_user_admin
):
"""Test generic exception handler in update_organization (covers lines 217-221)."""
# Login as admin user
login_response = await client.post(
"/api/v1/auth/login",
json={"email": "testuser@example.com", "password": "TestPassword123!"}
json={"email": "testuser@example.com", "password": "TestPassword123!"},
)
admin_token = login_response.json()["access_token"]
with patch(
"app.crud.organization.organization.get",
return_value=test_org_with_user_admin
return_value=test_org_with_user_admin,
):
with patch(
"app.crud.organization.organization.update",
side_effect=Exception("Write lock timeout")
side_effect=Exception("Write lock timeout"),
):
with pytest.raises(Exception, match="Write lock timeout"):
await client.put(
f"/api/v1/organizations/{test_org_with_user_admin.id}",
json={"name": "Should Fail"},
headers={"Authorization": f"Bearer {admin_token}"}
headers={"Authorization": f"Bearer {admin_token}"},
)

View File

@@ -5,15 +5,17 @@ Tests for permission dependencies - CRITICAL SECURITY PATHS.
These tests ensure superusers can bypass organization checks correctly,
and that regular users are properly blocked.
"""
from uuid import uuid4
import pytest
import pytest_asyncio
from fastapi import status
from uuid import uuid4
from app.core.auth import get_password_hash
from app.models.organization import Organization
from app.models.user import User
from app.models.user_organization import UserOrganization, OrganizationRole
from app.core.auth import get_password_hash
from app.models.user_organization import OrganizationRole, UserOrganization
@pytest_asyncio.fixture
@@ -21,10 +23,7 @@ async def superuser_token(client, async_test_superuser):
"""Get access token for superuser."""
response = await client.post(
"/api/v1/auth/login",
json={
"email": "superuser@example.com",
"password": "SuperPassword123!"
}
json={"email": "superuser@example.com", "password": "SuperPassword123!"},
)
assert response.status_code == 200
return response.json()["access_token"]
@@ -35,10 +34,7 @@ async def regular_user_token(client, async_test_user):
"""Get access token for regular user."""
response = await client.post(
"/api/v1/auth/login",
json={
"email": "testuser@example.com",
"password": "TestPassword123!"
}
json={"email": "testuser@example.com", "password": "TestPassword123!"},
)
assert response.status_code == 200
return response.json()["access_token"]
@@ -47,12 +43,12 @@ async def regular_user_token(client, async_test_user):
@pytest_asyncio.fixture
async def test_org_no_members(async_test_db):
"""Create a test organization with NO members."""
test_engine, AsyncTestingSessionLocal = async_test_db
_test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session:
org = Organization(
name="No Members Org",
slug="no-members-org",
description="Test org with no members"
description="Test org with no members",
)
session.add(org)
await session.commit()
@@ -63,12 +59,12 @@ async def test_org_no_members(async_test_db):
@pytest_asyncio.fixture
async def test_org_with_member(async_test_db, async_test_user):
"""Create a test organization with async_test_user as member (not admin)."""
test_engine, AsyncTestingSessionLocal = async_test_db
_test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session:
org = Organization(
name="Member Only Org",
slug="member-only-org",
description="Test org where user is just a member"
description="Test org where user is just a member",
)
session.add(org)
await session.commit()
@@ -79,7 +75,7 @@ async def test_org_with_member(async_test_db, async_test_user):
user_id=async_test_user.id,
organization_id=org.id,
role=OrganizationRole.MEMBER,
is_active=True
is_active=True,
)
session.add(membership)
await session.commit()
@@ -89,6 +85,7 @@ async def test_org_with_member(async_test_db, async_test_user):
# ===== CRITICAL SECURITY TESTS: Superuser Bypass =====
class TestSuperuserBypass:
"""
CRITICAL: Test that superusers can bypass organization checks.
@@ -99,10 +96,7 @@ class TestSuperuserBypass:
@pytest.mark.asyncio
async def test_superuser_can_access_org_not_member_of(
self,
client,
superuser_token,
test_org_no_members
self, client, superuser_token, test_org_no_members
):
"""
CRITICAL: Superuser should bypass membership check (covers line 175).
@@ -111,7 +105,7 @@ class TestSuperuserBypass:
"""
response = await client.get(
f"/api/v1/organizations/{test_org_no_members.id}",
headers={"Authorization": f"Bearer {superuser_token}"}
headers={"Authorization": f"Bearer {superuser_token}"},
)
# Superuser should succeed even though they're not a member
@@ -121,15 +115,12 @@ class TestSuperuserBypass:
@pytest.mark.asyncio
async def test_regular_user_cannot_access_org_not_member_of(
self,
client,
regular_user_token,
test_org_no_members
self, client, regular_user_token, test_org_no_members
):
"""Regular user should be blocked from org they're not a member of."""
response = await client.get(
f"/api/v1/organizations/{test_org_no_members.id}",
headers={"Authorization": f"Bearer {regular_user_token}"}
headers={"Authorization": f"Bearer {regular_user_token}"},
)
# Regular user should fail permission check
@@ -137,10 +128,7 @@ class TestSuperuserBypass:
@pytest.mark.asyncio
async def test_superuser_can_update_org_not_admin_of(
self,
client,
superuser_token,
test_org_no_members
self, client, superuser_token, test_org_no_members
):
"""
CRITICAL: Superuser should bypass admin check (covers line 99).
@@ -150,7 +138,7 @@ class TestSuperuserBypass:
response = await client.put(
f"/api/v1/organizations/{test_org_no_members.id}",
json={"name": "Updated by Superuser"},
headers={"Authorization": f"Bearer {superuser_token}"}
headers={"Authorization": f"Bearer {superuser_token}"},
)
# Superuser should succeed in updating org
@@ -160,16 +148,13 @@ class TestSuperuserBypass:
@pytest.mark.asyncio
async def test_regular_member_cannot_update_org(
self,
client,
regular_user_token,
test_org_with_member
self, client, regular_user_token, test_org_with_member
):
"""Regular member (not admin) should NOT be able to update org."""
response = await client.put(
f"/api/v1/organizations/{test_org_with_member.id}",
json={"name": "Should Fail"},
headers={"Authorization": f"Bearer {regular_user_token}"}
headers={"Authorization": f"Bearer {regular_user_token}"},
)
# Member should fail - need admin or owner role
@@ -177,15 +162,12 @@ class TestSuperuserBypass:
@pytest.mark.asyncio
async def test_superuser_can_list_org_members_not_member_of(
self,
client,
superuser_token,
test_org_no_members
self, client, superuser_token, test_org_no_members
):
"""CRITICAL: Superuser should bypass membership check to list members."""
response = await client.get(
f"/api/v1/organizations/{test_org_no_members.id}/members",
headers={"Authorization": f"Bearer {superuser_token}"}
headers={"Authorization": f"Bearer {superuser_token}"},
)
# Superuser should succeed
@@ -197,13 +179,14 @@ class TestSuperuserBypass:
# ===== Edge Cases and Security Tests =====
class TestPermissionEdgeCases:
"""Test edge cases in permission system."""
@pytest.mark.asyncio
async def test_inactive_user_blocked(self, client, async_test_db):
"""Test that inactive users are blocked."""
test_engine, AsyncTestingSessionLocal = async_test_db
_test_engine, AsyncTestingSessionLocal = async_test_db
# Create inactive user
async with AsyncTestingSessionLocal() as session:
@@ -213,7 +196,7 @@ class TestPermissionEdgeCases:
password_hash=get_password_hash("TestPassword123!"),
first_name="Inactive",
last_name="User",
is_active=False # INACTIVE
is_active=False, # INACTIVE
)
session.add(user)
await session.commit()
@@ -222,7 +205,7 @@ class TestPermissionEdgeCases:
# But accessing protected endpoints should fail
login_response = await client.post(
"/api/v1/auth/login",
json={"email": "inactive@example.com", "password": "TestPassword123!"}
json={"email": "inactive@example.com", "password": "TestPassword123!"},
)
# Login might fail for inactive users depending on auth implementation
@@ -231,18 +214,18 @@ class TestPermissionEdgeCases:
# Try to access protected endpoint
response = await client.get(
"/api/v1/users/me",
headers={"Authorization": f"Bearer {token}"}
"/api/v1/users/me", headers={"Authorization": f"Bearer {token}"}
)
# Should be blocked
assert response.status_code in [status.HTTP_401_UNAUTHORIZED, status.HTTP_403_FORBIDDEN]
assert response.status_code in [
status.HTTP_401_UNAUTHORIZED,
status.HTTP_403_FORBIDDEN,
]
@pytest.mark.asyncio
async def test_nonexistent_organization_returns_403_not_404(
self,
client,
regular_user_token
self, client, regular_user_token
):
"""
Test that accessing nonexistent org returns 403, not 404.
@@ -254,7 +237,7 @@ class TestPermissionEdgeCases:
fake_org_id = uuid4()
response = await client.get(
f"/api/v1/organizations/{fake_org_id}",
headers={"Authorization": f"Bearer {regular_user_token}"}
headers={"Authorization": f"Bearer {regular_user_token}"},
)
# Should get 403 (not a member), not 404 (doesn't exist)
@@ -264,18 +247,16 @@ class TestPermissionEdgeCases:
# ===== Admin Role Tests =====
class TestAdminRolePermissions:
"""Test admin role can perform admin actions."""
@pytest_asyncio.fixture
async def test_org_with_admin(self, async_test_db, async_test_user):
"""Create org where user is ADMIN."""
test_engine, AsyncTestingSessionLocal = async_test_db
_test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session:
org = Organization(
name="Admin Org",
slug="admin-org"
)
org = Organization(name="Admin Org", slug="admin-org")
session.add(org)
await session.commit()
await session.refresh(org)
@@ -284,7 +265,7 @@ class TestAdminRolePermissions:
user_id=async_test_user.id,
organization_id=org.id,
role=OrganizationRole.ADMIN,
is_active=True
is_active=True,
)
session.add(membership)
await session.commit()
@@ -293,16 +274,13 @@ class TestAdminRolePermissions:
@pytest.mark.asyncio
async def test_admin_can_update_org(
self,
client,
regular_user_token,
test_org_with_admin
self, client, regular_user_token, test_org_with_admin
):
"""Admin should be able to update organization."""
response = await client.put(
f"/api/v1/organizations/{test_org_with_admin.id}",
json={"name": "Updated by Admin"},
headers={"Authorization": f"Bearer {regular_user_token}"}
headers={"Authorization": f"Bearer {regular_user_token}"},
)
assert response.status_code == status.HTTP_200_OK

View File

@@ -7,13 +7,13 @@ Critical security tests covering:
These tests prevent unauthorized access and privilege escalation.
"""
import pytest
from httpx import AsyncClient
from sqlalchemy.ext.asyncio import AsyncSession
from app.models.user import User
from app.models.organization import Organization
from app.crud.user import user as user_crud
from app.models.organization import Organization
from app.models.user import User
class TestInactiveUserBlocking:
@@ -29,11 +29,7 @@ class TestInactiveUserBlocking:
@pytest.mark.asyncio
async def test_inactive_user_cannot_access_protected_endpoints(
self,
client: AsyncClient,
async_test_db,
async_test_user: User,
user_token: str
self, client: AsyncClient, async_test_db, async_test_user: User, user_token: str
):
"""
Test that inactive users are blocked from protected endpoints.
@@ -44,12 +40,11 @@ class TestInactiveUserBlocking:
3. User tries to access protected endpoint with valid token
4. System MUST reject (account inactive)
"""
test_engine, SessionLocal = async_test_db
_test_engine, SessionLocal = async_test_db
# Step 1: Verify user can access endpoint while active
response = await client.get(
"/api/v1/users/me",
headers={"Authorization": f"Bearer {user_token}"}
"/api/v1/users/me", headers={"Authorization": f"Bearer {user_token}"}
)
assert response.status_code == 200, "Active user should have access"
@@ -61,8 +56,7 @@ class TestInactiveUserBlocking:
# Step 3: User tries to access endpoint with same token
response = await client.get(
"/api/v1/users/me",
headers={"Authorization": f"Bearer {user_token}"}
"/api/v1/users/me", headers={"Authorization": f"Bearer {user_token}"}
)
# Step 4: System MUST reject (covers lines 52-57)
@@ -75,18 +69,14 @@ class TestInactiveUserBlocking:
@pytest.mark.asyncio
async def test_inactive_user_blocked_from_organization_endpoints(
self,
client: AsyncClient,
async_test_db,
async_test_user: User,
user_token: str
self, client: AsyncClient, async_test_db, async_test_user: User, user_token: str
):
"""
Test that inactive users can't access organization endpoints.
Ensures the inactive check applies to ALL protected endpoints.
"""
test_engine, SessionLocal = async_test_db
_test_engine, SessionLocal = async_test_db
# Deactivate user
async with SessionLocal() as session:
@@ -97,7 +87,7 @@ class TestInactiveUserBlocking:
# Try to list organizations
response = await client.get(
"/api/v1/organizations/me",
headers={"Authorization": f"Bearer {user_token}"}
headers={"Authorization": f"Bearer {user_token}"},
)
# Must be blocked
@@ -122,7 +112,7 @@ class TestSuperuserPrivilegeEscalation:
client: AsyncClient,
async_test_db,
async_test_superuser: User,
superuser_token: str
superuser_token: str,
):
"""
Test that superusers automatically get OWNER role in organizations.
@@ -131,14 +121,11 @@ class TestSuperuserPrivilegeEscalation:
Superusers can manage any organization without being explicitly added.
This is for platform administration.
"""
test_engine, SessionLocal = async_test_db
_test_engine, SessionLocal = async_test_db
# Step 1: Create an organization (owned by someone else)
async with SessionLocal() as session:
org = Organization(
name="Test Organization",
slug="test-org"
)
org = Organization(name="Test Organization", slug="test-org")
session.add(org)
await session.commit()
await session.refresh(org)
@@ -148,7 +135,7 @@ class TestSuperuserPrivilegeEscalation:
# (They're not a member, but should auto-get OWNER role)
response = await client.get(
f"/api/v1/organizations/{org_id}",
headers={"Authorization": f"Bearer {superuser_token}"}
headers={"Authorization": f"Bearer {superuser_token}"},
)
# Step 3: Should have access (covers lines 154-157)
@@ -161,21 +148,18 @@ class TestSuperuserPrivilegeEscalation:
client: AsyncClient,
async_test_db,
async_test_superuser: User,
superuser_token: str
superuser_token: str,
):
"""
Test that superusers have full management access to all organizations.
Ensures the OWNER role privilege escalation works end-to-end.
"""
test_engine, SessionLocal = async_test_db
_test_engine, SessionLocal = async_test_db
# Create an organization
async with SessionLocal() as session:
org = Organization(
name="Test Organization",
slug="test-org"
)
org = Organization(name="Test Organization", slug="test-org")
session.add(org)
await session.commit()
await session.refresh(org)
@@ -185,34 +169,29 @@ class TestSuperuserPrivilegeEscalation:
response = await client.put(
f"/api/v1/organizations/{org_id}",
headers={"Authorization": f"Bearer {superuser_token}"},
json={"name": "Updated Name"}
json={"name": "Updated Name"},
)
# Should succeed (superuser has OWNER privileges)
assert response.status_code in [200, 404], "Superuser should be able to manage any org"
assert response.status_code in [200, 404], (
"Superuser should be able to manage any org"
)
# Note: Might be 404 if org endpoints require membership, but the role check passes
@pytest.mark.asyncio
async def test_regular_user_does_not_get_owner_role(
self,
client: AsyncClient,
async_test_db,
async_test_user: User,
user_token: str
self, client: AsyncClient, async_test_db, async_test_user: User, user_token: str
):
"""
Sanity check: Regular users don't get automatic OWNER role.
Ensures the superuser check is working correctly (line 154).
"""
test_engine, SessionLocal = async_test_db
_test_engine, SessionLocal = async_test_db
# Create an organization
async with SessionLocal() as session:
org = Organization(
name="Test Organization",
slug="test-org"
)
org = Organization(name="Test Organization", slug="test-org")
session.add(org)
await session.commit()
await session.refresh(org)
@@ -221,8 +200,10 @@ class TestSuperuserPrivilegeEscalation:
# Regular user tries to access it (not a member)
response = await client.get(
f"/api/v1/organizations/{org_id}",
headers={"Authorization": f"Bearer {user_token}"}
headers={"Authorization": f"Bearer {user_token}"},
)
# Should be denied (not a member, not a superuser)
assert response.status_code in [403, 404], "Regular user shouldn't access non-member org"
assert response.status_code in [403, 404], (
"Regular user shouldn't access non-member org"
)

View File

@@ -1,7 +1,8 @@
# tests/api/test_security_headers.py
from unittest.mock import patch
import pytest
from fastapi.testclient import TestClient
from unittest.mock import patch
from app.main import app
@@ -11,8 +12,10 @@ def client():
"""Create a FastAPI test client for the main app (module-scoped for speed)."""
# Mock get_db to avoid database connection issues
with patch("app.core.database.get_db") as mock_get_db:
async def mock_session_generator():
from unittest.mock import MagicMock, AsyncMock
from unittest.mock import AsyncMock, MagicMock
mock_session = MagicMock()
mock_session.execute = AsyncMock(return_value=None)
mock_session.close = AsyncMock(return_value=None)
@@ -77,8 +80,10 @@ class TestSecurityHeaders:
"""Test that HSTS header is set in production (covers line 95)"""
with patch("app.core.config.settings.ENVIRONMENT", "production"):
with patch("app.core.database.get_db") as mock_get_db:
async def mock_session_generator():
from unittest.mock import MagicMock, AsyncMock
from unittest.mock import AsyncMock, MagicMock
mock_session = MagicMock()
mock_session.execute = AsyncMock(return_value=None)
mock_session.close = AsyncMock(return_value=None)
@@ -88,20 +93,26 @@ class TestSecurityHeaders:
# Need to reimport app to pick up the new settings
from importlib import reload
import app.main
reload(app.main)
test_client = TestClient(app.main.app)
response = test_client.get("/health")
assert "Strict-Transport-Security" in response.headers
assert "max-age=31536000" in response.headers["Strict-Transport-Security"]
assert (
"max-age=31536000" in response.headers["Strict-Transport-Security"]
)
def test_csp_strict_mode(self):
"""Test CSP strict mode (covers line 121)"""
with patch("app.core.config.settings.CSP_MODE", "strict"):
with patch("app.core.database.get_db") as mock_get_db:
async def mock_session_generator():
from unittest.mock import MagicMock, AsyncMock
from unittest.mock import AsyncMock, MagicMock
mock_session = MagicMock()
mock_session.execute = AsyncMock(return_value=None)
mock_session.close = AsyncMock(return_value=None)
@@ -110,7 +121,9 @@ class TestSecurityHeaders:
mock_get_db.side_effect = lambda: mock_session_generator()
from importlib import reload
import app.main
reload(app.main)
test_client = TestClient(app.main.app)
@@ -136,8 +149,10 @@ class TestRootEndpoint:
def test_root_endpoint(self):
"""Test root endpoint returns HTML (covers line 174)"""
with patch("app.core.database.get_db") as mock_get_db:
async def mock_session_generator():
from unittest.mock import MagicMock, AsyncMock
from unittest.mock import AsyncMock, MagicMock
mock_session = MagicMock()
mock_session.execute = AsyncMock(return_value=None)
mock_session.close = AsyncMock(return_value=None)

View File

@@ -2,23 +2,23 @@
"""
Comprehensive tests for session management API endpoints.
"""
from datetime import UTC, datetime, timedelta
from unittest.mock import patch
from uuid import uuid4
import pytest
import pytest_asyncio
from datetime import datetime, timedelta, timezone
from uuid import uuid4
from unittest.mock import patch
from fastapi import status
from app.models.user_session import UserSession
from app.schemas.users import UserCreate
# Disable rate limiting for tests
@pytest.fixture(autouse=True)
def disable_rate_limit():
"""Disable rate limiting for all tests in this module."""
with patch('app.api.routes.sessions.limiter.enabled', False):
with patch("app.api.routes.sessions.limiter.enabled", False):
yield
@@ -27,10 +27,7 @@ async def user_token(client, async_test_user):
"""Create and return an access token for async_test_user."""
response = await client.post(
"/api/v1/auth/login",
json={
"email": "testuser@example.com",
"password": "TestPassword123!"
}
json={"email": "testuser@example.com", "password": "TestPassword123!"},
)
assert response.status_code == 200
return response.json()["access_token"]
@@ -39,7 +36,7 @@ async def user_token(client, async_test_user):
@pytest_asyncio.fixture
async def async_test_user2(async_test_db):
"""Create a second test user."""
test_engine, SessionLocal = async_test_db
_test_engine, SessionLocal = async_test_db
async with SessionLocal() as session:
from app.crud.user import user as user_crud
@@ -49,7 +46,7 @@ async def async_test_user2(async_test_db):
email="testuser2@example.com",
password="TestPassword123!",
first_name="Test",
last_name="User2"
last_name="User2",
)
user = await user_crud.create(session, obj_in=user_data)
await session.commit()
@@ -61,9 +58,11 @@ class TestListMySessions:
"""Tests for GET /api/v1/sessions/me endpoint."""
@pytest.mark.asyncio
async def test_list_my_sessions_success(self, client, async_test_user, async_test_db, user_token):
async def test_list_my_sessions_success(
self, client, async_test_user, async_test_db, user_token
):
"""Test successfully listing user's active sessions."""
test_engine, SessionLocal = async_test_db
_test_engine, SessionLocal = async_test_db
# Create some sessions for the user
async with SessionLocal() as session:
@@ -75,8 +74,8 @@ class TestListMySessions:
ip_address="192.168.1.100",
user_agent="Mozilla/5.0 (iPhone)",
is_active=True,
expires_at=datetime.now(timezone.utc) + timedelta(days=7),
last_used_at=datetime.now(timezone.utc)
expires_at=datetime.now(UTC) + timedelta(days=7),
last_used_at=datetime.now(UTC),
)
# Active session 2
s2 = UserSession(
@@ -86,8 +85,8 @@ class TestListMySessions:
ip_address="192.168.1.101",
user_agent="Mozilla/5.0 (Macintosh)",
is_active=True,
expires_at=datetime.now(timezone.utc) + timedelta(days=7),
last_used_at=datetime.now(timezone.utc) - timedelta(hours=1)
expires_at=datetime.now(UTC) + timedelta(days=7),
last_used_at=datetime.now(UTC) - timedelta(hours=1),
)
# Inactive session (should not appear)
s3 = UserSession(
@@ -97,16 +96,15 @@ class TestListMySessions:
ip_address="192.168.1.102",
user_agent="Mozilla/5.0",
is_active=False,
expires_at=datetime.now(timezone.utc) + timedelta(days=7),
last_used_at=datetime.now(timezone.utc) - timedelta(days=1)
expires_at=datetime.now(UTC) + timedelta(days=7),
last_used_at=datetime.now(UTC) - timedelta(days=1),
)
session.add_all([s1, s2, s3])
await session.commit()
# Make request
response = await client.get(
"/api/v1/sessions/me",
headers={"Authorization": f"Bearer {user_token}"}
"/api/v1/sessions/me", headers={"Authorization": f"Bearer {user_token}"}
)
assert response.status_code == status.HTTP_200_OK
@@ -128,11 +126,12 @@ class TestListMySessions:
assert data["sessions"][0]["is_current"] is True
@pytest.mark.asyncio
async def test_list_my_sessions_with_login_session(self, client, async_test_user, user_token):
async def test_list_my_sessions_with_login_session(
self, client, async_test_user, user_token
):
"""Test listing sessions shows the login session."""
response = await client.get(
"/api/v1/sessions/me",
headers={"Authorization": f"Bearer {user_token}"}
"/api/v1/sessions/me", headers={"Authorization": f"Bearer {user_token}"}
)
assert response.status_code == status.HTTP_200_OK
@@ -155,9 +154,11 @@ class TestRevokeSession:
"""Tests for DELETE /api/v1/sessions/{session_id} endpoint."""
@pytest.mark.asyncio
async def test_revoke_session_success(self, client, async_test_user, async_test_db, user_token):
async def test_revoke_session_success(
self, client, async_test_user, async_test_db, user_token
):
"""Test successfully revoking a session."""
test_engine, SessionLocal = async_test_db
_test_engine, SessionLocal = async_test_db
# Create a session to revoke
async with SessionLocal() as session:
@@ -168,8 +169,8 @@ class TestRevokeSession:
ip_address="192.168.1.103",
user_agent="Mozilla/5.0 (iPad)",
is_active=True,
expires_at=datetime.now(timezone.utc) + timedelta(days=7),
last_used_at=datetime.now(timezone.utc)
expires_at=datetime.now(UTC) + timedelta(days=7),
last_used_at=datetime.now(UTC),
)
session.add(user_session)
await session.commit()
@@ -179,7 +180,7 @@ class TestRevokeSession:
# Revoke the session
response = await client.delete(
f"/api/v1/sessions/{session_id}",
headers={"Authorization": f"Bearer {user_token}"}
headers={"Authorization": f"Bearer {user_token}"},
)
assert response.status_code == status.HTTP_200_OK
@@ -191,6 +192,7 @@ class TestRevokeSession:
# Verify session is deactivated
async with SessionLocal() as session:
from app.crud.session import session as session_crud
revoked_session = await session_crud.get(session, id=str(session_id))
assert revoked_session.is_active is False
@@ -200,7 +202,7 @@ class TestRevokeSession:
fake_id = uuid4()
response = await client.delete(
f"/api/v1/sessions/{fake_id}",
headers={"Authorization": f"Bearer {user_token}"}
headers={"Authorization": f"Bearer {user_token}"},
)
assert response.status_code == status.HTTP_404_NOT_FOUND
@@ -222,7 +224,7 @@ class TestRevokeSession:
self, client, async_test_user, async_test_user2, async_test_db, user_token
):
"""Test that users cannot revoke other users' sessions."""
test_engine, SessionLocal = async_test_db
_test_engine, SessionLocal = async_test_db
# Create a session for user2
async with SessionLocal() as session:
@@ -233,8 +235,8 @@ class TestRevokeSession:
ip_address="192.168.1.200",
user_agent="Mozilla/5.0",
is_active=True,
expires_at=datetime.now(timezone.utc) + timedelta(days=7),
last_used_at=datetime.now(timezone.utc)
expires_at=datetime.now(UTC) + timedelta(days=7),
last_used_at=datetime.now(UTC),
)
session.add(other_user_session)
await session.commit()
@@ -244,7 +246,7 @@ class TestRevokeSession:
# Try to revoke it as user1
response = await client.delete(
f"/api/v1/sessions/{session_id}",
headers={"Authorization": f"Bearer {user_token}"}
headers={"Authorization": f"Bearer {user_token}"},
)
assert response.status_code == status.HTTP_403_FORBIDDEN
@@ -263,7 +265,7 @@ class TestCleanupExpiredSessions:
self, client, async_test_user, async_test_db, user_token
):
"""Test successfully cleaning up expired sessions."""
test_engine, SessionLocal = async_test_db
_test_engine, SessionLocal = async_test_db
# Create expired and active sessions using CRUD to avoid greenlet issues
from app.crud.session import session as session_crud
@@ -277,8 +279,8 @@ class TestCleanupExpiredSessions:
device_name="Expired 1",
ip_address="192.168.1.201",
user_agent="Mozilla/5.0",
expires_at=datetime.now(timezone.utc) - timedelta(days=1),
last_used_at=datetime.now(timezone.utc) - timedelta(days=2)
expires_at=datetime.now(UTC) - timedelta(days=1),
last_used_at=datetime.now(UTC) - timedelta(days=2),
)
e1 = await session_crud.create_session(db, obj_in=e1_data)
e1.is_active = False
@@ -291,8 +293,8 @@ class TestCleanupExpiredSessions:
device_name="Expired 2",
ip_address="192.168.1.202",
user_agent="Mozilla/5.0",
expires_at=datetime.now(timezone.utc) - timedelta(hours=1),
last_used_at=datetime.now(timezone.utc) - timedelta(hours=2)
expires_at=datetime.now(UTC) - timedelta(hours=1),
last_used_at=datetime.now(UTC) - timedelta(hours=2),
)
e2 = await session_crud.create_session(db, obj_in=e2_data)
e2.is_active = False
@@ -305,8 +307,8 @@ class TestCleanupExpiredSessions:
device_name="Active",
ip_address="192.168.1.203",
user_agent="Mozilla/5.0",
expires_at=datetime.now(timezone.utc) + timedelta(days=7),
last_used_at=datetime.now(timezone.utc)
expires_at=datetime.now(UTC) + timedelta(days=7),
last_used_at=datetime.now(UTC),
)
await session_crud.create_session(db, obj_in=a1_data)
await db.commit()
@@ -314,7 +316,7 @@ class TestCleanupExpiredSessions:
# Cleanup expired sessions
response = await client.delete(
"/api/v1/sessions/me/expired",
headers={"Authorization": f"Bearer {user_token}"}
headers={"Authorization": f"Bearer {user_token}"},
)
assert response.status_code == status.HTTP_200_OK
@@ -329,7 +331,7 @@ class TestCleanupExpiredSessions:
self, client, async_test_user, async_test_db, user_token
):
"""Test cleanup when no sessions are expired."""
test_engine, SessionLocal = async_test_db
_test_engine, SessionLocal = async_test_db
# Create only active sessions using CRUD
from app.crud.session import session as session_crud
@@ -342,15 +344,15 @@ class TestCleanupExpiredSessions:
device_name="Active Device",
ip_address="192.168.1.210",
user_agent="Mozilla/5.0",
expires_at=datetime.now(timezone.utc) + timedelta(days=7),
last_used_at=datetime.now(timezone.utc)
expires_at=datetime.now(UTC) + timedelta(days=7),
last_used_at=datetime.now(UTC),
)
await session_crud.create_session(db, obj_in=a1_data)
await db.commit()
response = await client.delete(
"/api/v1/sessions/me/expired",
headers={"Authorization": f"Bearer {user_token}"}
headers={"Authorization": f"Bearer {user_token}"},
)
assert response.status_code == status.HTTP_200_OK
@@ -369,13 +371,16 @@ class TestCleanupExpiredSessions:
# Additional tests for better coverage
class TestSessionsAdditionalCases:
"""Additional tests to improve sessions endpoint coverage."""
@pytest.mark.asyncio
async def test_list_sessions_pagination(self, client, async_test_user, async_test_db, user_token):
async def test_list_sessions_pagination(
self, client, async_test_user, async_test_db, user_token
):
"""Test listing sessions with pagination."""
test_engine, SessionLocal = async_test_db
_test_engine, SessionLocal = async_test_db
# Create multiple sessions
async with SessionLocal() as session:
@@ -389,15 +394,15 @@ class TestSessionsAdditionalCases:
device_name=f"Device {i}",
ip_address=f"192.168.1.{i}",
user_agent="Mozilla/5.0",
expires_at=datetime.now(timezone.utc) + timedelta(days=7),
last_used_at=datetime.now(timezone.utc)
expires_at=datetime.now(UTC) + timedelta(days=7),
last_used_at=datetime.now(UTC),
)
await session_crud.create_session(session, obj_in=session_data)
await session.commit()
response = await client.get(
"/api/v1/sessions/me?page=1&limit=3",
headers={"Authorization": f"Bearer {user_token}"}
headers={"Authorization": f"Bearer {user_token}"},
)
assert response.status_code == status.HTTP_200_OK
@@ -410,16 +415,21 @@ class TestSessionsAdditionalCases:
"""Test revoking session with invalid UUID."""
response = await client.delete(
"/api/v1/sessions/not-a-uuid",
headers={"Authorization": f"Bearer {user_token}"}
headers={"Authorization": f"Bearer {user_token}"},
)
# Should return 422 for invalid UUID format
assert response.status_code in [status.HTTP_422_UNPROCESSABLE_ENTITY, status.HTTP_404_NOT_FOUND]
assert response.status_code in [
status.HTTP_422_UNPROCESSABLE_ENTITY,
status.HTTP_404_NOT_FOUND,
]
@pytest.mark.asyncio
async def test_cleanup_expired_sessions_with_mixed_states(self, client, async_test_user, async_test_db, user_token):
async def test_cleanup_expired_sessions_with_mixed_states(
self, client, async_test_user, async_test_db, user_token
):
"""Test cleanup with mix of active/inactive and expired/not-expired sessions."""
test_engine, SessionLocal = async_test_db
_test_engine, SessionLocal = async_test_db
from app.crud.session import session as session_crud
from app.schemas.sessions import SessionCreate
@@ -432,8 +442,8 @@ class TestSessionsAdditionalCases:
device_name="Expired Inactive",
ip_address="192.168.1.100",
user_agent="Mozilla/5.0",
expires_at=datetime.now(timezone.utc) - timedelta(days=1),
last_used_at=datetime.now(timezone.utc) - timedelta(days=2)
expires_at=datetime.now(UTC) - timedelta(days=1),
last_used_at=datetime.now(UTC) - timedelta(days=2),
)
e1 = await session_crud.create_session(db, obj_in=e1_data)
e1.is_active = False
@@ -446,8 +456,8 @@ class TestSessionsAdditionalCases:
device_name="Expired Active",
ip_address="192.168.1.101",
user_agent="Mozilla/5.0",
expires_at=datetime.now(timezone.utc) - timedelta(hours=1),
last_used_at=datetime.now(timezone.utc) - timedelta(hours=2)
expires_at=datetime.now(UTC) - timedelta(hours=1),
last_used_at=datetime.now(UTC) - timedelta(hours=2),
)
await session_crud.create_session(db, obj_in=e2_data)
@@ -455,7 +465,7 @@ class TestSessionsAdditionalCases:
response = await client.delete(
"/api/v1/sessions/me/expired",
headers={"Authorization": f"Bearer {user_token}"}
headers={"Authorization": f"Bearer {user_token}"},
)
assert response.status_code == status.HTTP_200_OK
@@ -476,10 +486,12 @@ class TestSessionExceptionHandlers:
from unittest.mock import patch
# Patch decode_token to raise an exception
with patch('app.api.routes.sessions.decode_token', side_effect=Exception("Token decode error")):
with patch(
"app.api.routes.sessions.decode_token",
side_effect=Exception("Token decode error"),
):
response = await client.get(
"/api/v1/sessions/me",
headers={"Authorization": f"Bearer {user_token}"}
"/api/v1/sessions/me", headers={"Authorization": f"Bearer {user_token}"}
)
# Should still succeed (exception is caught and ignored in try/except at line 77)
@@ -489,12 +501,16 @@ class TestSessionExceptionHandlers:
async def test_list_sessions_database_error(self, client, user_token):
"""Test list_sessions handles database errors (covers lines 104-106)."""
from unittest.mock import patch
from app.crud import session as session_module
with patch.object(session_module.session, 'get_user_sessions', side_effect=Exception("Database error")):
with patch.object(
session_module.session,
"get_user_sessions",
side_effect=Exception("Database error"),
):
response = await client.get(
"/api/v1/sessions/me",
headers={"Authorization": f"Bearer {user_token}"}
"/api/v1/sessions/me", headers={"Authorization": f"Bearer {user_token}"}
)
assert response.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR
@@ -503,18 +519,21 @@ class TestSessionExceptionHandlers:
assert data["errors"][0]["message"] == "Failed to retrieve sessions"
@pytest.mark.asyncio
async def test_revoke_session_database_error(self, client, user_token, async_test_db, async_test_user):
async def test_revoke_session_database_error(
self, client, user_token, async_test_db, async_test_user
):
"""Test revoke_session handles database errors (covers lines 181-183)."""
from datetime import datetime, timedelta
from unittest.mock import patch
from uuid import uuid4
from app.crud import session as session_module
# First create a session to revoke
from app.crud.session import session as session_crud
from app.schemas.sessions import SessionCreate
from datetime import datetime, timedelta, timezone
test_engine, AsyncTestingSessionLocal = async_test_db
_test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as db:
session_in = SessionCreate(
@@ -523,17 +542,21 @@ class TestSessionExceptionHandlers:
device_name="Test Device",
ip_address="192.168.1.1",
user_agent="Mozilla/5.0",
last_used_at=datetime.now(timezone.utc),
expires_at=datetime.now(timezone.utc) + timedelta(days=60)
last_used_at=datetime.now(UTC),
expires_at=datetime.now(UTC) + timedelta(days=60),
)
user_session = await session_crud.create_session(db, obj_in=session_in)
session_id = user_session.id
# Mock the deactivate method to raise an exception
with patch.object(session_module.session, 'deactivate', side_effect=Exception("Database connection lost")):
with patch.object(
session_module.session,
"deactivate",
side_effect=Exception("Database connection lost"),
):
response = await client.delete(
f"/api/v1/sessions/{session_id}",
headers={"Authorization": f"Bearer {user_token}"}
headers={"Authorization": f"Bearer {user_token}"},
)
assert response.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR
@@ -544,12 +567,17 @@ class TestSessionExceptionHandlers:
async def test_cleanup_expired_sessions_database_error(self, client, user_token):
"""Test cleanup_expired_sessions handles database errors (covers lines 233-236)."""
from unittest.mock import patch
from app.crud import session as session_module
with patch.object(session_module.session, 'cleanup_expired_for_user', side_effect=Exception("Cleanup failed")):
with patch.object(
session_module.session,
"cleanup_expired_for_user",
side_effect=Exception("Cleanup failed"),
):
response = await client.delete(
"/api/v1/sessions/me/expired",
headers={"Authorization": f"Bearer {user_token}"}
headers={"Authorization": f"Bearer {user_token}"},
)
assert response.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR

View File

@@ -3,32 +3,29 @@
Comprehensive tests for user management endpoints.
These tests focus on finding potential bugs, not just coverage.
"""
import pytest
import pytest_asyncio
from unittest.mock import patch
from fastapi import status
import uuid
from sqlalchemy import select
import uuid
from unittest.mock import patch
import pytest
from fastapi import status
from app.models.user import User
from app.models.user import User
from app.schemas.users import UserUpdate
# Disable rate limiting for tests
@pytest.fixture(autouse=True)
def disable_rate_limit():
"""Disable rate limiting for all tests in this module."""
with patch('app.api.routes.users.limiter.enabled', False):
with patch('app.api.routes.auth.limiter.enabled', False):
with patch("app.api.routes.users.limiter.enabled", False):
with patch("app.api.routes.auth.limiter.enabled", False):
yield
async def get_auth_headers(client, email, password):
"""Helper to get authentication headers."""
response = await client.post(
"/api/v1/auth/login",
json={"email": email, "password": password}
"/api/v1/auth/login", json={"email": email, "password": password}
)
token = response.json()["access_token"]
return {"Authorization": f"Bearer {token}"}
@@ -40,7 +37,9 @@ class TestListUsers:
@pytest.mark.asyncio
async def test_list_users_as_superuser(self, client, async_test_superuser):
"""Test listing users as superuser."""
headers = await get_auth_headers(client, async_test_superuser.email, "SuperPassword123!")
headers = await get_auth_headers(
client, async_test_superuser.email, "SuperPassword123!"
)
response = await client.get("/api/v1/users", headers=headers)
@@ -53,16 +52,20 @@ class TestListUsers:
@pytest.mark.asyncio
async def test_list_users_as_regular_user(self, client, async_test_user):
"""Test that regular users cannot list users."""
headers = await get_auth_headers(client, async_test_user.email, "TestPassword123!")
headers = await get_auth_headers(
client, async_test_user.email, "TestPassword123!"
)
response = await client.get("/api/v1/users", headers=headers)
assert response.status_code == status.HTTP_403_FORBIDDEN
@pytest.mark.asyncio
async def test_list_users_pagination(self, client, async_test_superuser, async_test_db):
async def test_list_users_pagination(
self, client, async_test_superuser, async_test_db
):
"""Test pagination works correctly."""
test_engine, AsyncTestingSessionLocal = async_test_db
_test_engine, AsyncTestingSessionLocal = async_test_db
# Create multiple users
async with AsyncTestingSessionLocal() as session:
@@ -72,12 +75,14 @@ class TestListUsers:
password_hash="hash",
first_name=f"PagUser{i}",
is_active=True,
is_superuser=False
is_superuser=False,
)
session.add(user)
await session.commit()
headers = await get_auth_headers(client, async_test_superuser.email, "SuperPassword123!")
headers = await get_auth_headers(
client, async_test_superuser.email, "SuperPassword123!"
)
# Get first page
response = await client.get("/api/v1/users?page=1&limit=5", headers=headers)
@@ -88,9 +93,11 @@ class TestListUsers:
assert data["pagination"]["total"] >= 15
@pytest.mark.asyncio
async def test_list_users_filter_active(self, client, async_test_superuser, async_test_db):
async def test_list_users_filter_active(
self, client, async_test_superuser, async_test_db
):
"""Test filtering by active status."""
test_engine, AsyncTestingSessionLocal = async_test_db
_test_engine, AsyncTestingSessionLocal = async_test_db
# Create active and inactive users
async with AsyncTestingSessionLocal() as session:
@@ -99,19 +106,21 @@ class TestListUsers:
password_hash="hash",
first_name="Active",
is_active=True,
is_superuser=False
is_superuser=False,
)
inactive_user = User(
email="inactivefilter@example.com",
password_hash="hash",
first_name="Inactive",
is_active=False,
is_superuser=False
is_superuser=False,
)
session.add_all([active_user, inactive_user])
await session.commit()
headers = await get_auth_headers(client, async_test_superuser.email, "SuperPassword123!")
headers = await get_auth_headers(
client, async_test_superuser.email, "SuperPassword123!"
)
# Filter for active users
response = await client.get("/api/v1/users?is_active=true", headers=headers)
@@ -130,9 +139,13 @@ class TestListUsers:
@pytest.mark.asyncio
async def test_list_users_sort_by_email(self, client, async_test_superuser):
"""Test sorting users by email."""
headers = await get_auth_headers(client, async_test_superuser.email, "SuperPassword123!")
headers = await get_auth_headers(
client, async_test_superuser.email, "SuperPassword123!"
)
response = await client.get("/api/v1/users?sort_by=email&sort_order=asc", headers=headers)
response = await client.get(
"/api/v1/users?sort_by=email&sort_order=asc", headers=headers
)
assert response.status_code == status.HTTP_200_OK
data = response.json()
emails = [u["email"] for u in data["data"]]
@@ -154,7 +167,9 @@ class TestGetCurrentUserProfile:
@pytest.mark.asyncio
async def test_get_own_profile(self, client, async_test_user):
"""Test getting own profile."""
headers = await get_auth_headers(client, async_test_user.email, "TestPassword123!")
headers = await get_auth_headers(
client, async_test_user.email, "TestPassword123!"
)
response = await client.get("/api/v1/users/me", headers=headers)
@@ -176,12 +191,14 @@ class TestUpdateCurrentUser:
@pytest.mark.asyncio
async def test_update_own_profile(self, client, async_test_user):
"""Test updating own profile."""
headers = await get_auth_headers(client, async_test_user.email, "TestPassword123!")
headers = await get_auth_headers(
client, async_test_user.email, "TestPassword123!"
)
response = await client.patch(
"/api/v1/users/me",
headers=headers,
json={"first_name": "Updated", "last_name": "Name"}
json={"first_name": "Updated", "last_name": "Name"},
)
assert response.status_code == status.HTTP_200_OK
@@ -192,12 +209,12 @@ class TestUpdateCurrentUser:
@pytest.mark.asyncio
async def test_update_profile_phone_number(self, client, async_test_user, test_db):
"""Test updating phone number with validation."""
headers = await get_auth_headers(client, async_test_user.email, "TestPassword123!")
headers = await get_auth_headers(
client, async_test_user.email, "TestPassword123!"
)
response = await client.patch(
"/api/v1/users/me",
headers=headers,
json={"phone_number": "+19876543210"}
"/api/v1/users/me", headers=headers, json={"phone_number": "+19876543210"}
)
assert response.status_code == status.HTTP_200_OK
@@ -207,12 +224,12 @@ class TestUpdateCurrentUser:
@pytest.mark.asyncio
async def test_update_profile_invalid_phone(self, client, async_test_user):
"""Test that invalid phone numbers are rejected."""
headers = await get_auth_headers(client, async_test_user.email, "TestPassword123!")
headers = await get_auth_headers(
client, async_test_user.email, "TestPassword123!"
)
response = await client.patch(
"/api/v1/users/me",
headers=headers,
json={"phone_number": "invalid"}
"/api/v1/users/me", headers=headers, json={"phone_number": "invalid"}
)
assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY
@@ -220,14 +237,16 @@ class TestUpdateCurrentUser:
@pytest.mark.asyncio
async def test_cannot_elevate_to_superuser(self, client, async_test_user):
"""Test that users cannot make themselves superuser."""
headers = await get_auth_headers(client, async_test_user.email, "TestPassword123!")
headers = await get_auth_headers(
client, async_test_user.email, "TestPassword123!"
)
# Note: is_superuser is now in UserUpdate schema with explicit validation
# This tests that Pydantic rejects the attempt at the schema level
response = await client.patch(
"/api/v1/users/me",
headers=headers,
json={"first_name": "Test", "is_superuser": True}
json={"first_name": "Test", "is_superuser": True},
)
# Pydantic validation should reject this at the schema level
@@ -242,10 +261,7 @@ class TestUpdateCurrentUser:
@pytest.mark.asyncio
async def test_update_profile_no_auth(self, client):
"""Test that unauthenticated requests are rejected."""
response = await client.patch(
"/api/v1/users/me",
json={"first_name": "Hacker"}
)
response = await client.patch("/api/v1/users/me", json={"first_name": "Hacker"})
assert response.status_code == status.HTTP_401_UNAUTHORIZED
# Note: Removed test_update_profile_unexpected_error - see comment above
@@ -257,16 +273,22 @@ class TestGetUserById:
@pytest.mark.asyncio
async def test_get_own_profile_by_id(self, client, async_test_user):
"""Test getting own profile by ID."""
headers = await get_auth_headers(client, async_test_user.email, "TestPassword123!")
headers = await get_auth_headers(
client, async_test_user.email, "TestPassword123!"
)
response = await client.get(f"/api/v1/users/{async_test_user.id}", headers=headers)
response = await client.get(
f"/api/v1/users/{async_test_user.id}", headers=headers
)
assert response.status_code == status.HTTP_200_OK
data = response.json()
assert data["email"] == async_test_user.email
@pytest.mark.asyncio
async def test_get_other_user_as_regular_user(self, client, async_test_user, test_db):
async def test_get_other_user_as_regular_user(
self, client, async_test_user, test_db
):
"""Test that regular users cannot view other profiles."""
# Create another user
other_user = User(
@@ -274,24 +296,32 @@ class TestGetUserById:
password_hash="hash",
first_name="Other",
is_active=True,
is_superuser=False
is_superuser=False,
)
test_db.add(other_user)
test_db.commit()
test_db.refresh(other_user)
headers = await get_auth_headers(client, async_test_user.email, "TestPassword123!")
headers = await get_auth_headers(
client, async_test_user.email, "TestPassword123!"
)
response = await client.get(f"/api/v1/users/{other_user.id}", headers=headers)
assert response.status_code == status.HTTP_403_FORBIDDEN
@pytest.mark.asyncio
async def test_get_other_user_as_superuser(self, client, async_test_superuser, async_test_user):
async def test_get_other_user_as_superuser(
self, client, async_test_superuser, async_test_user
):
"""Test that superusers can view other profiles."""
headers = await get_auth_headers(client, async_test_superuser.email, "SuperPassword123!")
headers = await get_auth_headers(
client, async_test_superuser.email, "SuperPassword123!"
)
response = await client.get(f"/api/v1/users/{async_test_user.id}", headers=headers)
response = await client.get(
f"/api/v1/users/{async_test_user.id}", headers=headers
)
assert response.status_code == status.HTTP_200_OK
data = response.json()
@@ -300,7 +330,9 @@ class TestGetUserById:
@pytest.mark.asyncio
async def test_get_nonexistent_user(self, client, async_test_superuser):
"""Test getting non-existent user."""
headers = await get_auth_headers(client, async_test_superuser.email, "SuperPassword123!")
headers = await get_auth_headers(
client, async_test_superuser.email, "SuperPassword123!"
)
fake_id = uuid.uuid4()
response = await client.get(f"/api/v1/users/{fake_id}", headers=headers)
@@ -310,7 +342,9 @@ class TestGetUserById:
@pytest.mark.asyncio
async def test_get_user_invalid_uuid(self, client, async_test_superuser):
"""Test getting user with invalid UUID format."""
headers = await get_auth_headers(client, async_test_superuser.email, "SuperPassword123!")
headers = await get_auth_headers(
client, async_test_superuser.email, "SuperPassword123!"
)
response = await client.get("/api/v1/users/not-a-uuid", headers=headers)
@@ -323,12 +357,14 @@ class TestUpdateUserById:
@pytest.mark.asyncio
async def test_update_own_profile_by_id(self, client, async_test_user, test_db):
"""Test updating own profile by ID."""
headers = await get_auth_headers(client, async_test_user.email, "TestPassword123!")
headers = await get_auth_headers(
client, async_test_user.email, "TestPassword123!"
)
response = await client.patch(
f"/api/v1/users/{async_test_user.id}",
headers=headers,
json={"first_name": "SelfUpdated"}
json={"first_name": "SelfUpdated"},
)
assert response.status_code == status.HTTP_200_OK
@@ -336,7 +372,9 @@ class TestUpdateUserById:
assert data["first_name"] == "SelfUpdated"
@pytest.mark.asyncio
async def test_update_other_user_as_regular_user(self, client, async_test_user, test_db):
async def test_update_other_user_as_regular_user(
self, client, async_test_user, test_db
):
"""Test that regular users cannot update other profiles."""
# Create another user
other_user = User(
@@ -344,18 +382,20 @@ class TestUpdateUserById:
password_hash="hash",
first_name="Other",
is_active=True,
is_superuser=False
is_superuser=False,
)
test_db.add(other_user)
test_db.commit()
test_db.refresh(other_user)
headers = await get_auth_headers(client, async_test_user.email, "TestPassword123!")
headers = await get_auth_headers(
client, async_test_user.email, "TestPassword123!"
)
response = await client.patch(
f"/api/v1/users/{other_user.id}",
headers=headers,
json={"first_name": "Hacked"}
json={"first_name": "Hacked"},
)
assert response.status_code == status.HTTP_403_FORBIDDEN
@@ -365,14 +405,18 @@ class TestUpdateUserById:
assert other_user.first_name == "Other"
@pytest.mark.asyncio
async def test_update_other_user_as_superuser(self, client, async_test_superuser, async_test_user, test_db):
async def test_update_other_user_as_superuser(
self, client, async_test_superuser, async_test_user, test_db
):
"""Test that superusers can update other profiles."""
headers = await get_auth_headers(client, async_test_superuser.email, "SuperPassword123!")
headers = await get_auth_headers(
client, async_test_superuser.email, "SuperPassword123!"
)
response = await client.patch(
f"/api/v1/users/{async_test_user.id}",
headers=headers,
json={"first_name": "AdminUpdated"}
json={"first_name": "AdminUpdated"},
)
assert response.status_code == status.HTTP_200_OK
@@ -380,16 +424,20 @@ class TestUpdateUserById:
assert data["first_name"] == "AdminUpdated"
@pytest.mark.asyncio
async def test_regular_user_cannot_modify_superuser_status(self, client, async_test_user):
async def test_regular_user_cannot_modify_superuser_status(
self, client, async_test_user
):
"""Test that regular users cannot change superuser status even if they try."""
headers = await get_auth_headers(client, async_test_user.email, "TestPassword123!")
headers = await get_auth_headers(
client, async_test_user.email, "TestPassword123!"
)
# is_superuser not in UserUpdate schema, so it gets ignored by Pydantic
# Just verify the user stays the same
response = await client.patch(
f"/api/v1/users/{async_test_user.id}",
headers=headers,
json={"first_name": "Test"}
json={"first_name": "Test"},
)
assert response.status_code == status.HTTP_200_OK
@@ -397,14 +445,18 @@ class TestUpdateUserById:
assert data["is_superuser"] is False
@pytest.mark.asyncio
async def test_superuser_can_update_users(self, client, async_test_superuser, async_test_user, test_db):
async def test_superuser_can_update_users(
self, client, async_test_superuser, async_test_user, test_db
):
"""Test that superusers can update other users."""
headers = await get_auth_headers(client, async_test_superuser.email, "SuperPassword123!")
headers = await get_auth_headers(
client, async_test_superuser.email, "SuperPassword123!"
)
response = await client.patch(
f"/api/v1/users/{async_test_user.id}",
headers=headers,
json={"first_name": "AdminChanged", "is_active": False}
json={"first_name": "AdminChanged", "is_active": False},
)
assert response.status_code == status.HTTP_200_OK
@@ -415,13 +467,13 @@ class TestUpdateUserById:
@pytest.mark.asyncio
async def test_update_nonexistent_user(self, client, async_test_superuser):
"""Test updating non-existent user."""
headers = await get_auth_headers(client, async_test_superuser.email, "SuperPassword123!")
headers = await get_auth_headers(
client, async_test_superuser.email, "SuperPassword123!"
)
fake_id = uuid.uuid4()
response = await client.patch(
f"/api/v1/users/{fake_id}",
headers=headers,
json={"first_name": "Ghost"}
f"/api/v1/users/{fake_id}", headers=headers, json={"first_name": "Ghost"}
)
assert response.status_code == status.HTTP_404_NOT_FOUND
@@ -435,15 +487,17 @@ class TestChangePassword:
@pytest.mark.asyncio
async def test_change_password_success(self, client, async_test_user, test_db):
"""Test successful password change."""
headers = await get_auth_headers(client, async_test_user.email, "TestPassword123!")
headers = await get_auth_headers(
client, async_test_user.email, "TestPassword123!"
)
response = await client.patch(
"/api/v1/users/me/password",
headers=headers,
json={
"current_password": "TestPassword123!",
"new_password": "NewPassword123!"
}
"new_password": "NewPassword123!",
},
)
assert response.status_code == status.HTTP_200_OK
@@ -453,25 +507,24 @@ class TestChangePassword:
# Verify can login with new password
login_response = await client.post(
"/api/v1/auth/login",
json={
"email": async_test_user.email,
"password": "NewPassword123!"
}
json={"email": async_test_user.email, "password": "NewPassword123!"},
)
assert login_response.status_code == status.HTTP_200_OK
@pytest.mark.asyncio
async def test_change_password_wrong_current(self, client, async_test_user):
"""Test that wrong current password is rejected."""
headers = await get_auth_headers(client, async_test_user.email, "TestPassword123!")
headers = await get_auth_headers(
client, async_test_user.email, "TestPassword123!"
)
response = await client.patch(
"/api/v1/users/me/password",
headers=headers,
json={
"current_password": "WrongPassword123",
"new_password": "NewPassword123!"
}
"new_password": "NewPassword123!",
},
)
assert response.status_code == status.HTTP_403_FORBIDDEN
@@ -479,15 +532,14 @@ class TestChangePassword:
@pytest.mark.asyncio
async def test_change_password_weak_new_password(self, client, async_test_user):
"""Test that weak new passwords are rejected."""
headers = await get_auth_headers(client, async_test_user.email, "TestPassword123!")
headers = await get_auth_headers(
client, async_test_user.email, "TestPassword123!"
)
response = await client.patch(
"/api/v1/users/me/password",
headers=headers,
json={
"current_password": "TestPassword123!",
"new_password": "weak"
}
json={"current_password": "TestPassword123!", "new_password": "weak"},
)
assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY
@@ -499,8 +551,8 @@ class TestChangePassword:
"/api/v1/users/me/password",
json={
"current_password": "TestPassword123!",
"new_password": "NewPassword123!"
}
"new_password": "NewPassword123!",
},
)
assert response.status_code == status.HTTP_401_UNAUTHORIZED
@@ -511,9 +563,11 @@ class TestDeleteUser:
"""Tests for DELETE /users/{user_id} endpoint."""
@pytest.mark.asyncio
async def test_delete_user_as_superuser(self, client, async_test_superuser, async_test_db):
async def test_delete_user_as_superuser(
self, client, async_test_superuser, async_test_db
):
"""Test deleting a user as superuser."""
test_engine, AsyncTestingSessionLocal = async_test_db
_test_engine, AsyncTestingSessionLocal = async_test_db
# Create a user to delete
async with AsyncTestingSessionLocal() as session:
@@ -522,14 +576,16 @@ class TestDeleteUser:
password_hash="hash",
first_name="Delete",
is_active=True,
is_superuser=False
is_superuser=False,
)
session.add(user_to_delete)
await session.commit()
await session.refresh(user_to_delete)
user_id = user_to_delete.id
headers = await get_auth_headers(client, async_test_superuser.email, "SuperPassword123!")
headers = await get_auth_headers(
client, async_test_superuser.email, "SuperPassword123!"
)
response = await client.delete(f"/api/v1/users/{user_id}", headers=headers)
@@ -540,6 +596,7 @@ class TestDeleteUser:
# Verify user is soft-deleted (has deleted_at timestamp)
async with AsyncTestingSessionLocal() as session:
from sqlalchemy import select
result = await session.execute(select(User).where(User.id == user_id))
deleted_user = result.scalar_one_or_none()
assert deleted_user.deleted_at is not None
@@ -547,9 +604,13 @@ class TestDeleteUser:
@pytest.mark.asyncio
async def test_cannot_delete_self(self, client, async_test_superuser):
"""Test that users cannot delete their own account."""
headers = await get_auth_headers(client, async_test_superuser.email, "SuperPassword123!")
headers = await get_auth_headers(
client, async_test_superuser.email, "SuperPassword123!"
)
response = await client.delete(f"/api/v1/users/{async_test_superuser.id}", headers=headers)
response = await client.delete(
f"/api/v1/users/{async_test_superuser.id}", headers=headers
)
assert response.status_code == status.HTTP_403_FORBIDDEN
@@ -562,22 +623,28 @@ class TestDeleteUser:
password_hash="hash",
first_name="Protected",
is_active=True,
is_superuser=False
is_superuser=False,
)
test_db.add(other_user)
test_db.commit()
test_db.refresh(other_user)
headers = await get_auth_headers(client, async_test_user.email, "TestPassword123!")
headers = await get_auth_headers(
client, async_test_user.email, "TestPassword123!"
)
response = await client.delete(f"/api/v1/users/{other_user.id}", headers=headers)
response = await client.delete(
f"/api/v1/users/{other_user.id}", headers=headers
)
assert response.status_code == status.HTTP_403_FORBIDDEN
@pytest.mark.asyncio
async def test_delete_nonexistent_user(self, client, async_test_superuser):
"""Test deleting non-existent user."""
headers = await get_auth_headers(client, async_test_superuser.email, "SuperPassword123!")
headers = await get_auth_headers(
client, async_test_superuser.email, "SuperPassword123!"
)
fake_id = uuid.uuid4()
response = await client.delete(f"/api/v1/users/{fake_id}", headers=headers)

View File

@@ -2,10 +2,12 @@
"""
Tests for user routes.
"""
from uuid import uuid4
import pytest
import pytest_asyncio
from fastapi import status
from uuid import uuid4
@pytest_asyncio.fixture
@@ -13,10 +15,7 @@ async def superuser_token(client, async_test_superuser):
"""Get access token for superuser."""
response = await client.post(
"/api/v1/auth/login",
json={
"email": "superuser@example.com",
"password": "SuperPassword123!"
}
json={"email": "superuser@example.com", "password": "SuperPassword123!"},
)
assert response.status_code == 200
return response.json()["access_token"]
@@ -27,10 +26,7 @@ async def user_token(client, async_test_user):
"""Get access token for regular user."""
response = await client.post(
"/api/v1/auth/login",
json={
"email": "testuser@example.com",
"password": "TestPassword123!"
}
json={"email": "testuser@example.com", "password": "TestPassword123!"},
)
assert response.status_code == 200
return response.json()["access_token"]
@@ -43,8 +39,7 @@ class TestListUsers:
async def test_list_users_success(self, client, superuser_token):
"""Test listing users successfully (covers lines 87-100)."""
response = await client.get(
"/api/v1/users",
headers={"Authorization": f"Bearer {superuser_token}"}
"/api/v1/users", headers={"Authorization": f"Bearer {superuser_token}"}
)
assert response.status_code == status.HTTP_200_OK
@@ -58,7 +53,7 @@ class TestListUsers:
"""Test listing users with is_superuser filter (covers line 74)."""
response = await client.get(
"/api/v1/users?is_superuser=true",
headers={"Authorization": f"Bearer {superuser_token}"}
headers={"Authorization": f"Bearer {superuser_token}"},
)
assert response.status_code == status.HTTP_200_OK
@@ -73,8 +68,7 @@ class TestGetCurrentUser:
async def test_get_current_user_success(self, client, async_test_user, user_token):
"""Test getting current user profile."""
response = await client.get(
"/api/v1/users/me",
headers={"Authorization": f"Bearer {user_token}"}
"/api/v1/users/me", headers={"Authorization": f"Bearer {user_token}"}
)
assert response.status_code == status.HTTP_200_OK
@@ -92,7 +86,7 @@ class TestUpdateCurrentUser:
response = await client.patch(
"/api/v1/users/me",
headers={"Authorization": f"Bearer {user_token}"},
json={"first_name": "UpdatedName"}
json={"first_name": "UpdatedName"},
)
assert response.status_code == status.HTTP_200_OK
@@ -104,12 +98,14 @@ class TestUpdateCurrentUser:
"""Test database error handling during update (covers lines 162-169)."""
from unittest.mock import patch
with patch('app.api.routes.users.user_crud.update', side_effect=Exception("DB error")):
with patch(
"app.api.routes.users.user_crud.update", side_effect=Exception("DB error")
):
with pytest.raises(Exception):
await client.patch(
"/api/v1/users/me",
headers={"Authorization": f"Bearer {user_token}"},
json={"first_name": "Updated"}
json={"first_name": "Updated"},
)
@pytest.mark.asyncio
@@ -118,7 +114,7 @@ class TestUpdateCurrentUser:
response = await client.patch(
"/api/v1/users/me",
headers={"Authorization": f"Bearer {user_token}"},
json={"is_superuser": True}
json={"is_superuser": True},
)
# Pydantic validation should reject this at the schema level
@@ -137,12 +133,15 @@ class TestUpdateCurrentUser:
"""Test ValueError handling during update (covers lines 165-166)."""
from unittest.mock import patch
with patch('app.api.routes.users.user_crud.update', side_effect=ValueError("Invalid value")):
with patch(
"app.api.routes.users.user_crud.update",
side_effect=ValueError("Invalid value"),
):
with pytest.raises(ValueError):
await client.patch(
"/api/v1/users/me",
headers={"Authorization": f"Bearer {user_token}"},
json={"first_name": "Updated"}
json={"first_name": "Updated"},
)
@@ -154,7 +153,7 @@ class TestGetUser:
"""Test getting user by ID."""
response = await client.get(
f"/api/v1/users/{async_test_user.id}",
headers={"Authorization": f"Bearer {superuser_token}"}
headers={"Authorization": f"Bearer {superuser_token}"},
)
assert response.status_code == status.HTTP_200_OK
@@ -167,7 +166,7 @@ class TestGetUser:
fake_id = uuid4()
response = await client.get(
f"/api/v1/users/{fake_id}",
headers={"Authorization": f"Bearer {superuser_token}"}
headers={"Authorization": f"Bearer {superuser_token}"},
)
assert response.status_code == status.HTTP_404_NOT_FOUND
@@ -183,30 +182,34 @@ class TestUpdateUserById:
response = await client.patch(
f"/api/v1/users/{fake_id}",
headers={"Authorization": f"Bearer {superuser_token}"},
json={"first_name": "Updated"}
json={"first_name": "Updated"},
)
assert response.status_code == status.HTTP_404_NOT_FOUND
@pytest.mark.asyncio
async def test_update_user_by_id_non_superuser_cannot_change_superuser_status(self, client, async_test_user, user_token):
async def test_update_user_by_id_non_superuser_cannot_change_superuser_status(
self, client, async_test_user, user_token
):
"""Test non-superuser cannot modify superuser status (Pydantic validation)."""
response = await client.patch(
f"/api/v1/users/{async_test_user.id}",
headers={"Authorization": f"Bearer {user_token}"},
json={"is_superuser": True}
json={"is_superuser": True},
)
# Pydantic validation should reject this at the schema level
assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY
@pytest.mark.asyncio
async def test_update_user_by_id_success(self, client, async_test_user, superuser_token):
async def test_update_user_by_id_success(
self, client, async_test_user, superuser_token
):
"""Test updating user successfully (covers lines 276-278)."""
response = await client.patch(
f"/api/v1/users/{async_test_user.id}",
headers={"Authorization": f"Bearer {superuser_token}"},
json={"first_name": "SuperUpdated"}
json={"first_name": "SuperUpdated"},
)
assert response.status_code == status.HTTP_200_OK
@@ -214,29 +217,37 @@ class TestUpdateUserById:
assert data["first_name"] == "SuperUpdated"
@pytest.mark.asyncio
async def test_update_user_by_id_value_error(self, client, async_test_user, superuser_token):
async def test_update_user_by_id_value_error(
self, client, async_test_user, superuser_token
):
"""Test ValueError handling (covers lines 280-281)."""
from unittest.mock import patch
with patch('app.api.routes.users.user_crud.update', side_effect=ValueError("Invalid")):
with patch(
"app.api.routes.users.user_crud.update", side_effect=ValueError("Invalid")
):
with pytest.raises(ValueError):
await client.patch(
f"/api/v1/users/{async_test_user.id}",
headers={"Authorization": f"Bearer {superuser_token}"},
json={"first_name": "Updated"}
json={"first_name": "Updated"},
)
@pytest.mark.asyncio
async def test_update_user_by_id_unexpected_error(self, client, async_test_user, superuser_token):
async def test_update_user_by_id_unexpected_error(
self, client, async_test_user, superuser_token
):
"""Test unexpected error handling (covers lines 283-284)."""
from unittest.mock import patch
with patch('app.api.routes.users.user_crud.update', side_effect=Exception("Unexpected")):
with patch(
"app.api.routes.users.user_crud.update", side_effect=Exception("Unexpected")
):
with pytest.raises(Exception):
await client.patch(
f"/api/v1/users/{async_test_user.id}",
headers={"Authorization": f"Bearer {superuser_token}"},
json={"first_name": "Updated"}
json={"first_name": "Updated"},
)
@@ -246,18 +257,18 @@ class TestChangePassword:
@pytest.mark.asyncio
async def test_change_password_success(self, client, async_test_db):
"""Test changing password successfully."""
test_engine, AsyncTestingSessionLocal = async_test_db
_test_engine, AsyncTestingSessionLocal = async_test_db
# Create a fresh user
async with AsyncTestingSessionLocal() as session:
from app.models.user import User
from app.core.auth import get_password_hash
from app.models.user import User
new_user = User(
email="changepass@example.com",
password_hash=get_password_hash("OldPassword123!"),
first_name="Change",
last_name="Pass"
last_name="Pass",
)
session.add(new_user)
await session.commit()
@@ -265,10 +276,7 @@ class TestChangePassword:
# Login
login_response = await client.post(
"/api/v1/auth/login",
json={
"email": "changepass@example.com",
"password": "OldPassword123!"
}
json={"email": "changepass@example.com", "password": "OldPassword123!"},
)
token = login_response.json()["access_token"]
@@ -278,8 +286,8 @@ class TestChangePassword:
headers={"Authorization": f"Bearer {token}"},
json={
"current_password": "OldPassword123!",
"new_password": "NewPassword456!"
}
"new_password": "NewPassword456!",
},
)
assert response.status_code == status.HTTP_200_OK
@@ -289,10 +297,7 @@ class TestChangePassword:
# Verify new password works
login_response = await client.post(
"/api/v1/auth/login",
json={
"email": "changepass@example.com",
"password": "NewPassword456!"
}
json={"email": "changepass@example.com", "password": "NewPassword456!"},
)
assert login_response.status_code == status.HTTP_200_OK
@@ -306,7 +311,7 @@ class TestDeleteUserById:
fake_id = uuid4()
response = await client.delete(
f"/api/v1/users/{fake_id}",
headers={"Authorization": f"Bearer {superuser_token}"}
headers={"Authorization": f"Bearer {superuser_token}"},
)
assert response.status_code == status.HTTP_404_NOT_FOUND
@@ -314,18 +319,18 @@ class TestDeleteUserById:
@pytest.mark.asyncio
async def test_delete_user_success(self, client, async_test_db, superuser_token):
"""Test deleting user successfully (covers lines 383-388)."""
test_engine, AsyncTestingSessionLocal = async_test_db
_test_engine, AsyncTestingSessionLocal = async_test_db
# Create a user to delete
async with AsyncTestingSessionLocal() as session:
from app.models.user import User
from app.core.auth import get_password_hash
from app.models.user import User
user_to_delete = User(
email=f"delete{uuid4().hex[:8]}@example.com",
password_hash=get_password_hash("Password123!"),
first_name="Delete",
last_name="Me"
last_name="Me",
)
session.add(user_to_delete)
await session.commit()
@@ -334,7 +339,7 @@ class TestDeleteUserById:
response = await client.delete(
f"/api/v1/users/{user_id}",
headers={"Authorization": f"Bearer {superuser_token}"}
headers={"Authorization": f"Bearer {superuser_token}"},
)
assert response.status_code == status.HTTP_200_OK
@@ -342,25 +347,35 @@ class TestDeleteUserById:
assert data["success"] is True
@pytest.mark.asyncio
async def test_delete_user_value_error(self, client, async_test_user, superuser_token):
async def test_delete_user_value_error(
self, client, async_test_user, superuser_token
):
"""Test ValueError handling during delete (covers lines 390-391)."""
from unittest.mock import patch
with patch('app.api.routes.users.user_crud.soft_delete', side_effect=ValueError("Cannot delete")):
with patch(
"app.api.routes.users.user_crud.soft_delete",
side_effect=ValueError("Cannot delete"),
):
with pytest.raises(ValueError):
await client.delete(
f"/api/v1/users/{async_test_user.id}",
headers={"Authorization": f"Bearer {superuser_token}"}
headers={"Authorization": f"Bearer {superuser_token}"},
)
@pytest.mark.asyncio
async def test_delete_user_unexpected_error(self, client, async_test_user, superuser_token):
async def test_delete_user_unexpected_error(
self, client, async_test_user, superuser_token
):
"""Test unexpected error handling during delete (covers lines 393-394)."""
from unittest.mock import patch
with patch('app.api.routes.users.user_crud.soft_delete', side_effect=Exception("Unexpected")):
with patch(
"app.api.routes.users.user_crud.soft_delete",
side_effect=Exception("Unexpected"),
):
with pytest.raises(Exception):
await client.delete(
f"/api/v1/users/{async_test_user.id}",
headers={"Authorization": f"Bearer {superuser_token}"}
headers={"Authorization": f"Bearer {superuser_token}"},
)

View File

@@ -1,28 +1,32 @@
# tests/conftest.py
import os
import uuid
from datetime import datetime, timezone
import pytest
import pytest_asyncio
from httpx import AsyncClient, ASGITransport
from httpx import ASGITransport, AsyncClient
# Set IS_TEST environment variable BEFORE importing app
# This prevents the scheduler from starting during tests
os.environ["IS_TEST"] = "True"
from app.main import app
from app.core.database import get_db
from app.models.user import User
from app.core.auth import get_password_hash
from app.utils.test_utils import setup_test_db, teardown_test_db, setup_async_test_db, teardown_async_test_db
from app.core.database import get_db
from app.main import app
from app.models.user import User
from app.utils.test_utils import (
setup_async_test_db,
setup_test_db,
teardown_async_test_db,
teardown_test_db,
)
@pytest.fixture(scope="function")
def db_session():
"""
Creates a fresh SQLite in-memory database for each test function.
Yields a SQLAlchemy session that can be used for testing.
"""
# Set up the database
@@ -46,6 +50,7 @@ async def async_test_db():
yield test_engine, AsyncTestingSessionLocal
await teardown_async_test_db(test_engine)
@pytest.fixture
def user_create_data():
return {
@@ -55,7 +60,7 @@ def user_create_data():
"last_name": "User",
"phone_number": "+1234567890",
"is_superuser": False,
"preferences": None
"preferences": None,
}
@@ -102,7 +107,7 @@ async def client(async_test_db):
This overrides the get_db dependency to use the test database.
"""
test_engine, AsyncTestingSessionLocal = async_test_db
_test_engine, AsyncTestingSessionLocal = async_test_db
async def override_get_db():
async with AsyncTestingSessionLocal() as session:
@@ -176,7 +181,7 @@ async def async_test_user(async_test_db):
Password: TestPassword123
"""
test_engine, AsyncTestingSessionLocal = async_test_db
_test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session:
user = User(
id=uuid.uuid4(),
@@ -202,7 +207,7 @@ async def async_test_superuser(async_test_db):
Password: SuperPassword123
"""
test_engine, AsyncTestingSessionLocal = async_test_db
_test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session:
user = User(
id=uuid.uuid4(),
@@ -256,4 +261,4 @@ async def superuser_token(client, async_test_superuser):
)
assert response.status_code == 200, f"Login failed: {response.text}"
tokens = response.json()
return tokens["access_token"]
return tokens["access_token"]

View File

@@ -1,20 +1,20 @@
# tests/core/test_auth.py
import uuid
from datetime import UTC, datetime, timedelta
import pytest
from datetime import datetime, timedelta, timezone
from jose import jwt
from pydantic import ValidationError
from app.core.auth import (
verify_password,
get_password_hash,
TokenExpiredError,
TokenInvalidError,
TokenMissingClaimError,
create_access_token,
create_refresh_token,
decode_token,
get_password_hash,
get_token_data,
TokenExpiredError,
TokenInvalidError,
TokenMissingClaimError
verify_password,
)
from app.core.config import settings
@@ -58,15 +58,13 @@ class TestTokenCreation:
custom_claims = {
"email": "test@example.com",
"first_name": "Test",
"is_superuser": True
"is_superuser": True,
}
token = create_access_token(subject=user_id, claims=custom_claims)
# Decode token to verify claims
payload = jwt.decode(
token,
settings.SECRET_KEY,
algorithms=[settings.ALGORITHM]
token, settings.SECRET_KEY, algorithms=[settings.ALGORITHM]
)
# Check standard claims
@@ -87,9 +85,7 @@ class TestTokenCreation:
# Decode token to verify claims
payload = jwt.decode(
token,
settings.SECRET_KEY,
algorithms=[settings.ALGORITHM]
token, settings.SECRET_KEY, algorithms=[settings.ALGORITHM]
)
# Check standard claims
@@ -105,23 +101,18 @@ class TestTokenCreation:
expires = timedelta(minutes=5)
# Create token with specific expiration
token = create_access_token(
subject=user_id,
expires_delta=expires
)
token = create_access_token(subject=user_id, expires_delta=expires)
# Decode token
payload = jwt.decode(
token,
settings.SECRET_KEY,
algorithms=[settings.ALGORITHM]
token, settings.SECRET_KEY, algorithms=[settings.ALGORITHM]
)
# Get actual expiration time from token
expiration = datetime.fromtimestamp(payload["exp"], tz=timezone.utc)
expiration = datetime.fromtimestamp(payload["exp"], tz=UTC)
# Calculate expected expiration (approximately)
now = datetime.now(timezone.utc)
now = datetime.now(UTC)
expected_expiration = now + expires
# Difference should be small (less than 1 second)
@@ -148,7 +139,7 @@ class TestTokenDecoding:
user_id = str(uuid.uuid4())
# Create a token that's already expired by directly manipulating the payload
now = datetime.now(timezone.utc)
now = datetime.now(UTC)
expired_time = now - timedelta(hours=1) # 1 hour in the past
# Create the expired token manually
@@ -157,13 +148,11 @@ class TestTokenDecoding:
"exp": int(expired_time.timestamp()), # Set expiration in the past
"iat": int(now.timestamp()),
"jti": str(uuid.uuid4()),
"type": "access"
"type": "access",
}
expired_token = jwt.encode(
payload,
settings.SECRET_KEY,
algorithm=settings.ALGORITHM
payload, settings.SECRET_KEY, algorithm=settings.ALGORITHM
)
# Attempting to decode should raise TokenExpiredError
@@ -180,20 +169,16 @@ class TestTokenDecoding:
def test_decode_token_with_missing_sub(self):
"""Test that a token without 'sub' claim raises TokenMissingClaimError"""
# Create a token without a subject
now = datetime.now(timezone.utc)
now = datetime.now(UTC)
payload = {
"exp": int((now + timedelta(minutes=30)).timestamp()),
"iat": int(now.timestamp()),
"jti": str(uuid.uuid4()),
"type": "access"
"type": "access",
# No 'sub' claim
}
token = jwt.encode(
payload,
settings.SECRET_KEY,
algorithm=settings.ALGORITHM
)
token = jwt.encode(payload, settings.SECRET_KEY, algorithm=settings.ALGORITHM)
with pytest.raises(TokenMissingClaimError):
decode_token(token)
@@ -211,20 +196,16 @@ class TestTokenDecoding:
"""Test that a token with invalid payload structure raises TokenInvalidError"""
# Create a token with an invalid payload structure - missing 'sub' which is required
# but including 'exp' to avoid the expiration check
now = datetime.now(timezone.utc)
now = datetime.now(UTC)
payload = {
# Missing "sub" field which is required
"exp": int((now + timedelta(minutes=30)).timestamp()),
"iat": int(now.timestamp()),
"jti": str(uuid.uuid4()),
"invalid_field": "test"
"invalid_field": "test",
}
token = jwt.encode(
payload,
settings.SECRET_KEY,
algorithm=settings.ALGORITHM
)
token = jwt.encode(payload, settings.SECRET_KEY, algorithm=settings.ALGORITHM)
# Should raise TokenMissingClaimError due to missing 'sub'
with pytest.raises(TokenMissingClaimError):
@@ -236,11 +217,7 @@ class TestTokenDecoding:
"exp": int((now + timedelta(minutes=30)).timestamp()),
}
token = jwt.encode(
payload,
settings.SECRET_KEY,
algorithm=settings.ALGORITHM
)
token = jwt.encode(payload, settings.SECRET_KEY, algorithm=settings.ALGORITHM)
# Should raise TokenInvalidError due to ValidationError
with pytest.raises(TokenInvalidError):
@@ -249,12 +226,9 @@ class TestTokenDecoding:
def test_get_token_data(self):
"""Test extracting TokenData from a token"""
user_id = uuid.uuid4()
token = create_access_token(
subject=str(user_id),
claims={"is_superuser": True}
)
token = create_access_token(subject=str(user_id), claims={"is_superuser": True})
token_data = get_token_data(token)
assert token_data.user_id == user_id
assert token_data.is_superuser is True
assert token_data.is_superuser is True

View File

@@ -8,11 +8,11 @@ Critical security tests covering:
These tests cover critical security vulnerabilities that could be exploited.
"""
import pytest
from jose import jwt
from datetime import datetime, timedelta, timezone
from app.core.auth import decode_token, create_access_token, TokenInvalidError
from app.core.auth import TokenInvalidError, create_access_token, decode_token
from app.core.config import settings
@@ -46,13 +46,14 @@ class TestJWTAlgorithmSecurityAttacks:
"""
# Create a payload that would normally be valid (using timestamps)
import time
now = int(time.time())
payload = {
"sub": "user123",
"exp": now + 3600, # 1 hour from now
"iat": now,
"type": "access"
"type": "access",
}
# Craft a malicious token with "alg: none"
@@ -61,13 +62,13 @@ class TestJWTAlgorithmSecurityAttacks:
import json
header = {"alg": "none", "typ": "JWT"}
header_encoded = base64.urlsafe_b64encode(
json.dumps(header).encode()
).decode().rstrip("=")
header_encoded = (
base64.urlsafe_b64encode(json.dumps(header).encode()).decode().rstrip("=")
)
payload_encoded = base64.urlsafe_b64encode(
json.dumps(payload).encode()
).decode().rstrip("=")
payload_encoded = (
base64.urlsafe_b64encode(json.dumps(payload).encode()).decode().rstrip("=")
)
# Token with no signature (algorithm "none")
malicious_token = f"{header_encoded}.{payload_encoded}."
@@ -85,22 +86,17 @@ class TestJWTAlgorithmSecurityAttacks:
import time
now = int(time.time())
payload = {
"sub": "user123",
"exp": now + 3600,
"iat": now,
"type": "access"
}
payload = {"sub": "user123", "exp": now + 3600, "iat": now, "type": "access"}
# Try uppercase "NONE"
header = {"alg": "NONE", "typ": "JWT"}
header_encoded = base64.urlsafe_b64encode(
json.dumps(header).encode()
).decode().rstrip("=")
header_encoded = (
base64.urlsafe_b64encode(json.dumps(header).encode()).decode().rstrip("=")
)
payload_encoded = base64.urlsafe_b64encode(
json.dumps(payload).encode()
).decode().rstrip("=")
payload_encoded = (
base64.urlsafe_b64encode(json.dumps(payload).encode()).decode().rstrip("=")
)
malicious_token = f"{header_encoded}.{payload_encoded}."
@@ -121,15 +117,11 @@ class TestJWTAlgorithmSecurityAttacks:
before our defensive checks at line 212. This is good for security!
"""
import time
now = int(time.time())
# Create a valid payload
payload = {
"sub": "user123",
"exp": now + 3600,
"iat": now,
"type": "access"
}
payload = {"sub": "user123", "exp": now + 3600, "iat": now, "type": "access"}
# Encode with wrong algorithm (RS256 instead of HS256)
# This simulates an attacker trying algorithm substitution
@@ -137,9 +129,7 @@ class TestJWTAlgorithmSecurityAttacks:
try:
malicious_token = jwt.encode(
payload,
settings.SECRET_KEY,
algorithm=wrong_algorithm
payload, settings.SECRET_KEY, algorithm=wrong_algorithm
)
# Should reject the token (library catches mismatch)
@@ -156,21 +146,15 @@ class TestJWTAlgorithmSecurityAttacks:
Prevents algorithm downgrade/upgrade attacks.
"""
import time
now = int(time.time())
payload = {
"sub": "user123",
"exp": now + 3600,
"iat": now,
"type": "access"
}
payload = {"sub": "user123", "exp": now + 3600, "iat": now, "type": "access"}
# Create token with HS384 instead of HS256
try:
malicious_token = jwt.encode(
payload,
settings.SECRET_KEY,
algorithm="HS384"
payload, settings.SECRET_KEY, algorithm="HS384"
)
with pytest.raises(TokenInvalidError):
@@ -223,20 +207,15 @@ class TestJWTSecurityEdgeCases:
# Create token without "alg" in header
header = {"typ": "JWT"} # Missing "alg"
payload = {
"sub": "user123",
"exp": now + 3600,
"iat": now,
"type": "access"
}
payload = {"sub": "user123", "exp": now + 3600, "iat": now, "type": "access"}
header_encoded = base64.urlsafe_b64encode(
json.dumps(header).encode()
).decode().rstrip("=")
header_encoded = (
base64.urlsafe_b64encode(json.dumps(header).encode()).decode().rstrip("=")
)
payload_encoded = base64.urlsafe_b64encode(
json.dumps(payload).encode()
).decode().rstrip("=")
payload_encoded = (
base64.urlsafe_b64encode(json.dumps(payload).encode()).decode().rstrip("=")
)
malicious_token = f"{header_encoded}.{payload_encoded}.fake_signature"
@@ -253,15 +232,20 @@ class TestJWTSecurityEdgeCases:
"""Test token with malformed JSON in payload."""
import base64
header = {"alg": "HS256", "typ": "JWT"}
header_encoded = base64.urlsafe_b64encode(
b'{"alg":"HS256","typ":"JWT"}'
).decode().rstrip("=")
header_encoded = (
base64.urlsafe_b64encode(b'{"alg":"HS256","typ":"JWT"}')
.decode()
.rstrip("=")
)
# Invalid JSON (missing closing brace)
invalid_payload_encoded = base64.urlsafe_b64encode(
b'{"sub":"user123"' # Invalid JSON
).decode().rstrip("=")
invalid_payload_encoded = (
base64.urlsafe_b64encode(
b'{"sub":"user123"' # Invalid JSON
)
.decode()
.rstrip("=")
)
malicious_token = f"{header_encoded}.{invalid_payload_encoded}.fake_sig"

View File

@@ -1,6 +1,7 @@
# tests/core/test_config.py
import pytest
from pydantic import ValidationError
from app.core.config import Settings
@@ -22,11 +23,15 @@ class TestSecretKeyValidation:
with pytest.raises(ValidationError) as exc_info:
Settings(SECRET_KEY=default_key, ENVIRONMENT="production")
assert "must be set to a secure random value in production" in str(exc_info.value)
assert "must be set to a secure random value in production" in str(
exc_info.value
)
def test_default_secret_key_in_development_allows_with_warning(self, caplog):
"""Test that default SECRET_KEY in development is allowed but warns"""
settings = Settings(SECRET_KEY="your_secret_key_here" + "x" * 14, ENVIRONMENT="development")
settings = Settings(
SECRET_KEY="your_secret_key_here" + "x" * 14, ENVIRONMENT="development"
)
assert settings.SECRET_KEY == "your_secret_key_here" + "x" * 14
# Note: The warning happens during validation, which we've seen works
@@ -44,19 +49,13 @@ class TestSuperuserPasswordValidation:
def test_none_password_accepted(self):
"""Test that None password is accepted (optional field)"""
settings = Settings(
SECRET_KEY="a" * 32,
FIRST_SUPERUSER_PASSWORD=None
)
settings = Settings(SECRET_KEY="a" * 32, FIRST_SUPERUSER_PASSWORD=None)
assert settings.FIRST_SUPERUSER_PASSWORD is None
def test_password_too_short_raises_error(self):
"""Test that password shorter than 12 characters raises error"""
with pytest.raises(ValidationError) as exc_info:
Settings(
SECRET_KEY="a" * 32,
FIRST_SUPERUSER_PASSWORD="Short1"
)
Settings(SECRET_KEY="a" * 32, FIRST_SUPERUSER_PASSWORD="Short1")
assert "must be at least 12 characters" in str(exc_info.value)
@@ -64,14 +63,11 @@ class TestSuperuserPasswordValidation:
"""Test that common weak passwords are rejected"""
# Test with the exact weak passwords from the validator
# These are in the weak_passwords set and should be rejected
weak_passwords = ['123456789012'] # Exactly 12 chars, in the weak set
weak_passwords = ["123456789012"] # Exactly 12 chars, in the weak set
for weak_pwd in weak_passwords:
with pytest.raises(ValidationError) as exc_info:
Settings(
SECRET_KEY="a" * 32,
FIRST_SUPERUSER_PASSWORD=weak_pwd
)
Settings(SECRET_KEY="a" * 32, FIRST_SUPERUSER_PASSWORD=weak_pwd)
# Should get "too weak" message
error_str = str(exc_info.value)
assert "too weak" in error_str
@@ -79,30 +75,21 @@ class TestSuperuserPasswordValidation:
def test_password_without_lowercase_rejected(self):
"""Test that password without lowercase is rejected"""
with pytest.raises(ValidationError) as exc_info:
Settings(
SECRET_KEY="a" * 32,
FIRST_SUPERUSER_PASSWORD="ALLUPPERCASE123"
)
Settings(SECRET_KEY="a" * 32, FIRST_SUPERUSER_PASSWORD="ALLUPPERCASE123")
assert "must contain lowercase, uppercase, and digits" in str(exc_info.value)
def test_password_without_uppercase_rejected(self):
"""Test that password without uppercase is rejected"""
with pytest.raises(ValidationError) as exc_info:
Settings(
SECRET_KEY="a" * 32,
FIRST_SUPERUSER_PASSWORD="alllowercase123"
)
Settings(SECRET_KEY="a" * 32, FIRST_SUPERUSER_PASSWORD="alllowercase123")
assert "must contain lowercase, uppercase, and digits" in str(exc_info.value)
def test_password_without_digit_rejected(self):
"""Test that password without digit is rejected"""
with pytest.raises(ValidationError) as exc_info:
Settings(
SECRET_KEY="a" * 32,
FIRST_SUPERUSER_PASSWORD="NoDigitsHere"
)
Settings(SECRET_KEY="a" * 32, FIRST_SUPERUSER_PASSWORD="NoDigitsHere")
assert "must contain lowercase, uppercase, and digits" in str(exc_info.value)
@@ -110,8 +97,7 @@ class TestSuperuserPasswordValidation:
"""Test that strong password is accepted"""
strong_password = "StrongPassword123!"
settings = Settings(
SECRET_KEY="a" * 32,
FIRST_SUPERUSER_PASSWORD=strong_password
SECRET_KEY="a" * 32, FIRST_SUPERUSER_PASSWORD=strong_password
)
assert settings.FIRST_SUPERUSER_PASSWORD == strong_password
@@ -150,7 +136,7 @@ class TestDatabaseConfiguration:
POSTGRES_HOST="testhost",
POSTGRES_PORT="5432",
POSTGRES_DB="testdb",
DATABASE_URL=None # Don't use explicit URL
DATABASE_URL=None, # Don't use explicit URL
)
expected_url = "postgresql://testuser:testpass@testhost:5432/testdb"
@@ -159,10 +145,7 @@ class TestDatabaseConfiguration:
def test_explicit_database_url_used_when_set(self):
"""Test that explicit DATABASE_URL is used when provided"""
explicit_url = "postgresql://explicit:pass@host:5432/db"
settings = Settings(
SECRET_KEY="a" * 32,
DATABASE_URL=explicit_url
)
settings = Settings(SECRET_KEY="a" * 32, DATABASE_URL=explicit_url)
assert settings.database_url == explicit_url

View File

@@ -6,8 +6,10 @@ Critical security tests covering:
These tests prevent security misconfigurations.
"""
import pytest
import os
import pytest
from pydantic import ValidationError
@@ -43,6 +45,7 @@ class TestSecretKeySecurityValidation:
# Import Settings class fresh (to pick up new env var)
# The ValidationError should be raised during reload when Settings() is instantiated
import importlib
from app.core import config
# Reload will raise ValidationError because Settings() is instantiated at module level
@@ -58,7 +61,9 @@ class TestSecretKeySecurityValidation:
# Reload config to restore original settings
import importlib
from app.core import config
importlib.reload(config)
def test_secret_key_exactly_32_characters_accepted(self):
@@ -75,7 +80,9 @@ class TestSecretKeySecurityValidation:
os.environ["SECRET_KEY"] = key_32
import importlib
from app.core import config
importlib.reload(config)
# Should work
@@ -89,7 +96,9 @@ class TestSecretKeySecurityValidation:
os.environ.pop("SECRET_KEY", None)
import importlib
from app.core import config
importlib.reload(config)
def test_secret_key_long_enough_accepted(self):
@@ -106,7 +115,9 @@ class TestSecretKeySecurityValidation:
os.environ["SECRET_KEY"] = key_64
import importlib
from app.core import config
importlib.reload(config)
# Should work
@@ -120,7 +131,9 @@ class TestSecretKeySecurityValidation:
os.environ.pop("SECRET_KEY", None)
import importlib
from app.core import config
importlib.reload(config)
def test_default_secret_key_meets_requirements(self):
@@ -132,4 +145,6 @@ class TestSecretKeySecurityValidation:
from app.core.config import settings
# Current settings should have valid SECRET_KEY
assert len(settings.SECRET_KEY) >= 32, "Default SECRET_KEY must be at least 32 chars"
assert len(settings.SECRET_KEY) >= 32, (
"Default SECRET_KEY must be at least 32 chars"
)

View File

@@ -9,18 +9,19 @@ Covers:
- init_async_db
- close_async_db
"""
from unittest.mock import patch
import pytest
import pytest_asyncio
from unittest.mock import patch, MagicMock, AsyncMock
from sqlalchemy.ext.asyncio import AsyncSession
from app.core.database import (
get_async_database_url,
get_db,
async_transaction_scope,
check_async_database_health,
init_async_db,
close_async_db,
get_async_database_url,
get_db,
init_async_db,
)
@@ -88,12 +89,13 @@ class TestAsyncTransactionScope:
async def test_transaction_scope_commits_on_success(self, async_test_db):
"""Test that successful operations are committed (covers line 138)."""
# Mock the transaction scope to use test database
test_engine, SessionLocal = async_test_db
_test_engine, SessionLocal = async_test_db
with patch('app.core.database.SessionLocal', SessionLocal):
with patch("app.core.database.SessionLocal", SessionLocal):
async with async_transaction_scope() as db:
# Execute a simple query to verify transaction works
from sqlalchemy import text
result = await db.execute(text("SELECT 1"))
assert result is not None
# Transaction should be committed (covers line 138 debug log)
@@ -101,12 +103,13 @@ class TestAsyncTransactionScope:
@pytest.mark.asyncio
async def test_transaction_scope_rollback_on_error(self, async_test_db):
"""Test that transaction rolls back on exception."""
test_engine, SessionLocal = async_test_db
_test_engine, SessionLocal = async_test_db
with patch('app.core.database.SessionLocal', SessionLocal):
with patch("app.core.database.SessionLocal", SessionLocal):
with pytest.raises(RuntimeError, match="Test error"):
async with async_transaction_scope() as db:
from sqlalchemy import text
await db.execute(text("SELECT 1"))
raise RuntimeError("Test error")
@@ -117,9 +120,9 @@ class TestCheckAsyncDatabaseHealth:
@pytest.mark.asyncio
async def test_database_health_check_success(self, async_test_db):
"""Test health check returns True on success (covers line 156)."""
test_engine, SessionLocal = async_test_db
_test_engine, SessionLocal = async_test_db
with patch('app.core.database.SessionLocal', SessionLocal):
with patch("app.core.database.SessionLocal", SessionLocal):
result = await check_async_database_health()
assert result is True
@@ -127,7 +130,7 @@ class TestCheckAsyncDatabaseHealth:
async def test_database_health_check_failure(self):
"""Test health check returns False on database error."""
# Mock async_transaction_scope to raise an error
with patch('app.core.database.async_transaction_scope') as mock_scope:
with patch("app.core.database.async_transaction_scope") as mock_scope:
mock_scope.side_effect = Exception("Database connection failed")
result = await check_async_database_health()
@@ -140,10 +143,10 @@ class TestInitAsyncDb:
@pytest.mark.asyncio
async def test_init_async_db_creates_tables(self, async_test_db):
"""Test init_async_db creates tables (covers lines 174-176)."""
test_engine, SessionLocal = async_test_db
test_engine, _SessionLocal = async_test_db
# Mock the engine to use test engine
with patch('app.core.database.engine', test_engine):
with patch("app.core.database.engine", test_engine):
await init_async_db()
# If no exception, tables were created successfully
@@ -155,7 +158,6 @@ class TestCloseAsyncDb:
async def test_close_async_db_disposes_engine(self):
"""Test close_async_db disposes engine (covers lines 185-186)."""
# Create a fresh engine to test closing
from app.core.database import engine
# Close connections
await close_async_db()

View File

@@ -2,14 +2,16 @@
"""
Comprehensive tests for CRUDBase class covering all error paths and edge cases.
"""
from datetime import UTC
from unittest.mock import patch
from uuid import uuid4
import pytest
from uuid import uuid4, UUID
from sqlalchemy.exc import IntegrityError, OperationalError, DataError
from sqlalchemy.exc import DataError, IntegrityError, OperationalError
from sqlalchemy.orm import joinedload
from unittest.mock import AsyncMock, patch, MagicMock
from app.crud.user import user as user_crud
from app.models.user import User
from app.schemas.users import UserCreate, UserUpdate
@@ -19,7 +21,7 @@ class TestCRUDBaseGet:
@pytest.mark.asyncio
async def test_get_with_invalid_uuid_string(self, async_test_db):
"""Test get with invalid UUID string returns None."""
test_engine, SessionLocal = async_test_db
_test_engine, SessionLocal = async_test_db
async with SessionLocal() as session:
result = await user_crud.get(session, id="invalid-uuid")
@@ -28,7 +30,7 @@ class TestCRUDBaseGet:
@pytest.mark.asyncio
async def test_get_with_invalid_uuid_type(self, async_test_db):
"""Test get with invalid UUID type returns None."""
test_engine, SessionLocal = async_test_db
_test_engine, SessionLocal = async_test_db
async with SessionLocal() as session:
result = await user_crud.get(session, id=12345) # int instead of UUID
@@ -37,7 +39,7 @@ class TestCRUDBaseGet:
@pytest.mark.asyncio
async def test_get_with_uuid_object(self, async_test_db, async_test_user):
"""Test get with UUID object instead of string."""
test_engine, SessionLocal = async_test_db
_test_engine, SessionLocal = async_test_db
async with SessionLocal() as session:
# Pass UUID object directly
@@ -48,26 +50,24 @@ class TestCRUDBaseGet:
@pytest.mark.asyncio
async def test_get_with_options(self, async_test_db, async_test_user):
"""Test get with eager loading options (tests lines 76-78)."""
test_engine, SessionLocal = async_test_db
_test_engine, SessionLocal = async_test_db
async with SessionLocal() as session:
# Test that options parameter is accepted and doesn't error
# We pass an empty list which still tests the code path
result = await user_crud.get(
session,
id=str(async_test_user.id),
options=[]
session, id=str(async_test_user.id), options=[]
)
assert result is not None
@pytest.mark.asyncio
async def test_get_database_error(self, async_test_db):
"""Test get handles database errors properly."""
test_engine, SessionLocal = async_test_db
_test_engine, SessionLocal = async_test_db
async with SessionLocal() as session:
# Mock execute to raise an exception
with patch.object(session, 'execute', side_effect=Exception("DB error")):
with patch.object(session, "execute", side_effect=Exception("DB error")):
with pytest.raises(Exception, match="DB error"):
await user_crud.get(session, id=str(uuid4()))
@@ -78,7 +78,7 @@ class TestCRUDBaseGetMulti:
@pytest.mark.asyncio
async def test_get_multi_negative_skip(self, async_test_db):
"""Test get_multi with negative skip raises ValueError."""
test_engine, SessionLocal = async_test_db
_test_engine, SessionLocal = async_test_db
async with SessionLocal() as session:
with pytest.raises(ValueError, match="skip must be non-negative"):
@@ -87,7 +87,7 @@ class TestCRUDBaseGetMulti:
@pytest.mark.asyncio
async def test_get_multi_negative_limit(self, async_test_db):
"""Test get_multi with negative limit raises ValueError."""
test_engine, SessionLocal = async_test_db
_test_engine, SessionLocal = async_test_db
async with SessionLocal() as session:
with pytest.raises(ValueError, match="limit must be non-negative"):
@@ -96,7 +96,7 @@ class TestCRUDBaseGetMulti:
@pytest.mark.asyncio
async def test_get_multi_limit_too_large(self, async_test_db):
"""Test get_multi with limit > 1000 raises ValueError."""
test_engine, SessionLocal = async_test_db
_test_engine, SessionLocal = async_test_db
async with SessionLocal() as session:
with pytest.raises(ValueError, match="Maximum limit is 1000"):
@@ -105,25 +105,20 @@ class TestCRUDBaseGetMulti:
@pytest.mark.asyncio
async def test_get_multi_with_options(self, async_test_db, async_test_user):
"""Test get_multi with eager loading options (tests lines 118-120)."""
test_engine, SessionLocal = async_test_db
_test_engine, SessionLocal = async_test_db
async with SessionLocal() as session:
# Test that options parameter is accepted
results = await user_crud.get_multi(
session,
skip=0,
limit=10,
options=[]
)
results = await user_crud.get_multi(session, skip=0, limit=10, options=[])
assert isinstance(results, list)
@pytest.mark.asyncio
async def test_get_multi_database_error(self, async_test_db):
"""Test get_multi handles database errors."""
test_engine, SessionLocal = async_test_db
_test_engine, SessionLocal = async_test_db
async with SessionLocal() as session:
with patch.object(session, 'execute', side_effect=Exception("DB error")):
with patch.object(session, "execute", side_effect=Exception("DB error")):
with pytest.raises(Exception, match="DB error"):
await user_crud.get_multi(session)
@@ -134,7 +129,7 @@ class TestCRUDBaseCreate:
@pytest.mark.asyncio
async def test_create_duplicate_unique_field(self, async_test_db, async_test_user):
"""Test create with duplicate unique field raises ValueError."""
test_engine, SessionLocal = async_test_db
_test_engine, SessionLocal = async_test_db
async with SessionLocal() as session:
# Try to create user with duplicate email
@@ -142,7 +137,7 @@ class TestCRUDBaseCreate:
email=async_test_user.email, # Duplicate!
password="TestPassword123!",
first_name="Test",
last_name="Duplicate"
last_name="Duplicate",
)
with pytest.raises(ValueError, match="already exists"):
@@ -151,22 +146,23 @@ class TestCRUDBaseCreate:
@pytest.mark.asyncio
async def test_create_integrity_error_non_duplicate(self, async_test_db):
"""Test create with non-duplicate IntegrityError."""
test_engine, SessionLocal = async_test_db
_test_engine, SessionLocal = async_test_db
async with SessionLocal() as session:
# Mock commit to raise IntegrityError without "unique" in message
original_commit = session.commit
async def mock_commit():
error = IntegrityError("statement", {}, Exception("foreign key violation"))
error = IntegrityError(
"statement", {}, Exception("foreign key violation")
)
raise error
with patch.object(session, 'commit', side_effect=mock_commit):
with patch.object(session, "commit", side_effect=mock_commit):
user_data = UserCreate(
email="test@example.com",
password="TestPassword123!",
first_name="Test",
last_name="User"
last_name="User",
)
with pytest.raises(ValueError, match="Database integrity error"):
@@ -175,15 +171,21 @@ class TestCRUDBaseCreate:
@pytest.mark.asyncio
async def test_create_operational_error(self, async_test_db):
"""Test create with OperationalError (user CRUD catches as generic Exception)."""
test_engine, SessionLocal = async_test_db
_test_engine, SessionLocal = async_test_db
async with SessionLocal() as session:
with patch.object(session, 'commit', side_effect=OperationalError("statement", {}, Exception("connection lost"))):
with patch.object(
session,
"commit",
side_effect=OperationalError(
"statement", {}, Exception("connection lost")
),
):
user_data = UserCreate(
email="test@example.com",
password="TestPassword123!",
first_name="Test",
last_name="User"
last_name="User",
)
# User CRUD catches this as generic Exception and re-raises
@@ -193,15 +195,19 @@ class TestCRUDBaseCreate:
@pytest.mark.asyncio
async def test_create_data_error(self, async_test_db):
"""Test create with DataError (user CRUD catches as generic Exception)."""
test_engine, SessionLocal = async_test_db
_test_engine, SessionLocal = async_test_db
async with SessionLocal() as session:
with patch.object(session, 'commit', side_effect=DataError("statement", {}, Exception("invalid data"))):
with patch.object(
session,
"commit",
side_effect=DataError("statement", {}, Exception("invalid data")),
):
user_data = UserCreate(
email="test@example.com",
password="TestPassword123!",
first_name="Test",
last_name="User"
last_name="User",
)
# User CRUD catches this as generic Exception and re-raises
@@ -211,15 +217,17 @@ class TestCRUDBaseCreate:
@pytest.mark.asyncio
async def test_create_unexpected_error(self, async_test_db):
"""Test create with unexpected exception."""
test_engine, SessionLocal = async_test_db
_test_engine, SessionLocal = async_test_db
async with SessionLocal() as session:
with patch.object(session, 'commit', side_effect=RuntimeError("Unexpected error")):
with patch.object(
session, "commit", side_effect=RuntimeError("Unexpected error")
):
user_data = UserCreate(
email="test@example.com",
password="TestPassword123!",
first_name="Test",
last_name="User"
last_name="User",
)
with pytest.raises(RuntimeError, match="Unexpected error"):
@@ -232,16 +240,17 @@ class TestCRUDBaseUpdate:
@pytest.mark.asyncio
async def test_update_duplicate_unique_field(self, async_test_db, async_test_user):
"""Test update with duplicate unique field raises ValueError."""
test_engine, SessionLocal = async_test_db
_test_engine, SessionLocal = async_test_db
# Create another user
async with SessionLocal() as session:
from app.crud.user import user as user_crud
user2_data = UserCreate(
email="user2@example.com",
password="TestPassword123!",
first_name="User",
last_name="Two"
last_name="Two",
)
user2 = await user_crud.create(session, obj_in=user2_data)
await session.commit()
@@ -250,63 +259,89 @@ class TestCRUDBaseUpdate:
async with SessionLocal() as session:
user2_obj = await user_crud.get(session, id=str(user2.id))
with patch.object(session, 'commit', side_effect=IntegrityError("statement", {}, Exception("UNIQUE constraint failed"))):
with patch.object(
session,
"commit",
side_effect=IntegrityError(
"statement", {}, Exception("UNIQUE constraint failed")
),
):
update_data = UserUpdate(email=async_test_user.email)
with pytest.raises(ValueError, match="already exists"):
await user_crud.update(session, db_obj=user2_obj, obj_in=update_data)
await user_crud.update(
session, db_obj=user2_obj, obj_in=update_data
)
@pytest.mark.asyncio
async def test_update_with_dict(self, async_test_db, async_test_user):
"""Test update with dict instead of schema."""
test_engine, SessionLocal = async_test_db
_test_engine, SessionLocal = async_test_db
async with SessionLocal() as session:
user = await user_crud.get(session, id=str(async_test_user.id))
# Update with dict (tests lines 164-165)
updated = await user_crud.update(
session,
db_obj=user,
obj_in={"first_name": "UpdatedName"}
session, db_obj=user, obj_in={"first_name": "UpdatedName"}
)
assert updated.first_name == "UpdatedName"
@pytest.mark.asyncio
async def test_update_integrity_error(self, async_test_db, async_test_user):
"""Test update with IntegrityError."""
test_engine, SessionLocal = async_test_db
_test_engine, SessionLocal = async_test_db
async with SessionLocal() as session:
user = await user_crud.get(session, id=str(async_test_user.id))
with patch.object(session, 'commit', side_effect=IntegrityError("statement", {}, Exception("constraint failed"))):
with patch.object(
session,
"commit",
side_effect=IntegrityError(
"statement", {}, Exception("constraint failed")
),
):
with pytest.raises(ValueError, match="Database integrity error"):
await user_crud.update(session, db_obj=user, obj_in={"first_name": "Test"})
await user_crud.update(
session, db_obj=user, obj_in={"first_name": "Test"}
)
@pytest.mark.asyncio
async def test_update_operational_error(self, async_test_db, async_test_user):
"""Test update with OperationalError."""
test_engine, SessionLocal = async_test_db
_test_engine, SessionLocal = async_test_db
async with SessionLocal() as session:
user = await user_crud.get(session, id=str(async_test_user.id))
with patch.object(session, 'commit', side_effect=OperationalError("statement", {}, Exception("connection error"))):
with patch.object(
session,
"commit",
side_effect=OperationalError(
"statement", {}, Exception("connection error")
),
):
with pytest.raises(ValueError, match="Database operation failed"):
await user_crud.update(session, db_obj=user, obj_in={"first_name": "Test"})
await user_crud.update(
session, db_obj=user, obj_in={"first_name": "Test"}
)
@pytest.mark.asyncio
async def test_update_unexpected_error(self, async_test_db, async_test_user):
"""Test update with unexpected error."""
test_engine, SessionLocal = async_test_db
_test_engine, SessionLocal = async_test_db
async with SessionLocal() as session:
user = await user_crud.get(session, id=str(async_test_user.id))
with patch.object(session, 'commit', side_effect=RuntimeError("Unexpected")):
with patch.object(
session, "commit", side_effect=RuntimeError("Unexpected")
):
with pytest.raises(RuntimeError):
await user_crud.update(session, db_obj=user, obj_in={"first_name": "Test"})
await user_crud.update(
session, db_obj=user, obj_in={"first_name": "Test"}
)
class TestCRUDBaseRemove:
@@ -315,7 +350,7 @@ class TestCRUDBaseRemove:
@pytest.mark.asyncio
async def test_remove_invalid_uuid(self, async_test_db):
"""Test remove with invalid UUID returns None."""
test_engine, SessionLocal = async_test_db
_test_engine, SessionLocal = async_test_db
async with SessionLocal() as session:
result = await user_crud.remove(session, id="invalid-uuid")
@@ -324,7 +359,7 @@ class TestCRUDBaseRemove:
@pytest.mark.asyncio
async def test_remove_with_uuid_object(self, async_test_db, async_test_user):
"""Test remove with UUID object."""
test_engine, SessionLocal = async_test_db
_test_engine, SessionLocal = async_test_db
# Create a user to delete
async with SessionLocal() as session:
@@ -332,7 +367,7 @@ class TestCRUDBaseRemove:
email="todelete@example.com",
password="TestPassword123!",
first_name="To",
last_name="Delete"
last_name="Delete",
)
user = await user_crud.create(session, obj_in=user_data)
user_id = user.id
@@ -347,7 +382,7 @@ class TestCRUDBaseRemove:
@pytest.mark.asyncio
async def test_remove_nonexistent(self, async_test_db):
"""Test remove of nonexistent record returns None."""
test_engine, SessionLocal = async_test_db
_test_engine, SessionLocal = async_test_db
async with SessionLocal() as session:
result = await user_crud.remove(session, id=str(uuid4()))
@@ -356,21 +391,31 @@ class TestCRUDBaseRemove:
@pytest.mark.asyncio
async def test_remove_integrity_error(self, async_test_db, async_test_user):
"""Test remove with IntegrityError (foreign key constraint)."""
test_engine, SessionLocal = async_test_db
_test_engine, SessionLocal = async_test_db
async with SessionLocal() as session:
# Mock delete to raise IntegrityError
with patch.object(session, 'commit', side_effect=IntegrityError("statement", {}, Exception("FOREIGN KEY constraint"))):
with pytest.raises(ValueError, match="Cannot delete.*referenced by other records"):
with patch.object(
session,
"commit",
side_effect=IntegrityError(
"statement", {}, Exception("FOREIGN KEY constraint")
),
):
with pytest.raises(
ValueError, match="Cannot delete.*referenced by other records"
):
await user_crud.remove(session, id=str(async_test_user.id))
@pytest.mark.asyncio
async def test_remove_unexpected_error(self, async_test_db, async_test_user):
"""Test remove with unexpected error."""
test_engine, SessionLocal = async_test_db
_test_engine, SessionLocal = async_test_db
async with SessionLocal() as session:
with patch.object(session, 'commit', side_effect=RuntimeError("Unexpected")):
with patch.object(
session, "commit", side_effect=RuntimeError("Unexpected")
):
with pytest.raises(RuntimeError):
await user_crud.remove(session, id=str(async_test_user.id))
@@ -381,10 +426,12 @@ class TestCRUDBaseGetMultiWithTotal:
@pytest.mark.asyncio
async def test_get_multi_with_total_basic(self, async_test_db, async_test_user):
"""Test get_multi_with_total basic functionality."""
test_engine, SessionLocal = async_test_db
_test_engine, SessionLocal = async_test_db
async with SessionLocal() as session:
items, total = await user_crud.get_multi_with_total(session, skip=0, limit=10)
items, total = await user_crud.get_multi_with_total(
session, skip=0, limit=10
)
assert isinstance(items, list)
assert isinstance(total, int)
assert total >= 1 # At least the test user
@@ -392,7 +439,7 @@ class TestCRUDBaseGetMultiWithTotal:
@pytest.mark.asyncio
async def test_get_multi_with_total_negative_skip(self, async_test_db):
"""Test get_multi_with_total with negative skip raises ValueError."""
test_engine, SessionLocal = async_test_db
_test_engine, SessionLocal = async_test_db
async with SessionLocal() as session:
with pytest.raises(ValueError, match="skip must be non-negative"):
@@ -401,7 +448,7 @@ class TestCRUDBaseGetMultiWithTotal:
@pytest.mark.asyncio
async def test_get_multi_with_total_negative_limit(self, async_test_db):
"""Test get_multi_with_total with negative limit raises ValueError."""
test_engine, SessionLocal = async_test_db
_test_engine, SessionLocal = async_test_db
async with SessionLocal() as session:
with pytest.raises(ValueError, match="limit must be non-negative"):
@@ -410,28 +457,34 @@ class TestCRUDBaseGetMultiWithTotal:
@pytest.mark.asyncio
async def test_get_multi_with_total_limit_too_large(self, async_test_db):
"""Test get_multi_with_total with limit > 1000 raises ValueError."""
test_engine, SessionLocal = async_test_db
_test_engine, SessionLocal = async_test_db
async with SessionLocal() as session:
with pytest.raises(ValueError, match="Maximum limit is 1000"):
await user_crud.get_multi_with_total(session, limit=1001)
@pytest.mark.asyncio
async def test_get_multi_with_total_with_filters(self, async_test_db, async_test_user):
async def test_get_multi_with_total_with_filters(
self, async_test_db, async_test_user
):
"""Test get_multi_with_total with filters."""
test_engine, SessionLocal = async_test_db
_test_engine, SessionLocal = async_test_db
async with SessionLocal() as session:
filters = {"email": async_test_user.email}
items, total = await user_crud.get_multi_with_total(session, filters=filters)
items, total = await user_crud.get_multi_with_total(
session, filters=filters
)
assert total == 1
assert len(items) == 1
assert items[0].email == async_test_user.email
@pytest.mark.asyncio
async def test_get_multi_with_total_with_sorting_asc(self, async_test_db, async_test_user):
async def test_get_multi_with_total_with_sorting_asc(
self, async_test_db, async_test_user
):
"""Test get_multi_with_total with ascending sort."""
test_engine, SessionLocal = async_test_db
_test_engine, SessionLocal = async_test_db
# Create additional users
async with SessionLocal() as session:
@@ -439,13 +492,13 @@ class TestCRUDBaseGetMultiWithTotal:
email="aaa@example.com",
password="TestPassword123!",
first_name="AAA",
last_name="User"
last_name="User",
)
user_data2 = UserCreate(
email="zzz@example.com",
password="TestPassword123!",
first_name="ZZZ",
last_name="User"
last_name="User",
)
await user_crud.create(session, obj_in=user_data1)
await user_crud.create(session, obj_in=user_data2)
@@ -460,9 +513,11 @@ class TestCRUDBaseGetMultiWithTotal:
assert items[0].email == "aaa@example.com"
@pytest.mark.asyncio
async def test_get_multi_with_total_with_sorting_desc(self, async_test_db, async_test_user):
async def test_get_multi_with_total_with_sorting_desc(
self, async_test_db, async_test_user
):
"""Test get_multi_with_total with descending sort."""
test_engine, SessionLocal = async_test_db
_test_engine, SessionLocal = async_test_db
# Create additional users
async with SessionLocal() as session:
@@ -470,20 +525,20 @@ class TestCRUDBaseGetMultiWithTotal:
email="bbb@example.com",
password="TestPassword123!",
first_name="BBB",
last_name="User"
last_name="User",
)
user_data2 = UserCreate(
email="ccc@example.com",
password="TestPassword123!",
first_name="CCC",
last_name="User"
last_name="User",
)
await user_crud.create(session, obj_in=user_data1)
await user_crud.create(session, obj_in=user_data2)
await session.commit()
async with SessionLocal() as session:
items, total = await user_crud.get_multi_with_total(
items, _total = await user_crud.get_multi_with_total(
session, sort_by="email", sort_order="desc", limit=1
)
assert len(items) == 1
@@ -492,7 +547,7 @@ class TestCRUDBaseGetMultiWithTotal:
@pytest.mark.asyncio
async def test_get_multi_with_total_with_pagination(self, async_test_db):
"""Test get_multi_with_total pagination works correctly."""
test_engine, SessionLocal = async_test_db
_test_engine, SessionLocal = async_test_db
# Create minimal users for pagination test (3 instead of 5)
async with SessionLocal() as session:
@@ -501,19 +556,23 @@ class TestCRUDBaseGetMultiWithTotal:
email=f"user{i}@example.com",
password="TestPassword123!",
first_name=f"User{i}",
last_name="Test"
last_name="Test",
)
await user_crud.create(session, obj_in=user_data)
await session.commit()
async with SessionLocal() as session:
# Get first page
items1, total = await user_crud.get_multi_with_total(session, skip=0, limit=2)
items1, total = await user_crud.get_multi_with_total(
session, skip=0, limit=2
)
assert len(items1) == 2
assert total >= 3
# Get second page
items2, total2 = await user_crud.get_multi_with_total(session, skip=2, limit=2)
items2, total2 = await user_crud.get_multi_with_total(
session, skip=2, limit=2
)
assert len(items2) >= 1
assert total2 == total
@@ -529,7 +588,7 @@ class TestCRUDBaseCount:
@pytest.mark.asyncio
async def test_count_basic(self, async_test_db, async_test_user):
"""Test count returns correct number."""
test_engine, SessionLocal = async_test_db
_test_engine, SessionLocal = async_test_db
async with SessionLocal() as session:
count = await user_crud.count(session)
@@ -539,7 +598,7 @@ class TestCRUDBaseCount:
@pytest.mark.asyncio
async def test_count_multiple_users(self, async_test_db, async_test_user):
"""Test count with multiple users."""
test_engine, SessionLocal = async_test_db
_test_engine, SessionLocal = async_test_db
# Create additional users
async with SessionLocal() as session:
@@ -549,13 +608,13 @@ class TestCRUDBaseCount:
email="count1@example.com",
password="TestPassword123!",
first_name="Count",
last_name="One"
last_name="One",
)
user_data2 = UserCreate(
email="count2@example.com",
password="TestPassword123!",
first_name="Count",
last_name="Two"
last_name="Two",
)
await user_crud.create(session, obj_in=user_data1)
await user_crud.create(session, obj_in=user_data2)
@@ -568,10 +627,10 @@ class TestCRUDBaseCount:
@pytest.mark.asyncio
async def test_count_database_error(self, async_test_db):
"""Test count handles database errors."""
test_engine, SessionLocal = async_test_db
_test_engine, SessionLocal = async_test_db
async with SessionLocal() as session:
with patch.object(session, 'execute', side_effect=Exception("DB error")):
with patch.object(session, "execute", side_effect=Exception("DB error")):
with pytest.raises(Exception, match="DB error"):
await user_crud.count(session)
@@ -582,7 +641,7 @@ class TestCRUDBaseExists:
@pytest.mark.asyncio
async def test_exists_true(self, async_test_db, async_test_user):
"""Test exists returns True for existing record."""
test_engine, SessionLocal = async_test_db
_test_engine, SessionLocal = async_test_db
async with SessionLocal() as session:
result = await user_crud.exists(session, id=str(async_test_user.id))
@@ -591,7 +650,7 @@ class TestCRUDBaseExists:
@pytest.mark.asyncio
async def test_exists_false(self, async_test_db):
"""Test exists returns False for non-existent record."""
test_engine, SessionLocal = async_test_db
_test_engine, SessionLocal = async_test_db
async with SessionLocal() as session:
result = await user_crud.exists(session, id=str(uuid4()))
@@ -600,7 +659,7 @@ class TestCRUDBaseExists:
@pytest.mark.asyncio
async def test_exists_invalid_uuid(self, async_test_db):
"""Test exists returns False for invalid UUID."""
test_engine, SessionLocal = async_test_db
_test_engine, SessionLocal = async_test_db
async with SessionLocal() as session:
result = await user_crud.exists(session, id="invalid-uuid")
@@ -613,7 +672,7 @@ class TestCRUDBaseSoftDelete:
@pytest.mark.asyncio
async def test_soft_delete_success(self, async_test_db):
"""Test soft delete sets deleted_at timestamp."""
test_engine, SessionLocal = async_test_db
_test_engine, SessionLocal = async_test_db
# Create a user to soft delete
async with SessionLocal() as session:
@@ -621,7 +680,7 @@ class TestCRUDBaseSoftDelete:
email="softdelete@example.com",
password="TestPassword123!",
first_name="Soft",
last_name="Delete"
last_name="Delete",
)
user = await user_crud.create(session, obj_in=user_data)
user_id = user.id
@@ -636,7 +695,7 @@ class TestCRUDBaseSoftDelete:
@pytest.mark.asyncio
async def test_soft_delete_invalid_uuid(self, async_test_db):
"""Test soft delete with invalid UUID returns None."""
test_engine, SessionLocal = async_test_db
_test_engine, SessionLocal = async_test_db
async with SessionLocal() as session:
result = await user_crud.soft_delete(session, id="invalid-uuid")
@@ -645,7 +704,7 @@ class TestCRUDBaseSoftDelete:
@pytest.mark.asyncio
async def test_soft_delete_nonexistent(self, async_test_db):
"""Test soft delete of nonexistent record returns None."""
test_engine, SessionLocal = async_test_db
_test_engine, SessionLocal = async_test_db
async with SessionLocal() as session:
result = await user_crud.soft_delete(session, id=str(uuid4()))
@@ -654,7 +713,7 @@ class TestCRUDBaseSoftDelete:
@pytest.mark.asyncio
async def test_soft_delete_with_uuid_object(self, async_test_db):
"""Test soft delete with UUID object."""
test_engine, SessionLocal = async_test_db
_test_engine, SessionLocal = async_test_db
# Create a user to soft delete
async with SessionLocal() as session:
@@ -662,7 +721,7 @@ class TestCRUDBaseSoftDelete:
email="softdelete2@example.com",
password="TestPassword123!",
first_name="Soft",
last_name="Delete2"
last_name="Delete2",
)
user = await user_crud.create(session, obj_in=user_data)
user_id = user.id
@@ -681,7 +740,7 @@ class TestCRUDBaseRestore:
@pytest.mark.asyncio
async def test_restore_success(self, async_test_db):
"""Test restore clears deleted_at timestamp."""
test_engine, SessionLocal = async_test_db
_test_engine, SessionLocal = async_test_db
# Create and soft delete a user
async with SessionLocal() as session:
@@ -689,7 +748,7 @@ class TestCRUDBaseRestore:
email="restore@example.com",
password="TestPassword123!",
first_name="Restore",
last_name="Test"
last_name="Test",
)
user = await user_crud.create(session, obj_in=user_data)
user_id = user.id
@@ -707,7 +766,7 @@ class TestCRUDBaseRestore:
@pytest.mark.asyncio
async def test_restore_invalid_uuid(self, async_test_db):
"""Test restore with invalid UUID returns None."""
test_engine, SessionLocal = async_test_db
_test_engine, SessionLocal = async_test_db
async with SessionLocal() as session:
result = await user_crud.restore(session, id="invalid-uuid")
@@ -716,7 +775,7 @@ class TestCRUDBaseRestore:
@pytest.mark.asyncio
async def test_restore_nonexistent(self, async_test_db):
"""Test restore of nonexistent record returns None."""
test_engine, SessionLocal = async_test_db
_test_engine, SessionLocal = async_test_db
async with SessionLocal() as session:
result = await user_crud.restore(session, id=str(uuid4()))
@@ -725,7 +784,7 @@ class TestCRUDBaseRestore:
@pytest.mark.asyncio
async def test_restore_not_deleted(self, async_test_db, async_test_user):
"""Test restore of non-deleted record returns None."""
test_engine, SessionLocal = async_test_db
_test_engine, SessionLocal = async_test_db
async with SessionLocal() as session:
# Try to restore a user that's not deleted
@@ -735,7 +794,7 @@ class TestCRUDBaseRestore:
@pytest.mark.asyncio
async def test_restore_with_uuid_object(self, async_test_db):
"""Test restore with UUID object."""
test_engine, SessionLocal = async_test_db
_test_engine, SessionLocal = async_test_db
# Create and soft delete a user
async with SessionLocal() as session:
@@ -743,7 +802,7 @@ class TestCRUDBaseRestore:
email="restore2@example.com",
password="TestPassword123!",
first_name="Restore",
last_name="Test2"
last_name="Test2",
)
user = await user_crud.create(session, obj_in=user_data)
user_id = user.id
@@ -765,7 +824,7 @@ class TestCRUDBasePaginationValidation:
@pytest.mark.asyncio
async def test_get_multi_with_total_negative_skip(self, async_test_db):
"""Test that negative skip raises ValueError."""
test_engine, SessionLocal = async_test_db
_test_engine, SessionLocal = async_test_db
async with SessionLocal() as session:
with pytest.raises(ValueError, match="skip must be non-negative"):
@@ -774,7 +833,7 @@ class TestCRUDBasePaginationValidation:
@pytest.mark.asyncio
async def test_get_multi_with_total_negative_limit(self, async_test_db):
"""Test that negative limit raises ValueError."""
test_engine, SessionLocal = async_test_db
_test_engine, SessionLocal = async_test_db
async with SessionLocal() as session:
with pytest.raises(ValueError, match="limit must be non-negative"):
@@ -783,23 +842,22 @@ class TestCRUDBasePaginationValidation:
@pytest.mark.asyncio
async def test_get_multi_with_total_limit_too_large(self, async_test_db):
"""Test that limit > 1000 raises ValueError."""
test_engine, SessionLocal = async_test_db
_test_engine, SessionLocal = async_test_db
async with SessionLocal() as session:
with pytest.raises(ValueError, match="Maximum limit is 1000"):
await user_crud.get_multi_with_total(session, skip=0, limit=1001)
@pytest.mark.asyncio
async def test_get_multi_with_total_with_filters(self, async_test_db, async_test_user):
async def test_get_multi_with_total_with_filters(
self, async_test_db, async_test_user
):
"""Test pagination with filters (covers lines 270-273)."""
test_engine, SessionLocal = async_test_db
_test_engine, SessionLocal = async_test_db
async with SessionLocal() as session:
users, total = await user_crud.get_multi_with_total(
session,
skip=0,
limit=10,
filters={"is_active": True}
session, skip=0, limit=10, filters={"is_active": True}
)
assert isinstance(users, list)
assert total >= 0
@@ -807,30 +865,22 @@ class TestCRUDBasePaginationValidation:
@pytest.mark.asyncio
async def test_get_multi_with_total_with_sorting_desc(self, async_test_db):
"""Test pagination with descending sort (covers lines 283-284)."""
test_engine, SessionLocal = async_test_db
_test_engine, SessionLocal = async_test_db
async with SessionLocal() as session:
users, total = await user_crud.get_multi_with_total(
session,
skip=0,
limit=10,
sort_by="created_at",
sort_order="desc"
users, _total = await user_crud.get_multi_with_total(
session, skip=0, limit=10, sort_by="created_at", sort_order="desc"
)
assert isinstance(users, list)
@pytest.mark.asyncio
async def test_get_multi_with_total_with_sorting_asc(self, async_test_db):
"""Test pagination with ascending sort (covers lines 285-286)."""
test_engine, SessionLocal = async_test_db
_test_engine, SessionLocal = async_test_db
async with SessionLocal() as session:
users, total = await user_crud.get_multi_with_total(
session,
skip=0,
limit=10,
sort_by="created_at",
sort_order="asc"
users, _total = await user_crud.get_multi_with_total(
session, skip=0, limit=10, sort_by="created_at", sort_order="asc"
)
assert isinstance(users, list)
@@ -842,13 +892,15 @@ class TestCRUDBaseModelsWithoutSoftDelete:
"""
@pytest.mark.asyncio
async def test_soft_delete_model_without_deleted_at(self, async_test_db, async_test_user):
async def test_soft_delete_model_without_deleted_at(
self, async_test_db, async_test_user
):
"""Test soft_delete on Organization model (no deleted_at) raises ValueError (covers lines 342-343)."""
test_engine, SessionLocal = async_test_db
_test_engine, SessionLocal = async_test_db
# Create an organization (which doesn't have deleted_at)
from app.models.organization import Organization
from app.crud.organization import organization as org_crud
from app.models.organization import Organization
async with SessionLocal() as session:
org = Organization(name="Test Org", slug="test-org")
@@ -864,11 +916,11 @@ class TestCRUDBaseModelsWithoutSoftDelete:
@pytest.mark.asyncio
async def test_restore_model_without_deleted_at(self, async_test_db):
"""Test restore on Organization model (no deleted_at) raises ValueError (covers lines 383-384)."""
test_engine, SessionLocal = async_test_db
_test_engine, SessionLocal = async_test_db
# Create an organization (which doesn't have deleted_at)
from app.models.organization import Organization
from app.crud.organization import organization as org_crud
from app.models.organization import Organization
async with SessionLocal() as session:
org = Organization(name="Restore Test", slug="restore-test")
@@ -889,14 +941,17 @@ class TestCRUDBaseEagerLoadingWithRealOptions:
"""
@pytest.mark.asyncio
async def test_get_with_real_eager_loading_options(self, async_test_db, async_test_user):
async def test_get_with_real_eager_loading_options(
self, async_test_db, async_test_user
):
"""Test get() with actual eager loading options (covers lines 77-78)."""
from datetime import datetime, timedelta, timezone
test_engine, SessionLocal = async_test_db
from datetime import datetime, timedelta
_test_engine, SessionLocal = async_test_db
# Create a session for the user
from app.models.user_session import UserSession
from app.crud.session import session as session_crud
from app.models.user_session import UserSession
async with SessionLocal() as session:
user_session = UserSession(
@@ -905,8 +960,8 @@ class TestCRUDBaseEagerLoadingWithRealOptions:
device_id="test-device",
ip_address="192.168.1.1",
user_agent="Test Agent",
last_used_at=datetime.now(timezone.utc),
expires_at=datetime.now(timezone.utc) + timedelta(days=60)
last_used_at=datetime.now(UTC),
expires_at=datetime.now(UTC) + timedelta(days=60),
)
session.add(user_session)
await session.commit()
@@ -917,7 +972,7 @@ class TestCRUDBaseEagerLoadingWithRealOptions:
result = await session_crud.get(
session,
id=str(session_id),
options=[joinedload(UserSession.user)] # Real option, not empty list
options=[joinedload(UserSession.user)], # Real option, not empty list
)
assert result is not None
assert result.id == session_id
@@ -925,14 +980,17 @@ class TestCRUDBaseEagerLoadingWithRealOptions:
assert result.user.email == async_test_user.email
@pytest.mark.asyncio
async def test_get_multi_with_real_eager_loading_options(self, async_test_db, async_test_user):
async def test_get_multi_with_real_eager_loading_options(
self, async_test_db, async_test_user
):
"""Test get_multi() with actual eager loading options (covers lines 119-120)."""
from datetime import datetime, timedelta, timezone
test_engine, SessionLocal = async_test_db
from datetime import datetime, timedelta
_test_engine, SessionLocal = async_test_db
# Create multiple sessions for the user
from app.models.user_session import UserSession
from app.crud.session import session as session_crud
from app.models.user_session import UserSession
async with SessionLocal() as session:
for i in range(3):
@@ -942,8 +1000,8 @@ class TestCRUDBaseEagerLoadingWithRealOptions:
device_id=f"device-{i}",
ip_address=f"192.168.1.{i}",
user_agent=f"Agent {i}",
last_used_at=datetime.now(timezone.utc),
expires_at=datetime.now(timezone.utc) + timedelta(days=60)
last_used_at=datetime.now(UTC),
expires_at=datetime.now(UTC) + timedelta(days=60),
)
session.add(user_session)
await session.commit()
@@ -954,7 +1012,7 @@ class TestCRUDBaseEagerLoadingWithRealOptions:
session,
skip=0,
limit=10,
options=[joinedload(UserSession.user)] # Real option, not empty list
options=[joinedload(UserSession.user)], # Real option, not empty list
)
assert len(results) >= 3
# Verify we can access user without additional queries

View File

@@ -3,13 +3,15 @@
Comprehensive tests for base CRUD database failure scenarios.
Tests exception handling, rollbacks, and error messages.
"""
import pytest
from unittest.mock import AsyncMock, MagicMock, patch
from sqlalchemy.exc import IntegrityError, OperationalError, DataError
from unittest.mock import AsyncMock, patch
from uuid import uuid4
import pytest
from sqlalchemy.exc import DataError, OperationalError
from app.crud.user import user as user_crud
from app.schemas.users import UserCreate, UserUpdate
from app.schemas.users import UserCreate
class TestBaseCRUDCreateFailures:
@@ -18,19 +20,24 @@ class TestBaseCRUDCreateFailures:
@pytest.mark.asyncio
async def test_create_operational_error_triggers_rollback(self, async_test_db):
"""Test that OperationalError triggers rollback (User CRUD catches as Exception)."""
test_engine, SessionLocal = async_test_db
_test_engine, SessionLocal = async_test_db
async with SessionLocal() as session:
async def mock_commit():
raise OperationalError("Connection lost", {}, Exception("DB connection failed"))
with patch.object(session, 'commit', side_effect=mock_commit):
with patch.object(session, 'rollback', new_callable=AsyncMock) as mock_rollback:
async def mock_commit():
raise OperationalError(
"Connection lost", {}, Exception("DB connection failed")
)
with patch.object(session, "commit", side_effect=mock_commit):
with patch.object(
session, "rollback", new_callable=AsyncMock
) as mock_rollback:
user_data = UserCreate(
email="operror@example.com",
password="TestPassword123!",
first_name="Test",
last_name="User"
last_name="User",
)
# User CRUD catches this as generic Exception and re-raises
@@ -43,19 +50,22 @@ class TestBaseCRUDCreateFailures:
@pytest.mark.asyncio
async def test_create_data_error_triggers_rollback(self, async_test_db):
"""Test that DataError triggers rollback (User CRUD catches as Exception)."""
test_engine, SessionLocal = async_test_db
_test_engine, SessionLocal = async_test_db
async with SessionLocal() as session:
async def mock_commit():
raise DataError("Invalid data type", {}, Exception("Data overflow"))
with patch.object(session, 'commit', side_effect=mock_commit):
with patch.object(session, 'rollback', new_callable=AsyncMock) as mock_rollback:
with patch.object(session, "commit", side_effect=mock_commit):
with patch.object(
session, "rollback", new_callable=AsyncMock
) as mock_rollback:
user_data = UserCreate(
email="dataerror@example.com",
password="TestPassword123!",
first_name="Test",
last_name="User"
last_name="User",
)
# User CRUD catches this as generic Exception and re-raises
@@ -67,19 +77,22 @@ class TestBaseCRUDCreateFailures:
@pytest.mark.asyncio
async def test_create_unexpected_exception_triggers_rollback(self, async_test_db):
"""Test that unexpected exceptions trigger rollback and re-raise."""
test_engine, SessionLocal = async_test_db
_test_engine, SessionLocal = async_test_db
async with SessionLocal() as session:
async def mock_commit():
raise RuntimeError("Unexpected database error")
with patch.object(session, 'commit', side_effect=mock_commit):
with patch.object(session, 'rollback', new_callable=AsyncMock) as mock_rollback:
with patch.object(session, "commit", side_effect=mock_commit):
with patch.object(
session, "rollback", new_callable=AsyncMock
) as mock_rollback:
user_data = UserCreate(
email="unexpected@example.com",
password="TestPassword123!",
first_name="Test",
last_name="User"
last_name="User",
)
with pytest.raises(RuntimeError, match="Unexpected database error"):
@@ -94,7 +107,7 @@ class TestBaseCRUDUpdateFailures:
@pytest.mark.asyncio
async def test_update_operational_error(self, async_test_db, async_test_user):
"""Test update with OperationalError."""
test_engine, SessionLocal = async_test_db
_test_engine, SessionLocal = async_test_db
async with SessionLocal() as session:
user = await user_crud.get(session, id=str(async_test_user.id))
@@ -102,17 +115,21 @@ class TestBaseCRUDUpdateFailures:
async def mock_commit():
raise OperationalError("Connection timeout", {}, Exception("Timeout"))
with patch.object(session, 'commit', side_effect=mock_commit):
with patch.object(session, 'rollback', new_callable=AsyncMock) as mock_rollback:
with patch.object(session, "commit", side_effect=mock_commit):
with patch.object(
session, "rollback", new_callable=AsyncMock
) as mock_rollback:
with pytest.raises(ValueError, match="Database operation failed"):
await user_crud.update(session, db_obj=user, obj_in={"first_name": "Updated"})
await user_crud.update(
session, db_obj=user, obj_in={"first_name": "Updated"}
)
mock_rollback.assert_called_once()
@pytest.mark.asyncio
async def test_update_data_error(self, async_test_db, async_test_user):
"""Test update with DataError."""
test_engine, SessionLocal = async_test_db
_test_engine, SessionLocal = async_test_db
async with SessionLocal() as session:
user = await user_crud.get(session, id=str(async_test_user.id))
@@ -120,17 +137,21 @@ class TestBaseCRUDUpdateFailures:
async def mock_commit():
raise DataError("Invalid data", {}, Exception("Data type mismatch"))
with patch.object(session, 'commit', side_effect=mock_commit):
with patch.object(session, 'rollback', new_callable=AsyncMock) as mock_rollback:
with patch.object(session, "commit", side_effect=mock_commit):
with patch.object(
session, "rollback", new_callable=AsyncMock
) as mock_rollback:
with pytest.raises(ValueError, match="Database operation failed"):
await user_crud.update(session, db_obj=user, obj_in={"first_name": "Updated"})
await user_crud.update(
session, db_obj=user, obj_in={"first_name": "Updated"}
)
mock_rollback.assert_called_once()
@pytest.mark.asyncio
async def test_update_unexpected_error(self, async_test_db, async_test_user):
"""Test update with unexpected error."""
test_engine, SessionLocal = async_test_db
_test_engine, SessionLocal = async_test_db
async with SessionLocal() as session:
user = await user_crud.get(session, id=str(async_test_user.id))
@@ -138,10 +159,14 @@ class TestBaseCRUDUpdateFailures:
async def mock_commit():
raise KeyError("Unexpected error")
with patch.object(session, 'commit', side_effect=mock_commit):
with patch.object(session, 'rollback', new_callable=AsyncMock) as mock_rollback:
with patch.object(session, "commit", side_effect=mock_commit):
with patch.object(
session, "rollback", new_callable=AsyncMock
) as mock_rollback:
with pytest.raises(KeyError):
await user_crud.update(session, db_obj=user, obj_in={"first_name": "Updated"})
await user_crud.update(
session, db_obj=user, obj_in={"first_name": "Updated"}
)
mock_rollback.assert_called_once()
@@ -150,16 +175,21 @@ class TestBaseCRUDRemoveFailures:
"""Test base CRUD remove method exception handling."""
@pytest.mark.asyncio
async def test_remove_unexpected_error_triggers_rollback(self, async_test_db, async_test_user):
async def test_remove_unexpected_error_triggers_rollback(
self, async_test_db, async_test_user
):
"""Test that unexpected errors in remove trigger rollback."""
test_engine, SessionLocal = async_test_db
_test_engine, SessionLocal = async_test_db
async with SessionLocal() as session:
async def mock_commit():
raise RuntimeError("Database write failed")
with patch.object(session, 'commit', side_effect=mock_commit):
with patch.object(session, 'rollback', new_callable=AsyncMock) as mock_rollback:
with patch.object(session, "commit", side_effect=mock_commit):
with patch.object(
session, "rollback", new_callable=AsyncMock
) as mock_rollback:
with pytest.raises(RuntimeError, match="Database write failed"):
await user_crud.remove(session, id=str(async_test_user.id))
@@ -172,16 +202,15 @@ class TestBaseCRUDGetMultiWithTotalFailures:
@pytest.mark.asyncio
async def test_get_multi_with_total_database_error(self, async_test_db):
"""Test get_multi_with_total handles database errors."""
test_engine, SessionLocal = async_test_db
_test_engine, SessionLocal = async_test_db
async with SessionLocal() as session:
# Mock execute to raise an error
original_execute = session.execute
async def mock_execute(*args, **kwargs):
raise OperationalError("Query failed", {}, Exception("Database error"))
with patch.object(session, 'execute', side_effect=mock_execute):
with patch.object(session, "execute", side_effect=mock_execute):
with pytest.raises(OperationalError):
await user_crud.get_multi_with_total(session, skip=0, limit=10)
@@ -192,13 +221,14 @@ class TestBaseCRUDCountFailures:
@pytest.mark.asyncio
async def test_count_database_error_propagates(self, async_test_db):
"""Test count propagates database errors."""
test_engine, SessionLocal = async_test_db
_test_engine, SessionLocal = async_test_db
async with SessionLocal() as session:
async def mock_execute(*args, **kwargs):
raise OperationalError("Count failed", {}, Exception("DB error"))
with patch.object(session, 'execute', side_effect=mock_execute):
with patch.object(session, "execute", side_effect=mock_execute):
with pytest.raises(OperationalError):
await user_crud.count(session)
@@ -207,16 +237,21 @@ class TestBaseCRUDSoftDeleteFailures:
"""Test soft_delete method exception handling."""
@pytest.mark.asyncio
async def test_soft_delete_unexpected_error_triggers_rollback(self, async_test_db, async_test_user):
async def test_soft_delete_unexpected_error_triggers_rollback(
self, async_test_db, async_test_user
):
"""Test soft_delete handles unexpected errors with rollback."""
test_engine, SessionLocal = async_test_db
_test_engine, SessionLocal = async_test_db
async with SessionLocal() as session:
async def mock_commit():
raise RuntimeError("Soft delete failed")
with patch.object(session, 'commit', side_effect=mock_commit):
with patch.object(session, 'rollback', new_callable=AsyncMock) as mock_rollback:
with patch.object(session, "commit", side_effect=mock_commit):
with patch.object(
session, "rollback", new_callable=AsyncMock
) as mock_rollback:
with pytest.raises(RuntimeError, match="Soft delete failed"):
await user_crud.soft_delete(session, id=str(async_test_user.id))
@@ -229,7 +264,7 @@ class TestBaseCRUDRestoreFailures:
@pytest.mark.asyncio
async def test_restore_unexpected_error_triggers_rollback(self, async_test_db):
"""Test restore handles unexpected errors with rollback."""
test_engine, SessionLocal = async_test_db
_test_engine, SessionLocal = async_test_db
# First create and soft delete a user
async with SessionLocal() as session:
@@ -237,7 +272,7 @@ class TestBaseCRUDRestoreFailures:
email="restore_test@example.com",
password="TestPassword123!",
first_name="Restore",
last_name="Test"
last_name="Test",
)
user = await user_crud.create(session, obj_in=user_data)
user_id = user.id
@@ -248,11 +283,14 @@ class TestBaseCRUDRestoreFailures:
# Now test restore failure
async with SessionLocal() as session:
async def mock_commit():
raise RuntimeError("Restore failed")
with patch.object(session, 'commit', side_effect=mock_commit):
with patch.object(session, 'rollback', new_callable=AsyncMock) as mock_rollback:
with patch.object(session, "commit", side_effect=mock_commit):
with patch.object(
session, "rollback", new_callable=AsyncMock
) as mock_rollback:
with pytest.raises(RuntimeError, match="Restore failed"):
await user_crud.restore(session, id=str(user_id))
@@ -265,13 +303,14 @@ class TestBaseCRUDGetFailures:
@pytest.mark.asyncio
async def test_get_database_error_propagates(self, async_test_db):
"""Test get propagates database errors."""
test_engine, SessionLocal = async_test_db
_test_engine, SessionLocal = async_test_db
async with SessionLocal() as session:
async def mock_execute(*args, **kwargs):
raise OperationalError("Get failed", {}, Exception("DB error"))
with patch.object(session, 'execute', side_effect=mock_execute):
with patch.object(session, "execute", side_effect=mock_execute):
with pytest.raises(OperationalError):
await user_crud.get(session, id=str(uuid4()))
@@ -282,12 +321,13 @@ class TestBaseCRUDGetMultiFailures:
@pytest.mark.asyncio
async def test_get_multi_database_error_propagates(self, async_test_db):
"""Test get_multi propagates database errors."""
test_engine, SessionLocal = async_test_db
_test_engine, SessionLocal = async_test_db
async with SessionLocal() as session:
async def mock_execute(*args, **kwargs):
raise OperationalError("Query failed", {}, Exception("DB error"))
with patch.object(session, 'execute', side_effect=mock_execute):
with patch.object(session, "execute", side_effect=mock_execute):
with pytest.raises(OperationalError):
await user_crud.get_multi(session, skip=0, limit=10)

File diff suppressed because it is too large Load Diff

View File

@@ -2,10 +2,12 @@
"""
Comprehensive tests for async session CRUD operations.
"""
import pytest
from datetime import datetime, timedelta, timezone
from datetime import UTC, datetime, timedelta
from uuid import uuid4
import pytest
from app.crud.session import session as session_crud
from app.models.user_session import UserSession
from app.schemas.sessions import SessionCreate
@@ -17,7 +19,7 @@ class TestGetByJti:
@pytest.mark.asyncio
async def test_get_by_jti_success(self, async_test_db, async_test_user):
"""Test getting session by JTI."""
test_engine, AsyncTestingSessionLocal = async_test_db
_test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session:
user_session = UserSession(
@@ -27,8 +29,8 @@ class TestGetByJti:
ip_address="192.168.1.1",
user_agent="Mozilla/5.0",
is_active=True,
expires_at=datetime.now(timezone.utc) + timedelta(days=7),
last_used_at=datetime.now(timezone.utc)
expires_at=datetime.now(UTC) + timedelta(days=7),
last_used_at=datetime.now(UTC),
)
session.add(user_session)
await session.commit()
@@ -41,7 +43,7 @@ class TestGetByJti:
@pytest.mark.asyncio
async def test_get_by_jti_not_found(self, async_test_db):
"""Test getting non-existent JTI returns None."""
test_engine, AsyncTestingSessionLocal = async_test_db
_test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session:
result = await session_crud.get_by_jti(session, jti="nonexistent")
@@ -54,7 +56,7 @@ class TestGetActiveByJti:
@pytest.mark.asyncio
async def test_get_active_by_jti_success(self, async_test_db, async_test_user):
"""Test getting active session by JTI."""
test_engine, AsyncTestingSessionLocal = async_test_db
_test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session:
user_session = UserSession(
@@ -64,8 +66,8 @@ class TestGetActiveByJti:
ip_address="192.168.1.1",
user_agent="Mozilla/5.0",
is_active=True,
expires_at=datetime.now(timezone.utc) + timedelta(days=7),
last_used_at=datetime.now(timezone.utc)
expires_at=datetime.now(UTC) + timedelta(days=7),
last_used_at=datetime.now(UTC),
)
session.add(user_session)
await session.commit()
@@ -78,7 +80,7 @@ class TestGetActiveByJti:
@pytest.mark.asyncio
async def test_get_active_by_jti_inactive(self, async_test_db, async_test_user):
"""Test getting inactive session by JTI returns None."""
test_engine, AsyncTestingSessionLocal = async_test_db
_test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session:
user_session = UserSession(
@@ -88,8 +90,8 @@ class TestGetActiveByJti:
ip_address="192.168.1.1",
user_agent="Mozilla/5.0",
is_active=False,
expires_at=datetime.now(timezone.utc) + timedelta(days=7),
last_used_at=datetime.now(timezone.utc)
expires_at=datetime.now(UTC) + timedelta(days=7),
last_used_at=datetime.now(UTC),
)
session.add(user_session)
await session.commit()
@@ -105,7 +107,7 @@ class TestGetUserSessions:
@pytest.mark.asyncio
async def test_get_user_sessions_active_only(self, async_test_db, async_test_user):
"""Test getting only active user sessions."""
test_engine, AsyncTestingSessionLocal = async_test_db
_test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session:
active = UserSession(
@@ -115,8 +117,8 @@ class TestGetUserSessions:
ip_address="192.168.1.1",
user_agent="Mozilla/5.0",
is_active=True,
expires_at=datetime.now(timezone.utc) + timedelta(days=7),
last_used_at=datetime.now(timezone.utc)
expires_at=datetime.now(UTC) + timedelta(days=7),
last_used_at=datetime.now(UTC),
)
inactive = UserSession(
user_id=async_test_user.id,
@@ -125,17 +127,15 @@ class TestGetUserSessions:
ip_address="192.168.1.2",
user_agent="Mozilla/5.0",
is_active=False,
expires_at=datetime.now(timezone.utc) + timedelta(days=7),
last_used_at=datetime.now(timezone.utc)
expires_at=datetime.now(UTC) + timedelta(days=7),
last_used_at=datetime.now(UTC),
)
session.add_all([active, inactive])
await session.commit()
async with AsyncTestingSessionLocal() as session:
results = await session_crud.get_user_sessions(
session,
user_id=str(async_test_user.id),
active_only=True
session, user_id=str(async_test_user.id), active_only=True
)
assert len(results) == 1
assert results[0].is_active is True
@@ -143,7 +143,7 @@ class TestGetUserSessions:
@pytest.mark.asyncio
async def test_get_user_sessions_all(self, async_test_db, async_test_user):
"""Test getting all user sessions."""
test_engine, AsyncTestingSessionLocal = async_test_db
_test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session:
for i in range(3):
@@ -154,17 +154,15 @@ class TestGetUserSessions:
ip_address="192.168.1.1",
user_agent="Mozilla/5.0",
is_active=i % 2 == 0,
expires_at=datetime.now(timezone.utc) + timedelta(days=7),
last_used_at=datetime.now(timezone.utc)
expires_at=datetime.now(UTC) + timedelta(days=7),
last_used_at=datetime.now(UTC),
)
session.add(sess)
await session.commit()
async with AsyncTestingSessionLocal() as session:
results = await session_crud.get_user_sessions(
session,
user_id=str(async_test_user.id),
active_only=False
session, user_id=str(async_test_user.id), active_only=False
)
assert len(results) == 3
@@ -175,7 +173,7 @@ class TestCreateSession:
@pytest.mark.asyncio
async def test_create_session_success(self, async_test_db, async_test_user):
"""Test successfully creating a session_crud."""
test_engine, AsyncTestingSessionLocal = async_test_db
_test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session:
session_data = SessionCreate(
@@ -185,10 +183,10 @@ class TestCreateSession:
device_id="device_123",
ip_address="192.168.1.100",
user_agent="Mozilla/5.0",
last_used_at=datetime.now(timezone.utc),
expires_at=datetime.now(timezone.utc) + timedelta(days=7),
last_used_at=datetime.now(UTC),
expires_at=datetime.now(UTC) + timedelta(days=7),
location_city="San Francisco",
location_country="USA"
location_country="USA",
)
result = await session_crud.create_session(session, obj_in=session_data)
@@ -204,7 +202,7 @@ class TestDeactivate:
@pytest.mark.asyncio
async def test_deactivate_success(self, async_test_db, async_test_user):
"""Test successfully deactivating a session_crud."""
test_engine, AsyncTestingSessionLocal = async_test_db
_test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session:
user_session = UserSession(
@@ -214,8 +212,8 @@ class TestDeactivate:
ip_address="192.168.1.1",
user_agent="Mozilla/5.0",
is_active=True,
expires_at=datetime.now(timezone.utc) + timedelta(days=7),
last_used_at=datetime.now(timezone.utc)
expires_at=datetime.now(UTC) + timedelta(days=7),
last_used_at=datetime.now(UTC),
)
session.add(user_session)
await session.commit()
@@ -229,7 +227,7 @@ class TestDeactivate:
@pytest.mark.asyncio
async def test_deactivate_not_found(self, async_test_db):
"""Test deactivating non-existent session returns None."""
test_engine, AsyncTestingSessionLocal = async_test_db
_test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session:
result = await session_crud.deactivate(session, session_id=str(uuid4()))
@@ -240,9 +238,11 @@ class TestDeactivateAllUserSessions:
"""Tests for deactivate_all_user_sessions method."""
@pytest.mark.asyncio
async def test_deactivate_all_user_sessions_success(self, async_test_db, async_test_user):
async def test_deactivate_all_user_sessions_success(
self, async_test_db, async_test_user
):
"""Test deactivating all user sessions."""
test_engine, AsyncTestingSessionLocal = async_test_db
_test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session:
# Create minimal sessions for test (2 instead of 5)
@@ -254,16 +254,15 @@ class TestDeactivateAllUserSessions:
ip_address="192.168.1.1",
user_agent="Mozilla/5.0",
is_active=True,
expires_at=datetime.now(timezone.utc) + timedelta(days=7),
last_used_at=datetime.now(timezone.utc)
expires_at=datetime.now(UTC) + timedelta(days=7),
last_used_at=datetime.now(UTC),
)
session.add(sess)
await session.commit()
async with AsyncTestingSessionLocal() as session:
count = await session_crud.deactivate_all_user_sessions(
session,
user_id=str(async_test_user.id)
session, user_id=str(async_test_user.id)
)
assert count == 2
@@ -274,7 +273,7 @@ class TestUpdateLastUsed:
@pytest.mark.asyncio
async def test_update_last_used_success(self, async_test_db, async_test_user):
"""Test updating last_used_at timestamp."""
test_engine, AsyncTestingSessionLocal = async_test_db
_test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session:
user_session = UserSession(
@@ -284,8 +283,8 @@ class TestUpdateLastUsed:
ip_address="192.168.1.1",
user_agent="Mozilla/5.0",
is_active=True,
expires_at=datetime.now(timezone.utc) + timedelta(days=7),
last_used_at=datetime.now(timezone.utc) - timedelta(hours=1)
expires_at=datetime.now(UTC) + timedelta(days=7),
last_used_at=datetime.now(UTC) - timedelta(hours=1),
)
session.add(user_session)
await session.commit()
@@ -303,7 +302,7 @@ class TestGetUserSessionCount:
@pytest.mark.asyncio
async def test_get_user_session_count_success(self, async_test_db, async_test_user):
"""Test getting user session count."""
test_engine, AsyncTestingSessionLocal = async_test_db
_test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session:
for i in range(3):
@@ -314,28 +313,26 @@ class TestGetUserSessionCount:
ip_address="192.168.1.1",
user_agent="Mozilla/5.0",
is_active=True,
expires_at=datetime.now(timezone.utc) + timedelta(days=7),
last_used_at=datetime.now(timezone.utc)
expires_at=datetime.now(UTC) + timedelta(days=7),
last_used_at=datetime.now(UTC),
)
session.add(sess)
await session.commit()
async with AsyncTestingSessionLocal() as session:
count = await session_crud.get_user_session_count(
session,
user_id=str(async_test_user.id)
session, user_id=str(async_test_user.id)
)
assert count == 3
@pytest.mark.asyncio
async def test_get_user_session_count_empty(self, async_test_db):
"""Test getting session count for user with no sessions."""
test_engine, AsyncTestingSessionLocal = async_test_db
_test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session:
count = await session_crud.get_user_session_count(
session,
user_id=str(uuid4())
session, user_id=str(uuid4())
)
assert count == 0
@@ -346,7 +343,7 @@ class TestUpdateRefreshToken:
@pytest.mark.asyncio
async def test_update_refresh_token_success(self, async_test_db, async_test_user):
"""Test updating refresh token JTI and expiration."""
test_engine, AsyncTestingSessionLocal = async_test_db
_test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session:
user_session = UserSession(
@@ -356,26 +353,34 @@ class TestUpdateRefreshToken:
ip_address="192.168.1.1",
user_agent="Mozilla/5.0",
is_active=True,
expires_at=datetime.now(timezone.utc) + timedelta(days=7),
last_used_at=datetime.now(timezone.utc) - timedelta(hours=1)
expires_at=datetime.now(UTC) + timedelta(days=7),
last_used_at=datetime.now(UTC) - timedelta(hours=1),
)
session.add(user_session)
await session.commit()
await session.refresh(user_session)
new_jti = "new_jti_123"
new_expires = datetime.now(timezone.utc) + timedelta(days=14)
new_expires = datetime.now(UTC) + timedelta(days=14)
result = await session_crud.update_refresh_token(
session,
session=user_session,
new_jti=new_jti,
new_expires_at=new_expires
new_expires_at=new_expires,
)
assert result.refresh_token_jti == new_jti
# Compare timestamps ignoring timezone info
assert abs((result.expires_at.replace(tzinfo=None) - new_expires.replace(tzinfo=None)).total_seconds()) < 1
assert (
abs(
(
result.expires_at.replace(tzinfo=None)
- new_expires.replace(tzinfo=None)
).total_seconds()
)
< 1
)
class TestCleanupExpired:
@@ -384,7 +389,7 @@ class TestCleanupExpired:
@pytest.mark.asyncio
async def test_cleanup_expired_success(self, async_test_db, async_test_user):
"""Test cleaning up old expired inactive sessions."""
test_engine, AsyncTestingSessionLocal = async_test_db
_test_engine, AsyncTestingSessionLocal = async_test_db
# Create old expired inactive session
async with AsyncTestingSessionLocal() as session:
@@ -395,9 +400,9 @@ class TestCleanupExpired:
ip_address="192.168.1.1",
user_agent="Mozilla/5.0",
is_active=False,
expires_at=datetime.now(timezone.utc) - timedelta(days=5),
last_used_at=datetime.now(timezone.utc) - timedelta(days=35),
created_at=datetime.now(timezone.utc) - timedelta(days=35)
expires_at=datetime.now(UTC) - timedelta(days=5),
last_used_at=datetime.now(UTC) - timedelta(days=35),
created_at=datetime.now(UTC) - timedelta(days=35),
)
session.add(old_session)
await session.commit()
@@ -410,7 +415,7 @@ class TestCleanupExpired:
@pytest.mark.asyncio
async def test_cleanup_expired_keeps_recent(self, async_test_db, async_test_user):
"""Test that cleanup keeps recent expired sessions."""
test_engine, AsyncTestingSessionLocal = async_test_db
_test_engine, AsyncTestingSessionLocal = async_test_db
# Create recent expired inactive session (less than keep_days old)
async with AsyncTestingSessionLocal() as session:
@@ -421,9 +426,9 @@ class TestCleanupExpired:
ip_address="192.168.1.1",
user_agent="Mozilla/5.0",
is_active=False,
expires_at=datetime.now(timezone.utc) - timedelta(hours=1),
last_used_at=datetime.now(timezone.utc) - timedelta(hours=2),
created_at=datetime.now(timezone.utc) - timedelta(days=1)
expires_at=datetime.now(UTC) - timedelta(hours=1),
last_used_at=datetime.now(UTC) - timedelta(hours=2),
created_at=datetime.now(UTC) - timedelta(days=1),
)
session.add(recent_session)
await session.commit()
@@ -436,7 +441,7 @@ class TestCleanupExpired:
@pytest.mark.asyncio
async def test_cleanup_expired_keeps_active(self, async_test_db, async_test_user):
"""Test that cleanup does not delete active sessions."""
test_engine, AsyncTestingSessionLocal = async_test_db
_test_engine, AsyncTestingSessionLocal = async_test_db
# Create old expired but ACTIVE session
async with AsyncTestingSessionLocal() as session:
@@ -447,9 +452,9 @@ class TestCleanupExpired:
ip_address="192.168.1.1",
user_agent="Mozilla/5.0",
is_active=True, # Active
expires_at=datetime.now(timezone.utc) - timedelta(days=5),
last_used_at=datetime.now(timezone.utc) - timedelta(days=35),
created_at=datetime.now(timezone.utc) - timedelta(days=35)
expires_at=datetime.now(UTC) - timedelta(days=5),
last_used_at=datetime.now(UTC) - timedelta(days=35),
created_at=datetime.now(UTC) - timedelta(days=35),
)
session.add(active_session)
await session.commit()
@@ -464,9 +469,11 @@ class TestCleanupExpiredForUser:
"""Tests for cleanup_expired_for_user method."""
@pytest.mark.asyncio
async def test_cleanup_expired_for_user_success(self, async_test_db, async_test_user):
async def test_cleanup_expired_for_user_success(
self, async_test_db, async_test_user
):
"""Test cleaning up expired sessions for specific user."""
test_engine, AsyncTestingSessionLocal = async_test_db
_test_engine, AsyncTestingSessionLocal = async_test_db
# Create expired inactive session for user
async with AsyncTestingSessionLocal() as session:
@@ -477,8 +484,8 @@ class TestCleanupExpiredForUser:
ip_address="192.168.1.1",
user_agent="Mozilla/5.0",
is_active=False,
expires_at=datetime.now(timezone.utc) - timedelta(days=1),
last_used_at=datetime.now(timezone.utc) - timedelta(days=2)
expires_at=datetime.now(UTC) - timedelta(days=1),
last_used_at=datetime.now(UTC) - timedelta(days=2),
)
session.add(expired_session)
await session.commit()
@@ -486,27 +493,27 @@ class TestCleanupExpiredForUser:
# Cleanup for user
async with AsyncTestingSessionLocal() as session:
count = await session_crud.cleanup_expired_for_user(
session,
user_id=str(async_test_user.id)
session, user_id=str(async_test_user.id)
)
assert count == 1
@pytest.mark.asyncio
async def test_cleanup_expired_for_user_invalid_uuid(self, async_test_db):
"""Test cleanup with invalid user UUID."""
test_engine, AsyncTestingSessionLocal = async_test_db
_test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session:
with pytest.raises(ValueError, match="Invalid user ID format"):
await session_crud.cleanup_expired_for_user(
session,
user_id="not-a-valid-uuid"
session, user_id="not-a-valid-uuid"
)
@pytest.mark.asyncio
async def test_cleanup_expired_for_user_keeps_active(self, async_test_db, async_test_user):
async def test_cleanup_expired_for_user_keeps_active(
self, async_test_db, async_test_user
):
"""Test that cleanup for user keeps active sessions."""
test_engine, AsyncTestingSessionLocal = async_test_db
_test_engine, AsyncTestingSessionLocal = async_test_db
# Create expired but active session
async with AsyncTestingSessionLocal() as session:
@@ -517,8 +524,8 @@ class TestCleanupExpiredForUser:
ip_address="192.168.1.1",
user_agent="Mozilla/5.0",
is_active=True, # Active
expires_at=datetime.now(timezone.utc) - timedelta(days=1),
last_used_at=datetime.now(timezone.utc) - timedelta(days=2)
expires_at=datetime.now(UTC) - timedelta(days=1),
last_used_at=datetime.now(UTC) - timedelta(days=2),
)
session.add(active_session)
await session.commit()
@@ -526,8 +533,7 @@ class TestCleanupExpiredForUser:
# Cleanup
async with AsyncTestingSessionLocal() as session:
count = await session_crud.cleanup_expired_for_user(
session,
user_id=str(async_test_user.id)
session, user_id=str(async_test_user.id)
)
assert count == 0 # Should not delete active sessions
@@ -536,9 +542,11 @@ class TestGetUserSessionsWithUser:
"""Tests for get_user_sessions with eager loading."""
@pytest.mark.asyncio
async def test_get_user_sessions_with_user_relationship(self, async_test_db, async_test_user):
async def test_get_user_sessions_with_user_relationship(
self, async_test_db, async_test_user
):
"""Test getting sessions with user relationship loaded."""
test_engine, AsyncTestingSessionLocal = async_test_db
_test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session:
user_session = UserSession(
@@ -548,8 +556,8 @@ class TestGetUserSessionsWithUser:
ip_address="192.168.1.1",
user_agent="Mozilla/5.0",
is_active=True,
expires_at=datetime.now(timezone.utc) + timedelta(days=7),
last_used_at=datetime.now(timezone.utc)
expires_at=datetime.now(UTC) + timedelta(days=7),
last_used_at=datetime.now(UTC),
)
session.add(user_session)
await session.commit()
@@ -557,8 +565,6 @@ class TestGetUserSessionsWithUser:
# Get with user relationship
async with AsyncTestingSessionLocal() as session:
results = await session_crud.get_user_sessions(
session,
user_id=str(async_test_user.id),
with_user=True
session, user_id=str(async_test_user.id), with_user=True
)
assert len(results) >= 1

View File

@@ -2,12 +2,14 @@
"""
Comprehensive tests for session CRUD database failure scenarios.
"""
import pytest
from datetime import UTC, datetime, timedelta
from unittest.mock import AsyncMock, patch
from sqlalchemy.exc import OperationalError, IntegrityError
from datetime import datetime, timedelta, timezone
from uuid import uuid4
import pytest
from sqlalchemy.exc import OperationalError
from app.crud.session import session as session_crud
from app.models.user_session import UserSession
from app.schemas.sessions import SessionCreate
@@ -19,13 +21,14 @@ class TestSessionCRUDGetByJtiFailures:
@pytest.mark.asyncio
async def test_get_by_jti_database_error(self, async_test_db):
"""Test get_by_jti handles database errors."""
test_engine, SessionLocal = async_test_db
_test_engine, SessionLocal = async_test_db
async with SessionLocal() as session:
async def mock_execute(*args, **kwargs):
raise OperationalError("DB connection lost", {}, Exception())
with patch.object(session, 'execute', side_effect=mock_execute):
with patch.object(session, "execute", side_effect=mock_execute):
with pytest.raises(OperationalError):
await session_crud.get_by_jti(session, jti="test_jti")
@@ -36,13 +39,14 @@ class TestSessionCRUDGetActiveByJtiFailures:
@pytest.mark.asyncio
async def test_get_active_by_jti_database_error(self, async_test_db):
"""Test get_active_by_jti handles database errors."""
test_engine, SessionLocal = async_test_db
_test_engine, SessionLocal = async_test_db
async with SessionLocal() as session:
async def mock_execute(*args, **kwargs):
raise OperationalError("Query timeout", {}, Exception())
with patch.object(session, 'execute', side_effect=mock_execute):
with patch.object(session, "execute", side_effect=mock_execute):
with pytest.raises(OperationalError):
await session_crud.get_active_by_jti(session, jti="test_jti")
@@ -51,19 +55,21 @@ class TestSessionCRUDGetUserSessionsFailures:
"""Test get_user_sessions exception handling."""
@pytest.mark.asyncio
async def test_get_user_sessions_database_error(self, async_test_db, async_test_user):
async def test_get_user_sessions_database_error(
self, async_test_db, async_test_user
):
"""Test get_user_sessions handles database errors."""
test_engine, SessionLocal = async_test_db
_test_engine, SessionLocal = async_test_db
async with SessionLocal() as session:
async def mock_execute(*args, **kwargs):
raise OperationalError("Database error", {}, Exception())
with patch.object(session, 'execute', side_effect=mock_execute):
with patch.object(session, "execute", side_effect=mock_execute):
with pytest.raises(OperationalError):
await session_crud.get_user_sessions(
session,
user_id=str(async_test_user.id)
session, user_id=str(async_test_user.id)
)
@@ -71,24 +77,29 @@ class TestSessionCRUDCreateSessionFailures:
"""Test create_session exception handling."""
@pytest.mark.asyncio
async def test_create_session_commit_failure_triggers_rollback(self, async_test_db, async_test_user):
async def test_create_session_commit_failure_triggers_rollback(
self, async_test_db, async_test_user
):
"""Test create_session handles commit failures with rollback."""
test_engine, SessionLocal = async_test_db
_test_engine, SessionLocal = async_test_db
async with SessionLocal() as session:
async def mock_commit():
raise OperationalError("Commit failed", {}, Exception())
with patch.object(session, 'commit', side_effect=mock_commit):
with patch.object(session, 'rollback', new_callable=AsyncMock) as mock_rollback:
with patch.object(session, "commit", side_effect=mock_commit):
with patch.object(
session, "rollback", new_callable=AsyncMock
) as mock_rollback:
session_data = SessionCreate(
user_id=async_test_user.id,
refresh_token_jti=str(uuid4()),
device_name="Test Device",
ip_address="127.0.0.1",
user_agent="Test Agent",
expires_at=datetime.now(timezone.utc) + timedelta(days=7),
last_used_at=datetime.now(timezone.utc)
expires_at=datetime.now(UTC) + timedelta(days=7),
last_used_at=datetime.now(UTC),
)
with pytest.raises(ValueError, match="Failed to create session"):
@@ -97,24 +108,29 @@ class TestSessionCRUDCreateSessionFailures:
mock_rollback.assert_called_once()
@pytest.mark.asyncio
async def test_create_session_unexpected_error_triggers_rollback(self, async_test_db, async_test_user):
async def test_create_session_unexpected_error_triggers_rollback(
self, async_test_db, async_test_user
):
"""Test create_session handles unexpected errors."""
test_engine, SessionLocal = async_test_db
_test_engine, SessionLocal = async_test_db
async with SessionLocal() as session:
async def mock_commit():
raise RuntimeError("Unexpected error")
with patch.object(session, 'commit', side_effect=mock_commit):
with patch.object(session, 'rollback', new_callable=AsyncMock) as mock_rollback:
with patch.object(session, "commit", side_effect=mock_commit):
with patch.object(
session, "rollback", new_callable=AsyncMock
) as mock_rollback:
session_data = SessionCreate(
user_id=async_test_user.id,
refresh_token_jti=str(uuid4()),
device_name="Test Device",
ip_address="127.0.0.1",
user_agent="Test Agent",
expires_at=datetime.now(timezone.utc) + timedelta(days=7),
last_used_at=datetime.now(timezone.utc)
expires_at=datetime.now(UTC) + timedelta(days=7),
last_used_at=datetime.now(UTC),
)
with pytest.raises(ValueError, match="Failed to create session"):
@@ -127,9 +143,11 @@ class TestSessionCRUDDeactivateFailures:
"""Test deactivate exception handling."""
@pytest.mark.asyncio
async def test_deactivate_commit_failure_triggers_rollback(self, async_test_db, async_test_user):
async def test_deactivate_commit_failure_triggers_rollback(
self, async_test_db, async_test_user
):
"""Test deactivate handles commit failures."""
test_engine, SessionLocal = async_test_db
_test_engine, SessionLocal = async_test_db
# Create a session first
async with SessionLocal() as session:
@@ -140,8 +158,8 @@ class TestSessionCRUDDeactivateFailures:
ip_address="127.0.0.1",
user_agent="Test Agent",
is_active=True,
expires_at=datetime.now(timezone.utc) + timedelta(days=7),
last_used_at=datetime.now(timezone.utc)
expires_at=datetime.now(UTC) + timedelta(days=7),
last_used_at=datetime.now(UTC),
)
session.add(user_session)
await session.commit()
@@ -150,13 +168,18 @@ class TestSessionCRUDDeactivateFailures:
# Test deactivate failure
async with SessionLocal() as session:
async def mock_commit():
raise OperationalError("Deactivate failed", {}, Exception())
with patch.object(session, 'commit', side_effect=mock_commit):
with patch.object(session, 'rollback', new_callable=AsyncMock) as mock_rollback:
with patch.object(session, "commit", side_effect=mock_commit):
with patch.object(
session, "rollback", new_callable=AsyncMock
) as mock_rollback:
with pytest.raises(OperationalError):
await session_crud.deactivate(session, session_id=str(session_id))
await session_crud.deactivate(
session, session_id=str(session_id)
)
mock_rollback.assert_called_once()
@@ -165,20 +188,24 @@ class TestSessionCRUDDeactivateAllFailures:
"""Test deactivate_all_user_sessions exception handling."""
@pytest.mark.asyncio
async def test_deactivate_all_commit_failure_triggers_rollback(self, async_test_db, async_test_user):
async def test_deactivate_all_commit_failure_triggers_rollback(
self, async_test_db, async_test_user
):
"""Test deactivate_all handles commit failures."""
test_engine, SessionLocal = async_test_db
_test_engine, SessionLocal = async_test_db
async with SessionLocal() as session:
async def mock_commit():
raise OperationalError("Bulk deactivate failed", {}, Exception())
with patch.object(session, 'commit', side_effect=mock_commit):
with patch.object(session, 'rollback', new_callable=AsyncMock) as mock_rollback:
with patch.object(session, "commit", side_effect=mock_commit):
with patch.object(
session, "rollback", new_callable=AsyncMock
) as mock_rollback:
with pytest.raises(OperationalError):
await session_crud.deactivate_all_user_sessions(
session,
user_id=str(async_test_user.id)
session, user_id=str(async_test_user.id)
)
mock_rollback.assert_called_once()
@@ -188,9 +215,11 @@ class TestSessionCRUDUpdateLastUsedFailures:
"""Test update_last_used exception handling."""
@pytest.mark.asyncio
async def test_update_last_used_commit_failure_triggers_rollback(self, async_test_db, async_test_user):
async def test_update_last_used_commit_failure_triggers_rollback(
self, async_test_db, async_test_user
):
"""Test update_last_used handles commit failures."""
test_engine, SessionLocal = async_test_db
_test_engine, SessionLocal = async_test_db
# Create a session
async with SessionLocal() as session:
@@ -201,8 +230,8 @@ class TestSessionCRUDUpdateLastUsedFailures:
ip_address="127.0.0.1",
user_agent="Test Agent",
is_active=True,
expires_at=datetime.now(timezone.utc) + timedelta(days=7),
last_used_at=datetime.now(timezone.utc) - timedelta(hours=1)
expires_at=datetime.now(UTC) + timedelta(days=7),
last_used_at=datetime.now(UTC) - timedelta(hours=1),
)
session.add(user_session)
await session.commit()
@@ -211,15 +240,19 @@ class TestSessionCRUDUpdateLastUsedFailures:
# Test update failure
async with SessionLocal() as session:
from sqlalchemy import select
from app.models.user_session import UserSession as US
result = await session.execute(select(US).where(US.id == user_session.id))
sess = result.scalar_one()
async def mock_commit():
raise OperationalError("Update failed", {}, Exception())
with patch.object(session, 'commit', side_effect=mock_commit):
with patch.object(session, 'rollback', new_callable=AsyncMock) as mock_rollback:
with patch.object(session, "commit", side_effect=mock_commit):
with patch.object(
session, "rollback", new_callable=AsyncMock
) as mock_rollback:
with pytest.raises(OperationalError):
await session_crud.update_last_used(session, session=sess)
@@ -230,9 +263,11 @@ class TestSessionCRUDUpdateRefreshTokenFailures:
"""Test update_refresh_token exception handling."""
@pytest.mark.asyncio
async def test_update_refresh_token_commit_failure_triggers_rollback(self, async_test_db, async_test_user):
async def test_update_refresh_token_commit_failure_triggers_rollback(
self, async_test_db, async_test_user
):
"""Test update_refresh_token handles commit failures."""
test_engine, SessionLocal = async_test_db
_test_engine, SessionLocal = async_test_db
# Create a session
async with SessionLocal() as session:
@@ -243,8 +278,8 @@ class TestSessionCRUDUpdateRefreshTokenFailures:
ip_address="127.0.0.1",
user_agent="Test Agent",
is_active=True,
expires_at=datetime.now(timezone.utc) + timedelta(days=7),
last_used_at=datetime.now(timezone.utc)
expires_at=datetime.now(UTC) + timedelta(days=7),
last_used_at=datetime.now(UTC),
)
session.add(user_session)
await session.commit()
@@ -253,21 +288,25 @@ class TestSessionCRUDUpdateRefreshTokenFailures:
# Test update failure
async with SessionLocal() as session:
from sqlalchemy import select
from app.models.user_session import UserSession as US
result = await session.execute(select(US).where(US.id == user_session.id))
sess = result.scalar_one()
async def mock_commit():
raise OperationalError("Token update failed", {}, Exception())
with patch.object(session, 'commit', side_effect=mock_commit):
with patch.object(session, 'rollback', new_callable=AsyncMock) as mock_rollback:
with patch.object(session, "commit", side_effect=mock_commit):
with patch.object(
session, "rollback", new_callable=AsyncMock
) as mock_rollback:
with pytest.raises(OperationalError):
await session_crud.update_refresh_token(
session,
session=sess,
new_jti=str(uuid4()),
new_expires_at=datetime.now(timezone.utc) + timedelta(days=14)
new_expires_at=datetime.now(UTC) + timedelta(days=14),
)
mock_rollback.assert_called_once()
@@ -277,16 +316,21 @@ class TestSessionCRUDCleanupExpiredFailures:
"""Test cleanup_expired exception handling."""
@pytest.mark.asyncio
async def test_cleanup_expired_commit_failure_triggers_rollback(self, async_test_db):
async def test_cleanup_expired_commit_failure_triggers_rollback(
self, async_test_db
):
"""Test cleanup_expired handles commit failures."""
test_engine, SessionLocal = async_test_db
_test_engine, SessionLocal = async_test_db
async with SessionLocal() as session:
async def mock_commit():
raise OperationalError("Cleanup failed", {}, Exception())
with patch.object(session, 'commit', side_effect=mock_commit):
with patch.object(session, 'rollback', new_callable=AsyncMock) as mock_rollback:
with patch.object(session, "commit", side_effect=mock_commit):
with patch.object(
session, "rollback", new_callable=AsyncMock
) as mock_rollback:
with pytest.raises(OperationalError):
await session_crud.cleanup_expired(session, keep_days=30)
@@ -297,20 +341,24 @@ class TestSessionCRUDCleanupExpiredForUserFailures:
"""Test cleanup_expired_for_user exception handling."""
@pytest.mark.asyncio
async def test_cleanup_expired_for_user_commit_failure_triggers_rollback(self, async_test_db, async_test_user):
async def test_cleanup_expired_for_user_commit_failure_triggers_rollback(
self, async_test_db, async_test_user
):
"""Test cleanup_expired_for_user handles commit failures."""
test_engine, SessionLocal = async_test_db
_test_engine, SessionLocal = async_test_db
async with SessionLocal() as session:
async def mock_commit():
raise OperationalError("User cleanup failed", {}, Exception())
with patch.object(session, 'commit', side_effect=mock_commit):
with patch.object(session, 'rollback', new_callable=AsyncMock) as mock_rollback:
with patch.object(session, "commit", side_effect=mock_commit):
with patch.object(
session, "rollback", new_callable=AsyncMock
) as mock_rollback:
with pytest.raises(OperationalError):
await session_crud.cleanup_expired_for_user(
session,
user_id=str(async_test_user.id)
session, user_id=str(async_test_user.id)
)
mock_rollback.assert_called_once()
@@ -320,17 +368,19 @@ class TestSessionCRUDGetUserSessionCountFailures:
"""Test get_user_session_count exception handling."""
@pytest.mark.asyncio
async def test_get_user_session_count_database_error(self, async_test_db, async_test_user):
async def test_get_user_session_count_database_error(
self, async_test_db, async_test_user
):
"""Test get_user_session_count handles database errors."""
test_engine, SessionLocal = async_test_db
_test_engine, SessionLocal = async_test_db
async with SessionLocal() as session:
async def mock_execute(*args, **kwargs):
raise OperationalError("Count query failed", {}, Exception())
with patch.object(session, 'execute', side_effect=mock_execute):
with patch.object(session, "execute", side_effect=mock_execute):
with pytest.raises(OperationalError):
await session_crud.get_user_session_count(
session,
user_id=str(async_test_user.id)
session, user_id=str(async_test_user.id)
)

View File

@@ -2,12 +2,10 @@
"""
Comprehensive tests for async user CRUD operations.
"""
import pytest
from datetime import datetime, timezone
from uuid import uuid4
from app.crud.user import user as user_crud
from app.models.user import User
from app.schemas.users import UserCreate, UserUpdate
@@ -17,7 +15,7 @@ class TestGetByEmail:
@pytest.mark.asyncio
async def test_get_by_email_success(self, async_test_db, async_test_user):
"""Test getting user by email."""
test_engine, AsyncTestingSessionLocal = async_test_db
_test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session:
result = await user_crud.get_by_email(session, email=async_test_user.email)
@@ -28,10 +26,12 @@ class TestGetByEmail:
@pytest.mark.asyncio
async def test_get_by_email_not_found(self, async_test_db):
"""Test getting non-existent email returns None."""
test_engine, AsyncTestingSessionLocal = async_test_db
_test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session:
result = await user_crud.get_by_email(session, email="nonexistent@example.com")
result = await user_crud.get_by_email(
session, email="nonexistent@example.com"
)
assert result is None
@@ -41,7 +41,7 @@ class TestCreate:
@pytest.mark.asyncio
async def test_create_user_success(self, async_test_db):
"""Test successfully creating a user_crud."""
test_engine, AsyncTestingSessionLocal = async_test_db
_test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session:
user_data = UserCreate(
@@ -49,7 +49,7 @@ class TestCreate:
password="SecurePass123!",
first_name="New",
last_name="User",
phone_number="+1234567890"
phone_number="+1234567890",
)
result = await user_crud.create(session, obj_in=user_data)
@@ -65,7 +65,7 @@ class TestCreate:
@pytest.mark.asyncio
async def test_create_superuser_success(self, async_test_db):
"""Test creating a superuser."""
test_engine, AsyncTestingSessionLocal = async_test_db
_test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session:
user_data = UserCreate(
@@ -73,7 +73,7 @@ class TestCreate:
password="SuperPass123!",
first_name="Super",
last_name="User",
is_superuser=True
is_superuser=True,
)
result = await user_crud.create(session, obj_in=user_data)
@@ -83,14 +83,14 @@ class TestCreate:
@pytest.mark.asyncio
async def test_create_duplicate_email_fails(self, async_test_db, async_test_user):
"""Test creating user with duplicate email raises ValueError."""
test_engine, AsyncTestingSessionLocal = async_test_db
_test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session:
user_data = UserCreate(
email=async_test_user.email, # Duplicate email
password="AnotherPass123!",
first_name="Duplicate",
last_name="User"
last_name="User",
)
with pytest.raises(ValueError) as exc_info:
@@ -105,16 +105,14 @@ class TestUpdate:
@pytest.mark.asyncio
async def test_update_user_basic_fields(self, async_test_db, async_test_user):
"""Test updating basic user fields."""
test_engine, AsyncTestingSessionLocal = async_test_db
_test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session:
# Get fresh copy of user
user = await user_crud.get(session, id=str(async_test_user.id))
update_data = UserUpdate(
first_name="Updated",
last_name="Name",
phone_number="+9876543210"
first_name="Updated", last_name="Name", phone_number="+9876543210"
)
result = await user_crud.update(session, db_obj=user, obj_in=update_data)
@@ -125,7 +123,7 @@ class TestUpdate:
@pytest.mark.asyncio
async def test_update_user_password(self, async_test_db):
"""Test updating user password."""
test_engine, AsyncTestingSessionLocal = async_test_db
_test_engine, AsyncTestingSessionLocal = async_test_db
# Create a fresh user for this test
async with AsyncTestingSessionLocal() as session:
@@ -133,7 +131,7 @@ class TestUpdate:
email="passwordtest@example.com",
password="OldPassword123!",
first_name="Pass",
last_name="Test"
last_name="Test",
)
user = await user_crud.create(session, obj_in=user_data)
user_id = user.id
@@ -149,12 +147,14 @@ class TestUpdate:
await session.refresh(result)
assert result.password_hash != old_password_hash
assert result.password_hash is not None
assert "NewDifferentPassword123!" not in result.password_hash # Should be hashed
assert (
"NewDifferentPassword123!" not in result.password_hash
) # Should be hashed
@pytest.mark.asyncio
async def test_update_user_with_dict(self, async_test_db, async_test_user):
"""Test updating user with dictionary."""
test_engine, AsyncTestingSessionLocal = async_test_db
_test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session:
user = await user_crud.get(session, id=str(async_test_user.id))
@@ -171,13 +171,11 @@ class TestGetMultiWithTotal:
@pytest.mark.asyncio
async def test_get_multi_with_total_basic(self, async_test_db, async_test_user):
"""Test basic pagination."""
test_engine, AsyncTestingSessionLocal = async_test_db
_test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session:
users, total = await user_crud.get_multi_with_total(
session,
skip=0,
limit=10
session, skip=0, limit=10
)
assert total >= 1
assert len(users) >= 1
@@ -186,7 +184,7 @@ class TestGetMultiWithTotal:
@pytest.mark.asyncio
async def test_get_multi_with_total_sorting_asc(self, async_test_db):
"""Test sorting in ascending order."""
test_engine, AsyncTestingSessionLocal = async_test_db
_test_engine, AsyncTestingSessionLocal = async_test_db
# Create multiple users
async with AsyncTestingSessionLocal() as session:
@@ -195,17 +193,13 @@ class TestGetMultiWithTotal:
email=f"sort{i}@example.com",
password="SecurePass123!",
first_name=f"User{i}",
last_name="Test"
last_name="Test",
)
await user_crud.create(session, obj_in=user_data)
async with AsyncTestingSessionLocal() as session:
users, total = await user_crud.get_multi_with_total(
session,
skip=0,
limit=10,
sort_by="email",
sort_order="asc"
users, _total = await user_crud.get_multi_with_total(
session, skip=0, limit=10, sort_by="email", sort_order="asc"
)
# Check if sorted (at least the test users)
@@ -216,7 +210,7 @@ class TestGetMultiWithTotal:
@pytest.mark.asyncio
async def test_get_multi_with_total_sorting_desc(self, async_test_db):
"""Test sorting in descending order."""
test_engine, AsyncTestingSessionLocal = async_test_db
_test_engine, AsyncTestingSessionLocal = async_test_db
# Create multiple users
async with AsyncTestingSessionLocal() as session:
@@ -225,17 +219,13 @@ class TestGetMultiWithTotal:
email=f"desc{i}@example.com",
password="SecurePass123!",
first_name=f"User{i}",
last_name="Test"
last_name="Test",
)
await user_crud.create(session, obj_in=user_data)
async with AsyncTestingSessionLocal() as session:
users, total = await user_crud.get_multi_with_total(
session,
skip=0,
limit=10,
sort_by="email",
sort_order="desc"
users, _total = await user_crud.get_multi_with_total(
session, skip=0, limit=10, sort_by="email", sort_order="desc"
)
# Check if sorted descending (at least the test users)
@@ -246,7 +236,7 @@ class TestGetMultiWithTotal:
@pytest.mark.asyncio
async def test_get_multi_with_total_filtering(self, async_test_db):
"""Test filtering by field."""
test_engine, AsyncTestingSessionLocal = async_test_db
_test_engine, AsyncTestingSessionLocal = async_test_db
# Create active and inactive users
async with AsyncTestingSessionLocal() as session:
@@ -254,7 +244,7 @@ class TestGetMultiWithTotal:
email="active@example.com",
password="SecurePass123!",
first_name="Active",
last_name="User"
last_name="User",
)
await user_crud.create(session, obj_in=active_user)
@@ -262,23 +252,18 @@ class TestGetMultiWithTotal:
email="inactive@example.com",
password="SecurePass123!",
first_name="Inactive",
last_name="User"
last_name="User",
)
created_inactive = await user_crud.create(session, obj_in=inactive_user)
# Deactivate the user
await user_crud.update(
session,
db_obj=created_inactive,
obj_in={"is_active": False}
session, db_obj=created_inactive, obj_in={"is_active": False}
)
async with AsyncTestingSessionLocal() as session:
users, total = await user_crud.get_multi_with_total(
session,
skip=0,
limit=100,
filters={"is_active": True}
users, _total = await user_crud.get_multi_with_total(
session, skip=0, limit=100, filters={"is_active": True}
)
# All returned users should be active
@@ -287,7 +272,7 @@ class TestGetMultiWithTotal:
@pytest.mark.asyncio
async def test_get_multi_with_total_search(self, async_test_db):
"""Test search functionality."""
test_engine, AsyncTestingSessionLocal = async_test_db
_test_engine, AsyncTestingSessionLocal = async_test_db
# Create user with unique name
async with AsyncTestingSessionLocal() as session:
@@ -295,16 +280,13 @@ class TestGetMultiWithTotal:
email="searchable@example.com",
password="SecurePass123!",
first_name="Searchable",
last_name="UserName"
last_name="UserName",
)
await user_crud.create(session, obj_in=user_data)
async with AsyncTestingSessionLocal() as session:
users, total = await user_crud.get_multi_with_total(
session,
skip=0,
limit=100,
search="Searchable"
session, skip=0, limit=100, search="Searchable"
)
assert total >= 1
@@ -313,7 +295,7 @@ class TestGetMultiWithTotal:
@pytest.mark.asyncio
async def test_get_multi_with_total_pagination(self, async_test_db):
"""Test pagination with skip and limit."""
test_engine, AsyncTestingSessionLocal = async_test_db
_test_engine, AsyncTestingSessionLocal = async_test_db
# Create multiple users
async with AsyncTestingSessionLocal() as session:
@@ -322,23 +304,19 @@ class TestGetMultiWithTotal:
email=f"page{i}@example.com",
password="SecurePass123!",
first_name=f"Page{i}",
last_name="User"
last_name="User",
)
await user_crud.create(session, obj_in=user_data)
async with AsyncTestingSessionLocal() as session:
# Get first page
users_page1, total = await user_crud.get_multi_with_total(
session,
skip=0,
limit=2
session, skip=0, limit=2
)
# Get second page
users_page2, total2 = await user_crud.get_multi_with_total(
session,
skip=2,
limit=2
session, skip=2, limit=2
)
# Total should be same
@@ -349,7 +327,7 @@ class TestGetMultiWithTotal:
@pytest.mark.asyncio
async def test_get_multi_with_total_validation_negative_skip(self, async_test_db):
"""Test validation fails for negative skip."""
test_engine, AsyncTestingSessionLocal = async_test_db
_test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session:
with pytest.raises(ValueError) as exc_info:
@@ -360,7 +338,7 @@ class TestGetMultiWithTotal:
@pytest.mark.asyncio
async def test_get_multi_with_total_validation_negative_limit(self, async_test_db):
"""Test validation fails for negative limit."""
test_engine, AsyncTestingSessionLocal = async_test_db
_test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session:
with pytest.raises(ValueError) as exc_info:
@@ -371,7 +349,7 @@ class TestGetMultiWithTotal:
@pytest.mark.asyncio
async def test_get_multi_with_total_validation_max_limit(self, async_test_db):
"""Test validation fails for limit > 1000."""
test_engine, AsyncTestingSessionLocal = async_test_db
_test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session:
with pytest.raises(ValueError) as exc_info:
@@ -386,7 +364,7 @@ class TestBulkUpdateStatus:
@pytest.mark.asyncio
async def test_bulk_update_status_success(self, async_test_db):
"""Test bulk updating user status."""
test_engine, AsyncTestingSessionLocal = async_test_db
_test_engine, AsyncTestingSessionLocal = async_test_db
# Create multiple users
user_ids = []
@@ -396,7 +374,7 @@ class TestBulkUpdateStatus:
email=f"bulk{i}@example.com",
password="SecurePass123!",
first_name=f"Bulk{i}",
last_name="User"
last_name="User",
)
user = await user_crud.create(session, obj_in=user_data)
user_ids.append(user.id)
@@ -404,9 +382,7 @@ class TestBulkUpdateStatus:
# Bulk deactivate
async with AsyncTestingSessionLocal() as session:
count = await user_crud.bulk_update_status(
session,
user_ids=user_ids,
is_active=False
session, user_ids=user_ids, is_active=False
)
assert count == 3
@@ -419,20 +395,18 @@ class TestBulkUpdateStatus:
@pytest.mark.asyncio
async def test_bulk_update_status_empty_list(self, async_test_db):
"""Test bulk update with empty list returns 0."""
test_engine, AsyncTestingSessionLocal = async_test_db
_test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session:
count = await user_crud.bulk_update_status(
session,
user_ids=[],
is_active=False
session, user_ids=[], is_active=False
)
assert count == 0
@pytest.mark.asyncio
async def test_bulk_update_status_reactivate(self, async_test_db):
"""Test bulk reactivating users."""
test_engine, AsyncTestingSessionLocal = async_test_db
_test_engine, AsyncTestingSessionLocal = async_test_db
# Create inactive user
async with AsyncTestingSessionLocal() as session:
@@ -440,7 +414,7 @@ class TestBulkUpdateStatus:
email="reactivate@example.com",
password="SecurePass123!",
first_name="Reactivate",
last_name="User"
last_name="User",
)
user = await user_crud.create(session, obj_in=user_data)
# Deactivate
@@ -450,9 +424,7 @@ class TestBulkUpdateStatus:
# Reactivate
async with AsyncTestingSessionLocal() as session:
count = await user_crud.bulk_update_status(
session,
user_ids=[user_id],
is_active=True
session, user_ids=[user_id], is_active=True
)
assert count == 1
@@ -468,7 +440,7 @@ class TestBulkSoftDelete:
@pytest.mark.asyncio
async def test_bulk_soft_delete_success(self, async_test_db):
"""Test bulk soft deleting users."""
test_engine, AsyncTestingSessionLocal = async_test_db
_test_engine, AsyncTestingSessionLocal = async_test_db
# Create multiple users
user_ids = []
@@ -478,17 +450,14 @@ class TestBulkSoftDelete:
email=f"delete{i}@example.com",
password="SecurePass123!",
first_name=f"Delete{i}",
last_name="User"
last_name="User",
)
user = await user_crud.create(session, obj_in=user_data)
user_ids.append(user.id)
# Bulk delete
async with AsyncTestingSessionLocal() as session:
count = await user_crud.bulk_soft_delete(
session,
user_ids=user_ids
)
count = await user_crud.bulk_soft_delete(session, user_ids=user_ids)
assert count == 3
# Verify all are soft deleted
@@ -501,7 +470,7 @@ class TestBulkSoftDelete:
@pytest.mark.asyncio
async def test_bulk_soft_delete_with_exclusion(self, async_test_db):
"""Test bulk soft delete with excluded user_crud."""
test_engine, AsyncTestingSessionLocal = async_test_db
_test_engine, AsyncTestingSessionLocal = async_test_db
# Create multiple users
user_ids = []
@@ -511,7 +480,7 @@ class TestBulkSoftDelete:
email=f"exclude{i}@example.com",
password="SecurePass123!",
first_name=f"Exclude{i}",
last_name="User"
last_name="User",
)
user = await user_crud.create(session, obj_in=user_data)
user_ids.append(user.id)
@@ -520,9 +489,7 @@ class TestBulkSoftDelete:
exclude_id = user_ids[0]
async with AsyncTestingSessionLocal() as session:
count = await user_crud.bulk_soft_delete(
session,
user_ids=user_ids,
exclude_user_id=exclude_id
session, user_ids=user_ids, exclude_user_id=exclude_id
)
assert count == 2 # Only 2 deleted
@@ -534,19 +501,16 @@ class TestBulkSoftDelete:
@pytest.mark.asyncio
async def test_bulk_soft_delete_empty_list(self, async_test_db):
"""Test bulk delete with empty list returns 0."""
test_engine, AsyncTestingSessionLocal = async_test_db
_test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session:
count = await user_crud.bulk_soft_delete(
session,
user_ids=[]
)
count = await user_crud.bulk_soft_delete(session, user_ids=[])
assert count == 0
@pytest.mark.asyncio
async def test_bulk_soft_delete_all_excluded(self, async_test_db):
"""Test bulk delete where all users are excluded."""
test_engine, AsyncTestingSessionLocal = async_test_db
_test_engine, AsyncTestingSessionLocal = async_test_db
# Create user
async with AsyncTestingSessionLocal() as session:
@@ -554,7 +518,7 @@ class TestBulkSoftDelete:
email="onlyuser@example.com",
password="SecurePass123!",
first_name="Only",
last_name="User"
last_name="User",
)
user = await user_crud.create(session, obj_in=user_data)
user_id = user.id
@@ -562,16 +526,14 @@ class TestBulkSoftDelete:
# Try to delete but exclude
async with AsyncTestingSessionLocal() as session:
count = await user_crud.bulk_soft_delete(
session,
user_ids=[user_id],
exclude_user_id=user_id
session, user_ids=[user_id], exclude_user_id=user_id
)
assert count == 0
@pytest.mark.asyncio
async def test_bulk_soft_delete_already_deleted(self, async_test_db):
"""Test bulk delete doesn't re-delete already deleted users."""
test_engine, AsyncTestingSessionLocal = async_test_db
_test_engine, AsyncTestingSessionLocal = async_test_db
# Create and delete user
async with AsyncTestingSessionLocal() as session:
@@ -579,7 +541,7 @@ class TestBulkSoftDelete:
email="predeleted@example.com",
password="SecurePass123!",
first_name="PreDeleted",
last_name="User"
last_name="User",
)
user = await user_crud.create(session, obj_in=user_data)
user_id = user.id
@@ -589,10 +551,7 @@ class TestBulkSoftDelete:
# Try to delete again
async with AsyncTestingSessionLocal() as session:
count = await user_crud.bulk_soft_delete(
session,
user_ids=[user_id]
)
count = await user_crud.bulk_soft_delete(session, user_ids=[user_id])
assert count == 0 # Already deleted
@@ -602,7 +561,7 @@ class TestUtilityMethods:
@pytest.mark.asyncio
async def test_is_active_true(self, async_test_db, async_test_user):
"""Test is_active returns True for active user_crud."""
test_engine, AsyncTestingSessionLocal = async_test_db
_test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session:
user = await user_crud.get(session, id=str(async_test_user.id))
@@ -611,14 +570,14 @@ class TestUtilityMethods:
@pytest.mark.asyncio
async def test_is_active_false(self, async_test_db):
"""Test is_active returns False for inactive user_crud."""
test_engine, AsyncTestingSessionLocal = async_test_db
_test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session:
user_data = UserCreate(
email="inactive2@example.com",
password="SecurePass123!",
first_name="Inactive",
last_name="User"
last_name="User",
)
user = await user_crud.create(session, obj_in=user_data)
await user_crud.update(session, db_obj=user, obj_in={"is_active": False})
@@ -628,7 +587,7 @@ class TestUtilityMethods:
@pytest.mark.asyncio
async def test_is_superuser_true(self, async_test_db, async_test_superuser):
"""Test is_superuser returns True for superuser."""
test_engine, AsyncTestingSessionLocal = async_test_db
_test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session:
user = await user_crud.get(session, id=str(async_test_superuser.id))
@@ -637,7 +596,7 @@ class TestUtilityMethods:
@pytest.mark.asyncio
async def test_is_superuser_false(self, async_test_db, async_test_user):
"""Test is_superuser returns False for regular user_crud."""
test_engine, AsyncTestingSessionLocal = async_test_db
_test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session:
user = await user_crud.get(session, id=str(async_test_user.id))
@@ -654,42 +613,52 @@ class TestUserExceptionHandlers:
async def test_get_by_email_database_error(self, async_test_db):
"""Test get_by_email handles database errors (covers lines 30-32)."""
from unittest.mock import patch
test_engine, AsyncTestingSessionLocal = async_test_db
_test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session:
with patch.object(session, 'execute', side_effect=Exception("Database query failed")):
with patch.object(
session, "execute", side_effect=Exception("Database query failed")
):
with pytest.raises(Exception, match="Database query failed"):
await user_crud.get_by_email(session, email="test@example.com")
@pytest.mark.asyncio
async def test_bulk_update_status_database_error(self, async_test_db, async_test_user):
async def test_bulk_update_status_database_error(
self, async_test_db, async_test_user
):
"""Test bulk_update_status handles database errors (covers lines 205-208)."""
from unittest.mock import patch, AsyncMock
test_engine, AsyncTestingSessionLocal = async_test_db
from unittest.mock import AsyncMock, patch
_test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session:
# Mock execute to fail
with patch.object(session, 'execute', side_effect=Exception("Bulk update failed")):
with patch.object(session, 'rollback', new_callable=AsyncMock):
with patch.object(
session, "execute", side_effect=Exception("Bulk update failed")
):
with patch.object(session, "rollback", new_callable=AsyncMock):
with pytest.raises(Exception, match="Bulk update failed"):
await user_crud.bulk_update_status(
session,
user_ids=[async_test_user.id],
is_active=False
session, user_ids=[async_test_user.id], is_active=False
)
@pytest.mark.asyncio
async def test_bulk_soft_delete_database_error(self, async_test_db, async_test_user):
async def test_bulk_soft_delete_database_error(
self, async_test_db, async_test_user
):
"""Test bulk_soft_delete handles database errors (covers lines 257-260)."""
from unittest.mock import patch, AsyncMock
test_engine, AsyncTestingSessionLocal = async_test_db
from unittest.mock import AsyncMock, patch
_test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session:
# Mock execute to fail
with patch.object(session, 'execute', side_effect=Exception("Bulk delete failed")):
with patch.object(session, 'rollback', new_callable=AsyncMock):
with patch.object(
session, "execute", side_effect=Exception("Bulk delete failed")
):
with patch.object(session, "rollback", new_callable=AsyncMock):
with pytest.raises(Exception, match="Bulk delete failed"):
await user_crud.bulk_soft_delete(
session,
user_ids=[async_test_user.id]
session, user_ids=[async_test_user.id]
)

View File

@@ -1,8 +1,10 @@
# tests/models/test_user.py
import uuid
import pytest
from datetime import datetime
import pytest
from sqlalchemy.exc import IntegrityError
from app.models.user import User
@@ -166,7 +168,6 @@ def test_user_required_fields(db_session):
db_session.rollback()
def test_user_defaults(db_session):
"""Test that default values are correctly set."""
# Arrange - Create a minimal user with only required fields
@@ -210,22 +211,13 @@ def test_user_with_complex_json_preferences(db_session):
"""Test storing and retrieving complex JSON preferences."""
# Arrange - Create a user with nested JSON preferences
complex_preferences = {
"theme": {
"mode": "dark",
"colors": {
"primary": "#333",
"secondary": "#666"
}
},
"theme": {"mode": "dark", "colors": {"primary": "#333", "secondary": "#666"}},
"notifications": {
"email": True,
"sms": False,
"push": {
"enabled": True,
"quiet_hours": [22, 7]
}
"push": {"enabled": True, "quiet_hours": [22, 7]},
},
"tags": ["important", "family", "events"]
"tags": ["important", "family", "events"],
}
user = User(
@@ -234,16 +226,18 @@ def test_user_with_complex_json_preferences(db_session):
password_hash="hashedpassword",
first_name="Complex",
last_name="JSON",
preferences=complex_preferences
preferences=complex_preferences,
)
db_session.add(user)
db_session.commit()
# Act - Retrieve the user
retrieved_user = db_session.query(User).filter_by(email="complex@example.com").first()
retrieved_user = (
db_session.query(User).filter_by(email="complex@example.com").first()
)
# Assert - The complex JSON should be preserved
assert retrieved_user.preferences == complex_preferences
assert retrieved_user.preferences["theme"]["colors"]["primary"] == "#333"
assert retrieved_user.preferences["notifications"]["push"]["quiet_hours"] == [22, 7]
assert "important" in retrieved_user.preferences["tags"]
assert "important" in retrieved_user.preferences["tags"]

View File

@@ -5,6 +5,7 @@ Covers Pydantic validators for:
- Slug validation (lines 26, 28, 30, 32, 62-70)
- Name validation (lines 40, 77)
"""
import pytest
from pydantic import ValidationError
@@ -20,19 +21,13 @@ class TestOrganizationBaseValidators:
def test_valid_organization_base(self):
"""Test that valid data passes validation."""
org = OrganizationBase(
name="Test Organization",
slug="test-org"
)
org = OrganizationBase(name="Test Organization", slug="test-org")
assert org.name == "Test Organization"
assert org.slug == "test-org"
def test_slug_none_returns_none(self):
"""Test that None slug is allowed (covers line 26)."""
org = OrganizationBase(
name="Test Organization",
slug=None
)
org = OrganizationBase(name="Test Organization", slug=None)
assert org.slug is None
def test_slug_invalid_characters_rejected(self):
@@ -40,57 +35,46 @@ class TestOrganizationBaseValidators:
with pytest.raises(ValidationError) as exc_info:
OrganizationBase(
name="Test Organization",
slug="Test_Org!" # Uppercase and special chars
slug="Test_Org!", # Uppercase and special chars
)
errors = exc_info.value.errors()
assert any("lowercase letters, numbers, and hyphens" in str(e['msg']) for e in errors)
assert any(
"lowercase letters, numbers, and hyphens" in str(e["msg"]) for e in errors
)
def test_slug_starts_with_hyphen_rejected(self):
"""Test slug starting with hyphen is rejected (covers line 30)."""
with pytest.raises(ValidationError) as exc_info:
OrganizationBase(
name="Test Organization",
slug="-test-org"
)
OrganizationBase(name="Test Organization", slug="-test-org")
errors = exc_info.value.errors()
assert any("cannot start or end with a hyphen" in str(e['msg']) for e in errors)
assert any("cannot start or end with a hyphen" in str(e["msg"]) for e in errors)
def test_slug_ends_with_hyphen_rejected(self):
"""Test slug ending with hyphen is rejected (covers line 30)."""
with pytest.raises(ValidationError) as exc_info:
OrganizationBase(
name="Test Organization",
slug="test-org-"
)
OrganizationBase(name="Test Organization", slug="test-org-")
errors = exc_info.value.errors()
assert any("cannot start or end with a hyphen" in str(e['msg']) for e in errors)
assert any("cannot start or end with a hyphen" in str(e["msg"]) for e in errors)
def test_slug_consecutive_hyphens_rejected(self):
"""Test slug with consecutive hyphens is rejected (covers line 32)."""
with pytest.raises(ValidationError) as exc_info:
OrganizationBase(
name="Test Organization",
slug="test--org"
)
OrganizationBase(name="Test Organization", slug="test--org")
errors = exc_info.value.errors()
assert any("cannot contain consecutive hyphens" in str(e['msg']) for e in errors)
assert any(
"cannot contain consecutive hyphens" in str(e["msg"]) for e in errors
)
def test_name_whitespace_only_rejected(self):
"""Test whitespace-only name is rejected (covers line 40)."""
with pytest.raises(ValidationError) as exc_info:
OrganizationBase(
name=" ",
slug="test-org"
)
OrganizationBase(name=" ", slug="test-org")
errors = exc_info.value.errors()
assert any("name cannot be empty" in str(e['msg']) for e in errors)
assert any("name cannot be empty" in str(e["msg"]) for e in errors)
def test_name_trimmed(self):
"""Test that name is trimmed."""
org = OrganizationBase(
name=" Test Organization ",
slug="test-org"
)
org = OrganizationBase(name=" Test Organization ", slug="test-org")
assert org.name == "Test Organization"
@@ -99,22 +83,18 @@ class TestOrganizationCreateValidators:
def test_valid_organization_create(self):
"""Test that valid data passes validation."""
org = OrganizationCreate(
name="Test Organization",
slug="test-org"
)
org = OrganizationCreate(name="Test Organization", slug="test-org")
assert org.name == "Test Organization"
assert org.slug == "test-org"
def test_slug_validation_inherited(self):
"""Test that slug validation is inherited from base."""
with pytest.raises(ValidationError) as exc_info:
OrganizationCreate(
name="Test",
slug="Invalid_Slug!"
)
OrganizationCreate(name="Test", slug="Invalid_Slug!")
errors = exc_info.value.errors()
assert any("lowercase letters, numbers, and hyphens" in str(e['msg']) for e in errors)
assert any(
"lowercase letters, numbers, and hyphens" in str(e["msg"]) for e in errors
)
class TestOrganizationUpdateValidators:
@@ -122,10 +102,7 @@ class TestOrganizationUpdateValidators:
def test_valid_organization_update(self):
"""Test that valid update data passes validation."""
org = OrganizationUpdate(
name="Updated Name",
slug="updated-slug"
)
org = OrganizationUpdate(name="Updated Name", slug="updated-slug")
assert org.name == "Updated Name"
assert org.slug == "updated-slug"
@@ -139,35 +116,39 @@ class TestOrganizationUpdateValidators:
with pytest.raises(ValidationError) as exc_info:
OrganizationUpdate(slug="Test_Org!")
errors = exc_info.value.errors()
assert any("lowercase letters, numbers, and hyphens" in str(e['msg']) for e in errors)
assert any(
"lowercase letters, numbers, and hyphens" in str(e["msg"]) for e in errors
)
def test_update_slug_starts_with_hyphen_rejected(self):
"""Test update slug starting with hyphen is rejected (covers line 66)."""
with pytest.raises(ValidationError) as exc_info:
OrganizationUpdate(slug="-test-org")
errors = exc_info.value.errors()
assert any("cannot start or end with a hyphen" in str(e['msg']) for e in errors)
assert any("cannot start or end with a hyphen" in str(e["msg"]) for e in errors)
def test_update_slug_ends_with_hyphen_rejected(self):
"""Test update slug ending with hyphen is rejected (covers line 66)."""
with pytest.raises(ValidationError) as exc_info:
OrganizationUpdate(slug="test-org-")
errors = exc_info.value.errors()
assert any("cannot start or end with a hyphen" in str(e['msg']) for e in errors)
assert any("cannot start or end with a hyphen" in str(e["msg"]) for e in errors)
def test_update_slug_consecutive_hyphens_rejected(self):
"""Test update slug with consecutive hyphens is rejected (covers line 68)."""
with pytest.raises(ValidationError) as exc_info:
OrganizationUpdate(slug="test--org")
errors = exc_info.value.errors()
assert any("cannot contain consecutive hyphens" in str(e['msg']) for e in errors)
assert any(
"cannot contain consecutive hyphens" in str(e["msg"]) for e in errors
)
def test_update_name_whitespace_only_rejected(self):
"""Test whitespace-only name in update is rejected (covers line 77)."""
with pytest.raises(ValidationError) as exc_info:
OrganizationUpdate(name=" ")
errors = exc_info.value.errors()
assert any("name cannot be empty" in str(e['msg']) for e in errors)
assert any("name cannot be empty" in str(e["msg"]) for e in errors)
def test_update_name_none_allowed(self):
"""Test that None name is allowed in update."""

View File

@@ -1,80 +1,177 @@
# tests/schemas/test_user_schemas.py
import pytest
import re
import pytest
from pydantic import ValidationError
from app.schemas.users import UserBase, UserCreate
class TestPhoneNumberValidation:
"""Tests for phone number validation in user schemas"""
def test_valid_swiss_numbers(self):
"""Test valid Swiss phone numbers are accepted"""
# International format
user = UserBase(email="test@example.com", first_name="Test", last_name="User", phone_number="+41791234567")
user = UserBase(
email="test@example.com",
first_name="Test",
last_name="User",
phone_number="+41791234567",
)
assert user.phone_number == "+41791234567"
# Local format
user = UserBase(email="test@example.com", first_name="Test", last_name="User", phone_number="0791234567")
user = UserBase(
email="test@example.com",
first_name="Test",
last_name="User",
phone_number="0791234567",
)
assert user.phone_number == "0791234567"
# With formatting characters
user = UserBase(email="test@example.com", first_name="Test", last_name="User", phone_number="+41 79 123 45 67")
assert re.sub(r'[\s\-\(\)]', '', user.phone_number) == "+41791234567"
user = UserBase(
email="test@example.com",
first_name="Test",
last_name="User",
phone_number="+41 79 123 45 67",
)
assert re.sub(r"[\s\-\(\)]", "", user.phone_number) == "+41791234567"
user = UserBase(email="test@example.com", first_name="Test", last_name="User", phone_number="079 123 45 67")
assert re.sub(r'[\s\-\(\)]', '', user.phone_number) == "0791234567"
user = UserBase(
email="test@example.com",
first_name="Test",
last_name="User",
phone_number="079 123 45 67",
)
assert re.sub(r"[\s\-\(\)]", "", user.phone_number) == "0791234567"
user = UserBase(email="test@example.com", first_name="Test", last_name="User", phone_number="+41-79-123-45-67")
assert re.sub(r'[\s\-\(\)]', '', user.phone_number) == "+41791234567"
user = UserBase(
email="test@example.com",
first_name="Test",
last_name="User",
phone_number="+41-79-123-45-67",
)
assert re.sub(r"[\s\-\(\)]", "", user.phone_number) == "+41791234567"
user = UserBase(email="test@example.com", first_name="Test", last_name="User", phone_number="079-123-45-67")
assert re.sub(r'[\s\-\(\)]', '', user.phone_number) == "0791234567"
user = UserBase(
email="test@example.com",
first_name="Test",
last_name="User",
phone_number="079-123-45-67",
)
assert re.sub(r"[\s\-\(\)]", "", user.phone_number) == "0791234567"
user = UserBase(email="test@example.com", first_name="Test", last_name="User", phone_number="+41 (79) 123 45 67")
assert re.sub(r'[\s\-\(\)]', '', user.phone_number) == "+41791234567"
user = UserBase(
email="test@example.com",
first_name="Test",
last_name="User",
phone_number="+41 (79) 123 45 67",
)
assert re.sub(r"[\s\-\(\)]", "", user.phone_number) == "+41791234567"
user = UserBase(email="test@example.com", first_name="Test", last_name="User", phone_number="079 (123) 45 67")
assert re.sub(r'[\s\-\(\)]', '', user.phone_number) == "0791234567"
user = UserBase(
email="test@example.com",
first_name="Test",
last_name="User",
phone_number="079 (123) 45 67",
)
assert re.sub(r"[\s\-\(\)]", "", user.phone_number) == "0791234567"
def test_valid_italian_numbers(self):
"""Test valid Italian phone numbers are accepted"""
# International format
user = UserBase(email="test@example.com", first_name="Test", last_name="User", phone_number="+393451234567")
user = UserBase(
email="test@example.com",
first_name="Test",
last_name="User",
phone_number="+393451234567",
)
assert user.phone_number == "+393451234567"
user = UserBase(email="test@example.com", first_name="Test", last_name="User", phone_number="+39345123456")
user = UserBase(
email="test@example.com",
first_name="Test",
last_name="User",
phone_number="+39345123456",
)
assert user.phone_number == "+39345123456"
# Local format
user = UserBase(email="test@example.com", first_name="Test", last_name="User", phone_number="03451234567")
user = UserBase(
email="test@example.com",
first_name="Test",
last_name="User",
phone_number="03451234567",
)
assert user.phone_number == "03451234567"
user = UserBase(email="test@example.com", first_name="Test", last_name="User", phone_number="0345123456789")
user = UserBase(
email="test@example.com",
first_name="Test",
last_name="User",
phone_number="0345123456789",
)
assert user.phone_number == "0345123456789"
# With formatting characters
user = UserBase(email="test@example.com", first_name="Test", last_name="User", phone_number="+39 345 123 4567")
assert re.sub(r'[\s\-\(\)]', '', user.phone_number) == "+393451234567"
user = UserBase(
email="test@example.com",
first_name="Test",
last_name="User",
phone_number="+39 345 123 4567",
)
assert re.sub(r"[\s\-\(\)]", "", user.phone_number) == "+393451234567"
user = UserBase(email="test@example.com", first_name="Test", last_name="User", phone_number="0345 123 4567")
assert re.sub(r'[\s\-\(\)]', '', user.phone_number) == "03451234567"
user = UserBase(
email="test@example.com",
first_name="Test",
last_name="User",
phone_number="0345 123 4567",
)
assert re.sub(r"[\s\-\(\)]", "", user.phone_number) == "03451234567"
user = UserBase(email="test@example.com", first_name="Test", last_name="User", phone_number="+39-345-123-4567")
assert re.sub(r'[\s\-\(\)]', '', user.phone_number) == "+393451234567"
user = UserBase(
email="test@example.com",
first_name="Test",
last_name="User",
phone_number="+39-345-123-4567",
)
assert re.sub(r"[\s\-\(\)]", "", user.phone_number) == "+393451234567"
user = UserBase(email="test@example.com", first_name="Test", last_name="User", phone_number="0345-123-4567")
assert re.sub(r'[\s\-\(\)]', '', user.phone_number) == "03451234567"
user = UserBase(
email="test@example.com",
first_name="Test",
last_name="User",
phone_number="0345-123-4567",
)
assert re.sub(r"[\s\-\(\)]", "", user.phone_number) == "03451234567"
user = UserBase(email="test@example.com", first_name="Test", last_name="User", phone_number="+39 (345) 123 4567")
assert re.sub(r'[\s\-\(\)]', '', user.phone_number) == "+393451234567"
user = UserBase(
email="test@example.com",
first_name="Test",
last_name="User",
phone_number="+39 (345) 123 4567",
)
assert re.sub(r"[\s\-\(\)]", "", user.phone_number) == "+393451234567"
user = UserBase(email="test@example.com", first_name="Test", last_name="User", phone_number="0345 (123) 4567")
assert re.sub(r'[\s\-\(\)]', '', user.phone_number) == "03451234567"
user = UserBase(
email="test@example.com",
first_name="Test",
last_name="User",
phone_number="0345 (123) 4567",
)
assert re.sub(r"[\s\-\(\)]", "", user.phone_number) == "03451234567"
def test_none_phone_number(self):
"""Test that None is accepted as a valid value (optional phone number)"""
user = UserBase(email="test@example.com", first_name="Test", last_name="User", phone_number=None)
user = UserBase(
email="test@example.com",
first_name="Test",
last_name="User",
phone_number=None,
)
assert user.phone_number is None
def test_invalid_phone_numbers(self):
@@ -83,17 +180,14 @@ class TestPhoneNumberValidation:
# Too short
"+12",
"012",
# Invalid characters
"+41xyz123456",
"079abc4567",
"123-abc-7890",
"+1(800)CALL-NOW",
# Completely invalid formats
"++4412345678", # Double plus
# Note: "()+41123456" becomes "+41123456" after cleaning, which is valid
# Empty string
"",
# Spaces only
@@ -102,7 +196,12 @@ class TestPhoneNumberValidation:
for number in invalid_numbers:
with pytest.raises(ValidationError):
UserBase(email="test@example.com", first_name="Test", last_name="User", phone_number=number)
UserBase(
email="test@example.com",
first_name="Test",
last_name="User",
phone_number=number,
)
def test_phone_validation_in_user_create(self):
"""Test that phone validation also works in UserCreate schema"""
@@ -112,7 +211,7 @@ class TestPhoneNumberValidation:
first_name="Test",
last_name="User",
password="Password123!",
phone_number="+41791234567"
phone_number="+41791234567",
)
assert user.phone_number == "+41791234567"
@@ -123,5 +222,5 @@ class TestPhoneNumberValidation:
first_name="Test",
last_name="User",
password="Password123!",
phone_number="invalid-number"
)
phone_number="invalid-number",
)

View File

@@ -7,12 +7,13 @@ Covers all edge cases in validation functions:
- validate_email_format (line 148)
- validate_slug (lines 170-183)
"""
import pytest
from app.schemas.validators import (
validate_email_format,
validate_password_strength,
validate_phone_number,
validate_email_format,
validate_slug,
)
@@ -108,12 +109,14 @@ class TestPhoneNumberValidator:
validate_phone_number("+123456789012345") # 15 digits after +
def test_multiple_plus_symbols_rejected(self):
"""Test phone number with multiple + symbols.
r"""Test phone number with multiple + symbols.
Note: Line 115 is defensive code - the regex check at line 110 catches this first.
The regex ^(?:\+[0-9]{8,14}|0[0-9]{8,14})$ only allows + at the start.
"""
with pytest.raises(ValueError, match="must start with \\+ or 0 followed by 8-14 digits"):
with pytest.raises(
ValueError, match="must start with \\+ or 0 followed by 8-14 digits"
):
validate_phone_number("+1234+5678901")
def test_non_digit_after_prefix_rejected(self):

View File

@@ -1,14 +1,18 @@
# tests/services/test_auth_service.py
import uuid
import pytest
import pytest_asyncio
from unittest.mock import patch
import pytest
from sqlalchemy import select
from app.core.auth import get_password_hash, verify_password, TokenExpiredError, TokenInvalidError
from app.core.auth import (
TokenInvalidError,
get_password_hash,
verify_password,
)
from app.models.user import User
from app.schemas.users import UserCreate, Token
from app.services.auth_service import AuthService, AuthenticationError
from app.schemas.users import Token, UserCreate
from app.services.auth_service import AuthenticationError, AuthService
class TestAuthServiceAuthentication:
@@ -17,12 +21,14 @@ class TestAuthServiceAuthentication:
@pytest.mark.asyncio
async def test_authenticate_valid_user(self, async_test_db, async_test_user):
"""Test authenticating a user with valid credentials"""
test_engine, AsyncTestingSessionLocal = async_test_db
_test_engine, AsyncTestingSessionLocal = async_test_db
# Set a known password for the mock user
password = "TestPassword123!"
async with AsyncTestingSessionLocal() as session:
result = await session.execute(select(User).where(User.id == async_test_user.id))
result = await session.execute(
select(User).where(User.id == async_test_user.id)
)
user = result.scalar_one_or_none()
user.password_hash = get_password_hash(password)
await session.commit()
@@ -30,9 +36,7 @@ class TestAuthServiceAuthentication:
# Authenticate with correct credentials
async with AsyncTestingSessionLocal() as session:
auth_user = await AuthService.authenticate_user(
db=session,
email=async_test_user.email,
password=password
db=session, email=async_test_user.email, password=password
)
assert auth_user is not None
@@ -42,26 +46,28 @@ class TestAuthServiceAuthentication:
@pytest.mark.asyncio
async def test_authenticate_nonexistent_user(self, async_test_db):
"""Test authenticating with an email that doesn't exist"""
test_engine, AsyncTestingSessionLocal = async_test_db
_test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session:
user = await AuthService.authenticate_user(
db=session,
email="nonexistent@example.com",
password="password"
db=session, email="nonexistent@example.com", password="password"
)
assert user is None
@pytest.mark.asyncio
async def test_authenticate_with_wrong_password(self, async_test_db, async_test_user):
async def test_authenticate_with_wrong_password(
self, async_test_db, async_test_user
):
"""Test authenticating with the wrong password"""
test_engine, AsyncTestingSessionLocal = async_test_db
_test_engine, AsyncTestingSessionLocal = async_test_db
# Set a known password for the mock user
password = "TestPassword123!"
async with AsyncTestingSessionLocal() as session:
result = await session.execute(select(User).where(User.id == async_test_user.id))
result = await session.execute(
select(User).where(User.id == async_test_user.id)
)
user = result.scalar_one_or_none()
user.password_hash = get_password_hash(password)
await session.commit()
@@ -69,9 +75,7 @@ class TestAuthServiceAuthentication:
# Authenticate with wrong password
async with AsyncTestingSessionLocal() as session:
auth_user = await AuthService.authenticate_user(
db=session,
email=async_test_user.email,
password="WrongPassword123"
db=session, email=async_test_user.email, password="WrongPassword123"
)
assert auth_user is None
@@ -79,12 +83,14 @@ class TestAuthServiceAuthentication:
@pytest.mark.asyncio
async def test_authenticate_inactive_user(self, async_test_db, async_test_user):
"""Test authenticating an inactive user"""
test_engine, AsyncTestingSessionLocal = async_test_db
_test_engine, AsyncTestingSessionLocal = async_test_db
# Set a known password and make user inactive
password = "TestPassword123!"
async with AsyncTestingSessionLocal() as session:
result = await session.execute(select(User).where(User.id == async_test_user.id))
result = await session.execute(
select(User).where(User.id == async_test_user.id)
)
user = result.scalar_one_or_none()
user.password_hash = get_password_hash(password)
user.is_active = False
@@ -94,9 +100,7 @@ class TestAuthServiceAuthentication:
async with AsyncTestingSessionLocal() as session:
with pytest.raises(AuthenticationError):
await AuthService.authenticate_user(
db=session,
email=async_test_user.email,
password=password
db=session, email=async_test_user.email, password=password
)
@@ -106,14 +110,14 @@ class TestAuthServiceUserCreation:
@pytest.mark.asyncio
async def test_create_new_user(self, async_test_db):
"""Test creating a new user"""
test_engine, AsyncTestingSessionLocal = async_test_db
_test_engine, AsyncTestingSessionLocal = async_test_db
user_data = UserCreate(
email="newuser@example.com",
password="TestPassword123!",
first_name="New",
last_name="User",
phone_number="+1234567890"
phone_number="+1234567890",
)
async with AsyncTestingSessionLocal() as session:
@@ -135,15 +139,17 @@ class TestAuthServiceUserCreation:
assert user.is_superuser is False
@pytest.mark.asyncio
async def test_create_user_with_existing_email(self, async_test_db, async_test_user):
async def test_create_user_with_existing_email(
self, async_test_db, async_test_user
):
"""Test creating a user with an email that already exists"""
test_engine, AsyncTestingSessionLocal = async_test_db
_test_engine, AsyncTestingSessionLocal = async_test_db
user_data = UserCreate(
email=async_test_user.email, # Use existing email
password="TestPassword123!",
first_name="Duplicate",
last_name="User"
last_name="User",
)
# Should raise AuthenticationError
@@ -169,7 +175,7 @@ class TestAuthServiceTokens:
@pytest.mark.asyncio
async def test_refresh_tokens(self, async_test_db, async_test_user):
"""Test refreshing tokens with a valid refresh token"""
test_engine, AsyncTestingSessionLocal = async_test_db
_test_engine, AsyncTestingSessionLocal = async_test_db
# Create initial tokens
initial_tokens = AuthService.create_tokens(async_test_user)
@@ -177,8 +183,7 @@ class TestAuthServiceTokens:
# Refresh tokens
async with AsyncTestingSessionLocal() as session:
new_tokens = await AuthService.refresh_tokens(
db=session,
refresh_token=initial_tokens.refresh_token
db=session, refresh_token=initial_tokens.refresh_token
)
# Verify new tokens are different from old ones
@@ -188,7 +193,7 @@ class TestAuthServiceTokens:
@pytest.mark.asyncio
async def test_refresh_tokens_with_invalid_token(self, async_test_db):
"""Test refreshing tokens with an invalid token"""
test_engine, AsyncTestingSessionLocal = async_test_db
_test_engine, AsyncTestingSessionLocal = async_test_db
# Create an invalid token
invalid_token = "invalid.token.string"
@@ -197,14 +202,15 @@ class TestAuthServiceTokens:
async with AsyncTestingSessionLocal() as session:
with pytest.raises(TokenInvalidError):
await AuthService.refresh_tokens(
db=session,
refresh_token=invalid_token
db=session, refresh_token=invalid_token
)
@pytest.mark.asyncio
async def test_refresh_tokens_with_access_token(self, async_test_db, async_test_user):
async def test_refresh_tokens_with_access_token(
self, async_test_db, async_test_user
):
"""Test refreshing tokens with an access token instead of refresh token"""
test_engine, AsyncTestingSessionLocal = async_test_db
_test_engine, AsyncTestingSessionLocal = async_test_db
# Create tokens
tokens = AuthService.create_tokens(async_test_user)
@@ -213,18 +219,20 @@ class TestAuthServiceTokens:
async with AsyncTestingSessionLocal() as session:
with pytest.raises(TokenInvalidError):
await AuthService.refresh_tokens(
db=session,
refresh_token=tokens.access_token
db=session, refresh_token=tokens.access_token
)
@pytest.mark.asyncio
async def test_refresh_tokens_with_nonexistent_user(self, async_test_db):
"""Test refreshing tokens for a user that doesn't exist in the database"""
test_engine, AsyncTestingSessionLocal = async_test_db
_test_engine, AsyncTestingSessionLocal = async_test_db
# Create a token for a non-existent user
non_existent_id = str(uuid.uuid4())
with patch('app.core.auth.decode_token'), patch('app.core.auth.get_token_data') as mock_get_data:
with (
patch("app.core.auth.decode_token"),
patch("app.core.auth.get_token_data") as mock_get_data,
):
# Mock the token data to return a non-existent user ID
mock_get_data.return_value.user_id = uuid.UUID(non_existent_id)
@@ -232,8 +240,7 @@ class TestAuthServiceTokens:
async with AsyncTestingSessionLocal() as session:
with pytest.raises(TokenInvalidError):
await AuthService.refresh_tokens(
db=session,
refresh_token="some.refresh.token"
db=session, refresh_token="some.refresh.token"
)
@@ -243,12 +250,14 @@ class TestAuthServicePasswordChange:
@pytest.mark.asyncio
async def test_change_password(self, async_test_db, async_test_user):
"""Test changing a user's password"""
test_engine, AsyncTestingSessionLocal = async_test_db
_test_engine, AsyncTestingSessionLocal = async_test_db
# Set a known password for the mock user
current_password = "CurrentPassword123"
async with AsyncTestingSessionLocal() as session:
result = await session.execute(select(User).where(User.id == async_test_user.id))
result = await session.execute(
select(User).where(User.id == async_test_user.id)
)
user = result.scalar_one_or_none()
user.password_hash = get_password_hash(current_password)
await session.commit()
@@ -260,7 +269,7 @@ class TestAuthServicePasswordChange:
db=session,
user_id=async_test_user.id,
current_password=current_password,
new_password=new_password
new_password=new_password,
)
# Verify operation was successful
@@ -268,7 +277,9 @@ class TestAuthServicePasswordChange:
# Verify password was changed
async with AsyncTestingSessionLocal() as session:
result = await session.execute(select(User).where(User.id == async_test_user.id))
result = await session.execute(
select(User).where(User.id == async_test_user.id)
)
updated_user = result.scalar_one_or_none()
# Verify old password no longer works
@@ -278,14 +289,18 @@ class TestAuthServicePasswordChange:
assert verify_password(new_password, updated_user.password_hash)
@pytest.mark.asyncio
async def test_change_password_wrong_current_password(self, async_test_db, async_test_user):
async def test_change_password_wrong_current_password(
self, async_test_db, async_test_user
):
"""Test changing password with incorrect current password"""
test_engine, AsyncTestingSessionLocal = async_test_db
_test_engine, AsyncTestingSessionLocal = async_test_db
# Set a known password for the mock user
current_password = "CurrentPassword123"
async with AsyncTestingSessionLocal() as session:
result = await session.execute(select(User).where(User.id == async_test_user.id))
result = await session.execute(
select(User).where(User.id == async_test_user.id)
)
user = result.scalar_one_or_none()
user.password_hash = get_password_hash(current_password)
await session.commit()
@@ -298,19 +313,21 @@ class TestAuthServicePasswordChange:
db=session,
user_id=async_test_user.id,
current_password=wrong_password,
new_password="NewPassword456"
new_password="NewPassword456",
)
# Verify password was not changed
async with AsyncTestingSessionLocal() as session:
result = await session.execute(select(User).where(User.id == async_test_user.id))
result = await session.execute(
select(User).where(User.id == async_test_user.id)
)
user = result.scalar_one_or_none()
assert verify_password(current_password, user.password_hash)
@pytest.mark.asyncio
async def test_change_password_nonexistent_user(self, async_test_db):
"""Test changing password for a user that doesn't exist"""
test_engine, AsyncTestingSessionLocal = async_test_db
_test_engine, AsyncTestingSessionLocal = async_test_db
non_existent_id = uuid.uuid4()
@@ -320,5 +337,5 @@ class TestAuthServicePasswordChange:
db=session,
user_id=non_existent_id,
current_password="CurrentPassword123",
new_password="NewPassword456"
new_password="NewPassword456",
)

View File

@@ -2,13 +2,15 @@
"""
Tests for email service functionality.
"""
from unittest.mock import AsyncMock
import pytest
from unittest.mock import patch, AsyncMock, MagicMock
from app.services.email_service import (
EmailService,
ConsoleEmailBackend,
SMTPEmailBackend
EmailService,
SMTPEmailBackend,
)
@@ -24,7 +26,7 @@ class TestConsoleEmailBackend:
to=["user@example.com"],
subject="Test Subject",
html_content="<p>Test HTML</p>",
text_content="Test Text"
text_content="Test Text",
)
assert result is True
@@ -37,7 +39,7 @@ class TestConsoleEmailBackend:
result = await backend.send_email(
to=["user@example.com"],
subject="Test Subject",
html_content="<p>Test HTML</p>"
html_content="<p>Test HTML</p>",
)
assert result is True
@@ -50,7 +52,7 @@ class TestConsoleEmailBackend:
result = await backend.send_email(
to=["user1@example.com", "user2@example.com"],
subject="Test Subject",
html_content="<p>Test HTML</p>"
html_content="<p>Test HTML</p>",
)
assert result is True
@@ -66,7 +68,7 @@ class TestSMTPEmailBackend:
host="smtp.example.com",
port=587,
username="test@example.com",
password="password"
password="password",
)
assert backend.host == "smtp.example.com"
@@ -81,14 +83,14 @@ class TestSMTPEmailBackend:
host="smtp.example.com",
port=587,
username="test@example.com",
password="password"
password="password",
)
# Should fall back to console backend since SMTP is not implemented
result = await backend.send_email(
to=["user@example.com"],
subject="Test Subject",
html_content="<p>Test HTML</p>"
html_content="<p>Test HTML</p>",
)
assert result is True
@@ -114,9 +116,7 @@ class TestEmailService:
service = EmailService()
result = await service.send_password_reset_email(
to_email="user@example.com",
reset_token="test_token_123",
user_name="John"
to_email="user@example.com", reset_token="test_token_123", user_name="John"
)
assert result is True
@@ -127,8 +127,7 @@ class TestEmailService:
service = EmailService()
result = await service.send_password_reset_email(
to_email="user@example.com",
reset_token="test_token_123"
to_email="user@example.com", reset_token="test_token_123"
)
assert result is True
@@ -142,8 +141,7 @@ class TestEmailService:
token = "test_reset_token_xyz"
await service.send_password_reset_email(
to_email="user@example.com",
reset_token=token
to_email="user@example.com", reset_token=token
)
# Verify send_email was called
@@ -151,7 +149,7 @@ class TestEmailService:
call_args = backend_mock.send_email.call_args
# Check that token is in the HTML content
html_content = call_args.kwargs['html_content']
html_content = call_args.kwargs["html_content"]
assert token in html_content
@pytest.mark.asyncio
@@ -162,8 +160,7 @@ class TestEmailService:
service = EmailService(backend=backend_mock)
result = await service.send_password_reset_email(
to_email="user@example.com",
reset_token="test_token"
to_email="user@example.com", reset_token="test_token"
)
assert result is False
@@ -176,7 +173,7 @@ class TestEmailService:
result = await service.send_email_verification(
to_email="user@example.com",
verification_token="verification_token_123",
user_name="Jane"
user_name="Jane",
)
assert result is True
@@ -187,8 +184,7 @@ class TestEmailService:
service = EmailService()
result = await service.send_email_verification(
to_email="user@example.com",
verification_token="verification_token_123"
to_email="user@example.com", verification_token="verification_token_123"
)
assert result is True
@@ -202,8 +198,7 @@ class TestEmailService:
token = "test_verification_token_xyz"
await service.send_email_verification(
to_email="user@example.com",
verification_token=token
to_email="user@example.com", verification_token=token
)
# Verify send_email was called
@@ -211,7 +206,7 @@ class TestEmailService:
call_args = backend_mock.send_email.call_args
# Check that token is in the HTML content
html_content = call_args.kwargs['html_content']
html_content = call_args.kwargs["html_content"]
assert token in html_content
@pytest.mark.asyncio
@@ -222,8 +217,7 @@ class TestEmailService:
service = EmailService(backend=backend_mock)
result = await service.send_email_verification(
to_email="user@example.com",
verification_token="test_token"
to_email="user@example.com", verification_token="test_token"
)
assert result is False
@@ -236,14 +230,12 @@ class TestEmailService:
service = EmailService(backend=backend_mock)
await service.send_password_reset_email(
to_email="user@example.com",
reset_token="token123",
user_name="Test User"
to_email="user@example.com", reset_token="token123", user_name="Test User"
)
call_args = backend_mock.send_email.call_args
html_content = call_args.kwargs['html_content']
text_content = call_args.kwargs['text_content']
html_content = call_args.kwargs["html_content"]
text_content = call_args.kwargs["text_content"]
# Check HTML content
assert "Password Reset" in html_content
@@ -251,7 +243,9 @@ class TestEmailService:
assert "Test User" in html_content
# Check text content
assert "Password Reset" in text_content or "password reset" in text_content.lower()
assert (
"Password Reset" in text_content or "password reset" in text_content.lower()
)
assert "token123" in text_content
@pytest.mark.asyncio
@@ -264,12 +258,12 @@ class TestEmailService:
await service.send_email_verification(
to_email="user@example.com",
verification_token="verify123",
user_name="Test User"
user_name="Test User",
)
call_args = backend_mock.send_email.call_args
html_content = call_args.kwargs['html_content']
text_content = call_args.kwargs['text_content']
html_content = call_args.kwargs["html_content"]
text_content = call_args.kwargs["text_content"]
# Check HTML content
assert "Verify" in html_content

View File

@@ -2,23 +2,27 @@
"""
Comprehensive tests for session cleanup service.
"""
import pytest
import asyncio
from datetime import datetime, timedelta, timezone
from unittest.mock import patch, MagicMock, AsyncMock
from contextlib import asynccontextmanager
from datetime import UTC, datetime, timedelta
from unittest.mock import AsyncMock, patch
import pytest
from sqlalchemy import select
from app.models.user_session import UserSession
from sqlalchemy import select
class TestCleanupExpiredSessions:
"""Tests for cleanup_expired_sessions function."""
@pytest.mark.asyncio
async def test_cleanup_expired_sessions_success(self, async_test_db, async_test_user):
async def test_cleanup_expired_sessions_success(
self, async_test_db, async_test_user
):
"""Test successful cleanup of expired sessions."""
test_engine, AsyncTestingSessionLocal = async_test_db
_test_engine, AsyncTestingSessionLocal = async_test_db
# Create mix of sessions
async with AsyncTestingSessionLocal() as session:
@@ -30,9 +34,9 @@ class TestCleanupExpiredSessions:
ip_address="192.168.1.1",
user_agent="Mozilla/5.0",
is_active=True,
expires_at=datetime.now(timezone.utc) + timedelta(days=7),
created_at=datetime.now(timezone.utc) - timedelta(days=1),
last_used_at=datetime.now(timezone.utc)
expires_at=datetime.now(UTC) + timedelta(days=7),
created_at=datetime.now(UTC) - timedelta(days=1),
last_used_at=datetime.now(UTC),
)
# 2. Inactive, expired, old (SHOULD be deleted)
@@ -43,9 +47,9 @@ class TestCleanupExpiredSessions:
ip_address="192.168.1.2",
user_agent="Mozilla/5.0",
is_active=False,
expires_at=datetime.now(timezone.utc) - timedelta(days=10),
created_at=datetime.now(timezone.utc) - timedelta(days=40),
last_used_at=datetime.now(timezone.utc)
expires_at=datetime.now(UTC) - timedelta(days=10),
created_at=datetime.now(UTC) - timedelta(days=40),
last_used_at=datetime.now(UTC),
)
# 3. Inactive, expired, recent (should NOT be deleted - within keep_days)
@@ -56,17 +60,23 @@ class TestCleanupExpiredSessions:
ip_address="192.168.1.3",
user_agent="Mozilla/5.0",
is_active=False,
expires_at=datetime.now(timezone.utc) - timedelta(days=1),
created_at=datetime.now(timezone.utc) - timedelta(days=5),
last_used_at=datetime.now(timezone.utc)
expires_at=datetime.now(UTC) - timedelta(days=1),
created_at=datetime.now(UTC) - timedelta(days=5),
last_used_at=datetime.now(UTC),
)
session.add_all([active_session, old_expired_session, recent_expired_session])
session.add_all(
[active_session, old_expired_session, recent_expired_session]
)
await session.commit()
# Mock SessionLocal to return our test session
with patch('app.services.session_cleanup.SessionLocal', return_value=AsyncTestingSessionLocal()):
with patch(
"app.services.session_cleanup.SessionLocal",
return_value=AsyncTestingSessionLocal(),
):
from app.services.session_cleanup import cleanup_expired_sessions
deleted_count = await cleanup_expired_sessions(keep_days=30)
# Should only delete old_expired_session
@@ -85,7 +95,7 @@ class TestCleanupExpiredSessions:
@pytest.mark.asyncio
async def test_cleanup_no_sessions_to_delete(self, async_test_db, async_test_user):
"""Test cleanup when no sessions meet deletion criteria."""
test_engine, AsyncTestingSessionLocal = async_test_db
_test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session:
active = UserSession(
@@ -95,15 +105,19 @@ class TestCleanupExpiredSessions:
ip_address="192.168.1.1",
user_agent="Mozilla/5.0",
is_active=True,
expires_at=datetime.now(timezone.utc) + timedelta(days=7),
created_at=datetime.now(timezone.utc),
last_used_at=datetime.now(timezone.utc)
expires_at=datetime.now(UTC) + timedelta(days=7),
created_at=datetime.now(UTC),
last_used_at=datetime.now(UTC),
)
session.add(active)
await session.commit()
with patch('app.services.session_cleanup.SessionLocal', return_value=AsyncTestingSessionLocal()):
with patch(
"app.services.session_cleanup.SessionLocal",
return_value=AsyncTestingSessionLocal(),
):
from app.services.session_cleanup import cleanup_expired_sessions
deleted_count = await cleanup_expired_sessions(keep_days=30)
assert deleted_count == 0
@@ -111,10 +125,14 @@ class TestCleanupExpiredSessions:
@pytest.mark.asyncio
async def test_cleanup_empty_database(self, async_test_db):
"""Test cleanup with no sessions in database."""
test_engine, AsyncTestingSessionLocal = async_test_db
_test_engine, AsyncTestingSessionLocal = async_test_db
with patch('app.services.session_cleanup.SessionLocal', return_value=AsyncTestingSessionLocal()):
with patch(
"app.services.session_cleanup.SessionLocal",
return_value=AsyncTestingSessionLocal(),
):
from app.services.session_cleanup import cleanup_expired_sessions
deleted_count = await cleanup_expired_sessions(keep_days=30)
assert deleted_count == 0
@@ -122,7 +140,7 @@ class TestCleanupExpiredSessions:
@pytest.mark.asyncio
async def test_cleanup_with_keep_days_0(self, async_test_db, async_test_user):
"""Test cleanup with keep_days=0 deletes all inactive expired sessions."""
test_engine, AsyncTestingSessionLocal = async_test_db
_test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session:
today_expired = UserSession(
@@ -132,15 +150,19 @@ class TestCleanupExpiredSessions:
ip_address="192.168.1.1",
user_agent="Mozilla/5.0",
is_active=False,
expires_at=datetime.now(timezone.utc) - timedelta(hours=1),
created_at=datetime.now(timezone.utc) - timedelta(hours=2),
last_used_at=datetime.now(timezone.utc)
expires_at=datetime.now(UTC) - timedelta(hours=1),
created_at=datetime.now(UTC) - timedelta(hours=2),
last_used_at=datetime.now(UTC),
)
session.add(today_expired)
await session.commit()
with patch('app.services.session_cleanup.SessionLocal', return_value=AsyncTestingSessionLocal()):
with patch(
"app.services.session_cleanup.SessionLocal",
return_value=AsyncTestingSessionLocal(),
):
from app.services.session_cleanup import cleanup_expired_sessions
deleted_count = await cleanup_expired_sessions(keep_days=0)
assert deleted_count == 1
@@ -148,7 +170,7 @@ class TestCleanupExpiredSessions:
@pytest.mark.asyncio
async def test_cleanup_bulk_delete_efficiency(self, async_test_db, async_test_user):
"""Test that cleanup uses bulk DELETE for many sessions."""
test_engine, AsyncTestingSessionLocal = async_test_db
_test_engine, AsyncTestingSessionLocal = async_test_db
# Create 50 expired sessions
async with AsyncTestingSessionLocal() as session:
@@ -161,16 +183,20 @@ class TestCleanupExpiredSessions:
ip_address="192.168.1.1",
user_agent="Mozilla/5.0",
is_active=False,
expires_at=datetime.now(timezone.utc) - timedelta(days=10),
created_at=datetime.now(timezone.utc) - timedelta(days=40),
last_used_at=datetime.now(timezone.utc)
expires_at=datetime.now(UTC) - timedelta(days=10),
created_at=datetime.now(UTC) - timedelta(days=40),
last_used_at=datetime.now(UTC),
)
sessions_to_add.append(expired)
session.add_all(sessions_to_add)
await session.commit()
with patch('app.services.session_cleanup.SessionLocal', return_value=AsyncTestingSessionLocal()):
with patch(
"app.services.session_cleanup.SessionLocal",
return_value=AsyncTestingSessionLocal(),
):
from app.services.session_cleanup import cleanup_expired_sessions
deleted_count = await cleanup_expired_sessions(keep_days=30)
assert deleted_count == 50
@@ -178,14 +204,20 @@ class TestCleanupExpiredSessions:
@pytest.mark.asyncio
async def test_cleanup_database_error_returns_zero(self, async_test_db):
"""Test cleanup returns 0 on database errors (doesn't crash)."""
test_engine, AsyncTestingSessionLocal = async_test_db
_test_engine, AsyncTestingSessionLocal = async_test_db
# Mock session_crud.cleanup_expired to raise error
with patch('app.services.session_cleanup.SessionLocal', return_value=AsyncTestingSessionLocal()):
with patch('app.services.session_cleanup.session_crud.cleanup_expired') as mock_cleanup:
with patch(
"app.services.session_cleanup.SessionLocal",
return_value=AsyncTestingSessionLocal(),
):
with patch(
"app.services.session_cleanup.session_crud.cleanup_expired"
) as mock_cleanup:
mock_cleanup.side_effect = Exception("Database connection lost")
from app.services.session_cleanup import cleanup_expired_sessions
# Should not crash, should return 0
deleted_count = await cleanup_expired_sessions(keep_days=30)
@@ -198,7 +230,7 @@ class TestGetSessionStatistics:
@pytest.mark.asyncio
async def test_get_statistics_with_sessions(self, async_test_db, async_test_user):
"""Test getting session statistics with various session types."""
test_engine, AsyncTestingSessionLocal = async_test_db
_test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session:
# 2 active, not expired
@@ -210,9 +242,9 @@ class TestGetSessionStatistics:
ip_address="192.168.1.1",
user_agent="Mozilla/5.0",
is_active=True,
expires_at=datetime.now(timezone.utc) + timedelta(days=7),
created_at=datetime.now(timezone.utc),
last_used_at=datetime.now(timezone.utc)
expires_at=datetime.now(UTC) + timedelta(days=7),
created_at=datetime.now(UTC),
last_used_at=datetime.now(UTC),
)
session.add(active)
@@ -225,9 +257,9 @@ class TestGetSessionStatistics:
ip_address="192.168.1.2",
user_agent="Mozilla/5.0",
is_active=False,
expires_at=datetime.now(timezone.utc) - timedelta(days=1),
created_at=datetime.now(timezone.utc) - timedelta(days=2),
last_used_at=datetime.now(timezone.utc)
expires_at=datetime.now(UTC) - timedelta(days=1),
created_at=datetime.now(UTC) - timedelta(days=2),
last_used_at=datetime.now(UTC),
)
session.add(inactive)
@@ -239,16 +271,20 @@ class TestGetSessionStatistics:
ip_address="192.168.1.3",
user_agent="Mozilla/5.0",
is_active=True,
expires_at=datetime.now(timezone.utc) - timedelta(hours=1),
created_at=datetime.now(timezone.utc) - timedelta(days=1),
last_used_at=datetime.now(timezone.utc)
expires_at=datetime.now(UTC) - timedelta(hours=1),
created_at=datetime.now(UTC) - timedelta(days=1),
last_used_at=datetime.now(UTC),
)
session.add(expired_active)
await session.commit()
with patch('app.services.session_cleanup.SessionLocal', return_value=AsyncTestingSessionLocal()):
with patch(
"app.services.session_cleanup.SessionLocal",
return_value=AsyncTestingSessionLocal(),
):
from app.services.session_cleanup import get_session_statistics
stats = await get_session_statistics()
assert stats["total"] == 6
@@ -259,10 +295,14 @@ class TestGetSessionStatistics:
@pytest.mark.asyncio
async def test_get_statistics_empty_database(self, async_test_db):
"""Test getting statistics with no sessions."""
test_engine, AsyncTestingSessionLocal = async_test_db
_test_engine, AsyncTestingSessionLocal = async_test_db
with patch('app.services.session_cleanup.SessionLocal', return_value=AsyncTestingSessionLocal()):
with patch(
"app.services.session_cleanup.SessionLocal",
return_value=AsyncTestingSessionLocal(),
):
from app.services.session_cleanup import get_session_statistics
stats = await get_session_statistics()
assert stats["total"] == 0
@@ -271,9 +311,11 @@ class TestGetSessionStatistics:
assert stats["expired"] == 0
@pytest.mark.asyncio
async def test_get_statistics_database_error_returns_empty_dict(self, async_test_db):
async def test_get_statistics_database_error_returns_empty_dict(
self, async_test_db
):
"""Test statistics returns empty dict on database errors."""
test_engine, AsyncTestingSessionLocal = async_test_db
_test_engine, _AsyncTestingSessionLocal = async_test_db
# Create a mock that raises on execute
mock_session = AsyncMock()
@@ -283,8 +325,12 @@ class TestGetSessionStatistics:
async def mock_session_local():
yield mock_session
with patch('app.services.session_cleanup.SessionLocal', return_value=mock_session_local()):
with patch(
"app.services.session_cleanup.SessionLocal",
return_value=mock_session_local(),
):
from app.services.session_cleanup import get_session_statistics
stats = await get_session_statistics()
assert stats == {}
@@ -294,9 +340,11 @@ class TestConcurrentCleanup:
"""Tests for concurrent cleanup scenarios."""
@pytest.mark.asyncio
async def test_concurrent_cleanup_no_duplicate_deletes(self, async_test_db, async_test_user):
async def test_concurrent_cleanup_no_duplicate_deletes(
self, async_test_db, async_test_user
):
"""Test concurrent cleanups don't cause race conditions."""
test_engine, AsyncTestingSessionLocal = async_test_db
_test_engine, AsyncTestingSessionLocal = async_test_db
# Create 10 expired sessions
async with AsyncTestingSessionLocal() as session:
@@ -308,20 +356,24 @@ class TestConcurrentCleanup:
ip_address="192.168.1.1",
user_agent="Mozilla/5.0",
is_active=False,
expires_at=datetime.now(timezone.utc) - timedelta(days=10),
created_at=datetime.now(timezone.utc) - timedelta(days=40),
last_used_at=datetime.now(timezone.utc)
expires_at=datetime.now(UTC) - timedelta(days=10),
created_at=datetime.now(UTC) - timedelta(days=40),
last_used_at=datetime.now(UTC),
)
session.add(expired)
await session.commit()
# Run two cleanups concurrently
# Use side_effect to return fresh session instances for each call
with patch('app.services.session_cleanup.SessionLocal', side_effect=lambda: AsyncTestingSessionLocal()):
with patch(
"app.services.session_cleanup.SessionLocal",
side_effect=lambda: AsyncTestingSessionLocal(),
):
from app.services.session_cleanup import cleanup_expired_sessions
results = await asyncio.gather(
cleanup_expired_sessions(keep_days=30),
cleanup_expired_sessions(keep_days=30)
cleanup_expired_sessions(keep_days=30),
)
# Both should report deleting sessions (may overlap due to transaction timing)

View File

@@ -2,12 +2,13 @@
"""
Tests for database initialization script.
"""
import pytest
import pytest_asyncio
from unittest.mock import AsyncMock, patch
from app.init_db import init_db
from unittest.mock import patch
import pytest
from app.core.config import settings
from app.init_db import init_db
class TestInitDb:
@@ -16,69 +17,86 @@ class TestInitDb:
@pytest.mark.asyncio
async def test_init_db_creates_superuser_when_not_exists(self, async_test_db):
"""Test that init_db creates a superuser when one doesn't exist."""
test_engine, SessionLocal = async_test_db
_test_engine, SessionLocal = async_test_db
# Mock the SessionLocal to use our test database
with patch('app.init_db.SessionLocal', SessionLocal):
with patch("app.init_db.SessionLocal", SessionLocal):
# Mock settings to provide test credentials
with patch.object(settings, 'FIRST_SUPERUSER_EMAIL', 'test_admin@example.com'):
with patch.object(settings, 'FIRST_SUPERUSER_PASSWORD', 'TestAdmin123!'):
with patch.object(
settings, "FIRST_SUPERUSER_EMAIL", "test_admin@example.com"
):
with patch.object(
settings, "FIRST_SUPERUSER_PASSWORD", "TestAdmin123!"
):
# Run init_db
user = await init_db()
# Verify superuser was created
assert user is not None
assert user.email == 'test_admin@example.com'
assert user.email == "test_admin@example.com"
assert user.is_superuser is True
assert user.first_name == 'Admin'
assert user.last_name == 'User'
assert user.first_name == "Admin"
assert user.last_name == "User"
@pytest.mark.asyncio
async def test_init_db_returns_existing_superuser(self, async_test_db, async_test_user):
async def test_init_db_returns_existing_superuser(
self, async_test_db, async_test_user
):
"""Test that init_db returns existing superuser instead of creating duplicate."""
test_engine, SessionLocal = async_test_db
_test_engine, SessionLocal = async_test_db
# Mock the SessionLocal to use our test database
with patch('app.init_db.SessionLocal', SessionLocal):
with patch("app.init_db.SessionLocal", SessionLocal):
# Mock settings to match async_test_user's email
with patch.object(settings, 'FIRST_SUPERUSER_EMAIL', 'testuser@example.com'):
with patch.object(settings, 'FIRST_SUPERUSER_PASSWORD', 'TestPassword123!'):
with patch.object(
settings, "FIRST_SUPERUSER_EMAIL", "testuser@example.com"
):
with patch.object(
settings, "FIRST_SUPERUSER_PASSWORD", "TestPassword123!"
):
# Run init_db
user = await init_db()
# Verify it returns the existing user
assert user is not None
assert user.id == async_test_user.id
assert user.email == 'testuser@example.com'
assert user.email == "testuser@example.com"
@pytest.mark.asyncio
async def test_init_db_uses_default_credentials(self, async_test_db):
"""Test that init_db uses default credentials when env vars not set."""
test_engine, SessionLocal = async_test_db
_test_engine, SessionLocal = async_test_db
# Mock the SessionLocal to use our test database
with patch('app.init_db.SessionLocal', SessionLocal):
with patch("app.init_db.SessionLocal", SessionLocal):
# Mock settings to have None values (not configured)
with patch.object(settings, 'FIRST_SUPERUSER_EMAIL', None):
with patch.object(settings, 'FIRST_SUPERUSER_PASSWORD', None):
with patch.object(settings, "FIRST_SUPERUSER_EMAIL", None):
with patch.object(settings, "FIRST_SUPERUSER_PASSWORD", None):
# Run init_db
user = await init_db()
# Verify superuser was created with defaults
assert user is not None
assert user.email == 'admin@example.com'
assert user.email == "admin@example.com"
assert user.is_superuser is True
@pytest.mark.asyncio
async def test_init_db_handles_database_errors(self, async_test_db):
"""Test that init_db handles database errors gracefully."""
test_engine, SessionLocal = async_test_db
_test_engine, SessionLocal = async_test_db
# Mock user_crud.get_by_email to raise an exception
with patch('app.init_db.user_crud.get_by_email', side_effect=Exception("Database error")):
with patch('app.init_db.SessionLocal', SessionLocal):
with patch.object(settings, 'FIRST_SUPERUSER_EMAIL', 'test@example.com'):
with patch.object(settings, 'FIRST_SUPERUSER_PASSWORD', 'TestPassword123!'):
with patch(
"app.init_db.user_crud.get_by_email",
side_effect=Exception("Database error"),
):
with patch("app.init_db.SessionLocal", SessionLocal):
with patch.object(
settings, "FIRST_SUPERUSER_EMAIL", "test@example.com"
):
with patch.object(
settings, "FIRST_SUPERUSER_PASSWORD", "TestPassword123!"
):
# Run init_db and expect it to raise
with pytest.raises(Exception, match="Database error"):
await init_db()

View File

@@ -2,18 +2,18 @@
"""
Comprehensive tests for device utility functions.
"""
import pytest
from unittest.mock import Mock
from fastapi import Request
from app.utils.device import (
extract_device_info,
parse_device_name,
extract_browser,
extract_device_info,
get_client_ip,
get_device_type,
is_mobile_device,
get_device_type
parse_device_name,
)
@@ -138,7 +138,9 @@ class TestExtractBrowser:
def test_extract_browser_edge_legacy(self):
"""Test extracting legacy Edge browser."""
ua = "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 Edge/18.19582"
ua = (
"Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 Edge/18.19582"
)
result = extract_browser(ua)
assert result == "Edge"
@@ -249,7 +251,7 @@ class TestGetClientIp:
request = Mock(spec=Request)
request.headers = {
"x-forwarded-for": "192.168.1.100",
"x-real-ip": "192.168.1.200"
"x-real-ip": "192.168.1.200",
}
request.client = Mock()
request.client.host = "192.168.1.50"
@@ -385,7 +387,7 @@ class TestExtractDeviceInfo:
request.headers = {
"user-agent": "Mozilla/5.0 (iPhone; CPU iPhone OS 15_0 like Mac OS X)",
"x-device-id": "device-123-456",
"x-forwarded-for": "192.168.1.100"
"x-forwarded-for": "192.168.1.100",
}
request.client = None

View File

@@ -2,19 +2,21 @@
"""
Tests for security utility functions.
"""
import time
import base64
import json
import time
from unittest.mock import MagicMock, patch
import pytest
from unittest.mock import patch, MagicMock
from app.utils.security import (
create_upload_token,
verify_upload_token,
create_password_reset_token,
verify_password_reset_token,
create_email_verification_token,
verify_email_verification_token
create_password_reset_token,
create_upload_token,
verify_email_verification_token,
verify_password_reset_token,
verify_upload_token,
)
@@ -31,7 +33,7 @@ class TestCreateUploadToken:
# Token should be base64 encoded
try:
decoded = base64.urlsafe_b64decode(token.encode('utf-8'))
decoded = base64.urlsafe_b64decode(token.encode("utf-8"))
token_data = json.loads(decoded)
assert "payload" in token_data
assert "signature" in token_data
@@ -46,7 +48,7 @@ class TestCreateUploadToken:
token = create_upload_token(file_path, content_type)
# Decode and verify payload
decoded = base64.urlsafe_b64decode(token.encode('utf-8'))
decoded = base64.urlsafe_b64decode(token.encode("utf-8"))
token_data = json.loads(decoded)
payload = token_data["payload"]
@@ -62,7 +64,7 @@ class TestCreateUploadToken:
after = int(time.time())
# Decode token
decoded = base64.urlsafe_b64decode(token.encode('utf-8'))
decoded = base64.urlsafe_b64decode(token.encode("utf-8"))
token_data = json.loads(decoded)
payload = token_data["payload"]
@@ -74,11 +76,13 @@ class TestCreateUploadToken:
"""Test token creation with custom expiration time."""
custom_exp = 600 # 10 minutes
before = int(time.time())
token = create_upload_token("/uploads/test.jpg", "image/jpeg", expires_in=custom_exp)
token = create_upload_token(
"/uploads/test.jpg", "image/jpeg", expires_in=custom_exp
)
after = int(time.time())
# Decode token
decoded = base64.urlsafe_b64decode(token.encode('utf-8'))
decoded = base64.urlsafe_b64decode(token.encode("utf-8"))
token_data = json.loads(decoded)
payload = token_data["payload"]
@@ -92,11 +96,11 @@ class TestCreateUploadToken:
token2 = create_upload_token("/uploads/test.jpg", "image/jpeg")
# Decode both tokens
decoded1 = base64.urlsafe_b64decode(token1.encode('utf-8'))
decoded1 = base64.urlsafe_b64decode(token1.encode("utf-8"))
token_data1 = json.loads(decoded1)
nonce1 = token_data1["payload"]["nonce"]
decoded2 = base64.urlsafe_b64decode(token2.encode('utf-8'))
decoded2 = base64.urlsafe_b64decode(token2.encode("utf-8"))
token_data2 = json.loads(decoded2)
nonce2 = token_data2["payload"]["nonce"]
@@ -133,7 +137,7 @@ class TestVerifyUploadToken:
current_time = 1000000
mock_time.time = MagicMock(return_value=current_time)
with patch('app.utils.security.time', mock_time):
with patch("app.utils.security.time", mock_time):
# Create token that "expires" at current_time + 1
token = create_upload_token("/uploads/test.jpg", "image/jpeg", expires_in=1)
@@ -149,13 +153,15 @@ class TestVerifyUploadToken:
token = create_upload_token("/uploads/test.jpg", "image/jpeg")
# Decode, modify, and re-encode
decoded = base64.urlsafe_b64decode(token.encode('utf-8'))
decoded = base64.urlsafe_b64decode(token.encode("utf-8"))
token_data = json.loads(decoded)
token_data["signature"] = "invalid_signature"
# Re-encode the tampered token
tampered_json = json.dumps(token_data)
tampered_token = base64.urlsafe_b64encode(tampered_json.encode('utf-8')).decode('utf-8')
tampered_token = base64.urlsafe_b64encode(tampered_json.encode("utf-8")).decode(
"utf-8"
)
payload = verify_upload_token(tampered_token)
assert payload is None
@@ -165,13 +171,15 @@ class TestVerifyUploadToken:
token = create_upload_token("/uploads/test.jpg", "image/jpeg")
# Decode, modify payload, and re-encode
decoded = base64.urlsafe_b64decode(token.encode('utf-8'))
decoded = base64.urlsafe_b64decode(token.encode("utf-8"))
token_data = json.loads(decoded)
token_data["payload"]["path"] = "/uploads/hacked.exe"
# Re-encode the tampered token (signature won't match)
tampered_json = json.dumps(token_data)
tampered_token = base64.urlsafe_b64encode(tampered_json.encode('utf-8')).decode('utf-8')
tampered_token = base64.urlsafe_b64encode(tampered_json.encode("utf-8")).decode(
"utf-8"
)
payload = verify_upload_token(tampered_token)
assert payload is None
@@ -194,7 +202,9 @@ class TestVerifyUploadToken:
"""Test that tokens with invalid JSON are rejected."""
# Create a base64 string that decodes to invalid JSON
invalid_json = "not valid json"
invalid_token = base64.urlsafe_b64encode(invalid_json.encode('utf-8')).decode('utf-8')
invalid_token = base64.urlsafe_b64encode(invalid_json.encode("utf-8")).decode(
"utf-8"
)
payload = verify_upload_token(invalid_token)
assert payload is None
@@ -207,11 +217,13 @@ class TestVerifyUploadToken:
"path": "/uploads/test.jpg"
# Missing content_type, exp, nonce
},
"signature": "some_signature"
"signature": "some_signature",
}
incomplete_json = json.dumps(incomplete_data)
incomplete_token = base64.urlsafe_b64encode(incomplete_json.encode('utf-8')).decode('utf-8')
incomplete_token = base64.urlsafe_b64encode(
incomplete_json.encode("utf-8")
).decode("utf-8")
payload = verify_upload_token(incomplete_token)
assert payload is None
@@ -266,7 +278,7 @@ class TestPasswordResetTokens:
email = "user@example.com"
# Create token that expires in 1 second
with patch('app.utils.security.time') as mock_time:
with patch("app.utils.security.time") as mock_time:
mock_time.time = MagicMock(return_value=1000000)
token = create_password_reset_token(email, expires_in=1)
@@ -287,12 +299,14 @@ class TestPasswordResetTokens:
token = create_password_reset_token(email)
# Decode and tamper
decoded = base64.urlsafe_b64decode(token.encode('utf-8')).decode('utf-8')
decoded = base64.urlsafe_b64decode(token.encode("utf-8")).decode("utf-8")
token_data = json.loads(decoded)
token_data["payload"]["email"] = "hacker@example.com"
# Re-encode
tampered = base64.urlsafe_b64encode(json.dumps(token_data).encode('utf-8')).decode('utf-8')
tampered = base64.urlsafe_b64encode(
json.dumps(token_data).encode("utf-8")
).decode("utf-8")
verified_email = verify_password_reset_token(tampered)
assert verified_email is None
@@ -312,14 +326,14 @@ class TestPasswordResetTokens:
email = "user@example.com"
custom_exp = 7200 # 2 hours
with patch('app.utils.security.time') as mock_time:
with patch("app.utils.security.time") as mock_time:
current_time = 1000000
mock_time.time = MagicMock(return_value=current_time)
token = create_password_reset_token(email, expires_in=custom_exp)
# Decode to check expiration
decoded = base64.urlsafe_b64decode(token.encode('utf-8')).decode('utf-8')
decoded = base64.urlsafe_b64decode(token.encode("utf-8")).decode("utf-8")
token_data = json.loads(decoded)
assert token_data["payload"]["exp"] == current_time + custom_exp
@@ -350,7 +364,7 @@ class TestEmailVerificationTokens:
"""Test that expired verification tokens are rejected."""
email = "user@example.com"
with patch('app.utils.security.time') as mock_time:
with patch("app.utils.security.time") as mock_time:
mock_time.time = MagicMock(return_value=1000000)
token = create_email_verification_token(email, expires_in=1)
@@ -371,12 +385,14 @@ class TestEmailVerificationTokens:
token = create_email_verification_token(email)
# Decode and tamper
decoded = base64.urlsafe_b64decode(token.encode('utf-8')).decode('utf-8')
decoded = base64.urlsafe_b64decode(token.encode("utf-8")).decode("utf-8")
token_data = json.loads(decoded)
token_data["payload"]["email"] = "hacker@example.com"
# Re-encode
tampered = base64.urlsafe_b64encode(json.dumps(token_data).encode('utf-8')).decode('utf-8')
tampered = base64.urlsafe_b64encode(
json.dumps(token_data).encode("utf-8")
).decode("utf-8")
verified_email = verify_email_verification_token(tampered)
assert verified_email is None
@@ -395,14 +411,14 @@ class TestEmailVerificationTokens:
"""Test email verification token with default 24-hour expiration."""
email = "user@example.com"
with patch('app.utils.security.time') as mock_time:
with patch("app.utils.security.time") as mock_time:
current_time = 1000000
mock_time.time = MagicMock(return_value=current_time)
token = create_email_verification_token(email)
# Decode to check expiration (should be 86400 seconds = 24 hours)
decoded = base64.urlsafe_b64decode(token.encode('utf-8')).decode('utf-8')
decoded = base64.urlsafe_b64decode(token.encode("utf-8")).decode("utf-8")
token_data = json.loads(decoded)
assert token_data["payload"]["exp"] == current_time + 86400