Add pyproject.toml for consolidated project configuration and replace Black, isort, and Flake8 with Ruff
- Introduced `pyproject.toml` to centralize backend tool configurations (e.g., Ruff, mypy, coverage, pytest). - Replaced Black, isort, and Flake8 with Ruff for linting, formatting, and import sorting. - Updated `requirements.txt` to include Ruff and remove replaced tools. - Added `Makefile` to streamline development workflows with commands for linting, formatting, type-checking, testing, and cleanup.
This commit is contained in:
@@ -2,12 +2,11 @@ import sys
|
||||
from logging.config import fileConfig
|
||||
from pathlib import Path
|
||||
|
||||
from sqlalchemy import engine_from_config, pool, text, create_engine
|
||||
from alembic import context
|
||||
from sqlalchemy import create_engine, engine_from_config, pool, text
|
||||
from sqlalchemy.engine.url import make_url
|
||||
from sqlalchemy.exc import OperationalError
|
||||
|
||||
from alembic import context
|
||||
|
||||
# Get the path to the app directory (parent of 'alembic')
|
||||
app_dir = Path(__file__).resolve().parent.parent
|
||||
# Add the app directory to Python path
|
||||
@@ -66,7 +65,9 @@ def ensure_database_exists(db_url: str) -> None:
|
||||
admin_url = url.set(database="postgres")
|
||||
|
||||
# CREATE DATABASE cannot run inside a transaction
|
||||
admin_engine = create_engine(str(admin_url), isolation_level="AUTOCOMMIT", poolclass=pool.NullPool)
|
||||
admin_engine = create_engine(
|
||||
str(admin_url), isolation_level="AUTOCOMMIT", poolclass=pool.NullPool
|
||||
)
|
||||
try:
|
||||
with admin_engine.connect() as conn:
|
||||
exists = conn.execute(
|
||||
@@ -122,9 +123,7 @@ def run_migrations_online() -> None:
|
||||
)
|
||||
|
||||
with connectable.connect() as connection:
|
||||
context.configure(
|
||||
connection=connection, target_metadata=target_metadata
|
||||
)
|
||||
context.configure(connection=connection, target_metadata=target_metadata)
|
||||
|
||||
with context.begin_transaction():
|
||||
context.run_migrations()
|
||||
@@ -133,4 +132,4 @@ def run_migrations_online() -> None:
|
||||
if context.is_offline_mode():
|
||||
run_migrations_offline()
|
||||
else:
|
||||
run_migrations_online()
|
||||
run_migrations_online()
|
||||
|
||||
@@ -5,17 +5,17 @@ Revises: fbf6318a8a36
|
||||
Create Date: 2025-11-01 04:15:25.367010
|
||||
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
|
||||
from collections.abc import Sequence
|
||||
|
||||
import sqlalchemy as sa
|
||||
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = '1174fffbe3e4'
|
||||
down_revision: Union[str, None] = 'fbf6318a8a36'
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
revision: str = "1174fffbe3e4"
|
||||
down_revision: str | None = "fbf6318a8a36"
|
||||
branch_labels: str | Sequence[str] | None = None
|
||||
depends_on: str | Sequence[str] | None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
@@ -24,46 +24,46 @@ def upgrade() -> None:
|
||||
# Index for session cleanup queries
|
||||
# Optimizes: DELETE WHERE is_active = FALSE AND expires_at < now AND created_at < cutoff
|
||||
op.create_index(
|
||||
'ix_user_sessions_cleanup',
|
||||
'user_sessions',
|
||||
['is_active', 'expires_at', 'created_at'],
|
||||
"ix_user_sessions_cleanup",
|
||||
"user_sessions",
|
||||
["is_active", "expires_at", "created_at"],
|
||||
unique=False,
|
||||
postgresql_where=sa.text('is_active = false')
|
||||
postgresql_where=sa.text("is_active = false"),
|
||||
)
|
||||
|
||||
# Index for user search queries (basic trigram support without pg_trgm extension)
|
||||
# Optimizes: WHERE email ILIKE '%search%' OR first_name ILIKE '%search%'
|
||||
# Note: For better performance, consider enabling pg_trgm extension
|
||||
op.create_index(
|
||||
'ix_users_email_lower',
|
||||
'users',
|
||||
[sa.text('LOWER(email)')],
|
||||
"ix_users_email_lower",
|
||||
"users",
|
||||
[sa.text("LOWER(email)")],
|
||||
unique=False,
|
||||
postgresql_where=sa.text('deleted_at IS NULL')
|
||||
postgresql_where=sa.text("deleted_at IS NULL"),
|
||||
)
|
||||
|
||||
op.create_index(
|
||||
'ix_users_first_name_lower',
|
||||
'users',
|
||||
[sa.text('LOWER(first_name)')],
|
||||
"ix_users_first_name_lower",
|
||||
"users",
|
||||
[sa.text("LOWER(first_name)")],
|
||||
unique=False,
|
||||
postgresql_where=sa.text('deleted_at IS NULL')
|
||||
postgresql_where=sa.text("deleted_at IS NULL"),
|
||||
)
|
||||
|
||||
op.create_index(
|
||||
'ix_users_last_name_lower',
|
||||
'users',
|
||||
[sa.text('LOWER(last_name)')],
|
||||
"ix_users_last_name_lower",
|
||||
"users",
|
||||
[sa.text("LOWER(last_name)")],
|
||||
unique=False,
|
||||
postgresql_where=sa.text('deleted_at IS NULL')
|
||||
postgresql_where=sa.text("deleted_at IS NULL"),
|
||||
)
|
||||
|
||||
# Index for organization search
|
||||
op.create_index(
|
||||
'ix_organizations_name_lower',
|
||||
'organizations',
|
||||
[sa.text('LOWER(name)')],
|
||||
unique=False
|
||||
"ix_organizations_name_lower",
|
||||
"organizations",
|
||||
[sa.text("LOWER(name)")],
|
||||
unique=False,
|
||||
)
|
||||
|
||||
|
||||
@@ -71,8 +71,8 @@ def downgrade() -> None:
|
||||
"""Remove performance indexes."""
|
||||
|
||||
# Drop indexes in reverse order
|
||||
op.drop_index('ix_organizations_name_lower', table_name='organizations')
|
||||
op.drop_index('ix_users_last_name_lower', table_name='users')
|
||||
op.drop_index('ix_users_first_name_lower', table_name='users')
|
||||
op.drop_index('ix_users_email_lower', table_name='users')
|
||||
op.drop_index('ix_user_sessions_cleanup', table_name='user_sessions')
|
||||
op.drop_index("ix_organizations_name_lower", table_name="organizations")
|
||||
op.drop_index("ix_users_last_name_lower", table_name="users")
|
||||
op.drop_index("ix_users_first_name_lower", table_name="users")
|
||||
op.drop_index("ix_users_email_lower", table_name="users")
|
||||
op.drop_index("ix_user_sessions_cleanup", table_name="user_sessions")
|
||||
|
||||
@@ -5,30 +5,32 @@ Revises: 9e4f2a1b8c7d
|
||||
Create Date: 2025-10-30 16:40:21.000021
|
||||
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
|
||||
from collections.abc import Sequence
|
||||
|
||||
import sqlalchemy as sa
|
||||
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = '2d0fcec3b06d'
|
||||
down_revision: Union[str, None] = '9e4f2a1b8c7d'
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
revision: str = "2d0fcec3b06d"
|
||||
down_revision: str | None = "9e4f2a1b8c7d"
|
||||
branch_labels: str | Sequence[str] | None = None
|
||||
depends_on: str | Sequence[str] | None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Add deleted_at column for soft deletes
|
||||
op.add_column('users', sa.Column('deleted_at', sa.DateTime(timezone=True), nullable=True))
|
||||
op.add_column(
|
||||
"users", sa.Column("deleted_at", sa.DateTime(timezone=True), nullable=True)
|
||||
)
|
||||
|
||||
# Add index on deleted_at for efficient queries
|
||||
op.create_index('ix_users_deleted_at', 'users', ['deleted_at'])
|
||||
op.create_index("ix_users_deleted_at", "users", ["deleted_at"])
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Remove index
|
||||
op.drop_index('ix_users_deleted_at', table_name='users')
|
||||
op.drop_index("ix_users_deleted_at", table_name="users")
|
||||
|
||||
# Remove column
|
||||
op.drop_column('users', 'deleted_at')
|
||||
op.drop_column("users", "deleted_at")
|
||||
|
||||
@@ -5,42 +5,42 @@ Revises: 7396957cbe80
|
||||
Create Date: 2025-02-28 09:19:33.212278
|
||||
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
|
||||
from collections.abc import Sequence
|
||||
|
||||
import sqlalchemy as sa
|
||||
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = '38bf9e7e74b3'
|
||||
down_revision: Union[str, None] = '7396957cbe80'
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
revision: str = "38bf9e7e74b3"
|
||||
down_revision: str | None = "7396957cbe80"
|
||||
branch_labels: str | Sequence[str] | None = None
|
||||
depends_on: str | Sequence[str] | None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
|
||||
op.create_table('users',
|
||||
sa.Column('email', sa.String(), nullable=False),
|
||||
sa.Column('password_hash', sa.String(), nullable=False),
|
||||
sa.Column('first_name', sa.String(), nullable=False),
|
||||
sa.Column('last_name', sa.String(), nullable=True),
|
||||
sa.Column('phone_number', sa.String(), nullable=True),
|
||||
sa.Column('is_active', sa.Boolean(), nullable=False),
|
||||
sa.Column('is_superuser', sa.Boolean(), nullable=False),
|
||||
sa.Column('preferences', sa.JSON(), nullable=True),
|
||||
sa.Column('id', sa.UUID(), nullable=False),
|
||||
sa.Column('created_at', sa.DateTime(timezone=True), nullable=False),
|
||||
sa.Column('updated_at', sa.DateTime(timezone=True), nullable=False),
|
||||
sa.PrimaryKeyConstraint('id')
|
||||
op.create_table(
|
||||
"users",
|
||||
sa.Column("email", sa.String(), nullable=False),
|
||||
sa.Column("password_hash", sa.String(), nullable=False),
|
||||
sa.Column("first_name", sa.String(), nullable=False),
|
||||
sa.Column("last_name", sa.String(), nullable=True),
|
||||
sa.Column("phone_number", sa.String(), nullable=True),
|
||||
sa.Column("is_active", sa.Boolean(), nullable=False),
|
||||
sa.Column("is_superuser", sa.Boolean(), nullable=False),
|
||||
sa.Column("preferences", sa.JSON(), nullable=True),
|
||||
sa.Column("id", sa.UUID(), nullable=False),
|
||||
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False),
|
||||
sa.Column("updated_at", sa.DateTime(timezone=True), nullable=False),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
)
|
||||
op.create_index(op.f('ix_users_email'), 'users', ['email'], unique=True)
|
||||
op.create_index(op.f("ix_users_email"), "users", ["email"], unique=True)
|
||||
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.drop_index(op.f('ix_users_email'), table_name='users')
|
||||
op.drop_table('users')
|
||||
op.drop_index(op.f("ix_users_email"), table_name="users")
|
||||
op.drop_table("users")
|
||||
# ### end Alembic commands ###
|
||||
|
||||
@@ -5,98 +5,85 @@ Revises: b76c725fc3cf
|
||||
Create Date: 2025-10-31 07:41:18.729544
|
||||
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
|
||||
from collections.abc import Sequence
|
||||
|
||||
import sqlalchemy as sa
|
||||
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = '549b50ea888d'
|
||||
down_revision: Union[str, None] = 'b76c725fc3cf'
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
revision: str = "549b50ea888d"
|
||||
down_revision: str | None = "b76c725fc3cf"
|
||||
branch_labels: str | Sequence[str] | None = None
|
||||
depends_on: str | Sequence[str] | None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Create user_sessions table for per-device session management
|
||||
op.create_table(
|
||||
'user_sessions',
|
||||
sa.Column('id', sa.UUID(), nullable=False),
|
||||
sa.Column('user_id', sa.UUID(), nullable=False),
|
||||
sa.Column('refresh_token_jti', sa.String(length=255), nullable=False),
|
||||
sa.Column('device_name', sa.String(length=255), nullable=True),
|
||||
sa.Column('device_id', sa.String(length=255), nullable=True),
|
||||
sa.Column('ip_address', sa.String(length=45), nullable=True),
|
||||
sa.Column('user_agent', sa.String(length=500), nullable=True),
|
||||
sa.Column('last_used_at', sa.DateTime(timezone=True), nullable=False),
|
||||
sa.Column('expires_at', sa.DateTime(timezone=True), nullable=False),
|
||||
sa.Column('is_active', sa.Boolean(), nullable=False, server_default='true'),
|
||||
sa.Column('location_city', sa.String(length=100), nullable=True),
|
||||
sa.Column('location_country', sa.String(length=100), nullable=True),
|
||||
sa.Column('created_at', sa.DateTime(timezone=True), nullable=False),
|
||||
sa.Column('updated_at', sa.DateTime(timezone=True), nullable=False),
|
||||
sa.PrimaryKeyConstraint('id')
|
||||
"user_sessions",
|
||||
sa.Column("id", sa.UUID(), nullable=False),
|
||||
sa.Column("user_id", sa.UUID(), nullable=False),
|
||||
sa.Column("refresh_token_jti", sa.String(length=255), nullable=False),
|
||||
sa.Column("device_name", sa.String(length=255), nullable=True),
|
||||
sa.Column("device_id", sa.String(length=255), nullable=True),
|
||||
sa.Column("ip_address", sa.String(length=45), nullable=True),
|
||||
sa.Column("user_agent", sa.String(length=500), nullable=True),
|
||||
sa.Column("last_used_at", sa.DateTime(timezone=True), nullable=False),
|
||||
sa.Column("expires_at", sa.DateTime(timezone=True), nullable=False),
|
||||
sa.Column("is_active", sa.Boolean(), nullable=False, server_default="true"),
|
||||
sa.Column("location_city", sa.String(length=100), nullable=True),
|
||||
sa.Column("location_country", sa.String(length=100), nullable=True),
|
||||
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False),
|
||||
sa.Column("updated_at", sa.DateTime(timezone=True), nullable=False),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
)
|
||||
|
||||
# Create foreign key to users table
|
||||
op.create_foreign_key(
|
||||
'fk_user_sessions_user_id',
|
||||
'user_sessions',
|
||||
'users',
|
||||
['user_id'],
|
||||
['id'],
|
||||
ondelete='CASCADE'
|
||||
"fk_user_sessions_user_id",
|
||||
"user_sessions",
|
||||
"users",
|
||||
["user_id"],
|
||||
["id"],
|
||||
ondelete="CASCADE",
|
||||
)
|
||||
|
||||
# Create indexes for performance
|
||||
# 1. Lookup session by refresh token JTI (most common query)
|
||||
op.create_index(
|
||||
'ix_user_sessions_jti',
|
||||
'user_sessions',
|
||||
['refresh_token_jti'],
|
||||
unique=True
|
||||
"ix_user_sessions_jti", "user_sessions", ["refresh_token_jti"], unique=True
|
||||
)
|
||||
|
||||
# 2. Lookup sessions by user ID
|
||||
op.create_index(
|
||||
'ix_user_sessions_user_id',
|
||||
'user_sessions',
|
||||
['user_id']
|
||||
)
|
||||
op.create_index("ix_user_sessions_user_id", "user_sessions", ["user_id"])
|
||||
|
||||
# 3. Composite index for active sessions by user
|
||||
op.create_index(
|
||||
'ix_user_sessions_user_active',
|
||||
'user_sessions',
|
||||
['user_id', 'is_active']
|
||||
"ix_user_sessions_user_active", "user_sessions", ["user_id", "is_active"]
|
||||
)
|
||||
|
||||
# 4. Index on expires_at for cleanup job
|
||||
op.create_index(
|
||||
'ix_user_sessions_expires_at',
|
||||
'user_sessions',
|
||||
['expires_at']
|
||||
)
|
||||
op.create_index("ix_user_sessions_expires_at", "user_sessions", ["expires_at"])
|
||||
|
||||
# 5. Composite index for active session lookup by JTI
|
||||
op.create_index(
|
||||
'ix_user_sessions_jti_active',
|
||||
'user_sessions',
|
||||
['refresh_token_jti', 'is_active']
|
||||
"ix_user_sessions_jti_active",
|
||||
"user_sessions",
|
||||
["refresh_token_jti", "is_active"],
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Drop indexes first
|
||||
op.drop_index('ix_user_sessions_jti_active', table_name='user_sessions')
|
||||
op.drop_index('ix_user_sessions_expires_at', table_name='user_sessions')
|
||||
op.drop_index('ix_user_sessions_user_active', table_name='user_sessions')
|
||||
op.drop_index('ix_user_sessions_user_id', table_name='user_sessions')
|
||||
op.drop_index('ix_user_sessions_jti', table_name='user_sessions')
|
||||
op.drop_index("ix_user_sessions_jti_active", table_name="user_sessions")
|
||||
op.drop_index("ix_user_sessions_expires_at", table_name="user_sessions")
|
||||
op.drop_index("ix_user_sessions_user_active", table_name="user_sessions")
|
||||
op.drop_index("ix_user_sessions_user_id", table_name="user_sessions")
|
||||
op.drop_index("ix_user_sessions_jti", table_name="user_sessions")
|
||||
|
||||
# Drop foreign key
|
||||
op.drop_constraint('fk_user_sessions_user_id', 'user_sessions', type_='foreignkey')
|
||||
op.drop_constraint("fk_user_sessions_user_id", "user_sessions", type_="foreignkey")
|
||||
|
||||
# Drop table
|
||||
op.drop_table('user_sessions')
|
||||
op.drop_table("user_sessions")
|
||||
|
||||
@@ -1,19 +1,18 @@
|
||||
"""Initial empty migration
|
||||
|
||||
Revision ID: 7396957cbe80
|
||||
Revises:
|
||||
Revises:
|
||||
Create Date: 2025-02-27 12:47:46.445313
|
||||
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
from collections.abc import Sequence
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = '7396957cbe80'
|
||||
down_revision: Union[str, None] = None
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
revision: str = "7396957cbe80"
|
||||
down_revision: str | None = None
|
||||
branch_labels: str | Sequence[str] | None = None
|
||||
depends_on: str | Sequence[str] | None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
|
||||
@@ -5,80 +5,112 @@ Revises: 38bf9e7e74b3
|
||||
Create Date: 2025-10-30 10:00:00.000000
|
||||
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
|
||||
from collections.abc import Sequence
|
||||
|
||||
import sqlalchemy as sa
|
||||
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = '9e4f2a1b8c7d'
|
||||
down_revision: Union[str, None] = '38bf9e7e74b3'
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
revision: str = "9e4f2a1b8c7d"
|
||||
down_revision: str | None = "38bf9e7e74b3"
|
||||
branch_labels: str | Sequence[str] | None = None
|
||||
depends_on: str | Sequence[str] | None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Add missing indexes for is_active and is_superuser
|
||||
op.create_index(op.f('ix_users_is_active'), 'users', ['is_active'], unique=False)
|
||||
op.create_index(op.f('ix_users_is_superuser'), 'users', ['is_superuser'], unique=False)
|
||||
op.create_index(op.f("ix_users_is_active"), "users", ["is_active"], unique=False)
|
||||
op.create_index(
|
||||
op.f("ix_users_is_superuser"), "users", ["is_superuser"], unique=False
|
||||
)
|
||||
|
||||
# Fix column types to match model definitions with explicit lengths
|
||||
op.alter_column('users', 'email',
|
||||
existing_type=sa.String(),
|
||||
type_=sa.String(length=255),
|
||||
nullable=False)
|
||||
op.alter_column(
|
||||
"users",
|
||||
"email",
|
||||
existing_type=sa.String(),
|
||||
type_=sa.String(length=255),
|
||||
nullable=False,
|
||||
)
|
||||
|
||||
op.alter_column('users', 'password_hash',
|
||||
existing_type=sa.String(),
|
||||
type_=sa.String(length=255),
|
||||
nullable=False)
|
||||
op.alter_column(
|
||||
"users",
|
||||
"password_hash",
|
||||
existing_type=sa.String(),
|
||||
type_=sa.String(length=255),
|
||||
nullable=False,
|
||||
)
|
||||
|
||||
op.alter_column('users', 'first_name',
|
||||
existing_type=sa.String(),
|
||||
type_=sa.String(length=100),
|
||||
nullable=False,
|
||||
server_default='user') # Add server default
|
||||
op.alter_column(
|
||||
"users",
|
||||
"first_name",
|
||||
existing_type=sa.String(),
|
||||
type_=sa.String(length=100),
|
||||
nullable=False,
|
||||
server_default="user",
|
||||
) # Add server default
|
||||
|
||||
op.alter_column('users', 'last_name',
|
||||
existing_type=sa.String(),
|
||||
type_=sa.String(length=100),
|
||||
nullable=True)
|
||||
op.alter_column(
|
||||
"users",
|
||||
"last_name",
|
||||
existing_type=sa.String(),
|
||||
type_=sa.String(length=100),
|
||||
nullable=True,
|
||||
)
|
||||
|
||||
op.alter_column('users', 'phone_number',
|
||||
existing_type=sa.String(),
|
||||
type_=sa.String(length=20),
|
||||
nullable=True)
|
||||
op.alter_column(
|
||||
"users",
|
||||
"phone_number",
|
||||
existing_type=sa.String(),
|
||||
type_=sa.String(length=20),
|
||||
nullable=True,
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Revert column types
|
||||
op.alter_column('users', 'phone_number',
|
||||
existing_type=sa.String(length=20),
|
||||
type_=sa.String(),
|
||||
nullable=True)
|
||||
op.alter_column(
|
||||
"users",
|
||||
"phone_number",
|
||||
existing_type=sa.String(length=20),
|
||||
type_=sa.String(),
|
||||
nullable=True,
|
||||
)
|
||||
|
||||
op.alter_column('users', 'last_name',
|
||||
existing_type=sa.String(length=100),
|
||||
type_=sa.String(),
|
||||
nullable=True)
|
||||
op.alter_column(
|
||||
"users",
|
||||
"last_name",
|
||||
existing_type=sa.String(length=100),
|
||||
type_=sa.String(),
|
||||
nullable=True,
|
||||
)
|
||||
|
||||
op.alter_column('users', 'first_name',
|
||||
existing_type=sa.String(length=100),
|
||||
type_=sa.String(),
|
||||
nullable=False,
|
||||
server_default=None) # Remove server default
|
||||
op.alter_column(
|
||||
"users",
|
||||
"first_name",
|
||||
existing_type=sa.String(length=100),
|
||||
type_=sa.String(),
|
||||
nullable=False,
|
||||
server_default=None,
|
||||
) # Remove server default
|
||||
|
||||
op.alter_column('users', 'password_hash',
|
||||
existing_type=sa.String(length=255),
|
||||
type_=sa.String(),
|
||||
nullable=False)
|
||||
op.alter_column(
|
||||
"users",
|
||||
"password_hash",
|
||||
existing_type=sa.String(length=255),
|
||||
type_=sa.String(),
|
||||
nullable=False,
|
||||
)
|
||||
|
||||
op.alter_column('users', 'email',
|
||||
existing_type=sa.String(length=255),
|
||||
type_=sa.String(),
|
||||
nullable=False)
|
||||
op.alter_column(
|
||||
"users",
|
||||
"email",
|
||||
existing_type=sa.String(length=255),
|
||||
type_=sa.String(),
|
||||
nullable=False,
|
||||
)
|
||||
|
||||
# Drop indexes
|
||||
op.drop_index(op.f('ix_users_is_superuser'), table_name='users')
|
||||
op.drop_index(op.f('ix_users_is_active'), table_name='users')
|
||||
op.drop_index(op.f("ix_users_is_superuser"), table_name="users")
|
||||
op.drop_index(op.f("ix_users_is_active"), table_name="users")
|
||||
|
||||
@@ -5,17 +5,17 @@ Revises: 2d0fcec3b06d
|
||||
Create Date: 2025-10-30 16:41:33.273135
|
||||
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
|
||||
from collections.abc import Sequence
|
||||
|
||||
import sqlalchemy as sa
|
||||
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = 'b76c725fc3cf'
|
||||
down_revision: Union[str, None] = '2d0fcec3b06d'
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
revision: str = "b76c725fc3cf"
|
||||
down_revision: str | None = "2d0fcec3b06d"
|
||||
branch_labels: str | Sequence[str] | None = None
|
||||
depends_on: str | Sequence[str] | None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
@@ -23,30 +23,26 @@ def upgrade() -> None:
|
||||
|
||||
# Composite index for filtering active users by role
|
||||
op.create_index(
|
||||
'ix_users_active_superuser',
|
||||
'users',
|
||||
['is_active', 'is_superuser'],
|
||||
postgresql_where=sa.text('deleted_at IS NULL')
|
||||
"ix_users_active_superuser",
|
||||
"users",
|
||||
["is_active", "is_superuser"],
|
||||
postgresql_where=sa.text("deleted_at IS NULL"),
|
||||
)
|
||||
|
||||
# Composite index for sorting active users by creation date
|
||||
op.create_index(
|
||||
'ix_users_active_created',
|
||||
'users',
|
||||
['is_active', 'created_at'],
|
||||
postgresql_where=sa.text('deleted_at IS NULL')
|
||||
"ix_users_active_created",
|
||||
"users",
|
||||
["is_active", "created_at"],
|
||||
postgresql_where=sa.text("deleted_at IS NULL"),
|
||||
)
|
||||
|
||||
# Composite index for email lookup of non-deleted users
|
||||
op.create_index(
|
||||
'ix_users_email_not_deleted',
|
||||
'users',
|
||||
['email', 'deleted_at']
|
||||
)
|
||||
op.create_index("ix_users_email_not_deleted", "users", ["email", "deleted_at"])
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Remove composite indexes
|
||||
op.drop_index('ix_users_email_not_deleted', table_name='users')
|
||||
op.drop_index('ix_users_active_created', table_name='users')
|
||||
op.drop_index('ix_users_active_superuser', table_name='users')
|
||||
op.drop_index("ix_users_email_not_deleted", table_name="users")
|
||||
op.drop_index("ix_users_active_created", table_name="users")
|
||||
op.drop_index("ix_users_active_superuser", table_name="users")
|
||||
|
||||
@@ -5,102 +5,123 @@ Revises: 549b50ea888d
|
||||
Create Date: 2025-10-31 12:08:05.141353
|
||||
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
|
||||
from collections.abc import Sequence
|
||||
|
||||
import sqlalchemy as sa
|
||||
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = 'fbf6318a8a36'
|
||||
down_revision: Union[str, None] = '549b50ea888d'
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
revision: str = "fbf6318a8a36"
|
||||
down_revision: str | None = "549b50ea888d"
|
||||
branch_labels: str | Sequence[str] | None = None
|
||||
depends_on: str | Sequence[str] | None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Create organizations table
|
||||
op.create_table(
|
||||
'organizations',
|
||||
sa.Column('id', sa.UUID(), nullable=False),
|
||||
sa.Column('name', sa.String(length=255), nullable=False),
|
||||
sa.Column('slug', sa.String(length=255), nullable=False),
|
||||
sa.Column('description', sa.Text(), nullable=True),
|
||||
sa.Column('is_active', sa.Boolean(), nullable=False, server_default='true'),
|
||||
sa.Column('settings', sa.JSON(), nullable=True),
|
||||
sa.Column('created_at', sa.DateTime(timezone=True), nullable=False),
|
||||
sa.Column('updated_at', sa.DateTime(timezone=True), nullable=False),
|
||||
sa.PrimaryKeyConstraint('id')
|
||||
"organizations",
|
||||
sa.Column("id", sa.UUID(), nullable=False),
|
||||
sa.Column("name", sa.String(length=255), nullable=False),
|
||||
sa.Column("slug", sa.String(length=255), nullable=False),
|
||||
sa.Column("description", sa.Text(), nullable=True),
|
||||
sa.Column("is_active", sa.Boolean(), nullable=False, server_default="true"),
|
||||
sa.Column("settings", sa.JSON(), nullable=True),
|
||||
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False),
|
||||
sa.Column("updated_at", sa.DateTime(timezone=True), nullable=False),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
)
|
||||
|
||||
# Create indexes for organizations
|
||||
op.create_index('ix_organizations_name', 'organizations', ['name'])
|
||||
op.create_index('ix_organizations_slug', 'organizations', ['slug'], unique=True)
|
||||
op.create_index('ix_organizations_is_active', 'organizations', ['is_active'])
|
||||
op.create_index('ix_organizations_name_active', 'organizations', ['name', 'is_active'])
|
||||
op.create_index('ix_organizations_slug_active', 'organizations', ['slug', 'is_active'])
|
||||
op.create_index("ix_organizations_name", "organizations", ["name"])
|
||||
op.create_index("ix_organizations_slug", "organizations", ["slug"], unique=True)
|
||||
op.create_index("ix_organizations_is_active", "organizations", ["is_active"])
|
||||
op.create_index(
|
||||
"ix_organizations_name_active", "organizations", ["name", "is_active"]
|
||||
)
|
||||
op.create_index(
|
||||
"ix_organizations_slug_active", "organizations", ["slug", "is_active"]
|
||||
)
|
||||
|
||||
# Create user_organizations junction table
|
||||
op.create_table(
|
||||
'user_organizations',
|
||||
sa.Column('user_id', sa.UUID(), nullable=False),
|
||||
sa.Column('organization_id', sa.UUID(), nullable=False),
|
||||
sa.Column('role', sa.Enum('OWNER', 'ADMIN', 'MEMBER', 'GUEST', name='organizationrole'), nullable=False, server_default='MEMBER'),
|
||||
sa.Column('is_active', sa.Boolean(), nullable=False, server_default='true'),
|
||||
sa.Column('custom_permissions', sa.String(length=500), nullable=True),
|
||||
sa.Column('created_at', sa.DateTime(timezone=True), nullable=False),
|
||||
sa.Column('updated_at', sa.DateTime(timezone=True), nullable=False),
|
||||
sa.PrimaryKeyConstraint('user_id', 'organization_id')
|
||||
"user_organizations",
|
||||
sa.Column("user_id", sa.UUID(), nullable=False),
|
||||
sa.Column("organization_id", sa.UUID(), nullable=False),
|
||||
sa.Column(
|
||||
"role",
|
||||
sa.Enum("OWNER", "ADMIN", "MEMBER", "GUEST", name="organizationrole"),
|
||||
nullable=False,
|
||||
server_default="MEMBER",
|
||||
),
|
||||
sa.Column("is_active", sa.Boolean(), nullable=False, server_default="true"),
|
||||
sa.Column("custom_permissions", sa.String(length=500), nullable=True),
|
||||
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False),
|
||||
sa.Column("updated_at", sa.DateTime(timezone=True), nullable=False),
|
||||
sa.PrimaryKeyConstraint("user_id", "organization_id"),
|
||||
)
|
||||
|
||||
# Create foreign keys
|
||||
op.create_foreign_key(
|
||||
'fk_user_organizations_user_id',
|
||||
'user_organizations',
|
||||
'users',
|
||||
['user_id'],
|
||||
['id'],
|
||||
ondelete='CASCADE'
|
||||
"fk_user_organizations_user_id",
|
||||
"user_organizations",
|
||||
"users",
|
||||
["user_id"],
|
||||
["id"],
|
||||
ondelete="CASCADE",
|
||||
)
|
||||
op.create_foreign_key(
|
||||
'fk_user_organizations_organization_id',
|
||||
'user_organizations',
|
||||
'organizations',
|
||||
['organization_id'],
|
||||
['id'],
|
||||
ondelete='CASCADE'
|
||||
"fk_user_organizations_organization_id",
|
||||
"user_organizations",
|
||||
"organizations",
|
||||
["organization_id"],
|
||||
["id"],
|
||||
ondelete="CASCADE",
|
||||
)
|
||||
|
||||
# Create indexes for user_organizations
|
||||
op.create_index('ix_user_organizations_role', 'user_organizations', ['role'])
|
||||
op.create_index('ix_user_organizations_is_active', 'user_organizations', ['is_active'])
|
||||
op.create_index('ix_user_org_user_active', 'user_organizations', ['user_id', 'is_active'])
|
||||
op.create_index('ix_user_org_org_active', 'user_organizations', ['organization_id', 'is_active'])
|
||||
op.create_index("ix_user_organizations_role", "user_organizations", ["role"])
|
||||
op.create_index(
|
||||
"ix_user_organizations_is_active", "user_organizations", ["is_active"]
|
||||
)
|
||||
op.create_index(
|
||||
"ix_user_org_user_active", "user_organizations", ["user_id", "is_active"]
|
||||
)
|
||||
op.create_index(
|
||||
"ix_user_org_org_active", "user_organizations", ["organization_id", "is_active"]
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Drop indexes for user_organizations
|
||||
op.drop_index('ix_user_org_org_active', table_name='user_organizations')
|
||||
op.drop_index('ix_user_org_user_active', table_name='user_organizations')
|
||||
op.drop_index('ix_user_organizations_is_active', table_name='user_organizations')
|
||||
op.drop_index('ix_user_organizations_role', table_name='user_organizations')
|
||||
op.drop_index("ix_user_org_org_active", table_name="user_organizations")
|
||||
op.drop_index("ix_user_org_user_active", table_name="user_organizations")
|
||||
op.drop_index("ix_user_organizations_is_active", table_name="user_organizations")
|
||||
op.drop_index("ix_user_organizations_role", table_name="user_organizations")
|
||||
|
||||
# Drop foreign keys
|
||||
op.drop_constraint('fk_user_organizations_organization_id', 'user_organizations', type_='foreignkey')
|
||||
op.drop_constraint('fk_user_organizations_user_id', 'user_organizations', type_='foreignkey')
|
||||
op.drop_constraint(
|
||||
"fk_user_organizations_organization_id",
|
||||
"user_organizations",
|
||||
type_="foreignkey",
|
||||
)
|
||||
op.drop_constraint(
|
||||
"fk_user_organizations_user_id", "user_organizations", type_="foreignkey"
|
||||
)
|
||||
|
||||
# Drop user_organizations table
|
||||
op.drop_table('user_organizations')
|
||||
op.drop_table("user_organizations")
|
||||
|
||||
# Drop indexes for organizations
|
||||
op.drop_index('ix_organizations_slug_active', table_name='organizations')
|
||||
op.drop_index('ix_organizations_name_active', table_name='organizations')
|
||||
op.drop_index('ix_organizations_is_active', table_name='organizations')
|
||||
op.drop_index('ix_organizations_slug', table_name='organizations')
|
||||
op.drop_index('ix_organizations_name', table_name='organizations')
|
||||
op.drop_index("ix_organizations_slug_active", table_name="organizations")
|
||||
op.drop_index("ix_organizations_name_active", table_name="organizations")
|
||||
op.drop_index("ix_organizations_is_active", table_name="organizations")
|
||||
op.drop_index("ix_organizations_slug", table_name="organizations")
|
||||
op.drop_index("ix_organizations_name", table_name="organizations")
|
||||
|
||||
# Drop organizations table
|
||||
op.drop_table('organizations')
|
||||
op.drop_table("organizations")
|
||||
|
||||
# Drop enum type
|
||||
op.execute('DROP TYPE IF EXISTS organizationrole')
|
||||
op.execute("DROP TYPE IF EXISTS organizationrole")
|
||||
|
||||
@@ -1,12 +1,10 @@
|
||||
from typing import Optional
|
||||
|
||||
from fastapi import Depends, HTTPException, status, Header
|
||||
from fastapi import Depends, Header, HTTPException, status
|
||||
from fastapi.security import OAuth2PasswordBearer
|
||||
from fastapi.security.utils import get_authorization_scheme_param
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.core.auth import get_token_data, TokenExpiredError, TokenInvalidError
|
||||
from app.core.auth import TokenExpiredError, TokenInvalidError, get_token_data
|
||||
from app.core.database import get_db
|
||||
from app.models.user import User
|
||||
|
||||
@@ -15,8 +13,7 @@ oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/api/v1/auth/login")
|
||||
|
||||
|
||||
async def get_current_user(
|
||||
db: AsyncSession = Depends(get_db),
|
||||
token: str = Depends(oauth2_scheme)
|
||||
db: AsyncSession = Depends(get_db), token: str = Depends(oauth2_scheme)
|
||||
) -> User:
|
||||
"""
|
||||
Get the current authenticated user.
|
||||
@@ -36,21 +33,17 @@ async def get_current_user(
|
||||
token_data = get_token_data(token)
|
||||
|
||||
# Get user from database
|
||||
result = await db.execute(
|
||||
select(User).where(User.id == token_data.user_id)
|
||||
)
|
||||
result = await db.execute(select(User).where(User.id == token_data.user_id))
|
||||
user = result.scalar_one_or_none()
|
||||
|
||||
if not user:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="User not found"
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail="User not found"
|
||||
)
|
||||
|
||||
if not user.is_active:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Inactive user"
|
||||
status_code=status.HTTP_403_FORBIDDEN, detail="Inactive user"
|
||||
)
|
||||
|
||||
return user
|
||||
@@ -59,19 +52,17 @@ async def get_current_user(
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Token expired",
|
||||
headers={"WWW-Authenticate": "Bearer"}
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
except TokenInvalidError:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Could not validate credentials",
|
||||
headers={"WWW-Authenticate": "Bearer"}
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
|
||||
|
||||
def get_current_active_user(
|
||||
current_user: User = Depends(get_current_user)
|
||||
) -> User:
|
||||
def get_current_active_user(current_user: User = Depends(get_current_user)) -> User:
|
||||
"""
|
||||
Check if the current user is active.
|
||||
|
||||
@@ -86,15 +77,12 @@ def get_current_active_user(
|
||||
"""
|
||||
if not current_user.is_active:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Inactive user"
|
||||
status_code=status.HTTP_403_FORBIDDEN, detail="Inactive user"
|
||||
)
|
||||
return current_user
|
||||
|
||||
|
||||
def get_current_superuser(
|
||||
current_user: User = Depends(get_current_user)
|
||||
) -> User:
|
||||
def get_current_superuser(current_user: User = Depends(get_current_user)) -> User:
|
||||
"""
|
||||
Check if the current user is a superuser.
|
||||
|
||||
@@ -109,13 +97,12 @@ def get_current_superuser(
|
||||
"""
|
||||
if not current_user.is_superuser:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Not enough permissions"
|
||||
status_code=status.HTTP_403_FORBIDDEN, detail="Not enough permissions"
|
||||
)
|
||||
return current_user
|
||||
|
||||
|
||||
async def get_optional_token(authorization: str = Header(None)) -> Optional[str]:
|
||||
async def get_optional_token(authorization: str = Header(None)) -> str | None:
|
||||
"""
|
||||
Get the token from the Authorization header without requiring it.
|
||||
|
||||
@@ -139,9 +126,8 @@ async def get_optional_token(authorization: str = Header(None)) -> Optional[str]
|
||||
|
||||
|
||||
async def get_optional_current_user(
|
||||
db: AsyncSession = Depends(get_db),
|
||||
token: Optional[str] = Depends(get_optional_token)
|
||||
) -> Optional[User]:
|
||||
db: AsyncSession = Depends(get_db), token: str | None = Depends(get_optional_token)
|
||||
) -> User | None:
|
||||
"""
|
||||
Get the current user if authenticated, otherwise return None.
|
||||
Useful for endpoints that work with both authenticated and unauthenticated users.
|
||||
@@ -158,12 +144,10 @@ async def get_optional_current_user(
|
||||
|
||||
try:
|
||||
token_data = get_token_data(token)
|
||||
result = await db.execute(
|
||||
select(User).where(User.id == token_data.user_id)
|
||||
)
|
||||
result = await db.execute(select(User).where(User.id == token_data.user_id))
|
||||
user = result.scalar_one_or_none()
|
||||
if not user or not user.is_active:
|
||||
return None
|
||||
return user
|
||||
except (TokenExpiredError, TokenInvalidError):
|
||||
return None
|
||||
return None
|
||||
|
||||
@@ -7,7 +7,7 @@ These dependencies are optional and flexible:
|
||||
- Use require_org_role for organization-specific access control
|
||||
- Projects can choose to use these or implement their own permission system
|
||||
"""
|
||||
from typing import Optional
|
||||
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import Depends, HTTPException, status
|
||||
@@ -20,9 +20,7 @@ from app.models.user import User
|
||||
from app.models.user_organization import OrganizationRole
|
||||
|
||||
|
||||
def require_superuser(
|
||||
current_user: User = Depends(get_current_user)
|
||||
) -> User:
|
||||
def require_superuser(current_user: User = Depends(get_current_user)) -> User:
|
||||
"""
|
||||
Dependency to ensure the current user is a superuser.
|
||||
|
||||
@@ -36,7 +34,7 @@ def require_superuser(
|
||||
if not current_user.is_superuser:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Superuser privileges required"
|
||||
detail="Superuser privileges required",
|
||||
)
|
||||
return current_user
|
||||
|
||||
@@ -62,7 +60,7 @@ class OrganizationPermission:
|
||||
self,
|
||||
organization_id: UUID,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> User:
|
||||
"""
|
||||
Check if user has required role in the organization.
|
||||
@@ -84,21 +82,19 @@ class OrganizationPermission:
|
||||
|
||||
# Get user's role in organization
|
||||
user_role = await organization_crud.get_user_role_in_org(
|
||||
db,
|
||||
user_id=current_user.id,
|
||||
organization_id=organization_id
|
||||
db, user_id=current_user.id, organization_id=organization_id
|
||||
)
|
||||
|
||||
if not user_role:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Not a member of this organization"
|
||||
detail="Not a member of this organization",
|
||||
)
|
||||
|
||||
if user_role not in self.allowed_roles:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail=f"Role {user_role} not authorized. Required: {self.allowed_roles}"
|
||||
detail=f"Role {user_role} not authorized. Required: {self.allowed_roles}",
|
||||
)
|
||||
|
||||
return current_user
|
||||
@@ -106,18 +102,18 @@ class OrganizationPermission:
|
||||
|
||||
# Common permission presets for convenience
|
||||
require_org_owner = OrganizationPermission([OrganizationRole.OWNER])
|
||||
require_org_admin = OrganizationPermission([OrganizationRole.OWNER, OrganizationRole.ADMIN])
|
||||
require_org_member = OrganizationPermission([
|
||||
OrganizationRole.OWNER,
|
||||
OrganizationRole.ADMIN,
|
||||
OrganizationRole.MEMBER
|
||||
])
|
||||
require_org_admin = OrganizationPermission(
|
||||
[OrganizationRole.OWNER, OrganizationRole.ADMIN]
|
||||
)
|
||||
require_org_member = OrganizationPermission(
|
||||
[OrganizationRole.OWNER, OrganizationRole.ADMIN, OrganizationRole.MEMBER]
|
||||
)
|
||||
|
||||
|
||||
async def require_org_membership(
|
||||
organization_id: UUID,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> User:
|
||||
"""
|
||||
Ensure user is a member of the organization (any role).
|
||||
@@ -128,15 +124,13 @@ async def require_org_membership(
|
||||
return current_user
|
||||
|
||||
user_role = await organization_crud.get_user_role_in_org(
|
||||
db,
|
||||
user_id=current_user.id,
|
||||
organization_id=organization_id
|
||||
db, user_id=current_user.id, organization_id=organization_id
|
||||
)
|
||||
|
||||
if not user_role:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Not a member of this organization"
|
||||
detail="Not a member of this organization",
|
||||
)
|
||||
|
||||
return current_user
|
||||
|
||||
@@ -1,10 +1,12 @@
|
||||
from fastapi import APIRouter
|
||||
|
||||
from app.api.routes import auth, users, sessions, admin, organizations
|
||||
from app.api.routes import admin, auth, organizations, sessions, users
|
||||
|
||||
api_router = APIRouter()
|
||||
api_router.include_router(auth.router, prefix="/auth", tags=["Authentication"])
|
||||
api_router.include_router(users.router, prefix="/users", tags=["Users"])
|
||||
api_router.include_router(sessions.router, prefix="/sessions", tags=["Sessions"])
|
||||
api_router.include_router(admin.router, prefix="/admin", tags=["Admin"])
|
||||
api_router.include_router(organizations.router, prefix="/organizations", tags=["Organizations"])
|
||||
api_router.include_router(
|
||||
organizations.router, prefix="/organizations", tags=["Organizations"]
|
||||
)
|
||||
|
||||
@@ -5,9 +5,10 @@ Admin-specific endpoints for managing users and organizations.
|
||||
These endpoints require superuser privileges and provide CMS-like functionality
|
||||
for managing the application.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from enum import Enum
|
||||
from typing import Any, List, Optional
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import APIRouter, Depends, Query, status
|
||||
@@ -16,27 +17,32 @@ from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.api.dependencies.permissions import require_superuser
|
||||
from app.core.database import get_db
|
||||
from app.core.exceptions import NotFoundError, DuplicateError, AuthorizationError, ErrorCode
|
||||
from app.core.exceptions import (
|
||||
AuthorizationError,
|
||||
DuplicateError,
|
||||
ErrorCode,
|
||||
NotFoundError,
|
||||
)
|
||||
from app.crud.organization import organization as organization_crud
|
||||
from app.crud.user import user as user_crud
|
||||
from app.crud.session import session as session_crud
|
||||
from app.crud.user import user as user_crud
|
||||
from app.models.user import User
|
||||
from app.models.user_organization import OrganizationRole
|
||||
from app.schemas.common import (
|
||||
PaginationParams,
|
||||
PaginatedResponse,
|
||||
MessageResponse,
|
||||
PaginatedResponse,
|
||||
PaginationParams,
|
||||
SortParams,
|
||||
create_pagination_meta
|
||||
create_pagination_meta,
|
||||
)
|
||||
from app.schemas.organizations import (
|
||||
OrganizationResponse,
|
||||
OrganizationCreate,
|
||||
OrganizationMemberResponse,
|
||||
OrganizationResponse,
|
||||
OrganizationUpdate,
|
||||
OrganizationMemberResponse
|
||||
)
|
||||
from app.schemas.users import UserResponse, UserCreate, UserUpdate
|
||||
from app.schemas.sessions import AdminSessionResponse
|
||||
from app.schemas.users import UserCreate, UserResponse, UserUpdate
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -46,6 +52,7 @@ router = APIRouter()
|
||||
# Schemas for bulk operations
|
||||
class BulkAction(str, Enum):
|
||||
"""Supported bulk actions."""
|
||||
|
||||
ACTIVATE = "activate"
|
||||
DEACTIVATE = "deactivate"
|
||||
DELETE = "delete"
|
||||
@@ -53,36 +60,41 @@ class BulkAction(str, Enum):
|
||||
|
||||
class BulkUserAction(BaseModel):
|
||||
"""Schema for bulk user actions."""
|
||||
|
||||
action: BulkAction = Field(..., description="Action to perform on selected users")
|
||||
user_ids: List[UUID] = Field(..., min_items=1, max_items=100, description="List of user IDs (max 100)")
|
||||
user_ids: list[UUID] = Field(
|
||||
..., min_items=1, max_items=100, description="List of user IDs (max 100)"
|
||||
)
|
||||
|
||||
|
||||
class BulkActionResult(BaseModel):
|
||||
"""Result of a bulk action."""
|
||||
|
||||
success: bool
|
||||
affected_count: int
|
||||
failed_count: int
|
||||
message: str
|
||||
failed_ids: Optional[List[UUID]] = []
|
||||
failed_ids: list[UUID] | None = []
|
||||
|
||||
|
||||
# ===== User Management Endpoints =====
|
||||
|
||||
|
||||
@router.get(
|
||||
"/users",
|
||||
response_model=PaginatedResponse[UserResponse],
|
||||
summary="Admin: List All Users",
|
||||
description="Get paginated list of all users with filtering and search (admin only)",
|
||||
operation_id="admin_list_users"
|
||||
operation_id="admin_list_users",
|
||||
)
|
||||
async def admin_list_users(
|
||||
pagination: PaginationParams = Depends(),
|
||||
sort: SortParams = Depends(),
|
||||
is_active: Optional[bool] = Query(None, description="Filter by active status"),
|
||||
is_superuser: Optional[bool] = Query(None, description="Filter by superuser status"),
|
||||
search: Optional[str] = Query(None, description="Search by email, name"),
|
||||
is_active: bool | None = Query(None, description="Filter by active status"),
|
||||
is_superuser: bool | None = Query(None, description="Filter by superuser status"),
|
||||
search: str | None = Query(None, description="Search by email, name"),
|
||||
admin: User = Depends(require_superuser),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> Any:
|
||||
"""
|
||||
List all users with comprehensive filtering and search.
|
||||
@@ -105,20 +117,20 @@ async def admin_list_users(
|
||||
sort_by=sort.sort_by or "created_at",
|
||||
sort_order=sort.sort_order.value if sort.sort_order else "desc",
|
||||
filters=filters if filters else None,
|
||||
search=search
|
||||
search=search,
|
||||
)
|
||||
|
||||
pagination_meta = create_pagination_meta(
|
||||
total=total,
|
||||
page=pagination.page,
|
||||
limit=pagination.limit,
|
||||
items_count=len(users)
|
||||
items_count=len(users),
|
||||
)
|
||||
|
||||
return PaginatedResponse(data=users, pagination=pagination_meta)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error listing users (admin): {str(e)}", exc_info=True)
|
||||
logger.error(f"Error listing users (admin): {e!s}", exc_info=True)
|
||||
raise
|
||||
|
||||
|
||||
@@ -128,12 +140,12 @@ async def admin_list_users(
|
||||
status_code=status.HTTP_201_CREATED,
|
||||
summary="Admin: Create User",
|
||||
description="Create a new user (admin only)",
|
||||
operation_id="admin_create_user"
|
||||
operation_id="admin_create_user",
|
||||
)
|
||||
async def admin_create_user(
|
||||
user_in: UserCreate,
|
||||
admin: User = Depends(require_superuser),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> Any:
|
||||
"""
|
||||
Create a new user with admin privileges.
|
||||
@@ -145,13 +157,10 @@ async def admin_create_user(
|
||||
logger.info(f"Admin {admin.email} created user {user.email}")
|
||||
return user
|
||||
except ValueError as e:
|
||||
logger.warning(f"Failed to create user: {str(e)}")
|
||||
raise NotFoundError(
|
||||
message=str(e),
|
||||
error_code=ErrorCode.USER_ALREADY_EXISTS
|
||||
)
|
||||
logger.warning(f"Failed to create user: {e!s}")
|
||||
raise NotFoundError(message=str(e), error_code=ErrorCode.USER_ALREADY_EXISTS)
|
||||
except Exception as e:
|
||||
logger.error(f"Error creating user (admin): {str(e)}", exc_info=True)
|
||||
logger.error(f"Error creating user (admin): {e!s}", exc_info=True)
|
||||
raise
|
||||
|
||||
|
||||
@@ -160,19 +169,18 @@ async def admin_create_user(
|
||||
response_model=UserResponse,
|
||||
summary="Admin: Get User Details",
|
||||
description="Get detailed user information (admin only)",
|
||||
operation_id="admin_get_user"
|
||||
operation_id="admin_get_user",
|
||||
)
|
||||
async def admin_get_user(
|
||||
user_id: UUID,
|
||||
admin: User = Depends(require_superuser),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> Any:
|
||||
"""Get detailed information about a specific user."""
|
||||
user = await user_crud.get(db, id=user_id)
|
||||
if not user:
|
||||
raise NotFoundError(
|
||||
message=f"User {user_id} not found",
|
||||
error_code=ErrorCode.USER_NOT_FOUND
|
||||
message=f"User {user_id} not found", error_code=ErrorCode.USER_NOT_FOUND
|
||||
)
|
||||
return user
|
||||
|
||||
@@ -182,21 +190,20 @@ async def admin_get_user(
|
||||
response_model=UserResponse,
|
||||
summary="Admin: Update User",
|
||||
description="Update user information (admin only)",
|
||||
operation_id="admin_update_user"
|
||||
operation_id="admin_update_user",
|
||||
)
|
||||
async def admin_update_user(
|
||||
user_id: UUID,
|
||||
user_in: UserUpdate,
|
||||
admin: User = Depends(require_superuser),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> Any:
|
||||
"""Update user information with admin privileges."""
|
||||
try:
|
||||
user = await user_crud.get(db, id=user_id)
|
||||
if not user:
|
||||
raise NotFoundError(
|
||||
message=f"User {user_id} not found",
|
||||
error_code=ErrorCode.USER_NOT_FOUND
|
||||
message=f"User {user_id} not found", error_code=ErrorCode.USER_NOT_FOUND
|
||||
)
|
||||
|
||||
updated_user = await user_crud.update(db, db_obj=user, obj_in=user_in)
|
||||
@@ -206,7 +213,7 @@ async def admin_update_user(
|
||||
except NotFoundError:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error updating user (admin): {str(e)}", exc_info=True)
|
||||
logger.error(f"Error updating user (admin): {e!s}", exc_info=True)
|
||||
raise
|
||||
|
||||
|
||||
@@ -215,20 +222,19 @@ async def admin_update_user(
|
||||
response_model=MessageResponse,
|
||||
summary="Admin: Delete User",
|
||||
description="Soft delete a user (admin only)",
|
||||
operation_id="admin_delete_user"
|
||||
operation_id="admin_delete_user",
|
||||
)
|
||||
async def admin_delete_user(
|
||||
user_id: UUID,
|
||||
admin: User = Depends(require_superuser),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> Any:
|
||||
"""Soft delete a user (sets deleted_at timestamp)."""
|
||||
try:
|
||||
user = await user_crud.get(db, id=user_id)
|
||||
if not user:
|
||||
raise NotFoundError(
|
||||
message=f"User {user_id} not found",
|
||||
error_code=ErrorCode.USER_NOT_FOUND
|
||||
message=f"User {user_id} not found", error_code=ErrorCode.USER_NOT_FOUND
|
||||
)
|
||||
|
||||
# Prevent deleting yourself
|
||||
@@ -236,21 +242,20 @@ async def admin_delete_user(
|
||||
# Use AuthorizationError for permission/operation restrictions
|
||||
raise AuthorizationError(
|
||||
message="Cannot delete your own account",
|
||||
error_code=ErrorCode.OPERATION_FORBIDDEN
|
||||
error_code=ErrorCode.OPERATION_FORBIDDEN,
|
||||
)
|
||||
|
||||
await user_crud.soft_delete(db, id=user_id)
|
||||
logger.info(f"Admin {admin.email} deleted user {user.email}")
|
||||
|
||||
return MessageResponse(
|
||||
success=True,
|
||||
message=f"User {user.email} has been deleted"
|
||||
success=True, message=f"User {user.email} has been deleted"
|
||||
)
|
||||
|
||||
except NotFoundError:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error deleting user (admin): {str(e)}", exc_info=True)
|
||||
logger.error(f"Error deleting user (admin): {e!s}", exc_info=True)
|
||||
raise
|
||||
|
||||
|
||||
@@ -259,34 +264,32 @@ async def admin_delete_user(
|
||||
response_model=MessageResponse,
|
||||
summary="Admin: Activate User",
|
||||
description="Activate a user account (admin only)",
|
||||
operation_id="admin_activate_user"
|
||||
operation_id="admin_activate_user",
|
||||
)
|
||||
async def admin_activate_user(
|
||||
user_id: UUID,
|
||||
admin: User = Depends(require_superuser),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> Any:
|
||||
"""Activate a user account."""
|
||||
try:
|
||||
user = await user_crud.get(db, id=user_id)
|
||||
if not user:
|
||||
raise NotFoundError(
|
||||
message=f"User {user_id} not found",
|
||||
error_code=ErrorCode.USER_NOT_FOUND
|
||||
message=f"User {user_id} not found", error_code=ErrorCode.USER_NOT_FOUND
|
||||
)
|
||||
|
||||
await user_crud.update(db, db_obj=user, obj_in={"is_active": True})
|
||||
logger.info(f"Admin {admin.email} activated user {user.email}")
|
||||
|
||||
return MessageResponse(
|
||||
success=True,
|
||||
message=f"User {user.email} has been activated"
|
||||
success=True, message=f"User {user.email} has been activated"
|
||||
)
|
||||
|
||||
except NotFoundError:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error activating user (admin): {str(e)}", exc_info=True)
|
||||
logger.error(f"Error activating user (admin): {e!s}", exc_info=True)
|
||||
raise
|
||||
|
||||
|
||||
@@ -295,20 +298,19 @@ async def admin_activate_user(
|
||||
response_model=MessageResponse,
|
||||
summary="Admin: Deactivate User",
|
||||
description="Deactivate a user account (admin only)",
|
||||
operation_id="admin_deactivate_user"
|
||||
operation_id="admin_deactivate_user",
|
||||
)
|
||||
async def admin_deactivate_user(
|
||||
user_id: UUID,
|
||||
admin: User = Depends(require_superuser),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> Any:
|
||||
"""Deactivate a user account."""
|
||||
try:
|
||||
user = await user_crud.get(db, id=user_id)
|
||||
if not user:
|
||||
raise NotFoundError(
|
||||
message=f"User {user_id} not found",
|
||||
error_code=ErrorCode.USER_NOT_FOUND
|
||||
message=f"User {user_id} not found", error_code=ErrorCode.USER_NOT_FOUND
|
||||
)
|
||||
|
||||
# Prevent deactivating yourself
|
||||
@@ -316,21 +318,20 @@ async def admin_deactivate_user(
|
||||
# Use AuthorizationError for permission/operation restrictions
|
||||
raise AuthorizationError(
|
||||
message="Cannot deactivate your own account",
|
||||
error_code=ErrorCode.OPERATION_FORBIDDEN
|
||||
error_code=ErrorCode.OPERATION_FORBIDDEN,
|
||||
)
|
||||
|
||||
await user_crud.update(db, db_obj=user, obj_in={"is_active": False})
|
||||
logger.info(f"Admin {admin.email} deactivated user {user.email}")
|
||||
|
||||
return MessageResponse(
|
||||
success=True,
|
||||
message=f"User {user.email} has been deactivated"
|
||||
success=True, message=f"User {user.email} has been deactivated"
|
||||
)
|
||||
|
||||
except NotFoundError:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error deactivating user (admin): {str(e)}", exc_info=True)
|
||||
logger.error(f"Error deactivating user (admin): {e!s}", exc_info=True)
|
||||
raise
|
||||
|
||||
|
||||
@@ -339,12 +340,12 @@ async def admin_deactivate_user(
|
||||
response_model=BulkActionResult,
|
||||
summary="Admin: Bulk User Action",
|
||||
description="Perform bulk actions on multiple users (admin only)",
|
||||
operation_id="admin_bulk_user_action"
|
||||
operation_id="admin_bulk_user_action",
|
||||
)
|
||||
async def admin_bulk_user_action(
|
||||
bulk_action: BulkUserAction,
|
||||
admin: User = Depends(require_superuser),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> Any:
|
||||
"""
|
||||
Perform bulk actions on multiple users using optimized bulk operations.
|
||||
@@ -356,22 +357,16 @@ async def admin_bulk_user_action(
|
||||
# Use efficient bulk operations instead of loop
|
||||
if bulk_action.action == BulkAction.ACTIVATE:
|
||||
affected_count = await user_crud.bulk_update_status(
|
||||
db,
|
||||
user_ids=bulk_action.user_ids,
|
||||
is_active=True
|
||||
db, user_ids=bulk_action.user_ids, is_active=True
|
||||
)
|
||||
elif bulk_action.action == BulkAction.DEACTIVATE:
|
||||
affected_count = await user_crud.bulk_update_status(
|
||||
db,
|
||||
user_ids=bulk_action.user_ids,
|
||||
is_active=False
|
||||
db, user_ids=bulk_action.user_ids, is_active=False
|
||||
)
|
||||
elif bulk_action.action == BulkAction.DELETE:
|
||||
# bulk_soft_delete automatically excludes the admin user
|
||||
affected_count = await user_crud.bulk_soft_delete(
|
||||
db,
|
||||
user_ids=bulk_action.user_ids,
|
||||
exclude_user_id=admin.id
|
||||
db, user_ids=bulk_action.user_ids, exclude_user_id=admin.id
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unsupported bulk action: {bulk_action.action}")
|
||||
@@ -390,29 +385,30 @@ async def admin_bulk_user_action(
|
||||
affected_count=affected_count,
|
||||
failed_count=failed_count,
|
||||
message=f"Bulk {bulk_action.action.value}: {affected_count} users affected, {failed_count} skipped",
|
||||
failed_ids=None # Bulk operations don't track individual failures
|
||||
failed_ids=None, # Bulk operations don't track individual failures
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in bulk user action: {str(e)}", exc_info=True)
|
||||
logger.error(f"Error in bulk user action: {e!s}", exc_info=True)
|
||||
raise
|
||||
|
||||
|
||||
# ===== Organization Management Endpoints =====
|
||||
|
||||
|
||||
@router.get(
|
||||
"/organizations",
|
||||
response_model=PaginatedResponse[OrganizationResponse],
|
||||
summary="Admin: List Organizations",
|
||||
description="Get paginated list of all organizations (admin only)",
|
||||
operation_id="admin_list_organizations"
|
||||
operation_id="admin_list_organizations",
|
||||
)
|
||||
async def admin_list_organizations(
|
||||
pagination: PaginationParams = Depends(),
|
||||
is_active: Optional[bool] = Query(None, description="Filter by active status"),
|
||||
search: Optional[str] = Query(None, description="Search by name, slug, description"),
|
||||
is_active: bool | None = Query(None, description="Filter by active status"),
|
||||
search: str | None = Query(None, description="Search by name, slug, description"),
|
||||
admin: User = Depends(require_superuser),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> Any:
|
||||
"""List all organizations with filtering and search."""
|
||||
try:
|
||||
@@ -422,14 +418,14 @@ async def admin_list_organizations(
|
||||
skip=pagination.offset,
|
||||
limit=pagination.limit,
|
||||
is_active=is_active,
|
||||
search=search
|
||||
search=search,
|
||||
)
|
||||
|
||||
# Build response objects from optimized query results
|
||||
orgs_with_count = []
|
||||
for item in orgs_with_data:
|
||||
org = item['organization']
|
||||
member_count = item['member_count']
|
||||
org = item["organization"]
|
||||
member_count = item["member_count"]
|
||||
|
||||
org_dict = {
|
||||
"id": org.id,
|
||||
@@ -440,7 +436,7 @@ async def admin_list_organizations(
|
||||
"settings": org.settings,
|
||||
"created_at": org.created_at,
|
||||
"updated_at": org.updated_at,
|
||||
"member_count": member_count
|
||||
"member_count": member_count,
|
||||
}
|
||||
orgs_with_count.append(OrganizationResponse(**org_dict))
|
||||
|
||||
@@ -448,13 +444,13 @@ async def admin_list_organizations(
|
||||
total=total,
|
||||
page=pagination.page,
|
||||
limit=pagination.limit,
|
||||
items_count=len(orgs_with_count)
|
||||
items_count=len(orgs_with_count),
|
||||
)
|
||||
|
||||
return PaginatedResponse(data=orgs_with_count, pagination=pagination_meta)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error listing organizations (admin): {str(e)}", exc_info=True)
|
||||
logger.error(f"Error listing organizations (admin): {e!s}", exc_info=True)
|
||||
raise
|
||||
|
||||
|
||||
@@ -464,12 +460,12 @@ async def admin_list_organizations(
|
||||
status_code=status.HTTP_201_CREATED,
|
||||
summary="Admin: Create Organization",
|
||||
description="Create a new organization (admin only)",
|
||||
operation_id="admin_create_organization"
|
||||
operation_id="admin_create_organization",
|
||||
)
|
||||
async def admin_create_organization(
|
||||
org_in: OrganizationCreate,
|
||||
admin: User = Depends(require_superuser),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> Any:
|
||||
"""Create a new organization."""
|
||||
try:
|
||||
@@ -486,18 +482,15 @@ async def admin_create_organization(
|
||||
"settings": org.settings,
|
||||
"created_at": org.created_at,
|
||||
"updated_at": org.updated_at,
|
||||
"member_count": 0
|
||||
"member_count": 0,
|
||||
}
|
||||
return OrganizationResponse(**org_dict)
|
||||
|
||||
except ValueError as e:
|
||||
logger.warning(f"Failed to create organization: {str(e)}")
|
||||
raise NotFoundError(
|
||||
message=str(e),
|
||||
error_code=ErrorCode.ALREADY_EXISTS
|
||||
)
|
||||
logger.warning(f"Failed to create organization: {e!s}")
|
||||
raise NotFoundError(message=str(e), error_code=ErrorCode.ALREADY_EXISTS)
|
||||
except Exception as e:
|
||||
logger.error(f"Error creating organization (admin): {str(e)}", exc_info=True)
|
||||
logger.error(f"Error creating organization (admin): {e!s}", exc_info=True)
|
||||
raise
|
||||
|
||||
|
||||
@@ -506,19 +499,18 @@ async def admin_create_organization(
|
||||
response_model=OrganizationResponse,
|
||||
summary="Admin: Get Organization Details",
|
||||
description="Get detailed organization information (admin only)",
|
||||
operation_id="admin_get_organization"
|
||||
operation_id="admin_get_organization",
|
||||
)
|
||||
async def admin_get_organization(
|
||||
org_id: UUID,
|
||||
admin: User = Depends(require_superuser),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> Any:
|
||||
"""Get detailed information about a specific organization."""
|
||||
org = await organization_crud.get(db, id=org_id)
|
||||
if not org:
|
||||
raise NotFoundError(
|
||||
message=f"Organization {org_id} not found",
|
||||
error_code=ErrorCode.NOT_FOUND
|
||||
message=f"Organization {org_id} not found", error_code=ErrorCode.NOT_FOUND
|
||||
)
|
||||
|
||||
org_dict = {
|
||||
@@ -530,7 +522,9 @@ async def admin_get_organization(
|
||||
"settings": org.settings,
|
||||
"created_at": org.created_at,
|
||||
"updated_at": org.updated_at,
|
||||
"member_count": await organization_crud.get_member_count(db, organization_id=org.id)
|
||||
"member_count": await organization_crud.get_member_count(
|
||||
db, organization_id=org.id
|
||||
),
|
||||
}
|
||||
return OrganizationResponse(**org_dict)
|
||||
|
||||
@@ -540,13 +534,13 @@ async def admin_get_organization(
|
||||
response_model=OrganizationResponse,
|
||||
summary="Admin: Update Organization",
|
||||
description="Update organization information (admin only)",
|
||||
operation_id="admin_update_organization"
|
||||
operation_id="admin_update_organization",
|
||||
)
|
||||
async def admin_update_organization(
|
||||
org_id: UUID,
|
||||
org_in: OrganizationUpdate,
|
||||
admin: User = Depends(require_superuser),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> Any:
|
||||
"""Update organization information."""
|
||||
try:
|
||||
@@ -554,7 +548,7 @@ async def admin_update_organization(
|
||||
if not org:
|
||||
raise NotFoundError(
|
||||
message=f"Organization {org_id} not found",
|
||||
error_code=ErrorCode.NOT_FOUND
|
||||
error_code=ErrorCode.NOT_FOUND,
|
||||
)
|
||||
|
||||
updated_org = await organization_crud.update(db, db_obj=org, obj_in=org_in)
|
||||
@@ -569,14 +563,16 @@ async def admin_update_organization(
|
||||
"settings": updated_org.settings,
|
||||
"created_at": updated_org.created_at,
|
||||
"updated_at": updated_org.updated_at,
|
||||
"member_count": await organization_crud.get_member_count(db, organization_id=updated_org.id)
|
||||
"member_count": await organization_crud.get_member_count(
|
||||
db, organization_id=updated_org.id
|
||||
),
|
||||
}
|
||||
return OrganizationResponse(**org_dict)
|
||||
|
||||
except NotFoundError:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error updating organization (admin): {str(e)}", exc_info=True)
|
||||
logger.error(f"Error updating organization (admin): {e!s}", exc_info=True)
|
||||
raise
|
||||
|
||||
|
||||
@@ -585,12 +581,12 @@ async def admin_update_organization(
|
||||
response_model=MessageResponse,
|
||||
summary="Admin: Delete Organization",
|
||||
description="Delete an organization (admin only)",
|
||||
operation_id="admin_delete_organization"
|
||||
operation_id="admin_delete_organization",
|
||||
)
|
||||
async def admin_delete_organization(
|
||||
org_id: UUID,
|
||||
admin: User = Depends(require_superuser),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> Any:
|
||||
"""Delete an organization and all its relationships."""
|
||||
try:
|
||||
@@ -598,21 +594,20 @@ async def admin_delete_organization(
|
||||
if not org:
|
||||
raise NotFoundError(
|
||||
message=f"Organization {org_id} not found",
|
||||
error_code=ErrorCode.NOT_FOUND
|
||||
error_code=ErrorCode.NOT_FOUND,
|
||||
)
|
||||
|
||||
await organization_crud.remove(db, id=org_id)
|
||||
logger.info(f"Admin {admin.email} deleted organization {org.name}")
|
||||
|
||||
return MessageResponse(
|
||||
success=True,
|
||||
message=f"Organization {org.name} has been deleted"
|
||||
success=True, message=f"Organization {org.name} has been deleted"
|
||||
)
|
||||
|
||||
except NotFoundError:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error deleting organization (admin): {str(e)}", exc_info=True)
|
||||
logger.error(f"Error deleting organization (admin): {e!s}", exc_info=True)
|
||||
raise
|
||||
|
||||
|
||||
@@ -621,14 +616,14 @@ async def admin_delete_organization(
|
||||
response_model=PaginatedResponse[OrganizationMemberResponse],
|
||||
summary="Admin: List Organization Members",
|
||||
description="Get all members of an organization (admin only)",
|
||||
operation_id="admin_list_organization_members"
|
||||
operation_id="admin_list_organization_members",
|
||||
)
|
||||
async def admin_list_organization_members(
|
||||
org_id: UUID,
|
||||
pagination: PaginationParams = Depends(),
|
||||
is_active: Optional[bool] = Query(True, description="Filter by active status"),
|
||||
is_active: bool | None = Query(True, description="Filter by active status"),
|
||||
admin: User = Depends(require_superuser),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> Any:
|
||||
"""List all members of an organization."""
|
||||
try:
|
||||
@@ -636,7 +631,7 @@ async def admin_list_organization_members(
|
||||
if not org:
|
||||
raise NotFoundError(
|
||||
message=f"Organization {org_id} not found",
|
||||
error_code=ErrorCode.NOT_FOUND
|
||||
error_code=ErrorCode.NOT_FOUND,
|
||||
)
|
||||
|
||||
members, total = await organization_crud.get_organization_members(
|
||||
@@ -644,7 +639,7 @@ async def admin_list_organization_members(
|
||||
organization_id=org_id,
|
||||
skip=pagination.offset,
|
||||
limit=pagination.limit,
|
||||
is_active=is_active
|
||||
is_active=is_active,
|
||||
)
|
||||
|
||||
# Convert to response models
|
||||
@@ -654,7 +649,7 @@ async def admin_list_organization_members(
|
||||
total=total,
|
||||
page=pagination.page,
|
||||
limit=pagination.limit,
|
||||
items_count=len(member_responses)
|
||||
items_count=len(member_responses),
|
||||
)
|
||||
|
||||
return PaginatedResponse(data=member_responses, pagination=pagination_meta)
|
||||
@@ -662,14 +657,19 @@ async def admin_list_organization_members(
|
||||
except NotFoundError:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error listing organization members (admin): {str(e)}", exc_info=True)
|
||||
logger.error(
|
||||
f"Error listing organization members (admin): {e!s}", exc_info=True
|
||||
)
|
||||
raise
|
||||
|
||||
|
||||
class AddMemberRequest(BaseModel):
|
||||
"""Request to add a member to an organization."""
|
||||
|
||||
user_id: UUID = Field(..., description="User ID to add")
|
||||
role: OrganizationRole = Field(OrganizationRole.MEMBER, description="Role in organization")
|
||||
role: OrganizationRole = Field(
|
||||
OrganizationRole.MEMBER, description="Role in organization"
|
||||
)
|
||||
|
||||
|
||||
@router.post(
|
||||
@@ -677,13 +677,13 @@ class AddMemberRequest(BaseModel):
|
||||
response_model=MessageResponse,
|
||||
summary="Admin: Add Member to Organization",
|
||||
description="Add a user to an organization (admin only)",
|
||||
operation_id="admin_add_organization_member"
|
||||
operation_id="admin_add_organization_member",
|
||||
)
|
||||
async def admin_add_organization_member(
|
||||
org_id: UUID,
|
||||
request: AddMemberRequest,
|
||||
admin: User = Depends(require_superuser),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> Any:
|
||||
"""Add a user to an organization."""
|
||||
try:
|
||||
@@ -691,21 +691,18 @@ async def admin_add_organization_member(
|
||||
if not org:
|
||||
raise NotFoundError(
|
||||
message=f"Organization {org_id} not found",
|
||||
error_code=ErrorCode.NOT_FOUND
|
||||
error_code=ErrorCode.NOT_FOUND,
|
||||
)
|
||||
|
||||
user = await user_crud.get(db, id=request.user_id)
|
||||
if not user:
|
||||
raise NotFoundError(
|
||||
message=f"User {request.user_id} not found",
|
||||
error_code=ErrorCode.USER_NOT_FOUND
|
||||
error_code=ErrorCode.USER_NOT_FOUND,
|
||||
)
|
||||
|
||||
await organization_crud.add_user(
|
||||
db,
|
||||
organization_id=org_id,
|
||||
user_id=request.user_id,
|
||||
role=request.role
|
||||
db, organization_id=org_id, user_id=request.user_id, role=request.role
|
||||
)
|
||||
|
||||
logger.info(
|
||||
@@ -714,22 +711,21 @@ async def admin_add_organization_member(
|
||||
)
|
||||
|
||||
return MessageResponse(
|
||||
success=True,
|
||||
message=f"User {user.email} added to organization {org.name}"
|
||||
success=True, message=f"User {user.email} added to organization {org.name}"
|
||||
)
|
||||
|
||||
except ValueError as e:
|
||||
logger.warning(f"Failed to add user to organization: {str(e)}")
|
||||
logger.warning(f"Failed to add user to organization: {e!s}")
|
||||
# Use DuplicateError for "already exists" scenarios
|
||||
raise DuplicateError(
|
||||
message=str(e),
|
||||
error_code=ErrorCode.USER_ALREADY_EXISTS,
|
||||
field="user_id"
|
||||
message=str(e), error_code=ErrorCode.USER_ALREADY_EXISTS, field="user_id"
|
||||
)
|
||||
except NotFoundError:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error adding member to organization (admin): {str(e)}", exc_info=True)
|
||||
logger.error(
|
||||
f"Error adding member to organization (admin): {e!s}", exc_info=True
|
||||
)
|
||||
raise
|
||||
|
||||
|
||||
@@ -738,13 +734,13 @@ async def admin_add_organization_member(
|
||||
response_model=MessageResponse,
|
||||
summary="Admin: Remove Member from Organization",
|
||||
description="Remove a user from an organization (admin only)",
|
||||
operation_id="admin_remove_organization_member"
|
||||
operation_id="admin_remove_organization_member",
|
||||
)
|
||||
async def admin_remove_organization_member(
|
||||
org_id: UUID,
|
||||
user_id: UUID,
|
||||
admin: User = Depends(require_superuser),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> Any:
|
||||
"""Remove a user from an organization."""
|
||||
try:
|
||||
@@ -752,39 +748,40 @@ async def admin_remove_organization_member(
|
||||
if not org:
|
||||
raise NotFoundError(
|
||||
message=f"Organization {org_id} not found",
|
||||
error_code=ErrorCode.NOT_FOUND
|
||||
error_code=ErrorCode.NOT_FOUND,
|
||||
)
|
||||
|
||||
user = await user_crud.get(db, id=user_id)
|
||||
if not user:
|
||||
raise NotFoundError(
|
||||
message=f"User {user_id} not found",
|
||||
error_code=ErrorCode.USER_NOT_FOUND
|
||||
message=f"User {user_id} not found", error_code=ErrorCode.USER_NOT_FOUND
|
||||
)
|
||||
|
||||
success = await organization_crud.remove_user(
|
||||
db,
|
||||
organization_id=org_id,
|
||||
user_id=user_id
|
||||
db, organization_id=org_id, user_id=user_id
|
||||
)
|
||||
|
||||
if not success:
|
||||
raise NotFoundError(
|
||||
message="User is not a member of this organization",
|
||||
error_code=ErrorCode.NOT_FOUND
|
||||
error_code=ErrorCode.NOT_FOUND,
|
||||
)
|
||||
|
||||
logger.info(f"Admin {admin.email} removed user {user.email} from organization {org.name}")
|
||||
logger.info(
|
||||
f"Admin {admin.email} removed user {user.email} from organization {org.name}"
|
||||
)
|
||||
|
||||
return MessageResponse(
|
||||
success=True,
|
||||
message=f"User {user.email} removed from organization {org.name}"
|
||||
message=f"User {user.email} removed from organization {org.name}",
|
||||
)
|
||||
|
||||
except NotFoundError:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error removing member from organization (admin): {str(e)}", exc_info=True)
|
||||
logger.error(
|
||||
f"Error removing member from organization (admin): {e!s}", exc_info=True
|
||||
)
|
||||
raise
|
||||
|
||||
|
||||
@@ -792,6 +789,7 @@ async def admin_remove_organization_member(
|
||||
# Session Management Endpoints
|
||||
# ============================================================================
|
||||
|
||||
|
||||
@router.get(
|
||||
"/sessions",
|
||||
response_model=PaginatedResponse[AdminSessionResponse],
|
||||
@@ -802,13 +800,13 @@ async def admin_remove_organization_member(
|
||||
Returns paginated list of sessions with user information.
|
||||
Useful for admin dashboard statistics and session monitoring.
|
||||
""",
|
||||
operation_id="admin_list_sessions"
|
||||
operation_id="admin_list_sessions",
|
||||
)
|
||||
async def admin_list_sessions(
|
||||
pagination: PaginationParams = Depends(),
|
||||
is_active: Optional[bool] = Query(None, description="Filter by active status"),
|
||||
is_active: bool | None = Query(None, description="Filter by active status"),
|
||||
admin: User = Depends(require_superuser),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> Any:
|
||||
"""List all sessions across all users with filtering and pagination."""
|
||||
try:
|
||||
@@ -818,7 +816,7 @@ async def admin_list_sessions(
|
||||
skip=pagination.offset,
|
||||
limit=pagination.limit,
|
||||
active_only=is_active if is_active is not None else True,
|
||||
with_user=True
|
||||
with_user=True,
|
||||
)
|
||||
|
||||
# Build response objects with user information
|
||||
@@ -847,21 +845,23 @@ async def admin_list_sessions(
|
||||
last_used_at=session.last_used_at,
|
||||
created_at=session.created_at,
|
||||
expires_at=session.expires_at,
|
||||
is_active=session.is_active
|
||||
is_active=session.is_active,
|
||||
)
|
||||
session_responses.append(session_response)
|
||||
|
||||
logger.info(f"Admin {admin.email} listed {len(session_responses)} sessions (total: {total})")
|
||||
logger.info(
|
||||
f"Admin {admin.email} listed {len(session_responses)} sessions (total: {total})"
|
||||
)
|
||||
|
||||
pagination_meta = create_pagination_meta(
|
||||
total=total,
|
||||
page=pagination.page,
|
||||
limit=pagination.limit,
|
||||
items_count=len(session_responses)
|
||||
items_count=len(session_responses),
|
||||
)
|
||||
|
||||
return PaginatedResponse(data=session_responses, pagination=pagination_meta)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error listing sessions (admin): {str(e)}", exc_info=True)
|
||||
logger.error(f"Error listing sessions (admin): {e!s}", exc_info=True)
|
||||
raise
|
||||
|
||||
@@ -1,39 +1,43 @@
|
||||
# app/api/routes/auth.py
|
||||
import logging
|
||||
import os
|
||||
from datetime import datetime, timezone
|
||||
from datetime import UTC, datetime
|
||||
from typing import Any
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, status, Request
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request, status
|
||||
from fastapi.security import OAuth2PasswordRequestForm
|
||||
from slowapi import Limiter
|
||||
from slowapi.util import get_remote_address
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.api.dependencies.auth import get_current_user
|
||||
from app.core.auth import TokenExpiredError, TokenInvalidError, decode_token
|
||||
from app.core.auth import get_password_hash
|
||||
from app.core.auth import (
|
||||
TokenExpiredError,
|
||||
TokenInvalidError,
|
||||
decode_token,
|
||||
get_password_hash,
|
||||
)
|
||||
from app.core.database import get_db
|
||||
from app.core.exceptions import (
|
||||
AuthenticationError as AuthError,
|
||||
DatabaseError,
|
||||
ErrorCode
|
||||
ErrorCode,
|
||||
)
|
||||
from app.crud.session import session as session_crud
|
||||
from app.crud.user import user as user_crud
|
||||
from app.models.user import User
|
||||
from app.schemas.common import MessageResponse
|
||||
from app.schemas.sessions import SessionCreate, LogoutRequest
|
||||
from app.schemas.sessions import LogoutRequest, SessionCreate
|
||||
from app.schemas.users import (
|
||||
LoginRequest,
|
||||
PasswordResetConfirm,
|
||||
PasswordResetRequest,
|
||||
RefreshTokenRequest,
|
||||
Token,
|
||||
UserCreate,
|
||||
UserResponse,
|
||||
Token,
|
||||
LoginRequest,
|
||||
RefreshTokenRequest,
|
||||
PasswordResetRequest,
|
||||
PasswordResetConfirm
|
||||
)
|
||||
from app.services.auth_service import AuthService, AuthenticationError
|
||||
from app.services.auth_service import AuthenticationError, AuthService
|
||||
from app.services.email_service import email_service
|
||||
from app.utils.device import extract_device_info
|
||||
from app.utils.security import create_password_reset_token, verify_password_reset_token
|
||||
@@ -54,7 +58,7 @@ async def _create_login_session(
|
||||
request: Request,
|
||||
user: User,
|
||||
tokens: Token,
|
||||
login_type: str = "login"
|
||||
login_type: str = "login",
|
||||
) -> None:
|
||||
"""
|
||||
Create a session record for successful login.
|
||||
@@ -81,8 +85,8 @@ async def _create_login_session(
|
||||
device_id=device_info.device_id,
|
||||
ip_address=device_info.ip_address,
|
||||
user_agent=device_info.user_agent,
|
||||
last_used_at=datetime.now(timezone.utc),
|
||||
expires_at=datetime.fromtimestamp(refresh_payload.exp, tz=timezone.utc),
|
||||
last_used_at=datetime.now(UTC),
|
||||
expires_at=datetime.fromtimestamp(refresh_payload.exp, tz=UTC),
|
||||
location_city=device_info.location_city,
|
||||
location_country=device_info.location_country,
|
||||
)
|
||||
@@ -95,15 +99,20 @@ async def _create_login_session(
|
||||
)
|
||||
except Exception as session_err:
|
||||
# Log but don't fail login if session creation fails
|
||||
logger.error(f"Failed to create session for {user.email}: {str(session_err)}", exc_info=True)
|
||||
logger.error(
|
||||
f"Failed to create session for {user.email}: {session_err!s}", exc_info=True
|
||||
)
|
||||
|
||||
|
||||
@router.post("/register", response_model=UserResponse, status_code=status.HTTP_201_CREATED, operation_id="register")
|
||||
@router.post(
|
||||
"/register",
|
||||
response_model=UserResponse,
|
||||
status_code=status.HTTP_201_CREATED,
|
||||
operation_id="register",
|
||||
)
|
||||
@limiter.limit(f"{5 * RATE_MULTIPLIER}/minute")
|
||||
async def register_user(
|
||||
request: Request,
|
||||
user_data: UserCreate,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
request: Request, user_data: UserCreate, db: AsyncSession = Depends(get_db)
|
||||
) -> Any:
|
||||
"""
|
||||
Register a new user.
|
||||
@@ -116,25 +125,23 @@ async def register_user(
|
||||
return user
|
||||
except AuthenticationError as e:
|
||||
# SECURITY: Don't reveal if email exists - generic error message
|
||||
logger.warning(f"Registration failed: {str(e)}")
|
||||
logger.warning(f"Registration failed: {e!s}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Registration failed. Please check your information and try again."
|
||||
detail="Registration failed. Please check your information and try again.",
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error during registration: {str(e)}", exc_info=True)
|
||||
logger.error(f"Unexpected error during registration: {e!s}", exc_info=True)
|
||||
raise DatabaseError(
|
||||
message="An unexpected error occurred. Please try again later.",
|
||||
error_code=ErrorCode.INTERNAL_ERROR
|
||||
error_code=ErrorCode.INTERNAL_ERROR,
|
||||
)
|
||||
|
||||
|
||||
@router.post("/login", response_model=Token, operation_id="login")
|
||||
@limiter.limit(f"{10 * RATE_MULTIPLIER}/minute")
|
||||
async def login(
|
||||
request: Request,
|
||||
login_data: LoginRequest,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
request: Request, login_data: LoginRequest, db: AsyncSession = Depends(get_db)
|
||||
) -> Any:
|
||||
"""
|
||||
Login with username and password.
|
||||
@@ -146,14 +153,16 @@ async def login(
|
||||
"""
|
||||
try:
|
||||
# Attempt to authenticate the user
|
||||
user = await AuthService.authenticate_user(db, login_data.email, login_data.password)
|
||||
user = await AuthService.authenticate_user(
|
||||
db, login_data.email, login_data.password
|
||||
)
|
||||
|
||||
# Explicitly check for None result and raise correct exception
|
||||
if user is None:
|
||||
logger.warning(f"Invalid login attempt for: {login_data.email}")
|
||||
raise AuthError(
|
||||
message="Invalid email or password",
|
||||
error_code=ErrorCode.INVALID_CREDENTIALS
|
||||
error_code=ErrorCode.INVALID_CREDENTIALS,
|
||||
)
|
||||
|
||||
# User is authenticated, generate tokens
|
||||
@@ -166,29 +175,26 @@ async def login(
|
||||
|
||||
except AuthenticationError as e:
|
||||
# Handle specific authentication errors like inactive accounts
|
||||
logger.warning(f"Authentication failed: {str(e)}")
|
||||
raise AuthError(
|
||||
message=str(e),
|
||||
error_code=ErrorCode.INVALID_CREDENTIALS
|
||||
)
|
||||
logger.warning(f"Authentication failed: {e!s}")
|
||||
raise AuthError(message=str(e), error_code=ErrorCode.INVALID_CREDENTIALS)
|
||||
except AuthError:
|
||||
# Re-raise custom auth exceptions without modification
|
||||
raise
|
||||
except Exception as e:
|
||||
# Handle unexpected errors
|
||||
logger.error(f"Unexpected error during login: {str(e)}", exc_info=True)
|
||||
logger.error(f"Unexpected error during login: {e!s}", exc_info=True)
|
||||
raise DatabaseError(
|
||||
message="An unexpected error occurred. Please try again later.",
|
||||
error_code=ErrorCode.INTERNAL_ERROR
|
||||
error_code=ErrorCode.INTERNAL_ERROR,
|
||||
)
|
||||
|
||||
|
||||
@router.post("/login/oauth", response_model=Token, operation_id='login_oauth')
|
||||
@router.post("/login/oauth", response_model=Token, operation_id="login_oauth")
|
||||
@limiter.limit("10/minute")
|
||||
async def login_oauth(
|
||||
request: Request,
|
||||
form_data: OAuth2PasswordRequestForm = Depends(),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
request: Request,
|
||||
form_data: OAuth2PasswordRequestForm = Depends(),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> Any:
|
||||
"""
|
||||
OAuth2-compatible login endpoint, used by the OpenAPI UI.
|
||||
@@ -199,12 +205,14 @@ async def login_oauth(
|
||||
Access and refresh tokens.
|
||||
"""
|
||||
try:
|
||||
user = await AuthService.authenticate_user(db, form_data.username, form_data.password)
|
||||
user = await AuthService.authenticate_user(
|
||||
db, form_data.username, form_data.password
|
||||
)
|
||||
|
||||
if user is None:
|
||||
raise AuthError(
|
||||
message="Invalid email or password",
|
||||
error_code=ErrorCode.INVALID_CREDENTIALS
|
||||
error_code=ErrorCode.INVALID_CREDENTIALS,
|
||||
)
|
||||
|
||||
# Generate tokens
|
||||
@@ -216,28 +224,25 @@ async def login_oauth(
|
||||
# Return full token response with user data
|
||||
return tokens
|
||||
except AuthenticationError as e:
|
||||
logger.warning(f"OAuth authentication failed: {str(e)}")
|
||||
raise AuthError(
|
||||
message=str(e),
|
||||
error_code=ErrorCode.INVALID_CREDENTIALS
|
||||
)
|
||||
logger.warning(f"OAuth authentication failed: {e!s}")
|
||||
raise AuthError(message=str(e), error_code=ErrorCode.INVALID_CREDENTIALS)
|
||||
except AuthError:
|
||||
# Re-raise custom auth exceptions without modification
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error during OAuth login: {str(e)}", exc_info=True)
|
||||
logger.error(f"Unexpected error during OAuth login: {e!s}", exc_info=True)
|
||||
raise DatabaseError(
|
||||
message="An unexpected error occurred. Please try again later.",
|
||||
error_code=ErrorCode.INTERNAL_ERROR
|
||||
error_code=ErrorCode.INTERNAL_ERROR,
|
||||
)
|
||||
|
||||
|
||||
@router.post("/refresh", response_model=Token, operation_id="refresh_token")
|
||||
@limiter.limit("30/minute")
|
||||
async def refresh_token(
|
||||
request: Request,
|
||||
refresh_data: RefreshTokenRequest,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
request: Request,
|
||||
refresh_data: RefreshTokenRequest,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> Any:
|
||||
"""
|
||||
Refresh access token using a refresh token.
|
||||
@@ -249,13 +254,17 @@ async def refresh_token(
|
||||
"""
|
||||
try:
|
||||
# Decode the refresh token to get the JTI
|
||||
refresh_payload = decode_token(refresh_data.refresh_token, verify_type="refresh")
|
||||
refresh_payload = decode_token(
|
||||
refresh_data.refresh_token, verify_type="refresh"
|
||||
)
|
||||
|
||||
# Check if session exists and is active
|
||||
session = await session_crud.get_active_by_jti(db, jti=refresh_payload.jti)
|
||||
|
||||
if not session:
|
||||
logger.warning(f"Refresh token used for inactive or non-existent session: {refresh_payload.jti}")
|
||||
logger.warning(
|
||||
f"Refresh token used for inactive or non-existent session: {refresh_payload.jti}"
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Session has been revoked. Please log in again.",
|
||||
@@ -274,10 +283,12 @@ async def refresh_token(
|
||||
db,
|
||||
session=session,
|
||||
new_jti=new_refresh_payload.jti,
|
||||
new_expires_at=datetime.fromtimestamp(new_refresh_payload.exp, tz=timezone.utc)
|
||||
new_expires_at=datetime.fromtimestamp(new_refresh_payload.exp, tz=UTC),
|
||||
)
|
||||
except Exception as session_err:
|
||||
logger.error(f"Failed to update session {session.id}: {str(session_err)}", exc_info=True)
|
||||
logger.error(
|
||||
f"Failed to update session {session.id}: {session_err!s}", exc_info=True
|
||||
)
|
||||
# Continue anyway - tokens are already issued
|
||||
|
||||
return tokens
|
||||
@@ -300,10 +311,10 @@ async def refresh_token(
|
||||
# Re-raise HTTP exceptions (like session revoked)
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error during token refresh: {str(e)}")
|
||||
logger.error(f"Unexpected error during token refresh: {e!s}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="An unexpected error occurred. Please try again later."
|
||||
detail="An unexpected error occurred. Please try again later.",
|
||||
)
|
||||
|
||||
|
||||
@@ -320,13 +331,13 @@ async def refresh_token(
|
||||
|
||||
**Rate Limit**: 3 requests/minute
|
||||
""",
|
||||
operation_id="request_password_reset"
|
||||
operation_id="request_password_reset",
|
||||
)
|
||||
@limiter.limit("3/minute")
|
||||
async def request_password_reset(
|
||||
request: Request,
|
||||
reset_request: PasswordResetRequest,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> Any:
|
||||
"""
|
||||
Request a password reset.
|
||||
@@ -345,26 +356,26 @@ async def request_password_reset(
|
||||
|
||||
# Send password reset email
|
||||
await email_service.send_password_reset_email(
|
||||
to_email=user.email,
|
||||
reset_token=reset_token,
|
||||
user_name=user.first_name
|
||||
to_email=user.email, reset_token=reset_token, user_name=user.first_name
|
||||
)
|
||||
logger.info(f"Password reset requested for {user.email}")
|
||||
else:
|
||||
# Log attempt but don't reveal if email exists
|
||||
logger.warning(f"Password reset requested for non-existent or inactive email: {reset_request.email}")
|
||||
logger.warning(
|
||||
f"Password reset requested for non-existent or inactive email: {reset_request.email}"
|
||||
)
|
||||
|
||||
# Always return success to prevent email enumeration
|
||||
return MessageResponse(
|
||||
success=True,
|
||||
message="If your email is registered, you will receive a password reset link shortly"
|
||||
message="If your email is registered, you will receive a password reset link shortly",
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing password reset request: {str(e)}", exc_info=True)
|
||||
logger.error(f"Error processing password reset request: {e!s}", exc_info=True)
|
||||
# Still return success to prevent information leakage
|
||||
return MessageResponse(
|
||||
success=True,
|
||||
message="If your email is registered, you will receive a password reset link shortly"
|
||||
message="If your email is registered, you will receive a password reset link shortly",
|
||||
)
|
||||
|
||||
|
||||
@@ -378,13 +389,13 @@ async def request_password_reset(
|
||||
|
||||
**Rate Limit**: 5 requests/minute
|
||||
""",
|
||||
operation_id="confirm_password_reset"
|
||||
operation_id="confirm_password_reset",
|
||||
)
|
||||
@limiter.limit("5/minute")
|
||||
async def confirm_password_reset(
|
||||
request: Request,
|
||||
reset_confirm: PasswordResetConfirm,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> Any:
|
||||
"""
|
||||
Confirm password reset with token.
|
||||
@@ -398,7 +409,7 @@ async def confirm_password_reset(
|
||||
if not email:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Invalid or expired password reset token"
|
||||
detail="Invalid or expired password reset token",
|
||||
)
|
||||
|
||||
# Look up user
|
||||
@@ -406,14 +417,13 @@ async def confirm_password_reset(
|
||||
|
||||
if not user:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="User not found"
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail="User not found"
|
||||
)
|
||||
|
||||
if not user.is_active:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="User account is inactive"
|
||||
detail="User account is inactive",
|
||||
)
|
||||
|
||||
# Update password
|
||||
@@ -424,29 +434,33 @@ async def confirm_password_reset(
|
||||
# SECURITY: Invalidate all existing sessions after password reset
|
||||
# This prevents stolen sessions from being used after password change
|
||||
from app.crud.session import session as session_crud
|
||||
|
||||
try:
|
||||
deactivated_count = await session_crud.deactivate_all_user_sessions(
|
||||
db,
|
||||
user_id=str(user.id)
|
||||
db, user_id=str(user.id)
|
||||
)
|
||||
logger.info(
|
||||
f"Password reset successful for {user.email}, invalidated {deactivated_count} sessions"
|
||||
)
|
||||
logger.info(f"Password reset successful for {user.email}, invalidated {deactivated_count} sessions")
|
||||
except Exception as session_error:
|
||||
# Log but don't fail password reset if session invalidation fails
|
||||
logger.error(f"Failed to invalidate sessions after password reset: {str(session_error)}")
|
||||
logger.error(
|
||||
f"Failed to invalidate sessions after password reset: {session_error!s}"
|
||||
)
|
||||
|
||||
return MessageResponse(
|
||||
success=True,
|
||||
message="Password has been reset successfully. All devices have been logged out for security. You can now log in with your new password."
|
||||
message="Password has been reset successfully. All devices have been logged out for security. You can now log in with your new password.",
|
||||
)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error confirming password reset: {str(e)}", exc_info=True)
|
||||
logger.error(f"Error confirming password reset: {e!s}", exc_info=True)
|
||||
await db.rollback()
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="An error occurred while resetting your password"
|
||||
detail="An error occurred while resetting your password",
|
||||
)
|
||||
|
||||
|
||||
@@ -464,14 +478,14 @@ async def confirm_password_reset(
|
||||
|
||||
**Rate Limit**: 10 requests/minute
|
||||
""",
|
||||
operation_id="logout"
|
||||
operation_id="logout",
|
||||
)
|
||||
@limiter.limit("10/minute")
|
||||
async def logout(
|
||||
request: Request,
|
||||
logout_request: LogoutRequest,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> Any:
|
||||
"""
|
||||
Logout from current device by deactivating the session.
|
||||
@@ -487,15 +501,14 @@ async def logout(
|
||||
try:
|
||||
# Decode refresh token to get JTI
|
||||
try:
|
||||
refresh_payload = decode_token(logout_request.refresh_token, verify_type="refresh")
|
||||
refresh_payload = decode_token(
|
||||
logout_request.refresh_token, verify_type="refresh"
|
||||
)
|
||||
except (TokenExpiredError, TokenInvalidError) as e:
|
||||
# Even if token is expired/invalid, try to deactivate session
|
||||
logger.warning(f"Logout with invalid/expired token: {str(e)}")
|
||||
logger.warning(f"Logout with invalid/expired token: {e!s}")
|
||||
# Don't fail - return success anyway
|
||||
return MessageResponse(
|
||||
success=True,
|
||||
message="Logged out successfully"
|
||||
)
|
||||
return MessageResponse(success=True, message="Logged out successfully")
|
||||
|
||||
# Find the session by JTI
|
||||
session = await session_crud.get_by_jti(db, jti=refresh_payload.jti)
|
||||
@@ -509,7 +522,7 @@ async def logout(
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="You can only logout your own sessions"
|
||||
detail="You can only logout your own sessions",
|
||||
)
|
||||
|
||||
# Deactivate the session
|
||||
@@ -522,22 +535,20 @@ async def logout(
|
||||
else:
|
||||
# Session not found - maybe already deleted or never existed
|
||||
# Return success anyway (idempotent)
|
||||
logger.info(f"Logout requested for non-existent session (JTI: {refresh_payload.jti})")
|
||||
logger.info(
|
||||
f"Logout requested for non-existent session (JTI: {refresh_payload.jti})"
|
||||
)
|
||||
|
||||
return MessageResponse(
|
||||
success=True,
|
||||
message="Logged out successfully"
|
||||
)
|
||||
return MessageResponse(success=True, message="Logged out successfully")
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error during logout for user {current_user.id}: {str(e)}", exc_info=True)
|
||||
# Don't expose error details
|
||||
return MessageResponse(
|
||||
success=True,
|
||||
message="Logged out successfully"
|
||||
logger.error(
|
||||
f"Error during logout for user {current_user.id}: {e!s}", exc_info=True
|
||||
)
|
||||
# Don't expose error details
|
||||
return MessageResponse(success=True, message="Logged out successfully")
|
||||
|
||||
|
||||
@router.post(
|
||||
@@ -553,13 +564,13 @@ async def logout(
|
||||
|
||||
**Rate Limit**: 5 requests/minute
|
||||
""",
|
||||
operation_id="logout_all"
|
||||
operation_id="logout_all",
|
||||
)
|
||||
@limiter.limit("5/minute")
|
||||
async def logout_all(
|
||||
request: Request,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> Any:
|
||||
"""
|
||||
Logout from all devices by deactivating all user sessions.
|
||||
@@ -573,19 +584,25 @@ async def logout_all(
|
||||
"""
|
||||
try:
|
||||
# Deactivate all sessions for this user
|
||||
count = await session_crud.deactivate_all_user_sessions(db, user_id=str(current_user.id))
|
||||
count = await session_crud.deactivate_all_user_sessions(
|
||||
db, user_id=str(current_user.id)
|
||||
)
|
||||
|
||||
logger.info(f"User {current_user.id} logged out from all devices ({count} sessions)")
|
||||
logger.info(
|
||||
f"User {current_user.id} logged out from all devices ({count} sessions)"
|
||||
)
|
||||
|
||||
return MessageResponse(
|
||||
success=True,
|
||||
message=f"Successfully logged out from all devices ({count} sessions terminated)"
|
||||
message=f"Successfully logged out from all devices ({count} sessions terminated)",
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error during logout-all for user {current_user.id}: {str(e)}", exc_info=True)
|
||||
logger.error(
|
||||
f"Error during logout-all for user {current_user.id}: {e!s}", exc_info=True
|
||||
)
|
||||
await db.rollback()
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="An error occurred while logging out"
|
||||
detail="An error occurred while logging out",
|
||||
)
|
||||
|
||||
@@ -4,8 +4,9 @@ Organization endpoints for regular users.
|
||||
|
||||
These endpoints allow users to view and manage organizations they belong to.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Any, List
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import APIRouter, Depends, Query
|
||||
@@ -14,18 +15,18 @@ from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from app.api.dependencies.auth import get_current_user
|
||||
from app.api.dependencies.permissions import require_org_admin, require_org_membership
|
||||
from app.core.database import get_db
|
||||
from app.core.exceptions import NotFoundError, ErrorCode
|
||||
from app.core.exceptions import ErrorCode, NotFoundError
|
||||
from app.crud.organization import organization as organization_crud
|
||||
from app.models.user import User
|
||||
from app.schemas.common import (
|
||||
PaginationParams,
|
||||
PaginatedResponse,
|
||||
create_pagination_meta
|
||||
PaginationParams,
|
||||
create_pagination_meta,
|
||||
)
|
||||
from app.schemas.organizations import (
|
||||
OrganizationResponse,
|
||||
OrganizationMemberResponse,
|
||||
OrganizationUpdate
|
||||
OrganizationResponse,
|
||||
OrganizationUpdate,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -35,15 +36,15 @@ router = APIRouter()
|
||||
|
||||
@router.get(
|
||||
"/me",
|
||||
response_model=List[OrganizationResponse],
|
||||
response_model=list[OrganizationResponse],
|
||||
summary="Get My Organizations",
|
||||
description="Get all organizations the current user belongs to",
|
||||
operation_id="get_my_organizations"
|
||||
operation_id="get_my_organizations",
|
||||
)
|
||||
async def get_my_organizations(
|
||||
is_active: bool = Query(True, description="Filter by active membership"),
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> Any:
|
||||
"""
|
||||
Get all organizations the current user belongs to.
|
||||
@@ -54,15 +55,13 @@ async def get_my_organizations(
|
||||
try:
|
||||
# Get all org data in single query with JOIN and subquery
|
||||
orgs_data = await organization_crud.get_user_organizations_with_details(
|
||||
db,
|
||||
user_id=current_user.id,
|
||||
is_active=is_active
|
||||
db, user_id=current_user.id, is_active=is_active
|
||||
)
|
||||
|
||||
# Transform to response objects
|
||||
orgs_with_data = []
|
||||
for item in orgs_data:
|
||||
org = item['organization']
|
||||
org = item["organization"]
|
||||
org_dict = {
|
||||
"id": org.id,
|
||||
"name": org.name,
|
||||
@@ -72,14 +71,14 @@ async def get_my_organizations(
|
||||
"settings": org.settings,
|
||||
"created_at": org.created_at,
|
||||
"updated_at": org.updated_at,
|
||||
"member_count": item['member_count']
|
||||
"member_count": item["member_count"],
|
||||
}
|
||||
orgs_with_data.append(OrganizationResponse(**org_dict))
|
||||
|
||||
return orgs_with_data
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting user organizations: {str(e)}", exc_info=True)
|
||||
logger.error(f"Error getting user organizations: {e!s}", exc_info=True)
|
||||
raise
|
||||
|
||||
|
||||
@@ -88,12 +87,12 @@ async def get_my_organizations(
|
||||
response_model=OrganizationResponse,
|
||||
summary="Get Organization Details",
|
||||
description="Get details of an organization the user belongs to",
|
||||
operation_id="get_organization"
|
||||
operation_id="get_organization",
|
||||
)
|
||||
async def get_organization(
|
||||
organization_id: UUID,
|
||||
current_user: User = Depends(require_org_membership),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> Any:
|
||||
"""
|
||||
Get details of a specific organization.
|
||||
@@ -105,7 +104,7 @@ async def get_organization(
|
||||
if not org: # pragma: no cover - Permission check prevents this (see docs/UNREACHABLE_DEFENSIVE_CODE_ANALYSIS.md)
|
||||
raise NotFoundError(
|
||||
detail=f"Organization {organization_id} not found",
|
||||
error_code=ErrorCode.NOT_FOUND
|
||||
error_code=ErrorCode.NOT_FOUND,
|
||||
)
|
||||
|
||||
org_dict = {
|
||||
@@ -117,14 +116,16 @@ async def get_organization(
|
||||
"settings": org.settings,
|
||||
"created_at": org.created_at,
|
||||
"updated_at": org.updated_at,
|
||||
"member_count": await organization_crud.get_member_count(db, organization_id=org.id)
|
||||
"member_count": await organization_crud.get_member_count(
|
||||
db, organization_id=org.id
|
||||
),
|
||||
}
|
||||
return OrganizationResponse(**org_dict)
|
||||
|
||||
except NotFoundError: # pragma: no cover - See above
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting organization: {str(e)}", exc_info=True)
|
||||
logger.error(f"Error getting organization: {e!s}", exc_info=True)
|
||||
raise
|
||||
|
||||
|
||||
@@ -133,14 +134,14 @@ async def get_organization(
|
||||
response_model=PaginatedResponse[OrganizationMemberResponse],
|
||||
summary="Get Organization Members",
|
||||
description="Get all members of an organization (members can view)",
|
||||
operation_id="get_organization_members"
|
||||
operation_id="get_organization_members",
|
||||
)
|
||||
async def get_organization_members(
|
||||
organization_id: UUID,
|
||||
pagination: PaginationParams = Depends(),
|
||||
is_active: bool = Query(True, description="Filter by active status"),
|
||||
current_user: User = Depends(require_org_membership),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> Any:
|
||||
"""
|
||||
Get all members of an organization.
|
||||
@@ -153,7 +154,7 @@ async def get_organization_members(
|
||||
organization_id=organization_id,
|
||||
skip=pagination.offset,
|
||||
limit=pagination.limit,
|
||||
is_active=is_active
|
||||
is_active=is_active,
|
||||
)
|
||||
|
||||
member_responses = [OrganizationMemberResponse(**member) for member in members]
|
||||
@@ -162,13 +163,13 @@ async def get_organization_members(
|
||||
total=total,
|
||||
page=pagination.page,
|
||||
limit=pagination.limit,
|
||||
items_count=len(member_responses)
|
||||
items_count=len(member_responses),
|
||||
)
|
||||
|
||||
return PaginatedResponse(data=member_responses, pagination=pagination_meta)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting organization members: {str(e)}", exc_info=True)
|
||||
logger.error(f"Error getting organization members: {e!s}", exc_info=True)
|
||||
raise
|
||||
|
||||
|
||||
@@ -177,13 +178,13 @@ async def get_organization_members(
|
||||
response_model=OrganizationResponse,
|
||||
summary="Update Organization",
|
||||
description="Update organization details (admin/owner only)",
|
||||
operation_id="update_organization"
|
||||
operation_id="update_organization",
|
||||
)
|
||||
async def update_organization(
|
||||
organization_id: UUID,
|
||||
org_in: OrganizationUpdate,
|
||||
current_user: User = Depends(require_org_admin),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> Any:
|
||||
"""
|
||||
Update organization details.
|
||||
@@ -195,11 +196,13 @@ async def update_organization(
|
||||
if not org: # pragma: no cover - Permission check prevents this (see docs/UNREACHABLE_DEFENSIVE_CODE_ANALYSIS.md)
|
||||
raise NotFoundError(
|
||||
detail=f"Organization {organization_id} not found",
|
||||
error_code=ErrorCode.NOT_FOUND
|
||||
error_code=ErrorCode.NOT_FOUND,
|
||||
)
|
||||
|
||||
updated_org = await organization_crud.update(db, db_obj=org, obj_in=org_in)
|
||||
logger.info(f"User {current_user.email} updated organization {updated_org.name}")
|
||||
logger.info(
|
||||
f"User {current_user.email} updated organization {updated_org.name}"
|
||||
)
|
||||
|
||||
org_dict = {
|
||||
"id": updated_org.id,
|
||||
@@ -210,12 +213,14 @@ async def update_organization(
|
||||
"settings": updated_org.settings,
|
||||
"created_at": updated_org.created_at,
|
||||
"updated_at": updated_org.updated_at,
|
||||
"member_count": await organization_crud.get_member_count(db, organization_id=updated_org.id)
|
||||
"member_count": await organization_crud.get_member_count(
|
||||
db, organization_id=updated_org.id
|
||||
),
|
||||
}
|
||||
return OrganizationResponse(**org_dict)
|
||||
|
||||
except NotFoundError: # pragma: no cover - See above
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error updating organization: {str(e)}", exc_info=True)
|
||||
logger.error(f"Error updating organization: {e!s}", exc_info=True)
|
||||
raise
|
||||
|
||||
@@ -3,11 +3,12 @@ Session management endpoints.
|
||||
|
||||
Allows users to view and manage their active sessions across devices.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, status, Request
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request, status
|
||||
from slowapi import Limiter
|
||||
from slowapi.util import get_remote_address
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
@@ -15,11 +16,11 @@ from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from app.api.dependencies.auth import get_current_user
|
||||
from app.core.auth import decode_token
|
||||
from app.core.database import get_db
|
||||
from app.core.exceptions import NotFoundError, AuthorizationError, ErrorCode
|
||||
from app.core.exceptions import AuthorizationError, ErrorCode, NotFoundError
|
||||
from app.crud.session import session as session_crud
|
||||
from app.models.user import User
|
||||
from app.schemas.common import MessageResponse
|
||||
from app.schemas.sessions import SessionResponse, SessionListResponse
|
||||
from app.schemas.sessions import SessionListResponse, SessionResponse
|
||||
|
||||
router = APIRouter()
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -39,13 +40,13 @@ limiter = Limiter(key_func=get_remote_address)
|
||||
|
||||
**Rate Limit**: 30 requests/minute
|
||||
""",
|
||||
operation_id="list_my_sessions"
|
||||
operation_id="list_my_sessions",
|
||||
)
|
||||
@limiter.limit("30/minute")
|
||||
async def list_my_sessions(
|
||||
request: Request,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> Any:
|
||||
"""
|
||||
List all active sessions for the current user.
|
||||
@@ -60,18 +61,15 @@ async def list_my_sessions(
|
||||
try:
|
||||
# Get all active sessions for user
|
||||
sessions = await session_crud.get_user_sessions(
|
||||
db,
|
||||
user_id=str(current_user.id),
|
||||
active_only=True
|
||||
db, user_id=str(current_user.id), active_only=True
|
||||
)
|
||||
|
||||
# Try to identify current session from Authorization header
|
||||
current_session_jti = None
|
||||
auth_header = request.headers.get("authorization")
|
||||
if auth_header and auth_header.startswith("Bearer "):
|
||||
try:
|
||||
access_token = auth_header.split(" ")[1]
|
||||
token_payload = decode_token(access_token)
|
||||
decode_token(access_token)
|
||||
# Note: Access tokens don't have JTI by default, but we can try
|
||||
# For now, we'll mark current based on most recent activity
|
||||
except Exception:
|
||||
@@ -90,22 +88,27 @@ async def list_my_sessions(
|
||||
last_used_at=s.last_used_at,
|
||||
created_at=s.created_at,
|
||||
expires_at=s.expires_at,
|
||||
is_current=(s == sessions[0] if sessions else False) # Most recent = current
|
||||
is_current=(
|
||||
s == sessions[0] if sessions else False
|
||||
), # Most recent = current
|
||||
)
|
||||
session_responses.append(session_response)
|
||||
|
||||
logger.info(f"User {current_user.id} listed {len(session_responses)} active sessions")
|
||||
logger.info(
|
||||
f"User {current_user.id} listed {len(session_responses)} active sessions"
|
||||
)
|
||||
|
||||
return SessionListResponse(
|
||||
sessions=session_responses,
|
||||
total=len(session_responses)
|
||||
sessions=session_responses, total=len(session_responses)
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error listing sessions for user {current_user.id}: {str(e)}", exc_info=True)
|
||||
logger.error(
|
||||
f"Error listing sessions for user {current_user.id}: {e!s}", exc_info=True
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Failed to retrieve sessions"
|
||||
detail="Failed to retrieve sessions",
|
||||
)
|
||||
|
||||
|
||||
@@ -122,14 +125,14 @@ async def list_my_sessions(
|
||||
|
||||
**Rate Limit**: 10 requests/minute
|
||||
""",
|
||||
operation_id="revoke_session"
|
||||
operation_id="revoke_session",
|
||||
)
|
||||
@limiter.limit("10/minute")
|
||||
async def revoke_session(
|
||||
request: Request,
|
||||
session_id: UUID,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> Any:
|
||||
"""
|
||||
Revoke a specific session by ID.
|
||||
@@ -149,7 +152,7 @@ async def revoke_session(
|
||||
if not session:
|
||||
raise NotFoundError(
|
||||
message=f"Session {session_id} not found",
|
||||
error_code=ErrorCode.NOT_FOUND
|
||||
error_code=ErrorCode.NOT_FOUND,
|
||||
)
|
||||
|
||||
# Verify session belongs to current user
|
||||
@@ -160,7 +163,7 @@ async def revoke_session(
|
||||
)
|
||||
raise AuthorizationError(
|
||||
message="You can only revoke your own sessions",
|
||||
error_code=ErrorCode.INSUFFICIENT_PERMISSIONS
|
||||
error_code=ErrorCode.INSUFFICIENT_PERMISSIONS,
|
||||
)
|
||||
|
||||
# Deactivate the session
|
||||
@@ -173,16 +176,16 @@ async def revoke_session(
|
||||
|
||||
return MessageResponse(
|
||||
success=True,
|
||||
message=f"Session revoked: {session.device_name or 'Unknown device'}"
|
||||
message=f"Session revoked: {session.device_name or 'Unknown device'}",
|
||||
)
|
||||
|
||||
except (NotFoundError, AuthorizationError):
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error revoking session {session_id}: {str(e)}", exc_info=True)
|
||||
logger.error(f"Error revoking session {session_id}: {e!s}", exc_info=True)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Failed to revoke session"
|
||||
detail="Failed to revoke session",
|
||||
)
|
||||
|
||||
|
||||
@@ -198,13 +201,13 @@ async def revoke_session(
|
||||
|
||||
**Rate Limit**: 5 requests/minute
|
||||
""",
|
||||
operation_id="cleanup_expired_sessions"
|
||||
operation_id="cleanup_expired_sessions",
|
||||
)
|
||||
@limiter.limit("5/minute")
|
||||
async def cleanup_expired_sessions(
|
||||
request: Request,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> Any:
|
||||
"""
|
||||
Cleanup expired sessions for the current user.
|
||||
@@ -219,21 +222,24 @@ async def cleanup_expired_sessions(
|
||||
try:
|
||||
# Use optimized bulk DELETE instead of N individual deletes
|
||||
deleted_count = await session_crud.cleanup_expired_for_user(
|
||||
db,
|
||||
user_id=str(current_user.id)
|
||||
db, user_id=str(current_user.id)
|
||||
)
|
||||
|
||||
logger.info(f"User {current_user.id} cleaned up {deleted_count} expired sessions")
|
||||
logger.info(
|
||||
f"User {current_user.id} cleaned up {deleted_count} expired sessions"
|
||||
)
|
||||
|
||||
return MessageResponse(
|
||||
success=True,
|
||||
message=f"Cleaned up {deleted_count} expired sessions"
|
||||
success=True, message=f"Cleaned up {deleted_count} expired sessions"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error cleaning up sessions for user {current_user.id}: {str(e)}", exc_info=True)
|
||||
logger.error(
|
||||
f"Error cleaning up sessions for user {current_user.id}: {e!s}",
|
||||
exc_info=True,
|
||||
)
|
||||
await db.rollback()
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Failed to cleanup sessions"
|
||||
detail="Failed to cleanup sessions",
|
||||
)
|
||||
|
||||
@@ -1,33 +1,30 @@
|
||||
"""
|
||||
User management endpoints for CRUD operations.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Any, Optional
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import APIRouter, Depends, Query, status, Request
|
||||
from fastapi import APIRouter, Depends, Query, Request, status
|
||||
from slowapi import Limiter
|
||||
from slowapi.util import get_remote_address
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.api.dependencies.auth import get_current_user, get_current_superuser
|
||||
from app.api.dependencies.auth import get_current_superuser, get_current_user
|
||||
from app.core.database import get_db
|
||||
from app.core.exceptions import (
|
||||
NotFoundError,
|
||||
AuthorizationError,
|
||||
ErrorCode
|
||||
)
|
||||
from app.core.exceptions import AuthorizationError, ErrorCode, NotFoundError
|
||||
from app.crud.user import user as user_crud
|
||||
from app.models.user import User
|
||||
from app.schemas.common import (
|
||||
PaginationParams,
|
||||
PaginatedResponse,
|
||||
MessageResponse,
|
||||
PaginatedResponse,
|
||||
PaginationParams,
|
||||
SortParams,
|
||||
create_pagination_meta
|
||||
create_pagination_meta,
|
||||
)
|
||||
from app.schemas.users import UserResponse, UserUpdate, PasswordChange
|
||||
from app.services.auth_service import AuthService, AuthenticationError
|
||||
from app.schemas.users import PasswordChange, UserResponse, UserUpdate
|
||||
from app.services.auth_service import AuthenticationError, AuthService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -50,15 +47,15 @@ limiter = Limiter(key_func=get_remote_address)
|
||||
|
||||
**Rate Limit**: 60 requests/minute
|
||||
""",
|
||||
operation_id="list_users"
|
||||
operation_id="list_users",
|
||||
)
|
||||
async def list_users(
|
||||
pagination: PaginationParams = Depends(),
|
||||
sort: SortParams = Depends(),
|
||||
is_active: Optional[bool] = Query(None, description="Filter by active status"),
|
||||
is_superuser: Optional[bool] = Query(None, description="Filter by superuser status"),
|
||||
is_active: bool | None = Query(None, description="Filter by active status"),
|
||||
is_superuser: bool | None = Query(None, description="Filter by superuser status"),
|
||||
current_user: User = Depends(get_current_superuser),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> Any:
|
||||
"""
|
||||
List all users with pagination, filtering, and sorting.
|
||||
@@ -80,7 +77,7 @@ async def list_users(
|
||||
limit=pagination.limit,
|
||||
sort_by=sort.sort_by,
|
||||
sort_order=sort.sort_order.value if sort.sort_order else "asc",
|
||||
filters=filters if filters else None
|
||||
filters=filters if filters else None,
|
||||
)
|
||||
|
||||
# Create pagination metadata
|
||||
@@ -88,15 +85,12 @@ async def list_users(
|
||||
total=total,
|
||||
page=pagination.page,
|
||||
limit=pagination.limit,
|
||||
items_count=len(users)
|
||||
items_count=len(users),
|
||||
)
|
||||
|
||||
return PaginatedResponse(
|
||||
data=users,
|
||||
pagination=pagination_meta
|
||||
)
|
||||
return PaginatedResponse(data=users, pagination=pagination_meta)
|
||||
except Exception as e:
|
||||
logger.error(f"Error listing users: {str(e)}", exc_info=True)
|
||||
logger.error(f"Error listing users: {e!s}", exc_info=True)
|
||||
raise
|
||||
|
||||
|
||||
@@ -111,11 +105,9 @@ async def list_users(
|
||||
|
||||
**Rate Limit**: 60 requests/minute
|
||||
""",
|
||||
operation_id="get_current_user_profile"
|
||||
operation_id="get_current_user_profile",
|
||||
)
|
||||
def get_current_user_profile(
|
||||
current_user: User = Depends(get_current_user)
|
||||
) -> Any:
|
||||
def get_current_user_profile(current_user: User = Depends(get_current_user)) -> Any:
|
||||
"""Get current user's profile."""
|
||||
return current_user
|
||||
|
||||
@@ -133,12 +125,12 @@ def get_current_user_profile(
|
||||
|
||||
**Rate Limit**: 30 requests/minute
|
||||
""",
|
||||
operation_id="update_current_user"
|
||||
operation_id="update_current_user",
|
||||
)
|
||||
async def update_current_user(
|
||||
user_update: UserUpdate,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> Any:
|
||||
"""
|
||||
Update current user's profile.
|
||||
@@ -147,17 +139,17 @@ async def update_current_user(
|
||||
"""
|
||||
try:
|
||||
updated_user = await user_crud.update(
|
||||
db,
|
||||
db_obj=current_user,
|
||||
obj_in=user_update
|
||||
db, db_obj=current_user, obj_in=user_update
|
||||
)
|
||||
logger.info(f"User {current_user.id} updated their profile")
|
||||
return updated_user
|
||||
except ValueError as e:
|
||||
logger.error(f"Error updating user {current_user.id}: {str(e)}")
|
||||
logger.error(f"Error updating user {current_user.id}: {e!s}")
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error updating user {current_user.id}: {str(e)}", exc_info=True)
|
||||
logger.error(
|
||||
f"Unexpected error updating user {current_user.id}: {e!s}", exc_info=True
|
||||
)
|
||||
raise
|
||||
|
||||
|
||||
@@ -175,12 +167,12 @@ async def update_current_user(
|
||||
|
||||
**Rate Limit**: 60 requests/minute
|
||||
""",
|
||||
operation_id="get_user_by_id"
|
||||
operation_id="get_user_by_id",
|
||||
)
|
||||
async def get_user_by_id(
|
||||
user_id: UUID,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> Any:
|
||||
"""
|
||||
Get user by ID.
|
||||
@@ -194,7 +186,7 @@ async def get_user_by_id(
|
||||
)
|
||||
raise AuthorizationError(
|
||||
message="Not enough permissions to view this user",
|
||||
error_code=ErrorCode.INSUFFICIENT_PERMISSIONS
|
||||
error_code=ErrorCode.INSUFFICIENT_PERMISSIONS,
|
||||
)
|
||||
|
||||
# Get user
|
||||
@@ -202,7 +194,7 @@ async def get_user_by_id(
|
||||
if not user:
|
||||
raise NotFoundError(
|
||||
message=f"User with id {user_id} not found",
|
||||
error_code=ErrorCode.USER_NOT_FOUND
|
||||
error_code=ErrorCode.USER_NOT_FOUND,
|
||||
)
|
||||
|
||||
return user
|
||||
@@ -222,13 +214,13 @@ async def get_user_by_id(
|
||||
|
||||
**Rate Limit**: 30 requests/minute
|
||||
""",
|
||||
operation_id="update_user"
|
||||
operation_id="update_user",
|
||||
)
|
||||
async def update_user(
|
||||
user_id: UUID,
|
||||
user_update: UserUpdate,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> Any:
|
||||
"""
|
||||
Update user by ID.
|
||||
@@ -245,7 +237,7 @@ async def update_user(
|
||||
)
|
||||
raise AuthorizationError(
|
||||
message="Not enough permissions to update this user",
|
||||
error_code=ErrorCode.INSUFFICIENT_PERMISSIONS
|
||||
error_code=ErrorCode.INSUFFICIENT_PERMISSIONS,
|
||||
)
|
||||
|
||||
# Get user
|
||||
@@ -253,7 +245,7 @@ async def update_user(
|
||||
if not user:
|
||||
raise NotFoundError(
|
||||
message=f"User with id {user_id} not found",
|
||||
error_code=ErrorCode.USER_NOT_FOUND
|
||||
error_code=ErrorCode.USER_NOT_FOUND,
|
||||
)
|
||||
|
||||
try:
|
||||
@@ -261,10 +253,10 @@ async def update_user(
|
||||
logger.info(f"User {user_id} updated by {current_user.id}")
|
||||
return updated_user
|
||||
except ValueError as e:
|
||||
logger.error(f"Error updating user {user_id}: {str(e)}")
|
||||
logger.error(f"Error updating user {user_id}: {e!s}")
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error updating user {user_id}: {str(e)}", exc_info=True)
|
||||
logger.error(f"Unexpected error updating user {user_id}: {e!s}", exc_info=True)
|
||||
raise
|
||||
|
||||
|
||||
@@ -281,14 +273,14 @@ async def update_user(
|
||||
|
||||
**Rate Limit**: 5 requests/minute
|
||||
""",
|
||||
operation_id="change_current_user_password"
|
||||
operation_id="change_current_user_password",
|
||||
)
|
||||
@limiter.limit("5/minute")
|
||||
async def change_current_user_password(
|
||||
request: Request,
|
||||
password_change: PasswordChange,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> Any:
|
||||
"""
|
||||
Change current user's password.
|
||||
@@ -300,23 +292,23 @@ async def change_current_user_password(
|
||||
db=db,
|
||||
user_id=current_user.id,
|
||||
current_password=password_change.current_password,
|
||||
new_password=password_change.new_password
|
||||
new_password=password_change.new_password,
|
||||
)
|
||||
|
||||
if success:
|
||||
logger.info(f"User {current_user.id} changed their password")
|
||||
return MessageResponse(
|
||||
success=True,
|
||||
message="Password changed successfully"
|
||||
success=True, message="Password changed successfully"
|
||||
)
|
||||
except AuthenticationError as e:
|
||||
logger.warning(f"Failed password change attempt for user {current_user.id}: {str(e)}")
|
||||
logger.warning(
|
||||
f"Failed password change attempt for user {current_user.id}: {e!s}"
|
||||
)
|
||||
raise AuthorizationError(
|
||||
message=str(e),
|
||||
error_code=ErrorCode.INVALID_CREDENTIALS
|
||||
message=str(e), error_code=ErrorCode.INVALID_CREDENTIALS
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error changing password for user {current_user.id}: {str(e)}")
|
||||
logger.error(f"Error changing password for user {current_user.id}: {e!s}")
|
||||
raise
|
||||
|
||||
|
||||
@@ -335,12 +327,12 @@ async def change_current_user_password(
|
||||
|
||||
**Note**: This performs a hard delete. Consider implementing soft deletes for production.
|
||||
""",
|
||||
operation_id="delete_user"
|
||||
operation_id="delete_user",
|
||||
)
|
||||
async def delete_user(
|
||||
user_id: UUID,
|
||||
current_user: User = Depends(get_current_superuser),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> Any:
|
||||
"""
|
||||
Delete user by ID (superuser only).
|
||||
@@ -351,7 +343,7 @@ async def delete_user(
|
||||
if str(user_id) == str(current_user.id):
|
||||
raise AuthorizationError(
|
||||
message="Cannot delete your own account",
|
||||
error_code=ErrorCode.INSUFFICIENT_PERMISSIONS
|
||||
error_code=ErrorCode.INSUFFICIENT_PERMISSIONS,
|
||||
)
|
||||
|
||||
# Get user
|
||||
@@ -359,7 +351,7 @@ async def delete_user(
|
||||
if not user:
|
||||
raise NotFoundError(
|
||||
message=f"User with id {user_id} not found",
|
||||
error_code=ErrorCode.USER_NOT_FOUND
|
||||
error_code=ErrorCode.USER_NOT_FOUND,
|
||||
)
|
||||
|
||||
try:
|
||||
@@ -367,12 +359,11 @@ async def delete_user(
|
||||
await user_crud.soft_delete(db, id=str(user_id))
|
||||
logger.info(f"User {user_id} soft-deleted by {current_user.id}")
|
||||
return MessageResponse(
|
||||
success=True,
|
||||
message=f"User {user_id} deleted successfully"
|
||||
success=True, message=f"User {user_id} deleted successfully"
|
||||
)
|
||||
except ValueError as e:
|
||||
logger.error(f"Error deleting user {user_id}: {str(e)}")
|
||||
logger.error(f"Error deleting user {user_id}: {e!s}")
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error deleting user {user_id}: {str(e)}", exc_info=True)
|
||||
logger.error(f"Unexpected error deleting user {user_id}: {e!s}", exc_info=True)
|
||||
raise
|
||||
|
||||
@@ -1,39 +1,39 @@
|
||||
import logging
|
||||
logging.getLogger('passlib').setLevel(logging.ERROR)
|
||||
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import Any, Dict, Optional, Union
|
||||
import uuid
|
||||
logging.getLogger("passlib").setLevel(logging.ERROR)
|
||||
|
||||
import asyncio
|
||||
import uuid
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from functools import partial
|
||||
from typing import Any
|
||||
|
||||
from jose import jwt, JWTError
|
||||
from jose import JWTError, jwt
|
||||
from passlib.context import CryptContext
|
||||
from pydantic import ValidationError
|
||||
|
||||
from app.core.config import settings
|
||||
from app.schemas.users import TokenData, TokenPayload
|
||||
|
||||
|
||||
# Password hashing context
|
||||
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
|
||||
|
||||
|
||||
# Custom exceptions for auth
|
||||
class AuthError(Exception):
|
||||
"""Base authentication error"""
|
||||
pass
|
||||
|
||||
|
||||
class TokenExpiredError(AuthError):
|
||||
"""Token has expired"""
|
||||
pass
|
||||
|
||||
|
||||
class TokenInvalidError(AuthError):
|
||||
"""Token is invalid"""
|
||||
pass
|
||||
|
||||
|
||||
class TokenMissingClaimError(AuthError):
|
||||
"""Token is missing a required claim"""
|
||||
pass
|
||||
|
||||
|
||||
def verify_password(plain_password: str, hashed_password: str) -> bool:
|
||||
@@ -62,8 +62,7 @@ async def verify_password_async(plain_password: str, hashed_password: str) -> bo
|
||||
"""
|
||||
loop = asyncio.get_event_loop()
|
||||
return await loop.run_in_executor(
|
||||
None,
|
||||
partial(pwd_context.verify, plain_password, hashed_password)
|
||||
None, partial(pwd_context.verify, plain_password, hashed_password)
|
||||
)
|
||||
|
||||
|
||||
@@ -82,17 +81,13 @@ async def get_password_hash_async(password: str) -> str:
|
||||
Hashed password string
|
||||
"""
|
||||
loop = asyncio.get_event_loop()
|
||||
return await loop.run_in_executor(
|
||||
None,
|
||||
pwd_context.hash,
|
||||
password
|
||||
)
|
||||
return await loop.run_in_executor(None, pwd_context.hash, password)
|
||||
|
||||
|
||||
def create_access_token(
|
||||
subject: Union[str, Any],
|
||||
expires_delta: Optional[timedelta] = None,
|
||||
claims: Optional[Dict[str, Any]] = None
|
||||
subject: str | Any,
|
||||
expires_delta: timedelta | None = None,
|
||||
claims: dict[str, Any] | None = None,
|
||||
) -> str:
|
||||
"""
|
||||
Create a JWT access token.
|
||||
@@ -106,17 +101,19 @@ def create_access_token(
|
||||
Encoded JWT token
|
||||
"""
|
||||
if expires_delta:
|
||||
expire = datetime.now(timezone.utc) + expires_delta
|
||||
expire = datetime.now(UTC) + expires_delta
|
||||
else:
|
||||
expire = datetime.now(timezone.utc) + timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES)
|
||||
expire = datetime.now(UTC) + timedelta(
|
||||
minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES
|
||||
)
|
||||
|
||||
# Base token data
|
||||
to_encode = {
|
||||
"sub": str(subject),
|
||||
"exp": expire,
|
||||
"iat": datetime.now(tz=timezone.utc),
|
||||
"iat": datetime.now(tz=UTC),
|
||||
"jti": str(uuid.uuid4()),
|
||||
"type": "access"
|
||||
"type": "access",
|
||||
}
|
||||
|
||||
# Add custom claims
|
||||
@@ -125,17 +122,14 @@ def create_access_token(
|
||||
|
||||
# Create the JWT
|
||||
encoded_jwt = jwt.encode(
|
||||
to_encode,
|
||||
settings.SECRET_KEY,
|
||||
algorithm=settings.ALGORITHM
|
||||
to_encode, settings.SECRET_KEY, algorithm=settings.ALGORITHM
|
||||
)
|
||||
|
||||
return encoded_jwt
|
||||
|
||||
|
||||
def create_refresh_token(
|
||||
subject: Union[str, Any],
|
||||
expires_delta: Optional[timedelta] = None
|
||||
subject: str | Any, expires_delta: timedelta | None = None
|
||||
) -> str:
|
||||
"""
|
||||
Create a JWT refresh token.
|
||||
@@ -148,28 +142,26 @@ def create_refresh_token(
|
||||
Encoded JWT refresh token
|
||||
"""
|
||||
if expires_delta:
|
||||
expire = datetime.now(timezone.utc) + expires_delta
|
||||
expire = datetime.now(UTC) + expires_delta
|
||||
else:
|
||||
expire = datetime.now(timezone.utc) + timedelta(days=settings.REFRESH_TOKEN_EXPIRE_DAYS)
|
||||
expire = datetime.now(UTC) + timedelta(days=settings.REFRESH_TOKEN_EXPIRE_DAYS)
|
||||
|
||||
to_encode = {
|
||||
"sub": str(subject),
|
||||
"exp": expire,
|
||||
"iat": datetime.now(timezone.utc),
|
||||
"iat": datetime.now(UTC),
|
||||
"jti": str(uuid.uuid4()),
|
||||
"type": "refresh"
|
||||
"type": "refresh",
|
||||
}
|
||||
|
||||
encoded_jwt = jwt.encode(
|
||||
to_encode,
|
||||
settings.SECRET_KEY,
|
||||
algorithm=settings.ALGORITHM
|
||||
to_encode, settings.SECRET_KEY, algorithm=settings.ALGORITHM
|
||||
)
|
||||
|
||||
return encoded_jwt
|
||||
|
||||
|
||||
def decode_token(token: str, verify_type: Optional[str] = None) -> TokenPayload:
|
||||
def decode_token(token: str, verify_type: str | None = None) -> TokenPayload:
|
||||
"""
|
||||
Decode and verify a JWT token.
|
||||
|
||||
@@ -195,8 +187,8 @@ def decode_token(token: str, verify_type: Optional[str] = None) -> TokenPayload:
|
||||
"verify_signature": True,
|
||||
"verify_exp": True,
|
||||
"verify_iat": True,
|
||||
"require": ["exp", "sub", "iat"]
|
||||
}
|
||||
"require": ["exp", "sub", "iat"],
|
||||
},
|
||||
)
|
||||
|
||||
# SECURITY: Explicitly verify the algorithm to prevent algorithm confusion attacks
|
||||
@@ -250,4 +242,4 @@ def get_token_data(token: str) -> TokenData:
|
||||
user_id = payload.sub
|
||||
is_superuser = payload.is_superuser or False
|
||||
|
||||
return TokenData(user_id=uuid.UUID(user_id), is_superuser=is_superuser)
|
||||
return TokenData(user_id=uuid.UUID(user_id), is_superuser=is_superuser)
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
import logging
|
||||
from typing import Optional, List
|
||||
|
||||
from pydantic import Field, field_validator
|
||||
from pydantic_settings import BaseSettings
|
||||
@@ -13,7 +12,7 @@ class Settings(BaseSettings):
|
||||
# Environment (must be before SECRET_KEY for validation)
|
||||
ENVIRONMENT: str = Field(
|
||||
default="development",
|
||||
description="Environment: development, staging, or production"
|
||||
description="Environment: development, staging, or production",
|
||||
)
|
||||
|
||||
# Security: Content Security Policy
|
||||
@@ -21,8 +20,7 @@ class Settings(BaseSettings):
|
||||
# Set to True for strict CSP (blocks most external resources)
|
||||
# Set to "relaxed" for modern frontend development
|
||||
CSP_MODE: str = Field(
|
||||
default="relaxed",
|
||||
description="CSP mode: 'strict', 'relaxed', or 'disabled'"
|
||||
default="relaxed", description="CSP mode: 'strict', 'relaxed', or 'disabled'"
|
||||
)
|
||||
|
||||
# Database configuration
|
||||
@@ -31,7 +29,7 @@ class Settings(BaseSettings):
|
||||
POSTGRES_HOST: str = "localhost"
|
||||
POSTGRES_PORT: str = "5432"
|
||||
POSTGRES_DB: str = "app"
|
||||
DATABASE_URL: Optional[str] = None
|
||||
DATABASE_URL: str | None = None
|
||||
db_pool_size: int = 20 # Default connection pool size
|
||||
db_max_overflow: int = 50 # Maximum overflow connections
|
||||
db_pool_timeout: int = 30 # Seconds to wait for a connection
|
||||
@@ -59,38 +57,36 @@ class Settings(BaseSettings):
|
||||
SECRET_KEY: str = Field(
|
||||
default="dev_only_insecure_key_change_in_production_32chars_min",
|
||||
min_length=32,
|
||||
description="JWT signing key. MUST be changed in production. Generate with: python -c 'import secrets; print(secrets.token_urlsafe(32))'"
|
||||
description="JWT signing key. MUST be changed in production. Generate with: python -c 'import secrets; print(secrets.token_urlsafe(32))'",
|
||||
)
|
||||
ALGORITHM: str = "HS256"
|
||||
ACCESS_TOKEN_EXPIRE_MINUTES: int = 15 # 15 minutes (production standard)
|
||||
REFRESH_TOKEN_EXPIRE_DAYS: int = 7 # 7 days
|
||||
|
||||
# CORS configuration
|
||||
BACKEND_CORS_ORIGINS: List[str] = ["http://localhost:3000"]
|
||||
BACKEND_CORS_ORIGINS: list[str] = ["http://localhost:3000"]
|
||||
|
||||
# Frontend URL for email links
|
||||
FRONTEND_URL: str = Field(
|
||||
default="http://localhost:3000",
|
||||
description="Frontend application URL for email links"
|
||||
description="Frontend application URL for email links",
|
||||
)
|
||||
|
||||
# Admin user
|
||||
FIRST_SUPERUSER_EMAIL: Optional[str] = Field(
|
||||
default=None,
|
||||
description="Email for first superuser account"
|
||||
FIRST_SUPERUSER_EMAIL: str | None = Field(
|
||||
default=None, description="Email for first superuser account"
|
||||
)
|
||||
FIRST_SUPERUSER_PASSWORD: Optional[str] = Field(
|
||||
default=None,
|
||||
description="Password for first superuser (min 12 characters)"
|
||||
FIRST_SUPERUSER_PASSWORD: str | None = Field(
|
||||
default=None, description="Password for first superuser (min 12 characters)"
|
||||
)
|
||||
|
||||
@field_validator('SECRET_KEY')
|
||||
@field_validator("SECRET_KEY")
|
||||
@classmethod
|
||||
def validate_secret_key(cls, v: str, info) -> str:
|
||||
"""Validate SECRET_KEY is secure, especially in production."""
|
||||
# Get environment from values if available
|
||||
values_data = info.data if info.data else {}
|
||||
env = values_data.get('ENVIRONMENT', 'development')
|
||||
env = values_data.get("ENVIRONMENT", "development")
|
||||
|
||||
if v.startswith("your_secret_key_here"):
|
||||
if env == "production":
|
||||
@@ -106,13 +102,15 @@ class Settings(BaseSettings):
|
||||
)
|
||||
|
||||
if len(v) < 32:
|
||||
raise ValueError("SECRET_KEY must be at least 32 characters long for security")
|
||||
raise ValueError(
|
||||
"SECRET_KEY must be at least 32 characters long for security"
|
||||
)
|
||||
|
||||
return v
|
||||
|
||||
@field_validator('FIRST_SUPERUSER_PASSWORD')
|
||||
@field_validator("FIRST_SUPERUSER_PASSWORD")
|
||||
@classmethod
|
||||
def validate_superuser_password(cls, v: Optional[str]) -> Optional[str]:
|
||||
def validate_superuser_password(cls, v: str | None) -> str | None:
|
||||
"""Validate superuser password strength."""
|
||||
if v is None:
|
||||
return v
|
||||
@@ -121,7 +119,13 @@ class Settings(BaseSettings):
|
||||
raise ValueError("FIRST_SUPERUSER_PASSWORD must be at least 12 characters")
|
||||
|
||||
# Check for common weak passwords
|
||||
weak_passwords = {'admin123', 'Admin123', 'password123', 'Password123', '123456789012'}
|
||||
weak_passwords = {
|
||||
"admin123",
|
||||
"Admin123",
|
||||
"password123",
|
||||
"Password123",
|
||||
"123456789012",
|
||||
}
|
||||
if v in weak_passwords:
|
||||
raise ValueError(
|
||||
"FIRST_SUPERUSER_PASSWORD is too weak. "
|
||||
@@ -144,8 +148,8 @@ class Settings(BaseSettings):
|
||||
"env_file": "../.env",
|
||||
"env_file_encoding": "utf-8",
|
||||
"case_sensitive": True,
|
||||
"extra": "ignore" # Ignore extra fields from .env (e.g., frontend-specific vars)
|
||||
"extra": "ignore", # Ignore extra fields from .env (e.g., frontend-specific vars)
|
||||
}
|
||||
|
||||
|
||||
settings = Settings()
|
||||
settings = Settings()
|
||||
|
||||
@@ -5,17 +5,18 @@ Database configuration using SQLAlchemy 2.0 and asyncpg.
|
||||
This module provides async database connectivity with proper connection pooling
|
||||
and session management for FastAPI endpoints.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from collections.abc import AsyncGenerator
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import AsyncGenerator
|
||||
|
||||
from sqlalchemy import text
|
||||
from sqlalchemy.dialects.postgresql import JSONB, UUID
|
||||
from sqlalchemy.ext.asyncio import (
|
||||
AsyncSession,
|
||||
AsyncEngine,
|
||||
create_async_engine,
|
||||
AsyncSession,
|
||||
async_sessionmaker,
|
||||
create_async_engine,
|
||||
)
|
||||
from sqlalchemy.ext.compiler import compiles
|
||||
from sqlalchemy.orm import DeclarativeBase
|
||||
@@ -27,12 +28,12 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# SQLite compatibility for testing
|
||||
@compiles(JSONB, 'sqlite')
|
||||
@compiles(JSONB, "sqlite")
|
||||
def compile_jsonb_sqlite(type_, compiler, **kw):
|
||||
return "TEXT"
|
||||
|
||||
|
||||
@compiles(UUID, 'sqlite')
|
||||
@compiles(UUID, "sqlite")
|
||||
def compile_uuid_sqlite(type_, compiler, **kw):
|
||||
return "TEXT"
|
||||
|
||||
@@ -40,7 +41,6 @@ def compile_uuid_sqlite(type_, compiler, **kw):
|
||||
# Declarative base for models (SQLAlchemy 2.0 style)
|
||||
class Base(DeclarativeBase):
|
||||
"""Base class for all database models."""
|
||||
pass
|
||||
|
||||
|
||||
def get_async_database_url(url: str) -> str:
|
||||
@@ -139,7 +139,7 @@ async def async_transaction_scope() -> AsyncGenerator[AsyncSession, None]:
|
||||
logger.debug("Async transaction committed successfully")
|
||||
except Exception as e:
|
||||
await session.rollback()
|
||||
logger.error(f"Async transaction failed, rolling back: {str(e)}")
|
||||
logger.error(f"Async transaction failed, rolling back: {e!s}")
|
||||
raise
|
||||
finally:
|
||||
await session.close()
|
||||
@@ -155,7 +155,7 @@ async def check_async_database_health() -> bool:
|
||||
await db.execute(text("SELECT 1"))
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"Async database health check failed: {str(e)}")
|
||||
logger.error(f"Async database health check failed: {e!s}")
|
||||
return False
|
||||
|
||||
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
"""
|
||||
Custom exceptions and global exception handlers for the API.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Optional, Union
|
||||
|
||||
from fastapi import HTTPException, Request, status
|
||||
from fastapi.exceptions import RequestValidationError
|
||||
@@ -27,17 +27,13 @@ class APIException(HTTPException):
|
||||
status_code: int,
|
||||
error_code: ErrorCode,
|
||||
message: str,
|
||||
field: Optional[str] = None,
|
||||
headers: Optional[dict] = None
|
||||
field: str | None = None,
|
||||
headers: dict | None = None,
|
||||
):
|
||||
self.error_code = error_code
|
||||
self.field = field
|
||||
self.message = message
|
||||
super().__init__(
|
||||
status_code=status_code,
|
||||
detail=message,
|
||||
headers=headers
|
||||
)
|
||||
super().__init__(status_code=status_code, detail=message, headers=headers)
|
||||
|
||||
|
||||
class AuthenticationError(APIException):
|
||||
@@ -47,14 +43,14 @@ class AuthenticationError(APIException):
|
||||
self,
|
||||
message: str = "Authentication failed",
|
||||
error_code: ErrorCode = ErrorCode.INVALID_CREDENTIALS,
|
||||
field: Optional[str] = None
|
||||
field: str | None = None,
|
||||
):
|
||||
super().__init__(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
error_code=error_code,
|
||||
message=message,
|
||||
field=field,
|
||||
headers={"WWW-Authenticate": "Bearer"}
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
|
||||
|
||||
@@ -64,12 +60,12 @@ class AuthorizationError(APIException):
|
||||
def __init__(
|
||||
self,
|
||||
message: str = "Insufficient permissions",
|
||||
error_code: ErrorCode = ErrorCode.INSUFFICIENT_PERMISSIONS
|
||||
error_code: ErrorCode = ErrorCode.INSUFFICIENT_PERMISSIONS,
|
||||
):
|
||||
super().__init__(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
error_code=error_code,
|
||||
message=message
|
||||
message=message,
|
||||
)
|
||||
|
||||
|
||||
@@ -79,12 +75,12 @@ class NotFoundError(APIException):
|
||||
def __init__(
|
||||
self,
|
||||
message: str = "Resource not found",
|
||||
error_code: ErrorCode = ErrorCode.NOT_FOUND
|
||||
error_code: ErrorCode = ErrorCode.NOT_FOUND,
|
||||
):
|
||||
super().__init__(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
error_code=error_code,
|
||||
message=message
|
||||
message=message,
|
||||
)
|
||||
|
||||
|
||||
@@ -95,13 +91,13 @@ class DuplicateError(APIException):
|
||||
self,
|
||||
message: str = "Resource already exists",
|
||||
error_code: ErrorCode = ErrorCode.DUPLICATE_ENTRY,
|
||||
field: Optional[str] = None
|
||||
field: str | None = None,
|
||||
):
|
||||
super().__init__(
|
||||
status_code=status.HTTP_409_CONFLICT,
|
||||
error_code=error_code,
|
||||
message=message,
|
||||
field=field
|
||||
field=field,
|
||||
)
|
||||
|
||||
|
||||
@@ -112,13 +108,13 @@ class ValidationException(APIException):
|
||||
self,
|
||||
message: str = "Validation error",
|
||||
error_code: ErrorCode = ErrorCode.VALIDATION_ERROR,
|
||||
field: Optional[str] = None
|
||||
field: str | None = None,
|
||||
):
|
||||
super().__init__(
|
||||
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
|
||||
error_code=error_code,
|
||||
message=message,
|
||||
field=field
|
||||
field=field,
|
||||
)
|
||||
|
||||
|
||||
@@ -128,12 +124,12 @@ class DatabaseError(APIException):
|
||||
def __init__(
|
||||
self,
|
||||
message: str = "Database operation failed",
|
||||
error_code: ErrorCode = ErrorCode.DATABASE_ERROR
|
||||
error_code: ErrorCode = ErrorCode.DATABASE_ERROR,
|
||||
):
|
||||
super().__init__(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
error_code=error_code,
|
||||
message=message
|
||||
message=message,
|
||||
)
|
||||
|
||||
|
||||
@@ -152,23 +148,18 @@ async def api_exception_handler(request: Request, exc: APIException) -> JSONResp
|
||||
)
|
||||
|
||||
error_response = ErrorResponse(
|
||||
errors=[ErrorDetail(
|
||||
code=exc.error_code,
|
||||
message=exc.message,
|
||||
field=exc.field
|
||||
)]
|
||||
errors=[ErrorDetail(code=exc.error_code, message=exc.message, field=exc.field)]
|
||||
)
|
||||
|
||||
return JSONResponse(
|
||||
status_code=exc.status_code,
|
||||
content=error_response.model_dump(),
|
||||
headers=exc.headers
|
||||
headers=exc.headers,
|
||||
)
|
||||
|
||||
|
||||
async def validation_exception_handler(
|
||||
request: Request,
|
||||
exc: Union[RequestValidationError, ValidationError]
|
||||
request: Request, exc: RequestValidationError | ValidationError
|
||||
) -> JSONResponse:
|
||||
"""
|
||||
Handler for Pydantic validation errors.
|
||||
@@ -189,22 +180,19 @@ async def validation_exception_handler(
|
||||
# Skip 'body' or 'query' prefix in location
|
||||
field = ".".join(str(x) for x in error["loc"][1:])
|
||||
|
||||
errors.append(ErrorDetail(
|
||||
code=ErrorCode.VALIDATION_ERROR,
|
||||
message=error["msg"],
|
||||
field=field
|
||||
))
|
||||
errors.append(
|
||||
ErrorDetail(
|
||||
code=ErrorCode.VALIDATION_ERROR, message=error["msg"], field=field
|
||||
)
|
||||
)
|
||||
|
||||
logger.warning(
|
||||
f"Validation error: {len(errors)} errors "
|
||||
f"(path: {request.url.path})"
|
||||
)
|
||||
logger.warning(f"Validation error: {len(errors)} errors (path: {request.url.path})")
|
||||
|
||||
error_response = ErrorResponse(errors=errors)
|
||||
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
|
||||
content=error_response.model_dump()
|
||||
content=error_response.model_dump(),
|
||||
)
|
||||
|
||||
|
||||
@@ -226,26 +214,21 @@ async def http_exception_handler(request: Request, exc: HTTPException) -> JSONRe
|
||||
}
|
||||
|
||||
error_code = status_code_to_error_code.get(
|
||||
exc.status_code,
|
||||
ErrorCode.INTERNAL_ERROR
|
||||
exc.status_code, ErrorCode.INTERNAL_ERROR
|
||||
)
|
||||
|
||||
logger.warning(
|
||||
f"HTTP exception: {exc.status_code} - {exc.detail} "
|
||||
f"(path: {request.url.path})"
|
||||
f"HTTP exception: {exc.status_code} - {exc.detail} (path: {request.url.path})"
|
||||
)
|
||||
|
||||
error_response = ErrorResponse(
|
||||
errors=[ErrorDetail(
|
||||
code=error_code,
|
||||
message=str(exc.detail)
|
||||
)]
|
||||
errors=[ErrorDetail(code=error_code, message=str(exc.detail))]
|
||||
)
|
||||
|
||||
return JSONResponse(
|
||||
status_code=exc.status_code,
|
||||
content=error_response.model_dump(),
|
||||
headers=exc.headers
|
||||
headers=exc.headers,
|
||||
)
|
||||
|
||||
|
||||
@@ -257,26 +240,24 @@ async def unhandled_exception_handler(request: Request, exc: Exception) -> JSONR
|
||||
leaking sensitive information in production.
|
||||
"""
|
||||
logger.error(
|
||||
f"Unhandled exception: {type(exc).__name__} - {str(exc)} "
|
||||
f"Unhandled exception: {type(exc).__name__} - {exc!s} "
|
||||
f"(path: {request.url.path})",
|
||||
exc_info=True
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
# In production, don't expose internal error details
|
||||
from app.core.config import settings
|
||||
|
||||
if settings.ENVIRONMENT == "production":
|
||||
message = "An internal error occurred. Please try again later."
|
||||
else:
|
||||
message = f"{type(exc).__name__}: {str(exc)}"
|
||||
message = f"{type(exc).__name__}: {exc!s}"
|
||||
|
||||
error_response = ErrorResponse(
|
||||
errors=[ErrorDetail(
|
||||
code=ErrorCode.INTERNAL_ERROR,
|
||||
message=message
|
||||
)]
|
||||
errors=[ErrorDetail(code=ErrorCode.INTERNAL_ERROR, message=message)]
|
||||
)
|
||||
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
content=error_response.model_dump()
|
||||
content=error_response.model_dump(),
|
||||
)
|
||||
|
||||
@@ -3,4 +3,4 @@ from .organization import organization
|
||||
from .session import session as session_crud
|
||||
from .user import user
|
||||
|
||||
__all__ = ["user", "session_crud", "organization"]
|
||||
__all__ = ["organization", "session_crud", "user"]
|
||||
|
||||
@@ -4,14 +4,16 @@ Async CRUD operations base class using SQLAlchemy 2.0 async patterns.
|
||||
|
||||
Provides reusable create, read, update, and delete operations for all models.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import uuid
|
||||
from typing import Any, Dict, Generic, List, Optional, Type, TypeVar, Union, Tuple
|
||||
from datetime import UTC
|
||||
from typing import Any, TypeVar
|
||||
|
||||
from fastapi.encoders import jsonable_encoder
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy import func, select
|
||||
from sqlalchemy.exc import IntegrityError, OperationalError, DataError
|
||||
from sqlalchemy.exc import DataError, IntegrityError, OperationalError
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm import Load
|
||||
|
||||
@@ -24,10 +26,14 @@ CreateSchemaType = TypeVar("CreateSchemaType", bound=BaseModel)
|
||||
UpdateSchemaType = TypeVar("UpdateSchemaType", bound=BaseModel)
|
||||
|
||||
|
||||
class CRUDBase(Generic[ModelType, CreateSchemaType, UpdateSchemaType]):
|
||||
class CRUDBase[
|
||||
ModelType: Base,
|
||||
CreateSchemaType: BaseModel,
|
||||
UpdateSchemaType: BaseModel,
|
||||
]:
|
||||
"""Async CRUD operations for a model."""
|
||||
|
||||
def __init__(self, model: Type[ModelType]):
|
||||
def __init__(self, model: type[ModelType]):
|
||||
"""
|
||||
CRUD object with default async methods to Create, Read, Update, Delete.
|
||||
|
||||
@@ -37,11 +43,8 @@ class CRUDBase(Generic[ModelType, CreateSchemaType, UpdateSchemaType]):
|
||||
self.model = model
|
||||
|
||||
async def get(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
id: str,
|
||||
options: Optional[List[Load]] = None
|
||||
) -> Optional[ModelType]:
|
||||
self, db: AsyncSession, id: str, options: list[Load] | None = None
|
||||
) -> ModelType | None:
|
||||
"""
|
||||
Get a single record by ID with UUID validation and optional eager loading.
|
||||
|
||||
@@ -66,7 +69,7 @@ class CRUDBase(Generic[ModelType, CreateSchemaType, UpdateSchemaType]):
|
||||
else:
|
||||
uuid_obj = uuid.UUID(str(id))
|
||||
except (ValueError, AttributeError, TypeError) as e:
|
||||
logger.warning(f"Invalid UUID format: {id} - {str(e)}")
|
||||
logger.warning(f"Invalid UUID format: {id} - {e!s}")
|
||||
return None
|
||||
|
||||
try:
|
||||
@@ -80,7 +83,7 @@ class CRUDBase(Generic[ModelType, CreateSchemaType, UpdateSchemaType]):
|
||||
result = await db.execute(query)
|
||||
return result.scalar_one_or_none()
|
||||
except Exception as e:
|
||||
logger.error(f"Error retrieving {self.model.__name__} with id {id}: {str(e)}")
|
||||
logger.error(f"Error retrieving {self.model.__name__} with id {id}: {e!s}")
|
||||
raise
|
||||
|
||||
async def get_multi(
|
||||
@@ -89,8 +92,8 @@ class CRUDBase(Generic[ModelType, CreateSchemaType, UpdateSchemaType]):
|
||||
*,
|
||||
skip: int = 0,
|
||||
limit: int = 100,
|
||||
options: Optional[List[Load]] = None
|
||||
) -> List[ModelType]:
|
||||
options: list[Load] | None = None,
|
||||
) -> list[ModelType]:
|
||||
"""
|
||||
Get multiple records with pagination validation and optional eager loading.
|
||||
|
||||
@@ -122,10 +125,14 @@ class CRUDBase(Generic[ModelType, CreateSchemaType, UpdateSchemaType]):
|
||||
result = await db.execute(query)
|
||||
return list(result.scalars().all())
|
||||
except Exception as e:
|
||||
logger.error(f"Error retrieving multiple {self.model.__name__} records: {str(e)}")
|
||||
logger.error(
|
||||
f"Error retrieving multiple {self.model.__name__} records: {e!s}"
|
||||
)
|
||||
raise
|
||||
|
||||
async def create(self, db: AsyncSession, *, obj_in: CreateSchemaType) -> ModelType: # pragma: no cover
|
||||
async def create(
|
||||
self, db: AsyncSession, *, obj_in: CreateSchemaType
|
||||
) -> ModelType: # pragma: no cover
|
||||
"""Create a new record with error handling.
|
||||
|
||||
NOTE: This method is defensive code that's never called in practice.
|
||||
@@ -142,19 +149,25 @@ class CRUDBase(Generic[ModelType, CreateSchemaType, UpdateSchemaType]):
|
||||
return db_obj
|
||||
except IntegrityError as e: # pragma: no cover
|
||||
await db.rollback()
|
||||
error_msg = str(e.orig) if hasattr(e, 'orig') else str(e)
|
||||
error_msg = str(e.orig) if hasattr(e, "orig") else str(e)
|
||||
if "unique" in error_msg.lower() or "duplicate" in error_msg.lower():
|
||||
logger.warning(f"Duplicate entry attempted for {self.model.__name__}: {error_msg}")
|
||||
raise ValueError(f"A {self.model.__name__} with this data already exists")
|
||||
logger.warning(
|
||||
f"Duplicate entry attempted for {self.model.__name__}: {error_msg}"
|
||||
)
|
||||
raise ValueError(
|
||||
f"A {self.model.__name__} with this data already exists"
|
||||
)
|
||||
logger.error(f"Integrity error creating {self.model.__name__}: {error_msg}")
|
||||
raise ValueError(f"Database integrity error: {error_msg}")
|
||||
except (OperationalError, DataError) as e: # pragma: no cover
|
||||
await db.rollback()
|
||||
logger.error(f"Database error creating {self.model.__name__}: {str(e)}")
|
||||
raise ValueError(f"Database operation failed: {str(e)}")
|
||||
logger.error(f"Database error creating {self.model.__name__}: {e!s}")
|
||||
raise ValueError(f"Database operation failed: {e!s}")
|
||||
except Exception as e: # pragma: no cover
|
||||
await db.rollback()
|
||||
logger.error(f"Unexpected error creating {self.model.__name__}: {str(e)}", exc_info=True)
|
||||
logger.error(
|
||||
f"Unexpected error creating {self.model.__name__}: {e!s}", exc_info=True
|
||||
)
|
||||
raise
|
||||
|
||||
async def update(
|
||||
@@ -162,7 +175,7 @@ class CRUDBase(Generic[ModelType, CreateSchemaType, UpdateSchemaType]):
|
||||
db: AsyncSession,
|
||||
*,
|
||||
db_obj: ModelType,
|
||||
obj_in: Union[UpdateSchemaType, Dict[str, Any]]
|
||||
obj_in: UpdateSchemaType | dict[str, Any],
|
||||
) -> ModelType:
|
||||
"""Update a record with error handling."""
|
||||
try:
|
||||
@@ -182,22 +195,28 @@ class CRUDBase(Generic[ModelType, CreateSchemaType, UpdateSchemaType]):
|
||||
return db_obj
|
||||
except IntegrityError as e:
|
||||
await db.rollback()
|
||||
error_msg = str(e.orig) if hasattr(e, 'orig') else str(e)
|
||||
error_msg = str(e.orig) if hasattr(e, "orig") else str(e)
|
||||
if "unique" in error_msg.lower() or "duplicate" in error_msg.lower():
|
||||
logger.warning(f"Duplicate entry attempted for {self.model.__name__}: {error_msg}")
|
||||
raise ValueError(f"A {self.model.__name__} with this data already exists")
|
||||
logger.warning(
|
||||
f"Duplicate entry attempted for {self.model.__name__}: {error_msg}"
|
||||
)
|
||||
raise ValueError(
|
||||
f"A {self.model.__name__} with this data already exists"
|
||||
)
|
||||
logger.error(f"Integrity error updating {self.model.__name__}: {error_msg}")
|
||||
raise ValueError(f"Database integrity error: {error_msg}")
|
||||
except (OperationalError, DataError) as e:
|
||||
await db.rollback()
|
||||
logger.error(f"Database error updating {self.model.__name__}: {str(e)}")
|
||||
raise ValueError(f"Database operation failed: {str(e)}")
|
||||
logger.error(f"Database error updating {self.model.__name__}: {e!s}")
|
||||
raise ValueError(f"Database operation failed: {e!s}")
|
||||
except Exception as e:
|
||||
await db.rollback()
|
||||
logger.error(f"Unexpected error updating {self.model.__name__}: {str(e)}", exc_info=True)
|
||||
logger.error(
|
||||
f"Unexpected error updating {self.model.__name__}: {e!s}", exc_info=True
|
||||
)
|
||||
raise
|
||||
|
||||
async def remove(self, db: AsyncSession, *, id: str) -> Optional[ModelType]:
|
||||
async def remove(self, db: AsyncSession, *, id: str) -> ModelType | None:
|
||||
"""Delete a record with error handling and null check."""
|
||||
# Validate UUID format and convert to UUID object if string
|
||||
try:
|
||||
@@ -206,7 +225,7 @@ class CRUDBase(Generic[ModelType, CreateSchemaType, UpdateSchemaType]):
|
||||
else:
|
||||
uuid_obj = uuid.UUID(str(id))
|
||||
except (ValueError, AttributeError, TypeError) as e:
|
||||
logger.warning(f"Invalid UUID format for deletion: {id} - {str(e)}")
|
||||
logger.warning(f"Invalid UUID format for deletion: {id} - {e!s}")
|
||||
return None
|
||||
|
||||
try:
|
||||
@@ -216,7 +235,9 @@ class CRUDBase(Generic[ModelType, CreateSchemaType, UpdateSchemaType]):
|
||||
obj = result.scalar_one_or_none()
|
||||
|
||||
if obj is None:
|
||||
logger.warning(f"{self.model.__name__} with id {id} not found for deletion")
|
||||
logger.warning(
|
||||
f"{self.model.__name__} with id {id} not found for deletion"
|
||||
)
|
||||
return None
|
||||
|
||||
await db.delete(obj)
|
||||
@@ -224,12 +245,17 @@ class CRUDBase(Generic[ModelType, CreateSchemaType, UpdateSchemaType]):
|
||||
return obj
|
||||
except IntegrityError as e:
|
||||
await db.rollback()
|
||||
error_msg = str(e.orig) if hasattr(e, 'orig') else str(e)
|
||||
error_msg = str(e.orig) if hasattr(e, "orig") else str(e)
|
||||
logger.error(f"Integrity error deleting {self.model.__name__}: {error_msg}")
|
||||
raise ValueError(f"Cannot delete {self.model.__name__}: referenced by other records")
|
||||
raise ValueError(
|
||||
f"Cannot delete {self.model.__name__}: referenced by other records"
|
||||
)
|
||||
except Exception as e:
|
||||
await db.rollback()
|
||||
logger.error(f"Error deleting {self.model.__name__} with id {id}: {str(e)}", exc_info=True)
|
||||
logger.error(
|
||||
f"Error deleting {self.model.__name__} with id {id}: {e!s}",
|
||||
exc_info=True,
|
||||
)
|
||||
raise
|
||||
|
||||
async def get_multi_with_total(
|
||||
@@ -238,10 +264,10 @@ class CRUDBase(Generic[ModelType, CreateSchemaType, UpdateSchemaType]):
|
||||
*,
|
||||
skip: int = 0,
|
||||
limit: int = 100,
|
||||
sort_by: Optional[str] = None,
|
||||
sort_by: str | None = None,
|
||||
sort_order: str = "asc",
|
||||
filters: Optional[Dict[str, Any]] = None
|
||||
) -> Tuple[List[ModelType], int]:
|
||||
filters: dict[str, Any] | None = None,
|
||||
) -> tuple[list[ModelType], int]:
|
||||
"""
|
||||
Get multiple records with total count, filtering, and sorting.
|
||||
|
||||
@@ -269,7 +295,7 @@ class CRUDBase(Generic[ModelType, CreateSchemaType, UpdateSchemaType]):
|
||||
query = select(self.model)
|
||||
|
||||
# Exclude soft-deleted records by default
|
||||
if hasattr(self.model, 'deleted_at'):
|
||||
if hasattr(self.model, "deleted_at"):
|
||||
query = query.where(self.model.deleted_at.is_(None))
|
||||
|
||||
# Apply filters
|
||||
@@ -298,7 +324,9 @@ class CRUDBase(Generic[ModelType, CreateSchemaType, UpdateSchemaType]):
|
||||
|
||||
return items, total
|
||||
except Exception as e:
|
||||
logger.error(f"Error retrieving paginated {self.model.__name__} records: {str(e)}")
|
||||
logger.error(
|
||||
f"Error retrieving paginated {self.model.__name__} records: {e!s}"
|
||||
)
|
||||
raise
|
||||
|
||||
async def count(self, db: AsyncSession) -> int:
|
||||
@@ -307,7 +335,7 @@ class CRUDBase(Generic[ModelType, CreateSchemaType, UpdateSchemaType]):
|
||||
result = await db.execute(select(func.count(self.model.id)))
|
||||
return result.scalar_one()
|
||||
except Exception as e:
|
||||
logger.error(f"Error counting {self.model.__name__} records: {str(e)}")
|
||||
logger.error(f"Error counting {self.model.__name__} records: {e!s}")
|
||||
raise
|
||||
|
||||
async def exists(self, db: AsyncSession, id: str) -> bool:
|
||||
@@ -315,13 +343,13 @@ class CRUDBase(Generic[ModelType, CreateSchemaType, UpdateSchemaType]):
|
||||
obj = await self.get(db, id=id)
|
||||
return obj is not None
|
||||
|
||||
async def soft_delete(self, db: AsyncSession, *, id: str) -> Optional[ModelType]:
|
||||
async def soft_delete(self, db: AsyncSession, *, id: str) -> ModelType | None:
|
||||
"""
|
||||
Soft delete a record by setting deleted_at timestamp.
|
||||
|
||||
Only works if the model has a 'deleted_at' column.
|
||||
"""
|
||||
from datetime import datetime, timezone
|
||||
from datetime import datetime
|
||||
|
||||
# Validate UUID format and convert to UUID object if string
|
||||
try:
|
||||
@@ -330,7 +358,7 @@ class CRUDBase(Generic[ModelType, CreateSchemaType, UpdateSchemaType]):
|
||||
else:
|
||||
uuid_obj = uuid.UUID(str(id))
|
||||
except (ValueError, AttributeError, TypeError) as e:
|
||||
logger.warning(f"Invalid UUID format for soft deletion: {id} - {str(e)}")
|
||||
logger.warning(f"Invalid UUID format for soft deletion: {id} - {e!s}")
|
||||
return None
|
||||
|
||||
try:
|
||||
@@ -340,26 +368,33 @@ class CRUDBase(Generic[ModelType, CreateSchemaType, UpdateSchemaType]):
|
||||
obj = result.scalar_one_or_none()
|
||||
|
||||
if obj is None:
|
||||
logger.warning(f"{self.model.__name__} with id {id} not found for soft deletion")
|
||||
logger.warning(
|
||||
f"{self.model.__name__} with id {id} not found for soft deletion"
|
||||
)
|
||||
return None
|
||||
|
||||
# Check if model supports soft deletes
|
||||
if not hasattr(self.model, 'deleted_at'):
|
||||
if not hasattr(self.model, "deleted_at"):
|
||||
logger.error(f"{self.model.__name__} does not support soft deletes")
|
||||
raise ValueError(f"{self.model.__name__} does not have a deleted_at column")
|
||||
raise ValueError(
|
||||
f"{self.model.__name__} does not have a deleted_at column"
|
||||
)
|
||||
|
||||
# Set deleted_at timestamp
|
||||
obj.deleted_at = datetime.now(timezone.utc)
|
||||
obj.deleted_at = datetime.now(UTC)
|
||||
db.add(obj)
|
||||
await db.commit()
|
||||
await db.refresh(obj)
|
||||
return obj
|
||||
except Exception as e:
|
||||
await db.rollback()
|
||||
logger.error(f"Error soft deleting {self.model.__name__} with id {id}: {str(e)}", exc_info=True)
|
||||
logger.error(
|
||||
f"Error soft deleting {self.model.__name__} with id {id}: {e!s}",
|
||||
exc_info=True,
|
||||
)
|
||||
raise
|
||||
|
||||
async def restore(self, db: AsyncSession, *, id: str) -> Optional[ModelType]:
|
||||
async def restore(self, db: AsyncSession, *, id: str) -> ModelType | None:
|
||||
"""
|
||||
Restore a soft-deleted record by clearing the deleted_at timestamp.
|
||||
|
||||
@@ -372,25 +407,28 @@ class CRUDBase(Generic[ModelType, CreateSchemaType, UpdateSchemaType]):
|
||||
else:
|
||||
uuid_obj = uuid.UUID(str(id))
|
||||
except (ValueError, AttributeError, TypeError) as e:
|
||||
logger.warning(f"Invalid UUID format for restoration: {id} - {str(e)}")
|
||||
logger.warning(f"Invalid UUID format for restoration: {id} - {e!s}")
|
||||
return None
|
||||
|
||||
try:
|
||||
# Find the soft-deleted record
|
||||
if hasattr(self.model, 'deleted_at'):
|
||||
if hasattr(self.model, "deleted_at"):
|
||||
result = await db.execute(
|
||||
select(self.model).where(
|
||||
self.model.id == uuid_obj,
|
||||
self.model.deleted_at.isnot(None)
|
||||
self.model.id == uuid_obj, self.model.deleted_at.isnot(None)
|
||||
)
|
||||
)
|
||||
obj = result.scalar_one_or_none()
|
||||
else:
|
||||
logger.error(f"{self.model.__name__} does not support soft deletes")
|
||||
raise ValueError(f"{self.model.__name__} does not have a deleted_at column")
|
||||
raise ValueError(
|
||||
f"{self.model.__name__} does not have a deleted_at column"
|
||||
)
|
||||
|
||||
if obj is None:
|
||||
logger.warning(f"Soft-deleted {self.model.__name__} with id {id} not found for restoration")
|
||||
logger.warning(
|
||||
f"Soft-deleted {self.model.__name__} with id {id} not found for restoration"
|
||||
)
|
||||
return None
|
||||
|
||||
# Clear deleted_at timestamp
|
||||
@@ -401,5 +439,8 @@ class CRUDBase(Generic[ModelType, CreateSchemaType, UpdateSchemaType]):
|
||||
return obj
|
||||
except Exception as e:
|
||||
await db.rollback()
|
||||
logger.error(f"Error restoring {self.model.__name__} with id {id}: {str(e)}", exc_info=True)
|
||||
logger.error(
|
||||
f"Error restoring {self.model.__name__} with id {id}: {e!s}",
|
||||
exc_info=True,
|
||||
)
|
||||
raise
|
||||
|
||||
@@ -1,17 +1,18 @@
|
||||
# app/crud/organization_async.py
|
||||
"""Async CRUD operations for Organization model using SQLAlchemy 2.0 patterns."""
|
||||
|
||||
import logging
|
||||
from typing import Optional, List, Dict, Any
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy import func, or_, and_, select, case
|
||||
from sqlalchemy import and_, case, func, or_, select
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.crud.base import CRUDBase
|
||||
from app.models.organization import Organization
|
||||
from app.models.user import User
|
||||
from app.models.user_organization import UserOrganization, OrganizationRole
|
||||
from app.models.user_organization import OrganizationRole, UserOrganization
|
||||
from app.schemas.organizations import (
|
||||
OrganizationCreate,
|
||||
OrganizationUpdate,
|
||||
@@ -23,7 +24,7 @@ logger = logging.getLogger(__name__)
|
||||
class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUpdate]):
|
||||
"""Async CRUD operations for Organization model."""
|
||||
|
||||
async def get_by_slug(self, db: AsyncSession, *, slug: str) -> Optional[Organization]:
|
||||
async def get_by_slug(self, db: AsyncSession, *, slug: str) -> Organization | None:
|
||||
"""Get organization by slug."""
|
||||
try:
|
||||
result = await db.execute(
|
||||
@@ -31,10 +32,12 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp
|
||||
)
|
||||
return result.scalar_one_or_none()
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting organization by slug {slug}: {str(e)}")
|
||||
logger.error(f"Error getting organization by slug {slug}: {e!s}")
|
||||
raise
|
||||
|
||||
async def create(self, db: AsyncSession, *, obj_in: OrganizationCreate) -> Organization:
|
||||
async def create(
|
||||
self, db: AsyncSession, *, obj_in: OrganizationCreate
|
||||
) -> Organization:
|
||||
"""Create a new organization with error handling."""
|
||||
try:
|
||||
db_obj = Organization(
|
||||
@@ -42,7 +45,7 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp
|
||||
slug=obj_in.slug,
|
||||
description=obj_in.description,
|
||||
is_active=obj_in.is_active,
|
||||
settings=obj_in.settings or {}
|
||||
settings=obj_in.settings or {},
|
||||
)
|
||||
db.add(db_obj)
|
||||
await db.commit()
|
||||
@@ -50,15 +53,19 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp
|
||||
return db_obj
|
||||
except IntegrityError as e:
|
||||
await db.rollback()
|
||||
error_msg = str(e.orig) if hasattr(e, 'orig') else str(e)
|
||||
error_msg = str(e.orig) if hasattr(e, "orig") else str(e)
|
||||
if "slug" in error_msg.lower():
|
||||
logger.warning(f"Duplicate slug attempted: {obj_in.slug}")
|
||||
raise ValueError(f"Organization with slug '{obj_in.slug}' already exists")
|
||||
raise ValueError(
|
||||
f"Organization with slug '{obj_in.slug}' already exists"
|
||||
)
|
||||
logger.error(f"Integrity error creating organization: {error_msg}")
|
||||
raise ValueError(f"Database integrity error: {error_msg}")
|
||||
except Exception as e:
|
||||
await db.rollback()
|
||||
logger.error(f"Unexpected error creating organization: {str(e)}", exc_info=True)
|
||||
logger.error(
|
||||
f"Unexpected error creating organization: {e!s}", exc_info=True
|
||||
)
|
||||
raise
|
||||
|
||||
async def get_multi_with_filters(
|
||||
@@ -67,11 +74,11 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp
|
||||
*,
|
||||
skip: int = 0,
|
||||
limit: int = 100,
|
||||
is_active: Optional[bool] = None,
|
||||
search: Optional[str] = None,
|
||||
is_active: bool | None = None,
|
||||
search: str | None = None,
|
||||
sort_by: str = "created_at",
|
||||
sort_order: str = "desc"
|
||||
) -> tuple[List[Organization], int]:
|
||||
sort_order: str = "desc",
|
||||
) -> tuple[list[Organization], int]:
|
||||
"""
|
||||
Get multiple organizations with filtering, searching, and sorting.
|
||||
|
||||
@@ -89,7 +96,7 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp
|
||||
search_filter = or_(
|
||||
Organization.name.ilike(f"%{search}%"),
|
||||
Organization.slug.ilike(f"%{search}%"),
|
||||
Organization.description.ilike(f"%{search}%")
|
||||
Organization.description.ilike(f"%{search}%"),
|
||||
)
|
||||
query = query.where(search_filter)
|
||||
|
||||
@@ -112,7 +119,7 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp
|
||||
|
||||
return organizations, total
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting organizations with filters: {str(e)}")
|
||||
logger.error(f"Error getting organizations with filters: {e!s}")
|
||||
raise
|
||||
|
||||
async def get_member_count(self, db: AsyncSession, *, organization_id: UUID) -> int:
|
||||
@@ -122,13 +129,15 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp
|
||||
select(func.count(UserOrganization.user_id)).where(
|
||||
and_(
|
||||
UserOrganization.organization_id == organization_id,
|
||||
UserOrganization.is_active == True
|
||||
UserOrganization.is_active,
|
||||
)
|
||||
)
|
||||
)
|
||||
return result.scalar_one() or 0
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting member count for organization {organization_id}: {str(e)}")
|
||||
logger.error(
|
||||
f"Error getting member count for organization {organization_id}: {e!s}"
|
||||
)
|
||||
raise
|
||||
|
||||
async def get_multi_with_member_counts(
|
||||
@@ -137,9 +146,9 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp
|
||||
*,
|
||||
skip: int = 0,
|
||||
limit: int = 100,
|
||||
is_active: Optional[bool] = None,
|
||||
search: Optional[str] = None
|
||||
) -> tuple[List[Dict[str, Any]], int]:
|
||||
is_active: bool | None = None,
|
||||
search: str | None = None,
|
||||
) -> tuple[list[dict[str, Any]], int]:
|
||||
"""
|
||||
Get organizations with member counts in a SINGLE QUERY using JOIN and GROUP BY.
|
||||
This eliminates the N+1 query problem.
|
||||
@@ -156,13 +165,19 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp
|
||||
func.count(
|
||||
func.distinct(
|
||||
case(
|
||||
(UserOrganization.is_active == True, UserOrganization.user_id),
|
||||
else_=None
|
||||
(
|
||||
UserOrganization.is_active,
|
||||
UserOrganization.user_id,
|
||||
),
|
||||
else_=None,
|
||||
)
|
||||
)
|
||||
).label('member_count')
|
||||
).label("member_count"),
|
||||
)
|
||||
.outerjoin(
|
||||
UserOrganization,
|
||||
Organization.id == UserOrganization.organization_id,
|
||||
)
|
||||
.outerjoin(UserOrganization, Organization.id == UserOrganization.organization_id)
|
||||
.group_by(Organization.id)
|
||||
)
|
||||
|
||||
@@ -174,7 +189,7 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp
|
||||
search_filter = or_(
|
||||
Organization.name.ilike(f"%{search}%"),
|
||||
Organization.slug.ilike(f"%{search}%"),
|
||||
Organization.description.ilike(f"%{search}%")
|
||||
Organization.description.ilike(f"%{search}%"),
|
||||
)
|
||||
query = query.where(search_filter)
|
||||
|
||||
@@ -189,24 +204,25 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp
|
||||
total = count_result.scalar_one()
|
||||
|
||||
# Apply pagination and ordering
|
||||
query = query.order_by(Organization.created_at.desc()).offset(skip).limit(limit)
|
||||
query = (
|
||||
query.order_by(Organization.created_at.desc()).offset(skip).limit(limit)
|
||||
)
|
||||
|
||||
result = await db.execute(query)
|
||||
rows = result.all()
|
||||
|
||||
# Convert to list of dicts
|
||||
orgs_with_counts = [
|
||||
{
|
||||
'organization': org,
|
||||
'member_count': member_count
|
||||
}
|
||||
{"organization": org, "member_count": member_count}
|
||||
for org, member_count in rows
|
||||
]
|
||||
|
||||
return orgs_with_counts, total
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting organizations with member counts: {str(e)}", exc_info=True)
|
||||
logger.error(
|
||||
f"Error getting organizations with member counts: {e!s}", exc_info=True
|
||||
)
|
||||
raise
|
||||
|
||||
async def add_user(
|
||||
@@ -216,7 +232,7 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp
|
||||
organization_id: UUID,
|
||||
user_id: UUID,
|
||||
role: OrganizationRole = OrganizationRole.MEMBER,
|
||||
custom_permissions: Optional[str] = None
|
||||
custom_permissions: str | None = None,
|
||||
) -> UserOrganization:
|
||||
"""Add a user to an organization with a specific role."""
|
||||
try:
|
||||
@@ -225,7 +241,7 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp
|
||||
select(UserOrganization).where(
|
||||
and_(
|
||||
UserOrganization.user_id == user_id,
|
||||
UserOrganization.organization_id == organization_id
|
||||
UserOrganization.organization_id == organization_id,
|
||||
)
|
||||
)
|
||||
)
|
||||
@@ -249,7 +265,7 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp
|
||||
organization_id=organization_id,
|
||||
role=role,
|
||||
is_active=True,
|
||||
custom_permissions=custom_permissions
|
||||
custom_permissions=custom_permissions,
|
||||
)
|
||||
db.add(user_org)
|
||||
await db.commit()
|
||||
@@ -257,19 +273,15 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp
|
||||
return user_org
|
||||
except IntegrityError as e:
|
||||
await db.rollback()
|
||||
logger.error(f"Integrity error adding user to organization: {str(e)}")
|
||||
logger.error(f"Integrity error adding user to organization: {e!s}")
|
||||
raise ValueError("Failed to add user to organization")
|
||||
except Exception as e:
|
||||
await db.rollback()
|
||||
logger.error(f"Error adding user to organization: {str(e)}", exc_info=True)
|
||||
logger.error(f"Error adding user to organization: {e!s}", exc_info=True)
|
||||
raise
|
||||
|
||||
async def remove_user(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
*,
|
||||
organization_id: UUID,
|
||||
user_id: UUID
|
||||
self, db: AsyncSession, *, organization_id: UUID, user_id: UUID
|
||||
) -> bool:
|
||||
"""Remove a user from an organization (soft delete)."""
|
||||
try:
|
||||
@@ -277,7 +289,7 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp
|
||||
select(UserOrganization).where(
|
||||
and_(
|
||||
UserOrganization.user_id == user_id,
|
||||
UserOrganization.organization_id == organization_id
|
||||
UserOrganization.organization_id == organization_id,
|
||||
)
|
||||
)
|
||||
)
|
||||
@@ -291,7 +303,7 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp
|
||||
return True
|
||||
except Exception as e:
|
||||
await db.rollback()
|
||||
logger.error(f"Error removing user from organization: {str(e)}", exc_info=True)
|
||||
logger.error(f"Error removing user from organization: {e!s}", exc_info=True)
|
||||
raise
|
||||
|
||||
async def update_user_role(
|
||||
@@ -301,15 +313,15 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp
|
||||
organization_id: UUID,
|
||||
user_id: UUID,
|
||||
role: OrganizationRole,
|
||||
custom_permissions: Optional[str] = None
|
||||
) -> Optional[UserOrganization]:
|
||||
custom_permissions: str | None = None,
|
||||
) -> UserOrganization | None:
|
||||
"""Update a user's role in an organization."""
|
||||
try:
|
||||
result = await db.execute(
|
||||
select(UserOrganization).where(
|
||||
and_(
|
||||
UserOrganization.user_id == user_id,
|
||||
UserOrganization.organization_id == organization_id
|
||||
UserOrganization.organization_id == organization_id,
|
||||
)
|
||||
)
|
||||
)
|
||||
@@ -326,7 +338,7 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp
|
||||
return user_org
|
||||
except Exception as e:
|
||||
await db.rollback()
|
||||
logger.error(f"Error updating user role: {str(e)}", exc_info=True)
|
||||
logger.error(f"Error updating user role: {e!s}", exc_info=True)
|
||||
raise
|
||||
|
||||
async def get_organization_members(
|
||||
@@ -336,8 +348,8 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp
|
||||
organization_id: UUID,
|
||||
skip: int = 0,
|
||||
limit: int = 100,
|
||||
is_active: bool = True
|
||||
) -> tuple[List[Dict[str, Any]], int]:
|
||||
is_active: bool = True,
|
||||
) -> tuple[list[dict[str, Any]], int]:
|
||||
"""
|
||||
Get members of an organization with user details.
|
||||
|
||||
@@ -359,46 +371,55 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp
|
||||
count_query = select(func.count()).select_from(
|
||||
select(UserOrganization)
|
||||
.where(UserOrganization.organization_id == organization_id)
|
||||
.where(UserOrganization.is_active == is_active if is_active is not None else True)
|
||||
.where(
|
||||
UserOrganization.is_active == is_active
|
||||
if is_active is not None
|
||||
else True
|
||||
)
|
||||
.alias()
|
||||
)
|
||||
count_result = await db.execute(count_query)
|
||||
total = count_result.scalar_one()
|
||||
|
||||
# Apply ordering and pagination
|
||||
query = query.order_by(UserOrganization.created_at.desc()).offset(skip).limit(limit)
|
||||
query = (
|
||||
query.order_by(UserOrganization.created_at.desc())
|
||||
.offset(skip)
|
||||
.limit(limit)
|
||||
)
|
||||
result = await db.execute(query)
|
||||
results = result.all()
|
||||
|
||||
members = []
|
||||
for user_org, user in results:
|
||||
members.append({
|
||||
"user_id": user.id,
|
||||
"email": user.email,
|
||||
"first_name": user.first_name,
|
||||
"last_name": user.last_name,
|
||||
"role": user_org.role,
|
||||
"is_active": user_org.is_active,
|
||||
"joined_at": user_org.created_at
|
||||
})
|
||||
members.append(
|
||||
{
|
||||
"user_id": user.id,
|
||||
"email": user.email,
|
||||
"first_name": user.first_name,
|
||||
"last_name": user.last_name,
|
||||
"role": user_org.role,
|
||||
"is_active": user_org.is_active,
|
||||
"joined_at": user_org.created_at,
|
||||
}
|
||||
)
|
||||
|
||||
return members, total
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting organization members: {str(e)}")
|
||||
logger.error(f"Error getting organization members: {e!s}")
|
||||
raise
|
||||
|
||||
async def get_user_organizations(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
*,
|
||||
user_id: UUID,
|
||||
is_active: bool = True
|
||||
) -> List[Organization]:
|
||||
self, db: AsyncSession, *, user_id: UUID, is_active: bool = True
|
||||
) -> list[Organization]:
|
||||
"""Get all organizations a user belongs to."""
|
||||
try:
|
||||
query = (
|
||||
select(Organization)
|
||||
.join(UserOrganization, Organization.id == UserOrganization.organization_id)
|
||||
.join(
|
||||
UserOrganization,
|
||||
Organization.id == UserOrganization.organization_id,
|
||||
)
|
||||
.where(UserOrganization.user_id == user_id)
|
||||
)
|
||||
|
||||
@@ -408,16 +429,12 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp
|
||||
result = await db.execute(query)
|
||||
return list(result.scalars().all())
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting user organizations: {str(e)}")
|
||||
logger.error(f"Error getting user organizations: {e!s}")
|
||||
raise
|
||||
|
||||
async def get_user_organizations_with_details(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
*,
|
||||
user_id: UUID,
|
||||
is_active: bool = True
|
||||
) -> List[Dict[str, Any]]:
|
||||
self, db: AsyncSession, *, user_id: UUID, is_active: bool = True
|
||||
) -> list[dict[str, Any]]:
|
||||
"""
|
||||
Get user's organizations with role and member count in SINGLE QUERY.
|
||||
Eliminates N+1 problem by using subquery for member counts.
|
||||
@@ -430,9 +447,9 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp
|
||||
member_count_subq = (
|
||||
select(
|
||||
UserOrganization.organization_id,
|
||||
func.count(UserOrganization.user_id).label('member_count')
|
||||
func.count(UserOrganization.user_id).label("member_count"),
|
||||
)
|
||||
.where(UserOrganization.is_active == True)
|
||||
.where(UserOrganization.is_active)
|
||||
.group_by(UserOrganization.organization_id)
|
||||
.subquery()
|
||||
)
|
||||
@@ -442,10 +459,18 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp
|
||||
select(
|
||||
Organization,
|
||||
UserOrganization.role,
|
||||
func.coalesce(member_count_subq.c.member_count, 0).label('member_count')
|
||||
func.coalesce(member_count_subq.c.member_count, 0).label(
|
||||
"member_count"
|
||||
),
|
||||
)
|
||||
.join(
|
||||
UserOrganization,
|
||||
Organization.id == UserOrganization.organization_id,
|
||||
)
|
||||
.outerjoin(
|
||||
member_count_subq,
|
||||
Organization.id == member_count_subq.c.organization_id,
|
||||
)
|
||||
.join(UserOrganization, Organization.id == UserOrganization.organization_id)
|
||||
.outerjoin(member_count_subq, Organization.id == member_count_subq.c.organization_id)
|
||||
.where(UserOrganization.user_id == user_id)
|
||||
)
|
||||
|
||||
@@ -456,25 +481,19 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp
|
||||
rows = result.all()
|
||||
|
||||
return [
|
||||
{
|
||||
'organization': org,
|
||||
'role': role,
|
||||
'member_count': member_count
|
||||
}
|
||||
{"organization": org, "role": role, "member_count": member_count}
|
||||
for org, role, member_count in rows
|
||||
]
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting user organizations with details: {str(e)}", exc_info=True)
|
||||
logger.error(
|
||||
f"Error getting user organizations with details: {e!s}", exc_info=True
|
||||
)
|
||||
raise
|
||||
|
||||
async def get_user_role_in_org(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
*,
|
||||
user_id: UUID,
|
||||
organization_id: UUID
|
||||
) -> Optional[OrganizationRole]:
|
||||
self, db: AsyncSession, *, user_id: UUID, organization_id: UUID
|
||||
) -> OrganizationRole | None:
|
||||
"""Get a user's role in a specific organization."""
|
||||
try:
|
||||
result = await db.execute(
|
||||
@@ -482,7 +501,7 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp
|
||||
and_(
|
||||
UserOrganization.user_id == user_id,
|
||||
UserOrganization.organization_id == organization_id,
|
||||
UserOrganization.is_active == True
|
||||
UserOrganization.is_active,
|
||||
)
|
||||
)
|
||||
)
|
||||
@@ -490,29 +509,25 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp
|
||||
|
||||
return user_org.role if user_org else None
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting user role in org: {str(e)}")
|
||||
logger.error(f"Error getting user role in org: {e!s}")
|
||||
raise
|
||||
|
||||
async def is_user_org_owner(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
*,
|
||||
user_id: UUID,
|
||||
organization_id: UUID
|
||||
self, db: AsyncSession, *, user_id: UUID, organization_id: UUID
|
||||
) -> bool:
|
||||
"""Check if a user is an owner of an organization."""
|
||||
role = await self.get_user_role_in_org(db, user_id=user_id, organization_id=organization_id)
|
||||
role = await self.get_user_role_in_org(
|
||||
db, user_id=user_id, organization_id=organization_id
|
||||
)
|
||||
return role == OrganizationRole.OWNER
|
||||
|
||||
async def is_user_org_admin(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
*,
|
||||
user_id: UUID,
|
||||
organization_id: UUID
|
||||
self, db: AsyncSession, *, user_id: UUID, organization_id: UUID
|
||||
) -> bool:
|
||||
"""Check if a user is an owner or admin of an organization."""
|
||||
role = await self.get_user_role_in_org(db, user_id=user_id, organization_id=organization_id)
|
||||
role = await self.get_user_role_in_org(
|
||||
db, user_id=user_id, organization_id=organization_id
|
||||
)
|
||||
return role in [OrganizationRole.OWNER, OrganizationRole.ADMIN]
|
||||
|
||||
|
||||
|
||||
@@ -1,13 +1,13 @@
|
||||
"""
|
||||
Async CRUD operations for user sessions using SQLAlchemy 2.0 patterns.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import uuid
|
||||
from datetime import datetime, timezone, timedelta
|
||||
from typing import List, Optional
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy import and_, select, update, delete, func
|
||||
from sqlalchemy import and_, delete, func, select, update
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm import joinedload
|
||||
|
||||
@@ -21,7 +21,7 @@ logger = logging.getLogger(__name__)
|
||||
class CRUDSession(CRUDBase[UserSession, SessionCreate, SessionUpdate]):
|
||||
"""Async CRUD operations for user sessions."""
|
||||
|
||||
async def get_by_jti(self, db: AsyncSession, *, jti: str) -> Optional[UserSession]:
|
||||
async def get_by_jti(self, db: AsyncSession, *, jti: str) -> UserSession | None:
|
||||
"""
|
||||
Get session by refresh token JTI.
|
||||
|
||||
@@ -38,10 +38,12 @@ class CRUDSession(CRUDBase[UserSession, SessionCreate, SessionUpdate]):
|
||||
)
|
||||
return result.scalar_one_or_none()
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting session by JTI {jti}: {str(e)}")
|
||||
logger.error(f"Error getting session by JTI {jti}: {e!s}")
|
||||
raise
|
||||
|
||||
async def get_active_by_jti(self, db: AsyncSession, *, jti: str) -> Optional[UserSession]:
|
||||
async def get_active_by_jti(
|
||||
self, db: AsyncSession, *, jti: str
|
||||
) -> UserSession | None:
|
||||
"""
|
||||
Get active session by refresh token JTI.
|
||||
|
||||
@@ -57,13 +59,13 @@ class CRUDSession(CRUDBase[UserSession, SessionCreate, SessionUpdate]):
|
||||
select(UserSession).where(
|
||||
and_(
|
||||
UserSession.refresh_token_jti == jti,
|
||||
UserSession.is_active == True
|
||||
UserSession.is_active,
|
||||
)
|
||||
)
|
||||
)
|
||||
return result.scalar_one_or_none()
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting active session by JTI {jti}: {str(e)}")
|
||||
logger.error(f"Error getting active session by JTI {jti}: {e!s}")
|
||||
raise
|
||||
|
||||
async def get_user_sessions(
|
||||
@@ -72,8 +74,8 @@ class CRUDSession(CRUDBase[UserSession, SessionCreate, SessionUpdate]):
|
||||
*,
|
||||
user_id: str,
|
||||
active_only: bool = True,
|
||||
with_user: bool = False
|
||||
) -> List[UserSession]:
|
||||
with_user: bool = False,
|
||||
) -> list[UserSession]:
|
||||
"""
|
||||
Get all sessions for a user with optional eager loading.
|
||||
|
||||
@@ -97,20 +99,17 @@ class CRUDSession(CRUDBase[UserSession, SessionCreate, SessionUpdate]):
|
||||
query = query.options(joinedload(UserSession.user))
|
||||
|
||||
if active_only:
|
||||
query = query.where(UserSession.is_active == True)
|
||||
query = query.where(UserSession.is_active)
|
||||
|
||||
query = query.order_by(UserSession.last_used_at.desc())
|
||||
result = await db.execute(query)
|
||||
return list(result.scalars().all())
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting sessions for user {user_id}: {str(e)}")
|
||||
logger.error(f"Error getting sessions for user {user_id}: {e!s}")
|
||||
raise
|
||||
|
||||
async def create_session(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
*,
|
||||
obj_in: SessionCreate
|
||||
self, db: AsyncSession, *, obj_in: SessionCreate
|
||||
) -> UserSession:
|
||||
"""
|
||||
Create a new user session.
|
||||
@@ -151,10 +150,12 @@ class CRUDSession(CRUDBase[UserSession, SessionCreate, SessionUpdate]):
|
||||
return db_obj
|
||||
except Exception as e:
|
||||
await db.rollback()
|
||||
logger.error(f"Error creating session: {str(e)}", exc_info=True)
|
||||
raise ValueError(f"Failed to create session: {str(e)}")
|
||||
logger.error(f"Error creating session: {e!s}", exc_info=True)
|
||||
raise ValueError(f"Failed to create session: {e!s}")
|
||||
|
||||
async def deactivate(self, db: AsyncSession, *, session_id: str) -> Optional[UserSession]:
|
||||
async def deactivate(
|
||||
self, db: AsyncSession, *, session_id: str
|
||||
) -> UserSession | None:
|
||||
"""
|
||||
Deactivate a session (logout from device).
|
||||
|
||||
@@ -184,14 +185,11 @@ class CRUDSession(CRUDBase[UserSession, SessionCreate, SessionUpdate]):
|
||||
return session
|
||||
except Exception as e:
|
||||
await db.rollback()
|
||||
logger.error(f"Error deactivating session {session_id}: {str(e)}")
|
||||
logger.error(f"Error deactivating session {session_id}: {e!s}")
|
||||
raise
|
||||
|
||||
async def deactivate_all_user_sessions(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
*,
|
||||
user_id: str
|
||||
self, db: AsyncSession, *, user_id: str
|
||||
) -> int:
|
||||
"""
|
||||
Deactivate all active sessions for a user (logout from all devices).
|
||||
@@ -209,12 +207,7 @@ class CRUDSession(CRUDBase[UserSession, SessionCreate, SessionUpdate]):
|
||||
|
||||
stmt = (
|
||||
update(UserSession)
|
||||
.where(
|
||||
and_(
|
||||
UserSession.user_id == user_uuid,
|
||||
UserSession.is_active == True
|
||||
)
|
||||
)
|
||||
.where(and_(UserSession.user_id == user_uuid, UserSession.is_active))
|
||||
.values(is_active=False)
|
||||
)
|
||||
|
||||
@@ -228,14 +221,11 @@ class CRUDSession(CRUDBase[UserSession, SessionCreate, SessionUpdate]):
|
||||
return count
|
||||
except Exception as e:
|
||||
await db.rollback()
|
||||
logger.error(f"Error deactivating all sessions for user {user_id}: {str(e)}")
|
||||
logger.error(f"Error deactivating all sessions for user {user_id}: {e!s}")
|
||||
raise
|
||||
|
||||
async def update_last_used(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
*,
|
||||
session: UserSession
|
||||
self, db: AsyncSession, *, session: UserSession
|
||||
) -> UserSession:
|
||||
"""
|
||||
Update the last_used_at timestamp for a session.
|
||||
@@ -248,14 +238,14 @@ class CRUDSession(CRUDBase[UserSession, SessionCreate, SessionUpdate]):
|
||||
Updated UserSession
|
||||
"""
|
||||
try:
|
||||
session.last_used_at = datetime.now(timezone.utc)
|
||||
session.last_used_at = datetime.now(UTC)
|
||||
db.add(session)
|
||||
await db.commit()
|
||||
await db.refresh(session)
|
||||
return session
|
||||
except Exception as e:
|
||||
await db.rollback()
|
||||
logger.error(f"Error updating last_used for session {session.id}: {str(e)}")
|
||||
logger.error(f"Error updating last_used for session {session.id}: {e!s}")
|
||||
raise
|
||||
|
||||
async def update_refresh_token(
|
||||
@@ -264,7 +254,7 @@ class CRUDSession(CRUDBase[UserSession, SessionCreate, SessionUpdate]):
|
||||
*,
|
||||
session: UserSession,
|
||||
new_jti: str,
|
||||
new_expires_at: datetime
|
||||
new_expires_at: datetime,
|
||||
) -> UserSession:
|
||||
"""
|
||||
Update session with new refresh token JTI and expiration.
|
||||
@@ -283,14 +273,16 @@ class CRUDSession(CRUDBase[UserSession, SessionCreate, SessionUpdate]):
|
||||
try:
|
||||
session.refresh_token_jti = new_jti
|
||||
session.expires_at = new_expires_at
|
||||
session.last_used_at = datetime.now(timezone.utc)
|
||||
session.last_used_at = datetime.now(UTC)
|
||||
db.add(session)
|
||||
await db.commit()
|
||||
await db.refresh(session)
|
||||
return session
|
||||
except Exception as e:
|
||||
await db.rollback()
|
||||
logger.error(f"Error updating refresh token for session {session.id}: {str(e)}")
|
||||
logger.error(
|
||||
f"Error updating refresh token for session {session.id}: {e!s}"
|
||||
)
|
||||
raise
|
||||
|
||||
async def cleanup_expired(self, db: AsyncSession, *, keep_days: int = 30) -> int:
|
||||
@@ -311,15 +303,15 @@ class CRUDSession(CRUDBase[UserSession, SessionCreate, SessionUpdate]):
|
||||
Number of sessions deleted
|
||||
"""
|
||||
try:
|
||||
cutoff_date = datetime.now(timezone.utc) - timedelta(days=keep_days)
|
||||
now = datetime.now(timezone.utc)
|
||||
cutoff_date = datetime.now(UTC) - timedelta(days=keep_days)
|
||||
now = datetime.now(UTC)
|
||||
|
||||
# Use bulk DELETE with WHERE clause - single query
|
||||
stmt = delete(UserSession).where(
|
||||
and_(
|
||||
UserSession.is_active == False,
|
||||
not UserSession.is_active,
|
||||
UserSession.expires_at < now,
|
||||
UserSession.created_at < cutoff_date
|
||||
UserSession.created_at < cutoff_date,
|
||||
)
|
||||
)
|
||||
|
||||
@@ -334,15 +326,10 @@ class CRUDSession(CRUDBase[UserSession, SessionCreate, SessionUpdate]):
|
||||
return count
|
||||
except Exception as e:
|
||||
await db.rollback()
|
||||
logger.error(f"Error cleaning up expired sessions: {str(e)}")
|
||||
logger.error(f"Error cleaning up expired sessions: {e!s}")
|
||||
raise
|
||||
|
||||
async def cleanup_expired_for_user(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
*,
|
||||
user_id: str
|
||||
) -> int:
|
||||
async def cleanup_expired_for_user(self, db: AsyncSession, *, user_id: str) -> int:
|
||||
"""
|
||||
Clean up expired and inactive sessions for a specific user.
|
||||
|
||||
@@ -363,14 +350,14 @@ class CRUDSession(CRUDBase[UserSession, SessionCreate, SessionUpdate]):
|
||||
logger.error(f"Invalid UUID format: {user_id}")
|
||||
raise ValueError(f"Invalid user ID format: {user_id}")
|
||||
|
||||
now = datetime.now(timezone.utc)
|
||||
now = datetime.now(UTC)
|
||||
|
||||
# Use bulk DELETE with WHERE clause - single query
|
||||
stmt = delete(UserSession).where(
|
||||
and_(
|
||||
UserSession.user_id == uuid_obj,
|
||||
UserSession.is_active == False,
|
||||
UserSession.expires_at < now
|
||||
not UserSession.is_active,
|
||||
UserSession.expires_at < now,
|
||||
)
|
||||
)
|
||||
|
||||
@@ -388,7 +375,7 @@ class CRUDSession(CRUDBase[UserSession, SessionCreate, SessionUpdate]):
|
||||
except Exception as e:
|
||||
await db.rollback()
|
||||
logger.error(
|
||||
f"Error cleaning up expired sessions for user {user_id}: {str(e)}"
|
||||
f"Error cleaning up expired sessions for user {user_id}: {e!s}"
|
||||
)
|
||||
raise
|
||||
|
||||
@@ -409,15 +396,12 @@ class CRUDSession(CRUDBase[UserSession, SessionCreate, SessionUpdate]):
|
||||
|
||||
result = await db.execute(
|
||||
select(func.count(UserSession.id)).where(
|
||||
and_(
|
||||
UserSession.user_id == user_uuid,
|
||||
UserSession.is_active == True
|
||||
)
|
||||
and_(UserSession.user_id == user_uuid, UserSession.is_active)
|
||||
)
|
||||
)
|
||||
return result.scalar_one()
|
||||
except Exception as e:
|
||||
logger.error(f"Error counting sessions for user {user_id}: {str(e)}")
|
||||
logger.error(f"Error counting sessions for user {user_id}: {e!s}")
|
||||
raise
|
||||
|
||||
async def get_all_sessions(
|
||||
@@ -427,8 +411,8 @@ class CRUDSession(CRUDBase[UserSession, SessionCreate, SessionUpdate]):
|
||||
skip: int = 0,
|
||||
limit: int = 100,
|
||||
active_only: bool = True,
|
||||
with_user: bool = True
|
||||
) -> tuple[List[UserSession], int]:
|
||||
with_user: bool = True,
|
||||
) -> tuple[list[UserSession], int]:
|
||||
"""
|
||||
Get all sessions across all users with pagination (admin only).
|
||||
|
||||
@@ -451,18 +435,22 @@ class CRUDSession(CRUDBase[UserSession, SessionCreate, SessionUpdate]):
|
||||
query = query.options(joinedload(UserSession.user))
|
||||
|
||||
if active_only:
|
||||
query = query.where(UserSession.is_active == True)
|
||||
query = query.where(UserSession.is_active)
|
||||
|
||||
# Get total count
|
||||
count_query = select(func.count(UserSession.id))
|
||||
if active_only:
|
||||
count_query = count_query.where(UserSession.is_active == True)
|
||||
count_query = count_query.where(UserSession.is_active)
|
||||
|
||||
count_result = await db.execute(count_query)
|
||||
total = count_result.scalar_one()
|
||||
|
||||
# Apply pagination and ordering
|
||||
query = query.order_by(UserSession.last_used_at.desc()).offset(skip).limit(limit)
|
||||
query = (
|
||||
query.order_by(UserSession.last_used_at.desc())
|
||||
.offset(skip)
|
||||
.limit(limit)
|
||||
)
|
||||
|
||||
result = await db.execute(query)
|
||||
sessions = list(result.scalars().all())
|
||||
@@ -470,7 +458,7 @@ class CRUDSession(CRUDBase[UserSession, SessionCreate, SessionUpdate]):
|
||||
return sessions, total
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting all sessions: {str(e)}", exc_info=True)
|
||||
logger.error(f"Error getting all sessions: {e!s}", exc_info=True)
|
||||
raise
|
||||
|
||||
|
||||
|
||||
@@ -1,8 +1,9 @@
|
||||
# app/crud/user_async.py
|
||||
"""Async CRUD operations for User model using SQLAlchemy 2.0 patterns."""
|
||||
|
||||
import logging
|
||||
from datetime import datetime, timezone
|
||||
from typing import Optional, Union, Dict, Any, List, Tuple
|
||||
from datetime import UTC, datetime
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy import or_, select, update
|
||||
@@ -20,15 +21,13 @@ logger = logging.getLogger(__name__)
|
||||
class CRUDUser(CRUDBase[User, UserCreate, UserUpdate]):
|
||||
"""Async CRUD operations for User model."""
|
||||
|
||||
async def get_by_email(self, db: AsyncSession, *, email: str) -> Optional[User]:
|
||||
async def get_by_email(self, db: AsyncSession, *, email: str) -> User | None:
|
||||
"""Get user by email address."""
|
||||
try:
|
||||
result = await db.execute(
|
||||
select(User).where(User.email == email)
|
||||
)
|
||||
result = await db.execute(select(User).where(User.email == email))
|
||||
return result.scalar_one_or_none()
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting user by email {email}: {str(e)}")
|
||||
logger.error(f"Error getting user by email {email}: {e!s}")
|
||||
raise
|
||||
|
||||
async def create(self, db: AsyncSession, *, obj_in: UserCreate) -> User:
|
||||
@@ -42,9 +41,13 @@ class CRUDUser(CRUDBase[User, UserCreate, UserUpdate]):
|
||||
password_hash=password_hash,
|
||||
first_name=obj_in.first_name,
|
||||
last_name=obj_in.last_name,
|
||||
phone_number=obj_in.phone_number if hasattr(obj_in, 'phone_number') else None,
|
||||
is_superuser=obj_in.is_superuser if hasattr(obj_in, 'is_superuser') else False,
|
||||
preferences={}
|
||||
phone_number=obj_in.phone_number
|
||||
if hasattr(obj_in, "phone_number")
|
||||
else None,
|
||||
is_superuser=obj_in.is_superuser
|
||||
if hasattr(obj_in, "is_superuser")
|
||||
else False,
|
||||
preferences={},
|
||||
)
|
||||
db.add(db_obj)
|
||||
await db.commit()
|
||||
@@ -52,7 +55,7 @@ class CRUDUser(CRUDBase[User, UserCreate, UserUpdate]):
|
||||
return db_obj
|
||||
except IntegrityError as e:
|
||||
await db.rollback()
|
||||
error_msg = str(e.orig) if hasattr(e, 'orig') else str(e)
|
||||
error_msg = str(e.orig) if hasattr(e, "orig") else str(e)
|
||||
if "email" in error_msg.lower():
|
||||
logger.warning(f"Duplicate email attempted: {obj_in.email}")
|
||||
raise ValueError(f"User with email {obj_in.email} already exists")
|
||||
@@ -60,15 +63,11 @@ class CRUDUser(CRUDBase[User, UserCreate, UserUpdate]):
|
||||
raise ValueError(f"Database integrity error: {error_msg}")
|
||||
except Exception as e:
|
||||
await db.rollback()
|
||||
logger.error(f"Unexpected error creating user: {str(e)}", exc_info=True)
|
||||
logger.error(f"Unexpected error creating user: {e!s}", exc_info=True)
|
||||
raise
|
||||
|
||||
async def update(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
*,
|
||||
db_obj: User,
|
||||
obj_in: Union[UserUpdate, Dict[str, Any]]
|
||||
self, db: AsyncSession, *, db_obj: User, obj_in: UserUpdate | dict[str, Any]
|
||||
) -> User:
|
||||
"""Update user with async password hashing if password is updated."""
|
||||
if isinstance(obj_in, dict):
|
||||
@@ -79,7 +78,9 @@ class CRUDUser(CRUDBase[User, UserCreate, UserUpdate]):
|
||||
# Handle password separately if it exists in update data
|
||||
# Hash password asynchronously to avoid blocking event loop
|
||||
if "password" in update_data:
|
||||
update_data["password_hash"] = await get_password_hash_async(update_data["password"])
|
||||
update_data["password_hash"] = await get_password_hash_async(
|
||||
update_data["password"]
|
||||
)
|
||||
del update_data["password"]
|
||||
|
||||
return await super().update(db, db_obj=db_obj, obj_in=update_data)
|
||||
@@ -90,11 +91,11 @@ class CRUDUser(CRUDBase[User, UserCreate, UserUpdate]):
|
||||
*,
|
||||
skip: int = 0,
|
||||
limit: int = 100,
|
||||
sort_by: Optional[str] = None,
|
||||
sort_by: str | None = None,
|
||||
sort_order: str = "asc",
|
||||
filters: Optional[Dict[str, Any]] = None,
|
||||
search: Optional[str] = None
|
||||
) -> Tuple[List[User], int]:
|
||||
filters: dict[str, Any] | None = None,
|
||||
search: str | None = None,
|
||||
) -> tuple[list[User], int]:
|
||||
"""
|
||||
Get multiple users with total count, filtering, sorting, and search.
|
||||
|
||||
@@ -136,12 +137,13 @@ class CRUDUser(CRUDBase[User, UserCreate, UserUpdate]):
|
||||
search_filter = or_(
|
||||
User.email.ilike(f"%{search}%"),
|
||||
User.first_name.ilike(f"%{search}%"),
|
||||
User.last_name.ilike(f"%{search}%")
|
||||
User.last_name.ilike(f"%{search}%"),
|
||||
)
|
||||
query = query.where(search_filter)
|
||||
|
||||
# Get total count
|
||||
from sqlalchemy import func
|
||||
|
||||
count_query = select(func.count()).select_from(query.alias())
|
||||
count_result = await db.execute(count_query)
|
||||
total = count_result.scalar_one()
|
||||
@@ -162,15 +164,11 @@ class CRUDUser(CRUDBase[User, UserCreate, UserUpdate]):
|
||||
return users, total
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error retrieving paginated users: {str(e)}")
|
||||
logger.error(f"Error retrieving paginated users: {e!s}")
|
||||
raise
|
||||
|
||||
async def bulk_update_status(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
*,
|
||||
user_ids: List[UUID],
|
||||
is_active: bool
|
||||
self, db: AsyncSession, *, user_ids: list[UUID], is_active: bool
|
||||
) -> int:
|
||||
"""
|
||||
Bulk update is_active status for multiple users.
|
||||
@@ -192,7 +190,7 @@ class CRUDUser(CRUDBase[User, UserCreate, UserUpdate]):
|
||||
update(User)
|
||||
.where(User.id.in_(user_ids))
|
||||
.where(User.deleted_at.is_(None)) # Don't update deleted users
|
||||
.values(is_active=is_active, updated_at=datetime.now(timezone.utc))
|
||||
.values(is_active=is_active, updated_at=datetime.now(UTC))
|
||||
)
|
||||
|
||||
result = await db.execute(stmt)
|
||||
@@ -204,15 +202,15 @@ class CRUDUser(CRUDBase[User, UserCreate, UserUpdate]):
|
||||
|
||||
except Exception as e:
|
||||
await db.rollback()
|
||||
logger.error(f"Error bulk updating user status: {str(e)}", exc_info=True)
|
||||
logger.error(f"Error bulk updating user status: {e!s}", exc_info=True)
|
||||
raise
|
||||
|
||||
async def bulk_soft_delete(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
*,
|
||||
user_ids: List[UUID],
|
||||
exclude_user_id: Optional[UUID] = None
|
||||
user_ids: list[UUID],
|
||||
exclude_user_id: UUID | None = None,
|
||||
) -> int:
|
||||
"""
|
||||
Bulk soft delete multiple users.
|
||||
@@ -239,11 +237,13 @@ class CRUDUser(CRUDBase[User, UserCreate, UserUpdate]):
|
||||
stmt = (
|
||||
update(User)
|
||||
.where(User.id.in_(filtered_ids))
|
||||
.where(User.deleted_at.is_(None)) # Don't re-delete already deleted users
|
||||
.where(
|
||||
User.deleted_at.is_(None)
|
||||
) # Don't re-delete already deleted users
|
||||
.values(
|
||||
deleted_at=datetime.now(timezone.utc),
|
||||
deleted_at=datetime.now(UTC),
|
||||
is_active=False,
|
||||
updated_at=datetime.now(timezone.utc)
|
||||
updated_at=datetime.now(UTC),
|
||||
)
|
||||
)
|
||||
|
||||
@@ -256,7 +256,7 @@ class CRUDUser(CRUDBase[User, UserCreate, UserUpdate]):
|
||||
|
||||
except Exception as e:
|
||||
await db.rollback()
|
||||
logger.error(f"Error bulk deleting users: {str(e)}", exc_info=True)
|
||||
logger.error(f"Error bulk deleting users: {e!s}", exc_info=True)
|
||||
raise
|
||||
|
||||
def is_active(self, user: User) -> bool:
|
||||
|
||||
@@ -4,9 +4,9 @@ Async database initialization script.
|
||||
|
||||
Creates the first superuser if configured and doesn't already exist.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
from app.core.config import settings
|
||||
from app.core.database import SessionLocal, engine
|
||||
@@ -17,7 +17,7 @@ from app.schemas.users import UserCreate
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def init_db() -> Optional[User]:
|
||||
async def init_db() -> User | None:
|
||||
"""
|
||||
Initialize database with first superuser if settings are configured and user doesn't exist.
|
||||
|
||||
@@ -49,7 +49,7 @@ async def init_db() -> Optional[User]:
|
||||
password=superuser_password,
|
||||
first_name="Admin",
|
||||
last_name="User",
|
||||
is_superuser=True
|
||||
is_superuser=True,
|
||||
)
|
||||
|
||||
user = await user_crud.create(session, obj_in=user_in)
|
||||
@@ -70,13 +70,13 @@ async def main():
|
||||
# Configure logging to show info logs
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
|
||||
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
|
||||
)
|
||||
|
||||
try:
|
||||
user = await init_db()
|
||||
if user:
|
||||
print(f"✓ Database initialized successfully")
|
||||
print("✓ Database initialized successfully")
|
||||
print(f"✓ Superuser: {user.email}")
|
||||
else:
|
||||
print("✗ Failed to initialize database")
|
||||
|
||||
@@ -2,10 +2,10 @@ import logging
|
||||
import os
|
||||
from contextlib import asynccontextmanager
|
||||
from datetime import datetime
|
||||
from typing import Dict, Any
|
||||
from typing import Any
|
||||
|
||||
from apscheduler.schedulers.asyncio import AsyncIOScheduler
|
||||
from fastapi import FastAPI, status, Request, HTTPException
|
||||
from fastapi import FastAPI, HTTPException, Request, status
|
||||
from fastapi.exceptions import RequestValidationError
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.responses import HTMLResponse, JSONResponse
|
||||
@@ -19,9 +19,9 @@ from app.core.database import check_database_health
|
||||
from app.core.exceptions import (
|
||||
APIException,
|
||||
api_exception_handler,
|
||||
validation_exception_handler,
|
||||
http_exception_handler,
|
||||
unhandled_exception_handler
|
||||
unhandled_exception_handler,
|
||||
validation_exception_handler,
|
||||
)
|
||||
|
||||
scheduler = AsyncIOScheduler()
|
||||
@@ -52,11 +52,11 @@ async def lifespan(app: FastAPI):
|
||||
# Runs daily at 2:00 AM server time
|
||||
scheduler.add_job(
|
||||
cleanup_expired_sessions,
|
||||
'cron',
|
||||
"cron",
|
||||
hour=2,
|
||||
minute=0,
|
||||
id='cleanup_expired_sessions',
|
||||
replace_existing=True
|
||||
id="cleanup_expired_sessions",
|
||||
replace_existing=True,
|
||||
)
|
||||
|
||||
scheduler.start()
|
||||
@@ -73,12 +73,12 @@ async def lifespan(app: FastAPI):
|
||||
logger.info("Scheduled jobs stopped")
|
||||
|
||||
|
||||
logger.info(f"Starting app!!!")
|
||||
logger.info("Starting app!!!")
|
||||
app = FastAPI(
|
||||
title=settings.PROJECT_NAME,
|
||||
version=settings.VERSION,
|
||||
openapi_url=f"{settings.API_V1_STR}/openapi.json",
|
||||
lifespan=lifespan
|
||||
lifespan=lifespan,
|
||||
)
|
||||
|
||||
# Add rate limiter state to app
|
||||
@@ -96,7 +96,14 @@ app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=settings.BACKEND_CORS_ORIGINS,
|
||||
allow_credentials=True,
|
||||
allow_methods=["GET", "POST", "PUT", "PATCH", "DELETE", "OPTIONS"], # Explicit methods only
|
||||
allow_methods=[
|
||||
"GET",
|
||||
"POST",
|
||||
"PUT",
|
||||
"PATCH",
|
||||
"DELETE",
|
||||
"OPTIONS",
|
||||
], # Explicit methods only
|
||||
allow_headers=[
|
||||
"Content-Type",
|
||||
"Authorization",
|
||||
@@ -129,12 +136,14 @@ async def limit_request_size(request: Request, call_next):
|
||||
status_code=status.HTTP_413_REQUEST_ENTITY_TOO_LARGE,
|
||||
content={
|
||||
"success": False,
|
||||
"errors": [{
|
||||
"code": "REQUEST_TOO_LARGE",
|
||||
"message": f"Request body too large. Maximum size is {MAX_REQUEST_SIZE // (1024 * 1024)}MB",
|
||||
"field": None
|
||||
}]
|
||||
}
|
||||
"errors": [
|
||||
{
|
||||
"code": "REQUEST_TOO_LARGE",
|
||||
"message": f"Request body too large. Maximum size is {MAX_REQUEST_SIZE // (1024 * 1024)}MB",
|
||||
"field": None,
|
||||
}
|
||||
],
|
||||
},
|
||||
)
|
||||
|
||||
response = await call_next(request)
|
||||
@@ -165,15 +174,19 @@ async def add_security_headers(request: Request, call_next):
|
||||
|
||||
# Enforce HTTPS in production
|
||||
if settings.ENVIRONMENT == "production":
|
||||
response.headers["Strict-Transport-Security"] = "max-age=31536000; includeSubDomains"
|
||||
response.headers["Strict-Transport-Security"] = (
|
||||
"max-age=31536000; includeSubDomains"
|
||||
)
|
||||
|
||||
# Content Security Policy
|
||||
csp_mode = settings.CSP_MODE.lower()
|
||||
|
||||
# Special handling for API docs
|
||||
is_docs = request.url.path in ["/docs", "/redoc"] or \
|
||||
request.url.path.startswith("/docs/") or \
|
||||
request.url.path.startswith("/redoc/")
|
||||
is_docs = (
|
||||
request.url.path in ["/docs", "/redoc"]
|
||||
or request.url.path.startswith("/docs/")
|
||||
or request.url.path.startswith("/redoc/")
|
||||
)
|
||||
|
||||
if csp_mode == "disabled":
|
||||
# No CSP (only for local development/debugging)
|
||||
@@ -264,7 +277,7 @@ async def root():
|
||||
description="Check the health status of the API and its dependencies",
|
||||
response_description="Health status information",
|
||||
tags=["Health"],
|
||||
operation_id="health_check"
|
||||
operation_id="health_check",
|
||||
)
|
||||
async def health_check() -> JSONResponse:
|
||||
"""
|
||||
@@ -278,12 +291,12 @@ async def health_check() -> JSONResponse:
|
||||
- environment: Current environment (development, staging, production)
|
||||
- database: Database connectivity status
|
||||
"""
|
||||
health_status: Dict[str, Any] = {
|
||||
health_status: dict[str, Any] = {
|
||||
"status": "healthy",
|
||||
"timestamp": datetime.utcnow().isoformat() + "Z",
|
||||
"version": settings.VERSION,
|
||||
"environment": settings.ENVIRONMENT,
|
||||
"checks": {}
|
||||
"checks": {},
|
||||
}
|
||||
|
||||
response_status = status.HTTP_200_OK
|
||||
@@ -294,7 +307,7 @@ async def health_check() -> JSONResponse:
|
||||
if db_healthy:
|
||||
health_status["checks"]["database"] = {
|
||||
"status": "healthy",
|
||||
"message": "Database connection successful"
|
||||
"message": "Database connection successful",
|
||||
}
|
||||
else:
|
||||
raise Exception("Database health check returned unhealthy status")
|
||||
@@ -302,15 +315,12 @@ async def health_check() -> JSONResponse:
|
||||
health_status["status"] = "unhealthy"
|
||||
health_status["checks"]["database"] = {
|
||||
"status": "unhealthy",
|
||||
"message": f"Database connection failed: {str(e)}"
|
||||
"message": f"Database connection failed: {e!s}",
|
||||
}
|
||||
response_status = status.HTTP_503_SERVICE_UNAVAILABLE
|
||||
logger.error(f"Health check failed - database error: {e}")
|
||||
|
||||
return JSONResponse(
|
||||
status_code=response_status,
|
||||
content=health_status
|
||||
)
|
||||
return JSONResponse(status_code=response_status, content=health_status)
|
||||
|
||||
|
||||
app.include_router(api_router, prefix=settings.API_V1_STR)
|
||||
|
||||
@@ -2,17 +2,25 @@
|
||||
Models package initialization.
|
||||
Imports all models to ensure they're registered with SQLAlchemy.
|
||||
"""
|
||||
|
||||
# First import Base to avoid circular imports
|
||||
from app.core.database import Base
|
||||
|
||||
from .base import TimestampMixin, UUIDMixin
|
||||
from .organization import Organization
|
||||
|
||||
# Import models
|
||||
from .user import User
|
||||
from .user_organization import UserOrganization, OrganizationRole
|
||||
from .user_organization import OrganizationRole, UserOrganization
|
||||
from .user_session import UserSession
|
||||
|
||||
__all__ = [
|
||||
'Base', 'TimestampMixin', 'UUIDMixin',
|
||||
'User', 'UserSession',
|
||||
'Organization', 'UserOrganization', 'OrganizationRole',
|
||||
]
|
||||
"Base",
|
||||
"Organization",
|
||||
"OrganizationRole",
|
||||
"TimestampMixin",
|
||||
"UUIDMixin",
|
||||
"User",
|
||||
"UserOrganization",
|
||||
"UserSession",
|
||||
]
|
||||
|
||||
@@ -1,20 +1,27 @@
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
from datetime import UTC, datetime
|
||||
|
||||
from sqlalchemy import Column, DateTime
|
||||
from sqlalchemy.dialects.postgresql import UUID
|
||||
|
||||
# noinspection PyUnresolvedReferences
|
||||
from app.core.database import Base
|
||||
|
||||
|
||||
class TimestampMixin:
|
||||
"""Mixin to add created_at and updated_at timestamps to models"""
|
||||
created_at = Column(DateTime(timezone=True), default=lambda: datetime.now(timezone.utc), nullable=False)
|
||||
updated_at = Column(DateTime(timezone=True), default=lambda: datetime.now(timezone.utc),
|
||||
onupdate=lambda: datetime.now(timezone.utc), nullable=False)
|
||||
|
||||
created_at = Column(
|
||||
DateTime(timezone=True), default=lambda: datetime.now(UTC), nullable=False
|
||||
)
|
||||
updated_at = Column(
|
||||
DateTime(timezone=True),
|
||||
default=lambda: datetime.now(UTC),
|
||||
onupdate=lambda: datetime.now(UTC),
|
||||
nullable=False,
|
||||
)
|
||||
|
||||
|
||||
class UUIDMixin:
|
||||
"""Mixin to add UUID primary keys to models"""
|
||||
|
||||
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
# app/models/organization.py
|
||||
from sqlalchemy import Column, String, Boolean, Text, Index
|
||||
from sqlalchemy import Boolean, Column, Index, String, Text
|
||||
from sqlalchemy.dialects.postgresql import JSONB
|
||||
from sqlalchemy.orm import relationship
|
||||
|
||||
@@ -11,7 +11,8 @@ class Organization(Base, UUIDMixin, TimestampMixin):
|
||||
Organization model for multi-tenant support.
|
||||
Users can belong to multiple organizations with different roles.
|
||||
"""
|
||||
__tablename__ = 'organizations'
|
||||
|
||||
__tablename__ = "organizations"
|
||||
|
||||
name = Column(String(255), nullable=False, index=True)
|
||||
slug = Column(String(255), unique=True, nullable=False, index=True)
|
||||
@@ -20,11 +21,13 @@ class Organization(Base, UUIDMixin, TimestampMixin):
|
||||
settings = Column(JSONB, default={})
|
||||
|
||||
# Relationships
|
||||
user_organizations = relationship("UserOrganization", back_populates="organization", cascade="all, delete-orphan")
|
||||
user_organizations = relationship(
|
||||
"UserOrganization", back_populates="organization", cascade="all, delete-orphan"
|
||||
)
|
||||
|
||||
__table_args__ = (
|
||||
Index('ix_organizations_name_active', 'name', 'is_active'),
|
||||
Index('ix_organizations_slug_active', 'slug', 'is_active'),
|
||||
Index("ix_organizations_name_active", "name", "is_active"),
|
||||
Index("ix_organizations_slug_active", "slug", "is_active"),
|
||||
)
|
||||
|
||||
def __repr__(self):
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from sqlalchemy import Column, String, Boolean, DateTime
|
||||
from sqlalchemy import Boolean, Column, DateTime, String
|
||||
from sqlalchemy.dialects.postgresql import JSONB
|
||||
from sqlalchemy.orm import relationship
|
||||
|
||||
@@ -6,7 +6,7 @@ from .base import Base, TimestampMixin, UUIDMixin
|
||||
|
||||
|
||||
class User(Base, UUIDMixin, TimestampMixin):
|
||||
__tablename__ = 'users'
|
||||
__tablename__ = "users"
|
||||
|
||||
email = Column(String(255), unique=True, nullable=False, index=True)
|
||||
password_hash = Column(String(255), nullable=False)
|
||||
@@ -19,7 +19,9 @@ class User(Base, UUIDMixin, TimestampMixin):
|
||||
deleted_at = Column(DateTime(timezone=True), nullable=True, index=True)
|
||||
|
||||
# Relationships
|
||||
user_organizations = relationship("UserOrganization", back_populates="user", cascade="all, delete-orphan")
|
||||
user_organizations = relationship(
|
||||
"UserOrganization", back_populates="user", cascade="all, delete-orphan"
|
||||
)
|
||||
|
||||
def __repr__(self):
|
||||
return f"<User {self.email}>"
|
||||
return f"<User {self.email}>"
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
# app/models/user_organization.py
|
||||
from enum import Enum as PyEnum
|
||||
|
||||
from sqlalchemy import Column, ForeignKey, Boolean, String, Index, Enum
|
||||
from sqlalchemy import Boolean, Column, Enum, ForeignKey, Index, String
|
||||
from sqlalchemy.dialects.postgresql import UUID as PGUUID
|
||||
from sqlalchemy.orm import relationship
|
||||
|
||||
@@ -14,6 +14,7 @@ class OrganizationRole(str, PyEnum):
|
||||
These provide a baseline role system that can be optionally used.
|
||||
Projects can extend this or implement their own permission system.
|
||||
"""
|
||||
|
||||
OWNER = "owner" # Full control over organization
|
||||
ADMIN = "admin" # Can manage users and settings
|
||||
MEMBER = "member" # Regular member with standard access
|
||||
@@ -25,25 +26,41 @@ class UserOrganization(Base, TimestampMixin):
|
||||
Junction table for many-to-many relationship between Users and Organizations.
|
||||
Includes role information for flexible RBAC.
|
||||
"""
|
||||
__tablename__ = 'user_organizations'
|
||||
|
||||
user_id = Column(PGUUID(as_uuid=True), ForeignKey('users.id', ondelete='CASCADE'), primary_key=True)
|
||||
organization_id = Column(PGUUID(as_uuid=True), ForeignKey('organizations.id', ondelete='CASCADE'), primary_key=True)
|
||||
__tablename__ = "user_organizations"
|
||||
|
||||
role = Column(Enum(OrganizationRole), default=OrganizationRole.MEMBER, nullable=False, index=True)
|
||||
user_id = Column(
|
||||
PGUUID(as_uuid=True),
|
||||
ForeignKey("users.id", ondelete="CASCADE"),
|
||||
primary_key=True,
|
||||
)
|
||||
organization_id = Column(
|
||||
PGUUID(as_uuid=True),
|
||||
ForeignKey("organizations.id", ondelete="CASCADE"),
|
||||
primary_key=True,
|
||||
)
|
||||
|
||||
role = Column(
|
||||
Enum(OrganizationRole),
|
||||
default=OrganizationRole.MEMBER,
|
||||
nullable=False,
|
||||
index=True,
|
||||
)
|
||||
is_active = Column(Boolean, default=True, nullable=False, index=True)
|
||||
|
||||
# Optional: Custom permissions override for specific users
|
||||
custom_permissions = Column(String(500), nullable=True) # JSON array of permission strings
|
||||
custom_permissions = Column(
|
||||
String(500), nullable=True
|
||||
) # JSON array of permission strings
|
||||
|
||||
# Relationships
|
||||
user = relationship("User", back_populates="user_organizations")
|
||||
organization = relationship("Organization", back_populates="user_organizations")
|
||||
|
||||
__table_args__ = (
|
||||
Index('ix_user_org_user_active', 'user_id', 'is_active'),
|
||||
Index('ix_user_org_org_active', 'organization_id', 'is_active'),
|
||||
Index('ix_user_org_role', 'role'),
|
||||
Index("ix_user_org_user_active", "user_id", "is_active"),
|
||||
Index("ix_user_org_org_active", "organization_id", "is_active"),
|
||||
Index("ix_user_org_role", "role"),
|
||||
)
|
||||
|
||||
def __repr__(self):
|
||||
|
||||
@@ -6,7 +6,10 @@ This allows users to:
|
||||
- Logout from specific devices
|
||||
- Manage their active sessions
|
||||
"""
|
||||
from sqlalchemy import Column, String, Boolean, DateTime, ForeignKey, Index
|
||||
|
||||
from datetime import UTC
|
||||
|
||||
from sqlalchemy import Boolean, Column, DateTime, ForeignKey, Index, String
|
||||
from sqlalchemy.dialects.postgresql import UUID
|
||||
from sqlalchemy.orm import relationship
|
||||
|
||||
@@ -20,19 +23,27 @@ class UserSession(Base, UUIDMixin, TimestampMixin):
|
||||
Each time a user logs in from a device, a new session is created.
|
||||
Sessions are identified by the refresh token JTI (JWT ID).
|
||||
"""
|
||||
__tablename__ = 'user_sessions'
|
||||
|
||||
__tablename__ = "user_sessions"
|
||||
|
||||
# Foreign key to user
|
||||
user_id = Column(UUID(as_uuid=True), ForeignKey('users.id', ondelete='CASCADE'), nullable=False, index=True)
|
||||
user_id = Column(
|
||||
UUID(as_uuid=True),
|
||||
ForeignKey("users.id", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
index=True,
|
||||
)
|
||||
|
||||
# Refresh token identifier (JWT ID from the refresh token)
|
||||
refresh_token_jti = Column(String(255), unique=True, nullable=False, index=True)
|
||||
|
||||
# Device information
|
||||
device_name = Column(String(255), nullable=True) # "iPhone 14", "Chrome on MacBook"
|
||||
device_id = Column(String(255), nullable=True) # Persistent device identifier (from client)
|
||||
ip_address = Column(String(45), nullable=True) # IPv4 (15 chars) or IPv6 (45 chars)
|
||||
user_agent = Column(String(500), nullable=True) # Browser/app user agent
|
||||
device_id = Column(
|
||||
String(255), nullable=True
|
||||
) # Persistent device identifier (from client)
|
||||
ip_address = Column(String(45), nullable=True) # IPv4 (15 chars) or IPv6 (45 chars)
|
||||
user_agent = Column(String(500), nullable=True) # Browser/app user agent
|
||||
|
||||
# Session timing
|
||||
last_used_at = Column(DateTime(timezone=True), nullable=False)
|
||||
@@ -50,8 +61,8 @@ class UserSession(Base, UUIDMixin, TimestampMixin):
|
||||
|
||||
# Composite indexes for performance (defined in migration)
|
||||
__table_args__ = (
|
||||
Index('ix_user_sessions_user_active', 'user_id', 'is_active'),
|
||||
Index('ix_user_sessions_jti_active', 'refresh_token_jti', 'is_active'),
|
||||
Index("ix_user_sessions_user_active", "user_id", "is_active"),
|
||||
Index("ix_user_sessions_jti_active", "refresh_token_jti", "is_active"),
|
||||
)
|
||||
|
||||
def __repr__(self):
|
||||
@@ -60,21 +71,24 @@ class UserSession(Base, UUIDMixin, TimestampMixin):
|
||||
@property
|
||||
def is_expired(self) -> bool:
|
||||
"""Check if session has expired."""
|
||||
from datetime import datetime, timezone
|
||||
return self.expires_at < datetime.now(timezone.utc)
|
||||
from datetime import datetime
|
||||
|
||||
return self.expires_at < datetime.now(UTC)
|
||||
|
||||
def to_dict(self):
|
||||
"""Convert session to dictionary for serialization."""
|
||||
return {
|
||||
'id': str(self.id),
|
||||
'user_id': str(self.user_id),
|
||||
'device_name': self.device_name,
|
||||
'device_id': self.device_id,
|
||||
'ip_address': self.ip_address,
|
||||
'last_used_at': self.last_used_at.isoformat() if self.last_used_at else None,
|
||||
'expires_at': self.expires_at.isoformat() if self.expires_at else None,
|
||||
'is_active': self.is_active,
|
||||
'location_city': self.location_city,
|
||||
'location_country': self.location_country,
|
||||
'created_at': self.created_at.isoformat() if self.created_at else None,
|
||||
"id": str(self.id),
|
||||
"user_id": str(self.user_id),
|
||||
"device_name": self.device_name,
|
||||
"device_id": self.device_id,
|
||||
"ip_address": self.ip_address,
|
||||
"last_used_at": self.last_used_at.isoformat()
|
||||
if self.last_used_at
|
||||
else None,
|
||||
"expires_at": self.expires_at.isoformat() if self.expires_at else None,
|
||||
"is_active": self.is_active,
|
||||
"location_city": self.location_city,
|
||||
"location_country": self.location_country,
|
||||
"created_at": self.created_at.isoformat() if self.created_at else None,
|
||||
}
|
||||
|
||||
@@ -1,18 +1,20 @@
|
||||
"""
|
||||
Common schemas used across the API for pagination, responses, filtering, and sorting.
|
||||
"""
|
||||
|
||||
from enum import Enum
|
||||
from math import ceil
|
||||
from typing import Generic, TypeVar, List, Optional
|
||||
from typing import TypeVar
|
||||
from uuid import UUID
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
T = TypeVar('T')
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
class SortOrder(str, Enum):
|
||||
"""Sort order options."""
|
||||
|
||||
ASC = "asc"
|
||||
DESC = "desc"
|
||||
|
||||
@@ -20,16 +22,9 @@ class SortOrder(str, Enum):
|
||||
class PaginationParams(BaseModel):
|
||||
"""Parameters for pagination."""
|
||||
|
||||
page: int = Field(
|
||||
default=1,
|
||||
ge=1,
|
||||
description="Page number (1-indexed)"
|
||||
)
|
||||
page: int = Field(default=1, ge=1, description="Page number (1-indexed)")
|
||||
limit: int = Field(
|
||||
default=20,
|
||||
ge=1,
|
||||
le=100,
|
||||
description="Number of items per page (max 100)"
|
||||
default=20, ge=1, le=100, description="Number of items per page (max 100)"
|
||||
)
|
||||
|
||||
@property
|
||||
@@ -42,34 +37,20 @@ class PaginationParams(BaseModel):
|
||||
"""Alias for offset (compatibility with existing code)."""
|
||||
return self.offset
|
||||
|
||||
model_config = {
|
||||
"json_schema_extra": {
|
||||
"example": {
|
||||
"page": 1,
|
||||
"limit": 20
|
||||
}
|
||||
}
|
||||
}
|
||||
model_config = {"json_schema_extra": {"example": {"page": 1, "limit": 20}}}
|
||||
|
||||
|
||||
class SortParams(BaseModel):
|
||||
"""Parameters for sorting."""
|
||||
|
||||
sort_by: Optional[str] = Field(
|
||||
default=None,
|
||||
description="Field name to sort by"
|
||||
)
|
||||
sort_by: str | None = Field(default=None, description="Field name to sort by")
|
||||
sort_order: SortOrder = Field(
|
||||
default=SortOrder.ASC,
|
||||
description="Sort order (asc or desc)"
|
||||
default=SortOrder.ASC, description="Sort order (asc or desc)"
|
||||
)
|
||||
|
||||
model_config = {
|
||||
"json_schema_extra": {
|
||||
"example": {
|
||||
"sort_by": "created_at",
|
||||
"sort_order": "desc"
|
||||
}
|
||||
"example": {"sort_by": "created_at", "sort_order": "desc"}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -92,32 +73,30 @@ class PaginationMeta(BaseModel):
|
||||
"page_size": 20,
|
||||
"total_pages": 8,
|
||||
"has_next": True,
|
||||
"has_prev": False
|
||||
"has_prev": False,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
class PaginatedResponse(BaseModel, Generic[T]):
|
||||
class PaginatedResponse[T](BaseModel):
|
||||
"""Generic paginated response wrapper."""
|
||||
|
||||
data: List[T] = Field(..., description="List of items")
|
||||
data: list[T] = Field(..., description="List of items")
|
||||
pagination: PaginationMeta = Field(..., description="Pagination metadata")
|
||||
|
||||
model_config = {
|
||||
"json_schema_extra": {
|
||||
"example": {
|
||||
"data": [
|
||||
{"id": "123", "name": "Example Item"}
|
||||
],
|
||||
"data": [{"id": "123", "name": "Example Item"}],
|
||||
"pagination": {
|
||||
"total": 150,
|
||||
"page": 1,
|
||||
"page_size": 20,
|
||||
"total_pages": 8,
|
||||
"has_next": True,
|
||||
"has_prev": False
|
||||
}
|
||||
"has_prev": False,
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -131,10 +110,7 @@ class MessageResponse(BaseModel):
|
||||
|
||||
model_config = {
|
||||
"json_schema_extra": {
|
||||
"example": {
|
||||
"success": True,
|
||||
"message": "Operation completed successfully"
|
||||
}
|
||||
"example": {"success": True, "message": "Operation completed successfully"}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -142,11 +118,11 @@ class MessageResponse(BaseModel):
|
||||
class BulkActionRequest(BaseModel):
|
||||
"""Request schema for bulk operations on multiple items."""
|
||||
|
||||
ids: List[UUID] = Field(
|
||||
ids: list[UUID] = Field(
|
||||
...,
|
||||
min_length=1,
|
||||
max_length=100,
|
||||
description="List of item IDs to perform action on (max 100)"
|
||||
description="List of item IDs to perform action on (max 100)",
|
||||
)
|
||||
|
||||
model_config = {
|
||||
@@ -154,7 +130,7 @@ class BulkActionRequest(BaseModel):
|
||||
"example": {
|
||||
"ids": [
|
||||
"550e8400-e29b-41d4-a716-446655440000",
|
||||
"6ba7b810-9dad-11d1-80b4-00c04fd430c8"
|
||||
"6ba7b810-9dad-11d1-80b4-00c04fd430c8",
|
||||
]
|
||||
}
|
||||
}
|
||||
@@ -166,24 +142,23 @@ class BulkActionResponse(BaseModel):
|
||||
|
||||
success: bool = Field(default=True, description="Operation success status")
|
||||
message: str = Field(..., description="Human-readable message")
|
||||
affected_count: int = Field(..., description="Number of items affected by the operation")
|
||||
affected_count: int = Field(
|
||||
..., description="Number of items affected by the operation"
|
||||
)
|
||||
|
||||
model_config = {
|
||||
"json_schema_extra": {
|
||||
"example": {
|
||||
"success": True,
|
||||
"message": "Successfully deactivated 5 users",
|
||||
"affected_count": 5
|
||||
"affected_count": 5,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
def create_pagination_meta(
|
||||
total: int,
|
||||
page: int,
|
||||
limit: int,
|
||||
items_count: int
|
||||
total: int, page: int, limit: int, items_count: int
|
||||
) -> PaginationMeta:
|
||||
"""
|
||||
Helper function to create pagination metadata.
|
||||
@@ -205,5 +180,5 @@ def create_pagination_meta(
|
||||
page_size=items_count,
|
||||
total_pages=total_pages,
|
||||
has_next=page < total_pages,
|
||||
has_prev=page > 1
|
||||
has_prev=page > 1,
|
||||
)
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
"""
|
||||
Error schemas for standardized API error responses.
|
||||
"""
|
||||
|
||||
from enum import Enum
|
||||
from typing import List, Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
@@ -53,14 +53,14 @@ class ErrorDetail(BaseModel):
|
||||
|
||||
code: ErrorCode = Field(..., description="Machine-readable error code")
|
||||
message: str = Field(..., description="Human-readable error message")
|
||||
field: Optional[str] = Field(None, description="Field name if error is field-specific")
|
||||
field: str | None = Field(None, description="Field name if error is field-specific")
|
||||
|
||||
model_config = {
|
||||
"json_schema_extra": {
|
||||
"example": {
|
||||
"code": "VAL_002",
|
||||
"message": "Password must be at least 8 characters long",
|
||||
"field": "password"
|
||||
"field": "password",
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -70,7 +70,7 @@ class ErrorResponse(BaseModel):
|
||||
"""Standardized error response format."""
|
||||
|
||||
success: bool = Field(default=False, description="Always false for error responses")
|
||||
errors: List[ErrorDetail] = Field(..., description="List of errors that occurred")
|
||||
errors: list[ErrorDetail] = Field(..., description="List of errors that occurred")
|
||||
|
||||
model_config = {
|
||||
"json_schema_extra": {
|
||||
@@ -80,9 +80,9 @@ class ErrorResponse(BaseModel):
|
||||
{
|
||||
"code": "AUTH_001",
|
||||
"message": "Invalid email or password",
|
||||
"field": None
|
||||
"field": None,
|
||||
}
|
||||
]
|
||||
],
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
# app/schemas/organizations.py
|
||||
import re
|
||||
from datetime import datetime
|
||||
from typing import Optional, Dict, Any, List
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
from pydantic import BaseModel, field_validator, ConfigDict, Field
|
||||
from pydantic import BaseModel, ConfigDict, Field, field_validator
|
||||
|
||||
from app.models.user_organization import OrganizationRole
|
||||
|
||||
@@ -12,85 +12,94 @@ from app.models.user_organization import OrganizationRole
|
||||
# Organization Schemas
|
||||
class OrganizationBase(BaseModel):
|
||||
"""Base organization schema with common fields."""
|
||||
name: str = Field(..., min_length=1, max_length=255)
|
||||
slug: Optional[str] = Field(None, min_length=1, max_length=255)
|
||||
description: Optional[str] = None
|
||||
is_active: bool = True
|
||||
settings: Optional[Dict[str, Any]] = {}
|
||||
|
||||
@field_validator('slug')
|
||||
name: str = Field(..., min_length=1, max_length=255)
|
||||
slug: str | None = Field(None, min_length=1, max_length=255)
|
||||
description: str | None = None
|
||||
is_active: bool = True
|
||||
settings: dict[str, Any] | None = {}
|
||||
|
||||
@field_validator("slug")
|
||||
@classmethod
|
||||
def validate_slug(cls, v: Optional[str]) -> Optional[str]:
|
||||
def validate_slug(cls, v: str | None) -> str | None:
|
||||
"""Validate slug format: lowercase, alphanumeric, hyphens only."""
|
||||
if v is None:
|
||||
return v
|
||||
if not re.match(r'^[a-z0-9-]+$', v):
|
||||
raise ValueError('Slug must contain only lowercase letters, numbers, and hyphens')
|
||||
if v.startswith('-') or v.endswith('-'):
|
||||
raise ValueError('Slug cannot start or end with a hyphen')
|
||||
if '--' in v:
|
||||
raise ValueError('Slug cannot contain consecutive hyphens')
|
||||
if not re.match(r"^[a-z0-9-]+$", v):
|
||||
raise ValueError(
|
||||
"Slug must contain only lowercase letters, numbers, and hyphens"
|
||||
)
|
||||
if v.startswith("-") or v.endswith("-"):
|
||||
raise ValueError("Slug cannot start or end with a hyphen")
|
||||
if "--" in v:
|
||||
raise ValueError("Slug cannot contain consecutive hyphens")
|
||||
return v
|
||||
|
||||
@field_validator('name')
|
||||
@field_validator("name")
|
||||
@classmethod
|
||||
def validate_name(cls, v: str) -> str:
|
||||
"""Validate organization name."""
|
||||
if not v or v.strip() == "":
|
||||
raise ValueError('Organization name cannot be empty')
|
||||
raise ValueError("Organization name cannot be empty")
|
||||
return v.strip()
|
||||
|
||||
|
||||
class OrganizationCreate(OrganizationBase):
|
||||
"""Schema for creating a new organization."""
|
||||
|
||||
name: str = Field(..., min_length=1, max_length=255)
|
||||
slug: str = Field(..., min_length=1, max_length=255)
|
||||
|
||||
|
||||
class OrganizationUpdate(BaseModel):
|
||||
"""Schema for updating an organization."""
|
||||
name: Optional[str] = Field(None, min_length=1, max_length=255)
|
||||
slug: Optional[str] = Field(None, min_length=1, max_length=255)
|
||||
description: Optional[str] = None
|
||||
is_active: Optional[bool] = None
|
||||
settings: Optional[Dict[str, Any]] = None
|
||||
|
||||
@field_validator('slug')
|
||||
name: str | None = Field(None, min_length=1, max_length=255)
|
||||
slug: str | None = Field(None, min_length=1, max_length=255)
|
||||
description: str | None = None
|
||||
is_active: bool | None = None
|
||||
settings: dict[str, Any] | None = None
|
||||
|
||||
@field_validator("slug")
|
||||
@classmethod
|
||||
def validate_slug(cls, v: Optional[str]) -> Optional[str]:
|
||||
def validate_slug(cls, v: str | None) -> str | None:
|
||||
"""Validate slug format."""
|
||||
if v is None:
|
||||
return v
|
||||
if not re.match(r'^[a-z0-9-]+$', v):
|
||||
raise ValueError('Slug must contain only lowercase letters, numbers, and hyphens')
|
||||
if v.startswith('-') or v.endswith('-'):
|
||||
raise ValueError('Slug cannot start or end with a hyphen')
|
||||
if '--' in v:
|
||||
raise ValueError('Slug cannot contain consecutive hyphens')
|
||||
if not re.match(r"^[a-z0-9-]+$", v):
|
||||
raise ValueError(
|
||||
"Slug must contain only lowercase letters, numbers, and hyphens"
|
||||
)
|
||||
if v.startswith("-") or v.endswith("-"):
|
||||
raise ValueError("Slug cannot start or end with a hyphen")
|
||||
if "--" in v:
|
||||
raise ValueError("Slug cannot contain consecutive hyphens")
|
||||
return v
|
||||
|
||||
@field_validator('name')
|
||||
@field_validator("name")
|
||||
@classmethod
|
||||
def validate_name(cls, v: Optional[str]) -> Optional[str]:
|
||||
def validate_name(cls, v: str | None) -> str | None:
|
||||
"""Validate organization name."""
|
||||
if v is not None and (not v or v.strip() == ""):
|
||||
raise ValueError('Organization name cannot be empty')
|
||||
raise ValueError("Organization name cannot be empty")
|
||||
return v.strip() if v else v
|
||||
|
||||
|
||||
class OrganizationResponse(OrganizationBase):
|
||||
"""Schema for organization API responses."""
|
||||
|
||||
id: UUID
|
||||
created_at: datetime
|
||||
updated_at: Optional[datetime] = None
|
||||
member_count: Optional[int] = 0
|
||||
updated_at: datetime | None = None
|
||||
member_count: int | None = 0
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
|
||||
class OrganizationListResponse(BaseModel):
|
||||
"""Schema for paginated organization list responses."""
|
||||
organizations: List[OrganizationResponse]
|
||||
|
||||
organizations: list[OrganizationResponse]
|
||||
total: int
|
||||
page: int
|
||||
page_size: int
|
||||
@@ -100,44 +109,49 @@ class OrganizationListResponse(BaseModel):
|
||||
# User-Organization Relationship Schemas
|
||||
class UserOrganizationBase(BaseModel):
|
||||
"""Base schema for user-organization relationship."""
|
||||
|
||||
role: OrganizationRole = OrganizationRole.MEMBER
|
||||
is_active: bool = True
|
||||
custom_permissions: Optional[str] = None
|
||||
custom_permissions: str | None = None
|
||||
|
||||
|
||||
class UserOrganizationCreate(BaseModel):
|
||||
"""Schema for adding a user to an organization."""
|
||||
|
||||
user_id: UUID
|
||||
role: OrganizationRole = OrganizationRole.MEMBER
|
||||
custom_permissions: Optional[str] = None
|
||||
custom_permissions: str | None = None
|
||||
|
||||
|
||||
class UserOrganizationUpdate(BaseModel):
|
||||
"""Schema for updating user's role in an organization."""
|
||||
role: Optional[OrganizationRole] = None
|
||||
is_active: Optional[bool] = None
|
||||
custom_permissions: Optional[str] = None
|
||||
|
||||
role: OrganizationRole | None = None
|
||||
is_active: bool | None = None
|
||||
custom_permissions: str | None = None
|
||||
|
||||
|
||||
class UserOrganizationResponse(BaseModel):
|
||||
"""Schema for user-organization relationship responses."""
|
||||
|
||||
user_id: UUID
|
||||
organization_id: UUID
|
||||
role: OrganizationRole
|
||||
is_active: bool
|
||||
custom_permissions: Optional[str] = None
|
||||
custom_permissions: str | None = None
|
||||
created_at: datetime
|
||||
updated_at: Optional[datetime] = None
|
||||
updated_at: datetime | None = None
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
|
||||
class OrganizationMemberResponse(BaseModel):
|
||||
"""Schema for organization member information."""
|
||||
|
||||
user_id: UUID
|
||||
email: str
|
||||
first_name: str
|
||||
last_name: Optional[str] = None
|
||||
last_name: str | None = None
|
||||
role: OrganizationRole
|
||||
is_active: bool
|
||||
joined_at: datetime
|
||||
@@ -147,7 +161,8 @@ class OrganizationMemberResponse(BaseModel):
|
||||
|
||||
class OrganizationMemberListResponse(BaseModel):
|
||||
"""Schema for paginated organization member list."""
|
||||
members: List[OrganizationMemberResponse]
|
||||
|
||||
members: list[OrganizationMemberResponse]
|
||||
total: int
|
||||
page: int
|
||||
page_size: int
|
||||
|
||||
@@ -1,37 +1,44 @@
|
||||
"""
|
||||
Pydantic schemas for user session management.
|
||||
"""
|
||||
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
from uuid import UUID
|
||||
|
||||
from pydantic import BaseModel, Field, ConfigDict
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
|
||||
class SessionBase(BaseModel):
|
||||
"""Base schema for user sessions."""
|
||||
device_name: Optional[str] = Field(None, max_length=255, description="Friendly device name")
|
||||
device_id: Optional[str] = Field(None, max_length=255, description="Persistent device identifier")
|
||||
|
||||
device_name: str | None = Field(
|
||||
None, max_length=255, description="Friendly device name"
|
||||
)
|
||||
device_id: str | None = Field(
|
||||
None, max_length=255, description="Persistent device identifier"
|
||||
)
|
||||
|
||||
|
||||
class SessionCreate(SessionBase):
|
||||
"""Schema for creating a new session (internal use)."""
|
||||
|
||||
user_id: UUID
|
||||
refresh_token_jti: str = Field(..., max_length=255)
|
||||
ip_address: Optional[str] = Field(None, max_length=45)
|
||||
user_agent: Optional[str] = Field(None, max_length=500)
|
||||
ip_address: str | None = Field(None, max_length=45)
|
||||
user_agent: str | None = Field(None, max_length=500)
|
||||
last_used_at: datetime
|
||||
expires_at: datetime
|
||||
location_city: Optional[str] = Field(None, max_length=100)
|
||||
location_country: Optional[str] = Field(None, max_length=100)
|
||||
location_city: str | None = Field(None, max_length=100)
|
||||
location_country: str | None = Field(None, max_length=100)
|
||||
|
||||
|
||||
class SessionUpdate(BaseModel):
|
||||
"""Schema for updating a session (internal use)."""
|
||||
last_used_at: Optional[datetime] = None
|
||||
is_active: Optional[bool] = None
|
||||
refresh_token_jti: Optional[str] = None
|
||||
expires_at: Optional[datetime] = None
|
||||
|
||||
last_used_at: datetime | None = None
|
||||
is_active: bool | None = None
|
||||
refresh_token_jti: str | None = None
|
||||
expires_at: datetime | None = None
|
||||
|
||||
|
||||
class SessionResponse(SessionBase):
|
||||
@@ -40,14 +47,17 @@ class SessionResponse(SessionBase):
|
||||
|
||||
This is what users see when they list their active sessions.
|
||||
"""
|
||||
|
||||
id: UUID
|
||||
ip_address: Optional[str] = None
|
||||
location_city: Optional[str] = None
|
||||
location_country: Optional[str] = None
|
||||
ip_address: str | None = None
|
||||
location_city: str | None = None
|
||||
location_country: str | None = None
|
||||
last_used_at: datetime
|
||||
created_at: datetime
|
||||
expires_at: datetime
|
||||
is_current: bool = Field(default=False, description="Whether this is the current session")
|
||||
is_current: bool = Field(
|
||||
default=False, description="Whether this is the current session"
|
||||
)
|
||||
|
||||
model_config = ConfigDict(
|
||||
from_attributes=True,
|
||||
@@ -62,14 +72,15 @@ class SessionResponse(SessionBase):
|
||||
"last_used_at": "2025-10-31T12:00:00Z",
|
||||
"created_at": "2025-10-30T09:00:00Z",
|
||||
"expires_at": "2025-11-06T09:00:00Z",
|
||||
"is_current": True
|
||||
"is_current": True,
|
||||
}
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
class SessionListResponse(BaseModel):
|
||||
"""Response containing list of sessions."""
|
||||
|
||||
sessions: list[SessionResponse]
|
||||
total: int = Field(..., description="Total number of active sessions")
|
||||
|
||||
@@ -84,10 +95,10 @@ class SessionListResponse(BaseModel):
|
||||
"last_used_at": "2025-10-31T12:00:00Z",
|
||||
"created_at": "2025-10-30T09:00:00Z",
|
||||
"expires_at": "2025-11-06T09:00:00Z",
|
||||
"is_current": True
|
||||
"is_current": True,
|
||||
}
|
||||
],
|
||||
"total": 1
|
||||
"total": 1,
|
||||
}
|
||||
}
|
||||
)
|
||||
@@ -95,17 +106,14 @@ class SessionListResponse(BaseModel):
|
||||
|
||||
class LogoutRequest(BaseModel):
|
||||
"""Request schema for logout endpoint."""
|
||||
|
||||
refresh_token: str = Field(
|
||||
...,
|
||||
description="Refresh token for the session to logout from",
|
||||
min_length=10
|
||||
..., description="Refresh token for the session to logout from", min_length=10
|
||||
)
|
||||
|
||||
model_config = ConfigDict(
|
||||
json_schema_extra={
|
||||
"example": {
|
||||
"refresh_token": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9..."
|
||||
}
|
||||
"example": {"refresh_token": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9..."}
|
||||
}
|
||||
)
|
||||
|
||||
@@ -116,13 +124,14 @@ class AdminSessionResponse(SessionBase):
|
||||
|
||||
Includes user information for admin to see who owns each session.
|
||||
"""
|
||||
|
||||
id: UUID
|
||||
user_id: UUID
|
||||
user_email: str = Field(..., description="Email of the user who owns this session")
|
||||
user_full_name: Optional[str] = Field(None, description="Full name of the user")
|
||||
ip_address: Optional[str] = None
|
||||
location_city: Optional[str] = None
|
||||
location_country: Optional[str] = None
|
||||
user_full_name: str | None = Field(None, description="Full name of the user")
|
||||
ip_address: str | None = None
|
||||
location_city: str | None = None
|
||||
location_country: str | None = None
|
||||
last_used_at: datetime
|
||||
created_at: datetime
|
||||
expires_at: datetime
|
||||
@@ -144,20 +153,21 @@ class AdminSessionResponse(SessionBase):
|
||||
"last_used_at": "2025-10-31T12:00:00Z",
|
||||
"created_at": "2025-10-30T09:00:00Z",
|
||||
"expires_at": "2025-11-06T09:00:00Z",
|
||||
"is_active": True
|
||||
"is_active": True,
|
||||
}
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
class DeviceInfo(BaseModel):
|
||||
"""Device information extracted from request."""
|
||||
device_name: Optional[str] = None
|
||||
device_id: Optional[str] = None
|
||||
ip_address: Optional[str] = None
|
||||
user_agent: Optional[str] = None
|
||||
location_city: Optional[str] = None
|
||||
location_country: Optional[str] = None
|
||||
|
||||
device_name: str | None = None
|
||||
device_id: str | None = None
|
||||
ip_address: str | None = None
|
||||
user_agent: str | None = None
|
||||
location_city: str | None = None
|
||||
location_country: str | None = None
|
||||
|
||||
model_config = ConfigDict(
|
||||
json_schema_extra={
|
||||
@@ -167,7 +177,7 @@ class DeviceInfo(BaseModel):
|
||||
"ip_address": "192.168.1.50",
|
||||
"user_agent": "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7)...",
|
||||
"location_city": "San Francisco",
|
||||
"location_country": "United States"
|
||||
"location_country": "United States",
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
# app/schemas/users.py
|
||||
from datetime import datetime
|
||||
from typing import Optional, Dict, Any
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
from pydantic import BaseModel, EmailStr, field_validator, ConfigDict, Field
|
||||
from pydantic import BaseModel, ConfigDict, EmailStr, Field, field_validator
|
||||
|
||||
from app.schemas.validators import validate_password_strength, validate_phone_number
|
||||
|
||||
@@ -11,12 +11,12 @@ from app.schemas.validators import validate_password_strength, validate_phone_nu
|
||||
class UserBase(BaseModel):
|
||||
email: EmailStr
|
||||
first_name: str
|
||||
last_name: Optional[str] = None
|
||||
phone_number: Optional[str] = None
|
||||
last_name: str | None = None
|
||||
phone_number: str | None = None
|
||||
|
||||
@field_validator('phone_number')
|
||||
@field_validator("phone_number")
|
||||
@classmethod
|
||||
def validate_phone(cls, v: Optional[str]) -> Optional[str]:
|
||||
def validate_phone(cls, v: str | None) -> str | None:
|
||||
return validate_phone_number(v)
|
||||
|
||||
|
||||
@@ -24,7 +24,7 @@ class UserCreate(UserBase):
|
||||
password: str
|
||||
is_superuser: bool = False
|
||||
|
||||
@field_validator('password')
|
||||
@field_validator("password")
|
||||
@classmethod
|
||||
def password_strength(cls, v: str) -> str:
|
||||
"""Enterprise-grade password strength validation"""
|
||||
@@ -32,30 +32,32 @@ class UserCreate(UserBase):
|
||||
|
||||
|
||||
class UserUpdate(BaseModel):
|
||||
first_name: Optional[str] = None
|
||||
last_name: Optional[str] = None
|
||||
phone_number: Optional[str] = None
|
||||
password: Optional[str] = None
|
||||
preferences: Optional[Dict[str, Any]] = None
|
||||
is_active: Optional[bool] = None # Changed default from True to None to avoid unintended updates
|
||||
is_superuser: Optional[bool] = None # Explicitly reject privilege escalation attempts
|
||||
first_name: str | None = None
|
||||
last_name: str | None = None
|
||||
phone_number: str | None = None
|
||||
password: str | None = None
|
||||
preferences: dict[str, Any] | None = None
|
||||
is_active: bool | None = (
|
||||
None # Changed default from True to None to avoid unintended updates
|
||||
)
|
||||
is_superuser: bool | None = None # Explicitly reject privilege escalation attempts
|
||||
|
||||
@field_validator('phone_number')
|
||||
@field_validator("phone_number")
|
||||
@classmethod
|
||||
def validate_phone(cls, v: Optional[str]) -> Optional[str]:
|
||||
def validate_phone(cls, v: str | None) -> str | None:
|
||||
return validate_phone_number(v)
|
||||
|
||||
@field_validator('password')
|
||||
@field_validator("password")
|
||||
@classmethod
|
||||
def password_strength(cls, v: Optional[str]) -> Optional[str]:
|
||||
def password_strength(cls, v: str | None) -> str | None:
|
||||
"""Enterprise-grade password strength validation"""
|
||||
if v is None:
|
||||
return v
|
||||
return validate_password_strength(v)
|
||||
|
||||
@field_validator('is_superuser')
|
||||
@field_validator("is_superuser")
|
||||
@classmethod
|
||||
def prevent_superuser_modification(cls, v: Optional[bool]) -> Optional[bool]:
|
||||
def prevent_superuser_modification(cls, v: bool | None) -> bool | None:
|
||||
"""Prevent users from modifying their superuser status via this schema."""
|
||||
if v is not None:
|
||||
raise ValueError("Cannot modify superuser status through user update")
|
||||
@@ -67,7 +69,7 @@ class UserInDB(UserBase):
|
||||
is_active: bool
|
||||
is_superuser: bool
|
||||
created_at: datetime
|
||||
updated_at: Optional[datetime] = None
|
||||
updated_at: datetime | None = None
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
@@ -77,28 +79,28 @@ class UserResponse(UserBase):
|
||||
is_active: bool
|
||||
is_superuser: bool
|
||||
created_at: datetime
|
||||
updated_at: Optional[datetime] = None
|
||||
updated_at: datetime | None = None
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
|
||||
class Token(BaseModel):
|
||||
access_token: str
|
||||
refresh_token: Optional[str] = None
|
||||
refresh_token: str | None = None
|
||||
token_type: str = "bearer"
|
||||
user: "UserResponse" # Forward reference since UserResponse is defined above
|
||||
expires_in: Optional[int] = None # Token expiration in seconds
|
||||
expires_in: int | None = None # Token expiration in seconds
|
||||
|
||||
|
||||
class TokenPayload(BaseModel):
|
||||
sub: str # User ID
|
||||
exp: int # Expiration time
|
||||
iat: Optional[int] = None # Issued at
|
||||
jti: Optional[str] = None # JWT ID
|
||||
is_superuser: Optional[bool] = False
|
||||
first_name: Optional[str] = None
|
||||
email: Optional[str] = None
|
||||
type: Optional[str] = None # Token type (access/refresh)
|
||||
iat: int | None = None # Issued at
|
||||
jti: str | None = None # JWT ID
|
||||
is_superuser: bool | None = False
|
||||
first_name: str | None = None
|
||||
email: str | None = None
|
||||
type: str | None = None # Token type (access/refresh)
|
||||
|
||||
|
||||
class TokenData(BaseModel):
|
||||
@@ -108,10 +110,11 @@ class TokenData(BaseModel):
|
||||
|
||||
class PasswordChange(BaseModel):
|
||||
"""Schema for changing password (requires current password)."""
|
||||
|
||||
current_password: str
|
||||
new_password: str
|
||||
|
||||
@field_validator('new_password')
|
||||
@field_validator("new_password")
|
||||
@classmethod
|
||||
def password_strength(cls, v: str) -> str:
|
||||
"""Enterprise-grade password strength validation"""
|
||||
@@ -120,10 +123,11 @@ class PasswordChange(BaseModel):
|
||||
|
||||
class PasswordReset(BaseModel):
|
||||
"""Schema for resetting password (via email token)."""
|
||||
|
||||
token: str
|
||||
new_password: str
|
||||
|
||||
@field_validator('new_password')
|
||||
@field_validator("new_password")
|
||||
@classmethod
|
||||
def password_strength(cls, v: str) -> str:
|
||||
"""Enterprise-grade password strength validation"""
|
||||
@@ -141,23 +145,19 @@ class RefreshTokenRequest(BaseModel):
|
||||
|
||||
class PasswordResetRequest(BaseModel):
|
||||
"""Schema for requesting a password reset."""
|
||||
|
||||
email: EmailStr = Field(..., description="Email address of the account")
|
||||
|
||||
model_config = {
|
||||
"json_schema_extra": {
|
||||
"example": {
|
||||
"email": "user@example.com"
|
||||
}
|
||||
}
|
||||
}
|
||||
model_config = {"json_schema_extra": {"example": {"email": "user@example.com"}}}
|
||||
|
||||
|
||||
class PasswordResetConfirm(BaseModel):
|
||||
"""Schema for confirming a password reset with token."""
|
||||
|
||||
token: str = Field(..., description="Password reset token from email")
|
||||
new_password: str = Field(..., min_length=8, description="New password")
|
||||
|
||||
@field_validator('new_password')
|
||||
@field_validator("new_password")
|
||||
@classmethod
|
||||
def password_strength(cls, v: str) -> str:
|
||||
"""Enterprise-grade password strength validation"""
|
||||
@@ -167,7 +167,7 @@ class PasswordResetConfirm(BaseModel):
|
||||
"json_schema_extra": {
|
||||
"example": {
|
||||
"token": "eyJwYXlsb2FkIjp7ImVtYWlsIjoidXNlckBleGFtcGxlLmNvbSIsImV4cCI6MTcxMjM0NTY3OH19",
|
||||
"new_password": "NewSecurePassword123"
|
||||
"new_password": "NewSecurePassword123",
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -4,19 +4,34 @@ Shared validators for Pydantic schemas.
|
||||
This module provides reusable validation functions to ensure consistency
|
||||
across all schemas and avoid code duplication.
|
||||
"""
|
||||
|
||||
import re
|
||||
from typing import Set
|
||||
|
||||
# Common weak passwords that should be rejected
|
||||
COMMON_PASSWORDS: Set[str] = {
|
||||
'password', 'password1', 'password123', 'password1234',
|
||||
'admin', 'admin123', 'admin1234',
|
||||
'welcome', 'welcome1', 'welcome123',
|
||||
'qwerty', 'qwerty123',
|
||||
'12345678', '123456789', '1234567890',
|
||||
'letmein', 'letmein1', 'letmein123',
|
||||
'monkey123', 'dragon123',
|
||||
'passw0rd', 'p@ssw0rd', 'p@ssword',
|
||||
COMMON_PASSWORDS: set[str] = {
|
||||
"password",
|
||||
"password1",
|
||||
"password123",
|
||||
"password1234",
|
||||
"admin",
|
||||
"admin123",
|
||||
"admin1234",
|
||||
"welcome",
|
||||
"welcome1",
|
||||
"welcome123",
|
||||
"qwerty",
|
||||
"qwerty123",
|
||||
"12345678",
|
||||
"123456789",
|
||||
"1234567890",
|
||||
"letmein",
|
||||
"letmein1",
|
||||
"letmein123",
|
||||
"monkey123",
|
||||
"dragon123",
|
||||
"passw0rd",
|
||||
"p@ssw0rd",
|
||||
"p@ssword",
|
||||
}
|
||||
|
||||
|
||||
@@ -47,18 +62,21 @@ def validate_password_strength(password: str) -> str:
|
||||
"""
|
||||
# Check minimum length
|
||||
if len(password) < 12:
|
||||
raise ValueError('Password must be at least 12 characters long')
|
||||
raise ValueError("Password must be at least 12 characters long")
|
||||
|
||||
# Check against common passwords (case-insensitive)
|
||||
if password.lower() in COMMON_PASSWORDS:
|
||||
raise ValueError('Password is too common. Please choose a stronger password')
|
||||
raise ValueError("Password is too common. Please choose a stronger password")
|
||||
|
||||
# Check for required character types
|
||||
checks = [
|
||||
(any(c.islower() for c in password), 'at least one lowercase letter'),
|
||||
(any(c.isupper() for c in password), 'at least one uppercase letter'),
|
||||
(any(c.isdigit() for c in password), 'at least one digit'),
|
||||
(any(c in '!@#$%^&*()_+-=[]{}|;:,.<>?~`' for c in password), 'at least one special character (!@#$%^&*()_+-=[]{}|;:,.<>?~`)')
|
||||
(any(c.islower() for c in password), "at least one lowercase letter"),
|
||||
(any(c.isupper() for c in password), "at least one uppercase letter"),
|
||||
(any(c.isdigit() for c in password), "at least one digit"),
|
||||
(
|
||||
any(c in "!@#$%^&*()_+-=[]{}|;:,.<>?~`" for c in password),
|
||||
"at least one special character (!@#$%^&*()_+-=[]{}|;:,.<>?~`)",
|
||||
),
|
||||
]
|
||||
|
||||
failed = [msg for check, msg in checks if not check]
|
||||
@@ -94,10 +112,10 @@ def validate_phone_number(phone: str | None) -> str | None:
|
||||
|
||||
# Check for empty strings
|
||||
if not phone or phone.strip() == "":
|
||||
raise ValueError('Phone number cannot be empty')
|
||||
raise ValueError("Phone number cannot be empty")
|
||||
|
||||
# Remove all spaces and formatting characters
|
||||
cleaned = re.sub(r'[\s\-\(\)]', '', phone)
|
||||
cleaned = re.sub(r"[\s\-\(\)]", "", phone)
|
||||
|
||||
# Basic pattern:
|
||||
# Must start with + or 0
|
||||
@@ -105,19 +123,19 @@ def validate_phone_number(phone: str | None) -> str | None:
|
||||
# After 0 must have at least 8 digits
|
||||
# Maximum total length of 15 digits (international standard)
|
||||
# Only allowed characters are + at start and digits
|
||||
pattern = r'^(?:\+[0-9]{8,14}|0[0-9]{8,14})$'
|
||||
pattern = r"^(?:\+[0-9]{8,14}|0[0-9]{8,14})$"
|
||||
|
||||
if not re.match(pattern, cleaned):
|
||||
raise ValueError('Phone number must start with + or 0 followed by 8-14 digits')
|
||||
raise ValueError("Phone number must start with + or 0 followed by 8-14 digits")
|
||||
|
||||
# Additional validation to catch specific invalid cases
|
||||
# NOTE: These checks are defensive code - the regex pattern above already catches these cases
|
||||
if cleaned.count('+') > 1: # pragma: no cover
|
||||
raise ValueError('Phone number can only contain one + symbol at the start')
|
||||
if cleaned.count("+") > 1: # pragma: no cover
|
||||
raise ValueError("Phone number can only contain one + symbol at the start")
|
||||
|
||||
# Check for any non-digit characters (except the leading +)
|
||||
if not all(c.isdigit() for c in cleaned[1:]): # pragma: no cover
|
||||
raise ValueError('Phone number can only contain digits after the prefix')
|
||||
raise ValueError("Phone number can only contain digits after the prefix")
|
||||
|
||||
return cleaned
|
||||
|
||||
@@ -169,16 +187,16 @@ def validate_slug(slug: str) -> str:
|
||||
ValueError: If slug format is invalid
|
||||
"""
|
||||
if not slug or len(slug) < 2:
|
||||
raise ValueError('Slug must be at least 2 characters long')
|
||||
raise ValueError("Slug must be at least 2 characters long")
|
||||
|
||||
if len(slug) > 50:
|
||||
raise ValueError('Slug must be at most 50 characters long')
|
||||
raise ValueError("Slug must be at most 50 characters long")
|
||||
|
||||
# Check format
|
||||
if not re.match(r'^[a-z0-9]+(?:-[a-z0-9]+)*$', slug):
|
||||
if not re.match(r"^[a-z0-9]+(?:-[a-z0-9]+)*$", slug):
|
||||
raise ValueError(
|
||||
'Slug can only contain lowercase letters, numbers, and hyphens. '
|
||||
'It cannot start or end with a hyphen, and cannot contain consecutive hyphens'
|
||||
"Slug can only contain lowercase letters, numbers, and hyphens. "
|
||||
"It cannot start or end with a hyphen, and cannot contain consecutive hyphens"
|
||||
)
|
||||
|
||||
return slug
|
||||
|
||||
@@ -1,18 +1,17 @@
|
||||
# app/services/auth_service.py
|
||||
import logging
|
||||
from typing import Optional
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.core.auth import (
|
||||
verify_password_async,
|
||||
get_password_hash_async,
|
||||
TokenExpiredError,
|
||||
TokenInvalidError,
|
||||
create_access_token,
|
||||
create_refresh_token,
|
||||
TokenExpiredError,
|
||||
TokenInvalidError
|
||||
get_password_hash_async,
|
||||
verify_password_async,
|
||||
)
|
||||
from app.core.config import settings
|
||||
from app.core.exceptions import AuthenticationError
|
||||
@@ -26,7 +25,9 @@ class AuthService:
|
||||
"""Service for handling authentication operations"""
|
||||
|
||||
@staticmethod
|
||||
async def authenticate_user(db: AsyncSession, email: str, password: str) -> Optional[User]:
|
||||
async def authenticate_user(
|
||||
db: AsyncSession, email: str, password: str
|
||||
) -> User | None:
|
||||
"""
|
||||
Authenticate a user with email and password using async password verification.
|
||||
|
||||
@@ -87,7 +88,7 @@ class AuthService:
|
||||
last_name=user_data.last_name,
|
||||
phone_number=user_data.phone_number,
|
||||
is_active=True,
|
||||
is_superuser=False
|
||||
is_superuser=False,
|
||||
)
|
||||
|
||||
db.add(user)
|
||||
@@ -103,8 +104,8 @@ class AuthService:
|
||||
except Exception as e:
|
||||
# Rollback on any database errors
|
||||
await db.rollback()
|
||||
logger.error(f"Error creating user: {str(e)}", exc_info=True)
|
||||
raise AuthenticationError(f"Failed to create user: {str(e)}")
|
||||
logger.error(f"Error creating user: {e!s}", exc_info=True)
|
||||
raise AuthenticationError(f"Failed to create user: {e!s}")
|
||||
|
||||
@staticmethod
|
||||
def create_tokens(user: User) -> Token:
|
||||
@@ -121,18 +122,13 @@ class AuthService:
|
||||
claims = {
|
||||
"is_superuser": user.is_superuser,
|
||||
"email": user.email,
|
||||
"first_name": user.first_name
|
||||
"first_name": user.first_name,
|
||||
}
|
||||
|
||||
# Create tokens
|
||||
access_token = create_access_token(
|
||||
subject=str(user.id),
|
||||
claims=claims
|
||||
)
|
||||
access_token = create_access_token(subject=str(user.id), claims=claims)
|
||||
|
||||
refresh_token = create_refresh_token(
|
||||
subject=str(user.id)
|
||||
)
|
||||
refresh_token = create_refresh_token(subject=str(user.id))
|
||||
|
||||
# Convert User model to UserResponse schema
|
||||
user_response = UserResponse.model_validate(user)
|
||||
@@ -141,7 +137,8 @@ class AuthService:
|
||||
access_token=access_token,
|
||||
refresh_token=refresh_token,
|
||||
user=user_response,
|
||||
expires_in=settings.ACCESS_TOKEN_EXPIRE_MINUTES * 60 # Convert minutes to seconds
|
||||
expires_in=settings.ACCESS_TOKEN_EXPIRE_MINUTES
|
||||
* 60, # Convert minutes to seconds
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
@@ -180,11 +177,13 @@ class AuthService:
|
||||
return AuthService.create_tokens(user)
|
||||
|
||||
except (TokenExpiredError, TokenInvalidError) as e:
|
||||
logger.warning(f"Token refresh failed: {str(e)}")
|
||||
logger.warning(f"Token refresh failed: {e!s}")
|
||||
raise
|
||||
|
||||
@staticmethod
|
||||
async def change_password(db: AsyncSession, user_id: UUID, current_password: str, new_password: str) -> bool:
|
||||
async def change_password(
|
||||
db: AsyncSession, user_id: UUID, current_password: str, new_password: str
|
||||
) -> bool:
|
||||
"""
|
||||
Change a user's password.
|
||||
|
||||
@@ -223,5 +222,7 @@ class AuthService:
|
||||
except Exception as e:
|
||||
# Rollback on any database errors
|
||||
await db.rollback()
|
||||
logger.error(f"Error changing password for user {user_id}: {str(e)}", exc_info=True)
|
||||
raise AuthenticationError(f"Failed to change password: {str(e)}")
|
||||
logger.error(
|
||||
f"Error changing password for user {user_id}: {e!s}", exc_info=True
|
||||
)
|
||||
raise AuthenticationError(f"Failed to change password: {e!s}")
|
||||
|
||||
@@ -5,9 +5,9 @@ Email service with placeholder implementation.
|
||||
This service provides email sending functionality with a simple console/log-based
|
||||
placeholder that can be easily replaced with a real email provider (SendGrid, SES, etc.)
|
||||
"""
|
||||
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List, Optional
|
||||
|
||||
from app.core.config import settings
|
||||
|
||||
@@ -20,13 +20,12 @@ class EmailBackend(ABC):
|
||||
@abstractmethod
|
||||
async def send_email(
|
||||
self,
|
||||
to: List[str],
|
||||
to: list[str],
|
||||
subject: str,
|
||||
html_content: str,
|
||||
text_content: Optional[str] = None
|
||||
text_content: str | None = None,
|
||||
) -> bool:
|
||||
"""Send an email."""
|
||||
pass
|
||||
|
||||
|
||||
class ConsoleEmailBackend(EmailBackend):
|
||||
@@ -39,10 +38,10 @@ class ConsoleEmailBackend(EmailBackend):
|
||||
|
||||
async def send_email(
|
||||
self,
|
||||
to: List[str],
|
||||
to: list[str],
|
||||
subject: str,
|
||||
html_content: str,
|
||||
text_content: Optional[str] = None
|
||||
text_content: str | None = None,
|
||||
) -> bool:
|
||||
"""
|
||||
Log email content to console/logs.
|
||||
@@ -88,10 +87,10 @@ class SMTPEmailBackend(EmailBackend):
|
||||
|
||||
async def send_email(
|
||||
self,
|
||||
to: List[str],
|
||||
to: list[str],
|
||||
subject: str,
|
||||
html_content: str,
|
||||
text_content: Optional[str] = None
|
||||
text_content: str | None = None,
|
||||
) -> bool:
|
||||
"""Send email via SMTP."""
|
||||
# TODO: Implement SMTP sending
|
||||
@@ -108,7 +107,7 @@ class EmailService:
|
||||
and can be configured to use different backends (console, SMTP, SendGrid, etc.)
|
||||
"""
|
||||
|
||||
def __init__(self, backend: Optional[EmailBackend] = None):
|
||||
def __init__(self, backend: EmailBackend | None = None):
|
||||
"""
|
||||
Initialize email service with a backend.
|
||||
|
||||
@@ -118,10 +117,7 @@ class EmailService:
|
||||
self.backend = backend or ConsoleEmailBackend()
|
||||
|
||||
async def send_password_reset_email(
|
||||
self,
|
||||
to_email: str,
|
||||
reset_token: str,
|
||||
user_name: Optional[str] = None
|
||||
self, to_email: str, reset_token: str, user_name: str | None = None
|
||||
) -> bool:
|
||||
"""
|
||||
Send password reset email.
|
||||
@@ -142,7 +138,7 @@ class EmailService:
|
||||
|
||||
# Plain text version
|
||||
text_content = f"""
|
||||
Hello{' ' + user_name if user_name else ''},
|
||||
Hello{" " + user_name if user_name else ""},
|
||||
|
||||
You requested a password reset for your account. Click the link below to reset your password:
|
||||
|
||||
@@ -177,7 +173,7 @@ The {settings.PROJECT_NAME} Team
|
||||
<h1>Password Reset</h1>
|
||||
</div>
|
||||
<div class="content">
|
||||
<p>Hello{' ' + user_name if user_name else ''},</p>
|
||||
<p>Hello{" " + user_name if user_name else ""},</p>
|
||||
<p>You requested a password reset for your account. Click the button below to reset your password:</p>
|
||||
<p style="text-align: center;">
|
||||
<a href="{reset_url}" class="button">Reset Password</a>
|
||||
@@ -200,17 +196,14 @@ The {settings.PROJECT_NAME} Team
|
||||
to=[to_email],
|
||||
subject=subject,
|
||||
html_content=html_content,
|
||||
text_content=text_content
|
||||
text_content=text_content,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to send password reset email to {to_email}: {str(e)}")
|
||||
logger.error(f"Failed to send password reset email to {to_email}: {e!s}")
|
||||
return False
|
||||
|
||||
async def send_email_verification(
|
||||
self,
|
||||
to_email: str,
|
||||
verification_token: str,
|
||||
user_name: Optional[str] = None
|
||||
self, to_email: str, verification_token: str, user_name: str | None = None
|
||||
) -> bool:
|
||||
"""
|
||||
Send email verification email.
|
||||
@@ -224,14 +217,16 @@ The {settings.PROJECT_NAME} Team
|
||||
True if email sent successfully
|
||||
"""
|
||||
# Generate verification URL
|
||||
verification_url = f"{settings.FRONTEND_URL}/verify-email?token={verification_token}"
|
||||
verification_url = (
|
||||
f"{settings.FRONTEND_URL}/verify-email?token={verification_token}"
|
||||
)
|
||||
|
||||
# Prepare email content
|
||||
subject = "Verify Your Email Address"
|
||||
|
||||
# Plain text version
|
||||
text_content = f"""
|
||||
Hello{' ' + user_name if user_name else ''},
|
||||
Hello{" " + user_name if user_name else ""},
|
||||
|
||||
Thank you for signing up! Please verify your email address by clicking the link below:
|
||||
|
||||
@@ -266,7 +261,7 @@ The {settings.PROJECT_NAME} Team
|
||||
<h1>Verify Your Email</h1>
|
||||
</div>
|
||||
<div class="content">
|
||||
<p>Hello{' ' + user_name if user_name else ''},</p>
|
||||
<p>Hello{" " + user_name if user_name else ""},</p>
|
||||
<p>Thank you for signing up! Please verify your email address by clicking the button below:</p>
|
||||
<p style="text-align: center;">
|
||||
<a href="{verification_url}" class="button">Verify Email</a>
|
||||
@@ -289,10 +284,10 @@ The {settings.PROJECT_NAME} Team
|
||||
to=[to_email],
|
||||
subject=subject,
|
||||
html_content=html_content,
|
||||
text_content=text_content
|
||||
text_content=text_content,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to send verification email to {to_email}: {str(e)}")
|
||||
logger.error(f"Failed to send verification email to {to_email}: {e!s}")
|
||||
return False
|
||||
|
||||
|
||||
|
||||
@@ -3,8 +3,9 @@ Background job for cleaning up expired sessions.
|
||||
|
||||
This service runs periodically to remove old session records from the database.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from datetime import datetime, timezone
|
||||
from datetime import UTC, datetime
|
||||
|
||||
from app.core.database import SessionLocal
|
||||
from app.crud.session import session as session_crud
|
||||
@@ -39,7 +40,7 @@ async def cleanup_expired_sessions(keep_days: int = 30) -> int:
|
||||
return count
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error during session cleanup: {str(e)}", exc_info=True)
|
||||
logger.error(f"Error during session cleanup: {e!s}", exc_info=True)
|
||||
return 0
|
||||
|
||||
|
||||
@@ -52,20 +53,21 @@ async def get_session_statistics() -> dict:
|
||||
"""
|
||||
async with SessionLocal() as db:
|
||||
try:
|
||||
from sqlalchemy import func, select
|
||||
|
||||
from app.models.user_session import UserSession
|
||||
from sqlalchemy import select, func
|
||||
|
||||
total_result = await db.execute(select(func.count(UserSession.id)))
|
||||
total_sessions = total_result.scalar_one()
|
||||
|
||||
active_result = await db.execute(
|
||||
select(func.count(UserSession.id)).where(UserSession.is_active == True)
|
||||
select(func.count(UserSession.id)).where(UserSession.is_active)
|
||||
)
|
||||
active_sessions = active_result.scalar_one()
|
||||
|
||||
expired_result = await db.execute(
|
||||
select(func.count(UserSession.id)).where(
|
||||
UserSession.expires_at < datetime.now(timezone.utc)
|
||||
UserSession.expires_at < datetime.now(UTC)
|
||||
)
|
||||
)
|
||||
expired_sessions = expired_result.scalar_one()
|
||||
@@ -82,5 +84,5 @@ async def get_session_statistics() -> dict:
|
||||
return stats
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting session statistics: {str(e)}", exc_info=True)
|
||||
logger.error(f"Error getting session statistics: {e!s}", exc_info=True)
|
||||
return {}
|
||||
|
||||
@@ -2,7 +2,8 @@
|
||||
Authentication utilities for testing.
|
||||
This module provides tools to bypass FastAPI's authentication in tests.
|
||||
"""
|
||||
from typing import Callable, Dict, Optional
|
||||
|
||||
from collections.abc import Callable
|
||||
|
||||
from fastapi import FastAPI
|
||||
from fastapi.security import OAuth2PasswordBearer
|
||||
@@ -13,9 +14,9 @@ from app.models.user import User
|
||||
|
||||
|
||||
def create_test_auth_client(
|
||||
app: FastAPI,
|
||||
test_user: User,
|
||||
extra_overrides: Optional[Dict[Callable, Callable]] = None
|
||||
app: FastAPI,
|
||||
test_user: User,
|
||||
extra_overrides: dict[Callable, Callable] | None = None,
|
||||
) -> TestClient:
|
||||
"""
|
||||
Create a test client with authentication pre-configured.
|
||||
@@ -47,10 +48,7 @@ def create_test_auth_client(
|
||||
return TestClient(app)
|
||||
|
||||
|
||||
def create_test_optional_auth_client(
|
||||
app: FastAPI,
|
||||
test_user: User
|
||||
) -> TestClient:
|
||||
def create_test_optional_auth_client(app: FastAPI, test_user: User) -> TestClient:
|
||||
"""
|
||||
Create a test client with optional authentication pre-configured.
|
||||
|
||||
@@ -70,10 +68,7 @@ def create_test_optional_auth_client(
|
||||
return TestClient(app)
|
||||
|
||||
|
||||
def create_test_superuser_client(
|
||||
app: FastAPI,
|
||||
test_user: User
|
||||
) -> TestClient:
|
||||
def create_test_superuser_client(app: FastAPI, test_user: User) -> TestClient:
|
||||
"""
|
||||
Create a test client with superuser authentication pre-configured.
|
||||
|
||||
@@ -120,7 +115,7 @@ def cleanup_test_client_auth(app: FastAPI) -> None:
|
||||
auth_deps = [
|
||||
get_current_user,
|
||||
get_optional_current_user,
|
||||
OAuth2PasswordBearer(tokenUrl="/api/v1/auth/login")
|
||||
OAuth2PasswordBearer(tokenUrl="/api/v1/auth/login"),
|
||||
]
|
||||
|
||||
# Remove overrides
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
"""
|
||||
Utility functions for extracting and parsing device information from HTTP requests.
|
||||
"""
|
||||
|
||||
import re
|
||||
from typing import Optional
|
||||
|
||||
from fastapi import Request
|
||||
|
||||
@@ -19,11 +19,11 @@ def extract_device_info(request: Request) -> DeviceInfo:
|
||||
Returns:
|
||||
DeviceInfo object with parsed device information
|
||||
"""
|
||||
user_agent = request.headers.get('user-agent', '')
|
||||
user_agent = request.headers.get("user-agent", "")
|
||||
|
||||
device_info = DeviceInfo(
|
||||
device_name=parse_device_name(user_agent),
|
||||
device_id=request.headers.get('x-device-id'), # Client must send this header
|
||||
device_id=request.headers.get("x-device-id"), # Client must send this header
|
||||
ip_address=get_client_ip(request),
|
||||
user_agent=user_agent[:500] if user_agent else None, # Truncate to max length
|
||||
location_city=None, # Can be populated via IP geolocation service
|
||||
@@ -33,7 +33,7 @@ def extract_device_info(request: Request) -> DeviceInfo:
|
||||
return device_info
|
||||
|
||||
|
||||
def parse_device_name(user_agent: str) -> Optional[str]:
|
||||
def parse_device_name(user_agent: str) -> str | None:
|
||||
"""
|
||||
Parse user agent string to extract a friendly device name.
|
||||
|
||||
@@ -54,48 +54,48 @@ def parse_device_name(user_agent: str) -> Optional[str]:
|
||||
user_agent_lower = user_agent.lower()
|
||||
|
||||
# Mobile devices (check first, as they can contain desktop patterns too)
|
||||
if 'iphone' in user_agent_lower:
|
||||
if "iphone" in user_agent_lower:
|
||||
return "iPhone"
|
||||
elif 'ipad' in user_agent_lower:
|
||||
elif "ipad" in user_agent_lower:
|
||||
return "iPad"
|
||||
elif 'android' in user_agent_lower:
|
||||
elif "android" in user_agent_lower:
|
||||
# Try to extract device model
|
||||
android_match = re.search(r'android.*;\s*([^)]+)\s*build', user_agent_lower)
|
||||
android_match = re.search(r"android.*;\s*([^)]+)\s*build", user_agent_lower)
|
||||
if android_match:
|
||||
device_model = android_match.group(1).strip()
|
||||
return f"Android ({device_model.title()})"
|
||||
return "Android device"
|
||||
elif 'windows phone' in user_agent_lower:
|
||||
elif "windows phone" in user_agent_lower:
|
||||
return "Windows Phone"
|
||||
|
||||
# Tablets (check before desktop, as some tablets contain "android")
|
||||
elif 'tablet' in user_agent_lower:
|
||||
elif "tablet" in user_agent_lower:
|
||||
return "Tablet"
|
||||
|
||||
# Smart TVs (check before desktop OS patterns)
|
||||
elif any(tv in user_agent_lower for tv in ['smart-tv', 'smarttv']):
|
||||
elif any(tv in user_agent_lower for tv in ["smart-tv", "smarttv"]):
|
||||
return "Smart TV"
|
||||
|
||||
# Game consoles (check before desktop OS patterns, as Xbox contains "Windows")
|
||||
elif 'playstation' in user_agent_lower:
|
||||
elif "playstation" in user_agent_lower:
|
||||
return "PlayStation"
|
||||
elif 'xbox' in user_agent_lower:
|
||||
elif "xbox" in user_agent_lower:
|
||||
return "Xbox"
|
||||
elif 'nintendo' in user_agent_lower:
|
||||
elif "nintendo" in user_agent_lower:
|
||||
return "Nintendo"
|
||||
|
||||
# Desktop operating systems
|
||||
elif 'macintosh' in user_agent_lower or 'mac os x' in user_agent_lower:
|
||||
elif "macintosh" in user_agent_lower or "mac os x" in user_agent_lower:
|
||||
# Try to extract browser
|
||||
browser = extract_browser(user_agent)
|
||||
return f"{browser} on Mac" if browser else "Mac"
|
||||
elif 'windows' in user_agent_lower:
|
||||
elif "windows" in user_agent_lower:
|
||||
browser = extract_browser(user_agent)
|
||||
return f"{browser} on Windows" if browser else "Windows PC"
|
||||
elif 'linux' in user_agent_lower and 'android' not in user_agent_lower:
|
||||
elif "linux" in user_agent_lower and "android" not in user_agent_lower:
|
||||
browser = extract_browser(user_agent)
|
||||
return f"{browser} on Linux" if browser else "Linux"
|
||||
elif 'cros' in user_agent_lower:
|
||||
elif "cros" in user_agent_lower:
|
||||
return "Chromebook"
|
||||
|
||||
# Fallback: just return browser name if detected
|
||||
@@ -106,7 +106,7 @@ def parse_device_name(user_agent: str) -> Optional[str]:
|
||||
return "Unknown device"
|
||||
|
||||
|
||||
def extract_browser(user_agent: str) -> Optional[str]:
|
||||
def extract_browser(user_agent: str) -> str | None:
|
||||
"""
|
||||
Extract browser name from user agent string.
|
||||
|
||||
@@ -126,26 +126,26 @@ def extract_browser(user_agent: str) -> Optional[str]:
|
||||
user_agent_lower = user_agent.lower()
|
||||
|
||||
# Check specific browsers (order matters - check Edge before Chrome!)
|
||||
if 'edg/' in user_agent_lower or 'edge/' in user_agent_lower:
|
||||
if "edg/" in user_agent_lower or "edge/" in user_agent_lower:
|
||||
return "Edge"
|
||||
elif 'opr/' in user_agent_lower or 'opera' in user_agent_lower:
|
||||
elif "opr/" in user_agent_lower or "opera" in user_agent_lower:
|
||||
return "Opera"
|
||||
elif 'chrome/' in user_agent_lower:
|
||||
elif "chrome/" in user_agent_lower:
|
||||
return "Chrome"
|
||||
elif 'safari/' in user_agent_lower:
|
||||
elif "safari/" in user_agent_lower:
|
||||
# Make sure it's actually Safari, not Chrome (which also contains "Safari")
|
||||
if 'chrome' not in user_agent_lower:
|
||||
if "chrome" not in user_agent_lower:
|
||||
return "Safari"
|
||||
return None
|
||||
elif 'firefox/' in user_agent_lower:
|
||||
elif "firefox/" in user_agent_lower:
|
||||
return "Firefox"
|
||||
elif 'msie' in user_agent_lower or 'trident/' in user_agent_lower:
|
||||
elif "msie" in user_agent_lower or "trident/" in user_agent_lower:
|
||||
return "Internet Explorer"
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def get_client_ip(request: Request) -> Optional[str]:
|
||||
def get_client_ip(request: Request) -> str | None:
|
||||
"""
|
||||
Extract client IP address from request, considering proxy headers.
|
||||
|
||||
@@ -163,14 +163,14 @@ def get_client_ip(request: Request) -> Optional[str]:
|
||||
- request.client.host is fallback for direct connections
|
||||
"""
|
||||
# Check X-Forwarded-For (common in proxied environments)
|
||||
x_forwarded_for = request.headers.get('x-forwarded-for')
|
||||
x_forwarded_for = request.headers.get("x-forwarded-for")
|
||||
if x_forwarded_for:
|
||||
# Get the first IP (original client)
|
||||
client_ip = x_forwarded_for.split(',')[0].strip()
|
||||
client_ip = x_forwarded_for.split(",")[0].strip()
|
||||
return client_ip
|
||||
|
||||
# Check X-Real-IP (used by some proxies like nginx)
|
||||
x_real_ip = request.headers.get('x-real-ip')
|
||||
x_real_ip = request.headers.get("x-real-ip")
|
||||
if x_real_ip:
|
||||
return x_real_ip.strip()
|
||||
|
||||
@@ -195,9 +195,17 @@ def is_mobile_device(user_agent: str) -> bool:
|
||||
return False
|
||||
|
||||
mobile_patterns = [
|
||||
'mobile', 'android', 'iphone', 'ipad', 'ipod',
|
||||
'blackberry', 'windows phone', 'webos', 'opera mini',
|
||||
'iemobile', 'mobile safari'
|
||||
"mobile",
|
||||
"android",
|
||||
"iphone",
|
||||
"ipad",
|
||||
"ipod",
|
||||
"blackberry",
|
||||
"windows phone",
|
||||
"webos",
|
||||
"opera mini",
|
||||
"iemobile",
|
||||
"mobile safari",
|
||||
]
|
||||
|
||||
user_agent_lower = user_agent.lower()
|
||||
@@ -220,7 +228,7 @@ def get_device_type(user_agent: str) -> str:
|
||||
user_agent_lower = user_agent.lower()
|
||||
|
||||
# Check for tablets first (they can contain "mobile" too)
|
||||
if 'ipad' in user_agent_lower or 'tablet' in user_agent_lower:
|
||||
if "ipad" in user_agent_lower or "tablet" in user_agent_lower:
|
||||
return "tablet"
|
||||
|
||||
# Check for mobile
|
||||
@@ -228,7 +236,7 @@ def get_device_type(user_agent: str) -> str:
|
||||
return "mobile"
|
||||
|
||||
# Check for desktop OS patterns
|
||||
if any(os in user_agent_lower for os in ['windows', 'macintosh', 'linux', 'cros']):
|
||||
if any(os in user_agent_lower for os in ["windows", "macintosh", "linux", "cros"]):
|
||||
return "desktop"
|
||||
|
||||
return "other"
|
||||
|
||||
@@ -5,18 +5,21 @@ This module provides utilities for creating and verifying signed tokens,
|
||||
useful for operations like file uploads, password resets, or any other
|
||||
time-limited, single-use operations.
|
||||
"""
|
||||
|
||||
import base64
|
||||
import hashlib
|
||||
import hmac
|
||||
import json
|
||||
import secrets
|
||||
import time
|
||||
from typing import Dict, Any, Optional
|
||||
from typing import Any
|
||||
|
||||
from app.core.config import settings
|
||||
|
||||
|
||||
def create_upload_token(file_path: str, content_type: str, expires_in: int = 300) -> str:
|
||||
def create_upload_token(
|
||||
file_path: str, content_type: str, expires_in: int = 300
|
||||
) -> str:
|
||||
"""
|
||||
Create a signed token for secure file uploads.
|
||||
|
||||
@@ -40,34 +43,29 @@ def create_upload_token(file_path: str, content_type: str, expires_in: int = 300
|
||||
"path": file_path,
|
||||
"content_type": content_type,
|
||||
"exp": int(time.time()) + expires_in,
|
||||
"nonce": secrets.token_hex(8) # Add randomness to prevent token reuse
|
||||
"nonce": secrets.token_hex(8), # Add randomness to prevent token reuse
|
||||
}
|
||||
|
||||
# Convert to JSON and encode
|
||||
payload_bytes = json.dumps(payload).encode('utf-8')
|
||||
payload_bytes = json.dumps(payload).encode("utf-8")
|
||||
|
||||
# Create a signature using HMAC-SHA256 for security
|
||||
# This prevents length extension attacks that plain SHA-256 is vulnerable to
|
||||
signature = hmac.new(
|
||||
settings.SECRET_KEY.encode('utf-8'),
|
||||
payload_bytes,
|
||||
hashlib.sha256
|
||||
settings.SECRET_KEY.encode("utf-8"), payload_bytes, hashlib.sha256
|
||||
).hexdigest()
|
||||
|
||||
# Combine payload and signature
|
||||
token_data = {
|
||||
"payload": payload,
|
||||
"signature": signature
|
||||
}
|
||||
token_data = {"payload": payload, "signature": signature}
|
||||
|
||||
# Encode the final token
|
||||
token_json = json.dumps(token_data)
|
||||
token = base64.urlsafe_b64encode(token_json.encode('utf-8')).decode('utf-8')
|
||||
token = base64.urlsafe_b64encode(token_json.encode("utf-8")).decode("utf-8")
|
||||
|
||||
return token
|
||||
|
||||
|
||||
def verify_upload_token(token: str) -> Optional[Dict[str, Any]]:
|
||||
def verify_upload_token(token: str) -> dict[str, Any] | None:
|
||||
"""
|
||||
Verify an upload token and return the payload if valid.
|
||||
|
||||
@@ -88,7 +86,7 @@ def verify_upload_token(token: str) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
try:
|
||||
# Decode the token
|
||||
token_json = base64.urlsafe_b64decode(token.encode('utf-8')).decode('utf-8')
|
||||
token_json = base64.urlsafe_b64decode(token.encode("utf-8")).decode("utf-8")
|
||||
token_data = json.loads(token_json)
|
||||
|
||||
# Extract payload and signature
|
||||
@@ -96,11 +94,9 @@ def verify_upload_token(token: str) -> Optional[Dict[str, Any]]:
|
||||
signature = token_data["signature"]
|
||||
|
||||
# Verify signature using HMAC and constant-time comparison
|
||||
payload_bytes = json.dumps(payload).encode('utf-8')
|
||||
payload_bytes = json.dumps(payload).encode("utf-8")
|
||||
expected_signature = hmac.new(
|
||||
settings.SECRET_KEY.encode('utf-8'),
|
||||
payload_bytes,
|
||||
hashlib.sha256
|
||||
settings.SECRET_KEY.encode("utf-8"), payload_bytes, hashlib.sha256
|
||||
).hexdigest()
|
||||
|
||||
if not hmac.compare_digest(signature, expected_signature):
|
||||
@@ -136,34 +132,29 @@ def create_password_reset_token(email: str, expires_in: int = 3600) -> str:
|
||||
"email": email,
|
||||
"exp": int(time.time()) + expires_in,
|
||||
"nonce": secrets.token_hex(16), # Extra randomness
|
||||
"purpose": "password_reset"
|
||||
"purpose": "password_reset",
|
||||
}
|
||||
|
||||
# Convert to JSON and encode
|
||||
payload_bytes = json.dumps(payload).encode('utf-8')
|
||||
payload_bytes = json.dumps(payload).encode("utf-8")
|
||||
|
||||
# Create a signature using HMAC-SHA256 for security
|
||||
# This prevents length extension attacks that plain SHA-256 is vulnerable to
|
||||
signature = hmac.new(
|
||||
settings.SECRET_KEY.encode('utf-8'),
|
||||
payload_bytes,
|
||||
hashlib.sha256
|
||||
settings.SECRET_KEY.encode("utf-8"), payload_bytes, hashlib.sha256
|
||||
).hexdigest()
|
||||
|
||||
# Combine payload and signature
|
||||
token_data = {
|
||||
"payload": payload,
|
||||
"signature": signature
|
||||
}
|
||||
token_data = {"payload": payload, "signature": signature}
|
||||
|
||||
# Encode the final token
|
||||
token_json = json.dumps(token_data)
|
||||
token = base64.urlsafe_b64encode(token_json.encode('utf-8')).decode('utf-8')
|
||||
token = base64.urlsafe_b64encode(token_json.encode("utf-8")).decode("utf-8")
|
||||
|
||||
return token
|
||||
|
||||
|
||||
def verify_password_reset_token(token: str) -> Optional[str]:
|
||||
def verify_password_reset_token(token: str) -> str | None:
|
||||
"""
|
||||
Verify a password reset token and return the email if valid.
|
||||
|
||||
@@ -182,7 +173,7 @@ def verify_password_reset_token(token: str) -> Optional[str]:
|
||||
"""
|
||||
try:
|
||||
# Decode the token
|
||||
token_json = base64.urlsafe_b64decode(token.encode('utf-8')).decode('utf-8')
|
||||
token_json = base64.urlsafe_b64decode(token.encode("utf-8")).decode("utf-8")
|
||||
token_data = json.loads(token_json)
|
||||
|
||||
# Extract payload and signature
|
||||
@@ -194,11 +185,9 @@ def verify_password_reset_token(token: str) -> Optional[str]:
|
||||
return None
|
||||
|
||||
# Verify signature using HMAC and constant-time comparison
|
||||
payload_bytes = json.dumps(payload).encode('utf-8')
|
||||
payload_bytes = json.dumps(payload).encode("utf-8")
|
||||
expected_signature = hmac.new(
|
||||
settings.SECRET_KEY.encode('utf-8'),
|
||||
payload_bytes,
|
||||
hashlib.sha256
|
||||
settings.SECRET_KEY.encode("utf-8"), payload_bytes, hashlib.sha256
|
||||
).hexdigest()
|
||||
|
||||
if not hmac.compare_digest(signature, expected_signature):
|
||||
@@ -234,34 +223,29 @@ def create_email_verification_token(email: str, expires_in: int = 86400) -> str:
|
||||
"email": email,
|
||||
"exp": int(time.time()) + expires_in,
|
||||
"nonce": secrets.token_hex(16),
|
||||
"purpose": "email_verification"
|
||||
"purpose": "email_verification",
|
||||
}
|
||||
|
||||
# Convert to JSON and encode
|
||||
payload_bytes = json.dumps(payload).encode('utf-8')
|
||||
payload_bytes = json.dumps(payload).encode("utf-8")
|
||||
|
||||
# Create a signature using HMAC-SHA256 for security
|
||||
# This prevents length extension attacks that plain SHA-256 is vulnerable to
|
||||
signature = hmac.new(
|
||||
settings.SECRET_KEY.encode('utf-8'),
|
||||
payload_bytes,
|
||||
hashlib.sha256
|
||||
settings.SECRET_KEY.encode("utf-8"), payload_bytes, hashlib.sha256
|
||||
).hexdigest()
|
||||
|
||||
# Combine payload and signature
|
||||
token_data = {
|
||||
"payload": payload,
|
||||
"signature": signature
|
||||
}
|
||||
token_data = {"payload": payload, "signature": signature}
|
||||
|
||||
# Encode the final token
|
||||
token_json = json.dumps(token_data)
|
||||
token = base64.urlsafe_b64encode(token_json.encode('utf-8')).decode('utf-8')
|
||||
token = base64.urlsafe_b64encode(token_json.encode("utf-8")).decode("utf-8")
|
||||
|
||||
return token
|
||||
|
||||
|
||||
def verify_email_verification_token(token: str) -> Optional[str]:
|
||||
def verify_email_verification_token(token: str) -> str | None:
|
||||
"""
|
||||
Verify an email verification token and return the email if valid.
|
||||
|
||||
@@ -280,7 +264,7 @@ def verify_email_verification_token(token: str) -> Optional[str]:
|
||||
"""
|
||||
try:
|
||||
# Decode the token
|
||||
token_json = base64.urlsafe_b64decode(token.encode('utf-8')).decode('utf-8')
|
||||
token_json = base64.urlsafe_b64decode(token.encode("utf-8")).decode("utf-8")
|
||||
token_data = json.loads(token_json)
|
||||
|
||||
# Extract payload and signature
|
||||
@@ -292,11 +276,9 @@ def verify_email_verification_token(token: str) -> Optional[str]:
|
||||
return None
|
||||
|
||||
# Verify signature using HMAC and constant-time comparison
|
||||
payload_bytes = json.dumps(payload).encode('utf-8')
|
||||
payload_bytes = json.dumps(payload).encode("utf-8")
|
||||
expected_signature = hmac.new(
|
||||
settings.SECRET_KEY.encode('utf-8'),
|
||||
payload_bytes,
|
||||
hashlib.sha256
|
||||
settings.SECRET_KEY.encode("utf-8"), payload_bytes, hashlib.sha256
|
||||
).hexdigest()
|
||||
|
||||
if not hmac.compare_digest(signature, expected_signature):
|
||||
|
||||
@@ -9,17 +9,19 @@ from app.core.database import Base
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def get_test_engine():
|
||||
"""Create an SQLite in-memory engine specifically for testing"""
|
||||
test_engine = create_engine(
|
||||
"sqlite:///:memory:",
|
||||
connect_args={"check_same_thread": False},
|
||||
poolclass=StaticPool, # Use static pool for in-memory testing
|
||||
echo=False
|
||||
echo=False,
|
||||
)
|
||||
|
||||
return test_engine
|
||||
|
||||
|
||||
def setup_test_db():
|
||||
"""Create a test database and session factory"""
|
||||
# Create a new engine for this test run
|
||||
@@ -30,14 +32,12 @@ def setup_test_db():
|
||||
|
||||
# Create session factory
|
||||
TestingSessionLocal = sessionmaker(
|
||||
autocommit=False,
|
||||
autoflush=False,
|
||||
bind=test_engine,
|
||||
expire_on_commit=False
|
||||
autocommit=False, autoflush=False, bind=test_engine, expire_on_commit=False
|
||||
)
|
||||
|
||||
return test_engine, TestingSessionLocal
|
||||
|
||||
|
||||
def teardown_test_db(engine):
|
||||
"""Clean up after tests"""
|
||||
# Drop all tables
|
||||
@@ -46,13 +46,14 @@ def teardown_test_db(engine):
|
||||
# Dispose of engine
|
||||
engine.dispose()
|
||||
|
||||
|
||||
async def get_async_test_engine():
|
||||
"""Create an async SQLite in-memory engine specifically for testing"""
|
||||
test_engine = create_async_engine(
|
||||
"sqlite+aiosqlite:///:memory:",
|
||||
connect_args={"check_same_thread": False},
|
||||
poolclass=StaticPool, # Use static pool for in-memory testing
|
||||
echo=False
|
||||
echo=False,
|
||||
)
|
||||
return test_engine
|
||||
|
||||
@@ -69,7 +70,7 @@ async def setup_async_test_db():
|
||||
autoflush=False,
|
||||
bind=test_engine,
|
||||
expire_on_commit=False,
|
||||
class_=AsyncSession
|
||||
class_=AsyncSession,
|
||||
)
|
||||
|
||||
return test_engine, AsyncTestingSessionLocal
|
||||
|
||||
@@ -1,15 +1,16 @@
|
||||
# tests/api/dependencies/test_auth_dependencies.py
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
import uuid
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
from fastapi import HTTPException
|
||||
|
||||
from app.api.dependencies.auth import (
|
||||
get_current_user,
|
||||
get_current_active_user,
|
||||
get_current_superuser,
|
||||
get_optional_current_user
|
||||
get_current_user,
|
||||
get_optional_current_user,
|
||||
)
|
||||
from app.core.auth import TokenExpiredError, TokenInvalidError, get_password_hash
|
||||
from app.models.user import User
|
||||
@@ -24,7 +25,7 @@ def mock_token():
|
||||
@pytest_asyncio.fixture
|
||||
async def async_mock_user(async_test_db):
|
||||
"""Async fixture to create and return a mock User instance."""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
mock_user = User(
|
||||
id=uuid.uuid4(),
|
||||
@@ -47,12 +48,14 @@ class TestGetCurrentUser:
|
||||
"""Tests for get_current_user dependency"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_current_user_success(self, async_test_db, async_mock_user, mock_token):
|
||||
async def test_get_current_user_success(
|
||||
self, async_test_db, async_mock_user, mock_token
|
||||
):
|
||||
"""Test successfully getting the current user"""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
# Mock get_token_data to return user_id that matches our mock_user
|
||||
with patch('app.api.dependencies.auth.get_token_data') as mock_get_data:
|
||||
with patch("app.api.dependencies.auth.get_token_data") as mock_get_data:
|
||||
mock_get_data.return_value.user_id = async_mock_user.id
|
||||
|
||||
# Call the dependency
|
||||
@@ -65,12 +68,12 @@ class TestGetCurrentUser:
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_current_user_nonexistent(self, async_test_db, mock_token):
|
||||
"""Test when the token contains a user ID that doesn't exist"""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
# Mock get_token_data to return a non-existent user ID
|
||||
nonexistent_id = uuid.UUID("11111111-1111-1111-1111-111111111111")
|
||||
|
||||
with patch('app.api.dependencies.auth.get_token_data') as mock_get_data:
|
||||
with patch("app.api.dependencies.auth.get_token_data") as mock_get_data:
|
||||
mock_get_data.return_value.user_id = nonexistent_id
|
||||
|
||||
# Should raise HTTPException with 404 status
|
||||
@@ -81,19 +84,24 @@ class TestGetCurrentUser:
|
||||
assert "User not found" in exc_info.value.detail
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_current_user_inactive(self, async_test_db, async_mock_user, mock_token):
|
||||
async def test_get_current_user_inactive(
|
||||
self, async_test_db, async_mock_user, mock_token
|
||||
):
|
||||
"""Test when the user is inactive"""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
# Get the user in this session and make it inactive
|
||||
from sqlalchemy import select
|
||||
result = await session.execute(select(User).where(User.id == async_mock_user.id))
|
||||
|
||||
result = await session.execute(
|
||||
select(User).where(User.id == async_mock_user.id)
|
||||
)
|
||||
user_in_session = result.scalar_one_or_none()
|
||||
user_in_session.is_active = False
|
||||
await session.commit()
|
||||
|
||||
# Mock get_token_data
|
||||
with patch('app.api.dependencies.auth.get_token_data') as mock_get_data:
|
||||
with patch("app.api.dependencies.auth.get_token_data") as mock_get_data:
|
||||
mock_get_data.return_value.user_id = async_mock_user.id
|
||||
|
||||
# Should raise HTTPException with 403 status
|
||||
@@ -106,10 +114,10 @@ class TestGetCurrentUser:
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_current_user_expired_token(self, async_test_db, mock_token):
|
||||
"""Test with an expired token"""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
# Mock get_token_data to raise TokenExpiredError
|
||||
with patch('app.api.dependencies.auth.get_token_data') as mock_get_data:
|
||||
with patch("app.api.dependencies.auth.get_token_data") as mock_get_data:
|
||||
mock_get_data.side_effect = TokenExpiredError("Token expired")
|
||||
|
||||
# Should raise HTTPException with 401 status
|
||||
@@ -122,10 +130,10 @@ class TestGetCurrentUser:
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_current_user_invalid_token(self, async_test_db, mock_token):
|
||||
"""Test with an invalid token"""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
# Mock get_token_data to raise TokenInvalidError
|
||||
with patch('app.api.dependencies.auth.get_token_data') as mock_get_data:
|
||||
with patch("app.api.dependencies.auth.get_token_data") as mock_get_data:
|
||||
mock_get_data.side_effect = TokenInvalidError("Invalid token")
|
||||
|
||||
# Should raise HTTPException with 401 status
|
||||
@@ -194,12 +202,14 @@ class TestGetOptionalCurrentUser:
|
||||
"""Tests for get_optional_current_user dependency"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_optional_current_user_with_token(self, async_test_db, async_mock_user, mock_token):
|
||||
async def test_get_optional_current_user_with_token(
|
||||
self, async_test_db, async_mock_user, mock_token
|
||||
):
|
||||
"""Test getting optional user with a valid token"""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
# Mock get_token_data
|
||||
with patch('app.api.dependencies.auth.get_token_data') as mock_get_data:
|
||||
with patch("app.api.dependencies.auth.get_token_data") as mock_get_data:
|
||||
mock_get_data.return_value.user_id = async_mock_user.id
|
||||
|
||||
# Call the dependency
|
||||
@@ -212,7 +222,7 @@ class TestGetOptionalCurrentUser:
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_optional_current_user_no_token(self, async_test_db):
|
||||
"""Test getting optional user with no token"""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
# Call the dependency with no token
|
||||
user = await get_optional_current_user(db=session, token=None)
|
||||
@@ -221,12 +231,14 @@ class TestGetOptionalCurrentUser:
|
||||
assert user is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_optional_current_user_invalid_token(self, async_test_db, mock_token):
|
||||
async def test_get_optional_current_user_invalid_token(
|
||||
self, async_test_db, mock_token
|
||||
):
|
||||
"""Test getting optional user with an invalid token"""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
# Mock get_token_data to raise TokenInvalidError
|
||||
with patch('app.api.dependencies.auth.get_token_data') as mock_get_data:
|
||||
with patch("app.api.dependencies.auth.get_token_data") as mock_get_data:
|
||||
mock_get_data.side_effect = TokenInvalidError("Invalid token")
|
||||
|
||||
# Call the dependency
|
||||
@@ -236,12 +248,14 @@ class TestGetOptionalCurrentUser:
|
||||
assert user is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_optional_current_user_expired_token(self, async_test_db, mock_token):
|
||||
async def test_get_optional_current_user_expired_token(
|
||||
self, async_test_db, mock_token
|
||||
):
|
||||
"""Test getting optional user with an expired token"""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
# Mock get_token_data to raise TokenExpiredError
|
||||
with patch('app.api.dependencies.auth.get_token_data') as mock_get_data:
|
||||
with patch("app.api.dependencies.auth.get_token_data") as mock_get_data:
|
||||
mock_get_data.side_effect = TokenExpiredError("Token expired")
|
||||
|
||||
# Call the dependency
|
||||
@@ -251,19 +265,24 @@ class TestGetOptionalCurrentUser:
|
||||
assert user is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_optional_current_user_inactive(self, async_test_db, async_mock_user, mock_token):
|
||||
async def test_get_optional_current_user_inactive(
|
||||
self, async_test_db, async_mock_user, mock_token
|
||||
):
|
||||
"""Test getting optional user when user is inactive"""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
# Get the user in this session and make it inactive
|
||||
from sqlalchemy import select
|
||||
result = await session.execute(select(User).where(User.id == async_mock_user.id))
|
||||
|
||||
result = await session.execute(
|
||||
select(User).where(User.id == async_mock_user.id)
|
||||
)
|
||||
user_in_session = result.scalar_one_or_none()
|
||||
user_in_session.is_active = False
|
||||
await session.commit()
|
||||
|
||||
# Mock get_token_data
|
||||
with patch('app.api.dependencies.auth.get_token_data') as mock_get_data:
|
||||
with patch("app.api.dependencies.auth.get_token_data") as mock_get_data:
|
||||
mock_get_data.return_value.user_id = async_mock_user.id
|
||||
|
||||
# Call the dependency
|
||||
|
||||
@@ -1,13 +1,12 @@
|
||||
# tests/api/routes/test_health.py
|
||||
from datetime import datetime
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock, patch, MagicMock
|
||||
from fastapi import status
|
||||
from fastapi.testclient import TestClient
|
||||
from datetime import datetime
|
||||
from sqlalchemy.exc import OperationalError
|
||||
|
||||
from app.main import app
|
||||
from app.core.database import get_db
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@@ -121,7 +120,10 @@ class TestHealthEndpoint:
|
||||
response = client.get("/health")
|
||||
|
||||
# Should succeed without authentication
|
||||
assert response.status_code in [status.HTTP_200_OK, status.HTTP_503_SERVICE_UNAVAILABLE]
|
||||
assert response.status_code in [
|
||||
status.HTTP_200_OK,
|
||||
status.HTTP_503_SERVICE_UNAVAILABLE,
|
||||
]
|
||||
|
||||
def test_health_check_idempotent(self, client):
|
||||
"""Test that multiple health checks return consistent results"""
|
||||
@@ -142,7 +144,10 @@ class TestHealthEndpoint:
|
||||
assert data1["environment"] == data2["environment"]
|
||||
|
||||
# Same database check status
|
||||
assert data1["checks"]["database"]["status"] == data2["checks"]["database"]["status"]
|
||||
assert (
|
||||
data1["checks"]["database"]["status"]
|
||||
== data2["checks"]["database"]["status"]
|
||||
)
|
||||
|
||||
def test_health_check_content_type(self, client):
|
||||
"""Test that health check returns JSON content type"""
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@@ -2,6 +2,7 @@
|
||||
"""
|
||||
Tests for authentication endpoints.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
from fastapi import status
|
||||
@@ -19,8 +20,8 @@ class TestRegisterEndpoint:
|
||||
"email": "newuser@example.com",
|
||||
"password": "NewPassword123!",
|
||||
"first_name": "New",
|
||||
"last_name": "User"
|
||||
}
|
||||
"last_name": "User",
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_201_CREATED
|
||||
@@ -36,8 +37,8 @@ class TestRegisterEndpoint:
|
||||
"email": async_test_user.email,
|
||||
"password": "TestPassword123!",
|
||||
"first_name": "Test",
|
||||
"last_name": "User"
|
||||
}
|
||||
"last_name": "User",
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_400_BAD_REQUEST
|
||||
@@ -51,8 +52,8 @@ class TestRegisterEndpoint:
|
||||
"email": "test@example.com",
|
||||
"password": "weak",
|
||||
"first_name": "Test",
|
||||
"last_name": "User"
|
||||
}
|
||||
"last_name": "User",
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY
|
||||
@@ -66,10 +67,7 @@ class TestLoginEndpoint:
|
||||
"""Test successful login."""
|
||||
response = await client.post(
|
||||
"/api/v1/auth/login",
|
||||
json={
|
||||
"email": "testuser@example.com",
|
||||
"password": "TestPassword123!"
|
||||
}
|
||||
json={"email": "testuser@example.com", "password": "TestPassword123!"},
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
@@ -82,10 +80,7 @@ class TestLoginEndpoint:
|
||||
"""Test login with invalid password."""
|
||||
response = await client.post(
|
||||
"/api/v1/auth/login",
|
||||
json={
|
||||
"email": "testuser@example.com",
|
||||
"password": "WrongPassword123!"
|
||||
}
|
||||
json={"email": "testuser@example.com", "password": "WrongPassword123!"},
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_401_UNAUTHORIZED
|
||||
@@ -95,10 +90,7 @@ class TestLoginEndpoint:
|
||||
"""Test login with non-existent user."""
|
||||
response = await client.post(
|
||||
"/api/v1/auth/login",
|
||||
json={
|
||||
"email": "nonexistent@example.com",
|
||||
"password": "TestPassword123!"
|
||||
}
|
||||
json={"email": "nonexistent@example.com", "password": "TestPassword123!"},
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_401_UNAUTHORIZED
|
||||
@@ -106,27 +98,25 @@ class TestLoginEndpoint:
|
||||
@pytest.mark.asyncio
|
||||
async def test_login_inactive_user(self, client, async_test_db):
|
||||
"""Test login with inactive user."""
|
||||
test_engine, SessionLocal = async_test_db
|
||||
_test_engine, SessionLocal = async_test_db
|
||||
|
||||
async with SessionLocal() as session:
|
||||
from app.models.user import User
|
||||
from app.core.auth import get_password_hash
|
||||
from app.models.user import User
|
||||
|
||||
inactive_user = User(
|
||||
email="inactive@example.com",
|
||||
password_hash=get_password_hash("TestPassword123!"),
|
||||
first_name="Inactive",
|
||||
last_name="User",
|
||||
is_active=False
|
||||
is_active=False,
|
||||
)
|
||||
session.add(inactive_user)
|
||||
await session.commit()
|
||||
|
||||
response = await client.post(
|
||||
"/api/v1/auth/login",
|
||||
json={
|
||||
"email": "inactive@example.com",
|
||||
"password": "TestPassword123!"
|
||||
}
|
||||
json={"email": "inactive@example.com", "password": "TestPassword123!"},
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_401_UNAUTHORIZED
|
||||
@@ -140,10 +130,7 @@ class TestRefreshTokenEndpoint:
|
||||
"""Get a refresh token for testing."""
|
||||
response = await client.post(
|
||||
"/api/v1/auth/login",
|
||||
json={
|
||||
"email": "testuser@example.com",
|
||||
"password": "TestPassword123!"
|
||||
}
|
||||
json={"email": "testuser@example.com", "password": "TestPassword123!"},
|
||||
)
|
||||
return response.json()["refresh_token"]
|
||||
|
||||
@@ -151,8 +138,7 @@ class TestRefreshTokenEndpoint:
|
||||
async def test_refresh_token_success(self, client, refresh_token):
|
||||
"""Test successful token refresh."""
|
||||
response = await client.post(
|
||||
"/api/v1/auth/refresh",
|
||||
json={"refresh_token": refresh_token}
|
||||
"/api/v1/auth/refresh", json={"refresh_token": refresh_token}
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
@@ -164,8 +150,7 @@ class TestRefreshTokenEndpoint:
|
||||
async def test_refresh_token_invalid(self, client):
|
||||
"""Test refresh with invalid token."""
|
||||
response = await client.post(
|
||||
"/api/v1/auth/refresh",
|
||||
json={"refresh_token": "invalid.token.here"}
|
||||
"/api/v1/auth/refresh", json={"refresh_token": "invalid.token.here"}
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_401_UNAUTHORIZED
|
||||
@@ -179,13 +164,13 @@ class TestLogoutEndpoint:
|
||||
"""Get tokens for testing."""
|
||||
response = await client.post(
|
||||
"/api/v1/auth/login",
|
||||
json={
|
||||
"email": "testuser@example.com",
|
||||
"password": "TestPassword123!"
|
||||
}
|
||||
json={"email": "testuser@example.com", "password": "TestPassword123!"},
|
||||
)
|
||||
data = response.json()
|
||||
return {"access_token": data["access_token"], "refresh_token": data["refresh_token"]}
|
||||
return {
|
||||
"access_token": data["access_token"],
|
||||
"refresh_token": data["refresh_token"],
|
||||
}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_logout_success(self, client, tokens):
|
||||
@@ -193,7 +178,7 @@ class TestLogoutEndpoint:
|
||||
response = await client.post(
|
||||
"/api/v1/auth/logout",
|
||||
headers={"Authorization": f"Bearer {tokens['access_token']}"},
|
||||
json={"refresh_token": tokens["refresh_token"]}
|
||||
json={"refresh_token": tokens["refresh_token"]},
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
@@ -202,8 +187,7 @@ class TestLogoutEndpoint:
|
||||
async def test_logout_without_auth(self, client):
|
||||
"""Test logout without authentication."""
|
||||
response = await client.post(
|
||||
"/api/v1/auth/logout",
|
||||
json={"refresh_token": "some.token"}
|
||||
"/api/v1/auth/logout", json={"refresh_token": "some.token"}
|
||||
)
|
||||
assert response.status_code == status.HTTP_401_UNAUTHORIZED
|
||||
|
||||
@@ -215,8 +199,7 @@ class TestPasswordResetRequest:
|
||||
async def test_password_reset_request_success(self, client, async_test_user):
|
||||
"""Test password reset request with existing user."""
|
||||
response = await client.post(
|
||||
"/api/v1/auth/password-reset/request",
|
||||
json={"email": async_test_user.email}
|
||||
"/api/v1/auth/password-reset/request", json={"email": async_test_user.email}
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
@@ -228,7 +211,7 @@ class TestPasswordResetRequest:
|
||||
"""Test password reset request with non-existent email."""
|
||||
response = await client.post(
|
||||
"/api/v1/auth/password-reset/request",
|
||||
json={"email": "nonexistent@example.com"}
|
||||
json={"email": "nonexistent@example.com"},
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
@@ -244,10 +227,7 @@ class TestPasswordResetConfirm:
|
||||
"""Test password reset with invalid token."""
|
||||
response = await client.post(
|
||||
"/api/v1/auth/password-reset/confirm",
|
||||
json={
|
||||
"token": "invalid.token.here",
|
||||
"new_password": "NewPassword123!"
|
||||
}
|
||||
json={"token": "invalid.token.here", "new_password": "NewPassword123!"},
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_400_BAD_REQUEST
|
||||
@@ -261,20 +241,20 @@ class TestLogoutAll:
|
||||
"""Get tokens for testing."""
|
||||
response = await client.post(
|
||||
"/api/v1/auth/login",
|
||||
json={
|
||||
"email": "testuser@example.com",
|
||||
"password": "TestPassword123!"
|
||||
}
|
||||
json={"email": "testuser@example.com", "password": "TestPassword123!"},
|
||||
)
|
||||
data = response.json()
|
||||
return {"access_token": data["access_token"], "refresh_token": data["refresh_token"]}
|
||||
return {
|
||||
"access_token": data["access_token"],
|
||||
"refresh_token": data["refresh_token"],
|
||||
}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_logout_all_success(self, client, tokens):
|
||||
"""Test logout from all devices."""
|
||||
response = await client.post(
|
||||
"/api/v1/auth/logout-all",
|
||||
headers={"Authorization": f"Bearer {tokens['access_token']}"}
|
||||
headers={"Authorization": f"Bearer {tokens['access_token']}"},
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
@@ -298,10 +278,7 @@ class TestOAuthLogin:
|
||||
"""Test successful OAuth login."""
|
||||
response = await client.post(
|
||||
"/api/v1/auth/login/oauth",
|
||||
data={
|
||||
"username": "testuser@example.com",
|
||||
"password": "TestPassword123!"
|
||||
}
|
||||
data={"username": "testuser@example.com", "password": "TestPassword123!"},
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
@@ -315,10 +292,7 @@ class TestOAuthLogin:
|
||||
"""Test OAuth login with invalid credentials."""
|
||||
response = await client.post(
|
||||
"/api/v1/auth/login/oauth",
|
||||
data={
|
||||
"username": "testuser@example.com",
|
||||
"password": "WrongPassword"
|
||||
}
|
||||
data={"username": "testuser@example.com", "password": "WrongPassword"},
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_401_UNAUTHORIZED
|
||||
|
||||
@@ -1,15 +1,16 @@
|
||||
# tests/api/dependencies/test_auth_dependencies.py
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
import uuid
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
from fastapi import HTTPException
|
||||
|
||||
from app.api.dependencies.auth import (
|
||||
get_current_user,
|
||||
get_current_active_user,
|
||||
get_current_superuser,
|
||||
get_optional_current_user
|
||||
get_current_user,
|
||||
get_optional_current_user,
|
||||
)
|
||||
from app.core.auth import TokenExpiredError, TokenInvalidError, get_password_hash
|
||||
from app.models.user import User
|
||||
@@ -24,7 +25,7 @@ def mock_token():
|
||||
@pytest_asyncio.fixture
|
||||
async def async_mock_user(async_test_db):
|
||||
"""Async fixture to create and return a mock User instance."""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
mock_user = User(
|
||||
id=uuid.uuid4(),
|
||||
@@ -47,12 +48,14 @@ class TestGetCurrentUser:
|
||||
"""Tests for get_current_user dependency"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_current_user_success(self, async_test_db, async_mock_user, mock_token):
|
||||
async def test_get_current_user_success(
|
||||
self, async_test_db, async_mock_user, mock_token
|
||||
):
|
||||
"""Test successfully getting the current user"""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
# Mock get_token_data to return user_id that matches our mock_user
|
||||
with patch('app.api.dependencies.auth.get_token_data') as mock_get_data:
|
||||
with patch("app.api.dependencies.auth.get_token_data") as mock_get_data:
|
||||
mock_get_data.return_value.user_id = async_mock_user.id
|
||||
|
||||
# Call the dependency
|
||||
@@ -65,12 +68,12 @@ class TestGetCurrentUser:
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_current_user_nonexistent(self, async_test_db, mock_token):
|
||||
"""Test when the token contains a user ID that doesn't exist"""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
# Mock get_token_data to return a non-existent user ID
|
||||
nonexistent_id = uuid.UUID("11111111-1111-1111-1111-111111111111")
|
||||
|
||||
with patch('app.api.dependencies.auth.get_token_data') as mock_get_data:
|
||||
with patch("app.api.dependencies.auth.get_token_data") as mock_get_data:
|
||||
mock_get_data.return_value.user_id = nonexistent_id
|
||||
|
||||
# Should raise HTTPException with 404 status
|
||||
@@ -81,19 +84,24 @@ class TestGetCurrentUser:
|
||||
assert "User not found" in exc_info.value.detail
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_current_user_inactive(self, async_test_db, async_mock_user, mock_token):
|
||||
async def test_get_current_user_inactive(
|
||||
self, async_test_db, async_mock_user, mock_token
|
||||
):
|
||||
"""Test when the user is inactive"""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
# Get the user in this session and make it inactive
|
||||
from sqlalchemy import select
|
||||
result = await session.execute(select(User).where(User.id == async_mock_user.id))
|
||||
|
||||
result = await session.execute(
|
||||
select(User).where(User.id == async_mock_user.id)
|
||||
)
|
||||
user_in_session = result.scalar_one_or_none()
|
||||
user_in_session.is_active = False
|
||||
await session.commit()
|
||||
|
||||
# Mock get_token_data
|
||||
with patch('app.api.dependencies.auth.get_token_data') as mock_get_data:
|
||||
with patch("app.api.dependencies.auth.get_token_data") as mock_get_data:
|
||||
mock_get_data.return_value.user_id = async_mock_user.id
|
||||
|
||||
# Should raise HTTPException with 403 status
|
||||
@@ -106,10 +114,10 @@ class TestGetCurrentUser:
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_current_user_expired_token(self, async_test_db, mock_token):
|
||||
"""Test with an expired token"""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
# Mock get_token_data to raise TokenExpiredError
|
||||
with patch('app.api.dependencies.auth.get_token_data') as mock_get_data:
|
||||
with patch("app.api.dependencies.auth.get_token_data") as mock_get_data:
|
||||
mock_get_data.side_effect = TokenExpiredError("Token expired")
|
||||
|
||||
# Should raise HTTPException with 401 status
|
||||
@@ -122,10 +130,10 @@ class TestGetCurrentUser:
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_current_user_invalid_token(self, async_test_db, mock_token):
|
||||
"""Test with an invalid token"""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
# Mock get_token_data to raise TokenInvalidError
|
||||
with patch('app.api.dependencies.auth.get_token_data') as mock_get_data:
|
||||
with patch("app.api.dependencies.auth.get_token_data") as mock_get_data:
|
||||
mock_get_data.side_effect = TokenInvalidError("Invalid token")
|
||||
|
||||
# Should raise HTTPException with 401 status
|
||||
@@ -194,12 +202,14 @@ class TestGetOptionalCurrentUser:
|
||||
"""Tests for get_optional_current_user dependency"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_optional_current_user_with_token(self, async_test_db, async_mock_user, mock_token):
|
||||
async def test_get_optional_current_user_with_token(
|
||||
self, async_test_db, async_mock_user, mock_token
|
||||
):
|
||||
"""Test getting optional user with a valid token"""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
# Mock get_token_data
|
||||
with patch('app.api.dependencies.auth.get_token_data') as mock_get_data:
|
||||
with patch("app.api.dependencies.auth.get_token_data") as mock_get_data:
|
||||
mock_get_data.return_value.user_id = async_mock_user.id
|
||||
|
||||
# Call the dependency
|
||||
@@ -212,7 +222,7 @@ class TestGetOptionalCurrentUser:
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_optional_current_user_no_token(self, async_test_db):
|
||||
"""Test getting optional user with no token"""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
# Call the dependency with no token
|
||||
user = await get_optional_current_user(db=session, token=None)
|
||||
@@ -221,12 +231,14 @@ class TestGetOptionalCurrentUser:
|
||||
assert user is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_optional_current_user_invalid_token(self, async_test_db, mock_token):
|
||||
async def test_get_optional_current_user_invalid_token(
|
||||
self, async_test_db, mock_token
|
||||
):
|
||||
"""Test getting optional user with an invalid token"""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
# Mock get_token_data to raise TokenInvalidError
|
||||
with patch('app.api.dependencies.auth.get_token_data') as mock_get_data:
|
||||
with patch("app.api.dependencies.auth.get_token_data") as mock_get_data:
|
||||
mock_get_data.side_effect = TokenInvalidError("Invalid token")
|
||||
|
||||
# Call the dependency
|
||||
@@ -236,12 +248,14 @@ class TestGetOptionalCurrentUser:
|
||||
assert user is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_optional_current_user_expired_token(self, async_test_db, mock_token):
|
||||
async def test_get_optional_current_user_expired_token(
|
||||
self, async_test_db, mock_token
|
||||
):
|
||||
"""Test getting optional user with an expired token"""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
# Mock get_token_data to raise TokenExpiredError
|
||||
with patch('app.api.dependencies.auth.get_token_data') as mock_get_data:
|
||||
with patch("app.api.dependencies.auth.get_token_data") as mock_get_data:
|
||||
mock_get_data.side_effect = TokenExpiredError("Token expired")
|
||||
|
||||
# Call the dependency
|
||||
@@ -251,19 +265,24 @@ class TestGetOptionalCurrentUser:
|
||||
assert user is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_optional_current_user_inactive(self, async_test_db, async_mock_user, mock_token):
|
||||
async def test_get_optional_current_user_inactive(
|
||||
self, async_test_db, async_mock_user, mock_token
|
||||
):
|
||||
"""Test getting optional user when user is inactive"""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
# Get the user in this session and make it inactive
|
||||
from sqlalchemy import select
|
||||
result = await session.execute(select(User).where(User.id == async_mock_user.id))
|
||||
|
||||
result = await session.execute(
|
||||
select(User).where(User.id == async_mock_user.id)
|
||||
)
|
||||
user_in_session = result.scalar_one_or_none()
|
||||
user_in_session.is_active = False
|
||||
await session.commit()
|
||||
|
||||
# Mock get_token_data
|
||||
with patch('app.api.dependencies.auth.get_token_data') as mock_get_data:
|
||||
with patch("app.api.dependencies.auth.get_token_data") as mock_get_data:
|
||||
mock_get_data.return_value.user_id = async_mock_user.id
|
||||
|
||||
# Call the dependency
|
||||
|
||||
@@ -2,21 +2,21 @@
|
||||
"""
|
||||
Tests for authentication endpoints.
|
||||
"""
|
||||
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
from unittest.mock import patch, MagicMock
|
||||
from fastapi import status
|
||||
from sqlalchemy import select
|
||||
|
||||
from app.models.user import User
|
||||
from app.schemas.users import UserCreate
|
||||
|
||||
|
||||
# Disable rate limiting for tests
|
||||
@pytest.fixture(autouse=True)
|
||||
def disable_rate_limit():
|
||||
"""Disable rate limiting for all tests in this module."""
|
||||
with patch('app.api.routes.auth.limiter.enabled', False):
|
||||
with patch("app.api.routes.auth.limiter.enabled", False):
|
||||
yield
|
||||
|
||||
|
||||
@@ -32,8 +32,8 @@ class TestRegisterEndpoint:
|
||||
"email": "newuser@example.com",
|
||||
"password": "SecurePassword123!",
|
||||
"first_name": "New",
|
||||
"last_name": "User"
|
||||
}
|
||||
"last_name": "User",
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_201_CREATED
|
||||
@@ -54,8 +54,8 @@ class TestRegisterEndpoint:
|
||||
"email": async_test_user.email,
|
||||
"password": "SecurePassword123!",
|
||||
"first_name": "Duplicate",
|
||||
"last_name": "User"
|
||||
}
|
||||
"last_name": "User",
|
||||
},
|
||||
)
|
||||
|
||||
# Security: Returns 400 with generic message to prevent email enumeration
|
||||
@@ -73,8 +73,8 @@ class TestRegisterEndpoint:
|
||||
"email": "weakpass@example.com",
|
||||
"password": "weak",
|
||||
"first_name": "Weak",
|
||||
"last_name": "Pass"
|
||||
}
|
||||
"last_name": "Pass",
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY
|
||||
@@ -82,7 +82,7 @@ class TestRegisterEndpoint:
|
||||
@pytest.mark.asyncio
|
||||
async def test_register_unexpected_error(self, client):
|
||||
"""Test registration with unexpected error."""
|
||||
with patch('app.services.auth_service.AuthService.create_user') as mock_create:
|
||||
with patch("app.services.auth_service.AuthService.create_user") as mock_create:
|
||||
mock_create.side_effect = Exception("Unexpected error")
|
||||
|
||||
response = await client.post(
|
||||
@@ -91,8 +91,8 @@ class TestRegisterEndpoint:
|
||||
"email": "error@example.com",
|
||||
"password": "SecurePassword123!",
|
||||
"first_name": "Error",
|
||||
"last_name": "User"
|
||||
}
|
||||
"last_name": "User",
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR
|
||||
@@ -106,10 +106,7 @@ class TestLoginEndpoint:
|
||||
"""Test successful login."""
|
||||
response = await client.post(
|
||||
"/api/v1/auth/login",
|
||||
json={
|
||||
"email": async_test_user.email,
|
||||
"password": "TestPassword123!"
|
||||
}
|
||||
json={"email": async_test_user.email, "password": "TestPassword123!"},
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
@@ -123,10 +120,7 @@ class TestLoginEndpoint:
|
||||
"""Test login with wrong password."""
|
||||
response = await client.post(
|
||||
"/api/v1/auth/login",
|
||||
json={
|
||||
"email": async_test_user.email,
|
||||
"password": "WrongPassword123"
|
||||
}
|
||||
json={"email": async_test_user.email, "password": "WrongPassword123"},
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_401_UNAUTHORIZED
|
||||
@@ -136,10 +130,7 @@ class TestLoginEndpoint:
|
||||
"""Test login with non-existent email."""
|
||||
response = await client.post(
|
||||
"/api/v1/auth/login",
|
||||
json={
|
||||
"email": "nonexistent@example.com",
|
||||
"password": "Password123!"
|
||||
}
|
||||
json={"email": "nonexistent@example.com", "password": "Password123!"},
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_401_UNAUTHORIZED
|
||||
@@ -147,20 +138,19 @@ class TestLoginEndpoint:
|
||||
@pytest.mark.asyncio
|
||||
async def test_login_inactive_user(self, client, async_test_user, async_test_db):
|
||||
"""Test login with inactive user."""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
# Get the user in this session and make it inactive
|
||||
result = await session.execute(select(User).where(User.id == async_test_user.id))
|
||||
result = await session.execute(
|
||||
select(User).where(User.id == async_test_user.id)
|
||||
)
|
||||
user_in_session = result.scalar_one_or_none()
|
||||
user_in_session.is_active = False
|
||||
await session.commit()
|
||||
|
||||
response = await client.post(
|
||||
"/api/v1/auth/login",
|
||||
json={
|
||||
"email": async_test_user.email,
|
||||
"password": "TestPassword123!"
|
||||
}
|
||||
json={"email": async_test_user.email, "password": "TestPassword123!"},
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_401_UNAUTHORIZED
|
||||
@@ -168,15 +158,14 @@ class TestLoginEndpoint:
|
||||
@pytest.mark.asyncio
|
||||
async def test_login_unexpected_error(self, client, async_test_user):
|
||||
"""Test login with unexpected error."""
|
||||
with patch('app.services.auth_service.AuthService.authenticate_user') as mock_auth:
|
||||
with patch(
|
||||
"app.services.auth_service.AuthService.authenticate_user"
|
||||
) as mock_auth:
|
||||
mock_auth.side_effect = Exception("Database error")
|
||||
|
||||
response = await client.post(
|
||||
"/api/v1/auth/login",
|
||||
json={
|
||||
"email": async_test_user.email,
|
||||
"password": "TestPassword123!"
|
||||
}
|
||||
json={"email": async_test_user.email, "password": "TestPassword123!"},
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR
|
||||
@@ -190,10 +179,7 @@ class TestOAuthLoginEndpoint:
|
||||
"""Test successful OAuth login."""
|
||||
response = await client.post(
|
||||
"/api/v1/auth/login/oauth",
|
||||
data={
|
||||
"username": async_test_user.email,
|
||||
"password": "TestPassword123!"
|
||||
}
|
||||
data={"username": async_test_user.email, "password": "TestPassword123!"},
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
@@ -206,31 +192,29 @@ class TestOAuthLoginEndpoint:
|
||||
"""Test OAuth login with wrong credentials."""
|
||||
response = await client.post(
|
||||
"/api/v1/auth/login/oauth",
|
||||
data={
|
||||
"username": async_test_user.email,
|
||||
"password": "WrongPassword"
|
||||
}
|
||||
data={"username": async_test_user.email, "password": "WrongPassword"},
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_401_UNAUTHORIZED
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_oauth_login_inactive_user(self, client, async_test_user, async_test_db):
|
||||
async def test_oauth_login_inactive_user(
|
||||
self, client, async_test_user, async_test_db
|
||||
):
|
||||
"""Test OAuth login with inactive user."""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
# Get the user in this session and make it inactive
|
||||
result = await session.execute(select(User).where(User.id == async_test_user.id))
|
||||
result = await session.execute(
|
||||
select(User).where(User.id == async_test_user.id)
|
||||
)
|
||||
user_in_session = result.scalar_one_or_none()
|
||||
user_in_session.is_active = False
|
||||
await session.commit()
|
||||
|
||||
response = await client.post(
|
||||
"/api/v1/auth/login/oauth",
|
||||
data={
|
||||
"username": async_test_user.email,
|
||||
"password": "TestPassword123!"
|
||||
}
|
||||
data={"username": async_test_user.email, "password": "TestPassword123!"},
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_401_UNAUTHORIZED
|
||||
@@ -238,15 +222,17 @@ class TestOAuthLoginEndpoint:
|
||||
@pytest.mark.asyncio
|
||||
async def test_oauth_login_unexpected_error(self, client, async_test_user):
|
||||
"""Test OAuth login with unexpected error."""
|
||||
with patch('app.services.auth_service.AuthService.authenticate_user') as mock_auth:
|
||||
with patch(
|
||||
"app.services.auth_service.AuthService.authenticate_user"
|
||||
) as mock_auth:
|
||||
mock_auth.side_effect = Exception("Unexpected error")
|
||||
|
||||
response = await client.post(
|
||||
"/api/v1/auth/login/oauth",
|
||||
data={
|
||||
"username": async_test_user.email,
|
||||
"password": "TestPassword123!"
|
||||
}
|
||||
"password": "TestPassword123!",
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR
|
||||
@@ -261,17 +247,13 @@ class TestRefreshTokenEndpoint:
|
||||
# First, login to get a refresh token
|
||||
login_response = await client.post(
|
||||
"/api/v1/auth/login",
|
||||
json={
|
||||
"email": async_test_user.email,
|
||||
"password": "TestPassword123!"
|
||||
}
|
||||
json={"email": async_test_user.email, "password": "TestPassword123!"},
|
||||
)
|
||||
refresh_token = login_response.json()["refresh_token"]
|
||||
|
||||
# Now refresh the token
|
||||
response = await client.post(
|
||||
"/api/v1/auth/refresh",
|
||||
json={"refresh_token": refresh_token}
|
||||
"/api/v1/auth/refresh", json={"refresh_token": refresh_token}
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
@@ -284,12 +266,13 @@ class TestRefreshTokenEndpoint:
|
||||
"""Test refresh with expired token."""
|
||||
from app.core.auth import TokenExpiredError
|
||||
|
||||
with patch('app.services.auth_service.AuthService.refresh_tokens') as mock_refresh:
|
||||
with patch(
|
||||
"app.services.auth_service.AuthService.refresh_tokens"
|
||||
) as mock_refresh:
|
||||
mock_refresh.side_effect = TokenExpiredError("Token expired")
|
||||
|
||||
response = await client.post(
|
||||
"/api/v1/auth/refresh",
|
||||
json={"refresh_token": "some_token"}
|
||||
"/api/v1/auth/refresh", json={"refresh_token": "some_token"}
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_401_UNAUTHORIZED
|
||||
@@ -298,8 +281,7 @@ class TestRefreshTokenEndpoint:
|
||||
async def test_refresh_token_invalid(self, client):
|
||||
"""Test refresh with invalid token."""
|
||||
response = await client.post(
|
||||
"/api/v1/auth/refresh",
|
||||
json={"refresh_token": "invalid_token"}
|
||||
"/api/v1/auth/refresh", json={"refresh_token": "invalid_token"}
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_401_UNAUTHORIZED
|
||||
@@ -310,19 +292,17 @@ class TestRefreshTokenEndpoint:
|
||||
# Get a valid refresh token first
|
||||
login_response = await client.post(
|
||||
"/api/v1/auth/login",
|
||||
json={
|
||||
"email": async_test_user.email,
|
||||
"password": "TestPassword123!"
|
||||
}
|
||||
json={"email": async_test_user.email, "password": "TestPassword123!"},
|
||||
)
|
||||
refresh_token = login_response.json()["refresh_token"]
|
||||
|
||||
with patch('app.services.auth_service.AuthService.refresh_tokens') as mock_refresh:
|
||||
with patch(
|
||||
"app.services.auth_service.AuthService.refresh_tokens"
|
||||
) as mock_refresh:
|
||||
mock_refresh.side_effect = Exception("Unexpected error")
|
||||
|
||||
response = await client.post(
|
||||
"/api/v1/auth/refresh",
|
||||
json={"refresh_token": refresh_token}
|
||||
"/api/v1/auth/refresh", json={"refresh_token": refresh_token}
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR
|
||||
|
||||
@@ -2,8 +2,10 @@
|
||||
"""
|
||||
Tests for auth route exception handlers and error paths.
|
||||
"""
|
||||
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from unittest.mock import patch, AsyncMock
|
||||
from fastapi import status
|
||||
|
||||
|
||||
@@ -11,16 +13,18 @@ class TestLoginSessionCreationFailure:
|
||||
"""Test login when session creation fails."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_login_succeeds_despite_session_creation_failure(self, client, async_test_user):
|
||||
async def test_login_succeeds_despite_session_creation_failure(
|
||||
self, client, async_test_user
|
||||
):
|
||||
"""Test that login succeeds even if session creation fails."""
|
||||
# Mock session creation to fail
|
||||
with patch('app.api.routes.auth.session_crud.create_session', side_effect=Exception("Session creation failed")):
|
||||
with patch(
|
||||
"app.api.routes.auth.session_crud.create_session",
|
||||
side_effect=Exception("Session creation failed"),
|
||||
):
|
||||
response = await client.post(
|
||||
"/api/v1/auth/login",
|
||||
json={
|
||||
"email": "testuser@example.com",
|
||||
"password": "TestPassword123!"
|
||||
}
|
||||
json={"email": "testuser@example.com", "password": "TestPassword123!"},
|
||||
)
|
||||
|
||||
# Login should still succeed, just without session record
|
||||
@@ -34,15 +38,20 @@ class TestOAuthLoginSessionCreationFailure:
|
||||
"""Test OAuth login when session creation fails."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_oauth_login_succeeds_despite_session_failure(self, client, async_test_user):
|
||||
async def test_oauth_login_succeeds_despite_session_failure(
|
||||
self, client, async_test_user
|
||||
):
|
||||
"""Test OAuth login succeeds even if session creation fails."""
|
||||
with patch('app.api.routes.auth.session_crud.create_session', side_effect=Exception("Session failed")):
|
||||
with patch(
|
||||
"app.api.routes.auth.session_crud.create_session",
|
||||
side_effect=Exception("Session failed"),
|
||||
):
|
||||
response = await client.post(
|
||||
"/api/v1/auth/login/oauth",
|
||||
data={
|
||||
"username": "testuser@example.com",
|
||||
"password": "TestPassword123!"
|
||||
}
|
||||
"password": "TestPassword123!",
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
@@ -54,23 +63,24 @@ class TestRefreshTokenSessionUpdateFailure:
|
||||
"""Test refresh token when session update fails."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_refresh_token_succeeds_despite_session_update_failure(self, client, async_test_user):
|
||||
async def test_refresh_token_succeeds_despite_session_update_failure(
|
||||
self, client, async_test_user
|
||||
):
|
||||
"""Test that token refresh succeeds even if session update fails."""
|
||||
# First login to get tokens
|
||||
response = await client.post(
|
||||
"/api/v1/auth/login",
|
||||
json={
|
||||
"email": "testuser@example.com",
|
||||
"password": "TestPassword123!"
|
||||
}
|
||||
json={"email": "testuser@example.com", "password": "TestPassword123!"},
|
||||
)
|
||||
tokens = response.json()
|
||||
|
||||
# Mock session update to fail
|
||||
with patch('app.api.routes.auth.session_crud.update_refresh_token', side_effect=Exception("Update failed")):
|
||||
with patch(
|
||||
"app.api.routes.auth.session_crud.update_refresh_token",
|
||||
side_effect=Exception("Update failed"),
|
||||
):
|
||||
response = await client.post(
|
||||
"/api/v1/auth/refresh",
|
||||
json={"refresh_token": tokens["refresh_token"]}
|
||||
"/api/v1/auth/refresh", json={"refresh_token": tokens["refresh_token"]}
|
||||
)
|
||||
|
||||
# Should still succeed - tokens are issued before update
|
||||
@@ -83,15 +93,14 @@ class TestLogoutWithExpiredToken:
|
||||
"""Test logout with expired/invalid token."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_logout_with_invalid_token_still_succeeds(self, client, async_test_user):
|
||||
async def test_logout_with_invalid_token_still_succeeds(
|
||||
self, client, async_test_user
|
||||
):
|
||||
"""Test logout succeeds even with invalid refresh token."""
|
||||
# Login first
|
||||
response = await client.post(
|
||||
"/api/v1/auth/login",
|
||||
json={
|
||||
"email": "testuser@example.com",
|
||||
"password": "TestPassword123!"
|
||||
}
|
||||
json={"email": "testuser@example.com", "password": "TestPassword123!"},
|
||||
)
|
||||
access_token = response.json()["access_token"]
|
||||
|
||||
@@ -99,7 +108,7 @@ class TestLogoutWithExpiredToken:
|
||||
response = await client.post(
|
||||
"/api/v1/auth/logout",
|
||||
headers={"Authorization": f"Bearer {access_token}"},
|
||||
json={"refresh_token": "invalid.token.here"}
|
||||
json={"refresh_token": "invalid.token.here"},
|
||||
)
|
||||
|
||||
# Should succeed (idempotent)
|
||||
@@ -116,19 +125,16 @@ class TestLogoutWithNonExistentSession:
|
||||
"""Test logout succeeds even if session not found."""
|
||||
response = await client.post(
|
||||
"/api/v1/auth/login",
|
||||
json={
|
||||
"email": "testuser@example.com",
|
||||
"password": "TestPassword123!"
|
||||
}
|
||||
json={"email": "testuser@example.com", "password": "TestPassword123!"},
|
||||
)
|
||||
tokens = response.json()
|
||||
|
||||
# Mock session lookup to return None
|
||||
with patch('app.api.routes.auth.session_crud.get_by_jti', return_value=None):
|
||||
with patch("app.api.routes.auth.session_crud.get_by_jti", return_value=None):
|
||||
response = await client.post(
|
||||
"/api/v1/auth/logout",
|
||||
headers={"Authorization": f"Bearer {tokens['access_token']}"},
|
||||
json={"refresh_token": tokens["refresh_token"]}
|
||||
json={"refresh_token": tokens["refresh_token"]},
|
||||
)
|
||||
|
||||
# Should succeed (idempotent)
|
||||
@@ -139,23 +145,25 @@ class TestLogoutUnexpectedError:
|
||||
"""Test logout with unexpected errors."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_logout_with_unexpected_error_returns_success(self, client, async_test_user):
|
||||
async def test_logout_with_unexpected_error_returns_success(
|
||||
self, client, async_test_user
|
||||
):
|
||||
"""Test logout returns success even on unexpected errors."""
|
||||
response = await client.post(
|
||||
"/api/v1/auth/login",
|
||||
json={
|
||||
"email": "testuser@example.com",
|
||||
"password": "TestPassword123!"
|
||||
}
|
||||
json={"email": "testuser@example.com", "password": "TestPassword123!"},
|
||||
)
|
||||
tokens = response.json()
|
||||
|
||||
# Mock to raise unexpected error
|
||||
with patch('app.api.routes.auth.session_crud.get_by_jti', side_effect=Exception("Unexpected error")):
|
||||
with patch(
|
||||
"app.api.routes.auth.session_crud.get_by_jti",
|
||||
side_effect=Exception("Unexpected error"),
|
||||
):
|
||||
response = await client.post(
|
||||
"/api/v1/auth/logout",
|
||||
headers={"Authorization": f"Bearer {tokens['access_token']}"},
|
||||
json={"refresh_token": tokens["refresh_token"]}
|
||||
json={"refresh_token": tokens["refresh_token"]},
|
||||
)
|
||||
|
||||
# Should still return success (don't expose errors)
|
||||
@@ -172,18 +180,18 @@ class TestLogoutAllUnexpectedError:
|
||||
"""Test logout-all handles database errors."""
|
||||
response = await client.post(
|
||||
"/api/v1/auth/login",
|
||||
json={
|
||||
"email": "testuser@example.com",
|
||||
"password": "TestPassword123!"
|
||||
}
|
||||
json={"email": "testuser@example.com", "password": "TestPassword123!"},
|
||||
)
|
||||
access_token = response.json()["access_token"]
|
||||
|
||||
# Mock to raise database error
|
||||
with patch('app.api.routes.auth.session_crud.deactivate_all_user_sessions', side_effect=Exception("DB error")):
|
||||
with patch(
|
||||
"app.api.routes.auth.session_crud.deactivate_all_user_sessions",
|
||||
side_effect=Exception("DB error"),
|
||||
):
|
||||
response = await client.post(
|
||||
"/api/v1/auth/logout-all",
|
||||
headers={"Authorization": f"Bearer {access_token}"}
|
||||
headers={"Authorization": f"Bearer {access_token}"},
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR
|
||||
@@ -193,7 +201,9 @@ class TestPasswordResetConfirmSessionInvalidation:
|
||||
"""Test password reset invalidates sessions."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_password_reset_continues_despite_session_invalidation_failure(self, client, async_test_user):
|
||||
async def test_password_reset_continues_despite_session_invalidation_failure(
|
||||
self, client, async_test_user
|
||||
):
|
||||
"""Test password reset succeeds even if session invalidation fails."""
|
||||
# Create a valid password reset token
|
||||
from app.utils.security import create_password_reset_token
|
||||
@@ -201,13 +211,13 @@ class TestPasswordResetConfirmSessionInvalidation:
|
||||
token = create_password_reset_token(async_test_user.email)
|
||||
|
||||
# Mock session invalidation to fail
|
||||
with patch('app.api.routes.auth.session_crud.deactivate_all_user_sessions', side_effect=Exception("Invalidation failed")):
|
||||
with patch(
|
||||
"app.api.routes.auth.session_crud.deactivate_all_user_sessions",
|
||||
side_effect=Exception("Invalidation failed"),
|
||||
):
|
||||
response = await client.post(
|
||||
"/api/v1/auth/password-reset/confirm",
|
||||
json={
|
||||
"token": token,
|
||||
"new_password": "NewPassword123!"
|
||||
}
|
||||
json={"token": token, "new_password": "NewPassword123!"},
|
||||
)
|
||||
|
||||
# Should still succeed - password was reset
|
||||
|
||||
@@ -2,22 +2,22 @@
|
||||
"""
|
||||
Tests for password reset endpoints.
|
||||
"""
|
||||
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
from unittest.mock import patch, AsyncMock, MagicMock
|
||||
from fastapi import status
|
||||
from sqlalchemy import select
|
||||
|
||||
from app.schemas.users import PasswordResetRequest, PasswordResetConfirm
|
||||
from app.utils.security import create_password_reset_token
|
||||
from app.models.user import User
|
||||
from app.utils.security import create_password_reset_token
|
||||
|
||||
|
||||
# Disable rate limiting for tests
|
||||
@pytest.fixture(autouse=True)
|
||||
def disable_rate_limit():
|
||||
"""Disable rate limiting for all tests in this module."""
|
||||
with patch('app.api.routes.auth.limiter.enabled', False):
|
||||
with patch("app.api.routes.auth.limiter.enabled", False):
|
||||
yield
|
||||
|
||||
|
||||
@@ -27,12 +27,14 @@ class TestPasswordResetRequest:
|
||||
@pytest.mark.asyncio
|
||||
async def test_password_reset_request_valid_email(self, client, async_test_user):
|
||||
"""Test password reset request with valid email."""
|
||||
with patch('app.api.routes.auth.email_service.send_password_reset_email') as mock_send:
|
||||
with patch(
|
||||
"app.api.routes.auth.email_service.send_password_reset_email"
|
||||
) as mock_send:
|
||||
mock_send.return_value = True
|
||||
|
||||
response = await client.post(
|
||||
"/api/v1/auth/password-reset/request",
|
||||
json={"email": async_test_user.email}
|
||||
json={"email": async_test_user.email},
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
@@ -50,10 +52,12 @@ class TestPasswordResetRequest:
|
||||
@pytest.mark.asyncio
|
||||
async def test_password_reset_request_nonexistent_email(self, client):
|
||||
"""Test password reset request with non-existent email."""
|
||||
with patch('app.api.routes.auth.email_service.send_password_reset_email') as mock_send:
|
||||
with patch(
|
||||
"app.api.routes.auth.email_service.send_password_reset_email"
|
||||
) as mock_send:
|
||||
response = await client.post(
|
||||
"/api/v1/auth/password-reset/request",
|
||||
json={"email": "nonexistent@example.com"}
|
||||
json={"email": "nonexistent@example.com"},
|
||||
)
|
||||
|
||||
# Should still return success to prevent email enumeration
|
||||
@@ -65,20 +69,26 @@ class TestPasswordResetRequest:
|
||||
mock_send.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_password_reset_request_inactive_user(self, client, async_test_db, async_test_user):
|
||||
async def test_password_reset_request_inactive_user(
|
||||
self, client, async_test_db, async_test_user
|
||||
):
|
||||
"""Test password reset request with inactive user."""
|
||||
# Deactivate user
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
result = await session.execute(select(User).where(User.id == async_test_user.id))
|
||||
result = await session.execute(
|
||||
select(User).where(User.id == async_test_user.id)
|
||||
)
|
||||
user_in_session = result.scalar_one_or_none()
|
||||
user_in_session.is_active = False
|
||||
await session.commit()
|
||||
|
||||
with patch('app.api.routes.auth.email_service.send_password_reset_email') as mock_send:
|
||||
with patch(
|
||||
"app.api.routes.auth.email_service.send_password_reset_email"
|
||||
) as mock_send:
|
||||
response = await client.post(
|
||||
"/api/v1/auth/password-reset/request",
|
||||
json={"email": async_test_user.email}
|
||||
json={"email": async_test_user.email},
|
||||
)
|
||||
|
||||
# Should still return success to prevent email enumeration
|
||||
@@ -93,8 +103,7 @@ class TestPasswordResetRequest:
|
||||
async def test_password_reset_request_invalid_email_format(self, client):
|
||||
"""Test password reset request with invalid email format."""
|
||||
response = await client.post(
|
||||
"/api/v1/auth/password-reset/request",
|
||||
json={"email": "not-an-email"}
|
||||
"/api/v1/auth/password-reset/request", json={"email": "not-an-email"}
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY
|
||||
@@ -102,22 +111,23 @@ class TestPasswordResetRequest:
|
||||
@pytest.mark.asyncio
|
||||
async def test_password_reset_request_missing_email(self, client):
|
||||
"""Test password reset request without email."""
|
||||
response = await client.post(
|
||||
"/api/v1/auth/password-reset/request",
|
||||
json={}
|
||||
)
|
||||
response = await client.post("/api/v1/auth/password-reset/request", json={})
|
||||
|
||||
assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_password_reset_request_email_service_error(self, client, async_test_user):
|
||||
async def test_password_reset_request_email_service_error(
|
||||
self, client, async_test_user
|
||||
):
|
||||
"""Test password reset when email service fails."""
|
||||
with patch('app.api.routes.auth.email_service.send_password_reset_email') as mock_send:
|
||||
with patch(
|
||||
"app.api.routes.auth.email_service.send_password_reset_email"
|
||||
) as mock_send:
|
||||
mock_send.side_effect = Exception("SMTP Error")
|
||||
|
||||
response = await client.post(
|
||||
"/api/v1/auth/password-reset/request",
|
||||
json={"email": async_test_user.email}
|
||||
json={"email": async_test_user.email},
|
||||
)
|
||||
|
||||
# Should still return success even if email fails
|
||||
@@ -128,14 +138,16 @@ class TestPasswordResetRequest:
|
||||
@pytest.mark.asyncio
|
||||
async def test_password_reset_request_rate_limiting(self, client, async_test_user):
|
||||
"""Test that password reset requests are rate limited."""
|
||||
with patch('app.api.routes.auth.email_service.send_password_reset_email') as mock_send:
|
||||
with patch(
|
||||
"app.api.routes.auth.email_service.send_password_reset_email"
|
||||
) as mock_send:
|
||||
mock_send.return_value = True
|
||||
|
||||
# Make multiple requests quickly (3/minute limit)
|
||||
for _ in range(3):
|
||||
response = await client.post(
|
||||
"/api/v1/auth/password-reset/request",
|
||||
json={"email": async_test_user.email}
|
||||
json={"email": async_test_user.email},
|
||||
)
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
|
||||
@@ -144,7 +156,9 @@ class TestPasswordResetConfirm:
|
||||
"""Tests for POST /auth/password-reset/confirm endpoint."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_password_reset_confirm_valid_token(self, client, async_test_user, async_test_db):
|
||||
async def test_password_reset_confirm_valid_token(
|
||||
self, client, async_test_user, async_test_db
|
||||
):
|
||||
"""Test password reset confirmation with valid token."""
|
||||
# Generate valid token
|
||||
token = create_password_reset_token(async_test_user.email)
|
||||
@@ -152,10 +166,7 @@ class TestPasswordResetConfirm:
|
||||
|
||||
response = await client.post(
|
||||
"/api/v1/auth/password-reset/confirm",
|
||||
json={
|
||||
"token": token,
|
||||
"new_password": new_password
|
||||
}
|
||||
json={"token": token, "new_password": new_password},
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
@@ -164,11 +175,14 @@ class TestPasswordResetConfirm:
|
||||
assert "successfully" in data["message"].lower()
|
||||
|
||||
# Verify user can login with new password
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
result = await session.execute(select(User).where(User.id == async_test_user.id))
|
||||
result = await session.execute(
|
||||
select(User).where(User.id == async_test_user.id)
|
||||
)
|
||||
updated_user = result.scalar_one_or_none()
|
||||
from app.core.auth import verify_password
|
||||
|
||||
assert verify_password(new_password, updated_user.password_hash) is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -184,10 +198,7 @@ class TestPasswordResetConfirm:
|
||||
|
||||
response = await client.post(
|
||||
"/api/v1/auth/password-reset/confirm",
|
||||
json={
|
||||
"token": token,
|
||||
"new_password": "NewSecure123!"
|
||||
}
|
||||
json={"token": token, "new_password": "NewSecure123!"},
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_400_BAD_REQUEST
|
||||
@@ -202,10 +213,7 @@ class TestPasswordResetConfirm:
|
||||
"""Test password reset confirmation with invalid token."""
|
||||
response = await client.post(
|
||||
"/api/v1/auth/password-reset/confirm",
|
||||
json={
|
||||
"token": "invalid_token_xyz",
|
||||
"new_password": "NewSecure123!"
|
||||
}
|
||||
json={"token": "invalid_token_xyz", "new_password": "NewSecure123!"},
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_400_BAD_REQUEST
|
||||
@@ -222,19 +230,18 @@ class TestPasswordResetConfirm:
|
||||
|
||||
# Create valid token and tamper with it
|
||||
token = create_password_reset_token(async_test_user.email)
|
||||
decoded = base64.urlsafe_b64decode(token.encode('utf-8')).decode('utf-8')
|
||||
decoded = base64.urlsafe_b64decode(token.encode("utf-8")).decode("utf-8")
|
||||
token_data = json.loads(decoded)
|
||||
token_data["payload"]["email"] = "hacker@example.com"
|
||||
|
||||
# Re-encode tampered token
|
||||
tampered = base64.urlsafe_b64encode(json.dumps(token_data).encode('utf-8')).decode('utf-8')
|
||||
tampered = base64.urlsafe_b64encode(
|
||||
json.dumps(token_data).encode("utf-8")
|
||||
).decode("utf-8")
|
||||
|
||||
response = await client.post(
|
||||
"/api/v1/auth/password-reset/confirm",
|
||||
json={
|
||||
"token": tampered,
|
||||
"new_password": "NewSecure123!"
|
||||
}
|
||||
json={"token": tampered, "new_password": "NewSecure123!"},
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_400_BAD_REQUEST
|
||||
@@ -247,10 +254,7 @@ class TestPasswordResetConfirm:
|
||||
|
||||
response = await client.post(
|
||||
"/api/v1/auth/password-reset/confirm",
|
||||
json={
|
||||
"token": token,
|
||||
"new_password": "NewSecure123!"
|
||||
}
|
||||
json={"token": token, "new_password": "NewSecure123!"},
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_404_NOT_FOUND
|
||||
@@ -260,12 +264,16 @@ class TestPasswordResetConfirm:
|
||||
assert "not found" in error_msg
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_password_reset_confirm_inactive_user(self, client, async_test_user, async_test_db):
|
||||
async def test_password_reset_confirm_inactive_user(
|
||||
self, client, async_test_user, async_test_db
|
||||
):
|
||||
"""Test password reset confirmation for inactive user."""
|
||||
# Deactivate user
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
result = await session.execute(select(User).where(User.id == async_test_user.id))
|
||||
result = await session.execute(
|
||||
select(User).where(User.id == async_test_user.id)
|
||||
)
|
||||
user_in_session = result.scalar_one_or_none()
|
||||
user_in_session.is_active = False
|
||||
await session.commit()
|
||||
@@ -274,10 +282,7 @@ class TestPasswordResetConfirm:
|
||||
|
||||
response = await client.post(
|
||||
"/api/v1/auth/password-reset/confirm",
|
||||
json={
|
||||
"token": token,
|
||||
"new_password": "NewSecure123!"
|
||||
}
|
||||
json={"token": token, "new_password": "NewSecure123!"},
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_400_BAD_REQUEST
|
||||
@@ -301,10 +306,7 @@ class TestPasswordResetConfirm:
|
||||
for weak_password in weak_passwords:
|
||||
response = await client.post(
|
||||
"/api/v1/auth/password-reset/confirm",
|
||||
json={
|
||||
"token": token,
|
||||
"new_password": weak_password
|
||||
}
|
||||
json={"token": token, "new_password": weak_password},
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY
|
||||
@@ -315,15 +317,14 @@ class TestPasswordResetConfirm:
|
||||
# Missing token
|
||||
response = await client.post(
|
||||
"/api/v1/auth/password-reset/confirm",
|
||||
json={"new_password": "NewSecure123!"}
|
||||
json={"new_password": "NewSecure123!"},
|
||||
)
|
||||
assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY
|
||||
|
||||
# Missing password
|
||||
token = create_password_reset_token("test@example.com")
|
||||
response = await client.post(
|
||||
"/api/v1/auth/password-reset/confirm",
|
||||
json={"token": token}
|
||||
"/api/v1/auth/password-reset/confirm", json={"token": token}
|
||||
)
|
||||
assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY
|
||||
|
||||
@@ -333,15 +334,12 @@ class TestPasswordResetConfirm:
|
||||
token = create_password_reset_token(async_test_user.email)
|
||||
|
||||
# Mock the database commit to raise an exception
|
||||
with patch('app.api.routes.auth.user_crud.get_by_email') as mock_get:
|
||||
with patch("app.api.routes.auth.user_crud.get_by_email") as mock_get:
|
||||
mock_get.side_effect = Exception("Database error")
|
||||
|
||||
response = await client.post(
|
||||
"/api/v1/auth/password-reset/confirm",
|
||||
json={
|
||||
"token": token,
|
||||
"new_password": "NewSecure123!"
|
||||
}
|
||||
json={"token": token, "new_password": "NewSecure123!"},
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR
|
||||
@@ -351,18 +349,22 @@ class TestPasswordResetConfirm:
|
||||
assert "error" in error_msg or "resetting" in error_msg
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_password_reset_full_flow(self, client, async_test_user, async_test_db):
|
||||
async def test_password_reset_full_flow(
|
||||
self, client, async_test_user, async_test_db
|
||||
):
|
||||
"""Test complete password reset flow."""
|
||||
original_password = async_test_user.password_hash
|
||||
new_password = "BrandNew123!"
|
||||
|
||||
# Step 1: Request password reset
|
||||
with patch('app.api.routes.auth.email_service.send_password_reset_email') as mock_send:
|
||||
with patch(
|
||||
"app.api.routes.auth.email_service.send_password_reset_email"
|
||||
) as mock_send:
|
||||
mock_send.return_value = True
|
||||
|
||||
response = await client.post(
|
||||
"/api/v1/auth/password-reset/request",
|
||||
json={"email": async_test_user.email}
|
||||
json={"email": async_test_user.email},
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
@@ -374,29 +376,24 @@ class TestPasswordResetConfirm:
|
||||
# Step 2: Confirm password reset
|
||||
response = await client.post(
|
||||
"/api/v1/auth/password-reset/confirm",
|
||||
json={
|
||||
"token": reset_token,
|
||||
"new_password": new_password
|
||||
}
|
||||
json={"token": reset_token, "new_password": new_password},
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
|
||||
# Step 3: Verify old password doesn't work
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
result = await session.execute(select(User).where(User.id == async_test_user.id))
|
||||
result = await session.execute(
|
||||
select(User).where(User.id == async_test_user.id)
|
||||
)
|
||||
updated_user = result.scalar_one_or_none()
|
||||
from app.core.auth import verify_password
|
||||
assert updated_user.password_hash != original_password
|
||||
|
||||
# Step 4: Verify new password works
|
||||
response = await client.post(
|
||||
"/api/v1/auth/login",
|
||||
json={
|
||||
"email": async_test_user.email,
|
||||
"password": new_password
|
||||
}
|
||||
json={"email": async_test_user.email, "password": new_password},
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
|
||||
@@ -8,11 +8,10 @@ Critical security tests covering:
|
||||
|
||||
These tests prevent real-world attack scenarios.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from httpx import AsyncClient
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.core.auth import create_refresh_token
|
||||
from app.crud.session import session as session_crud
|
||||
from app.models.user import User
|
||||
|
||||
@@ -30,10 +29,7 @@ class TestRevokedSessionSecurity:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_refresh_token_rejected_after_logout(
|
||||
self,
|
||||
client: AsyncClient,
|
||||
async_test_db,
|
||||
async_test_user: User
|
||||
self, client: AsyncClient, async_test_db, async_test_user: User
|
||||
):
|
||||
"""
|
||||
Test that refresh tokens are rejected after session is deactivated.
|
||||
@@ -45,10 +41,10 @@ class TestRevokedSessionSecurity:
|
||||
4. Attacker tries to use stolen refresh token
|
||||
5. System MUST reject it (session revoked)
|
||||
"""
|
||||
test_engine, SessionLocal = async_test_db
|
||||
_test_engine, SessionLocal = async_test_db
|
||||
|
||||
# Step 1: Create a session and refresh token for the user
|
||||
async with SessionLocal() as session:
|
||||
async with SessionLocal():
|
||||
# Login to get tokens
|
||||
response = await client.post(
|
||||
"/api/v1/auth/login",
|
||||
@@ -64,8 +60,7 @@ class TestRevokedSessionSecurity:
|
||||
|
||||
# Step 2: Verify refresh token works before logout
|
||||
response = await client.post(
|
||||
"/api/v1/auth/refresh",
|
||||
json={"refresh_token": refresh_token}
|
||||
"/api/v1/auth/refresh", json={"refresh_token": refresh_token}
|
||||
)
|
||||
assert response.status_code == 200, "Refresh should work before logout"
|
||||
|
||||
@@ -73,14 +68,13 @@ class TestRevokedSessionSecurity:
|
||||
response = await client.post(
|
||||
"/api/v1/auth/logout",
|
||||
headers={"Authorization": f"Bearer {access_token}"},
|
||||
json={"refresh_token": refresh_token}
|
||||
json={"refresh_token": refresh_token},
|
||||
)
|
||||
assert response.status_code == 200, "Logout should succeed"
|
||||
|
||||
# Step 4: Attacker tries to use stolen refresh token
|
||||
response = await client.post(
|
||||
"/api/v1/auth/refresh",
|
||||
json={"refresh_token": refresh_token}
|
||||
"/api/v1/auth/refresh", json={"refresh_token": refresh_token}
|
||||
)
|
||||
|
||||
# Step 5: System MUST reject (covers lines 261-262)
|
||||
@@ -93,10 +87,7 @@ class TestRevokedSessionSecurity:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_refresh_token_rejected_for_deleted_session(
|
||||
self,
|
||||
client: AsyncClient,
|
||||
async_test_db,
|
||||
async_test_user: User
|
||||
self, client: AsyncClient, async_test_db, async_test_user: User
|
||||
):
|
||||
"""
|
||||
Test that tokens for deleted sessions are rejected.
|
||||
@@ -104,7 +95,7 @@ class TestRevokedSessionSecurity:
|
||||
Attack Scenario:
|
||||
Admin deletes a session from database, but attacker has the token.
|
||||
"""
|
||||
test_engine, SessionLocal = async_test_db
|
||||
_test_engine, SessionLocal = async_test_db
|
||||
|
||||
# Step 1: Login to create a session
|
||||
response = await client.post(
|
||||
@@ -120,6 +111,7 @@ class TestRevokedSessionSecurity:
|
||||
|
||||
# Step 2: Manually delete the session from database (simulating admin action)
|
||||
from app.core.auth import decode_token
|
||||
|
||||
token_data = decode_token(refresh_token, verify_type="refresh")
|
||||
jti = token_data.jti
|
||||
|
||||
@@ -132,15 +124,17 @@ class TestRevokedSessionSecurity:
|
||||
|
||||
# Step 3: Try to use the refresh token
|
||||
response = await client.post(
|
||||
"/api/v1/auth/refresh",
|
||||
json={"refresh_token": refresh_token}
|
||||
"/api/v1/auth/refresh", json={"refresh_token": refresh_token}
|
||||
)
|
||||
|
||||
# Should reject (session doesn't exist)
|
||||
assert response.status_code == 401
|
||||
data = response.json()
|
||||
if "errors" in data:
|
||||
assert "revoked" in data["errors"][0]["message"].lower() or "session" in data["errors"][0]["message"].lower()
|
||||
assert (
|
||||
"revoked" in data["errors"][0]["message"].lower()
|
||||
or "session" in data["errors"][0]["message"].lower()
|
||||
)
|
||||
else:
|
||||
assert "revoked" in data.get("detail", "").lower()
|
||||
|
||||
@@ -162,7 +156,7 @@ class TestSessionHijackingSecurity:
|
||||
client: AsyncClient,
|
||||
async_test_db,
|
||||
async_test_user: User,
|
||||
async_test_superuser: User
|
||||
async_test_superuser: User,
|
||||
):
|
||||
"""
|
||||
Test that users cannot logout other users' sessions.
|
||||
@@ -173,7 +167,7 @@ class TestSessionHijackingSecurity:
|
||||
3. User A tries to logout User B's session
|
||||
4. System MUST reject (cross-user attack)
|
||||
"""
|
||||
test_engine, SessionLocal = async_test_db
|
||||
_test_engine, _SessionLocal = async_test_db
|
||||
|
||||
# Step 1: User A logs in
|
||||
response = await client.post(
|
||||
@@ -202,8 +196,10 @@ class TestSessionHijackingSecurity:
|
||||
# Step 3: User A tries to logout User B's session using User B's refresh token
|
||||
response = await client.post(
|
||||
"/api/v1/auth/logout",
|
||||
headers={"Authorization": f"Bearer {user_a_access}"}, # User A's access token
|
||||
json={"refresh_token": user_b_refresh} # But User B's refresh token
|
||||
headers={
|
||||
"Authorization": f"Bearer {user_a_access}"
|
||||
}, # User A's access token
|
||||
json={"refresh_token": user_b_refresh}, # But User B's refresh token
|
||||
)
|
||||
|
||||
# Step 4: System MUST reject (covers lines 509-513)
|
||||
@@ -217,9 +213,7 @@ class TestSessionHijackingSecurity:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_users_can_logout_their_own_sessions(
|
||||
self,
|
||||
client: AsyncClient,
|
||||
async_test_user: User
|
||||
self, client: AsyncClient, async_test_user: User
|
||||
):
|
||||
"""
|
||||
Sanity check: Users CAN logout their own sessions.
|
||||
@@ -241,6 +235,8 @@ class TestSessionHijackingSecurity:
|
||||
response = await client.post(
|
||||
"/api/v1/auth/logout",
|
||||
headers={"Authorization": f"Bearer {tokens['access_token']}"},
|
||||
json={"refresh_token": tokens["refresh_token"]}
|
||||
json={"refresh_token": tokens["refresh_token"]},
|
||||
)
|
||||
assert response.status_code == 200, (
|
||||
"Users should be able to logout their own sessions"
|
||||
)
|
||||
assert response.status_code == 200, "Users should be able to logout their own sessions"
|
||||
|
||||
@@ -5,16 +5,18 @@ Tests for organization routes (user endpoints).
|
||||
These test the routes in app/api/routes/organizations.py which allow
|
||||
users to view and manage organizations they belong to.
|
||||
"""
|
||||
|
||||
from unittest.mock import patch
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
from fastapi import status
|
||||
from uuid import uuid4
|
||||
from unittest.mock import patch, AsyncMock
|
||||
|
||||
from app.core.auth import get_password_hash
|
||||
from app.models.organization import Organization
|
||||
from app.models.user import User
|
||||
from app.models.user_organization import UserOrganization, OrganizationRole
|
||||
from app.core.auth import get_password_hash
|
||||
from app.models.user_organization import OrganizationRole, UserOrganization
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
@@ -22,10 +24,7 @@ async def user_token(client, async_test_user):
|
||||
"""Get access token for regular user."""
|
||||
response = await client.post(
|
||||
"/api/v1/auth/login",
|
||||
json={
|
||||
"email": "testuser@example.com",
|
||||
"password": "TestPassword123!"
|
||||
}
|
||||
json={"email": "testuser@example.com", "password": "TestPassword123!"},
|
||||
)
|
||||
assert response.status_code == 200
|
||||
return response.json()["access_token"]
|
||||
@@ -34,7 +33,7 @@ async def user_token(client, async_test_user):
|
||||
@pytest_asyncio.fixture
|
||||
async def second_user(async_test_db):
|
||||
"""Create a second test user."""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
user = User(
|
||||
id=uuid4(),
|
||||
@@ -56,12 +55,12 @@ async def second_user(async_test_db):
|
||||
@pytest_asyncio.fixture
|
||||
async def test_org_with_user_member(async_test_db, async_test_user):
|
||||
"""Create a test organization with async_test_user as a member."""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
org = Organization(
|
||||
name="Member Org",
|
||||
slug="member-org",
|
||||
description="Test organization where user is a member"
|
||||
description="Test organization where user is a member",
|
||||
)
|
||||
session.add(org)
|
||||
await session.commit()
|
||||
@@ -72,7 +71,7 @@ async def test_org_with_user_member(async_test_db, async_test_user):
|
||||
user_id=async_test_user.id,
|
||||
organization_id=org.id,
|
||||
role=OrganizationRole.MEMBER,
|
||||
is_active=True
|
||||
is_active=True,
|
||||
)
|
||||
session.add(membership)
|
||||
await session.commit()
|
||||
@@ -83,12 +82,12 @@ async def test_org_with_user_member(async_test_db, async_test_user):
|
||||
@pytest_asyncio.fixture
|
||||
async def test_org_with_user_admin(async_test_db, async_test_user):
|
||||
"""Create a test organization with async_test_user as an admin."""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
org = Organization(
|
||||
name="Admin Org",
|
||||
slug="admin-org",
|
||||
description="Test organization where user is an admin"
|
||||
description="Test organization where user is an admin",
|
||||
)
|
||||
session.add(org)
|
||||
await session.commit()
|
||||
@@ -99,7 +98,7 @@ async def test_org_with_user_admin(async_test_db, async_test_user):
|
||||
user_id=async_test_user.id,
|
||||
organization_id=org.id,
|
||||
role=OrganizationRole.ADMIN,
|
||||
is_active=True
|
||||
is_active=True,
|
||||
)
|
||||
session.add(membership)
|
||||
await session.commit()
|
||||
@@ -110,12 +109,12 @@ async def test_org_with_user_admin(async_test_db, async_test_user):
|
||||
@pytest_asyncio.fixture
|
||||
async def test_org_with_user_owner(async_test_db, async_test_user):
|
||||
"""Create a test organization with async_test_user as owner."""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
org = Organization(
|
||||
name="Owner Org",
|
||||
slug="owner-org",
|
||||
description="Test organization where user is owner"
|
||||
description="Test organization where user is owner",
|
||||
)
|
||||
session.add(org)
|
||||
await session.commit()
|
||||
@@ -126,7 +125,7 @@ async def test_org_with_user_owner(async_test_db, async_test_user):
|
||||
user_id=async_test_user.id,
|
||||
organization_id=org.id,
|
||||
role=OrganizationRole.OWNER,
|
||||
is_active=True
|
||||
is_active=True,
|
||||
)
|
||||
session.add(membership)
|
||||
await session.commit()
|
||||
@@ -136,21 +135,18 @@ async def test_org_with_user_owner(async_test_db, async_test_user):
|
||||
|
||||
# ===== GET /api/v1/organizations/me =====
|
||||
|
||||
|
||||
class TestGetMyOrganizations:
|
||||
"""Tests for GET /api/v1/organizations/me endpoint."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_my_organizations_success(
|
||||
self,
|
||||
client,
|
||||
user_token,
|
||||
test_org_with_user_member,
|
||||
test_org_with_user_admin
|
||||
self, client, user_token, test_org_with_user_member, test_org_with_user_admin
|
||||
):
|
||||
"""Test successfully getting user's organizations (covers lines 54-79)."""
|
||||
response = await client.get(
|
||||
"/api/v1/organizations/me",
|
||||
headers={"Authorization": f"Bearer {user_token}"}
|
||||
headers={"Authorization": f"Bearer {user_token}"},
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
@@ -167,21 +163,15 @@ class TestGetMyOrganizations:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_my_organizations_filter_active(
|
||||
self,
|
||||
client,
|
||||
async_test_db,
|
||||
async_test_user,
|
||||
user_token
|
||||
self, client, async_test_db, async_test_user, user_token
|
||||
):
|
||||
"""Test filtering organizations by active status."""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
# Create active org
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
active_org = Organization(
|
||||
name="Active Org",
|
||||
slug="active-org-filter",
|
||||
is_active=True
|
||||
name="Active Org", slug="active-org-filter", is_active=True
|
||||
)
|
||||
session.add(active_org)
|
||||
await session.commit()
|
||||
@@ -192,14 +182,14 @@ class TestGetMyOrganizations:
|
||||
user_id=async_test_user.id,
|
||||
organization_id=active_org.id,
|
||||
role=OrganizationRole.MEMBER,
|
||||
is_active=True
|
||||
is_active=True,
|
||||
)
|
||||
session.add(membership)
|
||||
await session.commit()
|
||||
|
||||
response = await client.get(
|
||||
"/api/v1/organizations/me?is_active=true",
|
||||
headers={"Authorization": f"Bearer {user_token}"}
|
||||
headers={"Authorization": f"Bearer {user_token}"},
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
@@ -209,7 +199,7 @@ class TestGetMyOrganizations:
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_my_organizations_empty(self, client, async_test_db):
|
||||
"""Test getting organizations when user has none."""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
# Create user with no org memberships
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
@@ -219,7 +209,7 @@ class TestGetMyOrganizations:
|
||||
password_hash=get_password_hash("TestPassword123!"),
|
||||
first_name="No",
|
||||
last_name="Org",
|
||||
is_active=True
|
||||
is_active=True,
|
||||
)
|
||||
session.add(user)
|
||||
await session.commit()
|
||||
@@ -227,13 +217,12 @@ class TestGetMyOrganizations:
|
||||
# Login to get token
|
||||
login_response = await client.post(
|
||||
"/api/v1/auth/login",
|
||||
json={"email": "noorg@example.com", "password": "TestPassword123!"}
|
||||
json={"email": "noorg@example.com", "password": "TestPassword123!"},
|
||||
)
|
||||
token = login_response.json()["access_token"]
|
||||
|
||||
response = await client.get(
|
||||
"/api/v1/organizations/me",
|
||||
headers={"Authorization": f"Bearer {token}"}
|
||||
"/api/v1/organizations/me", headers={"Authorization": f"Bearer {token}"}
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
@@ -243,20 +232,18 @@ class TestGetMyOrganizations:
|
||||
|
||||
# ===== GET /api/v1/organizations/{organization_id} =====
|
||||
|
||||
|
||||
class TestGetOrganization:
|
||||
"""Tests for GET /api/v1/organizations/{organization_id} endpoint."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_organization_success(
|
||||
self,
|
||||
client,
|
||||
user_token,
|
||||
test_org_with_user_member
|
||||
self, client, user_token, test_org_with_user_member
|
||||
):
|
||||
"""Test successfully getting organization details (covers lines 103-122)."""
|
||||
response = await client.get(
|
||||
f"/api/v1/organizations/{test_org_with_user_member.id}",
|
||||
headers={"Authorization": f"Bearer {user_token}"}
|
||||
headers={"Authorization": f"Bearer {user_token}"},
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
@@ -272,7 +259,7 @@ class TestGetOrganization:
|
||||
fake_org_id = uuid4()
|
||||
response = await client.get(
|
||||
f"/api/v1/organizations/{fake_org_id}",
|
||||
headers={"Authorization": f"Bearer {user_token}"}
|
||||
headers={"Authorization": f"Bearer {user_token}"},
|
||||
)
|
||||
|
||||
# Permission dependency checks membership before endpoint logic
|
||||
@@ -283,20 +270,14 @@ class TestGetOrganization:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_organization_not_member(
|
||||
self,
|
||||
client,
|
||||
async_test_db,
|
||||
async_test_user
|
||||
self, client, async_test_db, async_test_user
|
||||
):
|
||||
"""Test getting organization where user is not a member fails."""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
# Create org without adding user
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
org = Organization(
|
||||
name="Not Member Org",
|
||||
slug="not-member-org"
|
||||
)
|
||||
org = Organization(name="Not Member Org", slug="not-member-org")
|
||||
session.add(org)
|
||||
await session.commit()
|
||||
await session.refresh(org)
|
||||
@@ -305,13 +286,13 @@ class TestGetOrganization:
|
||||
# Login as user
|
||||
login_response = await client.post(
|
||||
"/api/v1/auth/login",
|
||||
json={"email": "testuser@example.com", "password": "TestPassword123!"}
|
||||
json={"email": "testuser@example.com", "password": "TestPassword123!"},
|
||||
)
|
||||
token = login_response.json()["access_token"]
|
||||
|
||||
response = await client.get(
|
||||
f"/api/v1/organizations/{org_id}",
|
||||
headers={"Authorization": f"Bearer {token}"}
|
||||
headers={"Authorization": f"Bearer {token}"},
|
||||
)
|
||||
|
||||
# Should fail permission check
|
||||
@@ -320,6 +301,7 @@ class TestGetOrganization:
|
||||
|
||||
# ===== GET /api/v1/organizations/{organization_id}/members =====
|
||||
|
||||
|
||||
class TestGetOrganizationMembers:
|
||||
"""Tests for GET /api/v1/organizations/{organization_id}/members endpoint."""
|
||||
|
||||
@@ -331,10 +313,10 @@ class TestGetOrganizationMembers:
|
||||
async_test_user,
|
||||
second_user,
|
||||
user_token,
|
||||
test_org_with_user_member
|
||||
test_org_with_user_member,
|
||||
):
|
||||
"""Test successfully getting organization members (covers lines 150-168)."""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
# Add second user to org
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
@@ -342,14 +324,14 @@ class TestGetOrganizationMembers:
|
||||
user_id=second_user.id,
|
||||
organization_id=test_org_with_user_member.id,
|
||||
role=OrganizationRole.MEMBER,
|
||||
is_active=True
|
||||
is_active=True,
|
||||
)
|
||||
session.add(membership)
|
||||
await session.commit()
|
||||
|
||||
response = await client.get(
|
||||
f"/api/v1/organizations/{test_org_with_user_member.id}/members",
|
||||
headers={"Authorization": f"Bearer {user_token}"}
|
||||
headers={"Authorization": f"Bearer {user_token}"},
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
@@ -360,15 +342,12 @@ class TestGetOrganizationMembers:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_organization_members_with_pagination(
|
||||
self,
|
||||
client,
|
||||
user_token,
|
||||
test_org_with_user_member
|
||||
self, client, user_token, test_org_with_user_member
|
||||
):
|
||||
"""Test pagination parameters."""
|
||||
response = await client.get(
|
||||
f"/api/v1/organizations/{test_org_with_user_member.id}/members?page=1&limit=10",
|
||||
headers={"Authorization": f"Bearer {user_token}"}
|
||||
headers={"Authorization": f"Bearer {user_token}"},
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
@@ -385,10 +364,10 @@ class TestGetOrganizationMembers:
|
||||
async_test_user,
|
||||
second_user,
|
||||
user_token,
|
||||
test_org_with_user_member
|
||||
test_org_with_user_member,
|
||||
):
|
||||
"""Test filtering members by active status."""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
# Add second user as inactive member
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
@@ -396,7 +375,7 @@ class TestGetOrganizationMembers:
|
||||
user_id=second_user.id,
|
||||
organization_id=test_org_with_user_member.id,
|
||||
role=OrganizationRole.MEMBER,
|
||||
is_active=False
|
||||
is_active=False,
|
||||
)
|
||||
session.add(membership)
|
||||
await session.commit()
|
||||
@@ -404,7 +383,7 @@ class TestGetOrganizationMembers:
|
||||
# Filter for active only
|
||||
response = await client.get(
|
||||
f"/api/v1/organizations/{test_org_with_user_member.id}/members?is_active=true",
|
||||
headers={"Authorization": f"Bearer {user_token}"}
|
||||
headers={"Authorization": f"Bearer {user_token}"},
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
@@ -416,31 +395,26 @@ class TestGetOrganizationMembers:
|
||||
|
||||
# ===== PUT /api/v1/organizations/{organization_id} =====
|
||||
|
||||
|
||||
class TestUpdateOrganization:
|
||||
"""Tests for PUT /api/v1/organizations/{organization_id} endpoint."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_organization_as_admin_success(
|
||||
self,
|
||||
client,
|
||||
async_test_user,
|
||||
test_org_with_user_admin
|
||||
self, client, async_test_user, test_org_with_user_admin
|
||||
):
|
||||
"""Test successfully updating organization as admin (covers lines 193-215)."""
|
||||
# Login as admin user
|
||||
login_response = await client.post(
|
||||
"/api/v1/auth/login",
|
||||
json={"email": "testuser@example.com", "password": "TestPassword123!"}
|
||||
json={"email": "testuser@example.com", "password": "TestPassword123!"},
|
||||
)
|
||||
admin_token = login_response.json()["access_token"]
|
||||
|
||||
response = await client.put(
|
||||
f"/api/v1/organizations/{test_org_with_user_admin.id}",
|
||||
json={
|
||||
"name": "Updated Admin Org",
|
||||
"description": "Updated description"
|
||||
},
|
||||
headers={"Authorization": f"Bearer {admin_token}"}
|
||||
json={"name": "Updated Admin Org", "description": "Updated description"},
|
||||
headers={"Authorization": f"Bearer {admin_token}"},
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
@@ -450,23 +424,20 @@ class TestUpdateOrganization:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_organization_as_owner_success(
|
||||
self,
|
||||
client,
|
||||
async_test_user,
|
||||
test_org_with_user_owner
|
||||
self, client, async_test_user, test_org_with_user_owner
|
||||
):
|
||||
"""Test successfully updating organization as owner."""
|
||||
# Login as owner user
|
||||
login_response = await client.post(
|
||||
"/api/v1/auth/login",
|
||||
json={"email": "testuser@example.com", "password": "TestPassword123!"}
|
||||
json={"email": "testuser@example.com", "password": "TestPassword123!"},
|
||||
)
|
||||
owner_token = login_response.json()["access_token"]
|
||||
|
||||
response = await client.put(
|
||||
f"/api/v1/organizations/{test_org_with_user_owner.id}",
|
||||
json={"name": "Updated Owner Org"},
|
||||
headers={"Authorization": f"Bearer {owner_token}"}
|
||||
headers={"Authorization": f"Bearer {owner_token}"},
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
@@ -475,16 +446,13 @@ class TestUpdateOrganization:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_organization_as_member_fails(
|
||||
self,
|
||||
client,
|
||||
user_token,
|
||||
test_org_with_user_member
|
||||
self, client, user_token, test_org_with_user_member
|
||||
):
|
||||
"""Test updating organization as regular member fails."""
|
||||
response = await client.put(
|
||||
f"/api/v1/organizations/{test_org_with_user_member.id}",
|
||||
json={"name": "Should Fail"},
|
||||
headers={"Authorization": f"Bearer {user_token}"}
|
||||
headers={"Authorization": f"Bearer {user_token}"},
|
||||
)
|
||||
|
||||
# Should fail permission check (need admin or owner)
|
||||
@@ -492,15 +460,13 @@ class TestUpdateOrganization:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_organization_not_found(
|
||||
self,
|
||||
client,
|
||||
test_org_with_user_admin
|
||||
self, client, test_org_with_user_admin
|
||||
):
|
||||
"""Test updating nonexistent organization returns 403 (permission check first)."""
|
||||
# Login as admin
|
||||
login_response = await client.post(
|
||||
"/api/v1/auth/login",
|
||||
json={"email": "testuser@example.com", "password": "TestPassword123!"}
|
||||
json={"email": "testuser@example.com", "password": "TestPassword123!"},
|
||||
)
|
||||
admin_token = login_response.json()["access_token"]
|
||||
|
||||
@@ -508,7 +474,7 @@ class TestUpdateOrganization:
|
||||
response = await client.put(
|
||||
f"/api/v1/organizations/{fake_org_id}",
|
||||
json={"name": "Updated"},
|
||||
headers={"Authorization": f"Bearer {admin_token}"}
|
||||
headers={"Authorization": f"Bearer {admin_token}"},
|
||||
)
|
||||
|
||||
# Permission dependency checks admin role before endpoint logic
|
||||
@@ -520,6 +486,7 @@ class TestUpdateOrganization:
|
||||
|
||||
# ===== Authentication Tests =====
|
||||
|
||||
|
||||
class TestOrganizationAuthentication:
|
||||
"""Test authentication requirements for organization endpoints."""
|
||||
|
||||
@@ -548,14 +515,14 @@ class TestOrganizationAuthentication:
|
||||
"""Test unauthenticated access to update fails."""
|
||||
fake_id = uuid4()
|
||||
response = await client.put(
|
||||
f"/api/v1/organizations/{fake_id}",
|
||||
json={"name": "Test"}
|
||||
f"/api/v1/organizations/{fake_id}", json={"name": "Test"}
|
||||
)
|
||||
assert response.status_code == status.HTTP_401_UNAUTHORIZED
|
||||
|
||||
|
||||
# ===== Exception Handler Tests (Database Error Scenarios) =====
|
||||
|
||||
|
||||
class TestOrganizationExceptionHandlers:
|
||||
"""
|
||||
Test exception handlers in organization endpoints.
|
||||
@@ -566,86 +533,74 @@ class TestOrganizationExceptionHandlers:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_my_organizations_database_error(
|
||||
self,
|
||||
client,
|
||||
user_token,
|
||||
test_org_with_user_member
|
||||
self, client, user_token, test_org_with_user_member
|
||||
):
|
||||
"""Test generic exception handler in get_my_organizations (covers lines 81-83)."""
|
||||
with patch(
|
||||
"app.crud.organization.organization.get_user_organizations_with_details",
|
||||
side_effect=Exception("Database connection lost")
|
||||
side_effect=Exception("Database connection lost"),
|
||||
):
|
||||
# The exception handler logs and re-raises, so we expect the exception
|
||||
# to propagate (which proves the handler executed)
|
||||
with pytest.raises(Exception, match="Database connection lost"):
|
||||
await client.get(
|
||||
"/api/v1/organizations/me",
|
||||
headers={"Authorization": f"Bearer {user_token}"}
|
||||
headers={"Authorization": f"Bearer {user_token}"},
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_organization_database_error(
|
||||
self,
|
||||
client,
|
||||
user_token,
|
||||
test_org_with_user_member
|
||||
self, client, user_token, test_org_with_user_member
|
||||
):
|
||||
"""Test generic exception handler in get_organization (covers lines 124-128)."""
|
||||
with patch(
|
||||
"app.crud.organization.organization.get",
|
||||
side_effect=Exception("Database timeout")
|
||||
side_effect=Exception("Database timeout"),
|
||||
):
|
||||
with pytest.raises(Exception, match="Database timeout"):
|
||||
await client.get(
|
||||
f"/api/v1/organizations/{test_org_with_user_member.id}",
|
||||
headers={"Authorization": f"Bearer {user_token}"}
|
||||
headers={"Authorization": f"Bearer {user_token}"},
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_organization_members_database_error(
|
||||
self,
|
||||
client,
|
||||
user_token,
|
||||
test_org_with_user_member
|
||||
self, client, user_token, test_org_with_user_member
|
||||
):
|
||||
"""Test generic exception handler in get_organization_members (covers lines 170-172)."""
|
||||
with patch(
|
||||
"app.crud.organization.organization.get_organization_members",
|
||||
side_effect=Exception("Connection pool exhausted")
|
||||
side_effect=Exception("Connection pool exhausted"),
|
||||
):
|
||||
with pytest.raises(Exception, match="Connection pool exhausted"):
|
||||
await client.get(
|
||||
f"/api/v1/organizations/{test_org_with_user_member.id}/members",
|
||||
headers={"Authorization": f"Bearer {user_token}"}
|
||||
headers={"Authorization": f"Bearer {user_token}"},
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_organization_database_error(
|
||||
self,
|
||||
client,
|
||||
async_test_user,
|
||||
test_org_with_user_admin
|
||||
self, client, async_test_user, test_org_with_user_admin
|
||||
):
|
||||
"""Test generic exception handler in update_organization (covers lines 217-221)."""
|
||||
# Login as admin user
|
||||
login_response = await client.post(
|
||||
"/api/v1/auth/login",
|
||||
json={"email": "testuser@example.com", "password": "TestPassword123!"}
|
||||
json={"email": "testuser@example.com", "password": "TestPassword123!"},
|
||||
)
|
||||
admin_token = login_response.json()["access_token"]
|
||||
|
||||
with patch(
|
||||
"app.crud.organization.organization.get",
|
||||
return_value=test_org_with_user_admin
|
||||
return_value=test_org_with_user_admin,
|
||||
):
|
||||
with patch(
|
||||
"app.crud.organization.organization.update",
|
||||
side_effect=Exception("Write lock timeout")
|
||||
side_effect=Exception("Write lock timeout"),
|
||||
):
|
||||
with pytest.raises(Exception, match="Write lock timeout"):
|
||||
await client.put(
|
||||
f"/api/v1/organizations/{test_org_with_user_admin.id}",
|
||||
json={"name": "Should Fail"},
|
||||
headers={"Authorization": f"Bearer {admin_token}"}
|
||||
headers={"Authorization": f"Bearer {admin_token}"},
|
||||
)
|
||||
|
||||
@@ -5,15 +5,17 @@ Tests for permission dependencies - CRITICAL SECURITY PATHS.
|
||||
These tests ensure superusers can bypass organization checks correctly,
|
||||
and that regular users are properly blocked.
|
||||
"""
|
||||
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
from fastapi import status
|
||||
from uuid import uuid4
|
||||
|
||||
from app.core.auth import get_password_hash
|
||||
from app.models.organization import Organization
|
||||
from app.models.user import User
|
||||
from app.models.user_organization import UserOrganization, OrganizationRole
|
||||
from app.core.auth import get_password_hash
|
||||
from app.models.user_organization import OrganizationRole, UserOrganization
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
@@ -21,10 +23,7 @@ async def superuser_token(client, async_test_superuser):
|
||||
"""Get access token for superuser."""
|
||||
response = await client.post(
|
||||
"/api/v1/auth/login",
|
||||
json={
|
||||
"email": "superuser@example.com",
|
||||
"password": "SuperPassword123!"
|
||||
}
|
||||
json={"email": "superuser@example.com", "password": "SuperPassword123!"},
|
||||
)
|
||||
assert response.status_code == 200
|
||||
return response.json()["access_token"]
|
||||
@@ -35,10 +34,7 @@ async def regular_user_token(client, async_test_user):
|
||||
"""Get access token for regular user."""
|
||||
response = await client.post(
|
||||
"/api/v1/auth/login",
|
||||
json={
|
||||
"email": "testuser@example.com",
|
||||
"password": "TestPassword123!"
|
||||
}
|
||||
json={"email": "testuser@example.com", "password": "TestPassword123!"},
|
||||
)
|
||||
assert response.status_code == 200
|
||||
return response.json()["access_token"]
|
||||
@@ -47,12 +43,12 @@ async def regular_user_token(client, async_test_user):
|
||||
@pytest_asyncio.fixture
|
||||
async def test_org_no_members(async_test_db):
|
||||
"""Create a test organization with NO members."""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
org = Organization(
|
||||
name="No Members Org",
|
||||
slug="no-members-org",
|
||||
description="Test org with no members"
|
||||
description="Test org with no members",
|
||||
)
|
||||
session.add(org)
|
||||
await session.commit()
|
||||
@@ -63,12 +59,12 @@ async def test_org_no_members(async_test_db):
|
||||
@pytest_asyncio.fixture
|
||||
async def test_org_with_member(async_test_db, async_test_user):
|
||||
"""Create a test organization with async_test_user as member (not admin)."""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
org = Organization(
|
||||
name="Member Only Org",
|
||||
slug="member-only-org",
|
||||
description="Test org where user is just a member"
|
||||
description="Test org where user is just a member",
|
||||
)
|
||||
session.add(org)
|
||||
await session.commit()
|
||||
@@ -79,7 +75,7 @@ async def test_org_with_member(async_test_db, async_test_user):
|
||||
user_id=async_test_user.id,
|
||||
organization_id=org.id,
|
||||
role=OrganizationRole.MEMBER,
|
||||
is_active=True
|
||||
is_active=True,
|
||||
)
|
||||
session.add(membership)
|
||||
await session.commit()
|
||||
@@ -89,6 +85,7 @@ async def test_org_with_member(async_test_db, async_test_user):
|
||||
|
||||
# ===== CRITICAL SECURITY TESTS: Superuser Bypass =====
|
||||
|
||||
|
||||
class TestSuperuserBypass:
|
||||
"""
|
||||
CRITICAL: Test that superusers can bypass organization checks.
|
||||
@@ -99,10 +96,7 @@ class TestSuperuserBypass:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_superuser_can_access_org_not_member_of(
|
||||
self,
|
||||
client,
|
||||
superuser_token,
|
||||
test_org_no_members
|
||||
self, client, superuser_token, test_org_no_members
|
||||
):
|
||||
"""
|
||||
CRITICAL: Superuser should bypass membership check (covers line 175).
|
||||
@@ -111,7 +105,7 @@ class TestSuperuserBypass:
|
||||
"""
|
||||
response = await client.get(
|
||||
f"/api/v1/organizations/{test_org_no_members.id}",
|
||||
headers={"Authorization": f"Bearer {superuser_token}"}
|
||||
headers={"Authorization": f"Bearer {superuser_token}"},
|
||||
)
|
||||
|
||||
# Superuser should succeed even though they're not a member
|
||||
@@ -121,15 +115,12 @@ class TestSuperuserBypass:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_regular_user_cannot_access_org_not_member_of(
|
||||
self,
|
||||
client,
|
||||
regular_user_token,
|
||||
test_org_no_members
|
||||
self, client, regular_user_token, test_org_no_members
|
||||
):
|
||||
"""Regular user should be blocked from org they're not a member of."""
|
||||
response = await client.get(
|
||||
f"/api/v1/organizations/{test_org_no_members.id}",
|
||||
headers={"Authorization": f"Bearer {regular_user_token}"}
|
||||
headers={"Authorization": f"Bearer {regular_user_token}"},
|
||||
)
|
||||
|
||||
# Regular user should fail permission check
|
||||
@@ -137,10 +128,7 @@ class TestSuperuserBypass:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_superuser_can_update_org_not_admin_of(
|
||||
self,
|
||||
client,
|
||||
superuser_token,
|
||||
test_org_no_members
|
||||
self, client, superuser_token, test_org_no_members
|
||||
):
|
||||
"""
|
||||
CRITICAL: Superuser should bypass admin check (covers line 99).
|
||||
@@ -150,7 +138,7 @@ class TestSuperuserBypass:
|
||||
response = await client.put(
|
||||
f"/api/v1/organizations/{test_org_no_members.id}",
|
||||
json={"name": "Updated by Superuser"},
|
||||
headers={"Authorization": f"Bearer {superuser_token}"}
|
||||
headers={"Authorization": f"Bearer {superuser_token}"},
|
||||
)
|
||||
|
||||
# Superuser should succeed in updating org
|
||||
@@ -160,16 +148,13 @@ class TestSuperuserBypass:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_regular_member_cannot_update_org(
|
||||
self,
|
||||
client,
|
||||
regular_user_token,
|
||||
test_org_with_member
|
||||
self, client, regular_user_token, test_org_with_member
|
||||
):
|
||||
"""Regular member (not admin) should NOT be able to update org."""
|
||||
response = await client.put(
|
||||
f"/api/v1/organizations/{test_org_with_member.id}",
|
||||
json={"name": "Should Fail"},
|
||||
headers={"Authorization": f"Bearer {regular_user_token}"}
|
||||
headers={"Authorization": f"Bearer {regular_user_token}"},
|
||||
)
|
||||
|
||||
# Member should fail - need admin or owner role
|
||||
@@ -177,15 +162,12 @@ class TestSuperuserBypass:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_superuser_can_list_org_members_not_member_of(
|
||||
self,
|
||||
client,
|
||||
superuser_token,
|
||||
test_org_no_members
|
||||
self, client, superuser_token, test_org_no_members
|
||||
):
|
||||
"""CRITICAL: Superuser should bypass membership check to list members."""
|
||||
response = await client.get(
|
||||
f"/api/v1/organizations/{test_org_no_members.id}/members",
|
||||
headers={"Authorization": f"Bearer {superuser_token}"}
|
||||
headers={"Authorization": f"Bearer {superuser_token}"},
|
||||
)
|
||||
|
||||
# Superuser should succeed
|
||||
@@ -197,13 +179,14 @@ class TestSuperuserBypass:
|
||||
|
||||
# ===== Edge Cases and Security Tests =====
|
||||
|
||||
|
||||
class TestPermissionEdgeCases:
|
||||
"""Test edge cases in permission system."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_inactive_user_blocked(self, client, async_test_db):
|
||||
"""Test that inactive users are blocked."""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
# Create inactive user
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
@@ -213,7 +196,7 @@ class TestPermissionEdgeCases:
|
||||
password_hash=get_password_hash("TestPassword123!"),
|
||||
first_name="Inactive",
|
||||
last_name="User",
|
||||
is_active=False # INACTIVE
|
||||
is_active=False, # INACTIVE
|
||||
)
|
||||
session.add(user)
|
||||
await session.commit()
|
||||
@@ -222,7 +205,7 @@ class TestPermissionEdgeCases:
|
||||
# But accessing protected endpoints should fail
|
||||
login_response = await client.post(
|
||||
"/api/v1/auth/login",
|
||||
json={"email": "inactive@example.com", "password": "TestPassword123!"}
|
||||
json={"email": "inactive@example.com", "password": "TestPassword123!"},
|
||||
)
|
||||
|
||||
# Login might fail for inactive users depending on auth implementation
|
||||
@@ -231,18 +214,18 @@ class TestPermissionEdgeCases:
|
||||
|
||||
# Try to access protected endpoint
|
||||
response = await client.get(
|
||||
"/api/v1/users/me",
|
||||
headers={"Authorization": f"Bearer {token}"}
|
||||
"/api/v1/users/me", headers={"Authorization": f"Bearer {token}"}
|
||||
)
|
||||
|
||||
# Should be blocked
|
||||
assert response.status_code in [status.HTTP_401_UNAUTHORIZED, status.HTTP_403_FORBIDDEN]
|
||||
assert response.status_code in [
|
||||
status.HTTP_401_UNAUTHORIZED,
|
||||
status.HTTP_403_FORBIDDEN,
|
||||
]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_nonexistent_organization_returns_403_not_404(
|
||||
self,
|
||||
client,
|
||||
regular_user_token
|
||||
self, client, regular_user_token
|
||||
):
|
||||
"""
|
||||
Test that accessing nonexistent org returns 403, not 404.
|
||||
@@ -254,7 +237,7 @@ class TestPermissionEdgeCases:
|
||||
fake_org_id = uuid4()
|
||||
response = await client.get(
|
||||
f"/api/v1/organizations/{fake_org_id}",
|
||||
headers={"Authorization": f"Bearer {regular_user_token}"}
|
||||
headers={"Authorization": f"Bearer {regular_user_token}"},
|
||||
)
|
||||
|
||||
# Should get 403 (not a member), not 404 (doesn't exist)
|
||||
@@ -264,18 +247,16 @@ class TestPermissionEdgeCases:
|
||||
|
||||
# ===== Admin Role Tests =====
|
||||
|
||||
|
||||
class TestAdminRolePermissions:
|
||||
"""Test admin role can perform admin actions."""
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def test_org_with_admin(self, async_test_db, async_test_user):
|
||||
"""Create org where user is ADMIN."""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
org = Organization(
|
||||
name="Admin Org",
|
||||
slug="admin-org"
|
||||
)
|
||||
org = Organization(name="Admin Org", slug="admin-org")
|
||||
session.add(org)
|
||||
await session.commit()
|
||||
await session.refresh(org)
|
||||
@@ -284,7 +265,7 @@ class TestAdminRolePermissions:
|
||||
user_id=async_test_user.id,
|
||||
organization_id=org.id,
|
||||
role=OrganizationRole.ADMIN,
|
||||
is_active=True
|
||||
is_active=True,
|
||||
)
|
||||
session.add(membership)
|
||||
await session.commit()
|
||||
@@ -293,16 +274,13 @@ class TestAdminRolePermissions:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_admin_can_update_org(
|
||||
self,
|
||||
client,
|
||||
regular_user_token,
|
||||
test_org_with_admin
|
||||
self, client, regular_user_token, test_org_with_admin
|
||||
):
|
||||
"""Admin should be able to update organization."""
|
||||
response = await client.put(
|
||||
f"/api/v1/organizations/{test_org_with_admin.id}",
|
||||
json={"name": "Updated by Admin"},
|
||||
headers={"Authorization": f"Bearer {regular_user_token}"}
|
||||
headers={"Authorization": f"Bearer {regular_user_token}"},
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
|
||||
@@ -7,13 +7,13 @@ Critical security tests covering:
|
||||
|
||||
These tests prevent unauthorized access and privilege escalation.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from httpx import AsyncClient
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.models.user import User
|
||||
from app.models.organization import Organization
|
||||
from app.crud.user import user as user_crud
|
||||
from app.models.organization import Organization
|
||||
from app.models.user import User
|
||||
|
||||
|
||||
class TestInactiveUserBlocking:
|
||||
@@ -29,11 +29,7 @@ class TestInactiveUserBlocking:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_inactive_user_cannot_access_protected_endpoints(
|
||||
self,
|
||||
client: AsyncClient,
|
||||
async_test_db,
|
||||
async_test_user: User,
|
||||
user_token: str
|
||||
self, client: AsyncClient, async_test_db, async_test_user: User, user_token: str
|
||||
):
|
||||
"""
|
||||
Test that inactive users are blocked from protected endpoints.
|
||||
@@ -44,12 +40,11 @@ class TestInactiveUserBlocking:
|
||||
3. User tries to access protected endpoint with valid token
|
||||
4. System MUST reject (account inactive)
|
||||
"""
|
||||
test_engine, SessionLocal = async_test_db
|
||||
_test_engine, SessionLocal = async_test_db
|
||||
|
||||
# Step 1: Verify user can access endpoint while active
|
||||
response = await client.get(
|
||||
"/api/v1/users/me",
|
||||
headers={"Authorization": f"Bearer {user_token}"}
|
||||
"/api/v1/users/me", headers={"Authorization": f"Bearer {user_token}"}
|
||||
)
|
||||
assert response.status_code == 200, "Active user should have access"
|
||||
|
||||
@@ -61,8 +56,7 @@ class TestInactiveUserBlocking:
|
||||
|
||||
# Step 3: User tries to access endpoint with same token
|
||||
response = await client.get(
|
||||
"/api/v1/users/me",
|
||||
headers={"Authorization": f"Bearer {user_token}"}
|
||||
"/api/v1/users/me", headers={"Authorization": f"Bearer {user_token}"}
|
||||
)
|
||||
|
||||
# Step 4: System MUST reject (covers lines 52-57)
|
||||
@@ -75,18 +69,14 @@ class TestInactiveUserBlocking:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_inactive_user_blocked_from_organization_endpoints(
|
||||
self,
|
||||
client: AsyncClient,
|
||||
async_test_db,
|
||||
async_test_user: User,
|
||||
user_token: str
|
||||
self, client: AsyncClient, async_test_db, async_test_user: User, user_token: str
|
||||
):
|
||||
"""
|
||||
Test that inactive users can't access organization endpoints.
|
||||
|
||||
Ensures the inactive check applies to ALL protected endpoints.
|
||||
"""
|
||||
test_engine, SessionLocal = async_test_db
|
||||
_test_engine, SessionLocal = async_test_db
|
||||
|
||||
# Deactivate user
|
||||
async with SessionLocal() as session:
|
||||
@@ -97,7 +87,7 @@ class TestInactiveUserBlocking:
|
||||
# Try to list organizations
|
||||
response = await client.get(
|
||||
"/api/v1/organizations/me",
|
||||
headers={"Authorization": f"Bearer {user_token}"}
|
||||
headers={"Authorization": f"Bearer {user_token}"},
|
||||
)
|
||||
|
||||
# Must be blocked
|
||||
@@ -122,7 +112,7 @@ class TestSuperuserPrivilegeEscalation:
|
||||
client: AsyncClient,
|
||||
async_test_db,
|
||||
async_test_superuser: User,
|
||||
superuser_token: str
|
||||
superuser_token: str,
|
||||
):
|
||||
"""
|
||||
Test that superusers automatically get OWNER role in organizations.
|
||||
@@ -131,14 +121,11 @@ class TestSuperuserPrivilegeEscalation:
|
||||
Superusers can manage any organization without being explicitly added.
|
||||
This is for platform administration.
|
||||
"""
|
||||
test_engine, SessionLocal = async_test_db
|
||||
_test_engine, SessionLocal = async_test_db
|
||||
|
||||
# Step 1: Create an organization (owned by someone else)
|
||||
async with SessionLocal() as session:
|
||||
org = Organization(
|
||||
name="Test Organization",
|
||||
slug="test-org"
|
||||
)
|
||||
org = Organization(name="Test Organization", slug="test-org")
|
||||
session.add(org)
|
||||
await session.commit()
|
||||
await session.refresh(org)
|
||||
@@ -148,7 +135,7 @@ class TestSuperuserPrivilegeEscalation:
|
||||
# (They're not a member, but should auto-get OWNER role)
|
||||
response = await client.get(
|
||||
f"/api/v1/organizations/{org_id}",
|
||||
headers={"Authorization": f"Bearer {superuser_token}"}
|
||||
headers={"Authorization": f"Bearer {superuser_token}"},
|
||||
)
|
||||
|
||||
# Step 3: Should have access (covers lines 154-157)
|
||||
@@ -161,21 +148,18 @@ class TestSuperuserPrivilegeEscalation:
|
||||
client: AsyncClient,
|
||||
async_test_db,
|
||||
async_test_superuser: User,
|
||||
superuser_token: str
|
||||
superuser_token: str,
|
||||
):
|
||||
"""
|
||||
Test that superusers have full management access to all organizations.
|
||||
|
||||
Ensures the OWNER role privilege escalation works end-to-end.
|
||||
"""
|
||||
test_engine, SessionLocal = async_test_db
|
||||
_test_engine, SessionLocal = async_test_db
|
||||
|
||||
# Create an organization
|
||||
async with SessionLocal() as session:
|
||||
org = Organization(
|
||||
name="Test Organization",
|
||||
slug="test-org"
|
||||
)
|
||||
org = Organization(name="Test Organization", slug="test-org")
|
||||
session.add(org)
|
||||
await session.commit()
|
||||
await session.refresh(org)
|
||||
@@ -185,34 +169,29 @@ class TestSuperuserPrivilegeEscalation:
|
||||
response = await client.put(
|
||||
f"/api/v1/organizations/{org_id}",
|
||||
headers={"Authorization": f"Bearer {superuser_token}"},
|
||||
json={"name": "Updated Name"}
|
||||
json={"name": "Updated Name"},
|
||||
)
|
||||
|
||||
# Should succeed (superuser has OWNER privileges)
|
||||
assert response.status_code in [200, 404], "Superuser should be able to manage any org"
|
||||
assert response.status_code in [200, 404], (
|
||||
"Superuser should be able to manage any org"
|
||||
)
|
||||
# Note: Might be 404 if org endpoints require membership, but the role check passes
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_regular_user_does_not_get_owner_role(
|
||||
self,
|
||||
client: AsyncClient,
|
||||
async_test_db,
|
||||
async_test_user: User,
|
||||
user_token: str
|
||||
self, client: AsyncClient, async_test_db, async_test_user: User, user_token: str
|
||||
):
|
||||
"""
|
||||
Sanity check: Regular users don't get automatic OWNER role.
|
||||
|
||||
Ensures the superuser check is working correctly (line 154).
|
||||
"""
|
||||
test_engine, SessionLocal = async_test_db
|
||||
_test_engine, SessionLocal = async_test_db
|
||||
|
||||
# Create an organization
|
||||
async with SessionLocal() as session:
|
||||
org = Organization(
|
||||
name="Test Organization",
|
||||
slug="test-org"
|
||||
)
|
||||
org = Organization(name="Test Organization", slug="test-org")
|
||||
session.add(org)
|
||||
await session.commit()
|
||||
await session.refresh(org)
|
||||
@@ -221,8 +200,10 @@ class TestSuperuserPrivilegeEscalation:
|
||||
# Regular user tries to access it (not a member)
|
||||
response = await client.get(
|
||||
f"/api/v1/organizations/{org_id}",
|
||||
headers={"Authorization": f"Bearer {user_token}"}
|
||||
headers={"Authorization": f"Bearer {user_token}"},
|
||||
)
|
||||
|
||||
# Should be denied (not a member, not a superuser)
|
||||
assert response.status_code in [403, 404], "Regular user shouldn't access non-member org"
|
||||
assert response.status_code in [403, 404], (
|
||||
"Regular user shouldn't access non-member org"
|
||||
)
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
# tests/api/test_security_headers.py
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from fastapi.testclient import TestClient
|
||||
from unittest.mock import patch
|
||||
|
||||
from app.main import app
|
||||
|
||||
@@ -11,8 +12,10 @@ def client():
|
||||
"""Create a FastAPI test client for the main app (module-scoped for speed)."""
|
||||
# Mock get_db to avoid database connection issues
|
||||
with patch("app.core.database.get_db") as mock_get_db:
|
||||
|
||||
async def mock_session_generator():
|
||||
from unittest.mock import MagicMock, AsyncMock
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
mock_session = MagicMock()
|
||||
mock_session.execute = AsyncMock(return_value=None)
|
||||
mock_session.close = AsyncMock(return_value=None)
|
||||
@@ -77,8 +80,10 @@ class TestSecurityHeaders:
|
||||
"""Test that HSTS header is set in production (covers line 95)"""
|
||||
with patch("app.core.config.settings.ENVIRONMENT", "production"):
|
||||
with patch("app.core.database.get_db") as mock_get_db:
|
||||
|
||||
async def mock_session_generator():
|
||||
from unittest.mock import MagicMock, AsyncMock
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
mock_session = MagicMock()
|
||||
mock_session.execute = AsyncMock(return_value=None)
|
||||
mock_session.close = AsyncMock(return_value=None)
|
||||
@@ -88,20 +93,26 @@ class TestSecurityHeaders:
|
||||
|
||||
# Need to reimport app to pick up the new settings
|
||||
from importlib import reload
|
||||
|
||||
import app.main
|
||||
|
||||
reload(app.main)
|
||||
test_client = TestClient(app.main.app)
|
||||
|
||||
response = test_client.get("/health")
|
||||
assert "Strict-Transport-Security" in response.headers
|
||||
assert "max-age=31536000" in response.headers["Strict-Transport-Security"]
|
||||
assert (
|
||||
"max-age=31536000" in response.headers["Strict-Transport-Security"]
|
||||
)
|
||||
|
||||
def test_csp_strict_mode(self):
|
||||
"""Test CSP strict mode (covers line 121)"""
|
||||
with patch("app.core.config.settings.CSP_MODE", "strict"):
|
||||
with patch("app.core.database.get_db") as mock_get_db:
|
||||
|
||||
async def mock_session_generator():
|
||||
from unittest.mock import MagicMock, AsyncMock
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
mock_session = MagicMock()
|
||||
mock_session.execute = AsyncMock(return_value=None)
|
||||
mock_session.close = AsyncMock(return_value=None)
|
||||
@@ -110,7 +121,9 @@ class TestSecurityHeaders:
|
||||
mock_get_db.side_effect = lambda: mock_session_generator()
|
||||
|
||||
from importlib import reload
|
||||
|
||||
import app.main
|
||||
|
||||
reload(app.main)
|
||||
test_client = TestClient(app.main.app)
|
||||
|
||||
@@ -136,8 +149,10 @@ class TestRootEndpoint:
|
||||
def test_root_endpoint(self):
|
||||
"""Test root endpoint returns HTML (covers line 174)"""
|
||||
with patch("app.core.database.get_db") as mock_get_db:
|
||||
|
||||
async def mock_session_generator():
|
||||
from unittest.mock import MagicMock, AsyncMock
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
mock_session = MagicMock()
|
||||
mock_session.execute = AsyncMock(return_value=None)
|
||||
mock_session.close = AsyncMock(return_value=None)
|
||||
|
||||
@@ -2,23 +2,23 @@
|
||||
"""
|
||||
Comprehensive tests for session management API endpoints.
|
||||
"""
|
||||
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from unittest.mock import patch
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from uuid import uuid4
|
||||
from unittest.mock import patch
|
||||
|
||||
from fastapi import status
|
||||
|
||||
from app.models.user_session import UserSession
|
||||
from app.schemas.users import UserCreate
|
||||
|
||||
|
||||
# Disable rate limiting for tests
|
||||
@pytest.fixture(autouse=True)
|
||||
def disable_rate_limit():
|
||||
"""Disable rate limiting for all tests in this module."""
|
||||
with patch('app.api.routes.sessions.limiter.enabled', False):
|
||||
with patch("app.api.routes.sessions.limiter.enabled", False):
|
||||
yield
|
||||
|
||||
|
||||
@@ -27,10 +27,7 @@ async def user_token(client, async_test_user):
|
||||
"""Create and return an access token for async_test_user."""
|
||||
response = await client.post(
|
||||
"/api/v1/auth/login",
|
||||
json={
|
||||
"email": "testuser@example.com",
|
||||
"password": "TestPassword123!"
|
||||
}
|
||||
json={"email": "testuser@example.com", "password": "TestPassword123!"},
|
||||
)
|
||||
assert response.status_code == 200
|
||||
return response.json()["access_token"]
|
||||
@@ -39,7 +36,7 @@ async def user_token(client, async_test_user):
|
||||
@pytest_asyncio.fixture
|
||||
async def async_test_user2(async_test_db):
|
||||
"""Create a second test user."""
|
||||
test_engine, SessionLocal = async_test_db
|
||||
_test_engine, SessionLocal = async_test_db
|
||||
|
||||
async with SessionLocal() as session:
|
||||
from app.crud.user import user as user_crud
|
||||
@@ -49,7 +46,7 @@ async def async_test_user2(async_test_db):
|
||||
email="testuser2@example.com",
|
||||
password="TestPassword123!",
|
||||
first_name="Test",
|
||||
last_name="User2"
|
||||
last_name="User2",
|
||||
)
|
||||
user = await user_crud.create(session, obj_in=user_data)
|
||||
await session.commit()
|
||||
@@ -61,9 +58,11 @@ class TestListMySessions:
|
||||
"""Tests for GET /api/v1/sessions/me endpoint."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_my_sessions_success(self, client, async_test_user, async_test_db, user_token):
|
||||
async def test_list_my_sessions_success(
|
||||
self, client, async_test_user, async_test_db, user_token
|
||||
):
|
||||
"""Test successfully listing user's active sessions."""
|
||||
test_engine, SessionLocal = async_test_db
|
||||
_test_engine, SessionLocal = async_test_db
|
||||
|
||||
# Create some sessions for the user
|
||||
async with SessionLocal() as session:
|
||||
@@ -75,8 +74,8 @@ class TestListMySessions:
|
||||
ip_address="192.168.1.100",
|
||||
user_agent="Mozilla/5.0 (iPhone)",
|
||||
is_active=True,
|
||||
expires_at=datetime.now(timezone.utc) + timedelta(days=7),
|
||||
last_used_at=datetime.now(timezone.utc)
|
||||
expires_at=datetime.now(UTC) + timedelta(days=7),
|
||||
last_used_at=datetime.now(UTC),
|
||||
)
|
||||
# Active session 2
|
||||
s2 = UserSession(
|
||||
@@ -86,8 +85,8 @@ class TestListMySessions:
|
||||
ip_address="192.168.1.101",
|
||||
user_agent="Mozilla/5.0 (Macintosh)",
|
||||
is_active=True,
|
||||
expires_at=datetime.now(timezone.utc) + timedelta(days=7),
|
||||
last_used_at=datetime.now(timezone.utc) - timedelta(hours=1)
|
||||
expires_at=datetime.now(UTC) + timedelta(days=7),
|
||||
last_used_at=datetime.now(UTC) - timedelta(hours=1),
|
||||
)
|
||||
# Inactive session (should not appear)
|
||||
s3 = UserSession(
|
||||
@@ -97,16 +96,15 @@ class TestListMySessions:
|
||||
ip_address="192.168.1.102",
|
||||
user_agent="Mozilla/5.0",
|
||||
is_active=False,
|
||||
expires_at=datetime.now(timezone.utc) + timedelta(days=7),
|
||||
last_used_at=datetime.now(timezone.utc) - timedelta(days=1)
|
||||
expires_at=datetime.now(UTC) + timedelta(days=7),
|
||||
last_used_at=datetime.now(UTC) - timedelta(days=1),
|
||||
)
|
||||
session.add_all([s1, s2, s3])
|
||||
await session.commit()
|
||||
|
||||
# Make request
|
||||
response = await client.get(
|
||||
"/api/v1/sessions/me",
|
||||
headers={"Authorization": f"Bearer {user_token}"}
|
||||
"/api/v1/sessions/me", headers={"Authorization": f"Bearer {user_token}"}
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
@@ -128,11 +126,12 @@ class TestListMySessions:
|
||||
assert data["sessions"][0]["is_current"] is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_my_sessions_with_login_session(self, client, async_test_user, user_token):
|
||||
async def test_list_my_sessions_with_login_session(
|
||||
self, client, async_test_user, user_token
|
||||
):
|
||||
"""Test listing sessions shows the login session."""
|
||||
response = await client.get(
|
||||
"/api/v1/sessions/me",
|
||||
headers={"Authorization": f"Bearer {user_token}"}
|
||||
"/api/v1/sessions/me", headers={"Authorization": f"Bearer {user_token}"}
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
@@ -155,9 +154,11 @@ class TestRevokeSession:
|
||||
"""Tests for DELETE /api/v1/sessions/{session_id} endpoint."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_revoke_session_success(self, client, async_test_user, async_test_db, user_token):
|
||||
async def test_revoke_session_success(
|
||||
self, client, async_test_user, async_test_db, user_token
|
||||
):
|
||||
"""Test successfully revoking a session."""
|
||||
test_engine, SessionLocal = async_test_db
|
||||
_test_engine, SessionLocal = async_test_db
|
||||
|
||||
# Create a session to revoke
|
||||
async with SessionLocal() as session:
|
||||
@@ -168,8 +169,8 @@ class TestRevokeSession:
|
||||
ip_address="192.168.1.103",
|
||||
user_agent="Mozilla/5.0 (iPad)",
|
||||
is_active=True,
|
||||
expires_at=datetime.now(timezone.utc) + timedelta(days=7),
|
||||
last_used_at=datetime.now(timezone.utc)
|
||||
expires_at=datetime.now(UTC) + timedelta(days=7),
|
||||
last_used_at=datetime.now(UTC),
|
||||
)
|
||||
session.add(user_session)
|
||||
await session.commit()
|
||||
@@ -179,7 +180,7 @@ class TestRevokeSession:
|
||||
# Revoke the session
|
||||
response = await client.delete(
|
||||
f"/api/v1/sessions/{session_id}",
|
||||
headers={"Authorization": f"Bearer {user_token}"}
|
||||
headers={"Authorization": f"Bearer {user_token}"},
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
@@ -191,6 +192,7 @@ class TestRevokeSession:
|
||||
# Verify session is deactivated
|
||||
async with SessionLocal() as session:
|
||||
from app.crud.session import session as session_crud
|
||||
|
||||
revoked_session = await session_crud.get(session, id=str(session_id))
|
||||
assert revoked_session.is_active is False
|
||||
|
||||
@@ -200,7 +202,7 @@ class TestRevokeSession:
|
||||
fake_id = uuid4()
|
||||
response = await client.delete(
|
||||
f"/api/v1/sessions/{fake_id}",
|
||||
headers={"Authorization": f"Bearer {user_token}"}
|
||||
headers={"Authorization": f"Bearer {user_token}"},
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_404_NOT_FOUND
|
||||
@@ -222,7 +224,7 @@ class TestRevokeSession:
|
||||
self, client, async_test_user, async_test_user2, async_test_db, user_token
|
||||
):
|
||||
"""Test that users cannot revoke other users' sessions."""
|
||||
test_engine, SessionLocal = async_test_db
|
||||
_test_engine, SessionLocal = async_test_db
|
||||
|
||||
# Create a session for user2
|
||||
async with SessionLocal() as session:
|
||||
@@ -233,8 +235,8 @@ class TestRevokeSession:
|
||||
ip_address="192.168.1.200",
|
||||
user_agent="Mozilla/5.0",
|
||||
is_active=True,
|
||||
expires_at=datetime.now(timezone.utc) + timedelta(days=7),
|
||||
last_used_at=datetime.now(timezone.utc)
|
||||
expires_at=datetime.now(UTC) + timedelta(days=7),
|
||||
last_used_at=datetime.now(UTC),
|
||||
)
|
||||
session.add(other_user_session)
|
||||
await session.commit()
|
||||
@@ -244,7 +246,7 @@ class TestRevokeSession:
|
||||
# Try to revoke it as user1
|
||||
response = await client.delete(
|
||||
f"/api/v1/sessions/{session_id}",
|
||||
headers={"Authorization": f"Bearer {user_token}"}
|
||||
headers={"Authorization": f"Bearer {user_token}"},
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_403_FORBIDDEN
|
||||
@@ -263,7 +265,7 @@ class TestCleanupExpiredSessions:
|
||||
self, client, async_test_user, async_test_db, user_token
|
||||
):
|
||||
"""Test successfully cleaning up expired sessions."""
|
||||
test_engine, SessionLocal = async_test_db
|
||||
_test_engine, SessionLocal = async_test_db
|
||||
|
||||
# Create expired and active sessions using CRUD to avoid greenlet issues
|
||||
from app.crud.session import session as session_crud
|
||||
@@ -277,8 +279,8 @@ class TestCleanupExpiredSessions:
|
||||
device_name="Expired 1",
|
||||
ip_address="192.168.1.201",
|
||||
user_agent="Mozilla/5.0",
|
||||
expires_at=datetime.now(timezone.utc) - timedelta(days=1),
|
||||
last_used_at=datetime.now(timezone.utc) - timedelta(days=2)
|
||||
expires_at=datetime.now(UTC) - timedelta(days=1),
|
||||
last_used_at=datetime.now(UTC) - timedelta(days=2),
|
||||
)
|
||||
e1 = await session_crud.create_session(db, obj_in=e1_data)
|
||||
e1.is_active = False
|
||||
@@ -291,8 +293,8 @@ class TestCleanupExpiredSessions:
|
||||
device_name="Expired 2",
|
||||
ip_address="192.168.1.202",
|
||||
user_agent="Mozilla/5.0",
|
||||
expires_at=datetime.now(timezone.utc) - timedelta(hours=1),
|
||||
last_used_at=datetime.now(timezone.utc) - timedelta(hours=2)
|
||||
expires_at=datetime.now(UTC) - timedelta(hours=1),
|
||||
last_used_at=datetime.now(UTC) - timedelta(hours=2),
|
||||
)
|
||||
e2 = await session_crud.create_session(db, obj_in=e2_data)
|
||||
e2.is_active = False
|
||||
@@ -305,8 +307,8 @@ class TestCleanupExpiredSessions:
|
||||
device_name="Active",
|
||||
ip_address="192.168.1.203",
|
||||
user_agent="Mozilla/5.0",
|
||||
expires_at=datetime.now(timezone.utc) + timedelta(days=7),
|
||||
last_used_at=datetime.now(timezone.utc)
|
||||
expires_at=datetime.now(UTC) + timedelta(days=7),
|
||||
last_used_at=datetime.now(UTC),
|
||||
)
|
||||
await session_crud.create_session(db, obj_in=a1_data)
|
||||
await db.commit()
|
||||
@@ -314,7 +316,7 @@ class TestCleanupExpiredSessions:
|
||||
# Cleanup expired sessions
|
||||
response = await client.delete(
|
||||
"/api/v1/sessions/me/expired",
|
||||
headers={"Authorization": f"Bearer {user_token}"}
|
||||
headers={"Authorization": f"Bearer {user_token}"},
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
@@ -329,7 +331,7 @@ class TestCleanupExpiredSessions:
|
||||
self, client, async_test_user, async_test_db, user_token
|
||||
):
|
||||
"""Test cleanup when no sessions are expired."""
|
||||
test_engine, SessionLocal = async_test_db
|
||||
_test_engine, SessionLocal = async_test_db
|
||||
|
||||
# Create only active sessions using CRUD
|
||||
from app.crud.session import session as session_crud
|
||||
@@ -342,15 +344,15 @@ class TestCleanupExpiredSessions:
|
||||
device_name="Active Device",
|
||||
ip_address="192.168.1.210",
|
||||
user_agent="Mozilla/5.0",
|
||||
expires_at=datetime.now(timezone.utc) + timedelta(days=7),
|
||||
last_used_at=datetime.now(timezone.utc)
|
||||
expires_at=datetime.now(UTC) + timedelta(days=7),
|
||||
last_used_at=datetime.now(UTC),
|
||||
)
|
||||
await session_crud.create_session(db, obj_in=a1_data)
|
||||
await db.commit()
|
||||
|
||||
response = await client.delete(
|
||||
"/api/v1/sessions/me/expired",
|
||||
headers={"Authorization": f"Bearer {user_token}"}
|
||||
headers={"Authorization": f"Bearer {user_token}"},
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
@@ -369,13 +371,16 @@ class TestCleanupExpiredSessions:
|
||||
|
||||
# Additional tests for better coverage
|
||||
|
||||
|
||||
class TestSessionsAdditionalCases:
|
||||
"""Additional tests to improve sessions endpoint coverage."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_sessions_pagination(self, client, async_test_user, async_test_db, user_token):
|
||||
async def test_list_sessions_pagination(
|
||||
self, client, async_test_user, async_test_db, user_token
|
||||
):
|
||||
"""Test listing sessions with pagination."""
|
||||
test_engine, SessionLocal = async_test_db
|
||||
_test_engine, SessionLocal = async_test_db
|
||||
|
||||
# Create multiple sessions
|
||||
async with SessionLocal() as session:
|
||||
@@ -389,15 +394,15 @@ class TestSessionsAdditionalCases:
|
||||
device_name=f"Device {i}",
|
||||
ip_address=f"192.168.1.{i}",
|
||||
user_agent="Mozilla/5.0",
|
||||
expires_at=datetime.now(timezone.utc) + timedelta(days=7),
|
||||
last_used_at=datetime.now(timezone.utc)
|
||||
expires_at=datetime.now(UTC) + timedelta(days=7),
|
||||
last_used_at=datetime.now(UTC),
|
||||
)
|
||||
await session_crud.create_session(session, obj_in=session_data)
|
||||
await session.commit()
|
||||
|
||||
response = await client.get(
|
||||
"/api/v1/sessions/me?page=1&limit=3",
|
||||
headers={"Authorization": f"Bearer {user_token}"}
|
||||
headers={"Authorization": f"Bearer {user_token}"},
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
@@ -410,16 +415,21 @@ class TestSessionsAdditionalCases:
|
||||
"""Test revoking session with invalid UUID."""
|
||||
response = await client.delete(
|
||||
"/api/v1/sessions/not-a-uuid",
|
||||
headers={"Authorization": f"Bearer {user_token}"}
|
||||
headers={"Authorization": f"Bearer {user_token}"},
|
||||
)
|
||||
|
||||
# Should return 422 for invalid UUID format
|
||||
assert response.status_code in [status.HTTP_422_UNPROCESSABLE_ENTITY, status.HTTP_404_NOT_FOUND]
|
||||
assert response.status_code in [
|
||||
status.HTTP_422_UNPROCESSABLE_ENTITY,
|
||||
status.HTTP_404_NOT_FOUND,
|
||||
]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cleanup_expired_sessions_with_mixed_states(self, client, async_test_user, async_test_db, user_token):
|
||||
async def test_cleanup_expired_sessions_with_mixed_states(
|
||||
self, client, async_test_user, async_test_db, user_token
|
||||
):
|
||||
"""Test cleanup with mix of active/inactive and expired/not-expired sessions."""
|
||||
test_engine, SessionLocal = async_test_db
|
||||
_test_engine, SessionLocal = async_test_db
|
||||
|
||||
from app.crud.session import session as session_crud
|
||||
from app.schemas.sessions import SessionCreate
|
||||
@@ -432,8 +442,8 @@ class TestSessionsAdditionalCases:
|
||||
device_name="Expired Inactive",
|
||||
ip_address="192.168.1.100",
|
||||
user_agent="Mozilla/5.0",
|
||||
expires_at=datetime.now(timezone.utc) - timedelta(days=1),
|
||||
last_used_at=datetime.now(timezone.utc) - timedelta(days=2)
|
||||
expires_at=datetime.now(UTC) - timedelta(days=1),
|
||||
last_used_at=datetime.now(UTC) - timedelta(days=2),
|
||||
)
|
||||
e1 = await session_crud.create_session(db, obj_in=e1_data)
|
||||
e1.is_active = False
|
||||
@@ -446,8 +456,8 @@ class TestSessionsAdditionalCases:
|
||||
device_name="Expired Active",
|
||||
ip_address="192.168.1.101",
|
||||
user_agent="Mozilla/5.0",
|
||||
expires_at=datetime.now(timezone.utc) - timedelta(hours=1),
|
||||
last_used_at=datetime.now(timezone.utc) - timedelta(hours=2)
|
||||
expires_at=datetime.now(UTC) - timedelta(hours=1),
|
||||
last_used_at=datetime.now(UTC) - timedelta(hours=2),
|
||||
)
|
||||
await session_crud.create_session(db, obj_in=e2_data)
|
||||
|
||||
@@ -455,7 +465,7 @@ class TestSessionsAdditionalCases:
|
||||
|
||||
response = await client.delete(
|
||||
"/api/v1/sessions/me/expired",
|
||||
headers={"Authorization": f"Bearer {user_token}"}
|
||||
headers={"Authorization": f"Bearer {user_token}"},
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
@@ -476,10 +486,12 @@ class TestSessionExceptionHandlers:
|
||||
from unittest.mock import patch
|
||||
|
||||
# Patch decode_token to raise an exception
|
||||
with patch('app.api.routes.sessions.decode_token', side_effect=Exception("Token decode error")):
|
||||
with patch(
|
||||
"app.api.routes.sessions.decode_token",
|
||||
side_effect=Exception("Token decode error"),
|
||||
):
|
||||
response = await client.get(
|
||||
"/api/v1/sessions/me",
|
||||
headers={"Authorization": f"Bearer {user_token}"}
|
||||
"/api/v1/sessions/me", headers={"Authorization": f"Bearer {user_token}"}
|
||||
)
|
||||
|
||||
# Should still succeed (exception is caught and ignored in try/except at line 77)
|
||||
@@ -489,12 +501,16 @@ class TestSessionExceptionHandlers:
|
||||
async def test_list_sessions_database_error(self, client, user_token):
|
||||
"""Test list_sessions handles database errors (covers lines 104-106)."""
|
||||
from unittest.mock import patch
|
||||
|
||||
from app.crud import session as session_module
|
||||
|
||||
with patch.object(session_module.session, 'get_user_sessions', side_effect=Exception("Database error")):
|
||||
with patch.object(
|
||||
session_module.session,
|
||||
"get_user_sessions",
|
||||
side_effect=Exception("Database error"),
|
||||
):
|
||||
response = await client.get(
|
||||
"/api/v1/sessions/me",
|
||||
headers={"Authorization": f"Bearer {user_token}"}
|
||||
"/api/v1/sessions/me", headers={"Authorization": f"Bearer {user_token}"}
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR
|
||||
@@ -503,18 +519,21 @@ class TestSessionExceptionHandlers:
|
||||
assert data["errors"][0]["message"] == "Failed to retrieve sessions"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_revoke_session_database_error(self, client, user_token, async_test_db, async_test_user):
|
||||
async def test_revoke_session_database_error(
|
||||
self, client, user_token, async_test_db, async_test_user
|
||||
):
|
||||
"""Test revoke_session handles database errors (covers lines 181-183)."""
|
||||
from datetime import datetime, timedelta
|
||||
from unittest.mock import patch
|
||||
from uuid import uuid4
|
||||
|
||||
from app.crud import session as session_module
|
||||
|
||||
# First create a session to revoke
|
||||
from app.crud.session import session as session_crud
|
||||
from app.schemas.sessions import SessionCreate
|
||||
from datetime import datetime, timedelta, timezone
|
||||
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
async with AsyncTestingSessionLocal() as db:
|
||||
session_in = SessionCreate(
|
||||
@@ -523,17 +542,21 @@ class TestSessionExceptionHandlers:
|
||||
device_name="Test Device",
|
||||
ip_address="192.168.1.1",
|
||||
user_agent="Mozilla/5.0",
|
||||
last_used_at=datetime.now(timezone.utc),
|
||||
expires_at=datetime.now(timezone.utc) + timedelta(days=60)
|
||||
last_used_at=datetime.now(UTC),
|
||||
expires_at=datetime.now(UTC) + timedelta(days=60),
|
||||
)
|
||||
user_session = await session_crud.create_session(db, obj_in=session_in)
|
||||
session_id = user_session.id
|
||||
|
||||
# Mock the deactivate method to raise an exception
|
||||
with patch.object(session_module.session, 'deactivate', side_effect=Exception("Database connection lost")):
|
||||
with patch.object(
|
||||
session_module.session,
|
||||
"deactivate",
|
||||
side_effect=Exception("Database connection lost"),
|
||||
):
|
||||
response = await client.delete(
|
||||
f"/api/v1/sessions/{session_id}",
|
||||
headers={"Authorization": f"Bearer {user_token}"}
|
||||
headers={"Authorization": f"Bearer {user_token}"},
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR
|
||||
@@ -544,12 +567,17 @@ class TestSessionExceptionHandlers:
|
||||
async def test_cleanup_expired_sessions_database_error(self, client, user_token):
|
||||
"""Test cleanup_expired_sessions handles database errors (covers lines 233-236)."""
|
||||
from unittest.mock import patch
|
||||
|
||||
from app.crud import session as session_module
|
||||
|
||||
with patch.object(session_module.session, 'cleanup_expired_for_user', side_effect=Exception("Cleanup failed")):
|
||||
with patch.object(
|
||||
session_module.session,
|
||||
"cleanup_expired_for_user",
|
||||
side_effect=Exception("Cleanup failed"),
|
||||
):
|
||||
response = await client.delete(
|
||||
"/api/v1/sessions/me/expired",
|
||||
headers={"Authorization": f"Bearer {user_token}"}
|
||||
headers={"Authorization": f"Bearer {user_token}"},
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR
|
||||
|
||||
@@ -3,32 +3,29 @@
|
||||
Comprehensive tests for user management endpoints.
|
||||
These tests focus on finding potential bugs, not just coverage.
|
||||
"""
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
from unittest.mock import patch
|
||||
from fastapi import status
|
||||
import uuid
|
||||
|
||||
from sqlalchemy import select
|
||||
import uuid
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from fastapi import status
|
||||
|
||||
from app.models.user import User
|
||||
from app.models.user import User
|
||||
from app.schemas.users import UserUpdate
|
||||
|
||||
|
||||
# Disable rate limiting for tests
|
||||
@pytest.fixture(autouse=True)
|
||||
def disable_rate_limit():
|
||||
"""Disable rate limiting for all tests in this module."""
|
||||
with patch('app.api.routes.users.limiter.enabled', False):
|
||||
with patch('app.api.routes.auth.limiter.enabled', False):
|
||||
with patch("app.api.routes.users.limiter.enabled", False):
|
||||
with patch("app.api.routes.auth.limiter.enabled", False):
|
||||
yield
|
||||
|
||||
|
||||
async def get_auth_headers(client, email, password):
|
||||
"""Helper to get authentication headers."""
|
||||
response = await client.post(
|
||||
"/api/v1/auth/login",
|
||||
json={"email": email, "password": password}
|
||||
"/api/v1/auth/login", json={"email": email, "password": password}
|
||||
)
|
||||
token = response.json()["access_token"]
|
||||
return {"Authorization": f"Bearer {token}"}
|
||||
@@ -40,7 +37,9 @@ class TestListUsers:
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_users_as_superuser(self, client, async_test_superuser):
|
||||
"""Test listing users as superuser."""
|
||||
headers = await get_auth_headers(client, async_test_superuser.email, "SuperPassword123!")
|
||||
headers = await get_auth_headers(
|
||||
client, async_test_superuser.email, "SuperPassword123!"
|
||||
)
|
||||
|
||||
response = await client.get("/api/v1/users", headers=headers)
|
||||
|
||||
@@ -53,16 +52,20 @@ class TestListUsers:
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_users_as_regular_user(self, client, async_test_user):
|
||||
"""Test that regular users cannot list users."""
|
||||
headers = await get_auth_headers(client, async_test_user.email, "TestPassword123!")
|
||||
headers = await get_auth_headers(
|
||||
client, async_test_user.email, "TestPassword123!"
|
||||
)
|
||||
|
||||
response = await client.get("/api/v1/users", headers=headers)
|
||||
|
||||
assert response.status_code == status.HTTP_403_FORBIDDEN
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_users_pagination(self, client, async_test_superuser, async_test_db):
|
||||
async def test_list_users_pagination(
|
||||
self, client, async_test_superuser, async_test_db
|
||||
):
|
||||
"""Test pagination works correctly."""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
# Create multiple users
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
@@ -72,12 +75,14 @@ class TestListUsers:
|
||||
password_hash="hash",
|
||||
first_name=f"PagUser{i}",
|
||||
is_active=True,
|
||||
is_superuser=False
|
||||
is_superuser=False,
|
||||
)
|
||||
session.add(user)
|
||||
await session.commit()
|
||||
|
||||
headers = await get_auth_headers(client, async_test_superuser.email, "SuperPassword123!")
|
||||
headers = await get_auth_headers(
|
||||
client, async_test_superuser.email, "SuperPassword123!"
|
||||
)
|
||||
|
||||
# Get first page
|
||||
response = await client.get("/api/v1/users?page=1&limit=5", headers=headers)
|
||||
@@ -88,9 +93,11 @@ class TestListUsers:
|
||||
assert data["pagination"]["total"] >= 15
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_users_filter_active(self, client, async_test_superuser, async_test_db):
|
||||
async def test_list_users_filter_active(
|
||||
self, client, async_test_superuser, async_test_db
|
||||
):
|
||||
"""Test filtering by active status."""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
# Create active and inactive users
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
@@ -99,19 +106,21 @@ class TestListUsers:
|
||||
password_hash="hash",
|
||||
first_name="Active",
|
||||
is_active=True,
|
||||
is_superuser=False
|
||||
is_superuser=False,
|
||||
)
|
||||
inactive_user = User(
|
||||
email="inactivefilter@example.com",
|
||||
password_hash="hash",
|
||||
first_name="Inactive",
|
||||
is_active=False,
|
||||
is_superuser=False
|
||||
is_superuser=False,
|
||||
)
|
||||
session.add_all([active_user, inactive_user])
|
||||
await session.commit()
|
||||
|
||||
headers = await get_auth_headers(client, async_test_superuser.email, "SuperPassword123!")
|
||||
headers = await get_auth_headers(
|
||||
client, async_test_superuser.email, "SuperPassword123!"
|
||||
)
|
||||
|
||||
# Filter for active users
|
||||
response = await client.get("/api/v1/users?is_active=true", headers=headers)
|
||||
@@ -130,9 +139,13 @@ class TestListUsers:
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_users_sort_by_email(self, client, async_test_superuser):
|
||||
"""Test sorting users by email."""
|
||||
headers = await get_auth_headers(client, async_test_superuser.email, "SuperPassword123!")
|
||||
headers = await get_auth_headers(
|
||||
client, async_test_superuser.email, "SuperPassword123!"
|
||||
)
|
||||
|
||||
response = await client.get("/api/v1/users?sort_by=email&sort_order=asc", headers=headers)
|
||||
response = await client.get(
|
||||
"/api/v1/users?sort_by=email&sort_order=asc", headers=headers
|
||||
)
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
data = response.json()
|
||||
emails = [u["email"] for u in data["data"]]
|
||||
@@ -154,7 +167,9 @@ class TestGetCurrentUserProfile:
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_own_profile(self, client, async_test_user):
|
||||
"""Test getting own profile."""
|
||||
headers = await get_auth_headers(client, async_test_user.email, "TestPassword123!")
|
||||
headers = await get_auth_headers(
|
||||
client, async_test_user.email, "TestPassword123!"
|
||||
)
|
||||
|
||||
response = await client.get("/api/v1/users/me", headers=headers)
|
||||
|
||||
@@ -176,12 +191,14 @@ class TestUpdateCurrentUser:
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_own_profile(self, client, async_test_user):
|
||||
"""Test updating own profile."""
|
||||
headers = await get_auth_headers(client, async_test_user.email, "TestPassword123!")
|
||||
headers = await get_auth_headers(
|
||||
client, async_test_user.email, "TestPassword123!"
|
||||
)
|
||||
|
||||
response = await client.patch(
|
||||
"/api/v1/users/me",
|
||||
headers=headers,
|
||||
json={"first_name": "Updated", "last_name": "Name"}
|
||||
json={"first_name": "Updated", "last_name": "Name"},
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
@@ -192,12 +209,12 @@ class TestUpdateCurrentUser:
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_profile_phone_number(self, client, async_test_user, test_db):
|
||||
"""Test updating phone number with validation."""
|
||||
headers = await get_auth_headers(client, async_test_user.email, "TestPassword123!")
|
||||
headers = await get_auth_headers(
|
||||
client, async_test_user.email, "TestPassword123!"
|
||||
)
|
||||
|
||||
response = await client.patch(
|
||||
"/api/v1/users/me",
|
||||
headers=headers,
|
||||
json={"phone_number": "+19876543210"}
|
||||
"/api/v1/users/me", headers=headers, json={"phone_number": "+19876543210"}
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
@@ -207,12 +224,12 @@ class TestUpdateCurrentUser:
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_profile_invalid_phone(self, client, async_test_user):
|
||||
"""Test that invalid phone numbers are rejected."""
|
||||
headers = await get_auth_headers(client, async_test_user.email, "TestPassword123!")
|
||||
headers = await get_auth_headers(
|
||||
client, async_test_user.email, "TestPassword123!"
|
||||
)
|
||||
|
||||
response = await client.patch(
|
||||
"/api/v1/users/me",
|
||||
headers=headers,
|
||||
json={"phone_number": "invalid"}
|
||||
"/api/v1/users/me", headers=headers, json={"phone_number": "invalid"}
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY
|
||||
@@ -220,14 +237,16 @@ class TestUpdateCurrentUser:
|
||||
@pytest.mark.asyncio
|
||||
async def test_cannot_elevate_to_superuser(self, client, async_test_user):
|
||||
"""Test that users cannot make themselves superuser."""
|
||||
headers = await get_auth_headers(client, async_test_user.email, "TestPassword123!")
|
||||
headers = await get_auth_headers(
|
||||
client, async_test_user.email, "TestPassword123!"
|
||||
)
|
||||
|
||||
# Note: is_superuser is now in UserUpdate schema with explicit validation
|
||||
# This tests that Pydantic rejects the attempt at the schema level
|
||||
response = await client.patch(
|
||||
"/api/v1/users/me",
|
||||
headers=headers,
|
||||
json={"first_name": "Test", "is_superuser": True}
|
||||
json={"first_name": "Test", "is_superuser": True},
|
||||
)
|
||||
|
||||
# Pydantic validation should reject this at the schema level
|
||||
@@ -242,10 +261,7 @@ class TestUpdateCurrentUser:
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_profile_no_auth(self, client):
|
||||
"""Test that unauthenticated requests are rejected."""
|
||||
response = await client.patch(
|
||||
"/api/v1/users/me",
|
||||
json={"first_name": "Hacker"}
|
||||
)
|
||||
response = await client.patch("/api/v1/users/me", json={"first_name": "Hacker"})
|
||||
assert response.status_code == status.HTTP_401_UNAUTHORIZED
|
||||
|
||||
# Note: Removed test_update_profile_unexpected_error - see comment above
|
||||
@@ -257,16 +273,22 @@ class TestGetUserById:
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_own_profile_by_id(self, client, async_test_user):
|
||||
"""Test getting own profile by ID."""
|
||||
headers = await get_auth_headers(client, async_test_user.email, "TestPassword123!")
|
||||
headers = await get_auth_headers(
|
||||
client, async_test_user.email, "TestPassword123!"
|
||||
)
|
||||
|
||||
response = await client.get(f"/api/v1/users/{async_test_user.id}", headers=headers)
|
||||
response = await client.get(
|
||||
f"/api/v1/users/{async_test_user.id}", headers=headers
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
data = response.json()
|
||||
assert data["email"] == async_test_user.email
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_other_user_as_regular_user(self, client, async_test_user, test_db):
|
||||
async def test_get_other_user_as_regular_user(
|
||||
self, client, async_test_user, test_db
|
||||
):
|
||||
"""Test that regular users cannot view other profiles."""
|
||||
# Create another user
|
||||
other_user = User(
|
||||
@@ -274,24 +296,32 @@ class TestGetUserById:
|
||||
password_hash="hash",
|
||||
first_name="Other",
|
||||
is_active=True,
|
||||
is_superuser=False
|
||||
is_superuser=False,
|
||||
)
|
||||
test_db.add(other_user)
|
||||
test_db.commit()
|
||||
test_db.refresh(other_user)
|
||||
|
||||
headers = await get_auth_headers(client, async_test_user.email, "TestPassword123!")
|
||||
headers = await get_auth_headers(
|
||||
client, async_test_user.email, "TestPassword123!"
|
||||
)
|
||||
|
||||
response = await client.get(f"/api/v1/users/{other_user.id}", headers=headers)
|
||||
|
||||
assert response.status_code == status.HTTP_403_FORBIDDEN
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_other_user_as_superuser(self, client, async_test_superuser, async_test_user):
|
||||
async def test_get_other_user_as_superuser(
|
||||
self, client, async_test_superuser, async_test_user
|
||||
):
|
||||
"""Test that superusers can view other profiles."""
|
||||
headers = await get_auth_headers(client, async_test_superuser.email, "SuperPassword123!")
|
||||
headers = await get_auth_headers(
|
||||
client, async_test_superuser.email, "SuperPassword123!"
|
||||
)
|
||||
|
||||
response = await client.get(f"/api/v1/users/{async_test_user.id}", headers=headers)
|
||||
response = await client.get(
|
||||
f"/api/v1/users/{async_test_user.id}", headers=headers
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
data = response.json()
|
||||
@@ -300,7 +330,9 @@ class TestGetUserById:
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_nonexistent_user(self, client, async_test_superuser):
|
||||
"""Test getting non-existent user."""
|
||||
headers = await get_auth_headers(client, async_test_superuser.email, "SuperPassword123!")
|
||||
headers = await get_auth_headers(
|
||||
client, async_test_superuser.email, "SuperPassword123!"
|
||||
)
|
||||
fake_id = uuid.uuid4()
|
||||
|
||||
response = await client.get(f"/api/v1/users/{fake_id}", headers=headers)
|
||||
@@ -310,7 +342,9 @@ class TestGetUserById:
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_user_invalid_uuid(self, client, async_test_superuser):
|
||||
"""Test getting user with invalid UUID format."""
|
||||
headers = await get_auth_headers(client, async_test_superuser.email, "SuperPassword123!")
|
||||
headers = await get_auth_headers(
|
||||
client, async_test_superuser.email, "SuperPassword123!"
|
||||
)
|
||||
|
||||
response = await client.get("/api/v1/users/not-a-uuid", headers=headers)
|
||||
|
||||
@@ -323,12 +357,14 @@ class TestUpdateUserById:
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_own_profile_by_id(self, client, async_test_user, test_db):
|
||||
"""Test updating own profile by ID."""
|
||||
headers = await get_auth_headers(client, async_test_user.email, "TestPassword123!")
|
||||
headers = await get_auth_headers(
|
||||
client, async_test_user.email, "TestPassword123!"
|
||||
)
|
||||
|
||||
response = await client.patch(
|
||||
f"/api/v1/users/{async_test_user.id}",
|
||||
headers=headers,
|
||||
json={"first_name": "SelfUpdated"}
|
||||
json={"first_name": "SelfUpdated"},
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
@@ -336,7 +372,9 @@ class TestUpdateUserById:
|
||||
assert data["first_name"] == "SelfUpdated"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_other_user_as_regular_user(self, client, async_test_user, test_db):
|
||||
async def test_update_other_user_as_regular_user(
|
||||
self, client, async_test_user, test_db
|
||||
):
|
||||
"""Test that regular users cannot update other profiles."""
|
||||
# Create another user
|
||||
other_user = User(
|
||||
@@ -344,18 +382,20 @@ class TestUpdateUserById:
|
||||
password_hash="hash",
|
||||
first_name="Other",
|
||||
is_active=True,
|
||||
is_superuser=False
|
||||
is_superuser=False,
|
||||
)
|
||||
test_db.add(other_user)
|
||||
test_db.commit()
|
||||
test_db.refresh(other_user)
|
||||
|
||||
headers = await get_auth_headers(client, async_test_user.email, "TestPassword123!")
|
||||
headers = await get_auth_headers(
|
||||
client, async_test_user.email, "TestPassword123!"
|
||||
)
|
||||
|
||||
response = await client.patch(
|
||||
f"/api/v1/users/{other_user.id}",
|
||||
headers=headers,
|
||||
json={"first_name": "Hacked"}
|
||||
json={"first_name": "Hacked"},
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_403_FORBIDDEN
|
||||
@@ -365,14 +405,18 @@ class TestUpdateUserById:
|
||||
assert other_user.first_name == "Other"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_other_user_as_superuser(self, client, async_test_superuser, async_test_user, test_db):
|
||||
async def test_update_other_user_as_superuser(
|
||||
self, client, async_test_superuser, async_test_user, test_db
|
||||
):
|
||||
"""Test that superusers can update other profiles."""
|
||||
headers = await get_auth_headers(client, async_test_superuser.email, "SuperPassword123!")
|
||||
headers = await get_auth_headers(
|
||||
client, async_test_superuser.email, "SuperPassword123!"
|
||||
)
|
||||
|
||||
response = await client.patch(
|
||||
f"/api/v1/users/{async_test_user.id}",
|
||||
headers=headers,
|
||||
json={"first_name": "AdminUpdated"}
|
||||
json={"first_name": "AdminUpdated"},
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
@@ -380,16 +424,20 @@ class TestUpdateUserById:
|
||||
assert data["first_name"] == "AdminUpdated"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_regular_user_cannot_modify_superuser_status(self, client, async_test_user):
|
||||
async def test_regular_user_cannot_modify_superuser_status(
|
||||
self, client, async_test_user
|
||||
):
|
||||
"""Test that regular users cannot change superuser status even if they try."""
|
||||
headers = await get_auth_headers(client, async_test_user.email, "TestPassword123!")
|
||||
headers = await get_auth_headers(
|
||||
client, async_test_user.email, "TestPassword123!"
|
||||
)
|
||||
|
||||
# is_superuser not in UserUpdate schema, so it gets ignored by Pydantic
|
||||
# Just verify the user stays the same
|
||||
response = await client.patch(
|
||||
f"/api/v1/users/{async_test_user.id}",
|
||||
headers=headers,
|
||||
json={"first_name": "Test"}
|
||||
json={"first_name": "Test"},
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
@@ -397,14 +445,18 @@ class TestUpdateUserById:
|
||||
assert data["is_superuser"] is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_superuser_can_update_users(self, client, async_test_superuser, async_test_user, test_db):
|
||||
async def test_superuser_can_update_users(
|
||||
self, client, async_test_superuser, async_test_user, test_db
|
||||
):
|
||||
"""Test that superusers can update other users."""
|
||||
headers = await get_auth_headers(client, async_test_superuser.email, "SuperPassword123!")
|
||||
headers = await get_auth_headers(
|
||||
client, async_test_superuser.email, "SuperPassword123!"
|
||||
)
|
||||
|
||||
response = await client.patch(
|
||||
f"/api/v1/users/{async_test_user.id}",
|
||||
headers=headers,
|
||||
json={"first_name": "AdminChanged", "is_active": False}
|
||||
json={"first_name": "AdminChanged", "is_active": False},
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
@@ -415,13 +467,13 @@ class TestUpdateUserById:
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_nonexistent_user(self, client, async_test_superuser):
|
||||
"""Test updating non-existent user."""
|
||||
headers = await get_auth_headers(client, async_test_superuser.email, "SuperPassword123!")
|
||||
headers = await get_auth_headers(
|
||||
client, async_test_superuser.email, "SuperPassword123!"
|
||||
)
|
||||
fake_id = uuid.uuid4()
|
||||
|
||||
response = await client.patch(
|
||||
f"/api/v1/users/{fake_id}",
|
||||
headers=headers,
|
||||
json={"first_name": "Ghost"}
|
||||
f"/api/v1/users/{fake_id}", headers=headers, json={"first_name": "Ghost"}
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_404_NOT_FOUND
|
||||
@@ -435,15 +487,17 @@ class TestChangePassword:
|
||||
@pytest.mark.asyncio
|
||||
async def test_change_password_success(self, client, async_test_user, test_db):
|
||||
"""Test successful password change."""
|
||||
headers = await get_auth_headers(client, async_test_user.email, "TestPassword123!")
|
||||
headers = await get_auth_headers(
|
||||
client, async_test_user.email, "TestPassword123!"
|
||||
)
|
||||
|
||||
response = await client.patch(
|
||||
"/api/v1/users/me/password",
|
||||
headers=headers,
|
||||
json={
|
||||
"current_password": "TestPassword123!",
|
||||
"new_password": "NewPassword123!"
|
||||
}
|
||||
"new_password": "NewPassword123!",
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
@@ -453,25 +507,24 @@ class TestChangePassword:
|
||||
# Verify can login with new password
|
||||
login_response = await client.post(
|
||||
"/api/v1/auth/login",
|
||||
json={
|
||||
"email": async_test_user.email,
|
||||
"password": "NewPassword123!"
|
||||
}
|
||||
json={"email": async_test_user.email, "password": "NewPassword123!"},
|
||||
)
|
||||
assert login_response.status_code == status.HTTP_200_OK
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_change_password_wrong_current(self, client, async_test_user):
|
||||
"""Test that wrong current password is rejected."""
|
||||
headers = await get_auth_headers(client, async_test_user.email, "TestPassword123!")
|
||||
headers = await get_auth_headers(
|
||||
client, async_test_user.email, "TestPassword123!"
|
||||
)
|
||||
|
||||
response = await client.patch(
|
||||
"/api/v1/users/me/password",
|
||||
headers=headers,
|
||||
json={
|
||||
"current_password": "WrongPassword123",
|
||||
"new_password": "NewPassword123!"
|
||||
}
|
||||
"new_password": "NewPassword123!",
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_403_FORBIDDEN
|
||||
@@ -479,15 +532,14 @@ class TestChangePassword:
|
||||
@pytest.mark.asyncio
|
||||
async def test_change_password_weak_new_password(self, client, async_test_user):
|
||||
"""Test that weak new passwords are rejected."""
|
||||
headers = await get_auth_headers(client, async_test_user.email, "TestPassword123!")
|
||||
headers = await get_auth_headers(
|
||||
client, async_test_user.email, "TestPassword123!"
|
||||
)
|
||||
|
||||
response = await client.patch(
|
||||
"/api/v1/users/me/password",
|
||||
headers=headers,
|
||||
json={
|
||||
"current_password": "TestPassword123!",
|
||||
"new_password": "weak"
|
||||
}
|
||||
json={"current_password": "TestPassword123!", "new_password": "weak"},
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY
|
||||
@@ -499,8 +551,8 @@ class TestChangePassword:
|
||||
"/api/v1/users/me/password",
|
||||
json={
|
||||
"current_password": "TestPassword123!",
|
||||
"new_password": "NewPassword123!"
|
||||
}
|
||||
"new_password": "NewPassword123!",
|
||||
},
|
||||
)
|
||||
assert response.status_code == status.HTTP_401_UNAUTHORIZED
|
||||
|
||||
@@ -511,9 +563,11 @@ class TestDeleteUser:
|
||||
"""Tests for DELETE /users/{user_id} endpoint."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_user_as_superuser(self, client, async_test_superuser, async_test_db):
|
||||
async def test_delete_user_as_superuser(
|
||||
self, client, async_test_superuser, async_test_db
|
||||
):
|
||||
"""Test deleting a user as superuser."""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
# Create a user to delete
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
@@ -522,14 +576,16 @@ class TestDeleteUser:
|
||||
password_hash="hash",
|
||||
first_name="Delete",
|
||||
is_active=True,
|
||||
is_superuser=False
|
||||
is_superuser=False,
|
||||
)
|
||||
session.add(user_to_delete)
|
||||
await session.commit()
|
||||
await session.refresh(user_to_delete)
|
||||
user_id = user_to_delete.id
|
||||
|
||||
headers = await get_auth_headers(client, async_test_superuser.email, "SuperPassword123!")
|
||||
headers = await get_auth_headers(
|
||||
client, async_test_superuser.email, "SuperPassword123!"
|
||||
)
|
||||
|
||||
response = await client.delete(f"/api/v1/users/{user_id}", headers=headers)
|
||||
|
||||
@@ -540,6 +596,7 @@ class TestDeleteUser:
|
||||
# Verify user is soft-deleted (has deleted_at timestamp)
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
from sqlalchemy import select
|
||||
|
||||
result = await session.execute(select(User).where(User.id == user_id))
|
||||
deleted_user = result.scalar_one_or_none()
|
||||
assert deleted_user.deleted_at is not None
|
||||
@@ -547,9 +604,13 @@ class TestDeleteUser:
|
||||
@pytest.mark.asyncio
|
||||
async def test_cannot_delete_self(self, client, async_test_superuser):
|
||||
"""Test that users cannot delete their own account."""
|
||||
headers = await get_auth_headers(client, async_test_superuser.email, "SuperPassword123!")
|
||||
headers = await get_auth_headers(
|
||||
client, async_test_superuser.email, "SuperPassword123!"
|
||||
)
|
||||
|
||||
response = await client.delete(f"/api/v1/users/{async_test_superuser.id}", headers=headers)
|
||||
response = await client.delete(
|
||||
f"/api/v1/users/{async_test_superuser.id}", headers=headers
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_403_FORBIDDEN
|
||||
|
||||
@@ -562,22 +623,28 @@ class TestDeleteUser:
|
||||
password_hash="hash",
|
||||
first_name="Protected",
|
||||
is_active=True,
|
||||
is_superuser=False
|
||||
is_superuser=False,
|
||||
)
|
||||
test_db.add(other_user)
|
||||
test_db.commit()
|
||||
test_db.refresh(other_user)
|
||||
|
||||
headers = await get_auth_headers(client, async_test_user.email, "TestPassword123!")
|
||||
headers = await get_auth_headers(
|
||||
client, async_test_user.email, "TestPassword123!"
|
||||
)
|
||||
|
||||
response = await client.delete(f"/api/v1/users/{other_user.id}", headers=headers)
|
||||
response = await client.delete(
|
||||
f"/api/v1/users/{other_user.id}", headers=headers
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_403_FORBIDDEN
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_nonexistent_user(self, client, async_test_superuser):
|
||||
"""Test deleting non-existent user."""
|
||||
headers = await get_auth_headers(client, async_test_superuser.email, "SuperPassword123!")
|
||||
headers = await get_auth_headers(
|
||||
client, async_test_superuser.email, "SuperPassword123!"
|
||||
)
|
||||
fake_id = uuid.uuid4()
|
||||
|
||||
response = await client.delete(f"/api/v1/users/{fake_id}", headers=headers)
|
||||
|
||||
@@ -2,10 +2,12 @@
|
||||
"""
|
||||
Tests for user routes.
|
||||
"""
|
||||
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
from fastapi import status
|
||||
from uuid import uuid4
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
@@ -13,10 +15,7 @@ async def superuser_token(client, async_test_superuser):
|
||||
"""Get access token for superuser."""
|
||||
response = await client.post(
|
||||
"/api/v1/auth/login",
|
||||
json={
|
||||
"email": "superuser@example.com",
|
||||
"password": "SuperPassword123!"
|
||||
}
|
||||
json={"email": "superuser@example.com", "password": "SuperPassword123!"},
|
||||
)
|
||||
assert response.status_code == 200
|
||||
return response.json()["access_token"]
|
||||
@@ -27,10 +26,7 @@ async def user_token(client, async_test_user):
|
||||
"""Get access token for regular user."""
|
||||
response = await client.post(
|
||||
"/api/v1/auth/login",
|
||||
json={
|
||||
"email": "testuser@example.com",
|
||||
"password": "TestPassword123!"
|
||||
}
|
||||
json={"email": "testuser@example.com", "password": "TestPassword123!"},
|
||||
)
|
||||
assert response.status_code == 200
|
||||
return response.json()["access_token"]
|
||||
@@ -43,8 +39,7 @@ class TestListUsers:
|
||||
async def test_list_users_success(self, client, superuser_token):
|
||||
"""Test listing users successfully (covers lines 87-100)."""
|
||||
response = await client.get(
|
||||
"/api/v1/users",
|
||||
headers={"Authorization": f"Bearer {superuser_token}"}
|
||||
"/api/v1/users", headers={"Authorization": f"Bearer {superuser_token}"}
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
@@ -58,7 +53,7 @@ class TestListUsers:
|
||||
"""Test listing users with is_superuser filter (covers line 74)."""
|
||||
response = await client.get(
|
||||
"/api/v1/users?is_superuser=true",
|
||||
headers={"Authorization": f"Bearer {superuser_token}"}
|
||||
headers={"Authorization": f"Bearer {superuser_token}"},
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
@@ -73,8 +68,7 @@ class TestGetCurrentUser:
|
||||
async def test_get_current_user_success(self, client, async_test_user, user_token):
|
||||
"""Test getting current user profile."""
|
||||
response = await client.get(
|
||||
"/api/v1/users/me",
|
||||
headers={"Authorization": f"Bearer {user_token}"}
|
||||
"/api/v1/users/me", headers={"Authorization": f"Bearer {user_token}"}
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
@@ -92,7 +86,7 @@ class TestUpdateCurrentUser:
|
||||
response = await client.patch(
|
||||
"/api/v1/users/me",
|
||||
headers={"Authorization": f"Bearer {user_token}"},
|
||||
json={"first_name": "UpdatedName"}
|
||||
json={"first_name": "UpdatedName"},
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
@@ -104,12 +98,14 @@ class TestUpdateCurrentUser:
|
||||
"""Test database error handling during update (covers lines 162-169)."""
|
||||
from unittest.mock import patch
|
||||
|
||||
with patch('app.api.routes.users.user_crud.update', side_effect=Exception("DB error")):
|
||||
with patch(
|
||||
"app.api.routes.users.user_crud.update", side_effect=Exception("DB error")
|
||||
):
|
||||
with pytest.raises(Exception):
|
||||
await client.patch(
|
||||
"/api/v1/users/me",
|
||||
headers={"Authorization": f"Bearer {user_token}"},
|
||||
json={"first_name": "Updated"}
|
||||
json={"first_name": "Updated"},
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -118,7 +114,7 @@ class TestUpdateCurrentUser:
|
||||
response = await client.patch(
|
||||
"/api/v1/users/me",
|
||||
headers={"Authorization": f"Bearer {user_token}"},
|
||||
json={"is_superuser": True}
|
||||
json={"is_superuser": True},
|
||||
)
|
||||
|
||||
# Pydantic validation should reject this at the schema level
|
||||
@@ -137,12 +133,15 @@ class TestUpdateCurrentUser:
|
||||
"""Test ValueError handling during update (covers lines 165-166)."""
|
||||
from unittest.mock import patch
|
||||
|
||||
with patch('app.api.routes.users.user_crud.update', side_effect=ValueError("Invalid value")):
|
||||
with patch(
|
||||
"app.api.routes.users.user_crud.update",
|
||||
side_effect=ValueError("Invalid value"),
|
||||
):
|
||||
with pytest.raises(ValueError):
|
||||
await client.patch(
|
||||
"/api/v1/users/me",
|
||||
headers={"Authorization": f"Bearer {user_token}"},
|
||||
json={"first_name": "Updated"}
|
||||
json={"first_name": "Updated"},
|
||||
)
|
||||
|
||||
|
||||
@@ -154,7 +153,7 @@ class TestGetUser:
|
||||
"""Test getting user by ID."""
|
||||
response = await client.get(
|
||||
f"/api/v1/users/{async_test_user.id}",
|
||||
headers={"Authorization": f"Bearer {superuser_token}"}
|
||||
headers={"Authorization": f"Bearer {superuser_token}"},
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
@@ -167,7 +166,7 @@ class TestGetUser:
|
||||
fake_id = uuid4()
|
||||
response = await client.get(
|
||||
f"/api/v1/users/{fake_id}",
|
||||
headers={"Authorization": f"Bearer {superuser_token}"}
|
||||
headers={"Authorization": f"Bearer {superuser_token}"},
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_404_NOT_FOUND
|
||||
@@ -183,30 +182,34 @@ class TestUpdateUserById:
|
||||
response = await client.patch(
|
||||
f"/api/v1/users/{fake_id}",
|
||||
headers={"Authorization": f"Bearer {superuser_token}"},
|
||||
json={"first_name": "Updated"}
|
||||
json={"first_name": "Updated"},
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_404_NOT_FOUND
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_user_by_id_non_superuser_cannot_change_superuser_status(self, client, async_test_user, user_token):
|
||||
async def test_update_user_by_id_non_superuser_cannot_change_superuser_status(
|
||||
self, client, async_test_user, user_token
|
||||
):
|
||||
"""Test non-superuser cannot modify superuser status (Pydantic validation)."""
|
||||
response = await client.patch(
|
||||
f"/api/v1/users/{async_test_user.id}",
|
||||
headers={"Authorization": f"Bearer {user_token}"},
|
||||
json={"is_superuser": True}
|
||||
json={"is_superuser": True},
|
||||
)
|
||||
|
||||
# Pydantic validation should reject this at the schema level
|
||||
assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_user_by_id_success(self, client, async_test_user, superuser_token):
|
||||
async def test_update_user_by_id_success(
|
||||
self, client, async_test_user, superuser_token
|
||||
):
|
||||
"""Test updating user successfully (covers lines 276-278)."""
|
||||
response = await client.patch(
|
||||
f"/api/v1/users/{async_test_user.id}",
|
||||
headers={"Authorization": f"Bearer {superuser_token}"},
|
||||
json={"first_name": "SuperUpdated"}
|
||||
json={"first_name": "SuperUpdated"},
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
@@ -214,29 +217,37 @@ class TestUpdateUserById:
|
||||
assert data["first_name"] == "SuperUpdated"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_user_by_id_value_error(self, client, async_test_user, superuser_token):
|
||||
async def test_update_user_by_id_value_error(
|
||||
self, client, async_test_user, superuser_token
|
||||
):
|
||||
"""Test ValueError handling (covers lines 280-281)."""
|
||||
from unittest.mock import patch
|
||||
|
||||
with patch('app.api.routes.users.user_crud.update', side_effect=ValueError("Invalid")):
|
||||
with patch(
|
||||
"app.api.routes.users.user_crud.update", side_effect=ValueError("Invalid")
|
||||
):
|
||||
with pytest.raises(ValueError):
|
||||
await client.patch(
|
||||
f"/api/v1/users/{async_test_user.id}",
|
||||
headers={"Authorization": f"Bearer {superuser_token}"},
|
||||
json={"first_name": "Updated"}
|
||||
json={"first_name": "Updated"},
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_user_by_id_unexpected_error(self, client, async_test_user, superuser_token):
|
||||
async def test_update_user_by_id_unexpected_error(
|
||||
self, client, async_test_user, superuser_token
|
||||
):
|
||||
"""Test unexpected error handling (covers lines 283-284)."""
|
||||
from unittest.mock import patch
|
||||
|
||||
with patch('app.api.routes.users.user_crud.update', side_effect=Exception("Unexpected")):
|
||||
with patch(
|
||||
"app.api.routes.users.user_crud.update", side_effect=Exception("Unexpected")
|
||||
):
|
||||
with pytest.raises(Exception):
|
||||
await client.patch(
|
||||
f"/api/v1/users/{async_test_user.id}",
|
||||
headers={"Authorization": f"Bearer {superuser_token}"},
|
||||
json={"first_name": "Updated"}
|
||||
json={"first_name": "Updated"},
|
||||
)
|
||||
|
||||
|
||||
@@ -246,18 +257,18 @@ class TestChangePassword:
|
||||
@pytest.mark.asyncio
|
||||
async def test_change_password_success(self, client, async_test_db):
|
||||
"""Test changing password successfully."""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
# Create a fresh user
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
from app.models.user import User
|
||||
from app.core.auth import get_password_hash
|
||||
from app.models.user import User
|
||||
|
||||
new_user = User(
|
||||
email="changepass@example.com",
|
||||
password_hash=get_password_hash("OldPassword123!"),
|
||||
first_name="Change",
|
||||
last_name="Pass"
|
||||
last_name="Pass",
|
||||
)
|
||||
session.add(new_user)
|
||||
await session.commit()
|
||||
@@ -265,10 +276,7 @@ class TestChangePassword:
|
||||
# Login
|
||||
login_response = await client.post(
|
||||
"/api/v1/auth/login",
|
||||
json={
|
||||
"email": "changepass@example.com",
|
||||
"password": "OldPassword123!"
|
||||
}
|
||||
json={"email": "changepass@example.com", "password": "OldPassword123!"},
|
||||
)
|
||||
token = login_response.json()["access_token"]
|
||||
|
||||
@@ -278,8 +286,8 @@ class TestChangePassword:
|
||||
headers={"Authorization": f"Bearer {token}"},
|
||||
json={
|
||||
"current_password": "OldPassword123!",
|
||||
"new_password": "NewPassword456!"
|
||||
}
|
||||
"new_password": "NewPassword456!",
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
@@ -289,10 +297,7 @@ class TestChangePassword:
|
||||
# Verify new password works
|
||||
login_response = await client.post(
|
||||
"/api/v1/auth/login",
|
||||
json={
|
||||
"email": "changepass@example.com",
|
||||
"password": "NewPassword456!"
|
||||
}
|
||||
json={"email": "changepass@example.com", "password": "NewPassword456!"},
|
||||
)
|
||||
assert login_response.status_code == status.HTTP_200_OK
|
||||
|
||||
@@ -306,7 +311,7 @@ class TestDeleteUserById:
|
||||
fake_id = uuid4()
|
||||
response = await client.delete(
|
||||
f"/api/v1/users/{fake_id}",
|
||||
headers={"Authorization": f"Bearer {superuser_token}"}
|
||||
headers={"Authorization": f"Bearer {superuser_token}"},
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_404_NOT_FOUND
|
||||
@@ -314,18 +319,18 @@ class TestDeleteUserById:
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_user_success(self, client, async_test_db, superuser_token):
|
||||
"""Test deleting user successfully (covers lines 383-388)."""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
# Create a user to delete
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
from app.models.user import User
|
||||
from app.core.auth import get_password_hash
|
||||
from app.models.user import User
|
||||
|
||||
user_to_delete = User(
|
||||
email=f"delete{uuid4().hex[:8]}@example.com",
|
||||
password_hash=get_password_hash("Password123!"),
|
||||
first_name="Delete",
|
||||
last_name="Me"
|
||||
last_name="Me",
|
||||
)
|
||||
session.add(user_to_delete)
|
||||
await session.commit()
|
||||
@@ -334,7 +339,7 @@ class TestDeleteUserById:
|
||||
|
||||
response = await client.delete(
|
||||
f"/api/v1/users/{user_id}",
|
||||
headers={"Authorization": f"Bearer {superuser_token}"}
|
||||
headers={"Authorization": f"Bearer {superuser_token}"},
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
@@ -342,25 +347,35 @@ class TestDeleteUserById:
|
||||
assert data["success"] is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_user_value_error(self, client, async_test_user, superuser_token):
|
||||
async def test_delete_user_value_error(
|
||||
self, client, async_test_user, superuser_token
|
||||
):
|
||||
"""Test ValueError handling during delete (covers lines 390-391)."""
|
||||
from unittest.mock import patch
|
||||
|
||||
with patch('app.api.routes.users.user_crud.soft_delete', side_effect=ValueError("Cannot delete")):
|
||||
with patch(
|
||||
"app.api.routes.users.user_crud.soft_delete",
|
||||
side_effect=ValueError("Cannot delete"),
|
||||
):
|
||||
with pytest.raises(ValueError):
|
||||
await client.delete(
|
||||
f"/api/v1/users/{async_test_user.id}",
|
||||
headers={"Authorization": f"Bearer {superuser_token}"}
|
||||
headers={"Authorization": f"Bearer {superuser_token}"},
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_user_unexpected_error(self, client, async_test_user, superuser_token):
|
||||
async def test_delete_user_unexpected_error(
|
||||
self, client, async_test_user, superuser_token
|
||||
):
|
||||
"""Test unexpected error handling during delete (covers lines 393-394)."""
|
||||
from unittest.mock import patch
|
||||
|
||||
with patch('app.api.routes.users.user_crud.soft_delete', side_effect=Exception("Unexpected")):
|
||||
with patch(
|
||||
"app.api.routes.users.user_crud.soft_delete",
|
||||
side_effect=Exception("Unexpected"),
|
||||
):
|
||||
with pytest.raises(Exception):
|
||||
await client.delete(
|
||||
f"/api/v1/users/{async_test_user.id}",
|
||||
headers={"Authorization": f"Bearer {superuser_token}"}
|
||||
headers={"Authorization": f"Bearer {superuser_token}"},
|
||||
)
|
||||
|
||||
@@ -1,28 +1,32 @@
|
||||
# tests/conftest.py
|
||||
import os
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
from httpx import AsyncClient, ASGITransport
|
||||
from httpx import ASGITransport, AsyncClient
|
||||
|
||||
# Set IS_TEST environment variable BEFORE importing app
|
||||
# This prevents the scheduler from starting during tests
|
||||
os.environ["IS_TEST"] = "True"
|
||||
|
||||
from app.main import app
|
||||
from app.core.database import get_db
|
||||
from app.models.user import User
|
||||
from app.core.auth import get_password_hash
|
||||
from app.utils.test_utils import setup_test_db, teardown_test_db, setup_async_test_db, teardown_async_test_db
|
||||
from app.core.database import get_db
|
||||
from app.main import app
|
||||
from app.models.user import User
|
||||
from app.utils.test_utils import (
|
||||
setup_async_test_db,
|
||||
setup_test_db,
|
||||
teardown_async_test_db,
|
||||
teardown_test_db,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def db_session():
|
||||
"""
|
||||
Creates a fresh SQLite in-memory database for each test function.
|
||||
|
||||
|
||||
Yields a SQLAlchemy session that can be used for testing.
|
||||
"""
|
||||
# Set up the database
|
||||
@@ -46,6 +50,7 @@ async def async_test_db():
|
||||
yield test_engine, AsyncTestingSessionLocal
|
||||
await teardown_async_test_db(test_engine)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def user_create_data():
|
||||
return {
|
||||
@@ -55,7 +60,7 @@ def user_create_data():
|
||||
"last_name": "User",
|
||||
"phone_number": "+1234567890",
|
||||
"is_superuser": False,
|
||||
"preferences": None
|
||||
"preferences": None,
|
||||
}
|
||||
|
||||
|
||||
@@ -102,7 +107,7 @@ async def client(async_test_db):
|
||||
|
||||
This overrides the get_db dependency to use the test database.
|
||||
"""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
async def override_get_db():
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
@@ -176,7 +181,7 @@ async def async_test_user(async_test_db):
|
||||
|
||||
Password: TestPassword123
|
||||
"""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
user = User(
|
||||
id=uuid.uuid4(),
|
||||
@@ -202,7 +207,7 @@ async def async_test_superuser(async_test_db):
|
||||
|
||||
Password: SuperPassword123
|
||||
"""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
user = User(
|
||||
id=uuid.uuid4(),
|
||||
@@ -256,4 +261,4 @@ async def superuser_token(client, async_test_superuser):
|
||||
)
|
||||
assert response.status_code == 200, f"Login failed: {response.text}"
|
||||
tokens = response.json()
|
||||
return tokens["access_token"]
|
||||
return tokens["access_token"]
|
||||
|
||||
@@ -1,20 +1,20 @@
|
||||
# tests/core/test_auth.py
|
||||
import uuid
|
||||
from datetime import UTC, datetime, timedelta
|
||||
|
||||
import pytest
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from jose import jwt
|
||||
from pydantic import ValidationError
|
||||
|
||||
from app.core.auth import (
|
||||
verify_password,
|
||||
get_password_hash,
|
||||
TokenExpiredError,
|
||||
TokenInvalidError,
|
||||
TokenMissingClaimError,
|
||||
create_access_token,
|
||||
create_refresh_token,
|
||||
decode_token,
|
||||
get_password_hash,
|
||||
get_token_data,
|
||||
TokenExpiredError,
|
||||
TokenInvalidError,
|
||||
TokenMissingClaimError
|
||||
verify_password,
|
||||
)
|
||||
from app.core.config import settings
|
||||
|
||||
@@ -58,15 +58,13 @@ class TestTokenCreation:
|
||||
custom_claims = {
|
||||
"email": "test@example.com",
|
||||
"first_name": "Test",
|
||||
"is_superuser": True
|
||||
"is_superuser": True,
|
||||
}
|
||||
token = create_access_token(subject=user_id, claims=custom_claims)
|
||||
|
||||
# Decode token to verify claims
|
||||
payload = jwt.decode(
|
||||
token,
|
||||
settings.SECRET_KEY,
|
||||
algorithms=[settings.ALGORITHM]
|
||||
token, settings.SECRET_KEY, algorithms=[settings.ALGORITHM]
|
||||
)
|
||||
|
||||
# Check standard claims
|
||||
@@ -87,9 +85,7 @@ class TestTokenCreation:
|
||||
|
||||
# Decode token to verify claims
|
||||
payload = jwt.decode(
|
||||
token,
|
||||
settings.SECRET_KEY,
|
||||
algorithms=[settings.ALGORITHM]
|
||||
token, settings.SECRET_KEY, algorithms=[settings.ALGORITHM]
|
||||
)
|
||||
|
||||
# Check standard claims
|
||||
@@ -105,23 +101,18 @@ class TestTokenCreation:
|
||||
expires = timedelta(minutes=5)
|
||||
|
||||
# Create token with specific expiration
|
||||
token = create_access_token(
|
||||
subject=user_id,
|
||||
expires_delta=expires
|
||||
)
|
||||
token = create_access_token(subject=user_id, expires_delta=expires)
|
||||
|
||||
# Decode token
|
||||
payload = jwt.decode(
|
||||
token,
|
||||
settings.SECRET_KEY,
|
||||
algorithms=[settings.ALGORITHM]
|
||||
token, settings.SECRET_KEY, algorithms=[settings.ALGORITHM]
|
||||
)
|
||||
|
||||
# Get actual expiration time from token
|
||||
expiration = datetime.fromtimestamp(payload["exp"], tz=timezone.utc)
|
||||
expiration = datetime.fromtimestamp(payload["exp"], tz=UTC)
|
||||
|
||||
# Calculate expected expiration (approximately)
|
||||
now = datetime.now(timezone.utc)
|
||||
now = datetime.now(UTC)
|
||||
expected_expiration = now + expires
|
||||
|
||||
# Difference should be small (less than 1 second)
|
||||
@@ -148,7 +139,7 @@ class TestTokenDecoding:
|
||||
user_id = str(uuid.uuid4())
|
||||
|
||||
# Create a token that's already expired by directly manipulating the payload
|
||||
now = datetime.now(timezone.utc)
|
||||
now = datetime.now(UTC)
|
||||
expired_time = now - timedelta(hours=1) # 1 hour in the past
|
||||
|
||||
# Create the expired token manually
|
||||
@@ -157,13 +148,11 @@ class TestTokenDecoding:
|
||||
"exp": int(expired_time.timestamp()), # Set expiration in the past
|
||||
"iat": int(now.timestamp()),
|
||||
"jti": str(uuid.uuid4()),
|
||||
"type": "access"
|
||||
"type": "access",
|
||||
}
|
||||
|
||||
expired_token = jwt.encode(
|
||||
payload,
|
||||
settings.SECRET_KEY,
|
||||
algorithm=settings.ALGORITHM
|
||||
payload, settings.SECRET_KEY, algorithm=settings.ALGORITHM
|
||||
)
|
||||
|
||||
# Attempting to decode should raise TokenExpiredError
|
||||
@@ -180,20 +169,16 @@ class TestTokenDecoding:
|
||||
def test_decode_token_with_missing_sub(self):
|
||||
"""Test that a token without 'sub' claim raises TokenMissingClaimError"""
|
||||
# Create a token without a subject
|
||||
now = datetime.now(timezone.utc)
|
||||
now = datetime.now(UTC)
|
||||
payload = {
|
||||
"exp": int((now + timedelta(minutes=30)).timestamp()),
|
||||
"iat": int(now.timestamp()),
|
||||
"jti": str(uuid.uuid4()),
|
||||
"type": "access"
|
||||
"type": "access",
|
||||
# No 'sub' claim
|
||||
}
|
||||
|
||||
token = jwt.encode(
|
||||
payload,
|
||||
settings.SECRET_KEY,
|
||||
algorithm=settings.ALGORITHM
|
||||
)
|
||||
token = jwt.encode(payload, settings.SECRET_KEY, algorithm=settings.ALGORITHM)
|
||||
|
||||
with pytest.raises(TokenMissingClaimError):
|
||||
decode_token(token)
|
||||
@@ -211,20 +196,16 @@ class TestTokenDecoding:
|
||||
"""Test that a token with invalid payload structure raises TokenInvalidError"""
|
||||
# Create a token with an invalid payload structure - missing 'sub' which is required
|
||||
# but including 'exp' to avoid the expiration check
|
||||
now = datetime.now(timezone.utc)
|
||||
now = datetime.now(UTC)
|
||||
payload = {
|
||||
# Missing "sub" field which is required
|
||||
"exp": int((now + timedelta(minutes=30)).timestamp()),
|
||||
"iat": int(now.timestamp()),
|
||||
"jti": str(uuid.uuid4()),
|
||||
"invalid_field": "test"
|
||||
"invalid_field": "test",
|
||||
}
|
||||
|
||||
token = jwt.encode(
|
||||
payload,
|
||||
settings.SECRET_KEY,
|
||||
algorithm=settings.ALGORITHM
|
||||
)
|
||||
token = jwt.encode(payload, settings.SECRET_KEY, algorithm=settings.ALGORITHM)
|
||||
|
||||
# Should raise TokenMissingClaimError due to missing 'sub'
|
||||
with pytest.raises(TokenMissingClaimError):
|
||||
@@ -236,11 +217,7 @@ class TestTokenDecoding:
|
||||
"exp": int((now + timedelta(minutes=30)).timestamp()),
|
||||
}
|
||||
|
||||
token = jwt.encode(
|
||||
payload,
|
||||
settings.SECRET_KEY,
|
||||
algorithm=settings.ALGORITHM
|
||||
)
|
||||
token = jwt.encode(payload, settings.SECRET_KEY, algorithm=settings.ALGORITHM)
|
||||
|
||||
# Should raise TokenInvalidError due to ValidationError
|
||||
with pytest.raises(TokenInvalidError):
|
||||
@@ -249,12 +226,9 @@ class TestTokenDecoding:
|
||||
def test_get_token_data(self):
|
||||
"""Test extracting TokenData from a token"""
|
||||
user_id = uuid.uuid4()
|
||||
token = create_access_token(
|
||||
subject=str(user_id),
|
||||
claims={"is_superuser": True}
|
||||
)
|
||||
token = create_access_token(subject=str(user_id), claims={"is_superuser": True})
|
||||
|
||||
token_data = get_token_data(token)
|
||||
|
||||
assert token_data.user_id == user_id
|
||||
assert token_data.is_superuser is True
|
||||
assert token_data.is_superuser is True
|
||||
|
||||
@@ -8,11 +8,11 @@ Critical security tests covering:
|
||||
|
||||
These tests cover critical security vulnerabilities that could be exploited.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from jose import jwt
|
||||
from datetime import datetime, timedelta, timezone
|
||||
|
||||
from app.core.auth import decode_token, create_access_token, TokenInvalidError
|
||||
from app.core.auth import TokenInvalidError, create_access_token, decode_token
|
||||
from app.core.config import settings
|
||||
|
||||
|
||||
@@ -46,13 +46,14 @@ class TestJWTAlgorithmSecurityAttacks:
|
||||
"""
|
||||
# Create a payload that would normally be valid (using timestamps)
|
||||
import time
|
||||
|
||||
now = int(time.time())
|
||||
|
||||
payload = {
|
||||
"sub": "user123",
|
||||
"exp": now + 3600, # 1 hour from now
|
||||
"iat": now,
|
||||
"type": "access"
|
||||
"type": "access",
|
||||
}
|
||||
|
||||
# Craft a malicious token with "alg: none"
|
||||
@@ -61,13 +62,13 @@ class TestJWTAlgorithmSecurityAttacks:
|
||||
import json
|
||||
|
||||
header = {"alg": "none", "typ": "JWT"}
|
||||
header_encoded = base64.urlsafe_b64encode(
|
||||
json.dumps(header).encode()
|
||||
).decode().rstrip("=")
|
||||
header_encoded = (
|
||||
base64.urlsafe_b64encode(json.dumps(header).encode()).decode().rstrip("=")
|
||||
)
|
||||
|
||||
payload_encoded = base64.urlsafe_b64encode(
|
||||
json.dumps(payload).encode()
|
||||
).decode().rstrip("=")
|
||||
payload_encoded = (
|
||||
base64.urlsafe_b64encode(json.dumps(payload).encode()).decode().rstrip("=")
|
||||
)
|
||||
|
||||
# Token with no signature (algorithm "none")
|
||||
malicious_token = f"{header_encoded}.{payload_encoded}."
|
||||
@@ -85,22 +86,17 @@ class TestJWTAlgorithmSecurityAttacks:
|
||||
import time
|
||||
|
||||
now = int(time.time())
|
||||
payload = {
|
||||
"sub": "user123",
|
||||
"exp": now + 3600,
|
||||
"iat": now,
|
||||
"type": "access"
|
||||
}
|
||||
payload = {"sub": "user123", "exp": now + 3600, "iat": now, "type": "access"}
|
||||
|
||||
# Try uppercase "NONE"
|
||||
header = {"alg": "NONE", "typ": "JWT"}
|
||||
header_encoded = base64.urlsafe_b64encode(
|
||||
json.dumps(header).encode()
|
||||
).decode().rstrip("=")
|
||||
header_encoded = (
|
||||
base64.urlsafe_b64encode(json.dumps(header).encode()).decode().rstrip("=")
|
||||
)
|
||||
|
||||
payload_encoded = base64.urlsafe_b64encode(
|
||||
json.dumps(payload).encode()
|
||||
).decode().rstrip("=")
|
||||
payload_encoded = (
|
||||
base64.urlsafe_b64encode(json.dumps(payload).encode()).decode().rstrip("=")
|
||||
)
|
||||
|
||||
malicious_token = f"{header_encoded}.{payload_encoded}."
|
||||
|
||||
@@ -121,15 +117,11 @@ class TestJWTAlgorithmSecurityAttacks:
|
||||
before our defensive checks at line 212. This is good for security!
|
||||
"""
|
||||
import time
|
||||
|
||||
now = int(time.time())
|
||||
|
||||
# Create a valid payload
|
||||
payload = {
|
||||
"sub": "user123",
|
||||
"exp": now + 3600,
|
||||
"iat": now,
|
||||
"type": "access"
|
||||
}
|
||||
payload = {"sub": "user123", "exp": now + 3600, "iat": now, "type": "access"}
|
||||
|
||||
# Encode with wrong algorithm (RS256 instead of HS256)
|
||||
# This simulates an attacker trying algorithm substitution
|
||||
@@ -137,9 +129,7 @@ class TestJWTAlgorithmSecurityAttacks:
|
||||
|
||||
try:
|
||||
malicious_token = jwt.encode(
|
||||
payload,
|
||||
settings.SECRET_KEY,
|
||||
algorithm=wrong_algorithm
|
||||
payload, settings.SECRET_KEY, algorithm=wrong_algorithm
|
||||
)
|
||||
|
||||
# Should reject the token (library catches mismatch)
|
||||
@@ -156,21 +146,15 @@ class TestJWTAlgorithmSecurityAttacks:
|
||||
Prevents algorithm downgrade/upgrade attacks.
|
||||
"""
|
||||
import time
|
||||
|
||||
now = int(time.time())
|
||||
|
||||
payload = {
|
||||
"sub": "user123",
|
||||
"exp": now + 3600,
|
||||
"iat": now,
|
||||
"type": "access"
|
||||
}
|
||||
payload = {"sub": "user123", "exp": now + 3600, "iat": now, "type": "access"}
|
||||
|
||||
# Create token with HS384 instead of HS256
|
||||
try:
|
||||
malicious_token = jwt.encode(
|
||||
payload,
|
||||
settings.SECRET_KEY,
|
||||
algorithm="HS384"
|
||||
payload, settings.SECRET_KEY, algorithm="HS384"
|
||||
)
|
||||
|
||||
with pytest.raises(TokenInvalidError):
|
||||
@@ -223,20 +207,15 @@ class TestJWTSecurityEdgeCases:
|
||||
|
||||
# Create token without "alg" in header
|
||||
header = {"typ": "JWT"} # Missing "alg"
|
||||
payload = {
|
||||
"sub": "user123",
|
||||
"exp": now + 3600,
|
||||
"iat": now,
|
||||
"type": "access"
|
||||
}
|
||||
payload = {"sub": "user123", "exp": now + 3600, "iat": now, "type": "access"}
|
||||
|
||||
header_encoded = base64.urlsafe_b64encode(
|
||||
json.dumps(header).encode()
|
||||
).decode().rstrip("=")
|
||||
header_encoded = (
|
||||
base64.urlsafe_b64encode(json.dumps(header).encode()).decode().rstrip("=")
|
||||
)
|
||||
|
||||
payload_encoded = base64.urlsafe_b64encode(
|
||||
json.dumps(payload).encode()
|
||||
).decode().rstrip("=")
|
||||
payload_encoded = (
|
||||
base64.urlsafe_b64encode(json.dumps(payload).encode()).decode().rstrip("=")
|
||||
)
|
||||
|
||||
malicious_token = f"{header_encoded}.{payload_encoded}.fake_signature"
|
||||
|
||||
@@ -253,15 +232,20 @@ class TestJWTSecurityEdgeCases:
|
||||
"""Test token with malformed JSON in payload."""
|
||||
import base64
|
||||
|
||||
header = {"alg": "HS256", "typ": "JWT"}
|
||||
header_encoded = base64.urlsafe_b64encode(
|
||||
b'{"alg":"HS256","typ":"JWT"}'
|
||||
).decode().rstrip("=")
|
||||
header_encoded = (
|
||||
base64.urlsafe_b64encode(b'{"alg":"HS256","typ":"JWT"}')
|
||||
.decode()
|
||||
.rstrip("=")
|
||||
)
|
||||
|
||||
# Invalid JSON (missing closing brace)
|
||||
invalid_payload_encoded = base64.urlsafe_b64encode(
|
||||
b'{"sub":"user123"' # Invalid JSON
|
||||
).decode().rstrip("=")
|
||||
invalid_payload_encoded = (
|
||||
base64.urlsafe_b64encode(
|
||||
b'{"sub":"user123"' # Invalid JSON
|
||||
)
|
||||
.decode()
|
||||
.rstrip("=")
|
||||
)
|
||||
|
||||
malicious_token = f"{header_encoded}.{invalid_payload_encoded}.fake_sig"
|
||||
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
# tests/core/test_config.py
|
||||
import pytest
|
||||
from pydantic import ValidationError
|
||||
|
||||
from app.core.config import Settings
|
||||
|
||||
|
||||
@@ -22,11 +23,15 @@ class TestSecretKeyValidation:
|
||||
with pytest.raises(ValidationError) as exc_info:
|
||||
Settings(SECRET_KEY=default_key, ENVIRONMENT="production")
|
||||
|
||||
assert "must be set to a secure random value in production" in str(exc_info.value)
|
||||
assert "must be set to a secure random value in production" in str(
|
||||
exc_info.value
|
||||
)
|
||||
|
||||
def test_default_secret_key_in_development_allows_with_warning(self, caplog):
|
||||
"""Test that default SECRET_KEY in development is allowed but warns"""
|
||||
settings = Settings(SECRET_KEY="your_secret_key_here" + "x" * 14, ENVIRONMENT="development")
|
||||
settings = Settings(
|
||||
SECRET_KEY="your_secret_key_here" + "x" * 14, ENVIRONMENT="development"
|
||||
)
|
||||
|
||||
assert settings.SECRET_KEY == "your_secret_key_here" + "x" * 14
|
||||
# Note: The warning happens during validation, which we've seen works
|
||||
@@ -44,19 +49,13 @@ class TestSuperuserPasswordValidation:
|
||||
|
||||
def test_none_password_accepted(self):
|
||||
"""Test that None password is accepted (optional field)"""
|
||||
settings = Settings(
|
||||
SECRET_KEY="a" * 32,
|
||||
FIRST_SUPERUSER_PASSWORD=None
|
||||
)
|
||||
settings = Settings(SECRET_KEY="a" * 32, FIRST_SUPERUSER_PASSWORD=None)
|
||||
assert settings.FIRST_SUPERUSER_PASSWORD is None
|
||||
|
||||
def test_password_too_short_raises_error(self):
|
||||
"""Test that password shorter than 12 characters raises error"""
|
||||
with pytest.raises(ValidationError) as exc_info:
|
||||
Settings(
|
||||
SECRET_KEY="a" * 32,
|
||||
FIRST_SUPERUSER_PASSWORD="Short1"
|
||||
)
|
||||
Settings(SECRET_KEY="a" * 32, FIRST_SUPERUSER_PASSWORD="Short1")
|
||||
|
||||
assert "must be at least 12 characters" in str(exc_info.value)
|
||||
|
||||
@@ -64,14 +63,11 @@ class TestSuperuserPasswordValidation:
|
||||
"""Test that common weak passwords are rejected"""
|
||||
# Test with the exact weak passwords from the validator
|
||||
# These are in the weak_passwords set and should be rejected
|
||||
weak_passwords = ['123456789012'] # Exactly 12 chars, in the weak set
|
||||
weak_passwords = ["123456789012"] # Exactly 12 chars, in the weak set
|
||||
|
||||
for weak_pwd in weak_passwords:
|
||||
with pytest.raises(ValidationError) as exc_info:
|
||||
Settings(
|
||||
SECRET_KEY="a" * 32,
|
||||
FIRST_SUPERUSER_PASSWORD=weak_pwd
|
||||
)
|
||||
Settings(SECRET_KEY="a" * 32, FIRST_SUPERUSER_PASSWORD=weak_pwd)
|
||||
# Should get "too weak" message
|
||||
error_str = str(exc_info.value)
|
||||
assert "too weak" in error_str
|
||||
@@ -79,30 +75,21 @@ class TestSuperuserPasswordValidation:
|
||||
def test_password_without_lowercase_rejected(self):
|
||||
"""Test that password without lowercase is rejected"""
|
||||
with pytest.raises(ValidationError) as exc_info:
|
||||
Settings(
|
||||
SECRET_KEY="a" * 32,
|
||||
FIRST_SUPERUSER_PASSWORD="ALLUPPERCASE123"
|
||||
)
|
||||
Settings(SECRET_KEY="a" * 32, FIRST_SUPERUSER_PASSWORD="ALLUPPERCASE123")
|
||||
|
||||
assert "must contain lowercase, uppercase, and digits" in str(exc_info.value)
|
||||
|
||||
def test_password_without_uppercase_rejected(self):
|
||||
"""Test that password without uppercase is rejected"""
|
||||
with pytest.raises(ValidationError) as exc_info:
|
||||
Settings(
|
||||
SECRET_KEY="a" * 32,
|
||||
FIRST_SUPERUSER_PASSWORD="alllowercase123"
|
||||
)
|
||||
Settings(SECRET_KEY="a" * 32, FIRST_SUPERUSER_PASSWORD="alllowercase123")
|
||||
|
||||
assert "must contain lowercase, uppercase, and digits" in str(exc_info.value)
|
||||
|
||||
def test_password_without_digit_rejected(self):
|
||||
"""Test that password without digit is rejected"""
|
||||
with pytest.raises(ValidationError) as exc_info:
|
||||
Settings(
|
||||
SECRET_KEY="a" * 32,
|
||||
FIRST_SUPERUSER_PASSWORD="NoDigitsHere"
|
||||
)
|
||||
Settings(SECRET_KEY="a" * 32, FIRST_SUPERUSER_PASSWORD="NoDigitsHere")
|
||||
|
||||
assert "must contain lowercase, uppercase, and digits" in str(exc_info.value)
|
||||
|
||||
@@ -110,8 +97,7 @@ class TestSuperuserPasswordValidation:
|
||||
"""Test that strong password is accepted"""
|
||||
strong_password = "StrongPassword123!"
|
||||
settings = Settings(
|
||||
SECRET_KEY="a" * 32,
|
||||
FIRST_SUPERUSER_PASSWORD=strong_password
|
||||
SECRET_KEY="a" * 32, FIRST_SUPERUSER_PASSWORD=strong_password
|
||||
)
|
||||
|
||||
assert settings.FIRST_SUPERUSER_PASSWORD == strong_password
|
||||
@@ -150,7 +136,7 @@ class TestDatabaseConfiguration:
|
||||
POSTGRES_HOST="testhost",
|
||||
POSTGRES_PORT="5432",
|
||||
POSTGRES_DB="testdb",
|
||||
DATABASE_URL=None # Don't use explicit URL
|
||||
DATABASE_URL=None, # Don't use explicit URL
|
||||
)
|
||||
|
||||
expected_url = "postgresql://testuser:testpass@testhost:5432/testdb"
|
||||
@@ -159,10 +145,7 @@ class TestDatabaseConfiguration:
|
||||
def test_explicit_database_url_used_when_set(self):
|
||||
"""Test that explicit DATABASE_URL is used when provided"""
|
||||
explicit_url = "postgresql://explicit:pass@host:5432/db"
|
||||
settings = Settings(
|
||||
SECRET_KEY="a" * 32,
|
||||
DATABASE_URL=explicit_url
|
||||
)
|
||||
settings = Settings(SECRET_KEY="a" * 32, DATABASE_URL=explicit_url)
|
||||
|
||||
assert settings.database_url == explicit_url
|
||||
|
||||
|
||||
@@ -6,8 +6,10 @@ Critical security tests covering:
|
||||
|
||||
These tests prevent security misconfigurations.
|
||||
"""
|
||||
import pytest
|
||||
|
||||
import os
|
||||
|
||||
import pytest
|
||||
from pydantic import ValidationError
|
||||
|
||||
|
||||
@@ -43,6 +45,7 @@ class TestSecretKeySecurityValidation:
|
||||
# Import Settings class fresh (to pick up new env var)
|
||||
# The ValidationError should be raised during reload when Settings() is instantiated
|
||||
import importlib
|
||||
|
||||
from app.core import config
|
||||
|
||||
# Reload will raise ValidationError because Settings() is instantiated at module level
|
||||
@@ -58,7 +61,9 @@ class TestSecretKeySecurityValidation:
|
||||
|
||||
# Reload config to restore original settings
|
||||
import importlib
|
||||
|
||||
from app.core import config
|
||||
|
||||
importlib.reload(config)
|
||||
|
||||
def test_secret_key_exactly_32_characters_accepted(self):
|
||||
@@ -75,7 +80,9 @@ class TestSecretKeySecurityValidation:
|
||||
os.environ["SECRET_KEY"] = key_32
|
||||
|
||||
import importlib
|
||||
|
||||
from app.core import config
|
||||
|
||||
importlib.reload(config)
|
||||
|
||||
# Should work
|
||||
@@ -89,7 +96,9 @@ class TestSecretKeySecurityValidation:
|
||||
os.environ.pop("SECRET_KEY", None)
|
||||
|
||||
import importlib
|
||||
|
||||
from app.core import config
|
||||
|
||||
importlib.reload(config)
|
||||
|
||||
def test_secret_key_long_enough_accepted(self):
|
||||
@@ -106,7 +115,9 @@ class TestSecretKeySecurityValidation:
|
||||
os.environ["SECRET_KEY"] = key_64
|
||||
|
||||
import importlib
|
||||
|
||||
from app.core import config
|
||||
|
||||
importlib.reload(config)
|
||||
|
||||
# Should work
|
||||
@@ -120,7 +131,9 @@ class TestSecretKeySecurityValidation:
|
||||
os.environ.pop("SECRET_KEY", None)
|
||||
|
||||
import importlib
|
||||
|
||||
from app.core import config
|
||||
|
||||
importlib.reload(config)
|
||||
|
||||
def test_default_secret_key_meets_requirements(self):
|
||||
@@ -132,4 +145,6 @@ class TestSecretKeySecurityValidation:
|
||||
from app.core.config import settings
|
||||
|
||||
# Current settings should have valid SECRET_KEY
|
||||
assert len(settings.SECRET_KEY) >= 32, "Default SECRET_KEY must be at least 32 chars"
|
||||
assert len(settings.SECRET_KEY) >= 32, (
|
||||
"Default SECRET_KEY must be at least 32 chars"
|
||||
)
|
||||
|
||||
@@ -9,18 +9,19 @@ Covers:
|
||||
- init_async_db
|
||||
- close_async_db
|
||||
"""
|
||||
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
from unittest.mock import patch, MagicMock, AsyncMock
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.core.database import (
|
||||
get_async_database_url,
|
||||
get_db,
|
||||
async_transaction_scope,
|
||||
check_async_database_health,
|
||||
init_async_db,
|
||||
close_async_db,
|
||||
get_async_database_url,
|
||||
get_db,
|
||||
init_async_db,
|
||||
)
|
||||
|
||||
|
||||
@@ -88,12 +89,13 @@ class TestAsyncTransactionScope:
|
||||
async def test_transaction_scope_commits_on_success(self, async_test_db):
|
||||
"""Test that successful operations are committed (covers line 138)."""
|
||||
# Mock the transaction scope to use test database
|
||||
test_engine, SessionLocal = async_test_db
|
||||
_test_engine, SessionLocal = async_test_db
|
||||
|
||||
with patch('app.core.database.SessionLocal', SessionLocal):
|
||||
with patch("app.core.database.SessionLocal", SessionLocal):
|
||||
async with async_transaction_scope() as db:
|
||||
# Execute a simple query to verify transaction works
|
||||
from sqlalchemy import text
|
||||
|
||||
result = await db.execute(text("SELECT 1"))
|
||||
assert result is not None
|
||||
# Transaction should be committed (covers line 138 debug log)
|
||||
@@ -101,12 +103,13 @@ class TestAsyncTransactionScope:
|
||||
@pytest.mark.asyncio
|
||||
async def test_transaction_scope_rollback_on_error(self, async_test_db):
|
||||
"""Test that transaction rolls back on exception."""
|
||||
test_engine, SessionLocal = async_test_db
|
||||
_test_engine, SessionLocal = async_test_db
|
||||
|
||||
with patch('app.core.database.SessionLocal', SessionLocal):
|
||||
with patch("app.core.database.SessionLocal", SessionLocal):
|
||||
with pytest.raises(RuntimeError, match="Test error"):
|
||||
async with async_transaction_scope() as db:
|
||||
from sqlalchemy import text
|
||||
|
||||
await db.execute(text("SELECT 1"))
|
||||
raise RuntimeError("Test error")
|
||||
|
||||
@@ -117,9 +120,9 @@ class TestCheckAsyncDatabaseHealth:
|
||||
@pytest.mark.asyncio
|
||||
async def test_database_health_check_success(self, async_test_db):
|
||||
"""Test health check returns True on success (covers line 156)."""
|
||||
test_engine, SessionLocal = async_test_db
|
||||
_test_engine, SessionLocal = async_test_db
|
||||
|
||||
with patch('app.core.database.SessionLocal', SessionLocal):
|
||||
with patch("app.core.database.SessionLocal", SessionLocal):
|
||||
result = await check_async_database_health()
|
||||
assert result is True
|
||||
|
||||
@@ -127,7 +130,7 @@ class TestCheckAsyncDatabaseHealth:
|
||||
async def test_database_health_check_failure(self):
|
||||
"""Test health check returns False on database error."""
|
||||
# Mock async_transaction_scope to raise an error
|
||||
with patch('app.core.database.async_transaction_scope') as mock_scope:
|
||||
with patch("app.core.database.async_transaction_scope") as mock_scope:
|
||||
mock_scope.side_effect = Exception("Database connection failed")
|
||||
|
||||
result = await check_async_database_health()
|
||||
@@ -140,10 +143,10 @@ class TestInitAsyncDb:
|
||||
@pytest.mark.asyncio
|
||||
async def test_init_async_db_creates_tables(self, async_test_db):
|
||||
"""Test init_async_db creates tables (covers lines 174-176)."""
|
||||
test_engine, SessionLocal = async_test_db
|
||||
test_engine, _SessionLocal = async_test_db
|
||||
|
||||
# Mock the engine to use test engine
|
||||
with patch('app.core.database.engine', test_engine):
|
||||
with patch("app.core.database.engine", test_engine):
|
||||
await init_async_db()
|
||||
# If no exception, tables were created successfully
|
||||
|
||||
@@ -155,7 +158,6 @@ class TestCloseAsyncDb:
|
||||
async def test_close_async_db_disposes_engine(self):
|
||||
"""Test close_async_db disposes engine (covers lines 185-186)."""
|
||||
# Create a fresh engine to test closing
|
||||
from app.core.database import engine
|
||||
|
||||
# Close connections
|
||||
await close_async_db()
|
||||
|
||||
@@ -2,14 +2,16 @@
|
||||
"""
|
||||
Comprehensive tests for CRUDBase class covering all error paths and edge cases.
|
||||
"""
|
||||
|
||||
from datetime import UTC
|
||||
from unittest.mock import patch
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
from uuid import uuid4, UUID
|
||||
from sqlalchemy.exc import IntegrityError, OperationalError, DataError
|
||||
from sqlalchemy.exc import DataError, IntegrityError, OperationalError
|
||||
from sqlalchemy.orm import joinedload
|
||||
from unittest.mock import AsyncMock, patch, MagicMock
|
||||
|
||||
from app.crud.user import user as user_crud
|
||||
from app.models.user import User
|
||||
from app.schemas.users import UserCreate, UserUpdate
|
||||
|
||||
|
||||
@@ -19,7 +21,7 @@ class TestCRUDBaseGet:
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_with_invalid_uuid_string(self, async_test_db):
|
||||
"""Test get with invalid UUID string returns None."""
|
||||
test_engine, SessionLocal = async_test_db
|
||||
_test_engine, SessionLocal = async_test_db
|
||||
|
||||
async with SessionLocal() as session:
|
||||
result = await user_crud.get(session, id="invalid-uuid")
|
||||
@@ -28,7 +30,7 @@ class TestCRUDBaseGet:
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_with_invalid_uuid_type(self, async_test_db):
|
||||
"""Test get with invalid UUID type returns None."""
|
||||
test_engine, SessionLocal = async_test_db
|
||||
_test_engine, SessionLocal = async_test_db
|
||||
|
||||
async with SessionLocal() as session:
|
||||
result = await user_crud.get(session, id=12345) # int instead of UUID
|
||||
@@ -37,7 +39,7 @@ class TestCRUDBaseGet:
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_with_uuid_object(self, async_test_db, async_test_user):
|
||||
"""Test get with UUID object instead of string."""
|
||||
test_engine, SessionLocal = async_test_db
|
||||
_test_engine, SessionLocal = async_test_db
|
||||
|
||||
async with SessionLocal() as session:
|
||||
# Pass UUID object directly
|
||||
@@ -48,26 +50,24 @@ class TestCRUDBaseGet:
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_with_options(self, async_test_db, async_test_user):
|
||||
"""Test get with eager loading options (tests lines 76-78)."""
|
||||
test_engine, SessionLocal = async_test_db
|
||||
_test_engine, SessionLocal = async_test_db
|
||||
|
||||
async with SessionLocal() as session:
|
||||
# Test that options parameter is accepted and doesn't error
|
||||
# We pass an empty list which still tests the code path
|
||||
result = await user_crud.get(
|
||||
session,
|
||||
id=str(async_test_user.id),
|
||||
options=[]
|
||||
session, id=str(async_test_user.id), options=[]
|
||||
)
|
||||
assert result is not None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_database_error(self, async_test_db):
|
||||
"""Test get handles database errors properly."""
|
||||
test_engine, SessionLocal = async_test_db
|
||||
_test_engine, SessionLocal = async_test_db
|
||||
|
||||
async with SessionLocal() as session:
|
||||
# Mock execute to raise an exception
|
||||
with patch.object(session, 'execute', side_effect=Exception("DB error")):
|
||||
with patch.object(session, "execute", side_effect=Exception("DB error")):
|
||||
with pytest.raises(Exception, match="DB error"):
|
||||
await user_crud.get(session, id=str(uuid4()))
|
||||
|
||||
@@ -78,7 +78,7 @@ class TestCRUDBaseGetMulti:
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_multi_negative_skip(self, async_test_db):
|
||||
"""Test get_multi with negative skip raises ValueError."""
|
||||
test_engine, SessionLocal = async_test_db
|
||||
_test_engine, SessionLocal = async_test_db
|
||||
|
||||
async with SessionLocal() as session:
|
||||
with pytest.raises(ValueError, match="skip must be non-negative"):
|
||||
@@ -87,7 +87,7 @@ class TestCRUDBaseGetMulti:
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_multi_negative_limit(self, async_test_db):
|
||||
"""Test get_multi with negative limit raises ValueError."""
|
||||
test_engine, SessionLocal = async_test_db
|
||||
_test_engine, SessionLocal = async_test_db
|
||||
|
||||
async with SessionLocal() as session:
|
||||
with pytest.raises(ValueError, match="limit must be non-negative"):
|
||||
@@ -96,7 +96,7 @@ class TestCRUDBaseGetMulti:
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_multi_limit_too_large(self, async_test_db):
|
||||
"""Test get_multi with limit > 1000 raises ValueError."""
|
||||
test_engine, SessionLocal = async_test_db
|
||||
_test_engine, SessionLocal = async_test_db
|
||||
|
||||
async with SessionLocal() as session:
|
||||
with pytest.raises(ValueError, match="Maximum limit is 1000"):
|
||||
@@ -105,25 +105,20 @@ class TestCRUDBaseGetMulti:
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_multi_with_options(self, async_test_db, async_test_user):
|
||||
"""Test get_multi with eager loading options (tests lines 118-120)."""
|
||||
test_engine, SessionLocal = async_test_db
|
||||
_test_engine, SessionLocal = async_test_db
|
||||
|
||||
async with SessionLocal() as session:
|
||||
# Test that options parameter is accepted
|
||||
results = await user_crud.get_multi(
|
||||
session,
|
||||
skip=0,
|
||||
limit=10,
|
||||
options=[]
|
||||
)
|
||||
results = await user_crud.get_multi(session, skip=0, limit=10, options=[])
|
||||
assert isinstance(results, list)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_multi_database_error(self, async_test_db):
|
||||
"""Test get_multi handles database errors."""
|
||||
test_engine, SessionLocal = async_test_db
|
||||
_test_engine, SessionLocal = async_test_db
|
||||
|
||||
async with SessionLocal() as session:
|
||||
with patch.object(session, 'execute', side_effect=Exception("DB error")):
|
||||
with patch.object(session, "execute", side_effect=Exception("DB error")):
|
||||
with pytest.raises(Exception, match="DB error"):
|
||||
await user_crud.get_multi(session)
|
||||
|
||||
@@ -134,7 +129,7 @@ class TestCRUDBaseCreate:
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_duplicate_unique_field(self, async_test_db, async_test_user):
|
||||
"""Test create with duplicate unique field raises ValueError."""
|
||||
test_engine, SessionLocal = async_test_db
|
||||
_test_engine, SessionLocal = async_test_db
|
||||
|
||||
async with SessionLocal() as session:
|
||||
# Try to create user with duplicate email
|
||||
@@ -142,7 +137,7 @@ class TestCRUDBaseCreate:
|
||||
email=async_test_user.email, # Duplicate!
|
||||
password="TestPassword123!",
|
||||
first_name="Test",
|
||||
last_name="Duplicate"
|
||||
last_name="Duplicate",
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="already exists"):
|
||||
@@ -151,22 +146,23 @@ class TestCRUDBaseCreate:
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_integrity_error_non_duplicate(self, async_test_db):
|
||||
"""Test create with non-duplicate IntegrityError."""
|
||||
test_engine, SessionLocal = async_test_db
|
||||
_test_engine, SessionLocal = async_test_db
|
||||
|
||||
async with SessionLocal() as session:
|
||||
# Mock commit to raise IntegrityError without "unique" in message
|
||||
original_commit = session.commit
|
||||
|
||||
async def mock_commit():
|
||||
error = IntegrityError("statement", {}, Exception("foreign key violation"))
|
||||
error = IntegrityError(
|
||||
"statement", {}, Exception("foreign key violation")
|
||||
)
|
||||
raise error
|
||||
|
||||
with patch.object(session, 'commit', side_effect=mock_commit):
|
||||
with patch.object(session, "commit", side_effect=mock_commit):
|
||||
user_data = UserCreate(
|
||||
email="test@example.com",
|
||||
password="TestPassword123!",
|
||||
first_name="Test",
|
||||
last_name="User"
|
||||
last_name="User",
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="Database integrity error"):
|
||||
@@ -175,15 +171,21 @@ class TestCRUDBaseCreate:
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_operational_error(self, async_test_db):
|
||||
"""Test create with OperationalError (user CRUD catches as generic Exception)."""
|
||||
test_engine, SessionLocal = async_test_db
|
||||
_test_engine, SessionLocal = async_test_db
|
||||
|
||||
async with SessionLocal() as session:
|
||||
with patch.object(session, 'commit', side_effect=OperationalError("statement", {}, Exception("connection lost"))):
|
||||
with patch.object(
|
||||
session,
|
||||
"commit",
|
||||
side_effect=OperationalError(
|
||||
"statement", {}, Exception("connection lost")
|
||||
),
|
||||
):
|
||||
user_data = UserCreate(
|
||||
email="test@example.com",
|
||||
password="TestPassword123!",
|
||||
first_name="Test",
|
||||
last_name="User"
|
||||
last_name="User",
|
||||
)
|
||||
|
||||
# User CRUD catches this as generic Exception and re-raises
|
||||
@@ -193,15 +195,19 @@ class TestCRUDBaseCreate:
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_data_error(self, async_test_db):
|
||||
"""Test create with DataError (user CRUD catches as generic Exception)."""
|
||||
test_engine, SessionLocal = async_test_db
|
||||
_test_engine, SessionLocal = async_test_db
|
||||
|
||||
async with SessionLocal() as session:
|
||||
with patch.object(session, 'commit', side_effect=DataError("statement", {}, Exception("invalid data"))):
|
||||
with patch.object(
|
||||
session,
|
||||
"commit",
|
||||
side_effect=DataError("statement", {}, Exception("invalid data")),
|
||||
):
|
||||
user_data = UserCreate(
|
||||
email="test@example.com",
|
||||
password="TestPassword123!",
|
||||
first_name="Test",
|
||||
last_name="User"
|
||||
last_name="User",
|
||||
)
|
||||
|
||||
# User CRUD catches this as generic Exception and re-raises
|
||||
@@ -211,15 +217,17 @@ class TestCRUDBaseCreate:
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_unexpected_error(self, async_test_db):
|
||||
"""Test create with unexpected exception."""
|
||||
test_engine, SessionLocal = async_test_db
|
||||
_test_engine, SessionLocal = async_test_db
|
||||
|
||||
async with SessionLocal() as session:
|
||||
with patch.object(session, 'commit', side_effect=RuntimeError("Unexpected error")):
|
||||
with patch.object(
|
||||
session, "commit", side_effect=RuntimeError("Unexpected error")
|
||||
):
|
||||
user_data = UserCreate(
|
||||
email="test@example.com",
|
||||
password="TestPassword123!",
|
||||
first_name="Test",
|
||||
last_name="User"
|
||||
last_name="User",
|
||||
)
|
||||
|
||||
with pytest.raises(RuntimeError, match="Unexpected error"):
|
||||
@@ -232,16 +240,17 @@ class TestCRUDBaseUpdate:
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_duplicate_unique_field(self, async_test_db, async_test_user):
|
||||
"""Test update with duplicate unique field raises ValueError."""
|
||||
test_engine, SessionLocal = async_test_db
|
||||
_test_engine, SessionLocal = async_test_db
|
||||
|
||||
# Create another user
|
||||
async with SessionLocal() as session:
|
||||
from app.crud.user import user as user_crud
|
||||
|
||||
user2_data = UserCreate(
|
||||
email="user2@example.com",
|
||||
password="TestPassword123!",
|
||||
first_name="User",
|
||||
last_name="Two"
|
||||
last_name="Two",
|
||||
)
|
||||
user2 = await user_crud.create(session, obj_in=user2_data)
|
||||
await session.commit()
|
||||
@@ -250,63 +259,89 @@ class TestCRUDBaseUpdate:
|
||||
async with SessionLocal() as session:
|
||||
user2_obj = await user_crud.get(session, id=str(user2.id))
|
||||
|
||||
with patch.object(session, 'commit', side_effect=IntegrityError("statement", {}, Exception("UNIQUE constraint failed"))):
|
||||
with patch.object(
|
||||
session,
|
||||
"commit",
|
||||
side_effect=IntegrityError(
|
||||
"statement", {}, Exception("UNIQUE constraint failed")
|
||||
),
|
||||
):
|
||||
update_data = UserUpdate(email=async_test_user.email)
|
||||
|
||||
with pytest.raises(ValueError, match="already exists"):
|
||||
await user_crud.update(session, db_obj=user2_obj, obj_in=update_data)
|
||||
await user_crud.update(
|
||||
session, db_obj=user2_obj, obj_in=update_data
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_with_dict(self, async_test_db, async_test_user):
|
||||
"""Test update with dict instead of schema."""
|
||||
test_engine, SessionLocal = async_test_db
|
||||
_test_engine, SessionLocal = async_test_db
|
||||
|
||||
async with SessionLocal() as session:
|
||||
user = await user_crud.get(session, id=str(async_test_user.id))
|
||||
|
||||
# Update with dict (tests lines 164-165)
|
||||
updated = await user_crud.update(
|
||||
session,
|
||||
db_obj=user,
|
||||
obj_in={"first_name": "UpdatedName"}
|
||||
session, db_obj=user, obj_in={"first_name": "UpdatedName"}
|
||||
)
|
||||
assert updated.first_name == "UpdatedName"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_integrity_error(self, async_test_db, async_test_user):
|
||||
"""Test update with IntegrityError."""
|
||||
test_engine, SessionLocal = async_test_db
|
||||
_test_engine, SessionLocal = async_test_db
|
||||
|
||||
async with SessionLocal() as session:
|
||||
user = await user_crud.get(session, id=str(async_test_user.id))
|
||||
|
||||
with patch.object(session, 'commit', side_effect=IntegrityError("statement", {}, Exception("constraint failed"))):
|
||||
with patch.object(
|
||||
session,
|
||||
"commit",
|
||||
side_effect=IntegrityError(
|
||||
"statement", {}, Exception("constraint failed")
|
||||
),
|
||||
):
|
||||
with pytest.raises(ValueError, match="Database integrity error"):
|
||||
await user_crud.update(session, db_obj=user, obj_in={"first_name": "Test"})
|
||||
await user_crud.update(
|
||||
session, db_obj=user, obj_in={"first_name": "Test"}
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_operational_error(self, async_test_db, async_test_user):
|
||||
"""Test update with OperationalError."""
|
||||
test_engine, SessionLocal = async_test_db
|
||||
_test_engine, SessionLocal = async_test_db
|
||||
|
||||
async with SessionLocal() as session:
|
||||
user = await user_crud.get(session, id=str(async_test_user.id))
|
||||
|
||||
with patch.object(session, 'commit', side_effect=OperationalError("statement", {}, Exception("connection error"))):
|
||||
with patch.object(
|
||||
session,
|
||||
"commit",
|
||||
side_effect=OperationalError(
|
||||
"statement", {}, Exception("connection error")
|
||||
),
|
||||
):
|
||||
with pytest.raises(ValueError, match="Database operation failed"):
|
||||
await user_crud.update(session, db_obj=user, obj_in={"first_name": "Test"})
|
||||
await user_crud.update(
|
||||
session, db_obj=user, obj_in={"first_name": "Test"}
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_unexpected_error(self, async_test_db, async_test_user):
|
||||
"""Test update with unexpected error."""
|
||||
test_engine, SessionLocal = async_test_db
|
||||
_test_engine, SessionLocal = async_test_db
|
||||
|
||||
async with SessionLocal() as session:
|
||||
user = await user_crud.get(session, id=str(async_test_user.id))
|
||||
|
||||
with patch.object(session, 'commit', side_effect=RuntimeError("Unexpected")):
|
||||
with patch.object(
|
||||
session, "commit", side_effect=RuntimeError("Unexpected")
|
||||
):
|
||||
with pytest.raises(RuntimeError):
|
||||
await user_crud.update(session, db_obj=user, obj_in={"first_name": "Test"})
|
||||
await user_crud.update(
|
||||
session, db_obj=user, obj_in={"first_name": "Test"}
|
||||
)
|
||||
|
||||
|
||||
class TestCRUDBaseRemove:
|
||||
@@ -315,7 +350,7 @@ class TestCRUDBaseRemove:
|
||||
@pytest.mark.asyncio
|
||||
async def test_remove_invalid_uuid(self, async_test_db):
|
||||
"""Test remove with invalid UUID returns None."""
|
||||
test_engine, SessionLocal = async_test_db
|
||||
_test_engine, SessionLocal = async_test_db
|
||||
|
||||
async with SessionLocal() as session:
|
||||
result = await user_crud.remove(session, id="invalid-uuid")
|
||||
@@ -324,7 +359,7 @@ class TestCRUDBaseRemove:
|
||||
@pytest.mark.asyncio
|
||||
async def test_remove_with_uuid_object(self, async_test_db, async_test_user):
|
||||
"""Test remove with UUID object."""
|
||||
test_engine, SessionLocal = async_test_db
|
||||
_test_engine, SessionLocal = async_test_db
|
||||
|
||||
# Create a user to delete
|
||||
async with SessionLocal() as session:
|
||||
@@ -332,7 +367,7 @@ class TestCRUDBaseRemove:
|
||||
email="todelete@example.com",
|
||||
password="TestPassword123!",
|
||||
first_name="To",
|
||||
last_name="Delete"
|
||||
last_name="Delete",
|
||||
)
|
||||
user = await user_crud.create(session, obj_in=user_data)
|
||||
user_id = user.id
|
||||
@@ -347,7 +382,7 @@ class TestCRUDBaseRemove:
|
||||
@pytest.mark.asyncio
|
||||
async def test_remove_nonexistent(self, async_test_db):
|
||||
"""Test remove of nonexistent record returns None."""
|
||||
test_engine, SessionLocal = async_test_db
|
||||
_test_engine, SessionLocal = async_test_db
|
||||
|
||||
async with SessionLocal() as session:
|
||||
result = await user_crud.remove(session, id=str(uuid4()))
|
||||
@@ -356,21 +391,31 @@ class TestCRUDBaseRemove:
|
||||
@pytest.mark.asyncio
|
||||
async def test_remove_integrity_error(self, async_test_db, async_test_user):
|
||||
"""Test remove with IntegrityError (foreign key constraint)."""
|
||||
test_engine, SessionLocal = async_test_db
|
||||
_test_engine, SessionLocal = async_test_db
|
||||
|
||||
async with SessionLocal() as session:
|
||||
# Mock delete to raise IntegrityError
|
||||
with patch.object(session, 'commit', side_effect=IntegrityError("statement", {}, Exception("FOREIGN KEY constraint"))):
|
||||
with pytest.raises(ValueError, match="Cannot delete.*referenced by other records"):
|
||||
with patch.object(
|
||||
session,
|
||||
"commit",
|
||||
side_effect=IntegrityError(
|
||||
"statement", {}, Exception("FOREIGN KEY constraint")
|
||||
),
|
||||
):
|
||||
with pytest.raises(
|
||||
ValueError, match="Cannot delete.*referenced by other records"
|
||||
):
|
||||
await user_crud.remove(session, id=str(async_test_user.id))
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_remove_unexpected_error(self, async_test_db, async_test_user):
|
||||
"""Test remove with unexpected error."""
|
||||
test_engine, SessionLocal = async_test_db
|
||||
_test_engine, SessionLocal = async_test_db
|
||||
|
||||
async with SessionLocal() as session:
|
||||
with patch.object(session, 'commit', side_effect=RuntimeError("Unexpected")):
|
||||
with patch.object(
|
||||
session, "commit", side_effect=RuntimeError("Unexpected")
|
||||
):
|
||||
with pytest.raises(RuntimeError):
|
||||
await user_crud.remove(session, id=str(async_test_user.id))
|
||||
|
||||
@@ -381,10 +426,12 @@ class TestCRUDBaseGetMultiWithTotal:
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_multi_with_total_basic(self, async_test_db, async_test_user):
|
||||
"""Test get_multi_with_total basic functionality."""
|
||||
test_engine, SessionLocal = async_test_db
|
||||
_test_engine, SessionLocal = async_test_db
|
||||
|
||||
async with SessionLocal() as session:
|
||||
items, total = await user_crud.get_multi_with_total(session, skip=0, limit=10)
|
||||
items, total = await user_crud.get_multi_with_total(
|
||||
session, skip=0, limit=10
|
||||
)
|
||||
assert isinstance(items, list)
|
||||
assert isinstance(total, int)
|
||||
assert total >= 1 # At least the test user
|
||||
@@ -392,7 +439,7 @@ class TestCRUDBaseGetMultiWithTotal:
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_multi_with_total_negative_skip(self, async_test_db):
|
||||
"""Test get_multi_with_total with negative skip raises ValueError."""
|
||||
test_engine, SessionLocal = async_test_db
|
||||
_test_engine, SessionLocal = async_test_db
|
||||
|
||||
async with SessionLocal() as session:
|
||||
with pytest.raises(ValueError, match="skip must be non-negative"):
|
||||
@@ -401,7 +448,7 @@ class TestCRUDBaseGetMultiWithTotal:
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_multi_with_total_negative_limit(self, async_test_db):
|
||||
"""Test get_multi_with_total with negative limit raises ValueError."""
|
||||
test_engine, SessionLocal = async_test_db
|
||||
_test_engine, SessionLocal = async_test_db
|
||||
|
||||
async with SessionLocal() as session:
|
||||
with pytest.raises(ValueError, match="limit must be non-negative"):
|
||||
@@ -410,28 +457,34 @@ class TestCRUDBaseGetMultiWithTotal:
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_multi_with_total_limit_too_large(self, async_test_db):
|
||||
"""Test get_multi_with_total with limit > 1000 raises ValueError."""
|
||||
test_engine, SessionLocal = async_test_db
|
||||
_test_engine, SessionLocal = async_test_db
|
||||
|
||||
async with SessionLocal() as session:
|
||||
with pytest.raises(ValueError, match="Maximum limit is 1000"):
|
||||
await user_crud.get_multi_with_total(session, limit=1001)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_multi_with_total_with_filters(self, async_test_db, async_test_user):
|
||||
async def test_get_multi_with_total_with_filters(
|
||||
self, async_test_db, async_test_user
|
||||
):
|
||||
"""Test get_multi_with_total with filters."""
|
||||
test_engine, SessionLocal = async_test_db
|
||||
_test_engine, SessionLocal = async_test_db
|
||||
|
||||
async with SessionLocal() as session:
|
||||
filters = {"email": async_test_user.email}
|
||||
items, total = await user_crud.get_multi_with_total(session, filters=filters)
|
||||
items, total = await user_crud.get_multi_with_total(
|
||||
session, filters=filters
|
||||
)
|
||||
assert total == 1
|
||||
assert len(items) == 1
|
||||
assert items[0].email == async_test_user.email
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_multi_with_total_with_sorting_asc(self, async_test_db, async_test_user):
|
||||
async def test_get_multi_with_total_with_sorting_asc(
|
||||
self, async_test_db, async_test_user
|
||||
):
|
||||
"""Test get_multi_with_total with ascending sort."""
|
||||
test_engine, SessionLocal = async_test_db
|
||||
_test_engine, SessionLocal = async_test_db
|
||||
|
||||
# Create additional users
|
||||
async with SessionLocal() as session:
|
||||
@@ -439,13 +492,13 @@ class TestCRUDBaseGetMultiWithTotal:
|
||||
email="aaa@example.com",
|
||||
password="TestPassword123!",
|
||||
first_name="AAA",
|
||||
last_name="User"
|
||||
last_name="User",
|
||||
)
|
||||
user_data2 = UserCreate(
|
||||
email="zzz@example.com",
|
||||
password="TestPassword123!",
|
||||
first_name="ZZZ",
|
||||
last_name="User"
|
||||
last_name="User",
|
||||
)
|
||||
await user_crud.create(session, obj_in=user_data1)
|
||||
await user_crud.create(session, obj_in=user_data2)
|
||||
@@ -460,9 +513,11 @@ class TestCRUDBaseGetMultiWithTotal:
|
||||
assert items[0].email == "aaa@example.com"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_multi_with_total_with_sorting_desc(self, async_test_db, async_test_user):
|
||||
async def test_get_multi_with_total_with_sorting_desc(
|
||||
self, async_test_db, async_test_user
|
||||
):
|
||||
"""Test get_multi_with_total with descending sort."""
|
||||
test_engine, SessionLocal = async_test_db
|
||||
_test_engine, SessionLocal = async_test_db
|
||||
|
||||
# Create additional users
|
||||
async with SessionLocal() as session:
|
||||
@@ -470,20 +525,20 @@ class TestCRUDBaseGetMultiWithTotal:
|
||||
email="bbb@example.com",
|
||||
password="TestPassword123!",
|
||||
first_name="BBB",
|
||||
last_name="User"
|
||||
last_name="User",
|
||||
)
|
||||
user_data2 = UserCreate(
|
||||
email="ccc@example.com",
|
||||
password="TestPassword123!",
|
||||
first_name="CCC",
|
||||
last_name="User"
|
||||
last_name="User",
|
||||
)
|
||||
await user_crud.create(session, obj_in=user_data1)
|
||||
await user_crud.create(session, obj_in=user_data2)
|
||||
await session.commit()
|
||||
|
||||
async with SessionLocal() as session:
|
||||
items, total = await user_crud.get_multi_with_total(
|
||||
items, _total = await user_crud.get_multi_with_total(
|
||||
session, sort_by="email", sort_order="desc", limit=1
|
||||
)
|
||||
assert len(items) == 1
|
||||
@@ -492,7 +547,7 @@ class TestCRUDBaseGetMultiWithTotal:
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_multi_with_total_with_pagination(self, async_test_db):
|
||||
"""Test get_multi_with_total pagination works correctly."""
|
||||
test_engine, SessionLocal = async_test_db
|
||||
_test_engine, SessionLocal = async_test_db
|
||||
|
||||
# Create minimal users for pagination test (3 instead of 5)
|
||||
async with SessionLocal() as session:
|
||||
@@ -501,19 +556,23 @@ class TestCRUDBaseGetMultiWithTotal:
|
||||
email=f"user{i}@example.com",
|
||||
password="TestPassword123!",
|
||||
first_name=f"User{i}",
|
||||
last_name="Test"
|
||||
last_name="Test",
|
||||
)
|
||||
await user_crud.create(session, obj_in=user_data)
|
||||
await session.commit()
|
||||
|
||||
async with SessionLocal() as session:
|
||||
# Get first page
|
||||
items1, total = await user_crud.get_multi_with_total(session, skip=0, limit=2)
|
||||
items1, total = await user_crud.get_multi_with_total(
|
||||
session, skip=0, limit=2
|
||||
)
|
||||
assert len(items1) == 2
|
||||
assert total >= 3
|
||||
|
||||
# Get second page
|
||||
items2, total2 = await user_crud.get_multi_with_total(session, skip=2, limit=2)
|
||||
items2, total2 = await user_crud.get_multi_with_total(
|
||||
session, skip=2, limit=2
|
||||
)
|
||||
assert len(items2) >= 1
|
||||
assert total2 == total
|
||||
|
||||
@@ -529,7 +588,7 @@ class TestCRUDBaseCount:
|
||||
@pytest.mark.asyncio
|
||||
async def test_count_basic(self, async_test_db, async_test_user):
|
||||
"""Test count returns correct number."""
|
||||
test_engine, SessionLocal = async_test_db
|
||||
_test_engine, SessionLocal = async_test_db
|
||||
|
||||
async with SessionLocal() as session:
|
||||
count = await user_crud.count(session)
|
||||
@@ -539,7 +598,7 @@ class TestCRUDBaseCount:
|
||||
@pytest.mark.asyncio
|
||||
async def test_count_multiple_users(self, async_test_db, async_test_user):
|
||||
"""Test count with multiple users."""
|
||||
test_engine, SessionLocal = async_test_db
|
||||
_test_engine, SessionLocal = async_test_db
|
||||
|
||||
# Create additional users
|
||||
async with SessionLocal() as session:
|
||||
@@ -549,13 +608,13 @@ class TestCRUDBaseCount:
|
||||
email="count1@example.com",
|
||||
password="TestPassword123!",
|
||||
first_name="Count",
|
||||
last_name="One"
|
||||
last_name="One",
|
||||
)
|
||||
user_data2 = UserCreate(
|
||||
email="count2@example.com",
|
||||
password="TestPassword123!",
|
||||
first_name="Count",
|
||||
last_name="Two"
|
||||
last_name="Two",
|
||||
)
|
||||
await user_crud.create(session, obj_in=user_data1)
|
||||
await user_crud.create(session, obj_in=user_data2)
|
||||
@@ -568,10 +627,10 @@ class TestCRUDBaseCount:
|
||||
@pytest.mark.asyncio
|
||||
async def test_count_database_error(self, async_test_db):
|
||||
"""Test count handles database errors."""
|
||||
test_engine, SessionLocal = async_test_db
|
||||
_test_engine, SessionLocal = async_test_db
|
||||
|
||||
async with SessionLocal() as session:
|
||||
with patch.object(session, 'execute', side_effect=Exception("DB error")):
|
||||
with patch.object(session, "execute", side_effect=Exception("DB error")):
|
||||
with pytest.raises(Exception, match="DB error"):
|
||||
await user_crud.count(session)
|
||||
|
||||
@@ -582,7 +641,7 @@ class TestCRUDBaseExists:
|
||||
@pytest.mark.asyncio
|
||||
async def test_exists_true(self, async_test_db, async_test_user):
|
||||
"""Test exists returns True for existing record."""
|
||||
test_engine, SessionLocal = async_test_db
|
||||
_test_engine, SessionLocal = async_test_db
|
||||
|
||||
async with SessionLocal() as session:
|
||||
result = await user_crud.exists(session, id=str(async_test_user.id))
|
||||
@@ -591,7 +650,7 @@ class TestCRUDBaseExists:
|
||||
@pytest.mark.asyncio
|
||||
async def test_exists_false(self, async_test_db):
|
||||
"""Test exists returns False for non-existent record."""
|
||||
test_engine, SessionLocal = async_test_db
|
||||
_test_engine, SessionLocal = async_test_db
|
||||
|
||||
async with SessionLocal() as session:
|
||||
result = await user_crud.exists(session, id=str(uuid4()))
|
||||
@@ -600,7 +659,7 @@ class TestCRUDBaseExists:
|
||||
@pytest.mark.asyncio
|
||||
async def test_exists_invalid_uuid(self, async_test_db):
|
||||
"""Test exists returns False for invalid UUID."""
|
||||
test_engine, SessionLocal = async_test_db
|
||||
_test_engine, SessionLocal = async_test_db
|
||||
|
||||
async with SessionLocal() as session:
|
||||
result = await user_crud.exists(session, id="invalid-uuid")
|
||||
@@ -613,7 +672,7 @@ class TestCRUDBaseSoftDelete:
|
||||
@pytest.mark.asyncio
|
||||
async def test_soft_delete_success(self, async_test_db):
|
||||
"""Test soft delete sets deleted_at timestamp."""
|
||||
test_engine, SessionLocal = async_test_db
|
||||
_test_engine, SessionLocal = async_test_db
|
||||
|
||||
# Create a user to soft delete
|
||||
async with SessionLocal() as session:
|
||||
@@ -621,7 +680,7 @@ class TestCRUDBaseSoftDelete:
|
||||
email="softdelete@example.com",
|
||||
password="TestPassword123!",
|
||||
first_name="Soft",
|
||||
last_name="Delete"
|
||||
last_name="Delete",
|
||||
)
|
||||
user = await user_crud.create(session, obj_in=user_data)
|
||||
user_id = user.id
|
||||
@@ -636,7 +695,7 @@ class TestCRUDBaseSoftDelete:
|
||||
@pytest.mark.asyncio
|
||||
async def test_soft_delete_invalid_uuid(self, async_test_db):
|
||||
"""Test soft delete with invalid UUID returns None."""
|
||||
test_engine, SessionLocal = async_test_db
|
||||
_test_engine, SessionLocal = async_test_db
|
||||
|
||||
async with SessionLocal() as session:
|
||||
result = await user_crud.soft_delete(session, id="invalid-uuid")
|
||||
@@ -645,7 +704,7 @@ class TestCRUDBaseSoftDelete:
|
||||
@pytest.mark.asyncio
|
||||
async def test_soft_delete_nonexistent(self, async_test_db):
|
||||
"""Test soft delete of nonexistent record returns None."""
|
||||
test_engine, SessionLocal = async_test_db
|
||||
_test_engine, SessionLocal = async_test_db
|
||||
|
||||
async with SessionLocal() as session:
|
||||
result = await user_crud.soft_delete(session, id=str(uuid4()))
|
||||
@@ -654,7 +713,7 @@ class TestCRUDBaseSoftDelete:
|
||||
@pytest.mark.asyncio
|
||||
async def test_soft_delete_with_uuid_object(self, async_test_db):
|
||||
"""Test soft delete with UUID object."""
|
||||
test_engine, SessionLocal = async_test_db
|
||||
_test_engine, SessionLocal = async_test_db
|
||||
|
||||
# Create a user to soft delete
|
||||
async with SessionLocal() as session:
|
||||
@@ -662,7 +721,7 @@ class TestCRUDBaseSoftDelete:
|
||||
email="softdelete2@example.com",
|
||||
password="TestPassword123!",
|
||||
first_name="Soft",
|
||||
last_name="Delete2"
|
||||
last_name="Delete2",
|
||||
)
|
||||
user = await user_crud.create(session, obj_in=user_data)
|
||||
user_id = user.id
|
||||
@@ -681,7 +740,7 @@ class TestCRUDBaseRestore:
|
||||
@pytest.mark.asyncio
|
||||
async def test_restore_success(self, async_test_db):
|
||||
"""Test restore clears deleted_at timestamp."""
|
||||
test_engine, SessionLocal = async_test_db
|
||||
_test_engine, SessionLocal = async_test_db
|
||||
|
||||
# Create and soft delete a user
|
||||
async with SessionLocal() as session:
|
||||
@@ -689,7 +748,7 @@ class TestCRUDBaseRestore:
|
||||
email="restore@example.com",
|
||||
password="TestPassword123!",
|
||||
first_name="Restore",
|
||||
last_name="Test"
|
||||
last_name="Test",
|
||||
)
|
||||
user = await user_crud.create(session, obj_in=user_data)
|
||||
user_id = user.id
|
||||
@@ -707,7 +766,7 @@ class TestCRUDBaseRestore:
|
||||
@pytest.mark.asyncio
|
||||
async def test_restore_invalid_uuid(self, async_test_db):
|
||||
"""Test restore with invalid UUID returns None."""
|
||||
test_engine, SessionLocal = async_test_db
|
||||
_test_engine, SessionLocal = async_test_db
|
||||
|
||||
async with SessionLocal() as session:
|
||||
result = await user_crud.restore(session, id="invalid-uuid")
|
||||
@@ -716,7 +775,7 @@ class TestCRUDBaseRestore:
|
||||
@pytest.mark.asyncio
|
||||
async def test_restore_nonexistent(self, async_test_db):
|
||||
"""Test restore of nonexistent record returns None."""
|
||||
test_engine, SessionLocal = async_test_db
|
||||
_test_engine, SessionLocal = async_test_db
|
||||
|
||||
async with SessionLocal() as session:
|
||||
result = await user_crud.restore(session, id=str(uuid4()))
|
||||
@@ -725,7 +784,7 @@ class TestCRUDBaseRestore:
|
||||
@pytest.mark.asyncio
|
||||
async def test_restore_not_deleted(self, async_test_db, async_test_user):
|
||||
"""Test restore of non-deleted record returns None."""
|
||||
test_engine, SessionLocal = async_test_db
|
||||
_test_engine, SessionLocal = async_test_db
|
||||
|
||||
async with SessionLocal() as session:
|
||||
# Try to restore a user that's not deleted
|
||||
@@ -735,7 +794,7 @@ class TestCRUDBaseRestore:
|
||||
@pytest.mark.asyncio
|
||||
async def test_restore_with_uuid_object(self, async_test_db):
|
||||
"""Test restore with UUID object."""
|
||||
test_engine, SessionLocal = async_test_db
|
||||
_test_engine, SessionLocal = async_test_db
|
||||
|
||||
# Create and soft delete a user
|
||||
async with SessionLocal() as session:
|
||||
@@ -743,7 +802,7 @@ class TestCRUDBaseRestore:
|
||||
email="restore2@example.com",
|
||||
password="TestPassword123!",
|
||||
first_name="Restore",
|
||||
last_name="Test2"
|
||||
last_name="Test2",
|
||||
)
|
||||
user = await user_crud.create(session, obj_in=user_data)
|
||||
user_id = user.id
|
||||
@@ -765,7 +824,7 @@ class TestCRUDBasePaginationValidation:
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_multi_with_total_negative_skip(self, async_test_db):
|
||||
"""Test that negative skip raises ValueError."""
|
||||
test_engine, SessionLocal = async_test_db
|
||||
_test_engine, SessionLocal = async_test_db
|
||||
|
||||
async with SessionLocal() as session:
|
||||
with pytest.raises(ValueError, match="skip must be non-negative"):
|
||||
@@ -774,7 +833,7 @@ class TestCRUDBasePaginationValidation:
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_multi_with_total_negative_limit(self, async_test_db):
|
||||
"""Test that negative limit raises ValueError."""
|
||||
test_engine, SessionLocal = async_test_db
|
||||
_test_engine, SessionLocal = async_test_db
|
||||
|
||||
async with SessionLocal() as session:
|
||||
with pytest.raises(ValueError, match="limit must be non-negative"):
|
||||
@@ -783,23 +842,22 @@ class TestCRUDBasePaginationValidation:
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_multi_with_total_limit_too_large(self, async_test_db):
|
||||
"""Test that limit > 1000 raises ValueError."""
|
||||
test_engine, SessionLocal = async_test_db
|
||||
_test_engine, SessionLocal = async_test_db
|
||||
|
||||
async with SessionLocal() as session:
|
||||
with pytest.raises(ValueError, match="Maximum limit is 1000"):
|
||||
await user_crud.get_multi_with_total(session, skip=0, limit=1001)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_multi_with_total_with_filters(self, async_test_db, async_test_user):
|
||||
async def test_get_multi_with_total_with_filters(
|
||||
self, async_test_db, async_test_user
|
||||
):
|
||||
"""Test pagination with filters (covers lines 270-273)."""
|
||||
test_engine, SessionLocal = async_test_db
|
||||
_test_engine, SessionLocal = async_test_db
|
||||
|
||||
async with SessionLocal() as session:
|
||||
users, total = await user_crud.get_multi_with_total(
|
||||
session,
|
||||
skip=0,
|
||||
limit=10,
|
||||
filters={"is_active": True}
|
||||
session, skip=0, limit=10, filters={"is_active": True}
|
||||
)
|
||||
assert isinstance(users, list)
|
||||
assert total >= 0
|
||||
@@ -807,30 +865,22 @@ class TestCRUDBasePaginationValidation:
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_multi_with_total_with_sorting_desc(self, async_test_db):
|
||||
"""Test pagination with descending sort (covers lines 283-284)."""
|
||||
test_engine, SessionLocal = async_test_db
|
||||
_test_engine, SessionLocal = async_test_db
|
||||
|
||||
async with SessionLocal() as session:
|
||||
users, total = await user_crud.get_multi_with_total(
|
||||
session,
|
||||
skip=0,
|
||||
limit=10,
|
||||
sort_by="created_at",
|
||||
sort_order="desc"
|
||||
users, _total = await user_crud.get_multi_with_total(
|
||||
session, skip=0, limit=10, sort_by="created_at", sort_order="desc"
|
||||
)
|
||||
assert isinstance(users, list)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_multi_with_total_with_sorting_asc(self, async_test_db):
|
||||
"""Test pagination with ascending sort (covers lines 285-286)."""
|
||||
test_engine, SessionLocal = async_test_db
|
||||
_test_engine, SessionLocal = async_test_db
|
||||
|
||||
async with SessionLocal() as session:
|
||||
users, total = await user_crud.get_multi_with_total(
|
||||
session,
|
||||
skip=0,
|
||||
limit=10,
|
||||
sort_by="created_at",
|
||||
sort_order="asc"
|
||||
users, _total = await user_crud.get_multi_with_total(
|
||||
session, skip=0, limit=10, sort_by="created_at", sort_order="asc"
|
||||
)
|
||||
assert isinstance(users, list)
|
||||
|
||||
@@ -842,13 +892,15 @@ class TestCRUDBaseModelsWithoutSoftDelete:
|
||||
"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_soft_delete_model_without_deleted_at(self, async_test_db, async_test_user):
|
||||
async def test_soft_delete_model_without_deleted_at(
|
||||
self, async_test_db, async_test_user
|
||||
):
|
||||
"""Test soft_delete on Organization model (no deleted_at) raises ValueError (covers lines 342-343)."""
|
||||
test_engine, SessionLocal = async_test_db
|
||||
_test_engine, SessionLocal = async_test_db
|
||||
|
||||
# Create an organization (which doesn't have deleted_at)
|
||||
from app.models.organization import Organization
|
||||
from app.crud.organization import organization as org_crud
|
||||
from app.models.organization import Organization
|
||||
|
||||
async with SessionLocal() as session:
|
||||
org = Organization(name="Test Org", slug="test-org")
|
||||
@@ -864,11 +916,11 @@ class TestCRUDBaseModelsWithoutSoftDelete:
|
||||
@pytest.mark.asyncio
|
||||
async def test_restore_model_without_deleted_at(self, async_test_db):
|
||||
"""Test restore on Organization model (no deleted_at) raises ValueError (covers lines 383-384)."""
|
||||
test_engine, SessionLocal = async_test_db
|
||||
_test_engine, SessionLocal = async_test_db
|
||||
|
||||
# Create an organization (which doesn't have deleted_at)
|
||||
from app.models.organization import Organization
|
||||
from app.crud.organization import organization as org_crud
|
||||
from app.models.organization import Organization
|
||||
|
||||
async with SessionLocal() as session:
|
||||
org = Organization(name="Restore Test", slug="restore-test")
|
||||
@@ -889,14 +941,17 @@ class TestCRUDBaseEagerLoadingWithRealOptions:
|
||||
"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_with_real_eager_loading_options(self, async_test_db, async_test_user):
|
||||
async def test_get_with_real_eager_loading_options(
|
||||
self, async_test_db, async_test_user
|
||||
):
|
||||
"""Test get() with actual eager loading options (covers lines 77-78)."""
|
||||
from datetime import datetime, timedelta, timezone
|
||||
test_engine, SessionLocal = async_test_db
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
_test_engine, SessionLocal = async_test_db
|
||||
|
||||
# Create a session for the user
|
||||
from app.models.user_session import UserSession
|
||||
from app.crud.session import session as session_crud
|
||||
from app.models.user_session import UserSession
|
||||
|
||||
async with SessionLocal() as session:
|
||||
user_session = UserSession(
|
||||
@@ -905,8 +960,8 @@ class TestCRUDBaseEagerLoadingWithRealOptions:
|
||||
device_id="test-device",
|
||||
ip_address="192.168.1.1",
|
||||
user_agent="Test Agent",
|
||||
last_used_at=datetime.now(timezone.utc),
|
||||
expires_at=datetime.now(timezone.utc) + timedelta(days=60)
|
||||
last_used_at=datetime.now(UTC),
|
||||
expires_at=datetime.now(UTC) + timedelta(days=60),
|
||||
)
|
||||
session.add(user_session)
|
||||
await session.commit()
|
||||
@@ -917,7 +972,7 @@ class TestCRUDBaseEagerLoadingWithRealOptions:
|
||||
result = await session_crud.get(
|
||||
session,
|
||||
id=str(session_id),
|
||||
options=[joinedload(UserSession.user)] # Real option, not empty list
|
||||
options=[joinedload(UserSession.user)], # Real option, not empty list
|
||||
)
|
||||
assert result is not None
|
||||
assert result.id == session_id
|
||||
@@ -925,14 +980,17 @@ class TestCRUDBaseEagerLoadingWithRealOptions:
|
||||
assert result.user.email == async_test_user.email
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_multi_with_real_eager_loading_options(self, async_test_db, async_test_user):
|
||||
async def test_get_multi_with_real_eager_loading_options(
|
||||
self, async_test_db, async_test_user
|
||||
):
|
||||
"""Test get_multi() with actual eager loading options (covers lines 119-120)."""
|
||||
from datetime import datetime, timedelta, timezone
|
||||
test_engine, SessionLocal = async_test_db
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
_test_engine, SessionLocal = async_test_db
|
||||
|
||||
# Create multiple sessions for the user
|
||||
from app.models.user_session import UserSession
|
||||
from app.crud.session import session as session_crud
|
||||
from app.models.user_session import UserSession
|
||||
|
||||
async with SessionLocal() as session:
|
||||
for i in range(3):
|
||||
@@ -942,8 +1000,8 @@ class TestCRUDBaseEagerLoadingWithRealOptions:
|
||||
device_id=f"device-{i}",
|
||||
ip_address=f"192.168.1.{i}",
|
||||
user_agent=f"Agent {i}",
|
||||
last_used_at=datetime.now(timezone.utc),
|
||||
expires_at=datetime.now(timezone.utc) + timedelta(days=60)
|
||||
last_used_at=datetime.now(UTC),
|
||||
expires_at=datetime.now(UTC) + timedelta(days=60),
|
||||
)
|
||||
session.add(user_session)
|
||||
await session.commit()
|
||||
@@ -954,7 +1012,7 @@ class TestCRUDBaseEagerLoadingWithRealOptions:
|
||||
session,
|
||||
skip=0,
|
||||
limit=10,
|
||||
options=[joinedload(UserSession.user)] # Real option, not empty list
|
||||
options=[joinedload(UserSession.user)], # Real option, not empty list
|
||||
)
|
||||
assert len(results) >= 3
|
||||
# Verify we can access user without additional queries
|
||||
|
||||
@@ -3,13 +3,15 @@
|
||||
Comprehensive tests for base CRUD database failure scenarios.
|
||||
Tests exception handling, rollbacks, and error messages.
|
||||
"""
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
from sqlalchemy.exc import IntegrityError, OperationalError, DataError
|
||||
|
||||
from unittest.mock import AsyncMock, patch
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
from sqlalchemy.exc import DataError, OperationalError
|
||||
|
||||
from app.crud.user import user as user_crud
|
||||
from app.schemas.users import UserCreate, UserUpdate
|
||||
from app.schemas.users import UserCreate
|
||||
|
||||
|
||||
class TestBaseCRUDCreateFailures:
|
||||
@@ -18,19 +20,24 @@ class TestBaseCRUDCreateFailures:
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_operational_error_triggers_rollback(self, async_test_db):
|
||||
"""Test that OperationalError triggers rollback (User CRUD catches as Exception)."""
|
||||
test_engine, SessionLocal = async_test_db
|
||||
_test_engine, SessionLocal = async_test_db
|
||||
|
||||
async with SessionLocal() as session:
|
||||
async def mock_commit():
|
||||
raise OperationalError("Connection lost", {}, Exception("DB connection failed"))
|
||||
|
||||
with patch.object(session, 'commit', side_effect=mock_commit):
|
||||
with patch.object(session, 'rollback', new_callable=AsyncMock) as mock_rollback:
|
||||
async def mock_commit():
|
||||
raise OperationalError(
|
||||
"Connection lost", {}, Exception("DB connection failed")
|
||||
)
|
||||
|
||||
with patch.object(session, "commit", side_effect=mock_commit):
|
||||
with patch.object(
|
||||
session, "rollback", new_callable=AsyncMock
|
||||
) as mock_rollback:
|
||||
user_data = UserCreate(
|
||||
email="operror@example.com",
|
||||
password="TestPassword123!",
|
||||
first_name="Test",
|
||||
last_name="User"
|
||||
last_name="User",
|
||||
)
|
||||
|
||||
# User CRUD catches this as generic Exception and re-raises
|
||||
@@ -43,19 +50,22 @@ class TestBaseCRUDCreateFailures:
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_data_error_triggers_rollback(self, async_test_db):
|
||||
"""Test that DataError triggers rollback (User CRUD catches as Exception)."""
|
||||
test_engine, SessionLocal = async_test_db
|
||||
_test_engine, SessionLocal = async_test_db
|
||||
|
||||
async with SessionLocal() as session:
|
||||
|
||||
async def mock_commit():
|
||||
raise DataError("Invalid data type", {}, Exception("Data overflow"))
|
||||
|
||||
with patch.object(session, 'commit', side_effect=mock_commit):
|
||||
with patch.object(session, 'rollback', new_callable=AsyncMock) as mock_rollback:
|
||||
with patch.object(session, "commit", side_effect=mock_commit):
|
||||
with patch.object(
|
||||
session, "rollback", new_callable=AsyncMock
|
||||
) as mock_rollback:
|
||||
user_data = UserCreate(
|
||||
email="dataerror@example.com",
|
||||
password="TestPassword123!",
|
||||
first_name="Test",
|
||||
last_name="User"
|
||||
last_name="User",
|
||||
)
|
||||
|
||||
# User CRUD catches this as generic Exception and re-raises
|
||||
@@ -67,19 +77,22 @@ class TestBaseCRUDCreateFailures:
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_unexpected_exception_triggers_rollback(self, async_test_db):
|
||||
"""Test that unexpected exceptions trigger rollback and re-raise."""
|
||||
test_engine, SessionLocal = async_test_db
|
||||
_test_engine, SessionLocal = async_test_db
|
||||
|
||||
async with SessionLocal() as session:
|
||||
|
||||
async def mock_commit():
|
||||
raise RuntimeError("Unexpected database error")
|
||||
|
||||
with patch.object(session, 'commit', side_effect=mock_commit):
|
||||
with patch.object(session, 'rollback', new_callable=AsyncMock) as mock_rollback:
|
||||
with patch.object(session, "commit", side_effect=mock_commit):
|
||||
with patch.object(
|
||||
session, "rollback", new_callable=AsyncMock
|
||||
) as mock_rollback:
|
||||
user_data = UserCreate(
|
||||
email="unexpected@example.com",
|
||||
password="TestPassword123!",
|
||||
first_name="Test",
|
||||
last_name="User"
|
||||
last_name="User",
|
||||
)
|
||||
|
||||
with pytest.raises(RuntimeError, match="Unexpected database error"):
|
||||
@@ -94,7 +107,7 @@ class TestBaseCRUDUpdateFailures:
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_operational_error(self, async_test_db, async_test_user):
|
||||
"""Test update with OperationalError."""
|
||||
test_engine, SessionLocal = async_test_db
|
||||
_test_engine, SessionLocal = async_test_db
|
||||
|
||||
async with SessionLocal() as session:
|
||||
user = await user_crud.get(session, id=str(async_test_user.id))
|
||||
@@ -102,17 +115,21 @@ class TestBaseCRUDUpdateFailures:
|
||||
async def mock_commit():
|
||||
raise OperationalError("Connection timeout", {}, Exception("Timeout"))
|
||||
|
||||
with patch.object(session, 'commit', side_effect=mock_commit):
|
||||
with patch.object(session, 'rollback', new_callable=AsyncMock) as mock_rollback:
|
||||
with patch.object(session, "commit", side_effect=mock_commit):
|
||||
with patch.object(
|
||||
session, "rollback", new_callable=AsyncMock
|
||||
) as mock_rollback:
|
||||
with pytest.raises(ValueError, match="Database operation failed"):
|
||||
await user_crud.update(session, db_obj=user, obj_in={"first_name": "Updated"})
|
||||
await user_crud.update(
|
||||
session, db_obj=user, obj_in={"first_name": "Updated"}
|
||||
)
|
||||
|
||||
mock_rollback.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_data_error(self, async_test_db, async_test_user):
|
||||
"""Test update with DataError."""
|
||||
test_engine, SessionLocal = async_test_db
|
||||
_test_engine, SessionLocal = async_test_db
|
||||
|
||||
async with SessionLocal() as session:
|
||||
user = await user_crud.get(session, id=str(async_test_user.id))
|
||||
@@ -120,17 +137,21 @@ class TestBaseCRUDUpdateFailures:
|
||||
async def mock_commit():
|
||||
raise DataError("Invalid data", {}, Exception("Data type mismatch"))
|
||||
|
||||
with patch.object(session, 'commit', side_effect=mock_commit):
|
||||
with patch.object(session, 'rollback', new_callable=AsyncMock) as mock_rollback:
|
||||
with patch.object(session, "commit", side_effect=mock_commit):
|
||||
with patch.object(
|
||||
session, "rollback", new_callable=AsyncMock
|
||||
) as mock_rollback:
|
||||
with pytest.raises(ValueError, match="Database operation failed"):
|
||||
await user_crud.update(session, db_obj=user, obj_in={"first_name": "Updated"})
|
||||
await user_crud.update(
|
||||
session, db_obj=user, obj_in={"first_name": "Updated"}
|
||||
)
|
||||
|
||||
mock_rollback.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_unexpected_error(self, async_test_db, async_test_user):
|
||||
"""Test update with unexpected error."""
|
||||
test_engine, SessionLocal = async_test_db
|
||||
_test_engine, SessionLocal = async_test_db
|
||||
|
||||
async with SessionLocal() as session:
|
||||
user = await user_crud.get(session, id=str(async_test_user.id))
|
||||
@@ -138,10 +159,14 @@ class TestBaseCRUDUpdateFailures:
|
||||
async def mock_commit():
|
||||
raise KeyError("Unexpected error")
|
||||
|
||||
with patch.object(session, 'commit', side_effect=mock_commit):
|
||||
with patch.object(session, 'rollback', new_callable=AsyncMock) as mock_rollback:
|
||||
with patch.object(session, "commit", side_effect=mock_commit):
|
||||
with patch.object(
|
||||
session, "rollback", new_callable=AsyncMock
|
||||
) as mock_rollback:
|
||||
with pytest.raises(KeyError):
|
||||
await user_crud.update(session, db_obj=user, obj_in={"first_name": "Updated"})
|
||||
await user_crud.update(
|
||||
session, db_obj=user, obj_in={"first_name": "Updated"}
|
||||
)
|
||||
|
||||
mock_rollback.assert_called_once()
|
||||
|
||||
@@ -150,16 +175,21 @@ class TestBaseCRUDRemoveFailures:
|
||||
"""Test base CRUD remove method exception handling."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_remove_unexpected_error_triggers_rollback(self, async_test_db, async_test_user):
|
||||
async def test_remove_unexpected_error_triggers_rollback(
|
||||
self, async_test_db, async_test_user
|
||||
):
|
||||
"""Test that unexpected errors in remove trigger rollback."""
|
||||
test_engine, SessionLocal = async_test_db
|
||||
_test_engine, SessionLocal = async_test_db
|
||||
|
||||
async with SessionLocal() as session:
|
||||
|
||||
async def mock_commit():
|
||||
raise RuntimeError("Database write failed")
|
||||
|
||||
with patch.object(session, 'commit', side_effect=mock_commit):
|
||||
with patch.object(session, 'rollback', new_callable=AsyncMock) as mock_rollback:
|
||||
with patch.object(session, "commit", side_effect=mock_commit):
|
||||
with patch.object(
|
||||
session, "rollback", new_callable=AsyncMock
|
||||
) as mock_rollback:
|
||||
with pytest.raises(RuntimeError, match="Database write failed"):
|
||||
await user_crud.remove(session, id=str(async_test_user.id))
|
||||
|
||||
@@ -172,16 +202,15 @@ class TestBaseCRUDGetMultiWithTotalFailures:
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_multi_with_total_database_error(self, async_test_db):
|
||||
"""Test get_multi_with_total handles database errors."""
|
||||
test_engine, SessionLocal = async_test_db
|
||||
_test_engine, SessionLocal = async_test_db
|
||||
|
||||
async with SessionLocal() as session:
|
||||
# Mock execute to raise an error
|
||||
original_execute = session.execute
|
||||
|
||||
async def mock_execute(*args, **kwargs):
|
||||
raise OperationalError("Query failed", {}, Exception("Database error"))
|
||||
|
||||
with patch.object(session, 'execute', side_effect=mock_execute):
|
||||
with patch.object(session, "execute", side_effect=mock_execute):
|
||||
with pytest.raises(OperationalError):
|
||||
await user_crud.get_multi_with_total(session, skip=0, limit=10)
|
||||
|
||||
@@ -192,13 +221,14 @@ class TestBaseCRUDCountFailures:
|
||||
@pytest.mark.asyncio
|
||||
async def test_count_database_error_propagates(self, async_test_db):
|
||||
"""Test count propagates database errors."""
|
||||
test_engine, SessionLocal = async_test_db
|
||||
_test_engine, SessionLocal = async_test_db
|
||||
|
||||
async with SessionLocal() as session:
|
||||
|
||||
async def mock_execute(*args, **kwargs):
|
||||
raise OperationalError("Count failed", {}, Exception("DB error"))
|
||||
|
||||
with patch.object(session, 'execute', side_effect=mock_execute):
|
||||
with patch.object(session, "execute", side_effect=mock_execute):
|
||||
with pytest.raises(OperationalError):
|
||||
await user_crud.count(session)
|
||||
|
||||
@@ -207,16 +237,21 @@ class TestBaseCRUDSoftDeleteFailures:
|
||||
"""Test soft_delete method exception handling."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_soft_delete_unexpected_error_triggers_rollback(self, async_test_db, async_test_user):
|
||||
async def test_soft_delete_unexpected_error_triggers_rollback(
|
||||
self, async_test_db, async_test_user
|
||||
):
|
||||
"""Test soft_delete handles unexpected errors with rollback."""
|
||||
test_engine, SessionLocal = async_test_db
|
||||
_test_engine, SessionLocal = async_test_db
|
||||
|
||||
async with SessionLocal() as session:
|
||||
|
||||
async def mock_commit():
|
||||
raise RuntimeError("Soft delete failed")
|
||||
|
||||
with patch.object(session, 'commit', side_effect=mock_commit):
|
||||
with patch.object(session, 'rollback', new_callable=AsyncMock) as mock_rollback:
|
||||
with patch.object(session, "commit", side_effect=mock_commit):
|
||||
with patch.object(
|
||||
session, "rollback", new_callable=AsyncMock
|
||||
) as mock_rollback:
|
||||
with pytest.raises(RuntimeError, match="Soft delete failed"):
|
||||
await user_crud.soft_delete(session, id=str(async_test_user.id))
|
||||
|
||||
@@ -229,7 +264,7 @@ class TestBaseCRUDRestoreFailures:
|
||||
@pytest.mark.asyncio
|
||||
async def test_restore_unexpected_error_triggers_rollback(self, async_test_db):
|
||||
"""Test restore handles unexpected errors with rollback."""
|
||||
test_engine, SessionLocal = async_test_db
|
||||
_test_engine, SessionLocal = async_test_db
|
||||
|
||||
# First create and soft delete a user
|
||||
async with SessionLocal() as session:
|
||||
@@ -237,7 +272,7 @@ class TestBaseCRUDRestoreFailures:
|
||||
email="restore_test@example.com",
|
||||
password="TestPassword123!",
|
||||
first_name="Restore",
|
||||
last_name="Test"
|
||||
last_name="Test",
|
||||
)
|
||||
user = await user_crud.create(session, obj_in=user_data)
|
||||
user_id = user.id
|
||||
@@ -248,11 +283,14 @@ class TestBaseCRUDRestoreFailures:
|
||||
|
||||
# Now test restore failure
|
||||
async with SessionLocal() as session:
|
||||
|
||||
async def mock_commit():
|
||||
raise RuntimeError("Restore failed")
|
||||
|
||||
with patch.object(session, 'commit', side_effect=mock_commit):
|
||||
with patch.object(session, 'rollback', new_callable=AsyncMock) as mock_rollback:
|
||||
with patch.object(session, "commit", side_effect=mock_commit):
|
||||
with patch.object(
|
||||
session, "rollback", new_callable=AsyncMock
|
||||
) as mock_rollback:
|
||||
with pytest.raises(RuntimeError, match="Restore failed"):
|
||||
await user_crud.restore(session, id=str(user_id))
|
||||
|
||||
@@ -265,13 +303,14 @@ class TestBaseCRUDGetFailures:
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_database_error_propagates(self, async_test_db):
|
||||
"""Test get propagates database errors."""
|
||||
test_engine, SessionLocal = async_test_db
|
||||
_test_engine, SessionLocal = async_test_db
|
||||
|
||||
async with SessionLocal() as session:
|
||||
|
||||
async def mock_execute(*args, **kwargs):
|
||||
raise OperationalError("Get failed", {}, Exception("DB error"))
|
||||
|
||||
with patch.object(session, 'execute', side_effect=mock_execute):
|
||||
with patch.object(session, "execute", side_effect=mock_execute):
|
||||
with pytest.raises(OperationalError):
|
||||
await user_crud.get(session, id=str(uuid4()))
|
||||
|
||||
@@ -282,12 +321,13 @@ class TestBaseCRUDGetMultiFailures:
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_multi_database_error_propagates(self, async_test_db):
|
||||
"""Test get_multi propagates database errors."""
|
||||
test_engine, SessionLocal = async_test_db
|
||||
_test_engine, SessionLocal = async_test_db
|
||||
|
||||
async with SessionLocal() as session:
|
||||
|
||||
async def mock_execute(*args, **kwargs):
|
||||
raise OperationalError("Query failed", {}, Exception("DB error"))
|
||||
|
||||
with patch.object(session, 'execute', side_effect=mock_execute):
|
||||
with patch.object(session, "execute", side_effect=mock_execute):
|
||||
with pytest.raises(OperationalError):
|
||||
await user_crud.get_multi(session, skip=0, limit=10)
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -2,10 +2,12 @@
|
||||
"""
|
||||
Comprehensive tests for async session CRUD operations.
|
||||
"""
|
||||
import pytest
|
||||
from datetime import datetime, timedelta, timezone
|
||||
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
|
||||
from app.crud.session import session as session_crud
|
||||
from app.models.user_session import UserSession
|
||||
from app.schemas.sessions import SessionCreate
|
||||
@@ -17,7 +19,7 @@ class TestGetByJti:
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_by_jti_success(self, async_test_db, async_test_user):
|
||||
"""Test getting session by JTI."""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
user_session = UserSession(
|
||||
@@ -27,8 +29,8 @@ class TestGetByJti:
|
||||
ip_address="192.168.1.1",
|
||||
user_agent="Mozilla/5.0",
|
||||
is_active=True,
|
||||
expires_at=datetime.now(timezone.utc) + timedelta(days=7),
|
||||
last_used_at=datetime.now(timezone.utc)
|
||||
expires_at=datetime.now(UTC) + timedelta(days=7),
|
||||
last_used_at=datetime.now(UTC),
|
||||
)
|
||||
session.add(user_session)
|
||||
await session.commit()
|
||||
@@ -41,7 +43,7 @@ class TestGetByJti:
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_by_jti_not_found(self, async_test_db):
|
||||
"""Test getting non-existent JTI returns None."""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
result = await session_crud.get_by_jti(session, jti="nonexistent")
|
||||
@@ -54,7 +56,7 @@ class TestGetActiveByJti:
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_active_by_jti_success(self, async_test_db, async_test_user):
|
||||
"""Test getting active session by JTI."""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
user_session = UserSession(
|
||||
@@ -64,8 +66,8 @@ class TestGetActiveByJti:
|
||||
ip_address="192.168.1.1",
|
||||
user_agent="Mozilla/5.0",
|
||||
is_active=True,
|
||||
expires_at=datetime.now(timezone.utc) + timedelta(days=7),
|
||||
last_used_at=datetime.now(timezone.utc)
|
||||
expires_at=datetime.now(UTC) + timedelta(days=7),
|
||||
last_used_at=datetime.now(UTC),
|
||||
)
|
||||
session.add(user_session)
|
||||
await session.commit()
|
||||
@@ -78,7 +80,7 @@ class TestGetActiveByJti:
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_active_by_jti_inactive(self, async_test_db, async_test_user):
|
||||
"""Test getting inactive session by JTI returns None."""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
user_session = UserSession(
|
||||
@@ -88,8 +90,8 @@ class TestGetActiveByJti:
|
||||
ip_address="192.168.1.1",
|
||||
user_agent="Mozilla/5.0",
|
||||
is_active=False,
|
||||
expires_at=datetime.now(timezone.utc) + timedelta(days=7),
|
||||
last_used_at=datetime.now(timezone.utc)
|
||||
expires_at=datetime.now(UTC) + timedelta(days=7),
|
||||
last_used_at=datetime.now(UTC),
|
||||
)
|
||||
session.add(user_session)
|
||||
await session.commit()
|
||||
@@ -105,7 +107,7 @@ class TestGetUserSessions:
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_user_sessions_active_only(self, async_test_db, async_test_user):
|
||||
"""Test getting only active user sessions."""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
active = UserSession(
|
||||
@@ -115,8 +117,8 @@ class TestGetUserSessions:
|
||||
ip_address="192.168.1.1",
|
||||
user_agent="Mozilla/5.0",
|
||||
is_active=True,
|
||||
expires_at=datetime.now(timezone.utc) + timedelta(days=7),
|
||||
last_used_at=datetime.now(timezone.utc)
|
||||
expires_at=datetime.now(UTC) + timedelta(days=7),
|
||||
last_used_at=datetime.now(UTC),
|
||||
)
|
||||
inactive = UserSession(
|
||||
user_id=async_test_user.id,
|
||||
@@ -125,17 +127,15 @@ class TestGetUserSessions:
|
||||
ip_address="192.168.1.2",
|
||||
user_agent="Mozilla/5.0",
|
||||
is_active=False,
|
||||
expires_at=datetime.now(timezone.utc) + timedelta(days=7),
|
||||
last_used_at=datetime.now(timezone.utc)
|
||||
expires_at=datetime.now(UTC) + timedelta(days=7),
|
||||
last_used_at=datetime.now(UTC),
|
||||
)
|
||||
session.add_all([active, inactive])
|
||||
await session.commit()
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
results = await session_crud.get_user_sessions(
|
||||
session,
|
||||
user_id=str(async_test_user.id),
|
||||
active_only=True
|
||||
session, user_id=str(async_test_user.id), active_only=True
|
||||
)
|
||||
assert len(results) == 1
|
||||
assert results[0].is_active is True
|
||||
@@ -143,7 +143,7 @@ class TestGetUserSessions:
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_user_sessions_all(self, async_test_db, async_test_user):
|
||||
"""Test getting all user sessions."""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
for i in range(3):
|
||||
@@ -154,17 +154,15 @@ class TestGetUserSessions:
|
||||
ip_address="192.168.1.1",
|
||||
user_agent="Mozilla/5.0",
|
||||
is_active=i % 2 == 0,
|
||||
expires_at=datetime.now(timezone.utc) + timedelta(days=7),
|
||||
last_used_at=datetime.now(timezone.utc)
|
||||
expires_at=datetime.now(UTC) + timedelta(days=7),
|
||||
last_used_at=datetime.now(UTC),
|
||||
)
|
||||
session.add(sess)
|
||||
await session.commit()
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
results = await session_crud.get_user_sessions(
|
||||
session,
|
||||
user_id=str(async_test_user.id),
|
||||
active_only=False
|
||||
session, user_id=str(async_test_user.id), active_only=False
|
||||
)
|
||||
assert len(results) == 3
|
||||
|
||||
@@ -175,7 +173,7 @@ class TestCreateSession:
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_session_success(self, async_test_db, async_test_user):
|
||||
"""Test successfully creating a session_crud."""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
session_data = SessionCreate(
|
||||
@@ -185,10 +183,10 @@ class TestCreateSession:
|
||||
device_id="device_123",
|
||||
ip_address="192.168.1.100",
|
||||
user_agent="Mozilla/5.0",
|
||||
last_used_at=datetime.now(timezone.utc),
|
||||
expires_at=datetime.now(timezone.utc) + timedelta(days=7),
|
||||
last_used_at=datetime.now(UTC),
|
||||
expires_at=datetime.now(UTC) + timedelta(days=7),
|
||||
location_city="San Francisco",
|
||||
location_country="USA"
|
||||
location_country="USA",
|
||||
)
|
||||
result = await session_crud.create_session(session, obj_in=session_data)
|
||||
|
||||
@@ -204,7 +202,7 @@ class TestDeactivate:
|
||||
@pytest.mark.asyncio
|
||||
async def test_deactivate_success(self, async_test_db, async_test_user):
|
||||
"""Test successfully deactivating a session_crud."""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
user_session = UserSession(
|
||||
@@ -214,8 +212,8 @@ class TestDeactivate:
|
||||
ip_address="192.168.1.1",
|
||||
user_agent="Mozilla/5.0",
|
||||
is_active=True,
|
||||
expires_at=datetime.now(timezone.utc) + timedelta(days=7),
|
||||
last_used_at=datetime.now(timezone.utc)
|
||||
expires_at=datetime.now(UTC) + timedelta(days=7),
|
||||
last_used_at=datetime.now(UTC),
|
||||
)
|
||||
session.add(user_session)
|
||||
await session.commit()
|
||||
@@ -229,7 +227,7 @@ class TestDeactivate:
|
||||
@pytest.mark.asyncio
|
||||
async def test_deactivate_not_found(self, async_test_db):
|
||||
"""Test deactivating non-existent session returns None."""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
result = await session_crud.deactivate(session, session_id=str(uuid4()))
|
||||
@@ -240,9 +238,11 @@ class TestDeactivateAllUserSessions:
|
||||
"""Tests for deactivate_all_user_sessions method."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_deactivate_all_user_sessions_success(self, async_test_db, async_test_user):
|
||||
async def test_deactivate_all_user_sessions_success(
|
||||
self, async_test_db, async_test_user
|
||||
):
|
||||
"""Test deactivating all user sessions."""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
# Create minimal sessions for test (2 instead of 5)
|
||||
@@ -254,16 +254,15 @@ class TestDeactivateAllUserSessions:
|
||||
ip_address="192.168.1.1",
|
||||
user_agent="Mozilla/5.0",
|
||||
is_active=True,
|
||||
expires_at=datetime.now(timezone.utc) + timedelta(days=7),
|
||||
last_used_at=datetime.now(timezone.utc)
|
||||
expires_at=datetime.now(UTC) + timedelta(days=7),
|
||||
last_used_at=datetime.now(UTC),
|
||||
)
|
||||
session.add(sess)
|
||||
await session.commit()
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
count = await session_crud.deactivate_all_user_sessions(
|
||||
session,
|
||||
user_id=str(async_test_user.id)
|
||||
session, user_id=str(async_test_user.id)
|
||||
)
|
||||
assert count == 2
|
||||
|
||||
@@ -274,7 +273,7 @@ class TestUpdateLastUsed:
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_last_used_success(self, async_test_db, async_test_user):
|
||||
"""Test updating last_used_at timestamp."""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
user_session = UserSession(
|
||||
@@ -284,8 +283,8 @@ class TestUpdateLastUsed:
|
||||
ip_address="192.168.1.1",
|
||||
user_agent="Mozilla/5.0",
|
||||
is_active=True,
|
||||
expires_at=datetime.now(timezone.utc) + timedelta(days=7),
|
||||
last_used_at=datetime.now(timezone.utc) - timedelta(hours=1)
|
||||
expires_at=datetime.now(UTC) + timedelta(days=7),
|
||||
last_used_at=datetime.now(UTC) - timedelta(hours=1),
|
||||
)
|
||||
session.add(user_session)
|
||||
await session.commit()
|
||||
@@ -303,7 +302,7 @@ class TestGetUserSessionCount:
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_user_session_count_success(self, async_test_db, async_test_user):
|
||||
"""Test getting user session count."""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
for i in range(3):
|
||||
@@ -314,28 +313,26 @@ class TestGetUserSessionCount:
|
||||
ip_address="192.168.1.1",
|
||||
user_agent="Mozilla/5.0",
|
||||
is_active=True,
|
||||
expires_at=datetime.now(timezone.utc) + timedelta(days=7),
|
||||
last_used_at=datetime.now(timezone.utc)
|
||||
expires_at=datetime.now(UTC) + timedelta(days=7),
|
||||
last_used_at=datetime.now(UTC),
|
||||
)
|
||||
session.add(sess)
|
||||
await session.commit()
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
count = await session_crud.get_user_session_count(
|
||||
session,
|
||||
user_id=str(async_test_user.id)
|
||||
session, user_id=str(async_test_user.id)
|
||||
)
|
||||
assert count == 3
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_user_session_count_empty(self, async_test_db):
|
||||
"""Test getting session count for user with no sessions."""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
count = await session_crud.get_user_session_count(
|
||||
session,
|
||||
user_id=str(uuid4())
|
||||
session, user_id=str(uuid4())
|
||||
)
|
||||
assert count == 0
|
||||
|
||||
@@ -346,7 +343,7 @@ class TestUpdateRefreshToken:
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_refresh_token_success(self, async_test_db, async_test_user):
|
||||
"""Test updating refresh token JTI and expiration."""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
user_session = UserSession(
|
||||
@@ -356,26 +353,34 @@ class TestUpdateRefreshToken:
|
||||
ip_address="192.168.1.1",
|
||||
user_agent="Mozilla/5.0",
|
||||
is_active=True,
|
||||
expires_at=datetime.now(timezone.utc) + timedelta(days=7),
|
||||
last_used_at=datetime.now(timezone.utc) - timedelta(hours=1)
|
||||
expires_at=datetime.now(UTC) + timedelta(days=7),
|
||||
last_used_at=datetime.now(UTC) - timedelta(hours=1),
|
||||
)
|
||||
session.add(user_session)
|
||||
await session.commit()
|
||||
await session.refresh(user_session)
|
||||
|
||||
new_jti = "new_jti_123"
|
||||
new_expires = datetime.now(timezone.utc) + timedelta(days=14)
|
||||
new_expires = datetime.now(UTC) + timedelta(days=14)
|
||||
|
||||
result = await session_crud.update_refresh_token(
|
||||
session,
|
||||
session=user_session,
|
||||
new_jti=new_jti,
|
||||
new_expires_at=new_expires
|
||||
new_expires_at=new_expires,
|
||||
)
|
||||
|
||||
assert result.refresh_token_jti == new_jti
|
||||
# Compare timestamps ignoring timezone info
|
||||
assert abs((result.expires_at.replace(tzinfo=None) - new_expires.replace(tzinfo=None)).total_seconds()) < 1
|
||||
assert (
|
||||
abs(
|
||||
(
|
||||
result.expires_at.replace(tzinfo=None)
|
||||
- new_expires.replace(tzinfo=None)
|
||||
).total_seconds()
|
||||
)
|
||||
< 1
|
||||
)
|
||||
|
||||
|
||||
class TestCleanupExpired:
|
||||
@@ -384,7 +389,7 @@ class TestCleanupExpired:
|
||||
@pytest.mark.asyncio
|
||||
async def test_cleanup_expired_success(self, async_test_db, async_test_user):
|
||||
"""Test cleaning up old expired inactive sessions."""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
# Create old expired inactive session
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
@@ -395,9 +400,9 @@ class TestCleanupExpired:
|
||||
ip_address="192.168.1.1",
|
||||
user_agent="Mozilla/5.0",
|
||||
is_active=False,
|
||||
expires_at=datetime.now(timezone.utc) - timedelta(days=5),
|
||||
last_used_at=datetime.now(timezone.utc) - timedelta(days=35),
|
||||
created_at=datetime.now(timezone.utc) - timedelta(days=35)
|
||||
expires_at=datetime.now(UTC) - timedelta(days=5),
|
||||
last_used_at=datetime.now(UTC) - timedelta(days=35),
|
||||
created_at=datetime.now(UTC) - timedelta(days=35),
|
||||
)
|
||||
session.add(old_session)
|
||||
await session.commit()
|
||||
@@ -410,7 +415,7 @@ class TestCleanupExpired:
|
||||
@pytest.mark.asyncio
|
||||
async def test_cleanup_expired_keeps_recent(self, async_test_db, async_test_user):
|
||||
"""Test that cleanup keeps recent expired sessions."""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
# Create recent expired inactive session (less than keep_days old)
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
@@ -421,9 +426,9 @@ class TestCleanupExpired:
|
||||
ip_address="192.168.1.1",
|
||||
user_agent="Mozilla/5.0",
|
||||
is_active=False,
|
||||
expires_at=datetime.now(timezone.utc) - timedelta(hours=1),
|
||||
last_used_at=datetime.now(timezone.utc) - timedelta(hours=2),
|
||||
created_at=datetime.now(timezone.utc) - timedelta(days=1)
|
||||
expires_at=datetime.now(UTC) - timedelta(hours=1),
|
||||
last_used_at=datetime.now(UTC) - timedelta(hours=2),
|
||||
created_at=datetime.now(UTC) - timedelta(days=1),
|
||||
)
|
||||
session.add(recent_session)
|
||||
await session.commit()
|
||||
@@ -436,7 +441,7 @@ class TestCleanupExpired:
|
||||
@pytest.mark.asyncio
|
||||
async def test_cleanup_expired_keeps_active(self, async_test_db, async_test_user):
|
||||
"""Test that cleanup does not delete active sessions."""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
# Create old expired but ACTIVE session
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
@@ -447,9 +452,9 @@ class TestCleanupExpired:
|
||||
ip_address="192.168.1.1",
|
||||
user_agent="Mozilla/5.0",
|
||||
is_active=True, # Active
|
||||
expires_at=datetime.now(timezone.utc) - timedelta(days=5),
|
||||
last_used_at=datetime.now(timezone.utc) - timedelta(days=35),
|
||||
created_at=datetime.now(timezone.utc) - timedelta(days=35)
|
||||
expires_at=datetime.now(UTC) - timedelta(days=5),
|
||||
last_used_at=datetime.now(UTC) - timedelta(days=35),
|
||||
created_at=datetime.now(UTC) - timedelta(days=35),
|
||||
)
|
||||
session.add(active_session)
|
||||
await session.commit()
|
||||
@@ -464,9 +469,11 @@ class TestCleanupExpiredForUser:
|
||||
"""Tests for cleanup_expired_for_user method."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cleanup_expired_for_user_success(self, async_test_db, async_test_user):
|
||||
async def test_cleanup_expired_for_user_success(
|
||||
self, async_test_db, async_test_user
|
||||
):
|
||||
"""Test cleaning up expired sessions for specific user."""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
# Create expired inactive session for user
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
@@ -477,8 +484,8 @@ class TestCleanupExpiredForUser:
|
||||
ip_address="192.168.1.1",
|
||||
user_agent="Mozilla/5.0",
|
||||
is_active=False,
|
||||
expires_at=datetime.now(timezone.utc) - timedelta(days=1),
|
||||
last_used_at=datetime.now(timezone.utc) - timedelta(days=2)
|
||||
expires_at=datetime.now(UTC) - timedelta(days=1),
|
||||
last_used_at=datetime.now(UTC) - timedelta(days=2),
|
||||
)
|
||||
session.add(expired_session)
|
||||
await session.commit()
|
||||
@@ -486,27 +493,27 @@ class TestCleanupExpiredForUser:
|
||||
# Cleanup for user
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
count = await session_crud.cleanup_expired_for_user(
|
||||
session,
|
||||
user_id=str(async_test_user.id)
|
||||
session, user_id=str(async_test_user.id)
|
||||
)
|
||||
assert count == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cleanup_expired_for_user_invalid_uuid(self, async_test_db):
|
||||
"""Test cleanup with invalid user UUID."""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
with pytest.raises(ValueError, match="Invalid user ID format"):
|
||||
await session_crud.cleanup_expired_for_user(
|
||||
session,
|
||||
user_id="not-a-valid-uuid"
|
||||
session, user_id="not-a-valid-uuid"
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cleanup_expired_for_user_keeps_active(self, async_test_db, async_test_user):
|
||||
async def test_cleanup_expired_for_user_keeps_active(
|
||||
self, async_test_db, async_test_user
|
||||
):
|
||||
"""Test that cleanup for user keeps active sessions."""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
# Create expired but active session
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
@@ -517,8 +524,8 @@ class TestCleanupExpiredForUser:
|
||||
ip_address="192.168.1.1",
|
||||
user_agent="Mozilla/5.0",
|
||||
is_active=True, # Active
|
||||
expires_at=datetime.now(timezone.utc) - timedelta(days=1),
|
||||
last_used_at=datetime.now(timezone.utc) - timedelta(days=2)
|
||||
expires_at=datetime.now(UTC) - timedelta(days=1),
|
||||
last_used_at=datetime.now(UTC) - timedelta(days=2),
|
||||
)
|
||||
session.add(active_session)
|
||||
await session.commit()
|
||||
@@ -526,8 +533,7 @@ class TestCleanupExpiredForUser:
|
||||
# Cleanup
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
count = await session_crud.cleanup_expired_for_user(
|
||||
session,
|
||||
user_id=str(async_test_user.id)
|
||||
session, user_id=str(async_test_user.id)
|
||||
)
|
||||
assert count == 0 # Should not delete active sessions
|
||||
|
||||
@@ -536,9 +542,11 @@ class TestGetUserSessionsWithUser:
|
||||
"""Tests for get_user_sessions with eager loading."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_user_sessions_with_user_relationship(self, async_test_db, async_test_user):
|
||||
async def test_get_user_sessions_with_user_relationship(
|
||||
self, async_test_db, async_test_user
|
||||
):
|
||||
"""Test getting sessions with user relationship loaded."""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
user_session = UserSession(
|
||||
@@ -548,8 +556,8 @@ class TestGetUserSessionsWithUser:
|
||||
ip_address="192.168.1.1",
|
||||
user_agent="Mozilla/5.0",
|
||||
is_active=True,
|
||||
expires_at=datetime.now(timezone.utc) + timedelta(days=7),
|
||||
last_used_at=datetime.now(timezone.utc)
|
||||
expires_at=datetime.now(UTC) + timedelta(days=7),
|
||||
last_used_at=datetime.now(UTC),
|
||||
)
|
||||
session.add(user_session)
|
||||
await session.commit()
|
||||
@@ -557,8 +565,6 @@ class TestGetUserSessionsWithUser:
|
||||
# Get with user relationship
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
results = await session_crud.get_user_sessions(
|
||||
session,
|
||||
user_id=str(async_test_user.id),
|
||||
with_user=True
|
||||
session, user_id=str(async_test_user.id), with_user=True
|
||||
)
|
||||
assert len(results) >= 1
|
||||
|
||||
@@ -2,12 +2,14 @@
|
||||
"""
|
||||
Comprehensive tests for session CRUD database failure scenarios.
|
||||
"""
|
||||
import pytest
|
||||
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from unittest.mock import AsyncMock, patch
|
||||
from sqlalchemy.exc import OperationalError, IntegrityError
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
from sqlalchemy.exc import OperationalError
|
||||
|
||||
from app.crud.session import session as session_crud
|
||||
from app.models.user_session import UserSession
|
||||
from app.schemas.sessions import SessionCreate
|
||||
@@ -19,13 +21,14 @@ class TestSessionCRUDGetByJtiFailures:
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_by_jti_database_error(self, async_test_db):
|
||||
"""Test get_by_jti handles database errors."""
|
||||
test_engine, SessionLocal = async_test_db
|
||||
_test_engine, SessionLocal = async_test_db
|
||||
|
||||
async with SessionLocal() as session:
|
||||
|
||||
async def mock_execute(*args, **kwargs):
|
||||
raise OperationalError("DB connection lost", {}, Exception())
|
||||
|
||||
with patch.object(session, 'execute', side_effect=mock_execute):
|
||||
with patch.object(session, "execute", side_effect=mock_execute):
|
||||
with pytest.raises(OperationalError):
|
||||
await session_crud.get_by_jti(session, jti="test_jti")
|
||||
|
||||
@@ -36,13 +39,14 @@ class TestSessionCRUDGetActiveByJtiFailures:
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_active_by_jti_database_error(self, async_test_db):
|
||||
"""Test get_active_by_jti handles database errors."""
|
||||
test_engine, SessionLocal = async_test_db
|
||||
_test_engine, SessionLocal = async_test_db
|
||||
|
||||
async with SessionLocal() as session:
|
||||
|
||||
async def mock_execute(*args, **kwargs):
|
||||
raise OperationalError("Query timeout", {}, Exception())
|
||||
|
||||
with patch.object(session, 'execute', side_effect=mock_execute):
|
||||
with patch.object(session, "execute", side_effect=mock_execute):
|
||||
with pytest.raises(OperationalError):
|
||||
await session_crud.get_active_by_jti(session, jti="test_jti")
|
||||
|
||||
@@ -51,19 +55,21 @@ class TestSessionCRUDGetUserSessionsFailures:
|
||||
"""Test get_user_sessions exception handling."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_user_sessions_database_error(self, async_test_db, async_test_user):
|
||||
async def test_get_user_sessions_database_error(
|
||||
self, async_test_db, async_test_user
|
||||
):
|
||||
"""Test get_user_sessions handles database errors."""
|
||||
test_engine, SessionLocal = async_test_db
|
||||
_test_engine, SessionLocal = async_test_db
|
||||
|
||||
async with SessionLocal() as session:
|
||||
|
||||
async def mock_execute(*args, **kwargs):
|
||||
raise OperationalError("Database error", {}, Exception())
|
||||
|
||||
with patch.object(session, 'execute', side_effect=mock_execute):
|
||||
with patch.object(session, "execute", side_effect=mock_execute):
|
||||
with pytest.raises(OperationalError):
|
||||
await session_crud.get_user_sessions(
|
||||
session,
|
||||
user_id=str(async_test_user.id)
|
||||
session, user_id=str(async_test_user.id)
|
||||
)
|
||||
|
||||
|
||||
@@ -71,24 +77,29 @@ class TestSessionCRUDCreateSessionFailures:
|
||||
"""Test create_session exception handling."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_session_commit_failure_triggers_rollback(self, async_test_db, async_test_user):
|
||||
async def test_create_session_commit_failure_triggers_rollback(
|
||||
self, async_test_db, async_test_user
|
||||
):
|
||||
"""Test create_session handles commit failures with rollback."""
|
||||
test_engine, SessionLocal = async_test_db
|
||||
_test_engine, SessionLocal = async_test_db
|
||||
|
||||
async with SessionLocal() as session:
|
||||
|
||||
async def mock_commit():
|
||||
raise OperationalError("Commit failed", {}, Exception())
|
||||
|
||||
with patch.object(session, 'commit', side_effect=mock_commit):
|
||||
with patch.object(session, 'rollback', new_callable=AsyncMock) as mock_rollback:
|
||||
with patch.object(session, "commit", side_effect=mock_commit):
|
||||
with patch.object(
|
||||
session, "rollback", new_callable=AsyncMock
|
||||
) as mock_rollback:
|
||||
session_data = SessionCreate(
|
||||
user_id=async_test_user.id,
|
||||
refresh_token_jti=str(uuid4()),
|
||||
device_name="Test Device",
|
||||
ip_address="127.0.0.1",
|
||||
user_agent="Test Agent",
|
||||
expires_at=datetime.now(timezone.utc) + timedelta(days=7),
|
||||
last_used_at=datetime.now(timezone.utc)
|
||||
expires_at=datetime.now(UTC) + timedelta(days=7),
|
||||
last_used_at=datetime.now(UTC),
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="Failed to create session"):
|
||||
@@ -97,24 +108,29 @@ class TestSessionCRUDCreateSessionFailures:
|
||||
mock_rollback.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_session_unexpected_error_triggers_rollback(self, async_test_db, async_test_user):
|
||||
async def test_create_session_unexpected_error_triggers_rollback(
|
||||
self, async_test_db, async_test_user
|
||||
):
|
||||
"""Test create_session handles unexpected errors."""
|
||||
test_engine, SessionLocal = async_test_db
|
||||
_test_engine, SessionLocal = async_test_db
|
||||
|
||||
async with SessionLocal() as session:
|
||||
|
||||
async def mock_commit():
|
||||
raise RuntimeError("Unexpected error")
|
||||
|
||||
with patch.object(session, 'commit', side_effect=mock_commit):
|
||||
with patch.object(session, 'rollback', new_callable=AsyncMock) as mock_rollback:
|
||||
with patch.object(session, "commit", side_effect=mock_commit):
|
||||
with patch.object(
|
||||
session, "rollback", new_callable=AsyncMock
|
||||
) as mock_rollback:
|
||||
session_data = SessionCreate(
|
||||
user_id=async_test_user.id,
|
||||
refresh_token_jti=str(uuid4()),
|
||||
device_name="Test Device",
|
||||
ip_address="127.0.0.1",
|
||||
user_agent="Test Agent",
|
||||
expires_at=datetime.now(timezone.utc) + timedelta(days=7),
|
||||
last_used_at=datetime.now(timezone.utc)
|
||||
expires_at=datetime.now(UTC) + timedelta(days=7),
|
||||
last_used_at=datetime.now(UTC),
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="Failed to create session"):
|
||||
@@ -127,9 +143,11 @@ class TestSessionCRUDDeactivateFailures:
|
||||
"""Test deactivate exception handling."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_deactivate_commit_failure_triggers_rollback(self, async_test_db, async_test_user):
|
||||
async def test_deactivate_commit_failure_triggers_rollback(
|
||||
self, async_test_db, async_test_user
|
||||
):
|
||||
"""Test deactivate handles commit failures."""
|
||||
test_engine, SessionLocal = async_test_db
|
||||
_test_engine, SessionLocal = async_test_db
|
||||
|
||||
# Create a session first
|
||||
async with SessionLocal() as session:
|
||||
@@ -140,8 +158,8 @@ class TestSessionCRUDDeactivateFailures:
|
||||
ip_address="127.0.0.1",
|
||||
user_agent="Test Agent",
|
||||
is_active=True,
|
||||
expires_at=datetime.now(timezone.utc) + timedelta(days=7),
|
||||
last_used_at=datetime.now(timezone.utc)
|
||||
expires_at=datetime.now(UTC) + timedelta(days=7),
|
||||
last_used_at=datetime.now(UTC),
|
||||
)
|
||||
session.add(user_session)
|
||||
await session.commit()
|
||||
@@ -150,13 +168,18 @@ class TestSessionCRUDDeactivateFailures:
|
||||
|
||||
# Test deactivate failure
|
||||
async with SessionLocal() as session:
|
||||
|
||||
async def mock_commit():
|
||||
raise OperationalError("Deactivate failed", {}, Exception())
|
||||
|
||||
with patch.object(session, 'commit', side_effect=mock_commit):
|
||||
with patch.object(session, 'rollback', new_callable=AsyncMock) as mock_rollback:
|
||||
with patch.object(session, "commit", side_effect=mock_commit):
|
||||
with patch.object(
|
||||
session, "rollback", new_callable=AsyncMock
|
||||
) as mock_rollback:
|
||||
with pytest.raises(OperationalError):
|
||||
await session_crud.deactivate(session, session_id=str(session_id))
|
||||
await session_crud.deactivate(
|
||||
session, session_id=str(session_id)
|
||||
)
|
||||
|
||||
mock_rollback.assert_called_once()
|
||||
|
||||
@@ -165,20 +188,24 @@ class TestSessionCRUDDeactivateAllFailures:
|
||||
"""Test deactivate_all_user_sessions exception handling."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_deactivate_all_commit_failure_triggers_rollback(self, async_test_db, async_test_user):
|
||||
async def test_deactivate_all_commit_failure_triggers_rollback(
|
||||
self, async_test_db, async_test_user
|
||||
):
|
||||
"""Test deactivate_all handles commit failures."""
|
||||
test_engine, SessionLocal = async_test_db
|
||||
_test_engine, SessionLocal = async_test_db
|
||||
|
||||
async with SessionLocal() as session:
|
||||
|
||||
async def mock_commit():
|
||||
raise OperationalError("Bulk deactivate failed", {}, Exception())
|
||||
|
||||
with patch.object(session, 'commit', side_effect=mock_commit):
|
||||
with patch.object(session, 'rollback', new_callable=AsyncMock) as mock_rollback:
|
||||
with patch.object(session, "commit", side_effect=mock_commit):
|
||||
with patch.object(
|
||||
session, "rollback", new_callable=AsyncMock
|
||||
) as mock_rollback:
|
||||
with pytest.raises(OperationalError):
|
||||
await session_crud.deactivate_all_user_sessions(
|
||||
session,
|
||||
user_id=str(async_test_user.id)
|
||||
session, user_id=str(async_test_user.id)
|
||||
)
|
||||
|
||||
mock_rollback.assert_called_once()
|
||||
@@ -188,9 +215,11 @@ class TestSessionCRUDUpdateLastUsedFailures:
|
||||
"""Test update_last_used exception handling."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_last_used_commit_failure_triggers_rollback(self, async_test_db, async_test_user):
|
||||
async def test_update_last_used_commit_failure_triggers_rollback(
|
||||
self, async_test_db, async_test_user
|
||||
):
|
||||
"""Test update_last_used handles commit failures."""
|
||||
test_engine, SessionLocal = async_test_db
|
||||
_test_engine, SessionLocal = async_test_db
|
||||
|
||||
# Create a session
|
||||
async with SessionLocal() as session:
|
||||
@@ -201,8 +230,8 @@ class TestSessionCRUDUpdateLastUsedFailures:
|
||||
ip_address="127.0.0.1",
|
||||
user_agent="Test Agent",
|
||||
is_active=True,
|
||||
expires_at=datetime.now(timezone.utc) + timedelta(days=7),
|
||||
last_used_at=datetime.now(timezone.utc) - timedelta(hours=1)
|
||||
expires_at=datetime.now(UTC) + timedelta(days=7),
|
||||
last_used_at=datetime.now(UTC) - timedelta(hours=1),
|
||||
)
|
||||
session.add(user_session)
|
||||
await session.commit()
|
||||
@@ -211,15 +240,19 @@ class TestSessionCRUDUpdateLastUsedFailures:
|
||||
# Test update failure
|
||||
async with SessionLocal() as session:
|
||||
from sqlalchemy import select
|
||||
|
||||
from app.models.user_session import UserSession as US
|
||||
|
||||
result = await session.execute(select(US).where(US.id == user_session.id))
|
||||
sess = result.scalar_one()
|
||||
|
||||
async def mock_commit():
|
||||
raise OperationalError("Update failed", {}, Exception())
|
||||
|
||||
with patch.object(session, 'commit', side_effect=mock_commit):
|
||||
with patch.object(session, 'rollback', new_callable=AsyncMock) as mock_rollback:
|
||||
with patch.object(session, "commit", side_effect=mock_commit):
|
||||
with patch.object(
|
||||
session, "rollback", new_callable=AsyncMock
|
||||
) as mock_rollback:
|
||||
with pytest.raises(OperationalError):
|
||||
await session_crud.update_last_used(session, session=sess)
|
||||
|
||||
@@ -230,9 +263,11 @@ class TestSessionCRUDUpdateRefreshTokenFailures:
|
||||
"""Test update_refresh_token exception handling."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_refresh_token_commit_failure_triggers_rollback(self, async_test_db, async_test_user):
|
||||
async def test_update_refresh_token_commit_failure_triggers_rollback(
|
||||
self, async_test_db, async_test_user
|
||||
):
|
||||
"""Test update_refresh_token handles commit failures."""
|
||||
test_engine, SessionLocal = async_test_db
|
||||
_test_engine, SessionLocal = async_test_db
|
||||
|
||||
# Create a session
|
||||
async with SessionLocal() as session:
|
||||
@@ -243,8 +278,8 @@ class TestSessionCRUDUpdateRefreshTokenFailures:
|
||||
ip_address="127.0.0.1",
|
||||
user_agent="Test Agent",
|
||||
is_active=True,
|
||||
expires_at=datetime.now(timezone.utc) + timedelta(days=7),
|
||||
last_used_at=datetime.now(timezone.utc)
|
||||
expires_at=datetime.now(UTC) + timedelta(days=7),
|
||||
last_used_at=datetime.now(UTC),
|
||||
)
|
||||
session.add(user_session)
|
||||
await session.commit()
|
||||
@@ -253,21 +288,25 @@ class TestSessionCRUDUpdateRefreshTokenFailures:
|
||||
# Test update failure
|
||||
async with SessionLocal() as session:
|
||||
from sqlalchemy import select
|
||||
|
||||
from app.models.user_session import UserSession as US
|
||||
|
||||
result = await session.execute(select(US).where(US.id == user_session.id))
|
||||
sess = result.scalar_one()
|
||||
|
||||
async def mock_commit():
|
||||
raise OperationalError("Token update failed", {}, Exception())
|
||||
|
||||
with patch.object(session, 'commit', side_effect=mock_commit):
|
||||
with patch.object(session, 'rollback', new_callable=AsyncMock) as mock_rollback:
|
||||
with patch.object(session, "commit", side_effect=mock_commit):
|
||||
with patch.object(
|
||||
session, "rollback", new_callable=AsyncMock
|
||||
) as mock_rollback:
|
||||
with pytest.raises(OperationalError):
|
||||
await session_crud.update_refresh_token(
|
||||
session,
|
||||
session=sess,
|
||||
new_jti=str(uuid4()),
|
||||
new_expires_at=datetime.now(timezone.utc) + timedelta(days=14)
|
||||
new_expires_at=datetime.now(UTC) + timedelta(days=14),
|
||||
)
|
||||
|
||||
mock_rollback.assert_called_once()
|
||||
@@ -277,16 +316,21 @@ class TestSessionCRUDCleanupExpiredFailures:
|
||||
"""Test cleanup_expired exception handling."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cleanup_expired_commit_failure_triggers_rollback(self, async_test_db):
|
||||
async def test_cleanup_expired_commit_failure_triggers_rollback(
|
||||
self, async_test_db
|
||||
):
|
||||
"""Test cleanup_expired handles commit failures."""
|
||||
test_engine, SessionLocal = async_test_db
|
||||
_test_engine, SessionLocal = async_test_db
|
||||
|
||||
async with SessionLocal() as session:
|
||||
|
||||
async def mock_commit():
|
||||
raise OperationalError("Cleanup failed", {}, Exception())
|
||||
|
||||
with patch.object(session, 'commit', side_effect=mock_commit):
|
||||
with patch.object(session, 'rollback', new_callable=AsyncMock) as mock_rollback:
|
||||
with patch.object(session, "commit", side_effect=mock_commit):
|
||||
with patch.object(
|
||||
session, "rollback", new_callable=AsyncMock
|
||||
) as mock_rollback:
|
||||
with pytest.raises(OperationalError):
|
||||
await session_crud.cleanup_expired(session, keep_days=30)
|
||||
|
||||
@@ -297,20 +341,24 @@ class TestSessionCRUDCleanupExpiredForUserFailures:
|
||||
"""Test cleanup_expired_for_user exception handling."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cleanup_expired_for_user_commit_failure_triggers_rollback(self, async_test_db, async_test_user):
|
||||
async def test_cleanup_expired_for_user_commit_failure_triggers_rollback(
|
||||
self, async_test_db, async_test_user
|
||||
):
|
||||
"""Test cleanup_expired_for_user handles commit failures."""
|
||||
test_engine, SessionLocal = async_test_db
|
||||
_test_engine, SessionLocal = async_test_db
|
||||
|
||||
async with SessionLocal() as session:
|
||||
|
||||
async def mock_commit():
|
||||
raise OperationalError("User cleanup failed", {}, Exception())
|
||||
|
||||
with patch.object(session, 'commit', side_effect=mock_commit):
|
||||
with patch.object(session, 'rollback', new_callable=AsyncMock) as mock_rollback:
|
||||
with patch.object(session, "commit", side_effect=mock_commit):
|
||||
with patch.object(
|
||||
session, "rollback", new_callable=AsyncMock
|
||||
) as mock_rollback:
|
||||
with pytest.raises(OperationalError):
|
||||
await session_crud.cleanup_expired_for_user(
|
||||
session,
|
||||
user_id=str(async_test_user.id)
|
||||
session, user_id=str(async_test_user.id)
|
||||
)
|
||||
|
||||
mock_rollback.assert_called_once()
|
||||
@@ -320,17 +368,19 @@ class TestSessionCRUDGetUserSessionCountFailures:
|
||||
"""Test get_user_session_count exception handling."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_user_session_count_database_error(self, async_test_db, async_test_user):
|
||||
async def test_get_user_session_count_database_error(
|
||||
self, async_test_db, async_test_user
|
||||
):
|
||||
"""Test get_user_session_count handles database errors."""
|
||||
test_engine, SessionLocal = async_test_db
|
||||
_test_engine, SessionLocal = async_test_db
|
||||
|
||||
async with SessionLocal() as session:
|
||||
|
||||
async def mock_execute(*args, **kwargs):
|
||||
raise OperationalError("Count query failed", {}, Exception())
|
||||
|
||||
with patch.object(session, 'execute', side_effect=mock_execute):
|
||||
with patch.object(session, "execute", side_effect=mock_execute):
|
||||
with pytest.raises(OperationalError):
|
||||
await session_crud.get_user_session_count(
|
||||
session,
|
||||
user_id=str(async_test_user.id)
|
||||
session, user_id=str(async_test_user.id)
|
||||
)
|
||||
|
||||
@@ -2,12 +2,10 @@
|
||||
"""
|
||||
Comprehensive tests for async user CRUD operations.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from datetime import datetime, timezone
|
||||
from uuid import uuid4
|
||||
|
||||
from app.crud.user import user as user_crud
|
||||
from app.models.user import User
|
||||
from app.schemas.users import UserCreate, UserUpdate
|
||||
|
||||
|
||||
@@ -17,7 +15,7 @@ class TestGetByEmail:
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_by_email_success(self, async_test_db, async_test_user):
|
||||
"""Test getting user by email."""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
result = await user_crud.get_by_email(session, email=async_test_user.email)
|
||||
@@ -28,10 +26,12 @@ class TestGetByEmail:
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_by_email_not_found(self, async_test_db):
|
||||
"""Test getting non-existent email returns None."""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
result = await user_crud.get_by_email(session, email="nonexistent@example.com")
|
||||
result = await user_crud.get_by_email(
|
||||
session, email="nonexistent@example.com"
|
||||
)
|
||||
assert result is None
|
||||
|
||||
|
||||
@@ -41,7 +41,7 @@ class TestCreate:
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_user_success(self, async_test_db):
|
||||
"""Test successfully creating a user_crud."""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
user_data = UserCreate(
|
||||
@@ -49,7 +49,7 @@ class TestCreate:
|
||||
password="SecurePass123!",
|
||||
first_name="New",
|
||||
last_name="User",
|
||||
phone_number="+1234567890"
|
||||
phone_number="+1234567890",
|
||||
)
|
||||
result = await user_crud.create(session, obj_in=user_data)
|
||||
|
||||
@@ -65,7 +65,7 @@ class TestCreate:
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_superuser_success(self, async_test_db):
|
||||
"""Test creating a superuser."""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
user_data = UserCreate(
|
||||
@@ -73,7 +73,7 @@ class TestCreate:
|
||||
password="SuperPass123!",
|
||||
first_name="Super",
|
||||
last_name="User",
|
||||
is_superuser=True
|
||||
is_superuser=True,
|
||||
)
|
||||
result = await user_crud.create(session, obj_in=user_data)
|
||||
|
||||
@@ -83,14 +83,14 @@ class TestCreate:
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_duplicate_email_fails(self, async_test_db, async_test_user):
|
||||
"""Test creating user with duplicate email raises ValueError."""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
user_data = UserCreate(
|
||||
email=async_test_user.email, # Duplicate email
|
||||
password="AnotherPass123!",
|
||||
first_name="Duplicate",
|
||||
last_name="User"
|
||||
last_name="User",
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
@@ -105,16 +105,14 @@ class TestUpdate:
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_user_basic_fields(self, async_test_db, async_test_user):
|
||||
"""Test updating basic user fields."""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
# Get fresh copy of user
|
||||
user = await user_crud.get(session, id=str(async_test_user.id))
|
||||
|
||||
update_data = UserUpdate(
|
||||
first_name="Updated",
|
||||
last_name="Name",
|
||||
phone_number="+9876543210"
|
||||
first_name="Updated", last_name="Name", phone_number="+9876543210"
|
||||
)
|
||||
result = await user_crud.update(session, db_obj=user, obj_in=update_data)
|
||||
|
||||
@@ -125,7 +123,7 @@ class TestUpdate:
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_user_password(self, async_test_db):
|
||||
"""Test updating user password."""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
# Create a fresh user for this test
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
@@ -133,7 +131,7 @@ class TestUpdate:
|
||||
email="passwordtest@example.com",
|
||||
password="OldPassword123!",
|
||||
first_name="Pass",
|
||||
last_name="Test"
|
||||
last_name="Test",
|
||||
)
|
||||
user = await user_crud.create(session, obj_in=user_data)
|
||||
user_id = user.id
|
||||
@@ -149,12 +147,14 @@ class TestUpdate:
|
||||
await session.refresh(result)
|
||||
assert result.password_hash != old_password_hash
|
||||
assert result.password_hash is not None
|
||||
assert "NewDifferentPassword123!" not in result.password_hash # Should be hashed
|
||||
assert (
|
||||
"NewDifferentPassword123!" not in result.password_hash
|
||||
) # Should be hashed
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_user_with_dict(self, async_test_db, async_test_user):
|
||||
"""Test updating user with dictionary."""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
user = await user_crud.get(session, id=str(async_test_user.id))
|
||||
@@ -171,13 +171,11 @@ class TestGetMultiWithTotal:
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_multi_with_total_basic(self, async_test_db, async_test_user):
|
||||
"""Test basic pagination."""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
users, total = await user_crud.get_multi_with_total(
|
||||
session,
|
||||
skip=0,
|
||||
limit=10
|
||||
session, skip=0, limit=10
|
||||
)
|
||||
assert total >= 1
|
||||
assert len(users) >= 1
|
||||
@@ -186,7 +184,7 @@ class TestGetMultiWithTotal:
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_multi_with_total_sorting_asc(self, async_test_db):
|
||||
"""Test sorting in ascending order."""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
# Create multiple users
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
@@ -195,17 +193,13 @@ class TestGetMultiWithTotal:
|
||||
email=f"sort{i}@example.com",
|
||||
password="SecurePass123!",
|
||||
first_name=f"User{i}",
|
||||
last_name="Test"
|
||||
last_name="Test",
|
||||
)
|
||||
await user_crud.create(session, obj_in=user_data)
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
users, total = await user_crud.get_multi_with_total(
|
||||
session,
|
||||
skip=0,
|
||||
limit=10,
|
||||
sort_by="email",
|
||||
sort_order="asc"
|
||||
users, _total = await user_crud.get_multi_with_total(
|
||||
session, skip=0, limit=10, sort_by="email", sort_order="asc"
|
||||
)
|
||||
|
||||
# Check if sorted (at least the test users)
|
||||
@@ -216,7 +210,7 @@ class TestGetMultiWithTotal:
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_multi_with_total_sorting_desc(self, async_test_db):
|
||||
"""Test sorting in descending order."""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
# Create multiple users
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
@@ -225,17 +219,13 @@ class TestGetMultiWithTotal:
|
||||
email=f"desc{i}@example.com",
|
||||
password="SecurePass123!",
|
||||
first_name=f"User{i}",
|
||||
last_name="Test"
|
||||
last_name="Test",
|
||||
)
|
||||
await user_crud.create(session, obj_in=user_data)
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
users, total = await user_crud.get_multi_with_total(
|
||||
session,
|
||||
skip=0,
|
||||
limit=10,
|
||||
sort_by="email",
|
||||
sort_order="desc"
|
||||
users, _total = await user_crud.get_multi_with_total(
|
||||
session, skip=0, limit=10, sort_by="email", sort_order="desc"
|
||||
)
|
||||
|
||||
# Check if sorted descending (at least the test users)
|
||||
@@ -246,7 +236,7 @@ class TestGetMultiWithTotal:
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_multi_with_total_filtering(self, async_test_db):
|
||||
"""Test filtering by field."""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
# Create active and inactive users
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
@@ -254,7 +244,7 @@ class TestGetMultiWithTotal:
|
||||
email="active@example.com",
|
||||
password="SecurePass123!",
|
||||
first_name="Active",
|
||||
last_name="User"
|
||||
last_name="User",
|
||||
)
|
||||
await user_crud.create(session, obj_in=active_user)
|
||||
|
||||
@@ -262,23 +252,18 @@ class TestGetMultiWithTotal:
|
||||
email="inactive@example.com",
|
||||
password="SecurePass123!",
|
||||
first_name="Inactive",
|
||||
last_name="User"
|
||||
last_name="User",
|
||||
)
|
||||
created_inactive = await user_crud.create(session, obj_in=inactive_user)
|
||||
|
||||
# Deactivate the user
|
||||
await user_crud.update(
|
||||
session,
|
||||
db_obj=created_inactive,
|
||||
obj_in={"is_active": False}
|
||||
session, db_obj=created_inactive, obj_in={"is_active": False}
|
||||
)
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
users, total = await user_crud.get_multi_with_total(
|
||||
session,
|
||||
skip=0,
|
||||
limit=100,
|
||||
filters={"is_active": True}
|
||||
users, _total = await user_crud.get_multi_with_total(
|
||||
session, skip=0, limit=100, filters={"is_active": True}
|
||||
)
|
||||
|
||||
# All returned users should be active
|
||||
@@ -287,7 +272,7 @@ class TestGetMultiWithTotal:
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_multi_with_total_search(self, async_test_db):
|
||||
"""Test search functionality."""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
# Create user with unique name
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
@@ -295,16 +280,13 @@ class TestGetMultiWithTotal:
|
||||
email="searchable@example.com",
|
||||
password="SecurePass123!",
|
||||
first_name="Searchable",
|
||||
last_name="UserName"
|
||||
last_name="UserName",
|
||||
)
|
||||
await user_crud.create(session, obj_in=user_data)
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
users, total = await user_crud.get_multi_with_total(
|
||||
session,
|
||||
skip=0,
|
||||
limit=100,
|
||||
search="Searchable"
|
||||
session, skip=0, limit=100, search="Searchable"
|
||||
)
|
||||
|
||||
assert total >= 1
|
||||
@@ -313,7 +295,7 @@ class TestGetMultiWithTotal:
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_multi_with_total_pagination(self, async_test_db):
|
||||
"""Test pagination with skip and limit."""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
# Create multiple users
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
@@ -322,23 +304,19 @@ class TestGetMultiWithTotal:
|
||||
email=f"page{i}@example.com",
|
||||
password="SecurePass123!",
|
||||
first_name=f"Page{i}",
|
||||
last_name="User"
|
||||
last_name="User",
|
||||
)
|
||||
await user_crud.create(session, obj_in=user_data)
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
# Get first page
|
||||
users_page1, total = await user_crud.get_multi_with_total(
|
||||
session,
|
||||
skip=0,
|
||||
limit=2
|
||||
session, skip=0, limit=2
|
||||
)
|
||||
|
||||
# Get second page
|
||||
users_page2, total2 = await user_crud.get_multi_with_total(
|
||||
session,
|
||||
skip=2,
|
||||
limit=2
|
||||
session, skip=2, limit=2
|
||||
)
|
||||
|
||||
# Total should be same
|
||||
@@ -349,7 +327,7 @@ class TestGetMultiWithTotal:
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_multi_with_total_validation_negative_skip(self, async_test_db):
|
||||
"""Test validation fails for negative skip."""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
@@ -360,7 +338,7 @@ class TestGetMultiWithTotal:
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_multi_with_total_validation_negative_limit(self, async_test_db):
|
||||
"""Test validation fails for negative limit."""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
@@ -371,7 +349,7 @@ class TestGetMultiWithTotal:
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_multi_with_total_validation_max_limit(self, async_test_db):
|
||||
"""Test validation fails for limit > 1000."""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
@@ -386,7 +364,7 @@ class TestBulkUpdateStatus:
|
||||
@pytest.mark.asyncio
|
||||
async def test_bulk_update_status_success(self, async_test_db):
|
||||
"""Test bulk updating user status."""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
# Create multiple users
|
||||
user_ids = []
|
||||
@@ -396,7 +374,7 @@ class TestBulkUpdateStatus:
|
||||
email=f"bulk{i}@example.com",
|
||||
password="SecurePass123!",
|
||||
first_name=f"Bulk{i}",
|
||||
last_name="User"
|
||||
last_name="User",
|
||||
)
|
||||
user = await user_crud.create(session, obj_in=user_data)
|
||||
user_ids.append(user.id)
|
||||
@@ -404,9 +382,7 @@ class TestBulkUpdateStatus:
|
||||
# Bulk deactivate
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
count = await user_crud.bulk_update_status(
|
||||
session,
|
||||
user_ids=user_ids,
|
||||
is_active=False
|
||||
session, user_ids=user_ids, is_active=False
|
||||
)
|
||||
assert count == 3
|
||||
|
||||
@@ -419,20 +395,18 @@ class TestBulkUpdateStatus:
|
||||
@pytest.mark.asyncio
|
||||
async def test_bulk_update_status_empty_list(self, async_test_db):
|
||||
"""Test bulk update with empty list returns 0."""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
count = await user_crud.bulk_update_status(
|
||||
session,
|
||||
user_ids=[],
|
||||
is_active=False
|
||||
session, user_ids=[], is_active=False
|
||||
)
|
||||
assert count == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_bulk_update_status_reactivate(self, async_test_db):
|
||||
"""Test bulk reactivating users."""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
# Create inactive user
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
@@ -440,7 +414,7 @@ class TestBulkUpdateStatus:
|
||||
email="reactivate@example.com",
|
||||
password="SecurePass123!",
|
||||
first_name="Reactivate",
|
||||
last_name="User"
|
||||
last_name="User",
|
||||
)
|
||||
user = await user_crud.create(session, obj_in=user_data)
|
||||
# Deactivate
|
||||
@@ -450,9 +424,7 @@ class TestBulkUpdateStatus:
|
||||
# Reactivate
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
count = await user_crud.bulk_update_status(
|
||||
session,
|
||||
user_ids=[user_id],
|
||||
is_active=True
|
||||
session, user_ids=[user_id], is_active=True
|
||||
)
|
||||
assert count == 1
|
||||
|
||||
@@ -468,7 +440,7 @@ class TestBulkSoftDelete:
|
||||
@pytest.mark.asyncio
|
||||
async def test_bulk_soft_delete_success(self, async_test_db):
|
||||
"""Test bulk soft deleting users."""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
# Create multiple users
|
||||
user_ids = []
|
||||
@@ -478,17 +450,14 @@ class TestBulkSoftDelete:
|
||||
email=f"delete{i}@example.com",
|
||||
password="SecurePass123!",
|
||||
first_name=f"Delete{i}",
|
||||
last_name="User"
|
||||
last_name="User",
|
||||
)
|
||||
user = await user_crud.create(session, obj_in=user_data)
|
||||
user_ids.append(user.id)
|
||||
|
||||
# Bulk delete
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
count = await user_crud.bulk_soft_delete(
|
||||
session,
|
||||
user_ids=user_ids
|
||||
)
|
||||
count = await user_crud.bulk_soft_delete(session, user_ids=user_ids)
|
||||
assert count == 3
|
||||
|
||||
# Verify all are soft deleted
|
||||
@@ -501,7 +470,7 @@ class TestBulkSoftDelete:
|
||||
@pytest.mark.asyncio
|
||||
async def test_bulk_soft_delete_with_exclusion(self, async_test_db):
|
||||
"""Test bulk soft delete with excluded user_crud."""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
# Create multiple users
|
||||
user_ids = []
|
||||
@@ -511,7 +480,7 @@ class TestBulkSoftDelete:
|
||||
email=f"exclude{i}@example.com",
|
||||
password="SecurePass123!",
|
||||
first_name=f"Exclude{i}",
|
||||
last_name="User"
|
||||
last_name="User",
|
||||
)
|
||||
user = await user_crud.create(session, obj_in=user_data)
|
||||
user_ids.append(user.id)
|
||||
@@ -520,9 +489,7 @@ class TestBulkSoftDelete:
|
||||
exclude_id = user_ids[0]
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
count = await user_crud.bulk_soft_delete(
|
||||
session,
|
||||
user_ids=user_ids,
|
||||
exclude_user_id=exclude_id
|
||||
session, user_ids=user_ids, exclude_user_id=exclude_id
|
||||
)
|
||||
assert count == 2 # Only 2 deleted
|
||||
|
||||
@@ -534,19 +501,16 @@ class TestBulkSoftDelete:
|
||||
@pytest.mark.asyncio
|
||||
async def test_bulk_soft_delete_empty_list(self, async_test_db):
|
||||
"""Test bulk delete with empty list returns 0."""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
count = await user_crud.bulk_soft_delete(
|
||||
session,
|
||||
user_ids=[]
|
||||
)
|
||||
count = await user_crud.bulk_soft_delete(session, user_ids=[])
|
||||
assert count == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_bulk_soft_delete_all_excluded(self, async_test_db):
|
||||
"""Test bulk delete where all users are excluded."""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
# Create user
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
@@ -554,7 +518,7 @@ class TestBulkSoftDelete:
|
||||
email="onlyuser@example.com",
|
||||
password="SecurePass123!",
|
||||
first_name="Only",
|
||||
last_name="User"
|
||||
last_name="User",
|
||||
)
|
||||
user = await user_crud.create(session, obj_in=user_data)
|
||||
user_id = user.id
|
||||
@@ -562,16 +526,14 @@ class TestBulkSoftDelete:
|
||||
# Try to delete but exclude
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
count = await user_crud.bulk_soft_delete(
|
||||
session,
|
||||
user_ids=[user_id],
|
||||
exclude_user_id=user_id
|
||||
session, user_ids=[user_id], exclude_user_id=user_id
|
||||
)
|
||||
assert count == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_bulk_soft_delete_already_deleted(self, async_test_db):
|
||||
"""Test bulk delete doesn't re-delete already deleted users."""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
# Create and delete user
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
@@ -579,7 +541,7 @@ class TestBulkSoftDelete:
|
||||
email="predeleted@example.com",
|
||||
password="SecurePass123!",
|
||||
first_name="PreDeleted",
|
||||
last_name="User"
|
||||
last_name="User",
|
||||
)
|
||||
user = await user_crud.create(session, obj_in=user_data)
|
||||
user_id = user.id
|
||||
@@ -589,10 +551,7 @@ class TestBulkSoftDelete:
|
||||
|
||||
# Try to delete again
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
count = await user_crud.bulk_soft_delete(
|
||||
session,
|
||||
user_ids=[user_id]
|
||||
)
|
||||
count = await user_crud.bulk_soft_delete(session, user_ids=[user_id])
|
||||
assert count == 0 # Already deleted
|
||||
|
||||
|
||||
@@ -602,7 +561,7 @@ class TestUtilityMethods:
|
||||
@pytest.mark.asyncio
|
||||
async def test_is_active_true(self, async_test_db, async_test_user):
|
||||
"""Test is_active returns True for active user_crud."""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
user = await user_crud.get(session, id=str(async_test_user.id))
|
||||
@@ -611,14 +570,14 @@ class TestUtilityMethods:
|
||||
@pytest.mark.asyncio
|
||||
async def test_is_active_false(self, async_test_db):
|
||||
"""Test is_active returns False for inactive user_crud."""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
user_data = UserCreate(
|
||||
email="inactive2@example.com",
|
||||
password="SecurePass123!",
|
||||
first_name="Inactive",
|
||||
last_name="User"
|
||||
last_name="User",
|
||||
)
|
||||
user = await user_crud.create(session, obj_in=user_data)
|
||||
await user_crud.update(session, db_obj=user, obj_in={"is_active": False})
|
||||
@@ -628,7 +587,7 @@ class TestUtilityMethods:
|
||||
@pytest.mark.asyncio
|
||||
async def test_is_superuser_true(self, async_test_db, async_test_superuser):
|
||||
"""Test is_superuser returns True for superuser."""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
user = await user_crud.get(session, id=str(async_test_superuser.id))
|
||||
@@ -637,7 +596,7 @@ class TestUtilityMethods:
|
||||
@pytest.mark.asyncio
|
||||
async def test_is_superuser_false(self, async_test_db, async_test_user):
|
||||
"""Test is_superuser returns False for regular user_crud."""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
user = await user_crud.get(session, id=str(async_test_user.id))
|
||||
@@ -654,42 +613,52 @@ class TestUserExceptionHandlers:
|
||||
async def test_get_by_email_database_error(self, async_test_db):
|
||||
"""Test get_by_email handles database errors (covers lines 30-32)."""
|
||||
from unittest.mock import patch
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
with patch.object(session, 'execute', side_effect=Exception("Database query failed")):
|
||||
with patch.object(
|
||||
session, "execute", side_effect=Exception("Database query failed")
|
||||
):
|
||||
with pytest.raises(Exception, match="Database query failed"):
|
||||
await user_crud.get_by_email(session, email="test@example.com")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_bulk_update_status_database_error(self, async_test_db, async_test_user):
|
||||
async def test_bulk_update_status_database_error(
|
||||
self, async_test_db, async_test_user
|
||||
):
|
||||
"""Test bulk_update_status handles database errors (covers lines 205-208)."""
|
||||
from unittest.mock import patch, AsyncMock
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
# Mock execute to fail
|
||||
with patch.object(session, 'execute', side_effect=Exception("Bulk update failed")):
|
||||
with patch.object(session, 'rollback', new_callable=AsyncMock):
|
||||
with patch.object(
|
||||
session, "execute", side_effect=Exception("Bulk update failed")
|
||||
):
|
||||
with patch.object(session, "rollback", new_callable=AsyncMock):
|
||||
with pytest.raises(Exception, match="Bulk update failed"):
|
||||
await user_crud.bulk_update_status(
|
||||
session,
|
||||
user_ids=[async_test_user.id],
|
||||
is_active=False
|
||||
session, user_ids=[async_test_user.id], is_active=False
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_bulk_soft_delete_database_error(self, async_test_db, async_test_user):
|
||||
async def test_bulk_soft_delete_database_error(
|
||||
self, async_test_db, async_test_user
|
||||
):
|
||||
"""Test bulk_soft_delete handles database errors (covers lines 257-260)."""
|
||||
from unittest.mock import patch, AsyncMock
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
# Mock execute to fail
|
||||
with patch.object(session, 'execute', side_effect=Exception("Bulk delete failed")):
|
||||
with patch.object(session, 'rollback', new_callable=AsyncMock):
|
||||
with patch.object(
|
||||
session, "execute", side_effect=Exception("Bulk delete failed")
|
||||
):
|
||||
with patch.object(session, "rollback", new_callable=AsyncMock):
|
||||
with pytest.raises(Exception, match="Bulk delete failed"):
|
||||
await user_crud.bulk_soft_delete(
|
||||
session,
|
||||
user_ids=[async_test_user.id]
|
||||
session, user_ids=[async_test_user.id]
|
||||
)
|
||||
|
||||
@@ -1,8 +1,10 @@
|
||||
# tests/models/test_user.py
|
||||
import uuid
|
||||
import pytest
|
||||
from datetime import datetime
|
||||
|
||||
import pytest
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
|
||||
from app.models.user import User
|
||||
|
||||
|
||||
@@ -166,7 +168,6 @@ def test_user_required_fields(db_session):
|
||||
db_session.rollback()
|
||||
|
||||
|
||||
|
||||
def test_user_defaults(db_session):
|
||||
"""Test that default values are correctly set."""
|
||||
# Arrange - Create a minimal user with only required fields
|
||||
@@ -210,22 +211,13 @@ def test_user_with_complex_json_preferences(db_session):
|
||||
"""Test storing and retrieving complex JSON preferences."""
|
||||
# Arrange - Create a user with nested JSON preferences
|
||||
complex_preferences = {
|
||||
"theme": {
|
||||
"mode": "dark",
|
||||
"colors": {
|
||||
"primary": "#333",
|
||||
"secondary": "#666"
|
||||
}
|
||||
},
|
||||
"theme": {"mode": "dark", "colors": {"primary": "#333", "secondary": "#666"}},
|
||||
"notifications": {
|
||||
"email": True,
|
||||
"sms": False,
|
||||
"push": {
|
||||
"enabled": True,
|
||||
"quiet_hours": [22, 7]
|
||||
}
|
||||
"push": {"enabled": True, "quiet_hours": [22, 7]},
|
||||
},
|
||||
"tags": ["important", "family", "events"]
|
||||
"tags": ["important", "family", "events"],
|
||||
}
|
||||
|
||||
user = User(
|
||||
@@ -234,16 +226,18 @@ def test_user_with_complex_json_preferences(db_session):
|
||||
password_hash="hashedpassword",
|
||||
first_name="Complex",
|
||||
last_name="JSON",
|
||||
preferences=complex_preferences
|
||||
preferences=complex_preferences,
|
||||
)
|
||||
db_session.add(user)
|
||||
db_session.commit()
|
||||
|
||||
# Act - Retrieve the user
|
||||
retrieved_user = db_session.query(User).filter_by(email="complex@example.com").first()
|
||||
retrieved_user = (
|
||||
db_session.query(User).filter_by(email="complex@example.com").first()
|
||||
)
|
||||
|
||||
# Assert - The complex JSON should be preserved
|
||||
assert retrieved_user.preferences == complex_preferences
|
||||
assert retrieved_user.preferences["theme"]["colors"]["primary"] == "#333"
|
||||
assert retrieved_user.preferences["notifications"]["push"]["quiet_hours"] == [22, 7]
|
||||
assert "important" in retrieved_user.preferences["tags"]
|
||||
assert "important" in retrieved_user.preferences["tags"]
|
||||
|
||||
@@ -5,6 +5,7 @@ Covers Pydantic validators for:
|
||||
- Slug validation (lines 26, 28, 30, 32, 62-70)
|
||||
- Name validation (lines 40, 77)
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from pydantic import ValidationError
|
||||
|
||||
@@ -20,19 +21,13 @@ class TestOrganizationBaseValidators:
|
||||
|
||||
def test_valid_organization_base(self):
|
||||
"""Test that valid data passes validation."""
|
||||
org = OrganizationBase(
|
||||
name="Test Organization",
|
||||
slug="test-org"
|
||||
)
|
||||
org = OrganizationBase(name="Test Organization", slug="test-org")
|
||||
assert org.name == "Test Organization"
|
||||
assert org.slug == "test-org"
|
||||
|
||||
def test_slug_none_returns_none(self):
|
||||
"""Test that None slug is allowed (covers line 26)."""
|
||||
org = OrganizationBase(
|
||||
name="Test Organization",
|
||||
slug=None
|
||||
)
|
||||
org = OrganizationBase(name="Test Organization", slug=None)
|
||||
assert org.slug is None
|
||||
|
||||
def test_slug_invalid_characters_rejected(self):
|
||||
@@ -40,57 +35,46 @@ class TestOrganizationBaseValidators:
|
||||
with pytest.raises(ValidationError) as exc_info:
|
||||
OrganizationBase(
|
||||
name="Test Organization",
|
||||
slug="Test_Org!" # Uppercase and special chars
|
||||
slug="Test_Org!", # Uppercase and special chars
|
||||
)
|
||||
errors = exc_info.value.errors()
|
||||
assert any("lowercase letters, numbers, and hyphens" in str(e['msg']) for e in errors)
|
||||
assert any(
|
||||
"lowercase letters, numbers, and hyphens" in str(e["msg"]) for e in errors
|
||||
)
|
||||
|
||||
def test_slug_starts_with_hyphen_rejected(self):
|
||||
"""Test slug starting with hyphen is rejected (covers line 30)."""
|
||||
with pytest.raises(ValidationError) as exc_info:
|
||||
OrganizationBase(
|
||||
name="Test Organization",
|
||||
slug="-test-org"
|
||||
)
|
||||
OrganizationBase(name="Test Organization", slug="-test-org")
|
||||
errors = exc_info.value.errors()
|
||||
assert any("cannot start or end with a hyphen" in str(e['msg']) for e in errors)
|
||||
assert any("cannot start or end with a hyphen" in str(e["msg"]) for e in errors)
|
||||
|
||||
def test_slug_ends_with_hyphen_rejected(self):
|
||||
"""Test slug ending with hyphen is rejected (covers line 30)."""
|
||||
with pytest.raises(ValidationError) as exc_info:
|
||||
OrganizationBase(
|
||||
name="Test Organization",
|
||||
slug="test-org-"
|
||||
)
|
||||
OrganizationBase(name="Test Organization", slug="test-org-")
|
||||
errors = exc_info.value.errors()
|
||||
assert any("cannot start or end with a hyphen" in str(e['msg']) for e in errors)
|
||||
assert any("cannot start or end with a hyphen" in str(e["msg"]) for e in errors)
|
||||
|
||||
def test_slug_consecutive_hyphens_rejected(self):
|
||||
"""Test slug with consecutive hyphens is rejected (covers line 32)."""
|
||||
with pytest.raises(ValidationError) as exc_info:
|
||||
OrganizationBase(
|
||||
name="Test Organization",
|
||||
slug="test--org"
|
||||
)
|
||||
OrganizationBase(name="Test Organization", slug="test--org")
|
||||
errors = exc_info.value.errors()
|
||||
assert any("cannot contain consecutive hyphens" in str(e['msg']) for e in errors)
|
||||
assert any(
|
||||
"cannot contain consecutive hyphens" in str(e["msg"]) for e in errors
|
||||
)
|
||||
|
||||
def test_name_whitespace_only_rejected(self):
|
||||
"""Test whitespace-only name is rejected (covers line 40)."""
|
||||
with pytest.raises(ValidationError) as exc_info:
|
||||
OrganizationBase(
|
||||
name=" ",
|
||||
slug="test-org"
|
||||
)
|
||||
OrganizationBase(name=" ", slug="test-org")
|
||||
errors = exc_info.value.errors()
|
||||
assert any("name cannot be empty" in str(e['msg']) for e in errors)
|
||||
assert any("name cannot be empty" in str(e["msg"]) for e in errors)
|
||||
|
||||
def test_name_trimmed(self):
|
||||
"""Test that name is trimmed."""
|
||||
org = OrganizationBase(
|
||||
name=" Test Organization ",
|
||||
slug="test-org"
|
||||
)
|
||||
org = OrganizationBase(name=" Test Organization ", slug="test-org")
|
||||
assert org.name == "Test Organization"
|
||||
|
||||
|
||||
@@ -99,22 +83,18 @@ class TestOrganizationCreateValidators:
|
||||
|
||||
def test_valid_organization_create(self):
|
||||
"""Test that valid data passes validation."""
|
||||
org = OrganizationCreate(
|
||||
name="Test Organization",
|
||||
slug="test-org"
|
||||
)
|
||||
org = OrganizationCreate(name="Test Organization", slug="test-org")
|
||||
assert org.name == "Test Organization"
|
||||
assert org.slug == "test-org"
|
||||
|
||||
def test_slug_validation_inherited(self):
|
||||
"""Test that slug validation is inherited from base."""
|
||||
with pytest.raises(ValidationError) as exc_info:
|
||||
OrganizationCreate(
|
||||
name="Test",
|
||||
slug="Invalid_Slug!"
|
||||
)
|
||||
OrganizationCreate(name="Test", slug="Invalid_Slug!")
|
||||
errors = exc_info.value.errors()
|
||||
assert any("lowercase letters, numbers, and hyphens" in str(e['msg']) for e in errors)
|
||||
assert any(
|
||||
"lowercase letters, numbers, and hyphens" in str(e["msg"]) for e in errors
|
||||
)
|
||||
|
||||
|
||||
class TestOrganizationUpdateValidators:
|
||||
@@ -122,10 +102,7 @@ class TestOrganizationUpdateValidators:
|
||||
|
||||
def test_valid_organization_update(self):
|
||||
"""Test that valid update data passes validation."""
|
||||
org = OrganizationUpdate(
|
||||
name="Updated Name",
|
||||
slug="updated-slug"
|
||||
)
|
||||
org = OrganizationUpdate(name="Updated Name", slug="updated-slug")
|
||||
assert org.name == "Updated Name"
|
||||
assert org.slug == "updated-slug"
|
||||
|
||||
@@ -139,35 +116,39 @@ class TestOrganizationUpdateValidators:
|
||||
with pytest.raises(ValidationError) as exc_info:
|
||||
OrganizationUpdate(slug="Test_Org!")
|
||||
errors = exc_info.value.errors()
|
||||
assert any("lowercase letters, numbers, and hyphens" in str(e['msg']) for e in errors)
|
||||
assert any(
|
||||
"lowercase letters, numbers, and hyphens" in str(e["msg"]) for e in errors
|
||||
)
|
||||
|
||||
def test_update_slug_starts_with_hyphen_rejected(self):
|
||||
"""Test update slug starting with hyphen is rejected (covers line 66)."""
|
||||
with pytest.raises(ValidationError) as exc_info:
|
||||
OrganizationUpdate(slug="-test-org")
|
||||
errors = exc_info.value.errors()
|
||||
assert any("cannot start or end with a hyphen" in str(e['msg']) for e in errors)
|
||||
assert any("cannot start or end with a hyphen" in str(e["msg"]) for e in errors)
|
||||
|
||||
def test_update_slug_ends_with_hyphen_rejected(self):
|
||||
"""Test update slug ending with hyphen is rejected (covers line 66)."""
|
||||
with pytest.raises(ValidationError) as exc_info:
|
||||
OrganizationUpdate(slug="test-org-")
|
||||
errors = exc_info.value.errors()
|
||||
assert any("cannot start or end with a hyphen" in str(e['msg']) for e in errors)
|
||||
assert any("cannot start or end with a hyphen" in str(e["msg"]) for e in errors)
|
||||
|
||||
def test_update_slug_consecutive_hyphens_rejected(self):
|
||||
"""Test update slug with consecutive hyphens is rejected (covers line 68)."""
|
||||
with pytest.raises(ValidationError) as exc_info:
|
||||
OrganizationUpdate(slug="test--org")
|
||||
errors = exc_info.value.errors()
|
||||
assert any("cannot contain consecutive hyphens" in str(e['msg']) for e in errors)
|
||||
assert any(
|
||||
"cannot contain consecutive hyphens" in str(e["msg"]) for e in errors
|
||||
)
|
||||
|
||||
def test_update_name_whitespace_only_rejected(self):
|
||||
"""Test whitespace-only name in update is rejected (covers line 77)."""
|
||||
with pytest.raises(ValidationError) as exc_info:
|
||||
OrganizationUpdate(name=" ")
|
||||
errors = exc_info.value.errors()
|
||||
assert any("name cannot be empty" in str(e['msg']) for e in errors)
|
||||
assert any("name cannot be empty" in str(e["msg"]) for e in errors)
|
||||
|
||||
def test_update_name_none_allowed(self):
|
||||
"""Test that None name is allowed in update."""
|
||||
|
||||
@@ -1,80 +1,177 @@
|
||||
# tests/schemas/test_user_schemas.py
|
||||
import pytest
|
||||
import re
|
||||
|
||||
import pytest
|
||||
from pydantic import ValidationError
|
||||
|
||||
from app.schemas.users import UserBase, UserCreate
|
||||
|
||||
|
||||
class TestPhoneNumberValidation:
|
||||
"""Tests for phone number validation in user schemas"""
|
||||
|
||||
def test_valid_swiss_numbers(self):
|
||||
"""Test valid Swiss phone numbers are accepted"""
|
||||
# International format
|
||||
user = UserBase(email="test@example.com", first_name="Test", last_name="User", phone_number="+41791234567")
|
||||
user = UserBase(
|
||||
email="test@example.com",
|
||||
first_name="Test",
|
||||
last_name="User",
|
||||
phone_number="+41791234567",
|
||||
)
|
||||
assert user.phone_number == "+41791234567"
|
||||
|
||||
# Local format
|
||||
user = UserBase(email="test@example.com", first_name="Test", last_name="User", phone_number="0791234567")
|
||||
user = UserBase(
|
||||
email="test@example.com",
|
||||
first_name="Test",
|
||||
last_name="User",
|
||||
phone_number="0791234567",
|
||||
)
|
||||
assert user.phone_number == "0791234567"
|
||||
|
||||
# With formatting characters
|
||||
user = UserBase(email="test@example.com", first_name="Test", last_name="User", phone_number="+41 79 123 45 67")
|
||||
assert re.sub(r'[\s\-\(\)]', '', user.phone_number) == "+41791234567"
|
||||
user = UserBase(
|
||||
email="test@example.com",
|
||||
first_name="Test",
|
||||
last_name="User",
|
||||
phone_number="+41 79 123 45 67",
|
||||
)
|
||||
assert re.sub(r"[\s\-\(\)]", "", user.phone_number) == "+41791234567"
|
||||
|
||||
user = UserBase(email="test@example.com", first_name="Test", last_name="User", phone_number="079 123 45 67")
|
||||
assert re.sub(r'[\s\-\(\)]', '', user.phone_number) == "0791234567"
|
||||
user = UserBase(
|
||||
email="test@example.com",
|
||||
first_name="Test",
|
||||
last_name="User",
|
||||
phone_number="079 123 45 67",
|
||||
)
|
||||
assert re.sub(r"[\s\-\(\)]", "", user.phone_number) == "0791234567"
|
||||
|
||||
user = UserBase(email="test@example.com", first_name="Test", last_name="User", phone_number="+41-79-123-45-67")
|
||||
assert re.sub(r'[\s\-\(\)]', '', user.phone_number) == "+41791234567"
|
||||
user = UserBase(
|
||||
email="test@example.com",
|
||||
first_name="Test",
|
||||
last_name="User",
|
||||
phone_number="+41-79-123-45-67",
|
||||
)
|
||||
assert re.sub(r"[\s\-\(\)]", "", user.phone_number) == "+41791234567"
|
||||
|
||||
user = UserBase(email="test@example.com", first_name="Test", last_name="User", phone_number="079-123-45-67")
|
||||
assert re.sub(r'[\s\-\(\)]', '', user.phone_number) == "0791234567"
|
||||
user = UserBase(
|
||||
email="test@example.com",
|
||||
first_name="Test",
|
||||
last_name="User",
|
||||
phone_number="079-123-45-67",
|
||||
)
|
||||
assert re.sub(r"[\s\-\(\)]", "", user.phone_number) == "0791234567"
|
||||
|
||||
user = UserBase(email="test@example.com", first_name="Test", last_name="User", phone_number="+41 (79) 123 45 67")
|
||||
assert re.sub(r'[\s\-\(\)]', '', user.phone_number) == "+41791234567"
|
||||
user = UserBase(
|
||||
email="test@example.com",
|
||||
first_name="Test",
|
||||
last_name="User",
|
||||
phone_number="+41 (79) 123 45 67",
|
||||
)
|
||||
assert re.sub(r"[\s\-\(\)]", "", user.phone_number) == "+41791234567"
|
||||
|
||||
user = UserBase(email="test@example.com", first_name="Test", last_name="User", phone_number="079 (123) 45 67")
|
||||
assert re.sub(r'[\s\-\(\)]', '', user.phone_number) == "0791234567"
|
||||
user = UserBase(
|
||||
email="test@example.com",
|
||||
first_name="Test",
|
||||
last_name="User",
|
||||
phone_number="079 (123) 45 67",
|
||||
)
|
||||
assert re.sub(r"[\s\-\(\)]", "", user.phone_number) == "0791234567"
|
||||
|
||||
def test_valid_italian_numbers(self):
|
||||
"""Test valid Italian phone numbers are accepted"""
|
||||
# International format
|
||||
user = UserBase(email="test@example.com", first_name="Test", last_name="User", phone_number="+393451234567")
|
||||
user = UserBase(
|
||||
email="test@example.com",
|
||||
first_name="Test",
|
||||
last_name="User",
|
||||
phone_number="+393451234567",
|
||||
)
|
||||
assert user.phone_number == "+393451234567"
|
||||
|
||||
user = UserBase(email="test@example.com", first_name="Test", last_name="User", phone_number="+39345123456")
|
||||
user = UserBase(
|
||||
email="test@example.com",
|
||||
first_name="Test",
|
||||
last_name="User",
|
||||
phone_number="+39345123456",
|
||||
)
|
||||
assert user.phone_number == "+39345123456"
|
||||
|
||||
# Local format
|
||||
user = UserBase(email="test@example.com", first_name="Test", last_name="User", phone_number="03451234567")
|
||||
user = UserBase(
|
||||
email="test@example.com",
|
||||
first_name="Test",
|
||||
last_name="User",
|
||||
phone_number="03451234567",
|
||||
)
|
||||
assert user.phone_number == "03451234567"
|
||||
|
||||
user = UserBase(email="test@example.com", first_name="Test", last_name="User", phone_number="0345123456789")
|
||||
user = UserBase(
|
||||
email="test@example.com",
|
||||
first_name="Test",
|
||||
last_name="User",
|
||||
phone_number="0345123456789",
|
||||
)
|
||||
assert user.phone_number == "0345123456789"
|
||||
|
||||
# With formatting characters
|
||||
user = UserBase(email="test@example.com", first_name="Test", last_name="User", phone_number="+39 345 123 4567")
|
||||
assert re.sub(r'[\s\-\(\)]', '', user.phone_number) == "+393451234567"
|
||||
user = UserBase(
|
||||
email="test@example.com",
|
||||
first_name="Test",
|
||||
last_name="User",
|
||||
phone_number="+39 345 123 4567",
|
||||
)
|
||||
assert re.sub(r"[\s\-\(\)]", "", user.phone_number) == "+393451234567"
|
||||
|
||||
user = UserBase(email="test@example.com", first_name="Test", last_name="User", phone_number="0345 123 4567")
|
||||
assert re.sub(r'[\s\-\(\)]', '', user.phone_number) == "03451234567"
|
||||
user = UserBase(
|
||||
email="test@example.com",
|
||||
first_name="Test",
|
||||
last_name="User",
|
||||
phone_number="0345 123 4567",
|
||||
)
|
||||
assert re.sub(r"[\s\-\(\)]", "", user.phone_number) == "03451234567"
|
||||
|
||||
user = UserBase(email="test@example.com", first_name="Test", last_name="User", phone_number="+39-345-123-4567")
|
||||
assert re.sub(r'[\s\-\(\)]', '', user.phone_number) == "+393451234567"
|
||||
user = UserBase(
|
||||
email="test@example.com",
|
||||
first_name="Test",
|
||||
last_name="User",
|
||||
phone_number="+39-345-123-4567",
|
||||
)
|
||||
assert re.sub(r"[\s\-\(\)]", "", user.phone_number) == "+393451234567"
|
||||
|
||||
user = UserBase(email="test@example.com", first_name="Test", last_name="User", phone_number="0345-123-4567")
|
||||
assert re.sub(r'[\s\-\(\)]', '', user.phone_number) == "03451234567"
|
||||
user = UserBase(
|
||||
email="test@example.com",
|
||||
first_name="Test",
|
||||
last_name="User",
|
||||
phone_number="0345-123-4567",
|
||||
)
|
||||
assert re.sub(r"[\s\-\(\)]", "", user.phone_number) == "03451234567"
|
||||
|
||||
user = UserBase(email="test@example.com", first_name="Test", last_name="User", phone_number="+39 (345) 123 4567")
|
||||
assert re.sub(r'[\s\-\(\)]', '', user.phone_number) == "+393451234567"
|
||||
user = UserBase(
|
||||
email="test@example.com",
|
||||
first_name="Test",
|
||||
last_name="User",
|
||||
phone_number="+39 (345) 123 4567",
|
||||
)
|
||||
assert re.sub(r"[\s\-\(\)]", "", user.phone_number) == "+393451234567"
|
||||
|
||||
user = UserBase(email="test@example.com", first_name="Test", last_name="User", phone_number="0345 (123) 4567")
|
||||
assert re.sub(r'[\s\-\(\)]', '', user.phone_number) == "03451234567"
|
||||
user = UserBase(
|
||||
email="test@example.com",
|
||||
first_name="Test",
|
||||
last_name="User",
|
||||
phone_number="0345 (123) 4567",
|
||||
)
|
||||
assert re.sub(r"[\s\-\(\)]", "", user.phone_number) == "03451234567"
|
||||
|
||||
def test_none_phone_number(self):
|
||||
"""Test that None is accepted as a valid value (optional phone number)"""
|
||||
user = UserBase(email="test@example.com", first_name="Test", last_name="User", phone_number=None)
|
||||
user = UserBase(
|
||||
email="test@example.com",
|
||||
first_name="Test",
|
||||
last_name="User",
|
||||
phone_number=None,
|
||||
)
|
||||
assert user.phone_number is None
|
||||
|
||||
def test_invalid_phone_numbers(self):
|
||||
@@ -83,17 +180,14 @@ class TestPhoneNumberValidation:
|
||||
# Too short
|
||||
"+12",
|
||||
"012",
|
||||
|
||||
# Invalid characters
|
||||
"+41xyz123456",
|
||||
"079abc4567",
|
||||
"123-abc-7890",
|
||||
"+1(800)CALL-NOW",
|
||||
|
||||
# Completely invalid formats
|
||||
"++4412345678", # Double plus
|
||||
# Note: "()+41123456" becomes "+41123456" after cleaning, which is valid
|
||||
|
||||
# Empty string
|
||||
"",
|
||||
# Spaces only
|
||||
@@ -102,7 +196,12 @@ class TestPhoneNumberValidation:
|
||||
|
||||
for number in invalid_numbers:
|
||||
with pytest.raises(ValidationError):
|
||||
UserBase(email="test@example.com", first_name="Test", last_name="User", phone_number=number)
|
||||
UserBase(
|
||||
email="test@example.com",
|
||||
first_name="Test",
|
||||
last_name="User",
|
||||
phone_number=number,
|
||||
)
|
||||
|
||||
def test_phone_validation_in_user_create(self):
|
||||
"""Test that phone validation also works in UserCreate schema"""
|
||||
@@ -112,7 +211,7 @@ class TestPhoneNumberValidation:
|
||||
first_name="Test",
|
||||
last_name="User",
|
||||
password="Password123!",
|
||||
phone_number="+41791234567"
|
||||
phone_number="+41791234567",
|
||||
)
|
||||
assert user.phone_number == "+41791234567"
|
||||
|
||||
@@ -123,5 +222,5 @@ class TestPhoneNumberValidation:
|
||||
first_name="Test",
|
||||
last_name="User",
|
||||
password="Password123!",
|
||||
phone_number="invalid-number"
|
||||
)
|
||||
phone_number="invalid-number",
|
||||
)
|
||||
|
||||
@@ -7,12 +7,13 @@ Covers all edge cases in validation functions:
|
||||
- validate_email_format (line 148)
|
||||
- validate_slug (lines 170-183)
|
||||
"""
|
||||
|
||||
import pytest
|
||||
|
||||
from app.schemas.validators import (
|
||||
validate_email_format,
|
||||
validate_password_strength,
|
||||
validate_phone_number,
|
||||
validate_email_format,
|
||||
validate_slug,
|
||||
)
|
||||
|
||||
@@ -108,12 +109,14 @@ class TestPhoneNumberValidator:
|
||||
validate_phone_number("+123456789012345") # 15 digits after +
|
||||
|
||||
def test_multiple_plus_symbols_rejected(self):
|
||||
"""Test phone number with multiple + symbols.
|
||||
r"""Test phone number with multiple + symbols.
|
||||
|
||||
Note: Line 115 is defensive code - the regex check at line 110 catches this first.
|
||||
The regex ^(?:\+[0-9]{8,14}|0[0-9]{8,14})$ only allows + at the start.
|
||||
"""
|
||||
with pytest.raises(ValueError, match="must start with \\+ or 0 followed by 8-14 digits"):
|
||||
with pytest.raises(
|
||||
ValueError, match="must start with \\+ or 0 followed by 8-14 digits"
|
||||
):
|
||||
validate_phone_number("+1234+5678901")
|
||||
|
||||
def test_non_digit_after_prefix_rejected(self):
|
||||
|
||||
@@ -1,14 +1,18 @@
|
||||
# tests/services/test_auth_service.py
|
||||
import uuid
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from sqlalchemy import select
|
||||
|
||||
from app.core.auth import get_password_hash, verify_password, TokenExpiredError, TokenInvalidError
|
||||
from app.core.auth import (
|
||||
TokenInvalidError,
|
||||
get_password_hash,
|
||||
verify_password,
|
||||
)
|
||||
from app.models.user import User
|
||||
from app.schemas.users import UserCreate, Token
|
||||
from app.services.auth_service import AuthService, AuthenticationError
|
||||
from app.schemas.users import Token, UserCreate
|
||||
from app.services.auth_service import AuthenticationError, AuthService
|
||||
|
||||
|
||||
class TestAuthServiceAuthentication:
|
||||
@@ -17,12 +21,14 @@ class TestAuthServiceAuthentication:
|
||||
@pytest.mark.asyncio
|
||||
async def test_authenticate_valid_user(self, async_test_db, async_test_user):
|
||||
"""Test authenticating a user with valid credentials"""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
# Set a known password for the mock user
|
||||
password = "TestPassword123!"
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
result = await session.execute(select(User).where(User.id == async_test_user.id))
|
||||
result = await session.execute(
|
||||
select(User).where(User.id == async_test_user.id)
|
||||
)
|
||||
user = result.scalar_one_or_none()
|
||||
user.password_hash = get_password_hash(password)
|
||||
await session.commit()
|
||||
@@ -30,9 +36,7 @@ class TestAuthServiceAuthentication:
|
||||
# Authenticate with correct credentials
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
auth_user = await AuthService.authenticate_user(
|
||||
db=session,
|
||||
email=async_test_user.email,
|
||||
password=password
|
||||
db=session, email=async_test_user.email, password=password
|
||||
)
|
||||
|
||||
assert auth_user is not None
|
||||
@@ -42,26 +46,28 @@ class TestAuthServiceAuthentication:
|
||||
@pytest.mark.asyncio
|
||||
async def test_authenticate_nonexistent_user(self, async_test_db):
|
||||
"""Test authenticating with an email that doesn't exist"""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
user = await AuthService.authenticate_user(
|
||||
db=session,
|
||||
email="nonexistent@example.com",
|
||||
password="password"
|
||||
db=session, email="nonexistent@example.com", password="password"
|
||||
)
|
||||
|
||||
assert user is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_authenticate_with_wrong_password(self, async_test_db, async_test_user):
|
||||
async def test_authenticate_with_wrong_password(
|
||||
self, async_test_db, async_test_user
|
||||
):
|
||||
"""Test authenticating with the wrong password"""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
# Set a known password for the mock user
|
||||
password = "TestPassword123!"
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
result = await session.execute(select(User).where(User.id == async_test_user.id))
|
||||
result = await session.execute(
|
||||
select(User).where(User.id == async_test_user.id)
|
||||
)
|
||||
user = result.scalar_one_or_none()
|
||||
user.password_hash = get_password_hash(password)
|
||||
await session.commit()
|
||||
@@ -69,9 +75,7 @@ class TestAuthServiceAuthentication:
|
||||
# Authenticate with wrong password
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
auth_user = await AuthService.authenticate_user(
|
||||
db=session,
|
||||
email=async_test_user.email,
|
||||
password="WrongPassword123"
|
||||
db=session, email=async_test_user.email, password="WrongPassword123"
|
||||
)
|
||||
|
||||
assert auth_user is None
|
||||
@@ -79,12 +83,14 @@ class TestAuthServiceAuthentication:
|
||||
@pytest.mark.asyncio
|
||||
async def test_authenticate_inactive_user(self, async_test_db, async_test_user):
|
||||
"""Test authenticating an inactive user"""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
# Set a known password and make user inactive
|
||||
password = "TestPassword123!"
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
result = await session.execute(select(User).where(User.id == async_test_user.id))
|
||||
result = await session.execute(
|
||||
select(User).where(User.id == async_test_user.id)
|
||||
)
|
||||
user = result.scalar_one_or_none()
|
||||
user.password_hash = get_password_hash(password)
|
||||
user.is_active = False
|
||||
@@ -94,9 +100,7 @@ class TestAuthServiceAuthentication:
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
with pytest.raises(AuthenticationError):
|
||||
await AuthService.authenticate_user(
|
||||
db=session,
|
||||
email=async_test_user.email,
|
||||
password=password
|
||||
db=session, email=async_test_user.email, password=password
|
||||
)
|
||||
|
||||
|
||||
@@ -106,14 +110,14 @@ class TestAuthServiceUserCreation:
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_new_user(self, async_test_db):
|
||||
"""Test creating a new user"""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
user_data = UserCreate(
|
||||
email="newuser@example.com",
|
||||
password="TestPassword123!",
|
||||
first_name="New",
|
||||
last_name="User",
|
||||
phone_number="+1234567890"
|
||||
phone_number="+1234567890",
|
||||
)
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
@@ -135,15 +139,17 @@ class TestAuthServiceUserCreation:
|
||||
assert user.is_superuser is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_user_with_existing_email(self, async_test_db, async_test_user):
|
||||
async def test_create_user_with_existing_email(
|
||||
self, async_test_db, async_test_user
|
||||
):
|
||||
"""Test creating a user with an email that already exists"""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
user_data = UserCreate(
|
||||
email=async_test_user.email, # Use existing email
|
||||
password="TestPassword123!",
|
||||
first_name="Duplicate",
|
||||
last_name="User"
|
||||
last_name="User",
|
||||
)
|
||||
|
||||
# Should raise AuthenticationError
|
||||
@@ -169,7 +175,7 @@ class TestAuthServiceTokens:
|
||||
@pytest.mark.asyncio
|
||||
async def test_refresh_tokens(self, async_test_db, async_test_user):
|
||||
"""Test refreshing tokens with a valid refresh token"""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
# Create initial tokens
|
||||
initial_tokens = AuthService.create_tokens(async_test_user)
|
||||
@@ -177,8 +183,7 @@ class TestAuthServiceTokens:
|
||||
# Refresh tokens
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
new_tokens = await AuthService.refresh_tokens(
|
||||
db=session,
|
||||
refresh_token=initial_tokens.refresh_token
|
||||
db=session, refresh_token=initial_tokens.refresh_token
|
||||
)
|
||||
|
||||
# Verify new tokens are different from old ones
|
||||
@@ -188,7 +193,7 @@ class TestAuthServiceTokens:
|
||||
@pytest.mark.asyncio
|
||||
async def test_refresh_tokens_with_invalid_token(self, async_test_db):
|
||||
"""Test refreshing tokens with an invalid token"""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
# Create an invalid token
|
||||
invalid_token = "invalid.token.string"
|
||||
@@ -197,14 +202,15 @@ class TestAuthServiceTokens:
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
with pytest.raises(TokenInvalidError):
|
||||
await AuthService.refresh_tokens(
|
||||
db=session,
|
||||
refresh_token=invalid_token
|
||||
db=session, refresh_token=invalid_token
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_refresh_tokens_with_access_token(self, async_test_db, async_test_user):
|
||||
async def test_refresh_tokens_with_access_token(
|
||||
self, async_test_db, async_test_user
|
||||
):
|
||||
"""Test refreshing tokens with an access token instead of refresh token"""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
# Create tokens
|
||||
tokens = AuthService.create_tokens(async_test_user)
|
||||
@@ -213,18 +219,20 @@ class TestAuthServiceTokens:
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
with pytest.raises(TokenInvalidError):
|
||||
await AuthService.refresh_tokens(
|
||||
db=session,
|
||||
refresh_token=tokens.access_token
|
||||
db=session, refresh_token=tokens.access_token
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_refresh_tokens_with_nonexistent_user(self, async_test_db):
|
||||
"""Test refreshing tokens for a user that doesn't exist in the database"""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
# Create a token for a non-existent user
|
||||
non_existent_id = str(uuid.uuid4())
|
||||
with patch('app.core.auth.decode_token'), patch('app.core.auth.get_token_data') as mock_get_data:
|
||||
with (
|
||||
patch("app.core.auth.decode_token"),
|
||||
patch("app.core.auth.get_token_data") as mock_get_data,
|
||||
):
|
||||
# Mock the token data to return a non-existent user ID
|
||||
mock_get_data.return_value.user_id = uuid.UUID(non_existent_id)
|
||||
|
||||
@@ -232,8 +240,7 @@ class TestAuthServiceTokens:
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
with pytest.raises(TokenInvalidError):
|
||||
await AuthService.refresh_tokens(
|
||||
db=session,
|
||||
refresh_token="some.refresh.token"
|
||||
db=session, refresh_token="some.refresh.token"
|
||||
)
|
||||
|
||||
|
||||
@@ -243,12 +250,14 @@ class TestAuthServicePasswordChange:
|
||||
@pytest.mark.asyncio
|
||||
async def test_change_password(self, async_test_db, async_test_user):
|
||||
"""Test changing a user's password"""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
# Set a known password for the mock user
|
||||
current_password = "CurrentPassword123"
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
result = await session.execute(select(User).where(User.id == async_test_user.id))
|
||||
result = await session.execute(
|
||||
select(User).where(User.id == async_test_user.id)
|
||||
)
|
||||
user = result.scalar_one_or_none()
|
||||
user.password_hash = get_password_hash(current_password)
|
||||
await session.commit()
|
||||
@@ -260,7 +269,7 @@ class TestAuthServicePasswordChange:
|
||||
db=session,
|
||||
user_id=async_test_user.id,
|
||||
current_password=current_password,
|
||||
new_password=new_password
|
||||
new_password=new_password,
|
||||
)
|
||||
|
||||
# Verify operation was successful
|
||||
@@ -268,7 +277,9 @@ class TestAuthServicePasswordChange:
|
||||
|
||||
# Verify password was changed
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
result = await session.execute(select(User).where(User.id == async_test_user.id))
|
||||
result = await session.execute(
|
||||
select(User).where(User.id == async_test_user.id)
|
||||
)
|
||||
updated_user = result.scalar_one_or_none()
|
||||
|
||||
# Verify old password no longer works
|
||||
@@ -278,14 +289,18 @@ class TestAuthServicePasswordChange:
|
||||
assert verify_password(new_password, updated_user.password_hash)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_change_password_wrong_current_password(self, async_test_db, async_test_user):
|
||||
async def test_change_password_wrong_current_password(
|
||||
self, async_test_db, async_test_user
|
||||
):
|
||||
"""Test changing password with incorrect current password"""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
# Set a known password for the mock user
|
||||
current_password = "CurrentPassword123"
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
result = await session.execute(select(User).where(User.id == async_test_user.id))
|
||||
result = await session.execute(
|
||||
select(User).where(User.id == async_test_user.id)
|
||||
)
|
||||
user = result.scalar_one_or_none()
|
||||
user.password_hash = get_password_hash(current_password)
|
||||
await session.commit()
|
||||
@@ -298,19 +313,21 @@ class TestAuthServicePasswordChange:
|
||||
db=session,
|
||||
user_id=async_test_user.id,
|
||||
current_password=wrong_password,
|
||||
new_password="NewPassword456"
|
||||
new_password="NewPassword456",
|
||||
)
|
||||
|
||||
# Verify password was not changed
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
result = await session.execute(select(User).where(User.id == async_test_user.id))
|
||||
result = await session.execute(
|
||||
select(User).where(User.id == async_test_user.id)
|
||||
)
|
||||
user = result.scalar_one_or_none()
|
||||
assert verify_password(current_password, user.password_hash)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_change_password_nonexistent_user(self, async_test_db):
|
||||
"""Test changing password for a user that doesn't exist"""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
non_existent_id = uuid.uuid4()
|
||||
|
||||
@@ -320,5 +337,5 @@ class TestAuthServicePasswordChange:
|
||||
db=session,
|
||||
user_id=non_existent_id,
|
||||
current_password="CurrentPassword123",
|
||||
new_password="NewPassword456"
|
||||
new_password="NewPassword456",
|
||||
)
|
||||
|
||||
@@ -2,13 +2,15 @@
|
||||
"""
|
||||
Tests for email service functionality.
|
||||
"""
|
||||
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
import pytest
|
||||
from unittest.mock import patch, AsyncMock, MagicMock
|
||||
|
||||
from app.services.email_service import (
|
||||
EmailService,
|
||||
ConsoleEmailBackend,
|
||||
SMTPEmailBackend
|
||||
EmailService,
|
||||
SMTPEmailBackend,
|
||||
)
|
||||
|
||||
|
||||
@@ -24,7 +26,7 @@ class TestConsoleEmailBackend:
|
||||
to=["user@example.com"],
|
||||
subject="Test Subject",
|
||||
html_content="<p>Test HTML</p>",
|
||||
text_content="Test Text"
|
||||
text_content="Test Text",
|
||||
)
|
||||
|
||||
assert result is True
|
||||
@@ -37,7 +39,7 @@ class TestConsoleEmailBackend:
|
||||
result = await backend.send_email(
|
||||
to=["user@example.com"],
|
||||
subject="Test Subject",
|
||||
html_content="<p>Test HTML</p>"
|
||||
html_content="<p>Test HTML</p>",
|
||||
)
|
||||
|
||||
assert result is True
|
||||
@@ -50,7 +52,7 @@ class TestConsoleEmailBackend:
|
||||
result = await backend.send_email(
|
||||
to=["user1@example.com", "user2@example.com"],
|
||||
subject="Test Subject",
|
||||
html_content="<p>Test HTML</p>"
|
||||
html_content="<p>Test HTML</p>",
|
||||
)
|
||||
|
||||
assert result is True
|
||||
@@ -66,7 +68,7 @@ class TestSMTPEmailBackend:
|
||||
host="smtp.example.com",
|
||||
port=587,
|
||||
username="test@example.com",
|
||||
password="password"
|
||||
password="password",
|
||||
)
|
||||
|
||||
assert backend.host == "smtp.example.com"
|
||||
@@ -81,14 +83,14 @@ class TestSMTPEmailBackend:
|
||||
host="smtp.example.com",
|
||||
port=587,
|
||||
username="test@example.com",
|
||||
password="password"
|
||||
password="password",
|
||||
)
|
||||
|
||||
# Should fall back to console backend since SMTP is not implemented
|
||||
result = await backend.send_email(
|
||||
to=["user@example.com"],
|
||||
subject="Test Subject",
|
||||
html_content="<p>Test HTML</p>"
|
||||
html_content="<p>Test HTML</p>",
|
||||
)
|
||||
|
||||
assert result is True
|
||||
@@ -114,9 +116,7 @@ class TestEmailService:
|
||||
service = EmailService()
|
||||
|
||||
result = await service.send_password_reset_email(
|
||||
to_email="user@example.com",
|
||||
reset_token="test_token_123",
|
||||
user_name="John"
|
||||
to_email="user@example.com", reset_token="test_token_123", user_name="John"
|
||||
)
|
||||
|
||||
assert result is True
|
||||
@@ -127,8 +127,7 @@ class TestEmailService:
|
||||
service = EmailService()
|
||||
|
||||
result = await service.send_password_reset_email(
|
||||
to_email="user@example.com",
|
||||
reset_token="test_token_123"
|
||||
to_email="user@example.com", reset_token="test_token_123"
|
||||
)
|
||||
|
||||
assert result is True
|
||||
@@ -142,8 +141,7 @@ class TestEmailService:
|
||||
|
||||
token = "test_reset_token_xyz"
|
||||
await service.send_password_reset_email(
|
||||
to_email="user@example.com",
|
||||
reset_token=token
|
||||
to_email="user@example.com", reset_token=token
|
||||
)
|
||||
|
||||
# Verify send_email was called
|
||||
@@ -151,7 +149,7 @@ class TestEmailService:
|
||||
call_args = backend_mock.send_email.call_args
|
||||
|
||||
# Check that token is in the HTML content
|
||||
html_content = call_args.kwargs['html_content']
|
||||
html_content = call_args.kwargs["html_content"]
|
||||
assert token in html_content
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -162,8 +160,7 @@ class TestEmailService:
|
||||
service = EmailService(backend=backend_mock)
|
||||
|
||||
result = await service.send_password_reset_email(
|
||||
to_email="user@example.com",
|
||||
reset_token="test_token"
|
||||
to_email="user@example.com", reset_token="test_token"
|
||||
)
|
||||
|
||||
assert result is False
|
||||
@@ -176,7 +173,7 @@ class TestEmailService:
|
||||
result = await service.send_email_verification(
|
||||
to_email="user@example.com",
|
||||
verification_token="verification_token_123",
|
||||
user_name="Jane"
|
||||
user_name="Jane",
|
||||
)
|
||||
|
||||
assert result is True
|
||||
@@ -187,8 +184,7 @@ class TestEmailService:
|
||||
service = EmailService()
|
||||
|
||||
result = await service.send_email_verification(
|
||||
to_email="user@example.com",
|
||||
verification_token="verification_token_123"
|
||||
to_email="user@example.com", verification_token="verification_token_123"
|
||||
)
|
||||
|
||||
assert result is True
|
||||
@@ -202,8 +198,7 @@ class TestEmailService:
|
||||
|
||||
token = "test_verification_token_xyz"
|
||||
await service.send_email_verification(
|
||||
to_email="user@example.com",
|
||||
verification_token=token
|
||||
to_email="user@example.com", verification_token=token
|
||||
)
|
||||
|
||||
# Verify send_email was called
|
||||
@@ -211,7 +206,7 @@ class TestEmailService:
|
||||
call_args = backend_mock.send_email.call_args
|
||||
|
||||
# Check that token is in the HTML content
|
||||
html_content = call_args.kwargs['html_content']
|
||||
html_content = call_args.kwargs["html_content"]
|
||||
assert token in html_content
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -222,8 +217,7 @@ class TestEmailService:
|
||||
service = EmailService(backend=backend_mock)
|
||||
|
||||
result = await service.send_email_verification(
|
||||
to_email="user@example.com",
|
||||
verification_token="test_token"
|
||||
to_email="user@example.com", verification_token="test_token"
|
||||
)
|
||||
|
||||
assert result is False
|
||||
@@ -236,14 +230,12 @@ class TestEmailService:
|
||||
service = EmailService(backend=backend_mock)
|
||||
|
||||
await service.send_password_reset_email(
|
||||
to_email="user@example.com",
|
||||
reset_token="token123",
|
||||
user_name="Test User"
|
||||
to_email="user@example.com", reset_token="token123", user_name="Test User"
|
||||
)
|
||||
|
||||
call_args = backend_mock.send_email.call_args
|
||||
html_content = call_args.kwargs['html_content']
|
||||
text_content = call_args.kwargs['text_content']
|
||||
html_content = call_args.kwargs["html_content"]
|
||||
text_content = call_args.kwargs["text_content"]
|
||||
|
||||
# Check HTML content
|
||||
assert "Password Reset" in html_content
|
||||
@@ -251,7 +243,9 @@ class TestEmailService:
|
||||
assert "Test User" in html_content
|
||||
|
||||
# Check text content
|
||||
assert "Password Reset" in text_content or "password reset" in text_content.lower()
|
||||
assert (
|
||||
"Password Reset" in text_content or "password reset" in text_content.lower()
|
||||
)
|
||||
assert "token123" in text_content
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -264,12 +258,12 @@ class TestEmailService:
|
||||
await service.send_email_verification(
|
||||
to_email="user@example.com",
|
||||
verification_token="verify123",
|
||||
user_name="Test User"
|
||||
user_name="Test User",
|
||||
)
|
||||
|
||||
call_args = backend_mock.send_email.call_args
|
||||
html_content = call_args.kwargs['html_content']
|
||||
text_content = call_args.kwargs['text_content']
|
||||
html_content = call_args.kwargs["html_content"]
|
||||
text_content = call_args.kwargs["text_content"]
|
||||
|
||||
# Check HTML content
|
||||
assert "Verify" in html_content
|
||||
|
||||
@@ -2,23 +2,27 @@
|
||||
"""
|
||||
Comprehensive tests for session cleanup service.
|
||||
"""
|
||||
import pytest
|
||||
|
||||
import asyncio
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from unittest.mock import patch, MagicMock, AsyncMock
|
||||
from contextlib import asynccontextmanager
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
import pytest
|
||||
from sqlalchemy import select
|
||||
|
||||
from app.models.user_session import UserSession
|
||||
from sqlalchemy import select
|
||||
|
||||
|
||||
class TestCleanupExpiredSessions:
|
||||
"""Tests for cleanup_expired_sessions function."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cleanup_expired_sessions_success(self, async_test_db, async_test_user):
|
||||
async def test_cleanup_expired_sessions_success(
|
||||
self, async_test_db, async_test_user
|
||||
):
|
||||
"""Test successful cleanup of expired sessions."""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
# Create mix of sessions
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
@@ -30,9 +34,9 @@ class TestCleanupExpiredSessions:
|
||||
ip_address="192.168.1.1",
|
||||
user_agent="Mozilla/5.0",
|
||||
is_active=True,
|
||||
expires_at=datetime.now(timezone.utc) + timedelta(days=7),
|
||||
created_at=datetime.now(timezone.utc) - timedelta(days=1),
|
||||
last_used_at=datetime.now(timezone.utc)
|
||||
expires_at=datetime.now(UTC) + timedelta(days=7),
|
||||
created_at=datetime.now(UTC) - timedelta(days=1),
|
||||
last_used_at=datetime.now(UTC),
|
||||
)
|
||||
|
||||
# 2. Inactive, expired, old (SHOULD be deleted)
|
||||
@@ -43,9 +47,9 @@ class TestCleanupExpiredSessions:
|
||||
ip_address="192.168.1.2",
|
||||
user_agent="Mozilla/5.0",
|
||||
is_active=False,
|
||||
expires_at=datetime.now(timezone.utc) - timedelta(days=10),
|
||||
created_at=datetime.now(timezone.utc) - timedelta(days=40),
|
||||
last_used_at=datetime.now(timezone.utc)
|
||||
expires_at=datetime.now(UTC) - timedelta(days=10),
|
||||
created_at=datetime.now(UTC) - timedelta(days=40),
|
||||
last_used_at=datetime.now(UTC),
|
||||
)
|
||||
|
||||
# 3. Inactive, expired, recent (should NOT be deleted - within keep_days)
|
||||
@@ -56,17 +60,23 @@ class TestCleanupExpiredSessions:
|
||||
ip_address="192.168.1.3",
|
||||
user_agent="Mozilla/5.0",
|
||||
is_active=False,
|
||||
expires_at=datetime.now(timezone.utc) - timedelta(days=1),
|
||||
created_at=datetime.now(timezone.utc) - timedelta(days=5),
|
||||
last_used_at=datetime.now(timezone.utc)
|
||||
expires_at=datetime.now(UTC) - timedelta(days=1),
|
||||
created_at=datetime.now(UTC) - timedelta(days=5),
|
||||
last_used_at=datetime.now(UTC),
|
||||
)
|
||||
|
||||
session.add_all([active_session, old_expired_session, recent_expired_session])
|
||||
session.add_all(
|
||||
[active_session, old_expired_session, recent_expired_session]
|
||||
)
|
||||
await session.commit()
|
||||
|
||||
# Mock SessionLocal to return our test session
|
||||
with patch('app.services.session_cleanup.SessionLocal', return_value=AsyncTestingSessionLocal()):
|
||||
with patch(
|
||||
"app.services.session_cleanup.SessionLocal",
|
||||
return_value=AsyncTestingSessionLocal(),
|
||||
):
|
||||
from app.services.session_cleanup import cleanup_expired_sessions
|
||||
|
||||
deleted_count = await cleanup_expired_sessions(keep_days=30)
|
||||
|
||||
# Should only delete old_expired_session
|
||||
@@ -85,7 +95,7 @@ class TestCleanupExpiredSessions:
|
||||
@pytest.mark.asyncio
|
||||
async def test_cleanup_no_sessions_to_delete(self, async_test_db, async_test_user):
|
||||
"""Test cleanup when no sessions meet deletion criteria."""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
active = UserSession(
|
||||
@@ -95,15 +105,19 @@ class TestCleanupExpiredSessions:
|
||||
ip_address="192.168.1.1",
|
||||
user_agent="Mozilla/5.0",
|
||||
is_active=True,
|
||||
expires_at=datetime.now(timezone.utc) + timedelta(days=7),
|
||||
created_at=datetime.now(timezone.utc),
|
||||
last_used_at=datetime.now(timezone.utc)
|
||||
expires_at=datetime.now(UTC) + timedelta(days=7),
|
||||
created_at=datetime.now(UTC),
|
||||
last_used_at=datetime.now(UTC),
|
||||
)
|
||||
session.add(active)
|
||||
await session.commit()
|
||||
|
||||
with patch('app.services.session_cleanup.SessionLocal', return_value=AsyncTestingSessionLocal()):
|
||||
with patch(
|
||||
"app.services.session_cleanup.SessionLocal",
|
||||
return_value=AsyncTestingSessionLocal(),
|
||||
):
|
||||
from app.services.session_cleanup import cleanup_expired_sessions
|
||||
|
||||
deleted_count = await cleanup_expired_sessions(keep_days=30)
|
||||
|
||||
assert deleted_count == 0
|
||||
@@ -111,10 +125,14 @@ class TestCleanupExpiredSessions:
|
||||
@pytest.mark.asyncio
|
||||
async def test_cleanup_empty_database(self, async_test_db):
|
||||
"""Test cleanup with no sessions in database."""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
with patch('app.services.session_cleanup.SessionLocal', return_value=AsyncTestingSessionLocal()):
|
||||
with patch(
|
||||
"app.services.session_cleanup.SessionLocal",
|
||||
return_value=AsyncTestingSessionLocal(),
|
||||
):
|
||||
from app.services.session_cleanup import cleanup_expired_sessions
|
||||
|
||||
deleted_count = await cleanup_expired_sessions(keep_days=30)
|
||||
|
||||
assert deleted_count == 0
|
||||
@@ -122,7 +140,7 @@ class TestCleanupExpiredSessions:
|
||||
@pytest.mark.asyncio
|
||||
async def test_cleanup_with_keep_days_0(self, async_test_db, async_test_user):
|
||||
"""Test cleanup with keep_days=0 deletes all inactive expired sessions."""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
today_expired = UserSession(
|
||||
@@ -132,15 +150,19 @@ class TestCleanupExpiredSessions:
|
||||
ip_address="192.168.1.1",
|
||||
user_agent="Mozilla/5.0",
|
||||
is_active=False,
|
||||
expires_at=datetime.now(timezone.utc) - timedelta(hours=1),
|
||||
created_at=datetime.now(timezone.utc) - timedelta(hours=2),
|
||||
last_used_at=datetime.now(timezone.utc)
|
||||
expires_at=datetime.now(UTC) - timedelta(hours=1),
|
||||
created_at=datetime.now(UTC) - timedelta(hours=2),
|
||||
last_used_at=datetime.now(UTC),
|
||||
)
|
||||
session.add(today_expired)
|
||||
await session.commit()
|
||||
|
||||
with patch('app.services.session_cleanup.SessionLocal', return_value=AsyncTestingSessionLocal()):
|
||||
with patch(
|
||||
"app.services.session_cleanup.SessionLocal",
|
||||
return_value=AsyncTestingSessionLocal(),
|
||||
):
|
||||
from app.services.session_cleanup import cleanup_expired_sessions
|
||||
|
||||
deleted_count = await cleanup_expired_sessions(keep_days=0)
|
||||
|
||||
assert deleted_count == 1
|
||||
@@ -148,7 +170,7 @@ class TestCleanupExpiredSessions:
|
||||
@pytest.mark.asyncio
|
||||
async def test_cleanup_bulk_delete_efficiency(self, async_test_db, async_test_user):
|
||||
"""Test that cleanup uses bulk DELETE for many sessions."""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
# Create 50 expired sessions
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
@@ -161,16 +183,20 @@ class TestCleanupExpiredSessions:
|
||||
ip_address="192.168.1.1",
|
||||
user_agent="Mozilla/5.0",
|
||||
is_active=False,
|
||||
expires_at=datetime.now(timezone.utc) - timedelta(days=10),
|
||||
created_at=datetime.now(timezone.utc) - timedelta(days=40),
|
||||
last_used_at=datetime.now(timezone.utc)
|
||||
expires_at=datetime.now(UTC) - timedelta(days=10),
|
||||
created_at=datetime.now(UTC) - timedelta(days=40),
|
||||
last_used_at=datetime.now(UTC),
|
||||
)
|
||||
sessions_to_add.append(expired)
|
||||
session.add_all(sessions_to_add)
|
||||
await session.commit()
|
||||
|
||||
with patch('app.services.session_cleanup.SessionLocal', return_value=AsyncTestingSessionLocal()):
|
||||
with patch(
|
||||
"app.services.session_cleanup.SessionLocal",
|
||||
return_value=AsyncTestingSessionLocal(),
|
||||
):
|
||||
from app.services.session_cleanup import cleanup_expired_sessions
|
||||
|
||||
deleted_count = await cleanup_expired_sessions(keep_days=30)
|
||||
|
||||
assert deleted_count == 50
|
||||
@@ -178,14 +204,20 @@ class TestCleanupExpiredSessions:
|
||||
@pytest.mark.asyncio
|
||||
async def test_cleanup_database_error_returns_zero(self, async_test_db):
|
||||
"""Test cleanup returns 0 on database errors (doesn't crash)."""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
# Mock session_crud.cleanup_expired to raise error
|
||||
with patch('app.services.session_cleanup.SessionLocal', return_value=AsyncTestingSessionLocal()):
|
||||
with patch('app.services.session_cleanup.session_crud.cleanup_expired') as mock_cleanup:
|
||||
with patch(
|
||||
"app.services.session_cleanup.SessionLocal",
|
||||
return_value=AsyncTestingSessionLocal(),
|
||||
):
|
||||
with patch(
|
||||
"app.services.session_cleanup.session_crud.cleanup_expired"
|
||||
) as mock_cleanup:
|
||||
mock_cleanup.side_effect = Exception("Database connection lost")
|
||||
|
||||
from app.services.session_cleanup import cleanup_expired_sessions
|
||||
|
||||
# Should not crash, should return 0
|
||||
deleted_count = await cleanup_expired_sessions(keep_days=30)
|
||||
|
||||
@@ -198,7 +230,7 @@ class TestGetSessionStatistics:
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_statistics_with_sessions(self, async_test_db, async_test_user):
|
||||
"""Test getting session statistics with various session types."""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
# 2 active, not expired
|
||||
@@ -210,9 +242,9 @@ class TestGetSessionStatistics:
|
||||
ip_address="192.168.1.1",
|
||||
user_agent="Mozilla/5.0",
|
||||
is_active=True,
|
||||
expires_at=datetime.now(timezone.utc) + timedelta(days=7),
|
||||
created_at=datetime.now(timezone.utc),
|
||||
last_used_at=datetime.now(timezone.utc)
|
||||
expires_at=datetime.now(UTC) + timedelta(days=7),
|
||||
created_at=datetime.now(UTC),
|
||||
last_used_at=datetime.now(UTC),
|
||||
)
|
||||
session.add(active)
|
||||
|
||||
@@ -225,9 +257,9 @@ class TestGetSessionStatistics:
|
||||
ip_address="192.168.1.2",
|
||||
user_agent="Mozilla/5.0",
|
||||
is_active=False,
|
||||
expires_at=datetime.now(timezone.utc) - timedelta(days=1),
|
||||
created_at=datetime.now(timezone.utc) - timedelta(days=2),
|
||||
last_used_at=datetime.now(timezone.utc)
|
||||
expires_at=datetime.now(UTC) - timedelta(days=1),
|
||||
created_at=datetime.now(UTC) - timedelta(days=2),
|
||||
last_used_at=datetime.now(UTC),
|
||||
)
|
||||
session.add(inactive)
|
||||
|
||||
@@ -239,16 +271,20 @@ class TestGetSessionStatistics:
|
||||
ip_address="192.168.1.3",
|
||||
user_agent="Mozilla/5.0",
|
||||
is_active=True,
|
||||
expires_at=datetime.now(timezone.utc) - timedelta(hours=1),
|
||||
created_at=datetime.now(timezone.utc) - timedelta(days=1),
|
||||
last_used_at=datetime.now(timezone.utc)
|
||||
expires_at=datetime.now(UTC) - timedelta(hours=1),
|
||||
created_at=datetime.now(UTC) - timedelta(days=1),
|
||||
last_used_at=datetime.now(UTC),
|
||||
)
|
||||
session.add(expired_active)
|
||||
|
||||
await session.commit()
|
||||
|
||||
with patch('app.services.session_cleanup.SessionLocal', return_value=AsyncTestingSessionLocal()):
|
||||
with patch(
|
||||
"app.services.session_cleanup.SessionLocal",
|
||||
return_value=AsyncTestingSessionLocal(),
|
||||
):
|
||||
from app.services.session_cleanup import get_session_statistics
|
||||
|
||||
stats = await get_session_statistics()
|
||||
|
||||
assert stats["total"] == 6
|
||||
@@ -259,10 +295,14 @@ class TestGetSessionStatistics:
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_statistics_empty_database(self, async_test_db):
|
||||
"""Test getting statistics with no sessions."""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
with patch('app.services.session_cleanup.SessionLocal', return_value=AsyncTestingSessionLocal()):
|
||||
with patch(
|
||||
"app.services.session_cleanup.SessionLocal",
|
||||
return_value=AsyncTestingSessionLocal(),
|
||||
):
|
||||
from app.services.session_cleanup import get_session_statistics
|
||||
|
||||
stats = await get_session_statistics()
|
||||
|
||||
assert stats["total"] == 0
|
||||
@@ -271,9 +311,11 @@ class TestGetSessionStatistics:
|
||||
assert stats["expired"] == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_statistics_database_error_returns_empty_dict(self, async_test_db):
|
||||
async def test_get_statistics_database_error_returns_empty_dict(
|
||||
self, async_test_db
|
||||
):
|
||||
"""Test statistics returns empty dict on database errors."""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
_test_engine, _AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
# Create a mock that raises on execute
|
||||
mock_session = AsyncMock()
|
||||
@@ -283,8 +325,12 @@ class TestGetSessionStatistics:
|
||||
async def mock_session_local():
|
||||
yield mock_session
|
||||
|
||||
with patch('app.services.session_cleanup.SessionLocal', return_value=mock_session_local()):
|
||||
with patch(
|
||||
"app.services.session_cleanup.SessionLocal",
|
||||
return_value=mock_session_local(),
|
||||
):
|
||||
from app.services.session_cleanup import get_session_statistics
|
||||
|
||||
stats = await get_session_statistics()
|
||||
|
||||
assert stats == {}
|
||||
@@ -294,9 +340,11 @@ class TestConcurrentCleanup:
|
||||
"""Tests for concurrent cleanup scenarios."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_concurrent_cleanup_no_duplicate_deletes(self, async_test_db, async_test_user):
|
||||
async def test_concurrent_cleanup_no_duplicate_deletes(
|
||||
self, async_test_db, async_test_user
|
||||
):
|
||||
"""Test concurrent cleanups don't cause race conditions."""
|
||||
test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
# Create 10 expired sessions
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
@@ -308,20 +356,24 @@ class TestConcurrentCleanup:
|
||||
ip_address="192.168.1.1",
|
||||
user_agent="Mozilla/5.0",
|
||||
is_active=False,
|
||||
expires_at=datetime.now(timezone.utc) - timedelta(days=10),
|
||||
created_at=datetime.now(timezone.utc) - timedelta(days=40),
|
||||
last_used_at=datetime.now(timezone.utc)
|
||||
expires_at=datetime.now(UTC) - timedelta(days=10),
|
||||
created_at=datetime.now(UTC) - timedelta(days=40),
|
||||
last_used_at=datetime.now(UTC),
|
||||
)
|
||||
session.add(expired)
|
||||
await session.commit()
|
||||
|
||||
# Run two cleanups concurrently
|
||||
# Use side_effect to return fresh session instances for each call
|
||||
with patch('app.services.session_cleanup.SessionLocal', side_effect=lambda: AsyncTestingSessionLocal()):
|
||||
with patch(
|
||||
"app.services.session_cleanup.SessionLocal",
|
||||
side_effect=lambda: AsyncTestingSessionLocal(),
|
||||
):
|
||||
from app.services.session_cleanup import cleanup_expired_sessions
|
||||
|
||||
results = await asyncio.gather(
|
||||
cleanup_expired_sessions(keep_days=30),
|
||||
cleanup_expired_sessions(keep_days=30)
|
||||
cleanup_expired_sessions(keep_days=30),
|
||||
)
|
||||
|
||||
# Both should report deleting sessions (may overlap due to transaction timing)
|
||||
|
||||
@@ -2,12 +2,13 @@
|
||||
"""
|
||||
Tests for database initialization script.
|
||||
"""
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
from app.init_db import init_db
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from app.core.config import settings
|
||||
from app.init_db import init_db
|
||||
|
||||
|
||||
class TestInitDb:
|
||||
@@ -16,69 +17,86 @@ class TestInitDb:
|
||||
@pytest.mark.asyncio
|
||||
async def test_init_db_creates_superuser_when_not_exists(self, async_test_db):
|
||||
"""Test that init_db creates a superuser when one doesn't exist."""
|
||||
test_engine, SessionLocal = async_test_db
|
||||
_test_engine, SessionLocal = async_test_db
|
||||
|
||||
# Mock the SessionLocal to use our test database
|
||||
with patch('app.init_db.SessionLocal', SessionLocal):
|
||||
with patch("app.init_db.SessionLocal", SessionLocal):
|
||||
# Mock settings to provide test credentials
|
||||
with patch.object(settings, 'FIRST_SUPERUSER_EMAIL', 'test_admin@example.com'):
|
||||
with patch.object(settings, 'FIRST_SUPERUSER_PASSWORD', 'TestAdmin123!'):
|
||||
with patch.object(
|
||||
settings, "FIRST_SUPERUSER_EMAIL", "test_admin@example.com"
|
||||
):
|
||||
with patch.object(
|
||||
settings, "FIRST_SUPERUSER_PASSWORD", "TestAdmin123!"
|
||||
):
|
||||
# Run init_db
|
||||
user = await init_db()
|
||||
|
||||
# Verify superuser was created
|
||||
assert user is not None
|
||||
assert user.email == 'test_admin@example.com'
|
||||
assert user.email == "test_admin@example.com"
|
||||
assert user.is_superuser is True
|
||||
assert user.first_name == 'Admin'
|
||||
assert user.last_name == 'User'
|
||||
assert user.first_name == "Admin"
|
||||
assert user.last_name == "User"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_init_db_returns_existing_superuser(self, async_test_db, async_test_user):
|
||||
async def test_init_db_returns_existing_superuser(
|
||||
self, async_test_db, async_test_user
|
||||
):
|
||||
"""Test that init_db returns existing superuser instead of creating duplicate."""
|
||||
test_engine, SessionLocal = async_test_db
|
||||
_test_engine, SessionLocal = async_test_db
|
||||
|
||||
# Mock the SessionLocal to use our test database
|
||||
with patch('app.init_db.SessionLocal', SessionLocal):
|
||||
with patch("app.init_db.SessionLocal", SessionLocal):
|
||||
# Mock settings to match async_test_user's email
|
||||
with patch.object(settings, 'FIRST_SUPERUSER_EMAIL', 'testuser@example.com'):
|
||||
with patch.object(settings, 'FIRST_SUPERUSER_PASSWORD', 'TestPassword123!'):
|
||||
with patch.object(
|
||||
settings, "FIRST_SUPERUSER_EMAIL", "testuser@example.com"
|
||||
):
|
||||
with patch.object(
|
||||
settings, "FIRST_SUPERUSER_PASSWORD", "TestPassword123!"
|
||||
):
|
||||
# Run init_db
|
||||
user = await init_db()
|
||||
|
||||
# Verify it returns the existing user
|
||||
assert user is not None
|
||||
assert user.id == async_test_user.id
|
||||
assert user.email == 'testuser@example.com'
|
||||
assert user.email == "testuser@example.com"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_init_db_uses_default_credentials(self, async_test_db):
|
||||
"""Test that init_db uses default credentials when env vars not set."""
|
||||
test_engine, SessionLocal = async_test_db
|
||||
_test_engine, SessionLocal = async_test_db
|
||||
|
||||
# Mock the SessionLocal to use our test database
|
||||
with patch('app.init_db.SessionLocal', SessionLocal):
|
||||
with patch("app.init_db.SessionLocal", SessionLocal):
|
||||
# Mock settings to have None values (not configured)
|
||||
with patch.object(settings, 'FIRST_SUPERUSER_EMAIL', None):
|
||||
with patch.object(settings, 'FIRST_SUPERUSER_PASSWORD', None):
|
||||
with patch.object(settings, "FIRST_SUPERUSER_EMAIL", None):
|
||||
with patch.object(settings, "FIRST_SUPERUSER_PASSWORD", None):
|
||||
# Run init_db
|
||||
user = await init_db()
|
||||
|
||||
# Verify superuser was created with defaults
|
||||
assert user is not None
|
||||
assert user.email == 'admin@example.com'
|
||||
assert user.email == "admin@example.com"
|
||||
assert user.is_superuser is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_init_db_handles_database_errors(self, async_test_db):
|
||||
"""Test that init_db handles database errors gracefully."""
|
||||
test_engine, SessionLocal = async_test_db
|
||||
_test_engine, SessionLocal = async_test_db
|
||||
|
||||
# Mock user_crud.get_by_email to raise an exception
|
||||
with patch('app.init_db.user_crud.get_by_email', side_effect=Exception("Database error")):
|
||||
with patch('app.init_db.SessionLocal', SessionLocal):
|
||||
with patch.object(settings, 'FIRST_SUPERUSER_EMAIL', 'test@example.com'):
|
||||
with patch.object(settings, 'FIRST_SUPERUSER_PASSWORD', 'TestPassword123!'):
|
||||
with patch(
|
||||
"app.init_db.user_crud.get_by_email",
|
||||
side_effect=Exception("Database error"),
|
||||
):
|
||||
with patch("app.init_db.SessionLocal", SessionLocal):
|
||||
with patch.object(
|
||||
settings, "FIRST_SUPERUSER_EMAIL", "test@example.com"
|
||||
):
|
||||
with patch.object(
|
||||
settings, "FIRST_SUPERUSER_PASSWORD", "TestPassword123!"
|
||||
):
|
||||
# Run init_db and expect it to raise
|
||||
with pytest.raises(Exception, match="Database error"):
|
||||
await init_db()
|
||||
|
||||
@@ -2,18 +2,18 @@
|
||||
"""
|
||||
Comprehensive tests for device utility functions.
|
||||
"""
|
||||
import pytest
|
||||
|
||||
from unittest.mock import Mock
|
||||
|
||||
from fastapi import Request
|
||||
|
||||
from app.utils.device import (
|
||||
extract_device_info,
|
||||
parse_device_name,
|
||||
extract_browser,
|
||||
extract_device_info,
|
||||
get_client_ip,
|
||||
get_device_type,
|
||||
is_mobile_device,
|
||||
get_device_type
|
||||
parse_device_name,
|
||||
)
|
||||
|
||||
|
||||
@@ -138,7 +138,9 @@ class TestExtractBrowser:
|
||||
|
||||
def test_extract_browser_edge_legacy(self):
|
||||
"""Test extracting legacy Edge browser."""
|
||||
ua = "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 Edge/18.19582"
|
||||
ua = (
|
||||
"Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 Edge/18.19582"
|
||||
)
|
||||
result = extract_browser(ua)
|
||||
assert result == "Edge"
|
||||
|
||||
@@ -249,7 +251,7 @@ class TestGetClientIp:
|
||||
request = Mock(spec=Request)
|
||||
request.headers = {
|
||||
"x-forwarded-for": "192.168.1.100",
|
||||
"x-real-ip": "192.168.1.200"
|
||||
"x-real-ip": "192.168.1.200",
|
||||
}
|
||||
request.client = Mock()
|
||||
request.client.host = "192.168.1.50"
|
||||
@@ -385,7 +387,7 @@ class TestExtractDeviceInfo:
|
||||
request.headers = {
|
||||
"user-agent": "Mozilla/5.0 (iPhone; CPU iPhone OS 15_0 like Mac OS X)",
|
||||
"x-device-id": "device-123-456",
|
||||
"x-forwarded-for": "192.168.1.100"
|
||||
"x-forwarded-for": "192.168.1.100",
|
||||
}
|
||||
request.client = None
|
||||
|
||||
|
||||
@@ -2,19 +2,21 @@
|
||||
"""
|
||||
Tests for security utility functions.
|
||||
"""
|
||||
import time
|
||||
|
||||
import base64
|
||||
import json
|
||||
import time
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from unittest.mock import patch, MagicMock
|
||||
|
||||
from app.utils.security import (
|
||||
create_upload_token,
|
||||
verify_upload_token,
|
||||
create_password_reset_token,
|
||||
verify_password_reset_token,
|
||||
create_email_verification_token,
|
||||
verify_email_verification_token
|
||||
create_password_reset_token,
|
||||
create_upload_token,
|
||||
verify_email_verification_token,
|
||||
verify_password_reset_token,
|
||||
verify_upload_token,
|
||||
)
|
||||
|
||||
|
||||
@@ -31,7 +33,7 @@ class TestCreateUploadToken:
|
||||
|
||||
# Token should be base64 encoded
|
||||
try:
|
||||
decoded = base64.urlsafe_b64decode(token.encode('utf-8'))
|
||||
decoded = base64.urlsafe_b64decode(token.encode("utf-8"))
|
||||
token_data = json.loads(decoded)
|
||||
assert "payload" in token_data
|
||||
assert "signature" in token_data
|
||||
@@ -46,7 +48,7 @@ class TestCreateUploadToken:
|
||||
token = create_upload_token(file_path, content_type)
|
||||
|
||||
# Decode and verify payload
|
||||
decoded = base64.urlsafe_b64decode(token.encode('utf-8'))
|
||||
decoded = base64.urlsafe_b64decode(token.encode("utf-8"))
|
||||
token_data = json.loads(decoded)
|
||||
payload = token_data["payload"]
|
||||
|
||||
@@ -62,7 +64,7 @@ class TestCreateUploadToken:
|
||||
after = int(time.time())
|
||||
|
||||
# Decode token
|
||||
decoded = base64.urlsafe_b64decode(token.encode('utf-8'))
|
||||
decoded = base64.urlsafe_b64decode(token.encode("utf-8"))
|
||||
token_data = json.loads(decoded)
|
||||
payload = token_data["payload"]
|
||||
|
||||
@@ -74,11 +76,13 @@ class TestCreateUploadToken:
|
||||
"""Test token creation with custom expiration time."""
|
||||
custom_exp = 600 # 10 minutes
|
||||
before = int(time.time())
|
||||
token = create_upload_token("/uploads/test.jpg", "image/jpeg", expires_in=custom_exp)
|
||||
token = create_upload_token(
|
||||
"/uploads/test.jpg", "image/jpeg", expires_in=custom_exp
|
||||
)
|
||||
after = int(time.time())
|
||||
|
||||
# Decode token
|
||||
decoded = base64.urlsafe_b64decode(token.encode('utf-8'))
|
||||
decoded = base64.urlsafe_b64decode(token.encode("utf-8"))
|
||||
token_data = json.loads(decoded)
|
||||
payload = token_data["payload"]
|
||||
|
||||
@@ -92,11 +96,11 @@ class TestCreateUploadToken:
|
||||
token2 = create_upload_token("/uploads/test.jpg", "image/jpeg")
|
||||
|
||||
# Decode both tokens
|
||||
decoded1 = base64.urlsafe_b64decode(token1.encode('utf-8'))
|
||||
decoded1 = base64.urlsafe_b64decode(token1.encode("utf-8"))
|
||||
token_data1 = json.loads(decoded1)
|
||||
nonce1 = token_data1["payload"]["nonce"]
|
||||
|
||||
decoded2 = base64.urlsafe_b64decode(token2.encode('utf-8'))
|
||||
decoded2 = base64.urlsafe_b64decode(token2.encode("utf-8"))
|
||||
token_data2 = json.loads(decoded2)
|
||||
nonce2 = token_data2["payload"]["nonce"]
|
||||
|
||||
@@ -133,7 +137,7 @@ class TestVerifyUploadToken:
|
||||
current_time = 1000000
|
||||
mock_time.time = MagicMock(return_value=current_time)
|
||||
|
||||
with patch('app.utils.security.time', mock_time):
|
||||
with patch("app.utils.security.time", mock_time):
|
||||
# Create token that "expires" at current_time + 1
|
||||
token = create_upload_token("/uploads/test.jpg", "image/jpeg", expires_in=1)
|
||||
|
||||
@@ -149,13 +153,15 @@ class TestVerifyUploadToken:
|
||||
token = create_upload_token("/uploads/test.jpg", "image/jpeg")
|
||||
|
||||
# Decode, modify, and re-encode
|
||||
decoded = base64.urlsafe_b64decode(token.encode('utf-8'))
|
||||
decoded = base64.urlsafe_b64decode(token.encode("utf-8"))
|
||||
token_data = json.loads(decoded)
|
||||
token_data["signature"] = "invalid_signature"
|
||||
|
||||
# Re-encode the tampered token
|
||||
tampered_json = json.dumps(token_data)
|
||||
tampered_token = base64.urlsafe_b64encode(tampered_json.encode('utf-8')).decode('utf-8')
|
||||
tampered_token = base64.urlsafe_b64encode(tampered_json.encode("utf-8")).decode(
|
||||
"utf-8"
|
||||
)
|
||||
|
||||
payload = verify_upload_token(tampered_token)
|
||||
assert payload is None
|
||||
@@ -165,13 +171,15 @@ class TestVerifyUploadToken:
|
||||
token = create_upload_token("/uploads/test.jpg", "image/jpeg")
|
||||
|
||||
# Decode, modify payload, and re-encode
|
||||
decoded = base64.urlsafe_b64decode(token.encode('utf-8'))
|
||||
decoded = base64.urlsafe_b64decode(token.encode("utf-8"))
|
||||
token_data = json.loads(decoded)
|
||||
token_data["payload"]["path"] = "/uploads/hacked.exe"
|
||||
|
||||
# Re-encode the tampered token (signature won't match)
|
||||
tampered_json = json.dumps(token_data)
|
||||
tampered_token = base64.urlsafe_b64encode(tampered_json.encode('utf-8')).decode('utf-8')
|
||||
tampered_token = base64.urlsafe_b64encode(tampered_json.encode("utf-8")).decode(
|
||||
"utf-8"
|
||||
)
|
||||
|
||||
payload = verify_upload_token(tampered_token)
|
||||
assert payload is None
|
||||
@@ -194,7 +202,9 @@ class TestVerifyUploadToken:
|
||||
"""Test that tokens with invalid JSON are rejected."""
|
||||
# Create a base64 string that decodes to invalid JSON
|
||||
invalid_json = "not valid json"
|
||||
invalid_token = base64.urlsafe_b64encode(invalid_json.encode('utf-8')).decode('utf-8')
|
||||
invalid_token = base64.urlsafe_b64encode(invalid_json.encode("utf-8")).decode(
|
||||
"utf-8"
|
||||
)
|
||||
|
||||
payload = verify_upload_token(invalid_token)
|
||||
assert payload is None
|
||||
@@ -207,11 +217,13 @@ class TestVerifyUploadToken:
|
||||
"path": "/uploads/test.jpg"
|
||||
# Missing content_type, exp, nonce
|
||||
},
|
||||
"signature": "some_signature"
|
||||
"signature": "some_signature",
|
||||
}
|
||||
|
||||
incomplete_json = json.dumps(incomplete_data)
|
||||
incomplete_token = base64.urlsafe_b64encode(incomplete_json.encode('utf-8')).decode('utf-8')
|
||||
incomplete_token = base64.urlsafe_b64encode(
|
||||
incomplete_json.encode("utf-8")
|
||||
).decode("utf-8")
|
||||
|
||||
payload = verify_upload_token(incomplete_token)
|
||||
assert payload is None
|
||||
@@ -266,7 +278,7 @@ class TestPasswordResetTokens:
|
||||
email = "user@example.com"
|
||||
|
||||
# Create token that expires in 1 second
|
||||
with patch('app.utils.security.time') as mock_time:
|
||||
with patch("app.utils.security.time") as mock_time:
|
||||
mock_time.time = MagicMock(return_value=1000000)
|
||||
token = create_password_reset_token(email, expires_in=1)
|
||||
|
||||
@@ -287,12 +299,14 @@ class TestPasswordResetTokens:
|
||||
token = create_password_reset_token(email)
|
||||
|
||||
# Decode and tamper
|
||||
decoded = base64.urlsafe_b64decode(token.encode('utf-8')).decode('utf-8')
|
||||
decoded = base64.urlsafe_b64decode(token.encode("utf-8")).decode("utf-8")
|
||||
token_data = json.loads(decoded)
|
||||
token_data["payload"]["email"] = "hacker@example.com"
|
||||
|
||||
# Re-encode
|
||||
tampered = base64.urlsafe_b64encode(json.dumps(token_data).encode('utf-8')).decode('utf-8')
|
||||
tampered = base64.urlsafe_b64encode(
|
||||
json.dumps(token_data).encode("utf-8")
|
||||
).decode("utf-8")
|
||||
|
||||
verified_email = verify_password_reset_token(tampered)
|
||||
assert verified_email is None
|
||||
@@ -312,14 +326,14 @@ class TestPasswordResetTokens:
|
||||
email = "user@example.com"
|
||||
custom_exp = 7200 # 2 hours
|
||||
|
||||
with patch('app.utils.security.time') as mock_time:
|
||||
with patch("app.utils.security.time") as mock_time:
|
||||
current_time = 1000000
|
||||
mock_time.time = MagicMock(return_value=current_time)
|
||||
|
||||
token = create_password_reset_token(email, expires_in=custom_exp)
|
||||
|
||||
# Decode to check expiration
|
||||
decoded = base64.urlsafe_b64decode(token.encode('utf-8')).decode('utf-8')
|
||||
decoded = base64.urlsafe_b64decode(token.encode("utf-8")).decode("utf-8")
|
||||
token_data = json.loads(decoded)
|
||||
|
||||
assert token_data["payload"]["exp"] == current_time + custom_exp
|
||||
@@ -350,7 +364,7 @@ class TestEmailVerificationTokens:
|
||||
"""Test that expired verification tokens are rejected."""
|
||||
email = "user@example.com"
|
||||
|
||||
with patch('app.utils.security.time') as mock_time:
|
||||
with patch("app.utils.security.time") as mock_time:
|
||||
mock_time.time = MagicMock(return_value=1000000)
|
||||
token = create_email_verification_token(email, expires_in=1)
|
||||
|
||||
@@ -371,12 +385,14 @@ class TestEmailVerificationTokens:
|
||||
token = create_email_verification_token(email)
|
||||
|
||||
# Decode and tamper
|
||||
decoded = base64.urlsafe_b64decode(token.encode('utf-8')).decode('utf-8')
|
||||
decoded = base64.urlsafe_b64decode(token.encode("utf-8")).decode("utf-8")
|
||||
token_data = json.loads(decoded)
|
||||
token_data["payload"]["email"] = "hacker@example.com"
|
||||
|
||||
# Re-encode
|
||||
tampered = base64.urlsafe_b64encode(json.dumps(token_data).encode('utf-8')).decode('utf-8')
|
||||
tampered = base64.urlsafe_b64encode(
|
||||
json.dumps(token_data).encode("utf-8")
|
||||
).decode("utf-8")
|
||||
|
||||
verified_email = verify_email_verification_token(tampered)
|
||||
assert verified_email is None
|
||||
@@ -395,14 +411,14 @@ class TestEmailVerificationTokens:
|
||||
"""Test email verification token with default 24-hour expiration."""
|
||||
email = "user@example.com"
|
||||
|
||||
with patch('app.utils.security.time') as mock_time:
|
||||
with patch("app.utils.security.time") as mock_time:
|
||||
current_time = 1000000
|
||||
mock_time.time = MagicMock(return_value=current_time)
|
||||
|
||||
token = create_email_verification_token(email)
|
||||
|
||||
# Decode to check expiration (should be 86400 seconds = 24 hours)
|
||||
decoded = base64.urlsafe_b64decode(token.encode('utf-8')).decode('utf-8')
|
||||
decoded = base64.urlsafe_b64decode(token.encode("utf-8")).decode("utf-8")
|
||||
token_data = json.loads(decoded)
|
||||
|
||||
assert token_data["payload"]["exp"] == current_time + 86400
|
||||
|
||||
Reference in New Issue
Block a user