Add pyproject.toml for consolidated project configuration and replace Black, isort, and Flake8 with Ruff

- Introduced `pyproject.toml` to centralize backend tool configurations (e.g., Ruff, mypy, coverage, pytest).
- Replaced Black, isort, and Flake8 with Ruff for linting, formatting, and import sorting.
- Updated `requirements.txt` to include Ruff and remove replaced tools.
- Added `Makefile` to streamline development workflows with commands for linting, formatting, type-checking, testing, and cleanup.
This commit is contained in:
2025-11-10 11:55:15 +01:00
parent a5c671c133
commit c589b565f0
86 changed files with 4572 additions and 3956 deletions

View File

@@ -2,12 +2,11 @@ import sys
from logging.config import fileConfig
from pathlib import Path
from sqlalchemy import engine_from_config, pool, text, create_engine
from alembic import context
from sqlalchemy import create_engine, engine_from_config, pool, text
from sqlalchemy.engine.url import make_url
from sqlalchemy.exc import OperationalError
from alembic import context
# Get the path to the app directory (parent of 'alembic')
app_dir = Path(__file__).resolve().parent.parent
# Add the app directory to Python path
@@ -66,7 +65,9 @@ def ensure_database_exists(db_url: str) -> None:
admin_url = url.set(database="postgres")
# CREATE DATABASE cannot run inside a transaction
admin_engine = create_engine(str(admin_url), isolation_level="AUTOCOMMIT", poolclass=pool.NullPool)
admin_engine = create_engine(
str(admin_url), isolation_level="AUTOCOMMIT", poolclass=pool.NullPool
)
try:
with admin_engine.connect() as conn:
exists = conn.execute(
@@ -122,9 +123,7 @@ def run_migrations_online() -> None:
)
with connectable.connect() as connection:
context.configure(
connection=connection, target_metadata=target_metadata
)
context.configure(connection=connection, target_metadata=target_metadata)
with context.begin_transaction():
context.run_migrations()
@@ -133,4 +132,4 @@ def run_migrations_online() -> None:
if context.is_offline_mode():
run_migrations_offline()
else:
run_migrations_online()
run_migrations_online()

View File

@@ -5,17 +5,17 @@ Revises: fbf6318a8a36
Create Date: 2025-11-01 04:15:25.367010
"""
from typing import Sequence, Union
from collections.abc import Sequence
import sqlalchemy as sa
from alembic import op
# revision identifiers, used by Alembic.
revision: str = '1174fffbe3e4'
down_revision: Union[str, None] = 'fbf6318a8a36'
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
revision: str = "1174fffbe3e4"
down_revision: str | None = "fbf6318a8a36"
branch_labels: str | Sequence[str] | None = None
depends_on: str | Sequence[str] | None = None
def upgrade() -> None:
@@ -24,46 +24,46 @@ def upgrade() -> None:
# Index for session cleanup queries
# Optimizes: DELETE WHERE is_active = FALSE AND expires_at < now AND created_at < cutoff
op.create_index(
'ix_user_sessions_cleanup',
'user_sessions',
['is_active', 'expires_at', 'created_at'],
"ix_user_sessions_cleanup",
"user_sessions",
["is_active", "expires_at", "created_at"],
unique=False,
postgresql_where=sa.text('is_active = false')
postgresql_where=sa.text("is_active = false"),
)
# Index for user search queries (basic trigram support without pg_trgm extension)
# Optimizes: WHERE email ILIKE '%search%' OR first_name ILIKE '%search%'
# Note: For better performance, consider enabling pg_trgm extension
op.create_index(
'ix_users_email_lower',
'users',
[sa.text('LOWER(email)')],
"ix_users_email_lower",
"users",
[sa.text("LOWER(email)")],
unique=False,
postgresql_where=sa.text('deleted_at IS NULL')
postgresql_where=sa.text("deleted_at IS NULL"),
)
op.create_index(
'ix_users_first_name_lower',
'users',
[sa.text('LOWER(first_name)')],
"ix_users_first_name_lower",
"users",
[sa.text("LOWER(first_name)")],
unique=False,
postgresql_where=sa.text('deleted_at IS NULL')
postgresql_where=sa.text("deleted_at IS NULL"),
)
op.create_index(
'ix_users_last_name_lower',
'users',
[sa.text('LOWER(last_name)')],
"ix_users_last_name_lower",
"users",
[sa.text("LOWER(last_name)")],
unique=False,
postgresql_where=sa.text('deleted_at IS NULL')
postgresql_where=sa.text("deleted_at IS NULL"),
)
# Index for organization search
op.create_index(
'ix_organizations_name_lower',
'organizations',
[sa.text('LOWER(name)')],
unique=False
"ix_organizations_name_lower",
"organizations",
[sa.text("LOWER(name)")],
unique=False,
)
@@ -71,8 +71,8 @@ def downgrade() -> None:
"""Remove performance indexes."""
# Drop indexes in reverse order
op.drop_index('ix_organizations_name_lower', table_name='organizations')
op.drop_index('ix_users_last_name_lower', table_name='users')
op.drop_index('ix_users_first_name_lower', table_name='users')
op.drop_index('ix_users_email_lower', table_name='users')
op.drop_index('ix_user_sessions_cleanup', table_name='user_sessions')
op.drop_index("ix_organizations_name_lower", table_name="organizations")
op.drop_index("ix_users_last_name_lower", table_name="users")
op.drop_index("ix_users_first_name_lower", table_name="users")
op.drop_index("ix_users_email_lower", table_name="users")
op.drop_index("ix_user_sessions_cleanup", table_name="user_sessions")

View File

@@ -5,30 +5,32 @@ Revises: 9e4f2a1b8c7d
Create Date: 2025-10-30 16:40:21.000021
"""
from typing import Sequence, Union
from collections.abc import Sequence
import sqlalchemy as sa
from alembic import op
# revision identifiers, used by Alembic.
revision: str = '2d0fcec3b06d'
down_revision: Union[str, None] = '9e4f2a1b8c7d'
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
revision: str = "2d0fcec3b06d"
down_revision: str | None = "9e4f2a1b8c7d"
branch_labels: str | Sequence[str] | None = None
depends_on: str | Sequence[str] | None = None
def upgrade() -> None:
# Add deleted_at column for soft deletes
op.add_column('users', sa.Column('deleted_at', sa.DateTime(timezone=True), nullable=True))
op.add_column(
"users", sa.Column("deleted_at", sa.DateTime(timezone=True), nullable=True)
)
# Add index on deleted_at for efficient queries
op.create_index('ix_users_deleted_at', 'users', ['deleted_at'])
op.create_index("ix_users_deleted_at", "users", ["deleted_at"])
def downgrade() -> None:
# Remove index
op.drop_index('ix_users_deleted_at', table_name='users')
op.drop_index("ix_users_deleted_at", table_name="users")
# Remove column
op.drop_column('users', 'deleted_at')
op.drop_column("users", "deleted_at")

View File

@@ -5,42 +5,42 @@ Revises: 7396957cbe80
Create Date: 2025-02-28 09:19:33.212278
"""
from typing import Sequence, Union
from collections.abc import Sequence
import sqlalchemy as sa
from alembic import op
# revision identifiers, used by Alembic.
revision: str = '38bf9e7e74b3'
down_revision: Union[str, None] = '7396957cbe80'
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
revision: str = "38bf9e7e74b3"
down_revision: str | None = "7396957cbe80"
branch_labels: str | Sequence[str] | None = None
depends_on: str | Sequence[str] | None = None
def upgrade() -> None:
op.create_table('users',
sa.Column('email', sa.String(), nullable=False),
sa.Column('password_hash', sa.String(), nullable=False),
sa.Column('first_name', sa.String(), nullable=False),
sa.Column('last_name', sa.String(), nullable=True),
sa.Column('phone_number', sa.String(), nullable=True),
sa.Column('is_active', sa.Boolean(), nullable=False),
sa.Column('is_superuser', sa.Boolean(), nullable=False),
sa.Column('preferences', sa.JSON(), nullable=True),
sa.Column('id', sa.UUID(), nullable=False),
sa.Column('created_at', sa.DateTime(timezone=True), nullable=False),
sa.Column('updated_at', sa.DateTime(timezone=True), nullable=False),
sa.PrimaryKeyConstraint('id')
op.create_table(
"users",
sa.Column("email", sa.String(), nullable=False),
sa.Column("password_hash", sa.String(), nullable=False),
sa.Column("first_name", sa.String(), nullable=False),
sa.Column("last_name", sa.String(), nullable=True),
sa.Column("phone_number", sa.String(), nullable=True),
sa.Column("is_active", sa.Boolean(), nullable=False),
sa.Column("is_superuser", sa.Boolean(), nullable=False),
sa.Column("preferences", sa.JSON(), nullable=True),
sa.Column("id", sa.UUID(), nullable=False),
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False),
sa.Column("updated_at", sa.DateTime(timezone=True), nullable=False),
sa.PrimaryKeyConstraint("id"),
)
op.create_index(op.f('ix_users_email'), 'users', ['email'], unique=True)
op.create_index(op.f("ix_users_email"), "users", ["email"], unique=True)
# ### end Alembic commands ###
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_index(op.f('ix_users_email'), table_name='users')
op.drop_table('users')
op.drop_index(op.f("ix_users_email"), table_name="users")
op.drop_table("users")
# ### end Alembic commands ###

View File

@@ -5,98 +5,85 @@ Revises: b76c725fc3cf
Create Date: 2025-10-31 07:41:18.729544
"""
from typing import Sequence, Union
from collections.abc import Sequence
import sqlalchemy as sa
from alembic import op
# revision identifiers, used by Alembic.
revision: str = '549b50ea888d'
down_revision: Union[str, None] = 'b76c725fc3cf'
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
revision: str = "549b50ea888d"
down_revision: str | None = "b76c725fc3cf"
branch_labels: str | Sequence[str] | None = None
depends_on: str | Sequence[str] | None = None
def upgrade() -> None:
# Create user_sessions table for per-device session management
op.create_table(
'user_sessions',
sa.Column('id', sa.UUID(), nullable=False),
sa.Column('user_id', sa.UUID(), nullable=False),
sa.Column('refresh_token_jti', sa.String(length=255), nullable=False),
sa.Column('device_name', sa.String(length=255), nullable=True),
sa.Column('device_id', sa.String(length=255), nullable=True),
sa.Column('ip_address', sa.String(length=45), nullable=True),
sa.Column('user_agent', sa.String(length=500), nullable=True),
sa.Column('last_used_at', sa.DateTime(timezone=True), nullable=False),
sa.Column('expires_at', sa.DateTime(timezone=True), nullable=False),
sa.Column('is_active', sa.Boolean(), nullable=False, server_default='true'),
sa.Column('location_city', sa.String(length=100), nullable=True),
sa.Column('location_country', sa.String(length=100), nullable=True),
sa.Column('created_at', sa.DateTime(timezone=True), nullable=False),
sa.Column('updated_at', sa.DateTime(timezone=True), nullable=False),
sa.PrimaryKeyConstraint('id')
"user_sessions",
sa.Column("id", sa.UUID(), nullable=False),
sa.Column("user_id", sa.UUID(), nullable=False),
sa.Column("refresh_token_jti", sa.String(length=255), nullable=False),
sa.Column("device_name", sa.String(length=255), nullable=True),
sa.Column("device_id", sa.String(length=255), nullable=True),
sa.Column("ip_address", sa.String(length=45), nullable=True),
sa.Column("user_agent", sa.String(length=500), nullable=True),
sa.Column("last_used_at", sa.DateTime(timezone=True), nullable=False),
sa.Column("expires_at", sa.DateTime(timezone=True), nullable=False),
sa.Column("is_active", sa.Boolean(), nullable=False, server_default="true"),
sa.Column("location_city", sa.String(length=100), nullable=True),
sa.Column("location_country", sa.String(length=100), nullable=True),
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False),
sa.Column("updated_at", sa.DateTime(timezone=True), nullable=False),
sa.PrimaryKeyConstraint("id"),
)
# Create foreign key to users table
op.create_foreign_key(
'fk_user_sessions_user_id',
'user_sessions',
'users',
['user_id'],
['id'],
ondelete='CASCADE'
"fk_user_sessions_user_id",
"user_sessions",
"users",
["user_id"],
["id"],
ondelete="CASCADE",
)
# Create indexes for performance
# 1. Lookup session by refresh token JTI (most common query)
op.create_index(
'ix_user_sessions_jti',
'user_sessions',
['refresh_token_jti'],
unique=True
"ix_user_sessions_jti", "user_sessions", ["refresh_token_jti"], unique=True
)
# 2. Lookup sessions by user ID
op.create_index(
'ix_user_sessions_user_id',
'user_sessions',
['user_id']
)
op.create_index("ix_user_sessions_user_id", "user_sessions", ["user_id"])
# 3. Composite index for active sessions by user
op.create_index(
'ix_user_sessions_user_active',
'user_sessions',
['user_id', 'is_active']
"ix_user_sessions_user_active", "user_sessions", ["user_id", "is_active"]
)
# 4. Index on expires_at for cleanup job
op.create_index(
'ix_user_sessions_expires_at',
'user_sessions',
['expires_at']
)
op.create_index("ix_user_sessions_expires_at", "user_sessions", ["expires_at"])
# 5. Composite index for active session lookup by JTI
op.create_index(
'ix_user_sessions_jti_active',
'user_sessions',
['refresh_token_jti', 'is_active']
"ix_user_sessions_jti_active",
"user_sessions",
["refresh_token_jti", "is_active"],
)
def downgrade() -> None:
# Drop indexes first
op.drop_index('ix_user_sessions_jti_active', table_name='user_sessions')
op.drop_index('ix_user_sessions_expires_at', table_name='user_sessions')
op.drop_index('ix_user_sessions_user_active', table_name='user_sessions')
op.drop_index('ix_user_sessions_user_id', table_name='user_sessions')
op.drop_index('ix_user_sessions_jti', table_name='user_sessions')
op.drop_index("ix_user_sessions_jti_active", table_name="user_sessions")
op.drop_index("ix_user_sessions_expires_at", table_name="user_sessions")
op.drop_index("ix_user_sessions_user_active", table_name="user_sessions")
op.drop_index("ix_user_sessions_user_id", table_name="user_sessions")
op.drop_index("ix_user_sessions_jti", table_name="user_sessions")
# Drop foreign key
op.drop_constraint('fk_user_sessions_user_id', 'user_sessions', type_='foreignkey')
op.drop_constraint("fk_user_sessions_user_id", "user_sessions", type_="foreignkey")
# Drop table
op.drop_table('user_sessions')
op.drop_table("user_sessions")

View File

@@ -1,19 +1,18 @@
"""Initial empty migration
Revision ID: 7396957cbe80
Revises:
Revises:
Create Date: 2025-02-27 12:47:46.445313
"""
from typing import Sequence, Union
from alembic import op
from collections.abc import Sequence
# revision identifiers, used by Alembic.
revision: str = '7396957cbe80'
down_revision: Union[str, None] = None
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
revision: str = "7396957cbe80"
down_revision: str | None = None
branch_labels: str | Sequence[str] | None = None
depends_on: str | Sequence[str] | None = None
def upgrade() -> None:

View File

@@ -5,80 +5,112 @@ Revises: 38bf9e7e74b3
Create Date: 2025-10-30 10:00:00.000000
"""
from typing import Sequence, Union
from collections.abc import Sequence
import sqlalchemy as sa
from alembic import op
# revision identifiers, used by Alembic.
revision: str = '9e4f2a1b8c7d'
down_revision: Union[str, None] = '38bf9e7e74b3'
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
revision: str = "9e4f2a1b8c7d"
down_revision: str | None = "38bf9e7e74b3"
branch_labels: str | Sequence[str] | None = None
depends_on: str | Sequence[str] | None = None
def upgrade() -> None:
# Add missing indexes for is_active and is_superuser
op.create_index(op.f('ix_users_is_active'), 'users', ['is_active'], unique=False)
op.create_index(op.f('ix_users_is_superuser'), 'users', ['is_superuser'], unique=False)
op.create_index(op.f("ix_users_is_active"), "users", ["is_active"], unique=False)
op.create_index(
op.f("ix_users_is_superuser"), "users", ["is_superuser"], unique=False
)
# Fix column types to match model definitions with explicit lengths
op.alter_column('users', 'email',
existing_type=sa.String(),
type_=sa.String(length=255),
nullable=False)
op.alter_column(
"users",
"email",
existing_type=sa.String(),
type_=sa.String(length=255),
nullable=False,
)
op.alter_column('users', 'password_hash',
existing_type=sa.String(),
type_=sa.String(length=255),
nullable=False)
op.alter_column(
"users",
"password_hash",
existing_type=sa.String(),
type_=sa.String(length=255),
nullable=False,
)
op.alter_column('users', 'first_name',
existing_type=sa.String(),
type_=sa.String(length=100),
nullable=False,
server_default='user') # Add server default
op.alter_column(
"users",
"first_name",
existing_type=sa.String(),
type_=sa.String(length=100),
nullable=False,
server_default="user",
) # Add server default
op.alter_column('users', 'last_name',
existing_type=sa.String(),
type_=sa.String(length=100),
nullable=True)
op.alter_column(
"users",
"last_name",
existing_type=sa.String(),
type_=sa.String(length=100),
nullable=True,
)
op.alter_column('users', 'phone_number',
existing_type=sa.String(),
type_=sa.String(length=20),
nullable=True)
op.alter_column(
"users",
"phone_number",
existing_type=sa.String(),
type_=sa.String(length=20),
nullable=True,
)
def downgrade() -> None:
# Revert column types
op.alter_column('users', 'phone_number',
existing_type=sa.String(length=20),
type_=sa.String(),
nullable=True)
op.alter_column(
"users",
"phone_number",
existing_type=sa.String(length=20),
type_=sa.String(),
nullable=True,
)
op.alter_column('users', 'last_name',
existing_type=sa.String(length=100),
type_=sa.String(),
nullable=True)
op.alter_column(
"users",
"last_name",
existing_type=sa.String(length=100),
type_=sa.String(),
nullable=True,
)
op.alter_column('users', 'first_name',
existing_type=sa.String(length=100),
type_=sa.String(),
nullable=False,
server_default=None) # Remove server default
op.alter_column(
"users",
"first_name",
existing_type=sa.String(length=100),
type_=sa.String(),
nullable=False,
server_default=None,
) # Remove server default
op.alter_column('users', 'password_hash',
existing_type=sa.String(length=255),
type_=sa.String(),
nullable=False)
op.alter_column(
"users",
"password_hash",
existing_type=sa.String(length=255),
type_=sa.String(),
nullable=False,
)
op.alter_column('users', 'email',
existing_type=sa.String(length=255),
type_=sa.String(),
nullable=False)
op.alter_column(
"users",
"email",
existing_type=sa.String(length=255),
type_=sa.String(),
nullable=False,
)
# Drop indexes
op.drop_index(op.f('ix_users_is_superuser'), table_name='users')
op.drop_index(op.f('ix_users_is_active'), table_name='users')
op.drop_index(op.f("ix_users_is_superuser"), table_name="users")
op.drop_index(op.f("ix_users_is_active"), table_name="users")

View File

@@ -5,17 +5,17 @@ Revises: 2d0fcec3b06d
Create Date: 2025-10-30 16:41:33.273135
"""
from typing import Sequence, Union
from collections.abc import Sequence
import sqlalchemy as sa
from alembic import op
# revision identifiers, used by Alembic.
revision: str = 'b76c725fc3cf'
down_revision: Union[str, None] = '2d0fcec3b06d'
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
revision: str = "b76c725fc3cf"
down_revision: str | None = "2d0fcec3b06d"
branch_labels: str | Sequence[str] | None = None
depends_on: str | Sequence[str] | None = None
def upgrade() -> None:
@@ -23,30 +23,26 @@ def upgrade() -> None:
# Composite index for filtering active users by role
op.create_index(
'ix_users_active_superuser',
'users',
['is_active', 'is_superuser'],
postgresql_where=sa.text('deleted_at IS NULL')
"ix_users_active_superuser",
"users",
["is_active", "is_superuser"],
postgresql_where=sa.text("deleted_at IS NULL"),
)
# Composite index for sorting active users by creation date
op.create_index(
'ix_users_active_created',
'users',
['is_active', 'created_at'],
postgresql_where=sa.text('deleted_at IS NULL')
"ix_users_active_created",
"users",
["is_active", "created_at"],
postgresql_where=sa.text("deleted_at IS NULL"),
)
# Composite index for email lookup of non-deleted users
op.create_index(
'ix_users_email_not_deleted',
'users',
['email', 'deleted_at']
)
op.create_index("ix_users_email_not_deleted", "users", ["email", "deleted_at"])
def downgrade() -> None:
# Remove composite indexes
op.drop_index('ix_users_email_not_deleted', table_name='users')
op.drop_index('ix_users_active_created', table_name='users')
op.drop_index('ix_users_active_superuser', table_name='users')
op.drop_index("ix_users_email_not_deleted", table_name="users")
op.drop_index("ix_users_active_created", table_name="users")
op.drop_index("ix_users_active_superuser", table_name="users")

View File

@@ -5,102 +5,123 @@ Revises: 549b50ea888d
Create Date: 2025-10-31 12:08:05.141353
"""
from typing import Sequence, Union
from collections.abc import Sequence
import sqlalchemy as sa
from alembic import op
# revision identifiers, used by Alembic.
revision: str = 'fbf6318a8a36'
down_revision: Union[str, None] = '549b50ea888d'
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
revision: str = "fbf6318a8a36"
down_revision: str | None = "549b50ea888d"
branch_labels: str | Sequence[str] | None = None
depends_on: str | Sequence[str] | None = None
def upgrade() -> None:
# Create organizations table
op.create_table(
'organizations',
sa.Column('id', sa.UUID(), nullable=False),
sa.Column('name', sa.String(length=255), nullable=False),
sa.Column('slug', sa.String(length=255), nullable=False),
sa.Column('description', sa.Text(), nullable=True),
sa.Column('is_active', sa.Boolean(), nullable=False, server_default='true'),
sa.Column('settings', sa.JSON(), nullable=True),
sa.Column('created_at', sa.DateTime(timezone=True), nullable=False),
sa.Column('updated_at', sa.DateTime(timezone=True), nullable=False),
sa.PrimaryKeyConstraint('id')
"organizations",
sa.Column("id", sa.UUID(), nullable=False),
sa.Column("name", sa.String(length=255), nullable=False),
sa.Column("slug", sa.String(length=255), nullable=False),
sa.Column("description", sa.Text(), nullable=True),
sa.Column("is_active", sa.Boolean(), nullable=False, server_default="true"),
sa.Column("settings", sa.JSON(), nullable=True),
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False),
sa.Column("updated_at", sa.DateTime(timezone=True), nullable=False),
sa.PrimaryKeyConstraint("id"),
)
# Create indexes for organizations
op.create_index('ix_organizations_name', 'organizations', ['name'])
op.create_index('ix_organizations_slug', 'organizations', ['slug'], unique=True)
op.create_index('ix_organizations_is_active', 'organizations', ['is_active'])
op.create_index('ix_organizations_name_active', 'organizations', ['name', 'is_active'])
op.create_index('ix_organizations_slug_active', 'organizations', ['slug', 'is_active'])
op.create_index("ix_organizations_name", "organizations", ["name"])
op.create_index("ix_organizations_slug", "organizations", ["slug"], unique=True)
op.create_index("ix_organizations_is_active", "organizations", ["is_active"])
op.create_index(
"ix_organizations_name_active", "organizations", ["name", "is_active"]
)
op.create_index(
"ix_organizations_slug_active", "organizations", ["slug", "is_active"]
)
# Create user_organizations junction table
op.create_table(
'user_organizations',
sa.Column('user_id', sa.UUID(), nullable=False),
sa.Column('organization_id', sa.UUID(), nullable=False),
sa.Column('role', sa.Enum('OWNER', 'ADMIN', 'MEMBER', 'GUEST', name='organizationrole'), nullable=False, server_default='MEMBER'),
sa.Column('is_active', sa.Boolean(), nullable=False, server_default='true'),
sa.Column('custom_permissions', sa.String(length=500), nullable=True),
sa.Column('created_at', sa.DateTime(timezone=True), nullable=False),
sa.Column('updated_at', sa.DateTime(timezone=True), nullable=False),
sa.PrimaryKeyConstraint('user_id', 'organization_id')
"user_organizations",
sa.Column("user_id", sa.UUID(), nullable=False),
sa.Column("organization_id", sa.UUID(), nullable=False),
sa.Column(
"role",
sa.Enum("OWNER", "ADMIN", "MEMBER", "GUEST", name="organizationrole"),
nullable=False,
server_default="MEMBER",
),
sa.Column("is_active", sa.Boolean(), nullable=False, server_default="true"),
sa.Column("custom_permissions", sa.String(length=500), nullable=True),
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False),
sa.Column("updated_at", sa.DateTime(timezone=True), nullable=False),
sa.PrimaryKeyConstraint("user_id", "organization_id"),
)
# Create foreign keys
op.create_foreign_key(
'fk_user_organizations_user_id',
'user_organizations',
'users',
['user_id'],
['id'],
ondelete='CASCADE'
"fk_user_organizations_user_id",
"user_organizations",
"users",
["user_id"],
["id"],
ondelete="CASCADE",
)
op.create_foreign_key(
'fk_user_organizations_organization_id',
'user_organizations',
'organizations',
['organization_id'],
['id'],
ondelete='CASCADE'
"fk_user_organizations_organization_id",
"user_organizations",
"organizations",
["organization_id"],
["id"],
ondelete="CASCADE",
)
# Create indexes for user_organizations
op.create_index('ix_user_organizations_role', 'user_organizations', ['role'])
op.create_index('ix_user_organizations_is_active', 'user_organizations', ['is_active'])
op.create_index('ix_user_org_user_active', 'user_organizations', ['user_id', 'is_active'])
op.create_index('ix_user_org_org_active', 'user_organizations', ['organization_id', 'is_active'])
op.create_index("ix_user_organizations_role", "user_organizations", ["role"])
op.create_index(
"ix_user_organizations_is_active", "user_organizations", ["is_active"]
)
op.create_index(
"ix_user_org_user_active", "user_organizations", ["user_id", "is_active"]
)
op.create_index(
"ix_user_org_org_active", "user_organizations", ["organization_id", "is_active"]
)
def downgrade() -> None:
# Drop indexes for user_organizations
op.drop_index('ix_user_org_org_active', table_name='user_organizations')
op.drop_index('ix_user_org_user_active', table_name='user_organizations')
op.drop_index('ix_user_organizations_is_active', table_name='user_organizations')
op.drop_index('ix_user_organizations_role', table_name='user_organizations')
op.drop_index("ix_user_org_org_active", table_name="user_organizations")
op.drop_index("ix_user_org_user_active", table_name="user_organizations")
op.drop_index("ix_user_organizations_is_active", table_name="user_organizations")
op.drop_index("ix_user_organizations_role", table_name="user_organizations")
# Drop foreign keys
op.drop_constraint('fk_user_organizations_organization_id', 'user_organizations', type_='foreignkey')
op.drop_constraint('fk_user_organizations_user_id', 'user_organizations', type_='foreignkey')
op.drop_constraint(
"fk_user_organizations_organization_id",
"user_organizations",
type_="foreignkey",
)
op.drop_constraint(
"fk_user_organizations_user_id", "user_organizations", type_="foreignkey"
)
# Drop user_organizations table
op.drop_table('user_organizations')
op.drop_table("user_organizations")
# Drop indexes for organizations
op.drop_index('ix_organizations_slug_active', table_name='organizations')
op.drop_index('ix_organizations_name_active', table_name='organizations')
op.drop_index('ix_organizations_is_active', table_name='organizations')
op.drop_index('ix_organizations_slug', table_name='organizations')
op.drop_index('ix_organizations_name', table_name='organizations')
op.drop_index("ix_organizations_slug_active", table_name="organizations")
op.drop_index("ix_organizations_name_active", table_name="organizations")
op.drop_index("ix_organizations_is_active", table_name="organizations")
op.drop_index("ix_organizations_slug", table_name="organizations")
op.drop_index("ix_organizations_name", table_name="organizations")
# Drop organizations table
op.drop_table('organizations')
op.drop_table("organizations")
# Drop enum type
op.execute('DROP TYPE IF EXISTS organizationrole')
op.execute("DROP TYPE IF EXISTS organizationrole")

View File

@@ -1,12 +1,10 @@
from typing import Optional
from fastapi import Depends, HTTPException, status, Header
from fastapi import Depends, Header, HTTPException, status
from fastapi.security import OAuth2PasswordBearer
from fastapi.security.utils import get_authorization_scheme_param
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.core.auth import get_token_data, TokenExpiredError, TokenInvalidError
from app.core.auth import TokenExpiredError, TokenInvalidError, get_token_data
from app.core.database import get_db
from app.models.user import User
@@ -15,8 +13,7 @@ oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/api/v1/auth/login")
async def get_current_user(
db: AsyncSession = Depends(get_db),
token: str = Depends(oauth2_scheme)
db: AsyncSession = Depends(get_db), token: str = Depends(oauth2_scheme)
) -> User:
"""
Get the current authenticated user.
@@ -36,21 +33,17 @@ async def get_current_user(
token_data = get_token_data(token)
# Get user from database
result = await db.execute(
select(User).where(User.id == token_data.user_id)
)
result = await db.execute(select(User).where(User.id == token_data.user_id))
user = result.scalar_one_or_none()
if not user:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="User not found"
status_code=status.HTTP_404_NOT_FOUND, detail="User not found"
)
if not user.is_active:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Inactive user"
status_code=status.HTTP_403_FORBIDDEN, detail="Inactive user"
)
return user
@@ -59,19 +52,17 @@ async def get_current_user(
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Token expired",
headers={"WWW-Authenticate": "Bearer"}
headers={"WWW-Authenticate": "Bearer"},
)
except TokenInvalidError:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Could not validate credentials",
headers={"WWW-Authenticate": "Bearer"}
headers={"WWW-Authenticate": "Bearer"},
)
def get_current_active_user(
current_user: User = Depends(get_current_user)
) -> User:
def get_current_active_user(current_user: User = Depends(get_current_user)) -> User:
"""
Check if the current user is active.
@@ -86,15 +77,12 @@ def get_current_active_user(
"""
if not current_user.is_active:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Inactive user"
status_code=status.HTTP_403_FORBIDDEN, detail="Inactive user"
)
return current_user
def get_current_superuser(
current_user: User = Depends(get_current_user)
) -> User:
def get_current_superuser(current_user: User = Depends(get_current_user)) -> User:
"""
Check if the current user is a superuser.
@@ -109,13 +97,12 @@ def get_current_superuser(
"""
if not current_user.is_superuser:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Not enough permissions"
status_code=status.HTTP_403_FORBIDDEN, detail="Not enough permissions"
)
return current_user
async def get_optional_token(authorization: str = Header(None)) -> Optional[str]:
async def get_optional_token(authorization: str = Header(None)) -> str | None:
"""
Get the token from the Authorization header without requiring it.
@@ -139,9 +126,8 @@ async def get_optional_token(authorization: str = Header(None)) -> Optional[str]
async def get_optional_current_user(
db: AsyncSession = Depends(get_db),
token: Optional[str] = Depends(get_optional_token)
) -> Optional[User]:
db: AsyncSession = Depends(get_db), token: str | None = Depends(get_optional_token)
) -> User | None:
"""
Get the current user if authenticated, otherwise return None.
Useful for endpoints that work with both authenticated and unauthenticated users.
@@ -158,12 +144,10 @@ async def get_optional_current_user(
try:
token_data = get_token_data(token)
result = await db.execute(
select(User).where(User.id == token_data.user_id)
)
result = await db.execute(select(User).where(User.id == token_data.user_id))
user = result.scalar_one_or_none()
if not user or not user.is_active:
return None
return user
except (TokenExpiredError, TokenInvalidError):
return None
return None

View File

@@ -7,7 +7,7 @@ These dependencies are optional and flexible:
- Use require_org_role for organization-specific access control
- Projects can choose to use these or implement their own permission system
"""
from typing import Optional
from uuid import UUID
from fastapi import Depends, HTTPException, status
@@ -20,9 +20,7 @@ from app.models.user import User
from app.models.user_organization import OrganizationRole
def require_superuser(
current_user: User = Depends(get_current_user)
) -> User:
def require_superuser(current_user: User = Depends(get_current_user)) -> User:
"""
Dependency to ensure the current user is a superuser.
@@ -36,7 +34,7 @@ def require_superuser(
if not current_user.is_superuser:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Superuser privileges required"
detail="Superuser privileges required",
)
return current_user
@@ -62,7 +60,7 @@ class OrganizationPermission:
self,
organization_id: UUID,
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db)
db: AsyncSession = Depends(get_db),
) -> User:
"""
Check if user has required role in the organization.
@@ -84,21 +82,19 @@ class OrganizationPermission:
# Get user's role in organization
user_role = await organization_crud.get_user_role_in_org(
db,
user_id=current_user.id,
organization_id=organization_id
db, user_id=current_user.id, organization_id=organization_id
)
if not user_role:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Not a member of this organization"
detail="Not a member of this organization",
)
if user_role not in self.allowed_roles:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail=f"Role {user_role} not authorized. Required: {self.allowed_roles}"
detail=f"Role {user_role} not authorized. Required: {self.allowed_roles}",
)
return current_user
@@ -106,18 +102,18 @@ class OrganizationPermission:
# Common permission presets for convenience
require_org_owner = OrganizationPermission([OrganizationRole.OWNER])
require_org_admin = OrganizationPermission([OrganizationRole.OWNER, OrganizationRole.ADMIN])
require_org_member = OrganizationPermission([
OrganizationRole.OWNER,
OrganizationRole.ADMIN,
OrganizationRole.MEMBER
])
require_org_admin = OrganizationPermission(
[OrganizationRole.OWNER, OrganizationRole.ADMIN]
)
require_org_member = OrganizationPermission(
[OrganizationRole.OWNER, OrganizationRole.ADMIN, OrganizationRole.MEMBER]
)
async def require_org_membership(
organization_id: UUID,
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db)
db: AsyncSession = Depends(get_db),
) -> User:
"""
Ensure user is a member of the organization (any role).
@@ -128,15 +124,13 @@ async def require_org_membership(
return current_user
user_role = await organization_crud.get_user_role_in_org(
db,
user_id=current_user.id,
organization_id=organization_id
db, user_id=current_user.id, organization_id=organization_id
)
if not user_role:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Not a member of this organization"
detail="Not a member of this organization",
)
return current_user

View File

@@ -1,10 +1,12 @@
from fastapi import APIRouter
from app.api.routes import auth, users, sessions, admin, organizations
from app.api.routes import admin, auth, organizations, sessions, users
api_router = APIRouter()
api_router.include_router(auth.router, prefix="/auth", tags=["Authentication"])
api_router.include_router(users.router, prefix="/users", tags=["Users"])
api_router.include_router(sessions.router, prefix="/sessions", tags=["Sessions"])
api_router.include_router(admin.router, prefix="/admin", tags=["Admin"])
api_router.include_router(organizations.router, prefix="/organizations", tags=["Organizations"])
api_router.include_router(
organizations.router, prefix="/organizations", tags=["Organizations"]
)

View File

@@ -5,9 +5,10 @@ Admin-specific endpoints for managing users and organizations.
These endpoints require superuser privileges and provide CMS-like functionality
for managing the application.
"""
import logging
from enum import Enum
from typing import Any, List, Optional
from typing import Any
from uuid import UUID
from fastapi import APIRouter, Depends, Query, status
@@ -16,27 +17,32 @@ from sqlalchemy.ext.asyncio import AsyncSession
from app.api.dependencies.permissions import require_superuser
from app.core.database import get_db
from app.core.exceptions import NotFoundError, DuplicateError, AuthorizationError, ErrorCode
from app.core.exceptions import (
AuthorizationError,
DuplicateError,
ErrorCode,
NotFoundError,
)
from app.crud.organization import organization as organization_crud
from app.crud.user import user as user_crud
from app.crud.session import session as session_crud
from app.crud.user import user as user_crud
from app.models.user import User
from app.models.user_organization import OrganizationRole
from app.schemas.common import (
PaginationParams,
PaginatedResponse,
MessageResponse,
PaginatedResponse,
PaginationParams,
SortParams,
create_pagination_meta
create_pagination_meta,
)
from app.schemas.organizations import (
OrganizationResponse,
OrganizationCreate,
OrganizationMemberResponse,
OrganizationResponse,
OrganizationUpdate,
OrganizationMemberResponse
)
from app.schemas.users import UserResponse, UserCreate, UserUpdate
from app.schemas.sessions import AdminSessionResponse
from app.schemas.users import UserCreate, UserResponse, UserUpdate
logger = logging.getLogger(__name__)
@@ -46,6 +52,7 @@ router = APIRouter()
# Schemas for bulk operations
class BulkAction(str, Enum):
"""Supported bulk actions."""
ACTIVATE = "activate"
DEACTIVATE = "deactivate"
DELETE = "delete"
@@ -53,36 +60,41 @@ class BulkAction(str, Enum):
class BulkUserAction(BaseModel):
"""Schema for bulk user actions."""
action: BulkAction = Field(..., description="Action to perform on selected users")
user_ids: List[UUID] = Field(..., min_items=1, max_items=100, description="List of user IDs (max 100)")
user_ids: list[UUID] = Field(
..., min_items=1, max_items=100, description="List of user IDs (max 100)"
)
class BulkActionResult(BaseModel):
"""Result of a bulk action."""
success: bool
affected_count: int
failed_count: int
message: str
failed_ids: Optional[List[UUID]] = []
failed_ids: list[UUID] | None = []
# ===== User Management Endpoints =====
@router.get(
"/users",
response_model=PaginatedResponse[UserResponse],
summary="Admin: List All Users",
description="Get paginated list of all users with filtering and search (admin only)",
operation_id="admin_list_users"
operation_id="admin_list_users",
)
async def admin_list_users(
pagination: PaginationParams = Depends(),
sort: SortParams = Depends(),
is_active: Optional[bool] = Query(None, description="Filter by active status"),
is_superuser: Optional[bool] = Query(None, description="Filter by superuser status"),
search: Optional[str] = Query(None, description="Search by email, name"),
is_active: bool | None = Query(None, description="Filter by active status"),
is_superuser: bool | None = Query(None, description="Filter by superuser status"),
search: str | None = Query(None, description="Search by email, name"),
admin: User = Depends(require_superuser),
db: AsyncSession = Depends(get_db)
db: AsyncSession = Depends(get_db),
) -> Any:
"""
List all users with comprehensive filtering and search.
@@ -105,20 +117,20 @@ async def admin_list_users(
sort_by=sort.sort_by or "created_at",
sort_order=sort.sort_order.value if sort.sort_order else "desc",
filters=filters if filters else None,
search=search
search=search,
)
pagination_meta = create_pagination_meta(
total=total,
page=pagination.page,
limit=pagination.limit,
items_count=len(users)
items_count=len(users),
)
return PaginatedResponse(data=users, pagination=pagination_meta)
except Exception as e:
logger.error(f"Error listing users (admin): {str(e)}", exc_info=True)
logger.error(f"Error listing users (admin): {e!s}", exc_info=True)
raise
@@ -128,12 +140,12 @@ async def admin_list_users(
status_code=status.HTTP_201_CREATED,
summary="Admin: Create User",
description="Create a new user (admin only)",
operation_id="admin_create_user"
operation_id="admin_create_user",
)
async def admin_create_user(
user_in: UserCreate,
admin: User = Depends(require_superuser),
db: AsyncSession = Depends(get_db)
db: AsyncSession = Depends(get_db),
) -> Any:
"""
Create a new user with admin privileges.
@@ -145,13 +157,10 @@ async def admin_create_user(
logger.info(f"Admin {admin.email} created user {user.email}")
return user
except ValueError as e:
logger.warning(f"Failed to create user: {str(e)}")
raise NotFoundError(
message=str(e),
error_code=ErrorCode.USER_ALREADY_EXISTS
)
logger.warning(f"Failed to create user: {e!s}")
raise NotFoundError(message=str(e), error_code=ErrorCode.USER_ALREADY_EXISTS)
except Exception as e:
logger.error(f"Error creating user (admin): {str(e)}", exc_info=True)
logger.error(f"Error creating user (admin): {e!s}", exc_info=True)
raise
@@ -160,19 +169,18 @@ async def admin_create_user(
response_model=UserResponse,
summary="Admin: Get User Details",
description="Get detailed user information (admin only)",
operation_id="admin_get_user"
operation_id="admin_get_user",
)
async def admin_get_user(
user_id: UUID,
admin: User = Depends(require_superuser),
db: AsyncSession = Depends(get_db)
db: AsyncSession = Depends(get_db),
) -> Any:
"""Get detailed information about a specific user."""
user = await user_crud.get(db, id=user_id)
if not user:
raise NotFoundError(
message=f"User {user_id} not found",
error_code=ErrorCode.USER_NOT_FOUND
message=f"User {user_id} not found", error_code=ErrorCode.USER_NOT_FOUND
)
return user
@@ -182,21 +190,20 @@ async def admin_get_user(
response_model=UserResponse,
summary="Admin: Update User",
description="Update user information (admin only)",
operation_id="admin_update_user"
operation_id="admin_update_user",
)
async def admin_update_user(
user_id: UUID,
user_in: UserUpdate,
admin: User = Depends(require_superuser),
db: AsyncSession = Depends(get_db)
db: AsyncSession = Depends(get_db),
) -> Any:
"""Update user information with admin privileges."""
try:
user = await user_crud.get(db, id=user_id)
if not user:
raise NotFoundError(
message=f"User {user_id} not found",
error_code=ErrorCode.USER_NOT_FOUND
message=f"User {user_id} not found", error_code=ErrorCode.USER_NOT_FOUND
)
updated_user = await user_crud.update(db, db_obj=user, obj_in=user_in)
@@ -206,7 +213,7 @@ async def admin_update_user(
except NotFoundError:
raise
except Exception as e:
logger.error(f"Error updating user (admin): {str(e)}", exc_info=True)
logger.error(f"Error updating user (admin): {e!s}", exc_info=True)
raise
@@ -215,20 +222,19 @@ async def admin_update_user(
response_model=MessageResponse,
summary="Admin: Delete User",
description="Soft delete a user (admin only)",
operation_id="admin_delete_user"
operation_id="admin_delete_user",
)
async def admin_delete_user(
user_id: UUID,
admin: User = Depends(require_superuser),
db: AsyncSession = Depends(get_db)
db: AsyncSession = Depends(get_db),
) -> Any:
"""Soft delete a user (sets deleted_at timestamp)."""
try:
user = await user_crud.get(db, id=user_id)
if not user:
raise NotFoundError(
message=f"User {user_id} not found",
error_code=ErrorCode.USER_NOT_FOUND
message=f"User {user_id} not found", error_code=ErrorCode.USER_NOT_FOUND
)
# Prevent deleting yourself
@@ -236,21 +242,20 @@ async def admin_delete_user(
# Use AuthorizationError for permission/operation restrictions
raise AuthorizationError(
message="Cannot delete your own account",
error_code=ErrorCode.OPERATION_FORBIDDEN
error_code=ErrorCode.OPERATION_FORBIDDEN,
)
await user_crud.soft_delete(db, id=user_id)
logger.info(f"Admin {admin.email} deleted user {user.email}")
return MessageResponse(
success=True,
message=f"User {user.email} has been deleted"
success=True, message=f"User {user.email} has been deleted"
)
except NotFoundError:
raise
except Exception as e:
logger.error(f"Error deleting user (admin): {str(e)}", exc_info=True)
logger.error(f"Error deleting user (admin): {e!s}", exc_info=True)
raise
@@ -259,34 +264,32 @@ async def admin_delete_user(
response_model=MessageResponse,
summary="Admin: Activate User",
description="Activate a user account (admin only)",
operation_id="admin_activate_user"
operation_id="admin_activate_user",
)
async def admin_activate_user(
user_id: UUID,
admin: User = Depends(require_superuser),
db: AsyncSession = Depends(get_db)
db: AsyncSession = Depends(get_db),
) -> Any:
"""Activate a user account."""
try:
user = await user_crud.get(db, id=user_id)
if not user:
raise NotFoundError(
message=f"User {user_id} not found",
error_code=ErrorCode.USER_NOT_FOUND
message=f"User {user_id} not found", error_code=ErrorCode.USER_NOT_FOUND
)
await user_crud.update(db, db_obj=user, obj_in={"is_active": True})
logger.info(f"Admin {admin.email} activated user {user.email}")
return MessageResponse(
success=True,
message=f"User {user.email} has been activated"
success=True, message=f"User {user.email} has been activated"
)
except NotFoundError:
raise
except Exception as e:
logger.error(f"Error activating user (admin): {str(e)}", exc_info=True)
logger.error(f"Error activating user (admin): {e!s}", exc_info=True)
raise
@@ -295,20 +298,19 @@ async def admin_activate_user(
response_model=MessageResponse,
summary="Admin: Deactivate User",
description="Deactivate a user account (admin only)",
operation_id="admin_deactivate_user"
operation_id="admin_deactivate_user",
)
async def admin_deactivate_user(
user_id: UUID,
admin: User = Depends(require_superuser),
db: AsyncSession = Depends(get_db)
db: AsyncSession = Depends(get_db),
) -> Any:
"""Deactivate a user account."""
try:
user = await user_crud.get(db, id=user_id)
if not user:
raise NotFoundError(
message=f"User {user_id} not found",
error_code=ErrorCode.USER_NOT_FOUND
message=f"User {user_id} not found", error_code=ErrorCode.USER_NOT_FOUND
)
# Prevent deactivating yourself
@@ -316,21 +318,20 @@ async def admin_deactivate_user(
# Use AuthorizationError for permission/operation restrictions
raise AuthorizationError(
message="Cannot deactivate your own account",
error_code=ErrorCode.OPERATION_FORBIDDEN
error_code=ErrorCode.OPERATION_FORBIDDEN,
)
await user_crud.update(db, db_obj=user, obj_in={"is_active": False})
logger.info(f"Admin {admin.email} deactivated user {user.email}")
return MessageResponse(
success=True,
message=f"User {user.email} has been deactivated"
success=True, message=f"User {user.email} has been deactivated"
)
except NotFoundError:
raise
except Exception as e:
logger.error(f"Error deactivating user (admin): {str(e)}", exc_info=True)
logger.error(f"Error deactivating user (admin): {e!s}", exc_info=True)
raise
@@ -339,12 +340,12 @@ async def admin_deactivate_user(
response_model=BulkActionResult,
summary="Admin: Bulk User Action",
description="Perform bulk actions on multiple users (admin only)",
operation_id="admin_bulk_user_action"
operation_id="admin_bulk_user_action",
)
async def admin_bulk_user_action(
bulk_action: BulkUserAction,
admin: User = Depends(require_superuser),
db: AsyncSession = Depends(get_db)
db: AsyncSession = Depends(get_db),
) -> Any:
"""
Perform bulk actions on multiple users using optimized bulk operations.
@@ -356,22 +357,16 @@ async def admin_bulk_user_action(
# Use efficient bulk operations instead of loop
if bulk_action.action == BulkAction.ACTIVATE:
affected_count = await user_crud.bulk_update_status(
db,
user_ids=bulk_action.user_ids,
is_active=True
db, user_ids=bulk_action.user_ids, is_active=True
)
elif bulk_action.action == BulkAction.DEACTIVATE:
affected_count = await user_crud.bulk_update_status(
db,
user_ids=bulk_action.user_ids,
is_active=False
db, user_ids=bulk_action.user_ids, is_active=False
)
elif bulk_action.action == BulkAction.DELETE:
# bulk_soft_delete automatically excludes the admin user
affected_count = await user_crud.bulk_soft_delete(
db,
user_ids=bulk_action.user_ids,
exclude_user_id=admin.id
db, user_ids=bulk_action.user_ids, exclude_user_id=admin.id
)
else:
raise ValueError(f"Unsupported bulk action: {bulk_action.action}")
@@ -390,29 +385,30 @@ async def admin_bulk_user_action(
affected_count=affected_count,
failed_count=failed_count,
message=f"Bulk {bulk_action.action.value}: {affected_count} users affected, {failed_count} skipped",
failed_ids=None # Bulk operations don't track individual failures
failed_ids=None, # Bulk operations don't track individual failures
)
except Exception as e:
logger.error(f"Error in bulk user action: {str(e)}", exc_info=True)
logger.error(f"Error in bulk user action: {e!s}", exc_info=True)
raise
# ===== Organization Management Endpoints =====
@router.get(
"/organizations",
response_model=PaginatedResponse[OrganizationResponse],
summary="Admin: List Organizations",
description="Get paginated list of all organizations (admin only)",
operation_id="admin_list_organizations"
operation_id="admin_list_organizations",
)
async def admin_list_organizations(
pagination: PaginationParams = Depends(),
is_active: Optional[bool] = Query(None, description="Filter by active status"),
search: Optional[str] = Query(None, description="Search by name, slug, description"),
is_active: bool | None = Query(None, description="Filter by active status"),
search: str | None = Query(None, description="Search by name, slug, description"),
admin: User = Depends(require_superuser),
db: AsyncSession = Depends(get_db)
db: AsyncSession = Depends(get_db),
) -> Any:
"""List all organizations with filtering and search."""
try:
@@ -422,14 +418,14 @@ async def admin_list_organizations(
skip=pagination.offset,
limit=pagination.limit,
is_active=is_active,
search=search
search=search,
)
# Build response objects from optimized query results
orgs_with_count = []
for item in orgs_with_data:
org = item['organization']
member_count = item['member_count']
org = item["organization"]
member_count = item["member_count"]
org_dict = {
"id": org.id,
@@ -440,7 +436,7 @@ async def admin_list_organizations(
"settings": org.settings,
"created_at": org.created_at,
"updated_at": org.updated_at,
"member_count": member_count
"member_count": member_count,
}
orgs_with_count.append(OrganizationResponse(**org_dict))
@@ -448,13 +444,13 @@ async def admin_list_organizations(
total=total,
page=pagination.page,
limit=pagination.limit,
items_count=len(orgs_with_count)
items_count=len(orgs_with_count),
)
return PaginatedResponse(data=orgs_with_count, pagination=pagination_meta)
except Exception as e:
logger.error(f"Error listing organizations (admin): {str(e)}", exc_info=True)
logger.error(f"Error listing organizations (admin): {e!s}", exc_info=True)
raise
@@ -464,12 +460,12 @@ async def admin_list_organizations(
status_code=status.HTTP_201_CREATED,
summary="Admin: Create Organization",
description="Create a new organization (admin only)",
operation_id="admin_create_organization"
operation_id="admin_create_organization",
)
async def admin_create_organization(
org_in: OrganizationCreate,
admin: User = Depends(require_superuser),
db: AsyncSession = Depends(get_db)
db: AsyncSession = Depends(get_db),
) -> Any:
"""Create a new organization."""
try:
@@ -486,18 +482,15 @@ async def admin_create_organization(
"settings": org.settings,
"created_at": org.created_at,
"updated_at": org.updated_at,
"member_count": 0
"member_count": 0,
}
return OrganizationResponse(**org_dict)
except ValueError as e:
logger.warning(f"Failed to create organization: {str(e)}")
raise NotFoundError(
message=str(e),
error_code=ErrorCode.ALREADY_EXISTS
)
logger.warning(f"Failed to create organization: {e!s}")
raise NotFoundError(message=str(e), error_code=ErrorCode.ALREADY_EXISTS)
except Exception as e:
logger.error(f"Error creating organization (admin): {str(e)}", exc_info=True)
logger.error(f"Error creating organization (admin): {e!s}", exc_info=True)
raise
@@ -506,19 +499,18 @@ async def admin_create_organization(
response_model=OrganizationResponse,
summary="Admin: Get Organization Details",
description="Get detailed organization information (admin only)",
operation_id="admin_get_organization"
operation_id="admin_get_organization",
)
async def admin_get_organization(
org_id: UUID,
admin: User = Depends(require_superuser),
db: AsyncSession = Depends(get_db)
db: AsyncSession = Depends(get_db),
) -> Any:
"""Get detailed information about a specific organization."""
org = await organization_crud.get(db, id=org_id)
if not org:
raise NotFoundError(
message=f"Organization {org_id} not found",
error_code=ErrorCode.NOT_FOUND
message=f"Organization {org_id} not found", error_code=ErrorCode.NOT_FOUND
)
org_dict = {
@@ -530,7 +522,9 @@ async def admin_get_organization(
"settings": org.settings,
"created_at": org.created_at,
"updated_at": org.updated_at,
"member_count": await organization_crud.get_member_count(db, organization_id=org.id)
"member_count": await organization_crud.get_member_count(
db, organization_id=org.id
),
}
return OrganizationResponse(**org_dict)
@@ -540,13 +534,13 @@ async def admin_get_organization(
response_model=OrganizationResponse,
summary="Admin: Update Organization",
description="Update organization information (admin only)",
operation_id="admin_update_organization"
operation_id="admin_update_organization",
)
async def admin_update_organization(
org_id: UUID,
org_in: OrganizationUpdate,
admin: User = Depends(require_superuser),
db: AsyncSession = Depends(get_db)
db: AsyncSession = Depends(get_db),
) -> Any:
"""Update organization information."""
try:
@@ -554,7 +548,7 @@ async def admin_update_organization(
if not org:
raise NotFoundError(
message=f"Organization {org_id} not found",
error_code=ErrorCode.NOT_FOUND
error_code=ErrorCode.NOT_FOUND,
)
updated_org = await organization_crud.update(db, db_obj=org, obj_in=org_in)
@@ -569,14 +563,16 @@ async def admin_update_organization(
"settings": updated_org.settings,
"created_at": updated_org.created_at,
"updated_at": updated_org.updated_at,
"member_count": await organization_crud.get_member_count(db, organization_id=updated_org.id)
"member_count": await organization_crud.get_member_count(
db, organization_id=updated_org.id
),
}
return OrganizationResponse(**org_dict)
except NotFoundError:
raise
except Exception as e:
logger.error(f"Error updating organization (admin): {str(e)}", exc_info=True)
logger.error(f"Error updating organization (admin): {e!s}", exc_info=True)
raise
@@ -585,12 +581,12 @@ async def admin_update_organization(
response_model=MessageResponse,
summary="Admin: Delete Organization",
description="Delete an organization (admin only)",
operation_id="admin_delete_organization"
operation_id="admin_delete_organization",
)
async def admin_delete_organization(
org_id: UUID,
admin: User = Depends(require_superuser),
db: AsyncSession = Depends(get_db)
db: AsyncSession = Depends(get_db),
) -> Any:
"""Delete an organization and all its relationships."""
try:
@@ -598,21 +594,20 @@ async def admin_delete_organization(
if not org:
raise NotFoundError(
message=f"Organization {org_id} not found",
error_code=ErrorCode.NOT_FOUND
error_code=ErrorCode.NOT_FOUND,
)
await organization_crud.remove(db, id=org_id)
logger.info(f"Admin {admin.email} deleted organization {org.name}")
return MessageResponse(
success=True,
message=f"Organization {org.name} has been deleted"
success=True, message=f"Organization {org.name} has been deleted"
)
except NotFoundError:
raise
except Exception as e:
logger.error(f"Error deleting organization (admin): {str(e)}", exc_info=True)
logger.error(f"Error deleting organization (admin): {e!s}", exc_info=True)
raise
@@ -621,14 +616,14 @@ async def admin_delete_organization(
response_model=PaginatedResponse[OrganizationMemberResponse],
summary="Admin: List Organization Members",
description="Get all members of an organization (admin only)",
operation_id="admin_list_organization_members"
operation_id="admin_list_organization_members",
)
async def admin_list_organization_members(
org_id: UUID,
pagination: PaginationParams = Depends(),
is_active: Optional[bool] = Query(True, description="Filter by active status"),
is_active: bool | None = Query(True, description="Filter by active status"),
admin: User = Depends(require_superuser),
db: AsyncSession = Depends(get_db)
db: AsyncSession = Depends(get_db),
) -> Any:
"""List all members of an organization."""
try:
@@ -636,7 +631,7 @@ async def admin_list_organization_members(
if not org:
raise NotFoundError(
message=f"Organization {org_id} not found",
error_code=ErrorCode.NOT_FOUND
error_code=ErrorCode.NOT_FOUND,
)
members, total = await organization_crud.get_organization_members(
@@ -644,7 +639,7 @@ async def admin_list_organization_members(
organization_id=org_id,
skip=pagination.offset,
limit=pagination.limit,
is_active=is_active
is_active=is_active,
)
# Convert to response models
@@ -654,7 +649,7 @@ async def admin_list_organization_members(
total=total,
page=pagination.page,
limit=pagination.limit,
items_count=len(member_responses)
items_count=len(member_responses),
)
return PaginatedResponse(data=member_responses, pagination=pagination_meta)
@@ -662,14 +657,19 @@ async def admin_list_organization_members(
except NotFoundError:
raise
except Exception as e:
logger.error(f"Error listing organization members (admin): {str(e)}", exc_info=True)
logger.error(
f"Error listing organization members (admin): {e!s}", exc_info=True
)
raise
class AddMemberRequest(BaseModel):
"""Request to add a member to an organization."""
user_id: UUID = Field(..., description="User ID to add")
role: OrganizationRole = Field(OrganizationRole.MEMBER, description="Role in organization")
role: OrganizationRole = Field(
OrganizationRole.MEMBER, description="Role in organization"
)
@router.post(
@@ -677,13 +677,13 @@ class AddMemberRequest(BaseModel):
response_model=MessageResponse,
summary="Admin: Add Member to Organization",
description="Add a user to an organization (admin only)",
operation_id="admin_add_organization_member"
operation_id="admin_add_organization_member",
)
async def admin_add_organization_member(
org_id: UUID,
request: AddMemberRequest,
admin: User = Depends(require_superuser),
db: AsyncSession = Depends(get_db)
db: AsyncSession = Depends(get_db),
) -> Any:
"""Add a user to an organization."""
try:
@@ -691,21 +691,18 @@ async def admin_add_organization_member(
if not org:
raise NotFoundError(
message=f"Organization {org_id} not found",
error_code=ErrorCode.NOT_FOUND
error_code=ErrorCode.NOT_FOUND,
)
user = await user_crud.get(db, id=request.user_id)
if not user:
raise NotFoundError(
message=f"User {request.user_id} not found",
error_code=ErrorCode.USER_NOT_FOUND
error_code=ErrorCode.USER_NOT_FOUND,
)
await organization_crud.add_user(
db,
organization_id=org_id,
user_id=request.user_id,
role=request.role
db, organization_id=org_id, user_id=request.user_id, role=request.role
)
logger.info(
@@ -714,22 +711,21 @@ async def admin_add_organization_member(
)
return MessageResponse(
success=True,
message=f"User {user.email} added to organization {org.name}"
success=True, message=f"User {user.email} added to organization {org.name}"
)
except ValueError as e:
logger.warning(f"Failed to add user to organization: {str(e)}")
logger.warning(f"Failed to add user to organization: {e!s}")
# Use DuplicateError for "already exists" scenarios
raise DuplicateError(
message=str(e),
error_code=ErrorCode.USER_ALREADY_EXISTS,
field="user_id"
message=str(e), error_code=ErrorCode.USER_ALREADY_EXISTS, field="user_id"
)
except NotFoundError:
raise
except Exception as e:
logger.error(f"Error adding member to organization (admin): {str(e)}", exc_info=True)
logger.error(
f"Error adding member to organization (admin): {e!s}", exc_info=True
)
raise
@@ -738,13 +734,13 @@ async def admin_add_organization_member(
response_model=MessageResponse,
summary="Admin: Remove Member from Organization",
description="Remove a user from an organization (admin only)",
operation_id="admin_remove_organization_member"
operation_id="admin_remove_organization_member",
)
async def admin_remove_organization_member(
org_id: UUID,
user_id: UUID,
admin: User = Depends(require_superuser),
db: AsyncSession = Depends(get_db)
db: AsyncSession = Depends(get_db),
) -> Any:
"""Remove a user from an organization."""
try:
@@ -752,39 +748,40 @@ async def admin_remove_organization_member(
if not org:
raise NotFoundError(
message=f"Organization {org_id} not found",
error_code=ErrorCode.NOT_FOUND
error_code=ErrorCode.NOT_FOUND,
)
user = await user_crud.get(db, id=user_id)
if not user:
raise NotFoundError(
message=f"User {user_id} not found",
error_code=ErrorCode.USER_NOT_FOUND
message=f"User {user_id} not found", error_code=ErrorCode.USER_NOT_FOUND
)
success = await organization_crud.remove_user(
db,
organization_id=org_id,
user_id=user_id
db, organization_id=org_id, user_id=user_id
)
if not success:
raise NotFoundError(
message="User is not a member of this organization",
error_code=ErrorCode.NOT_FOUND
error_code=ErrorCode.NOT_FOUND,
)
logger.info(f"Admin {admin.email} removed user {user.email} from organization {org.name}")
logger.info(
f"Admin {admin.email} removed user {user.email} from organization {org.name}"
)
return MessageResponse(
success=True,
message=f"User {user.email} removed from organization {org.name}"
message=f"User {user.email} removed from organization {org.name}",
)
except NotFoundError:
raise
except Exception as e:
logger.error(f"Error removing member from organization (admin): {str(e)}", exc_info=True)
logger.error(
f"Error removing member from organization (admin): {e!s}", exc_info=True
)
raise
@@ -792,6 +789,7 @@ async def admin_remove_organization_member(
# Session Management Endpoints
# ============================================================================
@router.get(
"/sessions",
response_model=PaginatedResponse[AdminSessionResponse],
@@ -802,13 +800,13 @@ async def admin_remove_organization_member(
Returns paginated list of sessions with user information.
Useful for admin dashboard statistics and session monitoring.
""",
operation_id="admin_list_sessions"
operation_id="admin_list_sessions",
)
async def admin_list_sessions(
pagination: PaginationParams = Depends(),
is_active: Optional[bool] = Query(None, description="Filter by active status"),
is_active: bool | None = Query(None, description="Filter by active status"),
admin: User = Depends(require_superuser),
db: AsyncSession = Depends(get_db)
db: AsyncSession = Depends(get_db),
) -> Any:
"""List all sessions across all users with filtering and pagination."""
try:
@@ -818,7 +816,7 @@ async def admin_list_sessions(
skip=pagination.offset,
limit=pagination.limit,
active_only=is_active if is_active is not None else True,
with_user=True
with_user=True,
)
# Build response objects with user information
@@ -847,21 +845,23 @@ async def admin_list_sessions(
last_used_at=session.last_used_at,
created_at=session.created_at,
expires_at=session.expires_at,
is_active=session.is_active
is_active=session.is_active,
)
session_responses.append(session_response)
logger.info(f"Admin {admin.email} listed {len(session_responses)} sessions (total: {total})")
logger.info(
f"Admin {admin.email} listed {len(session_responses)} sessions (total: {total})"
)
pagination_meta = create_pagination_meta(
total=total,
page=pagination.page,
limit=pagination.limit,
items_count=len(session_responses)
items_count=len(session_responses),
)
return PaginatedResponse(data=session_responses, pagination=pagination_meta)
except Exception as e:
logger.error(f"Error listing sessions (admin): {str(e)}", exc_info=True)
logger.error(f"Error listing sessions (admin): {e!s}", exc_info=True)
raise

View File

@@ -1,39 +1,43 @@
# app/api/routes/auth.py
import logging
import os
from datetime import datetime, timezone
from datetime import UTC, datetime
from typing import Any
from fastapi import APIRouter, Depends, HTTPException, status, Request
from fastapi import APIRouter, Depends, HTTPException, Request, status
from fastapi.security import OAuth2PasswordRequestForm
from slowapi import Limiter
from slowapi.util import get_remote_address
from sqlalchemy.ext.asyncio import AsyncSession
from app.api.dependencies.auth import get_current_user
from app.core.auth import TokenExpiredError, TokenInvalidError, decode_token
from app.core.auth import get_password_hash
from app.core.auth import (
TokenExpiredError,
TokenInvalidError,
decode_token,
get_password_hash,
)
from app.core.database import get_db
from app.core.exceptions import (
AuthenticationError as AuthError,
DatabaseError,
ErrorCode
ErrorCode,
)
from app.crud.session import session as session_crud
from app.crud.user import user as user_crud
from app.models.user import User
from app.schemas.common import MessageResponse
from app.schemas.sessions import SessionCreate, LogoutRequest
from app.schemas.sessions import LogoutRequest, SessionCreate
from app.schemas.users import (
LoginRequest,
PasswordResetConfirm,
PasswordResetRequest,
RefreshTokenRequest,
Token,
UserCreate,
UserResponse,
Token,
LoginRequest,
RefreshTokenRequest,
PasswordResetRequest,
PasswordResetConfirm
)
from app.services.auth_service import AuthService, AuthenticationError
from app.services.auth_service import AuthenticationError, AuthService
from app.services.email_service import email_service
from app.utils.device import extract_device_info
from app.utils.security import create_password_reset_token, verify_password_reset_token
@@ -54,7 +58,7 @@ async def _create_login_session(
request: Request,
user: User,
tokens: Token,
login_type: str = "login"
login_type: str = "login",
) -> None:
"""
Create a session record for successful login.
@@ -81,8 +85,8 @@ async def _create_login_session(
device_id=device_info.device_id,
ip_address=device_info.ip_address,
user_agent=device_info.user_agent,
last_used_at=datetime.now(timezone.utc),
expires_at=datetime.fromtimestamp(refresh_payload.exp, tz=timezone.utc),
last_used_at=datetime.now(UTC),
expires_at=datetime.fromtimestamp(refresh_payload.exp, tz=UTC),
location_city=device_info.location_city,
location_country=device_info.location_country,
)
@@ -95,15 +99,20 @@ async def _create_login_session(
)
except Exception as session_err:
# Log but don't fail login if session creation fails
logger.error(f"Failed to create session for {user.email}: {str(session_err)}", exc_info=True)
logger.error(
f"Failed to create session for {user.email}: {session_err!s}", exc_info=True
)
@router.post("/register", response_model=UserResponse, status_code=status.HTTP_201_CREATED, operation_id="register")
@router.post(
"/register",
response_model=UserResponse,
status_code=status.HTTP_201_CREATED,
operation_id="register",
)
@limiter.limit(f"{5 * RATE_MULTIPLIER}/minute")
async def register_user(
request: Request,
user_data: UserCreate,
db: AsyncSession = Depends(get_db)
request: Request, user_data: UserCreate, db: AsyncSession = Depends(get_db)
) -> Any:
"""
Register a new user.
@@ -116,25 +125,23 @@ async def register_user(
return user
except AuthenticationError as e:
# SECURITY: Don't reveal if email exists - generic error message
logger.warning(f"Registration failed: {str(e)}")
logger.warning(f"Registration failed: {e!s}")
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Registration failed. Please check your information and try again."
detail="Registration failed. Please check your information and try again.",
)
except Exception as e:
logger.error(f"Unexpected error during registration: {str(e)}", exc_info=True)
logger.error(f"Unexpected error during registration: {e!s}", exc_info=True)
raise DatabaseError(
message="An unexpected error occurred. Please try again later.",
error_code=ErrorCode.INTERNAL_ERROR
error_code=ErrorCode.INTERNAL_ERROR,
)
@router.post("/login", response_model=Token, operation_id="login")
@limiter.limit(f"{10 * RATE_MULTIPLIER}/minute")
async def login(
request: Request,
login_data: LoginRequest,
db: AsyncSession = Depends(get_db)
request: Request, login_data: LoginRequest, db: AsyncSession = Depends(get_db)
) -> Any:
"""
Login with username and password.
@@ -146,14 +153,16 @@ async def login(
"""
try:
# Attempt to authenticate the user
user = await AuthService.authenticate_user(db, login_data.email, login_data.password)
user = await AuthService.authenticate_user(
db, login_data.email, login_data.password
)
# Explicitly check for None result and raise correct exception
if user is None:
logger.warning(f"Invalid login attempt for: {login_data.email}")
raise AuthError(
message="Invalid email or password",
error_code=ErrorCode.INVALID_CREDENTIALS
error_code=ErrorCode.INVALID_CREDENTIALS,
)
# User is authenticated, generate tokens
@@ -166,29 +175,26 @@ async def login(
except AuthenticationError as e:
# Handle specific authentication errors like inactive accounts
logger.warning(f"Authentication failed: {str(e)}")
raise AuthError(
message=str(e),
error_code=ErrorCode.INVALID_CREDENTIALS
)
logger.warning(f"Authentication failed: {e!s}")
raise AuthError(message=str(e), error_code=ErrorCode.INVALID_CREDENTIALS)
except AuthError:
# Re-raise custom auth exceptions without modification
raise
except Exception as e:
# Handle unexpected errors
logger.error(f"Unexpected error during login: {str(e)}", exc_info=True)
logger.error(f"Unexpected error during login: {e!s}", exc_info=True)
raise DatabaseError(
message="An unexpected error occurred. Please try again later.",
error_code=ErrorCode.INTERNAL_ERROR
error_code=ErrorCode.INTERNAL_ERROR,
)
@router.post("/login/oauth", response_model=Token, operation_id='login_oauth')
@router.post("/login/oauth", response_model=Token, operation_id="login_oauth")
@limiter.limit("10/minute")
async def login_oauth(
request: Request,
form_data: OAuth2PasswordRequestForm = Depends(),
db: AsyncSession = Depends(get_db)
request: Request,
form_data: OAuth2PasswordRequestForm = Depends(),
db: AsyncSession = Depends(get_db),
) -> Any:
"""
OAuth2-compatible login endpoint, used by the OpenAPI UI.
@@ -199,12 +205,14 @@ async def login_oauth(
Access and refresh tokens.
"""
try:
user = await AuthService.authenticate_user(db, form_data.username, form_data.password)
user = await AuthService.authenticate_user(
db, form_data.username, form_data.password
)
if user is None:
raise AuthError(
message="Invalid email or password",
error_code=ErrorCode.INVALID_CREDENTIALS
error_code=ErrorCode.INVALID_CREDENTIALS,
)
# Generate tokens
@@ -216,28 +224,25 @@ async def login_oauth(
# Return full token response with user data
return tokens
except AuthenticationError as e:
logger.warning(f"OAuth authentication failed: {str(e)}")
raise AuthError(
message=str(e),
error_code=ErrorCode.INVALID_CREDENTIALS
)
logger.warning(f"OAuth authentication failed: {e!s}")
raise AuthError(message=str(e), error_code=ErrorCode.INVALID_CREDENTIALS)
except AuthError:
# Re-raise custom auth exceptions without modification
raise
except Exception as e:
logger.error(f"Unexpected error during OAuth login: {str(e)}", exc_info=True)
logger.error(f"Unexpected error during OAuth login: {e!s}", exc_info=True)
raise DatabaseError(
message="An unexpected error occurred. Please try again later.",
error_code=ErrorCode.INTERNAL_ERROR
error_code=ErrorCode.INTERNAL_ERROR,
)
@router.post("/refresh", response_model=Token, operation_id="refresh_token")
@limiter.limit("30/minute")
async def refresh_token(
request: Request,
refresh_data: RefreshTokenRequest,
db: AsyncSession = Depends(get_db)
request: Request,
refresh_data: RefreshTokenRequest,
db: AsyncSession = Depends(get_db),
) -> Any:
"""
Refresh access token using a refresh token.
@@ -249,13 +254,17 @@ async def refresh_token(
"""
try:
# Decode the refresh token to get the JTI
refresh_payload = decode_token(refresh_data.refresh_token, verify_type="refresh")
refresh_payload = decode_token(
refresh_data.refresh_token, verify_type="refresh"
)
# Check if session exists and is active
session = await session_crud.get_active_by_jti(db, jti=refresh_payload.jti)
if not session:
logger.warning(f"Refresh token used for inactive or non-existent session: {refresh_payload.jti}")
logger.warning(
f"Refresh token used for inactive or non-existent session: {refresh_payload.jti}"
)
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Session has been revoked. Please log in again.",
@@ -274,10 +283,12 @@ async def refresh_token(
db,
session=session,
new_jti=new_refresh_payload.jti,
new_expires_at=datetime.fromtimestamp(new_refresh_payload.exp, tz=timezone.utc)
new_expires_at=datetime.fromtimestamp(new_refresh_payload.exp, tz=UTC),
)
except Exception as session_err:
logger.error(f"Failed to update session {session.id}: {str(session_err)}", exc_info=True)
logger.error(
f"Failed to update session {session.id}: {session_err!s}", exc_info=True
)
# Continue anyway - tokens are already issued
return tokens
@@ -300,10 +311,10 @@ async def refresh_token(
# Re-raise HTTP exceptions (like session revoked)
raise
except Exception as e:
logger.error(f"Unexpected error during token refresh: {str(e)}")
logger.error(f"Unexpected error during token refresh: {e!s}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="An unexpected error occurred. Please try again later."
detail="An unexpected error occurred. Please try again later.",
)
@@ -320,13 +331,13 @@ async def refresh_token(
**Rate Limit**: 3 requests/minute
""",
operation_id="request_password_reset"
operation_id="request_password_reset",
)
@limiter.limit("3/minute")
async def request_password_reset(
request: Request,
reset_request: PasswordResetRequest,
db: AsyncSession = Depends(get_db)
db: AsyncSession = Depends(get_db),
) -> Any:
"""
Request a password reset.
@@ -345,26 +356,26 @@ async def request_password_reset(
# Send password reset email
await email_service.send_password_reset_email(
to_email=user.email,
reset_token=reset_token,
user_name=user.first_name
to_email=user.email, reset_token=reset_token, user_name=user.first_name
)
logger.info(f"Password reset requested for {user.email}")
else:
# Log attempt but don't reveal if email exists
logger.warning(f"Password reset requested for non-existent or inactive email: {reset_request.email}")
logger.warning(
f"Password reset requested for non-existent or inactive email: {reset_request.email}"
)
# Always return success to prevent email enumeration
return MessageResponse(
success=True,
message="If your email is registered, you will receive a password reset link shortly"
message="If your email is registered, you will receive a password reset link shortly",
)
except Exception as e:
logger.error(f"Error processing password reset request: {str(e)}", exc_info=True)
logger.error(f"Error processing password reset request: {e!s}", exc_info=True)
# Still return success to prevent information leakage
return MessageResponse(
success=True,
message="If your email is registered, you will receive a password reset link shortly"
message="If your email is registered, you will receive a password reset link shortly",
)
@@ -378,13 +389,13 @@ async def request_password_reset(
**Rate Limit**: 5 requests/minute
""",
operation_id="confirm_password_reset"
operation_id="confirm_password_reset",
)
@limiter.limit("5/minute")
async def confirm_password_reset(
request: Request,
reset_confirm: PasswordResetConfirm,
db: AsyncSession = Depends(get_db)
db: AsyncSession = Depends(get_db),
) -> Any:
"""
Confirm password reset with token.
@@ -398,7 +409,7 @@ async def confirm_password_reset(
if not email:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Invalid or expired password reset token"
detail="Invalid or expired password reset token",
)
# Look up user
@@ -406,14 +417,13 @@ async def confirm_password_reset(
if not user:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="User not found"
status_code=status.HTTP_404_NOT_FOUND, detail="User not found"
)
if not user.is_active:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="User account is inactive"
detail="User account is inactive",
)
# Update password
@@ -424,29 +434,33 @@ async def confirm_password_reset(
# SECURITY: Invalidate all existing sessions after password reset
# This prevents stolen sessions from being used after password change
from app.crud.session import session as session_crud
try:
deactivated_count = await session_crud.deactivate_all_user_sessions(
db,
user_id=str(user.id)
db, user_id=str(user.id)
)
logger.info(
f"Password reset successful for {user.email}, invalidated {deactivated_count} sessions"
)
logger.info(f"Password reset successful for {user.email}, invalidated {deactivated_count} sessions")
except Exception as session_error:
# Log but don't fail password reset if session invalidation fails
logger.error(f"Failed to invalidate sessions after password reset: {str(session_error)}")
logger.error(
f"Failed to invalidate sessions after password reset: {session_error!s}"
)
return MessageResponse(
success=True,
message="Password has been reset successfully. All devices have been logged out for security. You can now log in with your new password."
message="Password has been reset successfully. All devices have been logged out for security. You can now log in with your new password.",
)
except HTTPException:
raise
except Exception as e:
logger.error(f"Error confirming password reset: {str(e)}", exc_info=True)
logger.error(f"Error confirming password reset: {e!s}", exc_info=True)
await db.rollback()
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="An error occurred while resetting your password"
detail="An error occurred while resetting your password",
)
@@ -464,14 +478,14 @@ async def confirm_password_reset(
**Rate Limit**: 10 requests/minute
""",
operation_id="logout"
operation_id="logout",
)
@limiter.limit("10/minute")
async def logout(
request: Request,
logout_request: LogoutRequest,
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db)
db: AsyncSession = Depends(get_db),
) -> Any:
"""
Logout from current device by deactivating the session.
@@ -487,15 +501,14 @@ async def logout(
try:
# Decode refresh token to get JTI
try:
refresh_payload = decode_token(logout_request.refresh_token, verify_type="refresh")
refresh_payload = decode_token(
logout_request.refresh_token, verify_type="refresh"
)
except (TokenExpiredError, TokenInvalidError) as e:
# Even if token is expired/invalid, try to deactivate session
logger.warning(f"Logout with invalid/expired token: {str(e)}")
logger.warning(f"Logout with invalid/expired token: {e!s}")
# Don't fail - return success anyway
return MessageResponse(
success=True,
message="Logged out successfully"
)
return MessageResponse(success=True, message="Logged out successfully")
# Find the session by JTI
session = await session_crud.get_by_jti(db, jti=refresh_payload.jti)
@@ -509,7 +522,7 @@ async def logout(
)
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="You can only logout your own sessions"
detail="You can only logout your own sessions",
)
# Deactivate the session
@@ -522,22 +535,20 @@ async def logout(
else:
# Session not found - maybe already deleted or never existed
# Return success anyway (idempotent)
logger.info(f"Logout requested for non-existent session (JTI: {refresh_payload.jti})")
logger.info(
f"Logout requested for non-existent session (JTI: {refresh_payload.jti})"
)
return MessageResponse(
success=True,
message="Logged out successfully"
)
return MessageResponse(success=True, message="Logged out successfully")
except HTTPException:
raise
except Exception as e:
logger.error(f"Error during logout for user {current_user.id}: {str(e)}", exc_info=True)
# Don't expose error details
return MessageResponse(
success=True,
message="Logged out successfully"
logger.error(
f"Error during logout for user {current_user.id}: {e!s}", exc_info=True
)
# Don't expose error details
return MessageResponse(success=True, message="Logged out successfully")
@router.post(
@@ -553,13 +564,13 @@ async def logout(
**Rate Limit**: 5 requests/minute
""",
operation_id="logout_all"
operation_id="logout_all",
)
@limiter.limit("5/minute")
async def logout_all(
request: Request,
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db)
db: AsyncSession = Depends(get_db),
) -> Any:
"""
Logout from all devices by deactivating all user sessions.
@@ -573,19 +584,25 @@ async def logout_all(
"""
try:
# Deactivate all sessions for this user
count = await session_crud.deactivate_all_user_sessions(db, user_id=str(current_user.id))
count = await session_crud.deactivate_all_user_sessions(
db, user_id=str(current_user.id)
)
logger.info(f"User {current_user.id} logged out from all devices ({count} sessions)")
logger.info(
f"User {current_user.id} logged out from all devices ({count} sessions)"
)
return MessageResponse(
success=True,
message=f"Successfully logged out from all devices ({count} sessions terminated)"
message=f"Successfully logged out from all devices ({count} sessions terminated)",
)
except Exception as e:
logger.error(f"Error during logout-all for user {current_user.id}: {str(e)}", exc_info=True)
logger.error(
f"Error during logout-all for user {current_user.id}: {e!s}", exc_info=True
)
await db.rollback()
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="An error occurred while logging out"
detail="An error occurred while logging out",
)

View File

@@ -4,8 +4,9 @@ Organization endpoints for regular users.
These endpoints allow users to view and manage organizations they belong to.
"""
import logging
from typing import Any, List
from typing import Any
from uuid import UUID
from fastapi import APIRouter, Depends, Query
@@ -14,18 +15,18 @@ from sqlalchemy.ext.asyncio import AsyncSession
from app.api.dependencies.auth import get_current_user
from app.api.dependencies.permissions import require_org_admin, require_org_membership
from app.core.database import get_db
from app.core.exceptions import NotFoundError, ErrorCode
from app.core.exceptions import ErrorCode, NotFoundError
from app.crud.organization import organization as organization_crud
from app.models.user import User
from app.schemas.common import (
PaginationParams,
PaginatedResponse,
create_pagination_meta
PaginationParams,
create_pagination_meta,
)
from app.schemas.organizations import (
OrganizationResponse,
OrganizationMemberResponse,
OrganizationUpdate
OrganizationResponse,
OrganizationUpdate,
)
logger = logging.getLogger(__name__)
@@ -35,15 +36,15 @@ router = APIRouter()
@router.get(
"/me",
response_model=List[OrganizationResponse],
response_model=list[OrganizationResponse],
summary="Get My Organizations",
description="Get all organizations the current user belongs to",
operation_id="get_my_organizations"
operation_id="get_my_organizations",
)
async def get_my_organizations(
is_active: bool = Query(True, description="Filter by active membership"),
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db)
db: AsyncSession = Depends(get_db),
) -> Any:
"""
Get all organizations the current user belongs to.
@@ -54,15 +55,13 @@ async def get_my_organizations(
try:
# Get all org data in single query with JOIN and subquery
orgs_data = await organization_crud.get_user_organizations_with_details(
db,
user_id=current_user.id,
is_active=is_active
db, user_id=current_user.id, is_active=is_active
)
# Transform to response objects
orgs_with_data = []
for item in orgs_data:
org = item['organization']
org = item["organization"]
org_dict = {
"id": org.id,
"name": org.name,
@@ -72,14 +71,14 @@ async def get_my_organizations(
"settings": org.settings,
"created_at": org.created_at,
"updated_at": org.updated_at,
"member_count": item['member_count']
"member_count": item["member_count"],
}
orgs_with_data.append(OrganizationResponse(**org_dict))
return orgs_with_data
except Exception as e:
logger.error(f"Error getting user organizations: {str(e)}", exc_info=True)
logger.error(f"Error getting user organizations: {e!s}", exc_info=True)
raise
@@ -88,12 +87,12 @@ async def get_my_organizations(
response_model=OrganizationResponse,
summary="Get Organization Details",
description="Get details of an organization the user belongs to",
operation_id="get_organization"
operation_id="get_organization",
)
async def get_organization(
organization_id: UUID,
current_user: User = Depends(require_org_membership),
db: AsyncSession = Depends(get_db)
db: AsyncSession = Depends(get_db),
) -> Any:
"""
Get details of a specific organization.
@@ -105,7 +104,7 @@ async def get_organization(
if not org: # pragma: no cover - Permission check prevents this (see docs/UNREACHABLE_DEFENSIVE_CODE_ANALYSIS.md)
raise NotFoundError(
detail=f"Organization {organization_id} not found",
error_code=ErrorCode.NOT_FOUND
error_code=ErrorCode.NOT_FOUND,
)
org_dict = {
@@ -117,14 +116,16 @@ async def get_organization(
"settings": org.settings,
"created_at": org.created_at,
"updated_at": org.updated_at,
"member_count": await organization_crud.get_member_count(db, organization_id=org.id)
"member_count": await organization_crud.get_member_count(
db, organization_id=org.id
),
}
return OrganizationResponse(**org_dict)
except NotFoundError: # pragma: no cover - See above
raise
except Exception as e:
logger.error(f"Error getting organization: {str(e)}", exc_info=True)
logger.error(f"Error getting organization: {e!s}", exc_info=True)
raise
@@ -133,14 +134,14 @@ async def get_organization(
response_model=PaginatedResponse[OrganizationMemberResponse],
summary="Get Organization Members",
description="Get all members of an organization (members can view)",
operation_id="get_organization_members"
operation_id="get_organization_members",
)
async def get_organization_members(
organization_id: UUID,
pagination: PaginationParams = Depends(),
is_active: bool = Query(True, description="Filter by active status"),
current_user: User = Depends(require_org_membership),
db: AsyncSession = Depends(get_db)
db: AsyncSession = Depends(get_db),
) -> Any:
"""
Get all members of an organization.
@@ -153,7 +154,7 @@ async def get_organization_members(
organization_id=organization_id,
skip=pagination.offset,
limit=pagination.limit,
is_active=is_active
is_active=is_active,
)
member_responses = [OrganizationMemberResponse(**member) for member in members]
@@ -162,13 +163,13 @@ async def get_organization_members(
total=total,
page=pagination.page,
limit=pagination.limit,
items_count=len(member_responses)
items_count=len(member_responses),
)
return PaginatedResponse(data=member_responses, pagination=pagination_meta)
except Exception as e:
logger.error(f"Error getting organization members: {str(e)}", exc_info=True)
logger.error(f"Error getting organization members: {e!s}", exc_info=True)
raise
@@ -177,13 +178,13 @@ async def get_organization_members(
response_model=OrganizationResponse,
summary="Update Organization",
description="Update organization details (admin/owner only)",
operation_id="update_organization"
operation_id="update_organization",
)
async def update_organization(
organization_id: UUID,
org_in: OrganizationUpdate,
current_user: User = Depends(require_org_admin),
db: AsyncSession = Depends(get_db)
db: AsyncSession = Depends(get_db),
) -> Any:
"""
Update organization details.
@@ -195,11 +196,13 @@ async def update_organization(
if not org: # pragma: no cover - Permission check prevents this (see docs/UNREACHABLE_DEFENSIVE_CODE_ANALYSIS.md)
raise NotFoundError(
detail=f"Organization {organization_id} not found",
error_code=ErrorCode.NOT_FOUND
error_code=ErrorCode.NOT_FOUND,
)
updated_org = await organization_crud.update(db, db_obj=org, obj_in=org_in)
logger.info(f"User {current_user.email} updated organization {updated_org.name}")
logger.info(
f"User {current_user.email} updated organization {updated_org.name}"
)
org_dict = {
"id": updated_org.id,
@@ -210,12 +213,14 @@ async def update_organization(
"settings": updated_org.settings,
"created_at": updated_org.created_at,
"updated_at": updated_org.updated_at,
"member_count": await organization_crud.get_member_count(db, organization_id=updated_org.id)
"member_count": await organization_crud.get_member_count(
db, organization_id=updated_org.id
),
}
return OrganizationResponse(**org_dict)
except NotFoundError: # pragma: no cover - See above
raise
except Exception as e:
logger.error(f"Error updating organization: {str(e)}", exc_info=True)
logger.error(f"Error updating organization: {e!s}", exc_info=True)
raise

View File

@@ -3,11 +3,12 @@ Session management endpoints.
Allows users to view and manage their active sessions across devices.
"""
import logging
from typing import Any
from uuid import UUID
from fastapi import APIRouter, Depends, HTTPException, status, Request
from fastapi import APIRouter, Depends, HTTPException, Request, status
from slowapi import Limiter
from slowapi.util import get_remote_address
from sqlalchemy.ext.asyncio import AsyncSession
@@ -15,11 +16,11 @@ from sqlalchemy.ext.asyncio import AsyncSession
from app.api.dependencies.auth import get_current_user
from app.core.auth import decode_token
from app.core.database import get_db
from app.core.exceptions import NotFoundError, AuthorizationError, ErrorCode
from app.core.exceptions import AuthorizationError, ErrorCode, NotFoundError
from app.crud.session import session as session_crud
from app.models.user import User
from app.schemas.common import MessageResponse
from app.schemas.sessions import SessionResponse, SessionListResponse
from app.schemas.sessions import SessionListResponse, SessionResponse
router = APIRouter()
logger = logging.getLogger(__name__)
@@ -39,13 +40,13 @@ limiter = Limiter(key_func=get_remote_address)
**Rate Limit**: 30 requests/minute
""",
operation_id="list_my_sessions"
operation_id="list_my_sessions",
)
@limiter.limit("30/minute")
async def list_my_sessions(
request: Request,
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db)
db: AsyncSession = Depends(get_db),
) -> Any:
"""
List all active sessions for the current user.
@@ -60,18 +61,15 @@ async def list_my_sessions(
try:
# Get all active sessions for user
sessions = await session_crud.get_user_sessions(
db,
user_id=str(current_user.id),
active_only=True
db, user_id=str(current_user.id), active_only=True
)
# Try to identify current session from Authorization header
current_session_jti = None
auth_header = request.headers.get("authorization")
if auth_header and auth_header.startswith("Bearer "):
try:
access_token = auth_header.split(" ")[1]
token_payload = decode_token(access_token)
decode_token(access_token)
# Note: Access tokens don't have JTI by default, but we can try
# For now, we'll mark current based on most recent activity
except Exception:
@@ -90,22 +88,27 @@ async def list_my_sessions(
last_used_at=s.last_used_at,
created_at=s.created_at,
expires_at=s.expires_at,
is_current=(s == sessions[0] if sessions else False) # Most recent = current
is_current=(
s == sessions[0] if sessions else False
), # Most recent = current
)
session_responses.append(session_response)
logger.info(f"User {current_user.id} listed {len(session_responses)} active sessions")
logger.info(
f"User {current_user.id} listed {len(session_responses)} active sessions"
)
return SessionListResponse(
sessions=session_responses,
total=len(session_responses)
sessions=session_responses, total=len(session_responses)
)
except Exception as e:
logger.error(f"Error listing sessions for user {current_user.id}: {str(e)}", exc_info=True)
logger.error(
f"Error listing sessions for user {current_user.id}: {e!s}", exc_info=True
)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to retrieve sessions"
detail="Failed to retrieve sessions",
)
@@ -122,14 +125,14 @@ async def list_my_sessions(
**Rate Limit**: 10 requests/minute
""",
operation_id="revoke_session"
operation_id="revoke_session",
)
@limiter.limit("10/minute")
async def revoke_session(
request: Request,
session_id: UUID,
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db)
db: AsyncSession = Depends(get_db),
) -> Any:
"""
Revoke a specific session by ID.
@@ -149,7 +152,7 @@ async def revoke_session(
if not session:
raise NotFoundError(
message=f"Session {session_id} not found",
error_code=ErrorCode.NOT_FOUND
error_code=ErrorCode.NOT_FOUND,
)
# Verify session belongs to current user
@@ -160,7 +163,7 @@ async def revoke_session(
)
raise AuthorizationError(
message="You can only revoke your own sessions",
error_code=ErrorCode.INSUFFICIENT_PERMISSIONS
error_code=ErrorCode.INSUFFICIENT_PERMISSIONS,
)
# Deactivate the session
@@ -173,16 +176,16 @@ async def revoke_session(
return MessageResponse(
success=True,
message=f"Session revoked: {session.device_name or 'Unknown device'}"
message=f"Session revoked: {session.device_name or 'Unknown device'}",
)
except (NotFoundError, AuthorizationError):
raise
except Exception as e:
logger.error(f"Error revoking session {session_id}: {str(e)}", exc_info=True)
logger.error(f"Error revoking session {session_id}: {e!s}", exc_info=True)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to revoke session"
detail="Failed to revoke session",
)
@@ -198,13 +201,13 @@ async def revoke_session(
**Rate Limit**: 5 requests/minute
""",
operation_id="cleanup_expired_sessions"
operation_id="cleanup_expired_sessions",
)
@limiter.limit("5/minute")
async def cleanup_expired_sessions(
request: Request,
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db)
db: AsyncSession = Depends(get_db),
) -> Any:
"""
Cleanup expired sessions for the current user.
@@ -219,21 +222,24 @@ async def cleanup_expired_sessions(
try:
# Use optimized bulk DELETE instead of N individual deletes
deleted_count = await session_crud.cleanup_expired_for_user(
db,
user_id=str(current_user.id)
db, user_id=str(current_user.id)
)
logger.info(f"User {current_user.id} cleaned up {deleted_count} expired sessions")
logger.info(
f"User {current_user.id} cleaned up {deleted_count} expired sessions"
)
return MessageResponse(
success=True,
message=f"Cleaned up {deleted_count} expired sessions"
success=True, message=f"Cleaned up {deleted_count} expired sessions"
)
except Exception as e:
logger.error(f"Error cleaning up sessions for user {current_user.id}: {str(e)}", exc_info=True)
logger.error(
f"Error cleaning up sessions for user {current_user.id}: {e!s}",
exc_info=True,
)
await db.rollback()
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to cleanup sessions"
detail="Failed to cleanup sessions",
)

View File

@@ -1,33 +1,30 @@
"""
User management endpoints for CRUD operations.
"""
import logging
from typing import Any, Optional
from typing import Any
from uuid import UUID
from fastapi import APIRouter, Depends, Query, status, Request
from fastapi import APIRouter, Depends, Query, Request, status
from slowapi import Limiter
from slowapi.util import get_remote_address
from sqlalchemy.ext.asyncio import AsyncSession
from app.api.dependencies.auth import get_current_user, get_current_superuser
from app.api.dependencies.auth import get_current_superuser, get_current_user
from app.core.database import get_db
from app.core.exceptions import (
NotFoundError,
AuthorizationError,
ErrorCode
)
from app.core.exceptions import AuthorizationError, ErrorCode, NotFoundError
from app.crud.user import user as user_crud
from app.models.user import User
from app.schemas.common import (
PaginationParams,
PaginatedResponse,
MessageResponse,
PaginatedResponse,
PaginationParams,
SortParams,
create_pagination_meta
create_pagination_meta,
)
from app.schemas.users import UserResponse, UserUpdate, PasswordChange
from app.services.auth_service import AuthService, AuthenticationError
from app.schemas.users import PasswordChange, UserResponse, UserUpdate
from app.services.auth_service import AuthenticationError, AuthService
logger = logging.getLogger(__name__)
@@ -50,15 +47,15 @@ limiter = Limiter(key_func=get_remote_address)
**Rate Limit**: 60 requests/minute
""",
operation_id="list_users"
operation_id="list_users",
)
async def list_users(
pagination: PaginationParams = Depends(),
sort: SortParams = Depends(),
is_active: Optional[bool] = Query(None, description="Filter by active status"),
is_superuser: Optional[bool] = Query(None, description="Filter by superuser status"),
is_active: bool | None = Query(None, description="Filter by active status"),
is_superuser: bool | None = Query(None, description="Filter by superuser status"),
current_user: User = Depends(get_current_superuser),
db: AsyncSession = Depends(get_db)
db: AsyncSession = Depends(get_db),
) -> Any:
"""
List all users with pagination, filtering, and sorting.
@@ -80,7 +77,7 @@ async def list_users(
limit=pagination.limit,
sort_by=sort.sort_by,
sort_order=sort.sort_order.value if sort.sort_order else "asc",
filters=filters if filters else None
filters=filters if filters else None,
)
# Create pagination metadata
@@ -88,15 +85,12 @@ async def list_users(
total=total,
page=pagination.page,
limit=pagination.limit,
items_count=len(users)
items_count=len(users),
)
return PaginatedResponse(
data=users,
pagination=pagination_meta
)
return PaginatedResponse(data=users, pagination=pagination_meta)
except Exception as e:
logger.error(f"Error listing users: {str(e)}", exc_info=True)
logger.error(f"Error listing users: {e!s}", exc_info=True)
raise
@@ -111,11 +105,9 @@ async def list_users(
**Rate Limit**: 60 requests/minute
""",
operation_id="get_current_user_profile"
operation_id="get_current_user_profile",
)
def get_current_user_profile(
current_user: User = Depends(get_current_user)
) -> Any:
def get_current_user_profile(current_user: User = Depends(get_current_user)) -> Any:
"""Get current user's profile."""
return current_user
@@ -133,12 +125,12 @@ def get_current_user_profile(
**Rate Limit**: 30 requests/minute
""",
operation_id="update_current_user"
operation_id="update_current_user",
)
async def update_current_user(
user_update: UserUpdate,
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db)
db: AsyncSession = Depends(get_db),
) -> Any:
"""
Update current user's profile.
@@ -147,17 +139,17 @@ async def update_current_user(
"""
try:
updated_user = await user_crud.update(
db,
db_obj=current_user,
obj_in=user_update
db, db_obj=current_user, obj_in=user_update
)
logger.info(f"User {current_user.id} updated their profile")
return updated_user
except ValueError as e:
logger.error(f"Error updating user {current_user.id}: {str(e)}")
logger.error(f"Error updating user {current_user.id}: {e!s}")
raise
except Exception as e:
logger.error(f"Unexpected error updating user {current_user.id}: {str(e)}", exc_info=True)
logger.error(
f"Unexpected error updating user {current_user.id}: {e!s}", exc_info=True
)
raise
@@ -175,12 +167,12 @@ async def update_current_user(
**Rate Limit**: 60 requests/minute
""",
operation_id="get_user_by_id"
operation_id="get_user_by_id",
)
async def get_user_by_id(
user_id: UUID,
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db)
db: AsyncSession = Depends(get_db),
) -> Any:
"""
Get user by ID.
@@ -194,7 +186,7 @@ async def get_user_by_id(
)
raise AuthorizationError(
message="Not enough permissions to view this user",
error_code=ErrorCode.INSUFFICIENT_PERMISSIONS
error_code=ErrorCode.INSUFFICIENT_PERMISSIONS,
)
# Get user
@@ -202,7 +194,7 @@ async def get_user_by_id(
if not user:
raise NotFoundError(
message=f"User with id {user_id} not found",
error_code=ErrorCode.USER_NOT_FOUND
error_code=ErrorCode.USER_NOT_FOUND,
)
return user
@@ -222,13 +214,13 @@ async def get_user_by_id(
**Rate Limit**: 30 requests/minute
""",
operation_id="update_user"
operation_id="update_user",
)
async def update_user(
user_id: UUID,
user_update: UserUpdate,
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db)
db: AsyncSession = Depends(get_db),
) -> Any:
"""
Update user by ID.
@@ -245,7 +237,7 @@ async def update_user(
)
raise AuthorizationError(
message="Not enough permissions to update this user",
error_code=ErrorCode.INSUFFICIENT_PERMISSIONS
error_code=ErrorCode.INSUFFICIENT_PERMISSIONS,
)
# Get user
@@ -253,7 +245,7 @@ async def update_user(
if not user:
raise NotFoundError(
message=f"User with id {user_id} not found",
error_code=ErrorCode.USER_NOT_FOUND
error_code=ErrorCode.USER_NOT_FOUND,
)
try:
@@ -261,10 +253,10 @@ async def update_user(
logger.info(f"User {user_id} updated by {current_user.id}")
return updated_user
except ValueError as e:
logger.error(f"Error updating user {user_id}: {str(e)}")
logger.error(f"Error updating user {user_id}: {e!s}")
raise
except Exception as e:
logger.error(f"Unexpected error updating user {user_id}: {str(e)}", exc_info=True)
logger.error(f"Unexpected error updating user {user_id}: {e!s}", exc_info=True)
raise
@@ -281,14 +273,14 @@ async def update_user(
**Rate Limit**: 5 requests/minute
""",
operation_id="change_current_user_password"
operation_id="change_current_user_password",
)
@limiter.limit("5/minute")
async def change_current_user_password(
request: Request,
password_change: PasswordChange,
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db)
db: AsyncSession = Depends(get_db),
) -> Any:
"""
Change current user's password.
@@ -300,23 +292,23 @@ async def change_current_user_password(
db=db,
user_id=current_user.id,
current_password=password_change.current_password,
new_password=password_change.new_password
new_password=password_change.new_password,
)
if success:
logger.info(f"User {current_user.id} changed their password")
return MessageResponse(
success=True,
message="Password changed successfully"
success=True, message="Password changed successfully"
)
except AuthenticationError as e:
logger.warning(f"Failed password change attempt for user {current_user.id}: {str(e)}")
logger.warning(
f"Failed password change attempt for user {current_user.id}: {e!s}"
)
raise AuthorizationError(
message=str(e),
error_code=ErrorCode.INVALID_CREDENTIALS
message=str(e), error_code=ErrorCode.INVALID_CREDENTIALS
)
except Exception as e:
logger.error(f"Error changing password for user {current_user.id}: {str(e)}")
logger.error(f"Error changing password for user {current_user.id}: {e!s}")
raise
@@ -335,12 +327,12 @@ async def change_current_user_password(
**Note**: This performs a hard delete. Consider implementing soft deletes for production.
""",
operation_id="delete_user"
operation_id="delete_user",
)
async def delete_user(
user_id: UUID,
current_user: User = Depends(get_current_superuser),
db: AsyncSession = Depends(get_db)
db: AsyncSession = Depends(get_db),
) -> Any:
"""
Delete user by ID (superuser only).
@@ -351,7 +343,7 @@ async def delete_user(
if str(user_id) == str(current_user.id):
raise AuthorizationError(
message="Cannot delete your own account",
error_code=ErrorCode.INSUFFICIENT_PERMISSIONS
error_code=ErrorCode.INSUFFICIENT_PERMISSIONS,
)
# Get user
@@ -359,7 +351,7 @@ async def delete_user(
if not user:
raise NotFoundError(
message=f"User with id {user_id} not found",
error_code=ErrorCode.USER_NOT_FOUND
error_code=ErrorCode.USER_NOT_FOUND,
)
try:
@@ -367,12 +359,11 @@ async def delete_user(
await user_crud.soft_delete(db, id=str(user_id))
logger.info(f"User {user_id} soft-deleted by {current_user.id}")
return MessageResponse(
success=True,
message=f"User {user_id} deleted successfully"
success=True, message=f"User {user_id} deleted successfully"
)
except ValueError as e:
logger.error(f"Error deleting user {user_id}: {str(e)}")
logger.error(f"Error deleting user {user_id}: {e!s}")
raise
except Exception as e:
logger.error(f"Unexpected error deleting user {user_id}: {str(e)}", exc_info=True)
logger.error(f"Unexpected error deleting user {user_id}: {e!s}", exc_info=True)
raise

View File

@@ -1,39 +1,39 @@
import logging
logging.getLogger('passlib').setLevel(logging.ERROR)
from datetime import datetime, timedelta, timezone
from typing import Any, Dict, Optional, Union
import uuid
logging.getLogger("passlib").setLevel(logging.ERROR)
import asyncio
import uuid
from datetime import UTC, datetime, timedelta
from functools import partial
from typing import Any
from jose import jwt, JWTError
from jose import JWTError, jwt
from passlib.context import CryptContext
from pydantic import ValidationError
from app.core.config import settings
from app.schemas.users import TokenData, TokenPayload
# Password hashing context
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
# Custom exceptions for auth
class AuthError(Exception):
"""Base authentication error"""
pass
class TokenExpiredError(AuthError):
"""Token has expired"""
pass
class TokenInvalidError(AuthError):
"""Token is invalid"""
pass
class TokenMissingClaimError(AuthError):
"""Token is missing a required claim"""
pass
def verify_password(plain_password: str, hashed_password: str) -> bool:
@@ -62,8 +62,7 @@ async def verify_password_async(plain_password: str, hashed_password: str) -> bo
"""
loop = asyncio.get_event_loop()
return await loop.run_in_executor(
None,
partial(pwd_context.verify, plain_password, hashed_password)
None, partial(pwd_context.verify, plain_password, hashed_password)
)
@@ -82,17 +81,13 @@ async def get_password_hash_async(password: str) -> str:
Hashed password string
"""
loop = asyncio.get_event_loop()
return await loop.run_in_executor(
None,
pwd_context.hash,
password
)
return await loop.run_in_executor(None, pwd_context.hash, password)
def create_access_token(
subject: Union[str, Any],
expires_delta: Optional[timedelta] = None,
claims: Optional[Dict[str, Any]] = None
subject: str | Any,
expires_delta: timedelta | None = None,
claims: dict[str, Any] | None = None,
) -> str:
"""
Create a JWT access token.
@@ -106,17 +101,19 @@ def create_access_token(
Encoded JWT token
"""
if expires_delta:
expire = datetime.now(timezone.utc) + expires_delta
expire = datetime.now(UTC) + expires_delta
else:
expire = datetime.now(timezone.utc) + timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES)
expire = datetime.now(UTC) + timedelta(
minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES
)
# Base token data
to_encode = {
"sub": str(subject),
"exp": expire,
"iat": datetime.now(tz=timezone.utc),
"iat": datetime.now(tz=UTC),
"jti": str(uuid.uuid4()),
"type": "access"
"type": "access",
}
# Add custom claims
@@ -125,17 +122,14 @@ def create_access_token(
# Create the JWT
encoded_jwt = jwt.encode(
to_encode,
settings.SECRET_KEY,
algorithm=settings.ALGORITHM
to_encode, settings.SECRET_KEY, algorithm=settings.ALGORITHM
)
return encoded_jwt
def create_refresh_token(
subject: Union[str, Any],
expires_delta: Optional[timedelta] = None
subject: str | Any, expires_delta: timedelta | None = None
) -> str:
"""
Create a JWT refresh token.
@@ -148,28 +142,26 @@ def create_refresh_token(
Encoded JWT refresh token
"""
if expires_delta:
expire = datetime.now(timezone.utc) + expires_delta
expire = datetime.now(UTC) + expires_delta
else:
expire = datetime.now(timezone.utc) + timedelta(days=settings.REFRESH_TOKEN_EXPIRE_DAYS)
expire = datetime.now(UTC) + timedelta(days=settings.REFRESH_TOKEN_EXPIRE_DAYS)
to_encode = {
"sub": str(subject),
"exp": expire,
"iat": datetime.now(timezone.utc),
"iat": datetime.now(UTC),
"jti": str(uuid.uuid4()),
"type": "refresh"
"type": "refresh",
}
encoded_jwt = jwt.encode(
to_encode,
settings.SECRET_KEY,
algorithm=settings.ALGORITHM
to_encode, settings.SECRET_KEY, algorithm=settings.ALGORITHM
)
return encoded_jwt
def decode_token(token: str, verify_type: Optional[str] = None) -> TokenPayload:
def decode_token(token: str, verify_type: str | None = None) -> TokenPayload:
"""
Decode and verify a JWT token.
@@ -195,8 +187,8 @@ def decode_token(token: str, verify_type: Optional[str] = None) -> TokenPayload:
"verify_signature": True,
"verify_exp": True,
"verify_iat": True,
"require": ["exp", "sub", "iat"]
}
"require": ["exp", "sub", "iat"],
},
)
# SECURITY: Explicitly verify the algorithm to prevent algorithm confusion attacks
@@ -250,4 +242,4 @@ def get_token_data(token: str) -> TokenData:
user_id = payload.sub
is_superuser = payload.is_superuser or False
return TokenData(user_id=uuid.UUID(user_id), is_superuser=is_superuser)
return TokenData(user_id=uuid.UUID(user_id), is_superuser=is_superuser)

View File

@@ -1,5 +1,4 @@
import logging
from typing import Optional, List
from pydantic import Field, field_validator
from pydantic_settings import BaseSettings
@@ -13,7 +12,7 @@ class Settings(BaseSettings):
# Environment (must be before SECRET_KEY for validation)
ENVIRONMENT: str = Field(
default="development",
description="Environment: development, staging, or production"
description="Environment: development, staging, or production",
)
# Security: Content Security Policy
@@ -21,8 +20,7 @@ class Settings(BaseSettings):
# Set to True for strict CSP (blocks most external resources)
# Set to "relaxed" for modern frontend development
CSP_MODE: str = Field(
default="relaxed",
description="CSP mode: 'strict', 'relaxed', or 'disabled'"
default="relaxed", description="CSP mode: 'strict', 'relaxed', or 'disabled'"
)
# Database configuration
@@ -31,7 +29,7 @@ class Settings(BaseSettings):
POSTGRES_HOST: str = "localhost"
POSTGRES_PORT: str = "5432"
POSTGRES_DB: str = "app"
DATABASE_URL: Optional[str] = None
DATABASE_URL: str | None = None
db_pool_size: int = 20 # Default connection pool size
db_max_overflow: int = 50 # Maximum overflow connections
db_pool_timeout: int = 30 # Seconds to wait for a connection
@@ -59,38 +57,36 @@ class Settings(BaseSettings):
SECRET_KEY: str = Field(
default="dev_only_insecure_key_change_in_production_32chars_min",
min_length=32,
description="JWT signing key. MUST be changed in production. Generate with: python -c 'import secrets; print(secrets.token_urlsafe(32))'"
description="JWT signing key. MUST be changed in production. Generate with: python -c 'import secrets; print(secrets.token_urlsafe(32))'",
)
ALGORITHM: str = "HS256"
ACCESS_TOKEN_EXPIRE_MINUTES: int = 15 # 15 minutes (production standard)
REFRESH_TOKEN_EXPIRE_DAYS: int = 7 # 7 days
# CORS configuration
BACKEND_CORS_ORIGINS: List[str] = ["http://localhost:3000"]
BACKEND_CORS_ORIGINS: list[str] = ["http://localhost:3000"]
# Frontend URL for email links
FRONTEND_URL: str = Field(
default="http://localhost:3000",
description="Frontend application URL for email links"
description="Frontend application URL for email links",
)
# Admin user
FIRST_SUPERUSER_EMAIL: Optional[str] = Field(
default=None,
description="Email for first superuser account"
FIRST_SUPERUSER_EMAIL: str | None = Field(
default=None, description="Email for first superuser account"
)
FIRST_SUPERUSER_PASSWORD: Optional[str] = Field(
default=None,
description="Password for first superuser (min 12 characters)"
FIRST_SUPERUSER_PASSWORD: str | None = Field(
default=None, description="Password for first superuser (min 12 characters)"
)
@field_validator('SECRET_KEY')
@field_validator("SECRET_KEY")
@classmethod
def validate_secret_key(cls, v: str, info) -> str:
"""Validate SECRET_KEY is secure, especially in production."""
# Get environment from values if available
values_data = info.data if info.data else {}
env = values_data.get('ENVIRONMENT', 'development')
env = values_data.get("ENVIRONMENT", "development")
if v.startswith("your_secret_key_here"):
if env == "production":
@@ -106,13 +102,15 @@ class Settings(BaseSettings):
)
if len(v) < 32:
raise ValueError("SECRET_KEY must be at least 32 characters long for security")
raise ValueError(
"SECRET_KEY must be at least 32 characters long for security"
)
return v
@field_validator('FIRST_SUPERUSER_PASSWORD')
@field_validator("FIRST_SUPERUSER_PASSWORD")
@classmethod
def validate_superuser_password(cls, v: Optional[str]) -> Optional[str]:
def validate_superuser_password(cls, v: str | None) -> str | None:
"""Validate superuser password strength."""
if v is None:
return v
@@ -121,7 +119,13 @@ class Settings(BaseSettings):
raise ValueError("FIRST_SUPERUSER_PASSWORD must be at least 12 characters")
# Check for common weak passwords
weak_passwords = {'admin123', 'Admin123', 'password123', 'Password123', '123456789012'}
weak_passwords = {
"admin123",
"Admin123",
"password123",
"Password123",
"123456789012",
}
if v in weak_passwords:
raise ValueError(
"FIRST_SUPERUSER_PASSWORD is too weak. "
@@ -144,8 +148,8 @@ class Settings(BaseSettings):
"env_file": "../.env",
"env_file_encoding": "utf-8",
"case_sensitive": True,
"extra": "ignore" # Ignore extra fields from .env (e.g., frontend-specific vars)
"extra": "ignore", # Ignore extra fields from .env (e.g., frontend-specific vars)
}
settings = Settings()
settings = Settings()

View File

@@ -5,17 +5,18 @@ Database configuration using SQLAlchemy 2.0 and asyncpg.
This module provides async database connectivity with proper connection pooling
and session management for FastAPI endpoints.
"""
import logging
from collections.abc import AsyncGenerator
from contextlib import asynccontextmanager
from typing import AsyncGenerator
from sqlalchemy import text
from sqlalchemy.dialects.postgresql import JSONB, UUID
from sqlalchemy.ext.asyncio import (
AsyncSession,
AsyncEngine,
create_async_engine,
AsyncSession,
async_sessionmaker,
create_async_engine,
)
from sqlalchemy.ext.compiler import compiles
from sqlalchemy.orm import DeclarativeBase
@@ -27,12 +28,12 @@ logger = logging.getLogger(__name__)
# SQLite compatibility for testing
@compiles(JSONB, 'sqlite')
@compiles(JSONB, "sqlite")
def compile_jsonb_sqlite(type_, compiler, **kw):
return "TEXT"
@compiles(UUID, 'sqlite')
@compiles(UUID, "sqlite")
def compile_uuid_sqlite(type_, compiler, **kw):
return "TEXT"
@@ -40,7 +41,6 @@ def compile_uuid_sqlite(type_, compiler, **kw):
# Declarative base for models (SQLAlchemy 2.0 style)
class Base(DeclarativeBase):
"""Base class for all database models."""
pass
def get_async_database_url(url: str) -> str:
@@ -139,7 +139,7 @@ async def async_transaction_scope() -> AsyncGenerator[AsyncSession, None]:
logger.debug("Async transaction committed successfully")
except Exception as e:
await session.rollback()
logger.error(f"Async transaction failed, rolling back: {str(e)}")
logger.error(f"Async transaction failed, rolling back: {e!s}")
raise
finally:
await session.close()
@@ -155,7 +155,7 @@ async def check_async_database_health() -> bool:
await db.execute(text("SELECT 1"))
return True
except Exception as e:
logger.error(f"Async database health check failed: {str(e)}")
logger.error(f"Async database health check failed: {e!s}")
return False

View File

@@ -1,8 +1,8 @@
"""
Custom exceptions and global exception handlers for the API.
"""
import logging
from typing import Optional, Union
from fastapi import HTTPException, Request, status
from fastapi.exceptions import RequestValidationError
@@ -27,17 +27,13 @@ class APIException(HTTPException):
status_code: int,
error_code: ErrorCode,
message: str,
field: Optional[str] = None,
headers: Optional[dict] = None
field: str | None = None,
headers: dict | None = None,
):
self.error_code = error_code
self.field = field
self.message = message
super().__init__(
status_code=status_code,
detail=message,
headers=headers
)
super().__init__(status_code=status_code, detail=message, headers=headers)
class AuthenticationError(APIException):
@@ -47,14 +43,14 @@ class AuthenticationError(APIException):
self,
message: str = "Authentication failed",
error_code: ErrorCode = ErrorCode.INVALID_CREDENTIALS,
field: Optional[str] = None
field: str | None = None,
):
super().__init__(
status_code=status.HTTP_401_UNAUTHORIZED,
error_code=error_code,
message=message,
field=field,
headers={"WWW-Authenticate": "Bearer"}
headers={"WWW-Authenticate": "Bearer"},
)
@@ -64,12 +60,12 @@ class AuthorizationError(APIException):
def __init__(
self,
message: str = "Insufficient permissions",
error_code: ErrorCode = ErrorCode.INSUFFICIENT_PERMISSIONS
error_code: ErrorCode = ErrorCode.INSUFFICIENT_PERMISSIONS,
):
super().__init__(
status_code=status.HTTP_403_FORBIDDEN,
error_code=error_code,
message=message
message=message,
)
@@ -79,12 +75,12 @@ class NotFoundError(APIException):
def __init__(
self,
message: str = "Resource not found",
error_code: ErrorCode = ErrorCode.NOT_FOUND
error_code: ErrorCode = ErrorCode.NOT_FOUND,
):
super().__init__(
status_code=status.HTTP_404_NOT_FOUND,
error_code=error_code,
message=message
message=message,
)
@@ -95,13 +91,13 @@ class DuplicateError(APIException):
self,
message: str = "Resource already exists",
error_code: ErrorCode = ErrorCode.DUPLICATE_ENTRY,
field: Optional[str] = None
field: str | None = None,
):
super().__init__(
status_code=status.HTTP_409_CONFLICT,
error_code=error_code,
message=message,
field=field
field=field,
)
@@ -112,13 +108,13 @@ class ValidationException(APIException):
self,
message: str = "Validation error",
error_code: ErrorCode = ErrorCode.VALIDATION_ERROR,
field: Optional[str] = None
field: str | None = None,
):
super().__init__(
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
error_code=error_code,
message=message,
field=field
field=field,
)
@@ -128,12 +124,12 @@ class DatabaseError(APIException):
def __init__(
self,
message: str = "Database operation failed",
error_code: ErrorCode = ErrorCode.DATABASE_ERROR
error_code: ErrorCode = ErrorCode.DATABASE_ERROR,
):
super().__init__(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
error_code=error_code,
message=message
message=message,
)
@@ -152,23 +148,18 @@ async def api_exception_handler(request: Request, exc: APIException) -> JSONResp
)
error_response = ErrorResponse(
errors=[ErrorDetail(
code=exc.error_code,
message=exc.message,
field=exc.field
)]
errors=[ErrorDetail(code=exc.error_code, message=exc.message, field=exc.field)]
)
return JSONResponse(
status_code=exc.status_code,
content=error_response.model_dump(),
headers=exc.headers
headers=exc.headers,
)
async def validation_exception_handler(
request: Request,
exc: Union[RequestValidationError, ValidationError]
request: Request, exc: RequestValidationError | ValidationError
) -> JSONResponse:
"""
Handler for Pydantic validation errors.
@@ -189,22 +180,19 @@ async def validation_exception_handler(
# Skip 'body' or 'query' prefix in location
field = ".".join(str(x) for x in error["loc"][1:])
errors.append(ErrorDetail(
code=ErrorCode.VALIDATION_ERROR,
message=error["msg"],
field=field
))
errors.append(
ErrorDetail(
code=ErrorCode.VALIDATION_ERROR, message=error["msg"], field=field
)
)
logger.warning(
f"Validation error: {len(errors)} errors "
f"(path: {request.url.path})"
)
logger.warning(f"Validation error: {len(errors)} errors (path: {request.url.path})")
error_response = ErrorResponse(errors=errors)
return JSONResponse(
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
content=error_response.model_dump()
content=error_response.model_dump(),
)
@@ -226,26 +214,21 @@ async def http_exception_handler(request: Request, exc: HTTPException) -> JSONRe
}
error_code = status_code_to_error_code.get(
exc.status_code,
ErrorCode.INTERNAL_ERROR
exc.status_code, ErrorCode.INTERNAL_ERROR
)
logger.warning(
f"HTTP exception: {exc.status_code} - {exc.detail} "
f"(path: {request.url.path})"
f"HTTP exception: {exc.status_code} - {exc.detail} (path: {request.url.path})"
)
error_response = ErrorResponse(
errors=[ErrorDetail(
code=error_code,
message=str(exc.detail)
)]
errors=[ErrorDetail(code=error_code, message=str(exc.detail))]
)
return JSONResponse(
status_code=exc.status_code,
content=error_response.model_dump(),
headers=exc.headers
headers=exc.headers,
)
@@ -257,26 +240,24 @@ async def unhandled_exception_handler(request: Request, exc: Exception) -> JSONR
leaking sensitive information in production.
"""
logger.error(
f"Unhandled exception: {type(exc).__name__} - {str(exc)} "
f"Unhandled exception: {type(exc).__name__} - {exc!s} "
f"(path: {request.url.path})",
exc_info=True
exc_info=True,
)
# In production, don't expose internal error details
from app.core.config import settings
if settings.ENVIRONMENT == "production":
message = "An internal error occurred. Please try again later."
else:
message = f"{type(exc).__name__}: {str(exc)}"
message = f"{type(exc).__name__}: {exc!s}"
error_response = ErrorResponse(
errors=[ErrorDetail(
code=ErrorCode.INTERNAL_ERROR,
message=message
)]
errors=[ErrorDetail(code=ErrorCode.INTERNAL_ERROR, message=message)]
)
return JSONResponse(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
content=error_response.model_dump()
content=error_response.model_dump(),
)

View File

@@ -3,4 +3,4 @@ from .organization import organization
from .session import session as session_crud
from .user import user
__all__ = ["user", "session_crud", "organization"]
__all__ = ["organization", "session_crud", "user"]

View File

@@ -4,14 +4,16 @@ Async CRUD operations base class using SQLAlchemy 2.0 async patterns.
Provides reusable create, read, update, and delete operations for all models.
"""
import logging
import uuid
from typing import Any, Dict, Generic, List, Optional, Type, TypeVar, Union, Tuple
from datetime import UTC
from typing import Any, TypeVar
from fastapi.encoders import jsonable_encoder
from pydantic import BaseModel
from sqlalchemy import func, select
from sqlalchemy.exc import IntegrityError, OperationalError, DataError
from sqlalchemy.exc import DataError, IntegrityError, OperationalError
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import Load
@@ -24,10 +26,14 @@ CreateSchemaType = TypeVar("CreateSchemaType", bound=BaseModel)
UpdateSchemaType = TypeVar("UpdateSchemaType", bound=BaseModel)
class CRUDBase(Generic[ModelType, CreateSchemaType, UpdateSchemaType]):
class CRUDBase[
ModelType: Base,
CreateSchemaType: BaseModel,
UpdateSchemaType: BaseModel,
]:
"""Async CRUD operations for a model."""
def __init__(self, model: Type[ModelType]):
def __init__(self, model: type[ModelType]):
"""
CRUD object with default async methods to Create, Read, Update, Delete.
@@ -37,11 +43,8 @@ class CRUDBase(Generic[ModelType, CreateSchemaType, UpdateSchemaType]):
self.model = model
async def get(
self,
db: AsyncSession,
id: str,
options: Optional[List[Load]] = None
) -> Optional[ModelType]:
self, db: AsyncSession, id: str, options: list[Load] | None = None
) -> ModelType | None:
"""
Get a single record by ID with UUID validation and optional eager loading.
@@ -66,7 +69,7 @@ class CRUDBase(Generic[ModelType, CreateSchemaType, UpdateSchemaType]):
else:
uuid_obj = uuid.UUID(str(id))
except (ValueError, AttributeError, TypeError) as e:
logger.warning(f"Invalid UUID format: {id} - {str(e)}")
logger.warning(f"Invalid UUID format: {id} - {e!s}")
return None
try:
@@ -80,7 +83,7 @@ class CRUDBase(Generic[ModelType, CreateSchemaType, UpdateSchemaType]):
result = await db.execute(query)
return result.scalar_one_or_none()
except Exception as e:
logger.error(f"Error retrieving {self.model.__name__} with id {id}: {str(e)}")
logger.error(f"Error retrieving {self.model.__name__} with id {id}: {e!s}")
raise
async def get_multi(
@@ -89,8 +92,8 @@ class CRUDBase(Generic[ModelType, CreateSchemaType, UpdateSchemaType]):
*,
skip: int = 0,
limit: int = 100,
options: Optional[List[Load]] = None
) -> List[ModelType]:
options: list[Load] | None = None,
) -> list[ModelType]:
"""
Get multiple records with pagination validation and optional eager loading.
@@ -122,10 +125,14 @@ class CRUDBase(Generic[ModelType, CreateSchemaType, UpdateSchemaType]):
result = await db.execute(query)
return list(result.scalars().all())
except Exception as e:
logger.error(f"Error retrieving multiple {self.model.__name__} records: {str(e)}")
logger.error(
f"Error retrieving multiple {self.model.__name__} records: {e!s}"
)
raise
async def create(self, db: AsyncSession, *, obj_in: CreateSchemaType) -> ModelType: # pragma: no cover
async def create(
self, db: AsyncSession, *, obj_in: CreateSchemaType
) -> ModelType: # pragma: no cover
"""Create a new record with error handling.
NOTE: This method is defensive code that's never called in practice.
@@ -142,19 +149,25 @@ class CRUDBase(Generic[ModelType, CreateSchemaType, UpdateSchemaType]):
return db_obj
except IntegrityError as e: # pragma: no cover
await db.rollback()
error_msg = str(e.orig) if hasattr(e, 'orig') else str(e)
error_msg = str(e.orig) if hasattr(e, "orig") else str(e)
if "unique" in error_msg.lower() or "duplicate" in error_msg.lower():
logger.warning(f"Duplicate entry attempted for {self.model.__name__}: {error_msg}")
raise ValueError(f"A {self.model.__name__} with this data already exists")
logger.warning(
f"Duplicate entry attempted for {self.model.__name__}: {error_msg}"
)
raise ValueError(
f"A {self.model.__name__} with this data already exists"
)
logger.error(f"Integrity error creating {self.model.__name__}: {error_msg}")
raise ValueError(f"Database integrity error: {error_msg}")
except (OperationalError, DataError) as e: # pragma: no cover
await db.rollback()
logger.error(f"Database error creating {self.model.__name__}: {str(e)}")
raise ValueError(f"Database operation failed: {str(e)}")
logger.error(f"Database error creating {self.model.__name__}: {e!s}")
raise ValueError(f"Database operation failed: {e!s}")
except Exception as e: # pragma: no cover
await db.rollback()
logger.error(f"Unexpected error creating {self.model.__name__}: {str(e)}", exc_info=True)
logger.error(
f"Unexpected error creating {self.model.__name__}: {e!s}", exc_info=True
)
raise
async def update(
@@ -162,7 +175,7 @@ class CRUDBase(Generic[ModelType, CreateSchemaType, UpdateSchemaType]):
db: AsyncSession,
*,
db_obj: ModelType,
obj_in: Union[UpdateSchemaType, Dict[str, Any]]
obj_in: UpdateSchemaType | dict[str, Any],
) -> ModelType:
"""Update a record with error handling."""
try:
@@ -182,22 +195,28 @@ class CRUDBase(Generic[ModelType, CreateSchemaType, UpdateSchemaType]):
return db_obj
except IntegrityError as e:
await db.rollback()
error_msg = str(e.orig) if hasattr(e, 'orig') else str(e)
error_msg = str(e.orig) if hasattr(e, "orig") else str(e)
if "unique" in error_msg.lower() or "duplicate" in error_msg.lower():
logger.warning(f"Duplicate entry attempted for {self.model.__name__}: {error_msg}")
raise ValueError(f"A {self.model.__name__} with this data already exists")
logger.warning(
f"Duplicate entry attempted for {self.model.__name__}: {error_msg}"
)
raise ValueError(
f"A {self.model.__name__} with this data already exists"
)
logger.error(f"Integrity error updating {self.model.__name__}: {error_msg}")
raise ValueError(f"Database integrity error: {error_msg}")
except (OperationalError, DataError) as e:
await db.rollback()
logger.error(f"Database error updating {self.model.__name__}: {str(e)}")
raise ValueError(f"Database operation failed: {str(e)}")
logger.error(f"Database error updating {self.model.__name__}: {e!s}")
raise ValueError(f"Database operation failed: {e!s}")
except Exception as e:
await db.rollback()
logger.error(f"Unexpected error updating {self.model.__name__}: {str(e)}", exc_info=True)
logger.error(
f"Unexpected error updating {self.model.__name__}: {e!s}", exc_info=True
)
raise
async def remove(self, db: AsyncSession, *, id: str) -> Optional[ModelType]:
async def remove(self, db: AsyncSession, *, id: str) -> ModelType | None:
"""Delete a record with error handling and null check."""
# Validate UUID format and convert to UUID object if string
try:
@@ -206,7 +225,7 @@ class CRUDBase(Generic[ModelType, CreateSchemaType, UpdateSchemaType]):
else:
uuid_obj = uuid.UUID(str(id))
except (ValueError, AttributeError, TypeError) as e:
logger.warning(f"Invalid UUID format for deletion: {id} - {str(e)}")
logger.warning(f"Invalid UUID format for deletion: {id} - {e!s}")
return None
try:
@@ -216,7 +235,9 @@ class CRUDBase(Generic[ModelType, CreateSchemaType, UpdateSchemaType]):
obj = result.scalar_one_or_none()
if obj is None:
logger.warning(f"{self.model.__name__} with id {id} not found for deletion")
logger.warning(
f"{self.model.__name__} with id {id} not found for deletion"
)
return None
await db.delete(obj)
@@ -224,12 +245,17 @@ class CRUDBase(Generic[ModelType, CreateSchemaType, UpdateSchemaType]):
return obj
except IntegrityError as e:
await db.rollback()
error_msg = str(e.orig) if hasattr(e, 'orig') else str(e)
error_msg = str(e.orig) if hasattr(e, "orig") else str(e)
logger.error(f"Integrity error deleting {self.model.__name__}: {error_msg}")
raise ValueError(f"Cannot delete {self.model.__name__}: referenced by other records")
raise ValueError(
f"Cannot delete {self.model.__name__}: referenced by other records"
)
except Exception as e:
await db.rollback()
logger.error(f"Error deleting {self.model.__name__} with id {id}: {str(e)}", exc_info=True)
logger.error(
f"Error deleting {self.model.__name__} with id {id}: {e!s}",
exc_info=True,
)
raise
async def get_multi_with_total(
@@ -238,10 +264,10 @@ class CRUDBase(Generic[ModelType, CreateSchemaType, UpdateSchemaType]):
*,
skip: int = 0,
limit: int = 100,
sort_by: Optional[str] = None,
sort_by: str | None = None,
sort_order: str = "asc",
filters: Optional[Dict[str, Any]] = None
) -> Tuple[List[ModelType], int]:
filters: dict[str, Any] | None = None,
) -> tuple[list[ModelType], int]:
"""
Get multiple records with total count, filtering, and sorting.
@@ -269,7 +295,7 @@ class CRUDBase(Generic[ModelType, CreateSchemaType, UpdateSchemaType]):
query = select(self.model)
# Exclude soft-deleted records by default
if hasattr(self.model, 'deleted_at'):
if hasattr(self.model, "deleted_at"):
query = query.where(self.model.deleted_at.is_(None))
# Apply filters
@@ -298,7 +324,9 @@ class CRUDBase(Generic[ModelType, CreateSchemaType, UpdateSchemaType]):
return items, total
except Exception as e:
logger.error(f"Error retrieving paginated {self.model.__name__} records: {str(e)}")
logger.error(
f"Error retrieving paginated {self.model.__name__} records: {e!s}"
)
raise
async def count(self, db: AsyncSession) -> int:
@@ -307,7 +335,7 @@ class CRUDBase(Generic[ModelType, CreateSchemaType, UpdateSchemaType]):
result = await db.execute(select(func.count(self.model.id)))
return result.scalar_one()
except Exception as e:
logger.error(f"Error counting {self.model.__name__} records: {str(e)}")
logger.error(f"Error counting {self.model.__name__} records: {e!s}")
raise
async def exists(self, db: AsyncSession, id: str) -> bool:
@@ -315,13 +343,13 @@ class CRUDBase(Generic[ModelType, CreateSchemaType, UpdateSchemaType]):
obj = await self.get(db, id=id)
return obj is not None
async def soft_delete(self, db: AsyncSession, *, id: str) -> Optional[ModelType]:
async def soft_delete(self, db: AsyncSession, *, id: str) -> ModelType | None:
"""
Soft delete a record by setting deleted_at timestamp.
Only works if the model has a 'deleted_at' column.
"""
from datetime import datetime, timezone
from datetime import datetime
# Validate UUID format and convert to UUID object if string
try:
@@ -330,7 +358,7 @@ class CRUDBase(Generic[ModelType, CreateSchemaType, UpdateSchemaType]):
else:
uuid_obj = uuid.UUID(str(id))
except (ValueError, AttributeError, TypeError) as e:
logger.warning(f"Invalid UUID format for soft deletion: {id} - {str(e)}")
logger.warning(f"Invalid UUID format for soft deletion: {id} - {e!s}")
return None
try:
@@ -340,26 +368,33 @@ class CRUDBase(Generic[ModelType, CreateSchemaType, UpdateSchemaType]):
obj = result.scalar_one_or_none()
if obj is None:
logger.warning(f"{self.model.__name__} with id {id} not found for soft deletion")
logger.warning(
f"{self.model.__name__} with id {id} not found for soft deletion"
)
return None
# Check if model supports soft deletes
if not hasattr(self.model, 'deleted_at'):
if not hasattr(self.model, "deleted_at"):
logger.error(f"{self.model.__name__} does not support soft deletes")
raise ValueError(f"{self.model.__name__} does not have a deleted_at column")
raise ValueError(
f"{self.model.__name__} does not have a deleted_at column"
)
# Set deleted_at timestamp
obj.deleted_at = datetime.now(timezone.utc)
obj.deleted_at = datetime.now(UTC)
db.add(obj)
await db.commit()
await db.refresh(obj)
return obj
except Exception as e:
await db.rollback()
logger.error(f"Error soft deleting {self.model.__name__} with id {id}: {str(e)}", exc_info=True)
logger.error(
f"Error soft deleting {self.model.__name__} with id {id}: {e!s}",
exc_info=True,
)
raise
async def restore(self, db: AsyncSession, *, id: str) -> Optional[ModelType]:
async def restore(self, db: AsyncSession, *, id: str) -> ModelType | None:
"""
Restore a soft-deleted record by clearing the deleted_at timestamp.
@@ -372,25 +407,28 @@ class CRUDBase(Generic[ModelType, CreateSchemaType, UpdateSchemaType]):
else:
uuid_obj = uuid.UUID(str(id))
except (ValueError, AttributeError, TypeError) as e:
logger.warning(f"Invalid UUID format for restoration: {id} - {str(e)}")
logger.warning(f"Invalid UUID format for restoration: {id} - {e!s}")
return None
try:
# Find the soft-deleted record
if hasattr(self.model, 'deleted_at'):
if hasattr(self.model, "deleted_at"):
result = await db.execute(
select(self.model).where(
self.model.id == uuid_obj,
self.model.deleted_at.isnot(None)
self.model.id == uuid_obj, self.model.deleted_at.isnot(None)
)
)
obj = result.scalar_one_or_none()
else:
logger.error(f"{self.model.__name__} does not support soft deletes")
raise ValueError(f"{self.model.__name__} does not have a deleted_at column")
raise ValueError(
f"{self.model.__name__} does not have a deleted_at column"
)
if obj is None:
logger.warning(f"Soft-deleted {self.model.__name__} with id {id} not found for restoration")
logger.warning(
f"Soft-deleted {self.model.__name__} with id {id} not found for restoration"
)
return None
# Clear deleted_at timestamp
@@ -401,5 +439,8 @@ class CRUDBase(Generic[ModelType, CreateSchemaType, UpdateSchemaType]):
return obj
except Exception as e:
await db.rollback()
logger.error(f"Error restoring {self.model.__name__} with id {id}: {str(e)}", exc_info=True)
logger.error(
f"Error restoring {self.model.__name__} with id {id}: {e!s}",
exc_info=True,
)
raise

View File

@@ -1,17 +1,18 @@
# app/crud/organization_async.py
"""Async CRUD operations for Organization model using SQLAlchemy 2.0 patterns."""
import logging
from typing import Optional, List, Dict, Any
from typing import Any
from uuid import UUID
from sqlalchemy import func, or_, and_, select, case
from sqlalchemy import and_, case, func, or_, select
from sqlalchemy.exc import IntegrityError
from sqlalchemy.ext.asyncio import AsyncSession
from app.crud.base import CRUDBase
from app.models.organization import Organization
from app.models.user import User
from app.models.user_organization import UserOrganization, OrganizationRole
from app.models.user_organization import OrganizationRole, UserOrganization
from app.schemas.organizations import (
OrganizationCreate,
OrganizationUpdate,
@@ -23,7 +24,7 @@ logger = logging.getLogger(__name__)
class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUpdate]):
"""Async CRUD operations for Organization model."""
async def get_by_slug(self, db: AsyncSession, *, slug: str) -> Optional[Organization]:
async def get_by_slug(self, db: AsyncSession, *, slug: str) -> Organization | None:
"""Get organization by slug."""
try:
result = await db.execute(
@@ -31,10 +32,12 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp
)
return result.scalar_one_or_none()
except Exception as e:
logger.error(f"Error getting organization by slug {slug}: {str(e)}")
logger.error(f"Error getting organization by slug {slug}: {e!s}")
raise
async def create(self, db: AsyncSession, *, obj_in: OrganizationCreate) -> Organization:
async def create(
self, db: AsyncSession, *, obj_in: OrganizationCreate
) -> Organization:
"""Create a new organization with error handling."""
try:
db_obj = Organization(
@@ -42,7 +45,7 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp
slug=obj_in.slug,
description=obj_in.description,
is_active=obj_in.is_active,
settings=obj_in.settings or {}
settings=obj_in.settings or {},
)
db.add(db_obj)
await db.commit()
@@ -50,15 +53,19 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp
return db_obj
except IntegrityError as e:
await db.rollback()
error_msg = str(e.orig) if hasattr(e, 'orig') else str(e)
error_msg = str(e.orig) if hasattr(e, "orig") else str(e)
if "slug" in error_msg.lower():
logger.warning(f"Duplicate slug attempted: {obj_in.slug}")
raise ValueError(f"Organization with slug '{obj_in.slug}' already exists")
raise ValueError(
f"Organization with slug '{obj_in.slug}' already exists"
)
logger.error(f"Integrity error creating organization: {error_msg}")
raise ValueError(f"Database integrity error: {error_msg}")
except Exception as e:
await db.rollback()
logger.error(f"Unexpected error creating organization: {str(e)}", exc_info=True)
logger.error(
f"Unexpected error creating organization: {e!s}", exc_info=True
)
raise
async def get_multi_with_filters(
@@ -67,11 +74,11 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp
*,
skip: int = 0,
limit: int = 100,
is_active: Optional[bool] = None,
search: Optional[str] = None,
is_active: bool | None = None,
search: str | None = None,
sort_by: str = "created_at",
sort_order: str = "desc"
) -> tuple[List[Organization], int]:
sort_order: str = "desc",
) -> tuple[list[Organization], int]:
"""
Get multiple organizations with filtering, searching, and sorting.
@@ -89,7 +96,7 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp
search_filter = or_(
Organization.name.ilike(f"%{search}%"),
Organization.slug.ilike(f"%{search}%"),
Organization.description.ilike(f"%{search}%")
Organization.description.ilike(f"%{search}%"),
)
query = query.where(search_filter)
@@ -112,7 +119,7 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp
return organizations, total
except Exception as e:
logger.error(f"Error getting organizations with filters: {str(e)}")
logger.error(f"Error getting organizations with filters: {e!s}")
raise
async def get_member_count(self, db: AsyncSession, *, organization_id: UUID) -> int:
@@ -122,13 +129,15 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp
select(func.count(UserOrganization.user_id)).where(
and_(
UserOrganization.organization_id == organization_id,
UserOrganization.is_active == True
UserOrganization.is_active,
)
)
)
return result.scalar_one() or 0
except Exception as e:
logger.error(f"Error getting member count for organization {organization_id}: {str(e)}")
logger.error(
f"Error getting member count for organization {organization_id}: {e!s}"
)
raise
async def get_multi_with_member_counts(
@@ -137,9 +146,9 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp
*,
skip: int = 0,
limit: int = 100,
is_active: Optional[bool] = None,
search: Optional[str] = None
) -> tuple[List[Dict[str, Any]], int]:
is_active: bool | None = None,
search: str | None = None,
) -> tuple[list[dict[str, Any]], int]:
"""
Get organizations with member counts in a SINGLE QUERY using JOIN and GROUP BY.
This eliminates the N+1 query problem.
@@ -156,13 +165,19 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp
func.count(
func.distinct(
case(
(UserOrganization.is_active == True, UserOrganization.user_id),
else_=None
(
UserOrganization.is_active,
UserOrganization.user_id,
),
else_=None,
)
)
).label('member_count')
).label("member_count"),
)
.outerjoin(
UserOrganization,
Organization.id == UserOrganization.organization_id,
)
.outerjoin(UserOrganization, Organization.id == UserOrganization.organization_id)
.group_by(Organization.id)
)
@@ -174,7 +189,7 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp
search_filter = or_(
Organization.name.ilike(f"%{search}%"),
Organization.slug.ilike(f"%{search}%"),
Organization.description.ilike(f"%{search}%")
Organization.description.ilike(f"%{search}%"),
)
query = query.where(search_filter)
@@ -189,24 +204,25 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp
total = count_result.scalar_one()
# Apply pagination and ordering
query = query.order_by(Organization.created_at.desc()).offset(skip).limit(limit)
query = (
query.order_by(Organization.created_at.desc()).offset(skip).limit(limit)
)
result = await db.execute(query)
rows = result.all()
# Convert to list of dicts
orgs_with_counts = [
{
'organization': org,
'member_count': member_count
}
{"organization": org, "member_count": member_count}
for org, member_count in rows
]
return orgs_with_counts, total
except Exception as e:
logger.error(f"Error getting organizations with member counts: {str(e)}", exc_info=True)
logger.error(
f"Error getting organizations with member counts: {e!s}", exc_info=True
)
raise
async def add_user(
@@ -216,7 +232,7 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp
organization_id: UUID,
user_id: UUID,
role: OrganizationRole = OrganizationRole.MEMBER,
custom_permissions: Optional[str] = None
custom_permissions: str | None = None,
) -> UserOrganization:
"""Add a user to an organization with a specific role."""
try:
@@ -225,7 +241,7 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp
select(UserOrganization).where(
and_(
UserOrganization.user_id == user_id,
UserOrganization.organization_id == organization_id
UserOrganization.organization_id == organization_id,
)
)
)
@@ -249,7 +265,7 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp
organization_id=organization_id,
role=role,
is_active=True,
custom_permissions=custom_permissions
custom_permissions=custom_permissions,
)
db.add(user_org)
await db.commit()
@@ -257,19 +273,15 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp
return user_org
except IntegrityError as e:
await db.rollback()
logger.error(f"Integrity error adding user to organization: {str(e)}")
logger.error(f"Integrity error adding user to organization: {e!s}")
raise ValueError("Failed to add user to organization")
except Exception as e:
await db.rollback()
logger.error(f"Error adding user to organization: {str(e)}", exc_info=True)
logger.error(f"Error adding user to organization: {e!s}", exc_info=True)
raise
async def remove_user(
self,
db: AsyncSession,
*,
organization_id: UUID,
user_id: UUID
self, db: AsyncSession, *, organization_id: UUID, user_id: UUID
) -> bool:
"""Remove a user from an organization (soft delete)."""
try:
@@ -277,7 +289,7 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp
select(UserOrganization).where(
and_(
UserOrganization.user_id == user_id,
UserOrganization.organization_id == organization_id
UserOrganization.organization_id == organization_id,
)
)
)
@@ -291,7 +303,7 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp
return True
except Exception as e:
await db.rollback()
logger.error(f"Error removing user from organization: {str(e)}", exc_info=True)
logger.error(f"Error removing user from organization: {e!s}", exc_info=True)
raise
async def update_user_role(
@@ -301,15 +313,15 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp
organization_id: UUID,
user_id: UUID,
role: OrganizationRole,
custom_permissions: Optional[str] = None
) -> Optional[UserOrganization]:
custom_permissions: str | None = None,
) -> UserOrganization | None:
"""Update a user's role in an organization."""
try:
result = await db.execute(
select(UserOrganization).where(
and_(
UserOrganization.user_id == user_id,
UserOrganization.organization_id == organization_id
UserOrganization.organization_id == organization_id,
)
)
)
@@ -326,7 +338,7 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp
return user_org
except Exception as e:
await db.rollback()
logger.error(f"Error updating user role: {str(e)}", exc_info=True)
logger.error(f"Error updating user role: {e!s}", exc_info=True)
raise
async def get_organization_members(
@@ -336,8 +348,8 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp
organization_id: UUID,
skip: int = 0,
limit: int = 100,
is_active: bool = True
) -> tuple[List[Dict[str, Any]], int]:
is_active: bool = True,
) -> tuple[list[dict[str, Any]], int]:
"""
Get members of an organization with user details.
@@ -359,46 +371,55 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp
count_query = select(func.count()).select_from(
select(UserOrganization)
.where(UserOrganization.organization_id == organization_id)
.where(UserOrganization.is_active == is_active if is_active is not None else True)
.where(
UserOrganization.is_active == is_active
if is_active is not None
else True
)
.alias()
)
count_result = await db.execute(count_query)
total = count_result.scalar_one()
# Apply ordering and pagination
query = query.order_by(UserOrganization.created_at.desc()).offset(skip).limit(limit)
query = (
query.order_by(UserOrganization.created_at.desc())
.offset(skip)
.limit(limit)
)
result = await db.execute(query)
results = result.all()
members = []
for user_org, user in results:
members.append({
"user_id": user.id,
"email": user.email,
"first_name": user.first_name,
"last_name": user.last_name,
"role": user_org.role,
"is_active": user_org.is_active,
"joined_at": user_org.created_at
})
members.append(
{
"user_id": user.id,
"email": user.email,
"first_name": user.first_name,
"last_name": user.last_name,
"role": user_org.role,
"is_active": user_org.is_active,
"joined_at": user_org.created_at,
}
)
return members, total
except Exception as e:
logger.error(f"Error getting organization members: {str(e)}")
logger.error(f"Error getting organization members: {e!s}")
raise
async def get_user_organizations(
self,
db: AsyncSession,
*,
user_id: UUID,
is_active: bool = True
) -> List[Organization]:
self, db: AsyncSession, *, user_id: UUID, is_active: bool = True
) -> list[Organization]:
"""Get all organizations a user belongs to."""
try:
query = (
select(Organization)
.join(UserOrganization, Organization.id == UserOrganization.organization_id)
.join(
UserOrganization,
Organization.id == UserOrganization.organization_id,
)
.where(UserOrganization.user_id == user_id)
)
@@ -408,16 +429,12 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp
result = await db.execute(query)
return list(result.scalars().all())
except Exception as e:
logger.error(f"Error getting user organizations: {str(e)}")
logger.error(f"Error getting user organizations: {e!s}")
raise
async def get_user_organizations_with_details(
self,
db: AsyncSession,
*,
user_id: UUID,
is_active: bool = True
) -> List[Dict[str, Any]]:
self, db: AsyncSession, *, user_id: UUID, is_active: bool = True
) -> list[dict[str, Any]]:
"""
Get user's organizations with role and member count in SINGLE QUERY.
Eliminates N+1 problem by using subquery for member counts.
@@ -430,9 +447,9 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp
member_count_subq = (
select(
UserOrganization.organization_id,
func.count(UserOrganization.user_id).label('member_count')
func.count(UserOrganization.user_id).label("member_count"),
)
.where(UserOrganization.is_active == True)
.where(UserOrganization.is_active)
.group_by(UserOrganization.organization_id)
.subquery()
)
@@ -442,10 +459,18 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp
select(
Organization,
UserOrganization.role,
func.coalesce(member_count_subq.c.member_count, 0).label('member_count')
func.coalesce(member_count_subq.c.member_count, 0).label(
"member_count"
),
)
.join(
UserOrganization,
Organization.id == UserOrganization.organization_id,
)
.outerjoin(
member_count_subq,
Organization.id == member_count_subq.c.organization_id,
)
.join(UserOrganization, Organization.id == UserOrganization.organization_id)
.outerjoin(member_count_subq, Organization.id == member_count_subq.c.organization_id)
.where(UserOrganization.user_id == user_id)
)
@@ -456,25 +481,19 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp
rows = result.all()
return [
{
'organization': org,
'role': role,
'member_count': member_count
}
{"organization": org, "role": role, "member_count": member_count}
for org, role, member_count in rows
]
except Exception as e:
logger.error(f"Error getting user organizations with details: {str(e)}", exc_info=True)
logger.error(
f"Error getting user organizations with details: {e!s}", exc_info=True
)
raise
async def get_user_role_in_org(
self,
db: AsyncSession,
*,
user_id: UUID,
organization_id: UUID
) -> Optional[OrganizationRole]:
self, db: AsyncSession, *, user_id: UUID, organization_id: UUID
) -> OrganizationRole | None:
"""Get a user's role in a specific organization."""
try:
result = await db.execute(
@@ -482,7 +501,7 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp
and_(
UserOrganization.user_id == user_id,
UserOrganization.organization_id == organization_id,
UserOrganization.is_active == True
UserOrganization.is_active,
)
)
)
@@ -490,29 +509,25 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp
return user_org.role if user_org else None
except Exception as e:
logger.error(f"Error getting user role in org: {str(e)}")
logger.error(f"Error getting user role in org: {e!s}")
raise
async def is_user_org_owner(
self,
db: AsyncSession,
*,
user_id: UUID,
organization_id: UUID
self, db: AsyncSession, *, user_id: UUID, organization_id: UUID
) -> bool:
"""Check if a user is an owner of an organization."""
role = await self.get_user_role_in_org(db, user_id=user_id, organization_id=organization_id)
role = await self.get_user_role_in_org(
db, user_id=user_id, organization_id=organization_id
)
return role == OrganizationRole.OWNER
async def is_user_org_admin(
self,
db: AsyncSession,
*,
user_id: UUID,
organization_id: UUID
self, db: AsyncSession, *, user_id: UUID, organization_id: UUID
) -> bool:
"""Check if a user is an owner or admin of an organization."""
role = await self.get_user_role_in_org(db, user_id=user_id, organization_id=organization_id)
role = await self.get_user_role_in_org(
db, user_id=user_id, organization_id=organization_id
)
return role in [OrganizationRole.OWNER, OrganizationRole.ADMIN]

View File

@@ -1,13 +1,13 @@
"""
Async CRUD operations for user sessions using SQLAlchemy 2.0 patterns.
"""
import logging
import uuid
from datetime import datetime, timezone, timedelta
from typing import List, Optional
from datetime import UTC, datetime, timedelta
from uuid import UUID
from sqlalchemy import and_, select, update, delete, func
from sqlalchemy import and_, delete, func, select, update
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import joinedload
@@ -21,7 +21,7 @@ logger = logging.getLogger(__name__)
class CRUDSession(CRUDBase[UserSession, SessionCreate, SessionUpdate]):
"""Async CRUD operations for user sessions."""
async def get_by_jti(self, db: AsyncSession, *, jti: str) -> Optional[UserSession]:
async def get_by_jti(self, db: AsyncSession, *, jti: str) -> UserSession | None:
"""
Get session by refresh token JTI.
@@ -38,10 +38,12 @@ class CRUDSession(CRUDBase[UserSession, SessionCreate, SessionUpdate]):
)
return result.scalar_one_or_none()
except Exception as e:
logger.error(f"Error getting session by JTI {jti}: {str(e)}")
logger.error(f"Error getting session by JTI {jti}: {e!s}")
raise
async def get_active_by_jti(self, db: AsyncSession, *, jti: str) -> Optional[UserSession]:
async def get_active_by_jti(
self, db: AsyncSession, *, jti: str
) -> UserSession | None:
"""
Get active session by refresh token JTI.
@@ -57,13 +59,13 @@ class CRUDSession(CRUDBase[UserSession, SessionCreate, SessionUpdate]):
select(UserSession).where(
and_(
UserSession.refresh_token_jti == jti,
UserSession.is_active == True
UserSession.is_active,
)
)
)
return result.scalar_one_or_none()
except Exception as e:
logger.error(f"Error getting active session by JTI {jti}: {str(e)}")
logger.error(f"Error getting active session by JTI {jti}: {e!s}")
raise
async def get_user_sessions(
@@ -72,8 +74,8 @@ class CRUDSession(CRUDBase[UserSession, SessionCreate, SessionUpdate]):
*,
user_id: str,
active_only: bool = True,
with_user: bool = False
) -> List[UserSession]:
with_user: bool = False,
) -> list[UserSession]:
"""
Get all sessions for a user with optional eager loading.
@@ -97,20 +99,17 @@ class CRUDSession(CRUDBase[UserSession, SessionCreate, SessionUpdate]):
query = query.options(joinedload(UserSession.user))
if active_only:
query = query.where(UserSession.is_active == True)
query = query.where(UserSession.is_active)
query = query.order_by(UserSession.last_used_at.desc())
result = await db.execute(query)
return list(result.scalars().all())
except Exception as e:
logger.error(f"Error getting sessions for user {user_id}: {str(e)}")
logger.error(f"Error getting sessions for user {user_id}: {e!s}")
raise
async def create_session(
self,
db: AsyncSession,
*,
obj_in: SessionCreate
self, db: AsyncSession, *, obj_in: SessionCreate
) -> UserSession:
"""
Create a new user session.
@@ -151,10 +150,12 @@ class CRUDSession(CRUDBase[UserSession, SessionCreate, SessionUpdate]):
return db_obj
except Exception as e:
await db.rollback()
logger.error(f"Error creating session: {str(e)}", exc_info=True)
raise ValueError(f"Failed to create session: {str(e)}")
logger.error(f"Error creating session: {e!s}", exc_info=True)
raise ValueError(f"Failed to create session: {e!s}")
async def deactivate(self, db: AsyncSession, *, session_id: str) -> Optional[UserSession]:
async def deactivate(
self, db: AsyncSession, *, session_id: str
) -> UserSession | None:
"""
Deactivate a session (logout from device).
@@ -184,14 +185,11 @@ class CRUDSession(CRUDBase[UserSession, SessionCreate, SessionUpdate]):
return session
except Exception as e:
await db.rollback()
logger.error(f"Error deactivating session {session_id}: {str(e)}")
logger.error(f"Error deactivating session {session_id}: {e!s}")
raise
async def deactivate_all_user_sessions(
self,
db: AsyncSession,
*,
user_id: str
self, db: AsyncSession, *, user_id: str
) -> int:
"""
Deactivate all active sessions for a user (logout from all devices).
@@ -209,12 +207,7 @@ class CRUDSession(CRUDBase[UserSession, SessionCreate, SessionUpdate]):
stmt = (
update(UserSession)
.where(
and_(
UserSession.user_id == user_uuid,
UserSession.is_active == True
)
)
.where(and_(UserSession.user_id == user_uuid, UserSession.is_active))
.values(is_active=False)
)
@@ -228,14 +221,11 @@ class CRUDSession(CRUDBase[UserSession, SessionCreate, SessionUpdate]):
return count
except Exception as e:
await db.rollback()
logger.error(f"Error deactivating all sessions for user {user_id}: {str(e)}")
logger.error(f"Error deactivating all sessions for user {user_id}: {e!s}")
raise
async def update_last_used(
self,
db: AsyncSession,
*,
session: UserSession
self, db: AsyncSession, *, session: UserSession
) -> UserSession:
"""
Update the last_used_at timestamp for a session.
@@ -248,14 +238,14 @@ class CRUDSession(CRUDBase[UserSession, SessionCreate, SessionUpdate]):
Updated UserSession
"""
try:
session.last_used_at = datetime.now(timezone.utc)
session.last_used_at = datetime.now(UTC)
db.add(session)
await db.commit()
await db.refresh(session)
return session
except Exception as e:
await db.rollback()
logger.error(f"Error updating last_used for session {session.id}: {str(e)}")
logger.error(f"Error updating last_used for session {session.id}: {e!s}")
raise
async def update_refresh_token(
@@ -264,7 +254,7 @@ class CRUDSession(CRUDBase[UserSession, SessionCreate, SessionUpdate]):
*,
session: UserSession,
new_jti: str,
new_expires_at: datetime
new_expires_at: datetime,
) -> UserSession:
"""
Update session with new refresh token JTI and expiration.
@@ -283,14 +273,16 @@ class CRUDSession(CRUDBase[UserSession, SessionCreate, SessionUpdate]):
try:
session.refresh_token_jti = new_jti
session.expires_at = new_expires_at
session.last_used_at = datetime.now(timezone.utc)
session.last_used_at = datetime.now(UTC)
db.add(session)
await db.commit()
await db.refresh(session)
return session
except Exception as e:
await db.rollback()
logger.error(f"Error updating refresh token for session {session.id}: {str(e)}")
logger.error(
f"Error updating refresh token for session {session.id}: {e!s}"
)
raise
async def cleanup_expired(self, db: AsyncSession, *, keep_days: int = 30) -> int:
@@ -311,15 +303,15 @@ class CRUDSession(CRUDBase[UserSession, SessionCreate, SessionUpdate]):
Number of sessions deleted
"""
try:
cutoff_date = datetime.now(timezone.utc) - timedelta(days=keep_days)
now = datetime.now(timezone.utc)
cutoff_date = datetime.now(UTC) - timedelta(days=keep_days)
now = datetime.now(UTC)
# Use bulk DELETE with WHERE clause - single query
stmt = delete(UserSession).where(
and_(
UserSession.is_active == False,
not UserSession.is_active,
UserSession.expires_at < now,
UserSession.created_at < cutoff_date
UserSession.created_at < cutoff_date,
)
)
@@ -334,15 +326,10 @@ class CRUDSession(CRUDBase[UserSession, SessionCreate, SessionUpdate]):
return count
except Exception as e:
await db.rollback()
logger.error(f"Error cleaning up expired sessions: {str(e)}")
logger.error(f"Error cleaning up expired sessions: {e!s}")
raise
async def cleanup_expired_for_user(
self,
db: AsyncSession,
*,
user_id: str
) -> int:
async def cleanup_expired_for_user(self, db: AsyncSession, *, user_id: str) -> int:
"""
Clean up expired and inactive sessions for a specific user.
@@ -363,14 +350,14 @@ class CRUDSession(CRUDBase[UserSession, SessionCreate, SessionUpdate]):
logger.error(f"Invalid UUID format: {user_id}")
raise ValueError(f"Invalid user ID format: {user_id}")
now = datetime.now(timezone.utc)
now = datetime.now(UTC)
# Use bulk DELETE with WHERE clause - single query
stmt = delete(UserSession).where(
and_(
UserSession.user_id == uuid_obj,
UserSession.is_active == False,
UserSession.expires_at < now
not UserSession.is_active,
UserSession.expires_at < now,
)
)
@@ -388,7 +375,7 @@ class CRUDSession(CRUDBase[UserSession, SessionCreate, SessionUpdate]):
except Exception as e:
await db.rollback()
logger.error(
f"Error cleaning up expired sessions for user {user_id}: {str(e)}"
f"Error cleaning up expired sessions for user {user_id}: {e!s}"
)
raise
@@ -409,15 +396,12 @@ class CRUDSession(CRUDBase[UserSession, SessionCreate, SessionUpdate]):
result = await db.execute(
select(func.count(UserSession.id)).where(
and_(
UserSession.user_id == user_uuid,
UserSession.is_active == True
)
and_(UserSession.user_id == user_uuid, UserSession.is_active)
)
)
return result.scalar_one()
except Exception as e:
logger.error(f"Error counting sessions for user {user_id}: {str(e)}")
logger.error(f"Error counting sessions for user {user_id}: {e!s}")
raise
async def get_all_sessions(
@@ -427,8 +411,8 @@ class CRUDSession(CRUDBase[UserSession, SessionCreate, SessionUpdate]):
skip: int = 0,
limit: int = 100,
active_only: bool = True,
with_user: bool = True
) -> tuple[List[UserSession], int]:
with_user: bool = True,
) -> tuple[list[UserSession], int]:
"""
Get all sessions across all users with pagination (admin only).
@@ -451,18 +435,22 @@ class CRUDSession(CRUDBase[UserSession, SessionCreate, SessionUpdate]):
query = query.options(joinedload(UserSession.user))
if active_only:
query = query.where(UserSession.is_active == True)
query = query.where(UserSession.is_active)
# Get total count
count_query = select(func.count(UserSession.id))
if active_only:
count_query = count_query.where(UserSession.is_active == True)
count_query = count_query.where(UserSession.is_active)
count_result = await db.execute(count_query)
total = count_result.scalar_one()
# Apply pagination and ordering
query = query.order_by(UserSession.last_used_at.desc()).offset(skip).limit(limit)
query = (
query.order_by(UserSession.last_used_at.desc())
.offset(skip)
.limit(limit)
)
result = await db.execute(query)
sessions = list(result.scalars().all())
@@ -470,7 +458,7 @@ class CRUDSession(CRUDBase[UserSession, SessionCreate, SessionUpdate]):
return sessions, total
except Exception as e:
logger.error(f"Error getting all sessions: {str(e)}", exc_info=True)
logger.error(f"Error getting all sessions: {e!s}", exc_info=True)
raise

View File

@@ -1,8 +1,9 @@
# app/crud/user_async.py
"""Async CRUD operations for User model using SQLAlchemy 2.0 patterns."""
import logging
from datetime import datetime, timezone
from typing import Optional, Union, Dict, Any, List, Tuple
from datetime import UTC, datetime
from typing import Any
from uuid import UUID
from sqlalchemy import or_, select, update
@@ -20,15 +21,13 @@ logger = logging.getLogger(__name__)
class CRUDUser(CRUDBase[User, UserCreate, UserUpdate]):
"""Async CRUD operations for User model."""
async def get_by_email(self, db: AsyncSession, *, email: str) -> Optional[User]:
async def get_by_email(self, db: AsyncSession, *, email: str) -> User | None:
"""Get user by email address."""
try:
result = await db.execute(
select(User).where(User.email == email)
)
result = await db.execute(select(User).where(User.email == email))
return result.scalar_one_or_none()
except Exception as e:
logger.error(f"Error getting user by email {email}: {str(e)}")
logger.error(f"Error getting user by email {email}: {e!s}")
raise
async def create(self, db: AsyncSession, *, obj_in: UserCreate) -> User:
@@ -42,9 +41,13 @@ class CRUDUser(CRUDBase[User, UserCreate, UserUpdate]):
password_hash=password_hash,
first_name=obj_in.first_name,
last_name=obj_in.last_name,
phone_number=obj_in.phone_number if hasattr(obj_in, 'phone_number') else None,
is_superuser=obj_in.is_superuser if hasattr(obj_in, 'is_superuser') else False,
preferences={}
phone_number=obj_in.phone_number
if hasattr(obj_in, "phone_number")
else None,
is_superuser=obj_in.is_superuser
if hasattr(obj_in, "is_superuser")
else False,
preferences={},
)
db.add(db_obj)
await db.commit()
@@ -52,7 +55,7 @@ class CRUDUser(CRUDBase[User, UserCreate, UserUpdate]):
return db_obj
except IntegrityError as e:
await db.rollback()
error_msg = str(e.orig) if hasattr(e, 'orig') else str(e)
error_msg = str(e.orig) if hasattr(e, "orig") else str(e)
if "email" in error_msg.lower():
logger.warning(f"Duplicate email attempted: {obj_in.email}")
raise ValueError(f"User with email {obj_in.email} already exists")
@@ -60,15 +63,11 @@ class CRUDUser(CRUDBase[User, UserCreate, UserUpdate]):
raise ValueError(f"Database integrity error: {error_msg}")
except Exception as e:
await db.rollback()
logger.error(f"Unexpected error creating user: {str(e)}", exc_info=True)
logger.error(f"Unexpected error creating user: {e!s}", exc_info=True)
raise
async def update(
self,
db: AsyncSession,
*,
db_obj: User,
obj_in: Union[UserUpdate, Dict[str, Any]]
self, db: AsyncSession, *, db_obj: User, obj_in: UserUpdate | dict[str, Any]
) -> User:
"""Update user with async password hashing if password is updated."""
if isinstance(obj_in, dict):
@@ -79,7 +78,9 @@ class CRUDUser(CRUDBase[User, UserCreate, UserUpdate]):
# Handle password separately if it exists in update data
# Hash password asynchronously to avoid blocking event loop
if "password" in update_data:
update_data["password_hash"] = await get_password_hash_async(update_data["password"])
update_data["password_hash"] = await get_password_hash_async(
update_data["password"]
)
del update_data["password"]
return await super().update(db, db_obj=db_obj, obj_in=update_data)
@@ -90,11 +91,11 @@ class CRUDUser(CRUDBase[User, UserCreate, UserUpdate]):
*,
skip: int = 0,
limit: int = 100,
sort_by: Optional[str] = None,
sort_by: str | None = None,
sort_order: str = "asc",
filters: Optional[Dict[str, Any]] = None,
search: Optional[str] = None
) -> Tuple[List[User], int]:
filters: dict[str, Any] | None = None,
search: str | None = None,
) -> tuple[list[User], int]:
"""
Get multiple users with total count, filtering, sorting, and search.
@@ -136,12 +137,13 @@ class CRUDUser(CRUDBase[User, UserCreate, UserUpdate]):
search_filter = or_(
User.email.ilike(f"%{search}%"),
User.first_name.ilike(f"%{search}%"),
User.last_name.ilike(f"%{search}%")
User.last_name.ilike(f"%{search}%"),
)
query = query.where(search_filter)
# Get total count
from sqlalchemy import func
count_query = select(func.count()).select_from(query.alias())
count_result = await db.execute(count_query)
total = count_result.scalar_one()
@@ -162,15 +164,11 @@ class CRUDUser(CRUDBase[User, UserCreate, UserUpdate]):
return users, total
except Exception as e:
logger.error(f"Error retrieving paginated users: {str(e)}")
logger.error(f"Error retrieving paginated users: {e!s}")
raise
async def bulk_update_status(
self,
db: AsyncSession,
*,
user_ids: List[UUID],
is_active: bool
self, db: AsyncSession, *, user_ids: list[UUID], is_active: bool
) -> int:
"""
Bulk update is_active status for multiple users.
@@ -192,7 +190,7 @@ class CRUDUser(CRUDBase[User, UserCreate, UserUpdate]):
update(User)
.where(User.id.in_(user_ids))
.where(User.deleted_at.is_(None)) # Don't update deleted users
.values(is_active=is_active, updated_at=datetime.now(timezone.utc))
.values(is_active=is_active, updated_at=datetime.now(UTC))
)
result = await db.execute(stmt)
@@ -204,15 +202,15 @@ class CRUDUser(CRUDBase[User, UserCreate, UserUpdate]):
except Exception as e:
await db.rollback()
logger.error(f"Error bulk updating user status: {str(e)}", exc_info=True)
logger.error(f"Error bulk updating user status: {e!s}", exc_info=True)
raise
async def bulk_soft_delete(
self,
db: AsyncSession,
*,
user_ids: List[UUID],
exclude_user_id: Optional[UUID] = None
user_ids: list[UUID],
exclude_user_id: UUID | None = None,
) -> int:
"""
Bulk soft delete multiple users.
@@ -239,11 +237,13 @@ class CRUDUser(CRUDBase[User, UserCreate, UserUpdate]):
stmt = (
update(User)
.where(User.id.in_(filtered_ids))
.where(User.deleted_at.is_(None)) # Don't re-delete already deleted users
.where(
User.deleted_at.is_(None)
) # Don't re-delete already deleted users
.values(
deleted_at=datetime.now(timezone.utc),
deleted_at=datetime.now(UTC),
is_active=False,
updated_at=datetime.now(timezone.utc)
updated_at=datetime.now(UTC),
)
)
@@ -256,7 +256,7 @@ class CRUDUser(CRUDBase[User, UserCreate, UserUpdate]):
except Exception as e:
await db.rollback()
logger.error(f"Error bulk deleting users: {str(e)}", exc_info=True)
logger.error(f"Error bulk deleting users: {e!s}", exc_info=True)
raise
def is_active(self, user: User) -> bool:

View File

@@ -4,9 +4,9 @@ Async database initialization script.
Creates the first superuser if configured and doesn't already exist.
"""
import asyncio
import logging
from typing import Optional
from app.core.config import settings
from app.core.database import SessionLocal, engine
@@ -17,7 +17,7 @@ from app.schemas.users import UserCreate
logger = logging.getLogger(__name__)
async def init_db() -> Optional[User]:
async def init_db() -> User | None:
"""
Initialize database with first superuser if settings are configured and user doesn't exist.
@@ -49,7 +49,7 @@ async def init_db() -> Optional[User]:
password=superuser_password,
first_name="Admin",
last_name="User",
is_superuser=True
is_superuser=True,
)
user = await user_crud.create(session, obj_in=user_in)
@@ -70,13 +70,13 @@ async def main():
# Configure logging to show info logs
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
)
try:
user = await init_db()
if user:
print(f"✓ Database initialized successfully")
print("✓ Database initialized successfully")
print(f"✓ Superuser: {user.email}")
else:
print("✗ Failed to initialize database")

View File

@@ -2,10 +2,10 @@ import logging
import os
from contextlib import asynccontextmanager
from datetime import datetime
from typing import Dict, Any
from typing import Any
from apscheduler.schedulers.asyncio import AsyncIOScheduler
from fastapi import FastAPI, status, Request, HTTPException
from fastapi import FastAPI, HTTPException, Request, status
from fastapi.exceptions import RequestValidationError
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import HTMLResponse, JSONResponse
@@ -19,9 +19,9 @@ from app.core.database import check_database_health
from app.core.exceptions import (
APIException,
api_exception_handler,
validation_exception_handler,
http_exception_handler,
unhandled_exception_handler
unhandled_exception_handler,
validation_exception_handler,
)
scheduler = AsyncIOScheduler()
@@ -52,11 +52,11 @@ async def lifespan(app: FastAPI):
# Runs daily at 2:00 AM server time
scheduler.add_job(
cleanup_expired_sessions,
'cron',
"cron",
hour=2,
minute=0,
id='cleanup_expired_sessions',
replace_existing=True
id="cleanup_expired_sessions",
replace_existing=True,
)
scheduler.start()
@@ -73,12 +73,12 @@ async def lifespan(app: FastAPI):
logger.info("Scheduled jobs stopped")
logger.info(f"Starting app!!!")
logger.info("Starting app!!!")
app = FastAPI(
title=settings.PROJECT_NAME,
version=settings.VERSION,
openapi_url=f"{settings.API_V1_STR}/openapi.json",
lifespan=lifespan
lifespan=lifespan,
)
# Add rate limiter state to app
@@ -96,7 +96,14 @@ app.add_middleware(
CORSMiddleware,
allow_origins=settings.BACKEND_CORS_ORIGINS,
allow_credentials=True,
allow_methods=["GET", "POST", "PUT", "PATCH", "DELETE", "OPTIONS"], # Explicit methods only
allow_methods=[
"GET",
"POST",
"PUT",
"PATCH",
"DELETE",
"OPTIONS",
], # Explicit methods only
allow_headers=[
"Content-Type",
"Authorization",
@@ -129,12 +136,14 @@ async def limit_request_size(request: Request, call_next):
status_code=status.HTTP_413_REQUEST_ENTITY_TOO_LARGE,
content={
"success": False,
"errors": [{
"code": "REQUEST_TOO_LARGE",
"message": f"Request body too large. Maximum size is {MAX_REQUEST_SIZE // (1024 * 1024)}MB",
"field": None
}]
}
"errors": [
{
"code": "REQUEST_TOO_LARGE",
"message": f"Request body too large. Maximum size is {MAX_REQUEST_SIZE // (1024 * 1024)}MB",
"field": None,
}
],
},
)
response = await call_next(request)
@@ -165,15 +174,19 @@ async def add_security_headers(request: Request, call_next):
# Enforce HTTPS in production
if settings.ENVIRONMENT == "production":
response.headers["Strict-Transport-Security"] = "max-age=31536000; includeSubDomains"
response.headers["Strict-Transport-Security"] = (
"max-age=31536000; includeSubDomains"
)
# Content Security Policy
csp_mode = settings.CSP_MODE.lower()
# Special handling for API docs
is_docs = request.url.path in ["/docs", "/redoc"] or \
request.url.path.startswith("/docs/") or \
request.url.path.startswith("/redoc/")
is_docs = (
request.url.path in ["/docs", "/redoc"]
or request.url.path.startswith("/docs/")
or request.url.path.startswith("/redoc/")
)
if csp_mode == "disabled":
# No CSP (only for local development/debugging)
@@ -264,7 +277,7 @@ async def root():
description="Check the health status of the API and its dependencies",
response_description="Health status information",
tags=["Health"],
operation_id="health_check"
operation_id="health_check",
)
async def health_check() -> JSONResponse:
"""
@@ -278,12 +291,12 @@ async def health_check() -> JSONResponse:
- environment: Current environment (development, staging, production)
- database: Database connectivity status
"""
health_status: Dict[str, Any] = {
health_status: dict[str, Any] = {
"status": "healthy",
"timestamp": datetime.utcnow().isoformat() + "Z",
"version": settings.VERSION,
"environment": settings.ENVIRONMENT,
"checks": {}
"checks": {},
}
response_status = status.HTTP_200_OK
@@ -294,7 +307,7 @@ async def health_check() -> JSONResponse:
if db_healthy:
health_status["checks"]["database"] = {
"status": "healthy",
"message": "Database connection successful"
"message": "Database connection successful",
}
else:
raise Exception("Database health check returned unhealthy status")
@@ -302,15 +315,12 @@ async def health_check() -> JSONResponse:
health_status["status"] = "unhealthy"
health_status["checks"]["database"] = {
"status": "unhealthy",
"message": f"Database connection failed: {str(e)}"
"message": f"Database connection failed: {e!s}",
}
response_status = status.HTTP_503_SERVICE_UNAVAILABLE
logger.error(f"Health check failed - database error: {e}")
return JSONResponse(
status_code=response_status,
content=health_status
)
return JSONResponse(status_code=response_status, content=health_status)
app.include_router(api_router, prefix=settings.API_V1_STR)

View File

@@ -2,17 +2,25 @@
Models package initialization.
Imports all models to ensure they're registered with SQLAlchemy.
"""
# First import Base to avoid circular imports
from app.core.database import Base
from .base import TimestampMixin, UUIDMixin
from .organization import Organization
# Import models
from .user import User
from .user_organization import UserOrganization, OrganizationRole
from .user_organization import OrganizationRole, UserOrganization
from .user_session import UserSession
__all__ = [
'Base', 'TimestampMixin', 'UUIDMixin',
'User', 'UserSession',
'Organization', 'UserOrganization', 'OrganizationRole',
]
"Base",
"Organization",
"OrganizationRole",
"TimestampMixin",
"UUIDMixin",
"User",
"UserOrganization",
"UserSession",
]

View File

@@ -1,20 +1,27 @@
import uuid
from datetime import datetime, timezone
from datetime import UTC, datetime
from sqlalchemy import Column, DateTime
from sqlalchemy.dialects.postgresql import UUID
# noinspection PyUnresolvedReferences
from app.core.database import Base
class TimestampMixin:
"""Mixin to add created_at and updated_at timestamps to models"""
created_at = Column(DateTime(timezone=True), default=lambda: datetime.now(timezone.utc), nullable=False)
updated_at = Column(DateTime(timezone=True), default=lambda: datetime.now(timezone.utc),
onupdate=lambda: datetime.now(timezone.utc), nullable=False)
created_at = Column(
DateTime(timezone=True), default=lambda: datetime.now(UTC), nullable=False
)
updated_at = Column(
DateTime(timezone=True),
default=lambda: datetime.now(UTC),
onupdate=lambda: datetime.now(UTC),
nullable=False,
)
class UUIDMixin:
"""Mixin to add UUID primary keys to models"""
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)

View File

@@ -1,5 +1,5 @@
# app/models/organization.py
from sqlalchemy import Column, String, Boolean, Text, Index
from sqlalchemy import Boolean, Column, Index, String, Text
from sqlalchemy.dialects.postgresql import JSONB
from sqlalchemy.orm import relationship
@@ -11,7 +11,8 @@ class Organization(Base, UUIDMixin, TimestampMixin):
Organization model for multi-tenant support.
Users can belong to multiple organizations with different roles.
"""
__tablename__ = 'organizations'
__tablename__ = "organizations"
name = Column(String(255), nullable=False, index=True)
slug = Column(String(255), unique=True, nullable=False, index=True)
@@ -20,11 +21,13 @@ class Organization(Base, UUIDMixin, TimestampMixin):
settings = Column(JSONB, default={})
# Relationships
user_organizations = relationship("UserOrganization", back_populates="organization", cascade="all, delete-orphan")
user_organizations = relationship(
"UserOrganization", back_populates="organization", cascade="all, delete-orphan"
)
__table_args__ = (
Index('ix_organizations_name_active', 'name', 'is_active'),
Index('ix_organizations_slug_active', 'slug', 'is_active'),
Index("ix_organizations_name_active", "name", "is_active"),
Index("ix_organizations_slug_active", "slug", "is_active"),
)
def __repr__(self):

View File

@@ -1,4 +1,4 @@
from sqlalchemy import Column, String, Boolean, DateTime
from sqlalchemy import Boolean, Column, DateTime, String
from sqlalchemy.dialects.postgresql import JSONB
from sqlalchemy.orm import relationship
@@ -6,7 +6,7 @@ from .base import Base, TimestampMixin, UUIDMixin
class User(Base, UUIDMixin, TimestampMixin):
__tablename__ = 'users'
__tablename__ = "users"
email = Column(String(255), unique=True, nullable=False, index=True)
password_hash = Column(String(255), nullable=False)
@@ -19,7 +19,9 @@ class User(Base, UUIDMixin, TimestampMixin):
deleted_at = Column(DateTime(timezone=True), nullable=True, index=True)
# Relationships
user_organizations = relationship("UserOrganization", back_populates="user", cascade="all, delete-orphan")
user_organizations = relationship(
"UserOrganization", back_populates="user", cascade="all, delete-orphan"
)
def __repr__(self):
return f"<User {self.email}>"
return f"<User {self.email}>"

View File

@@ -1,7 +1,7 @@
# app/models/user_organization.py
from enum import Enum as PyEnum
from sqlalchemy import Column, ForeignKey, Boolean, String, Index, Enum
from sqlalchemy import Boolean, Column, Enum, ForeignKey, Index, String
from sqlalchemy.dialects.postgresql import UUID as PGUUID
from sqlalchemy.orm import relationship
@@ -14,6 +14,7 @@ class OrganizationRole(str, PyEnum):
These provide a baseline role system that can be optionally used.
Projects can extend this or implement their own permission system.
"""
OWNER = "owner" # Full control over organization
ADMIN = "admin" # Can manage users and settings
MEMBER = "member" # Regular member with standard access
@@ -25,25 +26,41 @@ class UserOrganization(Base, TimestampMixin):
Junction table for many-to-many relationship between Users and Organizations.
Includes role information for flexible RBAC.
"""
__tablename__ = 'user_organizations'
user_id = Column(PGUUID(as_uuid=True), ForeignKey('users.id', ondelete='CASCADE'), primary_key=True)
organization_id = Column(PGUUID(as_uuid=True), ForeignKey('organizations.id', ondelete='CASCADE'), primary_key=True)
__tablename__ = "user_organizations"
role = Column(Enum(OrganizationRole), default=OrganizationRole.MEMBER, nullable=False, index=True)
user_id = Column(
PGUUID(as_uuid=True),
ForeignKey("users.id", ondelete="CASCADE"),
primary_key=True,
)
organization_id = Column(
PGUUID(as_uuid=True),
ForeignKey("organizations.id", ondelete="CASCADE"),
primary_key=True,
)
role = Column(
Enum(OrganizationRole),
default=OrganizationRole.MEMBER,
nullable=False,
index=True,
)
is_active = Column(Boolean, default=True, nullable=False, index=True)
# Optional: Custom permissions override for specific users
custom_permissions = Column(String(500), nullable=True) # JSON array of permission strings
custom_permissions = Column(
String(500), nullable=True
) # JSON array of permission strings
# Relationships
user = relationship("User", back_populates="user_organizations")
organization = relationship("Organization", back_populates="user_organizations")
__table_args__ = (
Index('ix_user_org_user_active', 'user_id', 'is_active'),
Index('ix_user_org_org_active', 'organization_id', 'is_active'),
Index('ix_user_org_role', 'role'),
Index("ix_user_org_user_active", "user_id", "is_active"),
Index("ix_user_org_org_active", "organization_id", "is_active"),
Index("ix_user_org_role", "role"),
)
def __repr__(self):

View File

@@ -6,7 +6,10 @@ This allows users to:
- Logout from specific devices
- Manage their active sessions
"""
from sqlalchemy import Column, String, Boolean, DateTime, ForeignKey, Index
from datetime import UTC
from sqlalchemy import Boolean, Column, DateTime, ForeignKey, Index, String
from sqlalchemy.dialects.postgresql import UUID
from sqlalchemy.orm import relationship
@@ -20,19 +23,27 @@ class UserSession(Base, UUIDMixin, TimestampMixin):
Each time a user logs in from a device, a new session is created.
Sessions are identified by the refresh token JTI (JWT ID).
"""
__tablename__ = 'user_sessions'
__tablename__ = "user_sessions"
# Foreign key to user
user_id = Column(UUID(as_uuid=True), ForeignKey('users.id', ondelete='CASCADE'), nullable=False, index=True)
user_id = Column(
UUID(as_uuid=True),
ForeignKey("users.id", ondelete="CASCADE"),
nullable=False,
index=True,
)
# Refresh token identifier (JWT ID from the refresh token)
refresh_token_jti = Column(String(255), unique=True, nullable=False, index=True)
# Device information
device_name = Column(String(255), nullable=True) # "iPhone 14", "Chrome on MacBook"
device_id = Column(String(255), nullable=True) # Persistent device identifier (from client)
ip_address = Column(String(45), nullable=True) # IPv4 (15 chars) or IPv6 (45 chars)
user_agent = Column(String(500), nullable=True) # Browser/app user agent
device_id = Column(
String(255), nullable=True
) # Persistent device identifier (from client)
ip_address = Column(String(45), nullable=True) # IPv4 (15 chars) or IPv6 (45 chars)
user_agent = Column(String(500), nullable=True) # Browser/app user agent
# Session timing
last_used_at = Column(DateTime(timezone=True), nullable=False)
@@ -50,8 +61,8 @@ class UserSession(Base, UUIDMixin, TimestampMixin):
# Composite indexes for performance (defined in migration)
__table_args__ = (
Index('ix_user_sessions_user_active', 'user_id', 'is_active'),
Index('ix_user_sessions_jti_active', 'refresh_token_jti', 'is_active'),
Index("ix_user_sessions_user_active", "user_id", "is_active"),
Index("ix_user_sessions_jti_active", "refresh_token_jti", "is_active"),
)
def __repr__(self):
@@ -60,21 +71,24 @@ class UserSession(Base, UUIDMixin, TimestampMixin):
@property
def is_expired(self) -> bool:
"""Check if session has expired."""
from datetime import datetime, timezone
return self.expires_at < datetime.now(timezone.utc)
from datetime import datetime
return self.expires_at < datetime.now(UTC)
def to_dict(self):
"""Convert session to dictionary for serialization."""
return {
'id': str(self.id),
'user_id': str(self.user_id),
'device_name': self.device_name,
'device_id': self.device_id,
'ip_address': self.ip_address,
'last_used_at': self.last_used_at.isoformat() if self.last_used_at else None,
'expires_at': self.expires_at.isoformat() if self.expires_at else None,
'is_active': self.is_active,
'location_city': self.location_city,
'location_country': self.location_country,
'created_at': self.created_at.isoformat() if self.created_at else None,
"id": str(self.id),
"user_id": str(self.user_id),
"device_name": self.device_name,
"device_id": self.device_id,
"ip_address": self.ip_address,
"last_used_at": self.last_used_at.isoformat()
if self.last_used_at
else None,
"expires_at": self.expires_at.isoformat() if self.expires_at else None,
"is_active": self.is_active,
"location_city": self.location_city,
"location_country": self.location_country,
"created_at": self.created_at.isoformat() if self.created_at else None,
}

View File

@@ -1,18 +1,20 @@
"""
Common schemas used across the API for pagination, responses, filtering, and sorting.
"""
from enum import Enum
from math import ceil
from typing import Generic, TypeVar, List, Optional
from typing import TypeVar
from uuid import UUID
from pydantic import BaseModel, Field
T = TypeVar('T')
T = TypeVar("T")
class SortOrder(str, Enum):
"""Sort order options."""
ASC = "asc"
DESC = "desc"
@@ -20,16 +22,9 @@ class SortOrder(str, Enum):
class PaginationParams(BaseModel):
"""Parameters for pagination."""
page: int = Field(
default=1,
ge=1,
description="Page number (1-indexed)"
)
page: int = Field(default=1, ge=1, description="Page number (1-indexed)")
limit: int = Field(
default=20,
ge=1,
le=100,
description="Number of items per page (max 100)"
default=20, ge=1, le=100, description="Number of items per page (max 100)"
)
@property
@@ -42,34 +37,20 @@ class PaginationParams(BaseModel):
"""Alias for offset (compatibility with existing code)."""
return self.offset
model_config = {
"json_schema_extra": {
"example": {
"page": 1,
"limit": 20
}
}
}
model_config = {"json_schema_extra": {"example": {"page": 1, "limit": 20}}}
class SortParams(BaseModel):
"""Parameters for sorting."""
sort_by: Optional[str] = Field(
default=None,
description="Field name to sort by"
)
sort_by: str | None = Field(default=None, description="Field name to sort by")
sort_order: SortOrder = Field(
default=SortOrder.ASC,
description="Sort order (asc or desc)"
default=SortOrder.ASC, description="Sort order (asc or desc)"
)
model_config = {
"json_schema_extra": {
"example": {
"sort_by": "created_at",
"sort_order": "desc"
}
"example": {"sort_by": "created_at", "sort_order": "desc"}
}
}
@@ -92,32 +73,30 @@ class PaginationMeta(BaseModel):
"page_size": 20,
"total_pages": 8,
"has_next": True,
"has_prev": False
"has_prev": False,
}
}
}
class PaginatedResponse(BaseModel, Generic[T]):
class PaginatedResponse[T](BaseModel):
"""Generic paginated response wrapper."""
data: List[T] = Field(..., description="List of items")
data: list[T] = Field(..., description="List of items")
pagination: PaginationMeta = Field(..., description="Pagination metadata")
model_config = {
"json_schema_extra": {
"example": {
"data": [
{"id": "123", "name": "Example Item"}
],
"data": [{"id": "123", "name": "Example Item"}],
"pagination": {
"total": 150,
"page": 1,
"page_size": 20,
"total_pages": 8,
"has_next": True,
"has_prev": False
}
"has_prev": False,
},
}
}
}
@@ -131,10 +110,7 @@ class MessageResponse(BaseModel):
model_config = {
"json_schema_extra": {
"example": {
"success": True,
"message": "Operation completed successfully"
}
"example": {"success": True, "message": "Operation completed successfully"}
}
}
@@ -142,11 +118,11 @@ class MessageResponse(BaseModel):
class BulkActionRequest(BaseModel):
"""Request schema for bulk operations on multiple items."""
ids: List[UUID] = Field(
ids: list[UUID] = Field(
...,
min_length=1,
max_length=100,
description="List of item IDs to perform action on (max 100)"
description="List of item IDs to perform action on (max 100)",
)
model_config = {
@@ -154,7 +130,7 @@ class BulkActionRequest(BaseModel):
"example": {
"ids": [
"550e8400-e29b-41d4-a716-446655440000",
"6ba7b810-9dad-11d1-80b4-00c04fd430c8"
"6ba7b810-9dad-11d1-80b4-00c04fd430c8",
]
}
}
@@ -166,24 +142,23 @@ class BulkActionResponse(BaseModel):
success: bool = Field(default=True, description="Operation success status")
message: str = Field(..., description="Human-readable message")
affected_count: int = Field(..., description="Number of items affected by the operation")
affected_count: int = Field(
..., description="Number of items affected by the operation"
)
model_config = {
"json_schema_extra": {
"example": {
"success": True,
"message": "Successfully deactivated 5 users",
"affected_count": 5
"affected_count": 5,
}
}
}
def create_pagination_meta(
total: int,
page: int,
limit: int,
items_count: int
total: int, page: int, limit: int, items_count: int
) -> PaginationMeta:
"""
Helper function to create pagination metadata.
@@ -205,5 +180,5 @@ def create_pagination_meta(
page_size=items_count,
total_pages=total_pages,
has_next=page < total_pages,
has_prev=page > 1
has_prev=page > 1,
)

View File

@@ -1,8 +1,8 @@
"""
Error schemas for standardized API error responses.
"""
from enum import Enum
from typing import List, Optional
from pydantic import BaseModel, Field
@@ -53,14 +53,14 @@ class ErrorDetail(BaseModel):
code: ErrorCode = Field(..., description="Machine-readable error code")
message: str = Field(..., description="Human-readable error message")
field: Optional[str] = Field(None, description="Field name if error is field-specific")
field: str | None = Field(None, description="Field name if error is field-specific")
model_config = {
"json_schema_extra": {
"example": {
"code": "VAL_002",
"message": "Password must be at least 8 characters long",
"field": "password"
"field": "password",
}
}
}
@@ -70,7 +70,7 @@ class ErrorResponse(BaseModel):
"""Standardized error response format."""
success: bool = Field(default=False, description="Always false for error responses")
errors: List[ErrorDetail] = Field(..., description="List of errors that occurred")
errors: list[ErrorDetail] = Field(..., description="List of errors that occurred")
model_config = {
"json_schema_extra": {
@@ -80,9 +80,9 @@ class ErrorResponse(BaseModel):
{
"code": "AUTH_001",
"message": "Invalid email or password",
"field": None
"field": None,
}
]
],
}
}
}

View File

@@ -1,10 +1,10 @@
# app/schemas/organizations.py
import re
from datetime import datetime
from typing import Optional, Dict, Any, List
from typing import Any
from uuid import UUID
from pydantic import BaseModel, field_validator, ConfigDict, Field
from pydantic import BaseModel, ConfigDict, Field, field_validator
from app.models.user_organization import OrganizationRole
@@ -12,85 +12,94 @@ from app.models.user_organization import OrganizationRole
# Organization Schemas
class OrganizationBase(BaseModel):
"""Base organization schema with common fields."""
name: str = Field(..., min_length=1, max_length=255)
slug: Optional[str] = Field(None, min_length=1, max_length=255)
description: Optional[str] = None
is_active: bool = True
settings: Optional[Dict[str, Any]] = {}
@field_validator('slug')
name: str = Field(..., min_length=1, max_length=255)
slug: str | None = Field(None, min_length=1, max_length=255)
description: str | None = None
is_active: bool = True
settings: dict[str, Any] | None = {}
@field_validator("slug")
@classmethod
def validate_slug(cls, v: Optional[str]) -> Optional[str]:
def validate_slug(cls, v: str | None) -> str | None:
"""Validate slug format: lowercase, alphanumeric, hyphens only."""
if v is None:
return v
if not re.match(r'^[a-z0-9-]+$', v):
raise ValueError('Slug must contain only lowercase letters, numbers, and hyphens')
if v.startswith('-') or v.endswith('-'):
raise ValueError('Slug cannot start or end with a hyphen')
if '--' in v:
raise ValueError('Slug cannot contain consecutive hyphens')
if not re.match(r"^[a-z0-9-]+$", v):
raise ValueError(
"Slug must contain only lowercase letters, numbers, and hyphens"
)
if v.startswith("-") or v.endswith("-"):
raise ValueError("Slug cannot start or end with a hyphen")
if "--" in v:
raise ValueError("Slug cannot contain consecutive hyphens")
return v
@field_validator('name')
@field_validator("name")
@classmethod
def validate_name(cls, v: str) -> str:
"""Validate organization name."""
if not v or v.strip() == "":
raise ValueError('Organization name cannot be empty')
raise ValueError("Organization name cannot be empty")
return v.strip()
class OrganizationCreate(OrganizationBase):
"""Schema for creating a new organization."""
name: str = Field(..., min_length=1, max_length=255)
slug: str = Field(..., min_length=1, max_length=255)
class OrganizationUpdate(BaseModel):
"""Schema for updating an organization."""
name: Optional[str] = Field(None, min_length=1, max_length=255)
slug: Optional[str] = Field(None, min_length=1, max_length=255)
description: Optional[str] = None
is_active: Optional[bool] = None
settings: Optional[Dict[str, Any]] = None
@field_validator('slug')
name: str | None = Field(None, min_length=1, max_length=255)
slug: str | None = Field(None, min_length=1, max_length=255)
description: str | None = None
is_active: bool | None = None
settings: dict[str, Any] | None = None
@field_validator("slug")
@classmethod
def validate_slug(cls, v: Optional[str]) -> Optional[str]:
def validate_slug(cls, v: str | None) -> str | None:
"""Validate slug format."""
if v is None:
return v
if not re.match(r'^[a-z0-9-]+$', v):
raise ValueError('Slug must contain only lowercase letters, numbers, and hyphens')
if v.startswith('-') or v.endswith('-'):
raise ValueError('Slug cannot start or end with a hyphen')
if '--' in v:
raise ValueError('Slug cannot contain consecutive hyphens')
if not re.match(r"^[a-z0-9-]+$", v):
raise ValueError(
"Slug must contain only lowercase letters, numbers, and hyphens"
)
if v.startswith("-") or v.endswith("-"):
raise ValueError("Slug cannot start or end with a hyphen")
if "--" in v:
raise ValueError("Slug cannot contain consecutive hyphens")
return v
@field_validator('name')
@field_validator("name")
@classmethod
def validate_name(cls, v: Optional[str]) -> Optional[str]:
def validate_name(cls, v: str | None) -> str | None:
"""Validate organization name."""
if v is not None and (not v or v.strip() == ""):
raise ValueError('Organization name cannot be empty')
raise ValueError("Organization name cannot be empty")
return v.strip() if v else v
class OrganizationResponse(OrganizationBase):
"""Schema for organization API responses."""
id: UUID
created_at: datetime
updated_at: Optional[datetime] = None
member_count: Optional[int] = 0
updated_at: datetime | None = None
member_count: int | None = 0
model_config = ConfigDict(from_attributes=True)
class OrganizationListResponse(BaseModel):
"""Schema for paginated organization list responses."""
organizations: List[OrganizationResponse]
organizations: list[OrganizationResponse]
total: int
page: int
page_size: int
@@ -100,44 +109,49 @@ class OrganizationListResponse(BaseModel):
# User-Organization Relationship Schemas
class UserOrganizationBase(BaseModel):
"""Base schema for user-organization relationship."""
role: OrganizationRole = OrganizationRole.MEMBER
is_active: bool = True
custom_permissions: Optional[str] = None
custom_permissions: str | None = None
class UserOrganizationCreate(BaseModel):
"""Schema for adding a user to an organization."""
user_id: UUID
role: OrganizationRole = OrganizationRole.MEMBER
custom_permissions: Optional[str] = None
custom_permissions: str | None = None
class UserOrganizationUpdate(BaseModel):
"""Schema for updating user's role in an organization."""
role: Optional[OrganizationRole] = None
is_active: Optional[bool] = None
custom_permissions: Optional[str] = None
role: OrganizationRole | None = None
is_active: bool | None = None
custom_permissions: str | None = None
class UserOrganizationResponse(BaseModel):
"""Schema for user-organization relationship responses."""
user_id: UUID
organization_id: UUID
role: OrganizationRole
is_active: bool
custom_permissions: Optional[str] = None
custom_permissions: str | None = None
created_at: datetime
updated_at: Optional[datetime] = None
updated_at: datetime | None = None
model_config = ConfigDict(from_attributes=True)
class OrganizationMemberResponse(BaseModel):
"""Schema for organization member information."""
user_id: UUID
email: str
first_name: str
last_name: Optional[str] = None
last_name: str | None = None
role: OrganizationRole
is_active: bool
joined_at: datetime
@@ -147,7 +161,8 @@ class OrganizationMemberResponse(BaseModel):
class OrganizationMemberListResponse(BaseModel):
"""Schema for paginated organization member list."""
members: List[OrganizationMemberResponse]
members: list[OrganizationMemberResponse]
total: int
page: int
page_size: int

View File

@@ -1,37 +1,44 @@
"""
Pydantic schemas for user session management.
"""
from datetime import datetime
from typing import Optional
from uuid import UUID
from pydantic import BaseModel, Field, ConfigDict
from pydantic import BaseModel, ConfigDict, Field
class SessionBase(BaseModel):
"""Base schema for user sessions."""
device_name: Optional[str] = Field(None, max_length=255, description="Friendly device name")
device_id: Optional[str] = Field(None, max_length=255, description="Persistent device identifier")
device_name: str | None = Field(
None, max_length=255, description="Friendly device name"
)
device_id: str | None = Field(
None, max_length=255, description="Persistent device identifier"
)
class SessionCreate(SessionBase):
"""Schema for creating a new session (internal use)."""
user_id: UUID
refresh_token_jti: str = Field(..., max_length=255)
ip_address: Optional[str] = Field(None, max_length=45)
user_agent: Optional[str] = Field(None, max_length=500)
ip_address: str | None = Field(None, max_length=45)
user_agent: str | None = Field(None, max_length=500)
last_used_at: datetime
expires_at: datetime
location_city: Optional[str] = Field(None, max_length=100)
location_country: Optional[str] = Field(None, max_length=100)
location_city: str | None = Field(None, max_length=100)
location_country: str | None = Field(None, max_length=100)
class SessionUpdate(BaseModel):
"""Schema for updating a session (internal use)."""
last_used_at: Optional[datetime] = None
is_active: Optional[bool] = None
refresh_token_jti: Optional[str] = None
expires_at: Optional[datetime] = None
last_used_at: datetime | None = None
is_active: bool | None = None
refresh_token_jti: str | None = None
expires_at: datetime | None = None
class SessionResponse(SessionBase):
@@ -40,14 +47,17 @@ class SessionResponse(SessionBase):
This is what users see when they list their active sessions.
"""
id: UUID
ip_address: Optional[str] = None
location_city: Optional[str] = None
location_country: Optional[str] = None
ip_address: str | None = None
location_city: str | None = None
location_country: str | None = None
last_used_at: datetime
created_at: datetime
expires_at: datetime
is_current: bool = Field(default=False, description="Whether this is the current session")
is_current: bool = Field(
default=False, description="Whether this is the current session"
)
model_config = ConfigDict(
from_attributes=True,
@@ -62,14 +72,15 @@ class SessionResponse(SessionBase):
"last_used_at": "2025-10-31T12:00:00Z",
"created_at": "2025-10-30T09:00:00Z",
"expires_at": "2025-11-06T09:00:00Z",
"is_current": True
"is_current": True,
}
}
},
)
class SessionListResponse(BaseModel):
"""Response containing list of sessions."""
sessions: list[SessionResponse]
total: int = Field(..., description="Total number of active sessions")
@@ -84,10 +95,10 @@ class SessionListResponse(BaseModel):
"last_used_at": "2025-10-31T12:00:00Z",
"created_at": "2025-10-30T09:00:00Z",
"expires_at": "2025-11-06T09:00:00Z",
"is_current": True
"is_current": True,
}
],
"total": 1
"total": 1,
}
}
)
@@ -95,17 +106,14 @@ class SessionListResponse(BaseModel):
class LogoutRequest(BaseModel):
"""Request schema for logout endpoint."""
refresh_token: str = Field(
...,
description="Refresh token for the session to logout from",
min_length=10
..., description="Refresh token for the session to logout from", min_length=10
)
model_config = ConfigDict(
json_schema_extra={
"example": {
"refresh_token": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9..."
}
"example": {"refresh_token": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9..."}
}
)
@@ -116,13 +124,14 @@ class AdminSessionResponse(SessionBase):
Includes user information for admin to see who owns each session.
"""
id: UUID
user_id: UUID
user_email: str = Field(..., description="Email of the user who owns this session")
user_full_name: Optional[str] = Field(None, description="Full name of the user")
ip_address: Optional[str] = None
location_city: Optional[str] = None
location_country: Optional[str] = None
user_full_name: str | None = Field(None, description="Full name of the user")
ip_address: str | None = None
location_city: str | None = None
location_country: str | None = None
last_used_at: datetime
created_at: datetime
expires_at: datetime
@@ -144,20 +153,21 @@ class AdminSessionResponse(SessionBase):
"last_used_at": "2025-10-31T12:00:00Z",
"created_at": "2025-10-30T09:00:00Z",
"expires_at": "2025-11-06T09:00:00Z",
"is_active": True
"is_active": True,
}
}
},
)
class DeviceInfo(BaseModel):
"""Device information extracted from request."""
device_name: Optional[str] = None
device_id: Optional[str] = None
ip_address: Optional[str] = None
user_agent: Optional[str] = None
location_city: Optional[str] = None
location_country: Optional[str] = None
device_name: str | None = None
device_id: str | None = None
ip_address: str | None = None
user_agent: str | None = None
location_city: str | None = None
location_country: str | None = None
model_config = ConfigDict(
json_schema_extra={
@@ -167,7 +177,7 @@ class DeviceInfo(BaseModel):
"ip_address": "192.168.1.50",
"user_agent": "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7)...",
"location_city": "San Francisco",
"location_country": "United States"
"location_country": "United States",
}
}
)

View File

@@ -1,9 +1,9 @@
# app/schemas/users.py
from datetime import datetime
from typing import Optional, Dict, Any
from typing import Any
from uuid import UUID
from pydantic import BaseModel, EmailStr, field_validator, ConfigDict, Field
from pydantic import BaseModel, ConfigDict, EmailStr, Field, field_validator
from app.schemas.validators import validate_password_strength, validate_phone_number
@@ -11,12 +11,12 @@ from app.schemas.validators import validate_password_strength, validate_phone_nu
class UserBase(BaseModel):
email: EmailStr
first_name: str
last_name: Optional[str] = None
phone_number: Optional[str] = None
last_name: str | None = None
phone_number: str | None = None
@field_validator('phone_number')
@field_validator("phone_number")
@classmethod
def validate_phone(cls, v: Optional[str]) -> Optional[str]:
def validate_phone(cls, v: str | None) -> str | None:
return validate_phone_number(v)
@@ -24,7 +24,7 @@ class UserCreate(UserBase):
password: str
is_superuser: bool = False
@field_validator('password')
@field_validator("password")
@classmethod
def password_strength(cls, v: str) -> str:
"""Enterprise-grade password strength validation"""
@@ -32,30 +32,32 @@ class UserCreate(UserBase):
class UserUpdate(BaseModel):
first_name: Optional[str] = None
last_name: Optional[str] = None
phone_number: Optional[str] = None
password: Optional[str] = None
preferences: Optional[Dict[str, Any]] = None
is_active: Optional[bool] = None # Changed default from True to None to avoid unintended updates
is_superuser: Optional[bool] = None # Explicitly reject privilege escalation attempts
first_name: str | None = None
last_name: str | None = None
phone_number: str | None = None
password: str | None = None
preferences: dict[str, Any] | None = None
is_active: bool | None = (
None # Changed default from True to None to avoid unintended updates
)
is_superuser: bool | None = None # Explicitly reject privilege escalation attempts
@field_validator('phone_number')
@field_validator("phone_number")
@classmethod
def validate_phone(cls, v: Optional[str]) -> Optional[str]:
def validate_phone(cls, v: str | None) -> str | None:
return validate_phone_number(v)
@field_validator('password')
@field_validator("password")
@classmethod
def password_strength(cls, v: Optional[str]) -> Optional[str]:
def password_strength(cls, v: str | None) -> str | None:
"""Enterprise-grade password strength validation"""
if v is None:
return v
return validate_password_strength(v)
@field_validator('is_superuser')
@field_validator("is_superuser")
@classmethod
def prevent_superuser_modification(cls, v: Optional[bool]) -> Optional[bool]:
def prevent_superuser_modification(cls, v: bool | None) -> bool | None:
"""Prevent users from modifying their superuser status via this schema."""
if v is not None:
raise ValueError("Cannot modify superuser status through user update")
@@ -67,7 +69,7 @@ class UserInDB(UserBase):
is_active: bool
is_superuser: bool
created_at: datetime
updated_at: Optional[datetime] = None
updated_at: datetime | None = None
model_config = ConfigDict(from_attributes=True)
@@ -77,28 +79,28 @@ class UserResponse(UserBase):
is_active: bool
is_superuser: bool
created_at: datetime
updated_at: Optional[datetime] = None
updated_at: datetime | None = None
model_config = ConfigDict(from_attributes=True)
class Token(BaseModel):
access_token: str
refresh_token: Optional[str] = None
refresh_token: str | None = None
token_type: str = "bearer"
user: "UserResponse" # Forward reference since UserResponse is defined above
expires_in: Optional[int] = None # Token expiration in seconds
expires_in: int | None = None # Token expiration in seconds
class TokenPayload(BaseModel):
sub: str # User ID
exp: int # Expiration time
iat: Optional[int] = None # Issued at
jti: Optional[str] = None # JWT ID
is_superuser: Optional[bool] = False
first_name: Optional[str] = None
email: Optional[str] = None
type: Optional[str] = None # Token type (access/refresh)
iat: int | None = None # Issued at
jti: str | None = None # JWT ID
is_superuser: bool | None = False
first_name: str | None = None
email: str | None = None
type: str | None = None # Token type (access/refresh)
class TokenData(BaseModel):
@@ -108,10 +110,11 @@ class TokenData(BaseModel):
class PasswordChange(BaseModel):
"""Schema for changing password (requires current password)."""
current_password: str
new_password: str
@field_validator('new_password')
@field_validator("new_password")
@classmethod
def password_strength(cls, v: str) -> str:
"""Enterprise-grade password strength validation"""
@@ -120,10 +123,11 @@ class PasswordChange(BaseModel):
class PasswordReset(BaseModel):
"""Schema for resetting password (via email token)."""
token: str
new_password: str
@field_validator('new_password')
@field_validator("new_password")
@classmethod
def password_strength(cls, v: str) -> str:
"""Enterprise-grade password strength validation"""
@@ -141,23 +145,19 @@ class RefreshTokenRequest(BaseModel):
class PasswordResetRequest(BaseModel):
"""Schema for requesting a password reset."""
email: EmailStr = Field(..., description="Email address of the account")
model_config = {
"json_schema_extra": {
"example": {
"email": "user@example.com"
}
}
}
model_config = {"json_schema_extra": {"example": {"email": "user@example.com"}}}
class PasswordResetConfirm(BaseModel):
"""Schema for confirming a password reset with token."""
token: str = Field(..., description="Password reset token from email")
new_password: str = Field(..., min_length=8, description="New password")
@field_validator('new_password')
@field_validator("new_password")
@classmethod
def password_strength(cls, v: str) -> str:
"""Enterprise-grade password strength validation"""
@@ -167,7 +167,7 @@ class PasswordResetConfirm(BaseModel):
"json_schema_extra": {
"example": {
"token": "eyJwYXlsb2FkIjp7ImVtYWlsIjoidXNlckBleGFtcGxlLmNvbSIsImV4cCI6MTcxMjM0NTY3OH19",
"new_password": "NewSecurePassword123"
"new_password": "NewSecurePassword123",
}
}
}

View File

@@ -4,19 +4,34 @@ Shared validators for Pydantic schemas.
This module provides reusable validation functions to ensure consistency
across all schemas and avoid code duplication.
"""
import re
from typing import Set
# Common weak passwords that should be rejected
COMMON_PASSWORDS: Set[str] = {
'password', 'password1', 'password123', 'password1234',
'admin', 'admin123', 'admin1234',
'welcome', 'welcome1', 'welcome123',
'qwerty', 'qwerty123',
'12345678', '123456789', '1234567890',
'letmein', 'letmein1', 'letmein123',
'monkey123', 'dragon123',
'passw0rd', 'p@ssw0rd', 'p@ssword',
COMMON_PASSWORDS: set[str] = {
"password",
"password1",
"password123",
"password1234",
"admin",
"admin123",
"admin1234",
"welcome",
"welcome1",
"welcome123",
"qwerty",
"qwerty123",
"12345678",
"123456789",
"1234567890",
"letmein",
"letmein1",
"letmein123",
"monkey123",
"dragon123",
"passw0rd",
"p@ssw0rd",
"p@ssword",
}
@@ -47,18 +62,21 @@ def validate_password_strength(password: str) -> str:
"""
# Check minimum length
if len(password) < 12:
raise ValueError('Password must be at least 12 characters long')
raise ValueError("Password must be at least 12 characters long")
# Check against common passwords (case-insensitive)
if password.lower() in COMMON_PASSWORDS:
raise ValueError('Password is too common. Please choose a stronger password')
raise ValueError("Password is too common. Please choose a stronger password")
# Check for required character types
checks = [
(any(c.islower() for c in password), 'at least one lowercase letter'),
(any(c.isupper() for c in password), 'at least one uppercase letter'),
(any(c.isdigit() for c in password), 'at least one digit'),
(any(c in '!@#$%^&*()_+-=[]{}|;:,.<>?~`' for c in password), 'at least one special character (!@#$%^&*()_+-=[]{}|;:,.<>?~`)')
(any(c.islower() for c in password), "at least one lowercase letter"),
(any(c.isupper() for c in password), "at least one uppercase letter"),
(any(c.isdigit() for c in password), "at least one digit"),
(
any(c in "!@#$%^&*()_+-=[]{}|;:,.<>?~`" for c in password),
"at least one special character (!@#$%^&*()_+-=[]{}|;:,.<>?~`)",
),
]
failed = [msg for check, msg in checks if not check]
@@ -94,10 +112,10 @@ def validate_phone_number(phone: str | None) -> str | None:
# Check for empty strings
if not phone or phone.strip() == "":
raise ValueError('Phone number cannot be empty')
raise ValueError("Phone number cannot be empty")
# Remove all spaces and formatting characters
cleaned = re.sub(r'[\s\-\(\)]', '', phone)
cleaned = re.sub(r"[\s\-\(\)]", "", phone)
# Basic pattern:
# Must start with + or 0
@@ -105,19 +123,19 @@ def validate_phone_number(phone: str | None) -> str | None:
# After 0 must have at least 8 digits
# Maximum total length of 15 digits (international standard)
# Only allowed characters are + at start and digits
pattern = r'^(?:\+[0-9]{8,14}|0[0-9]{8,14})$'
pattern = r"^(?:\+[0-9]{8,14}|0[0-9]{8,14})$"
if not re.match(pattern, cleaned):
raise ValueError('Phone number must start with + or 0 followed by 8-14 digits')
raise ValueError("Phone number must start with + or 0 followed by 8-14 digits")
# Additional validation to catch specific invalid cases
# NOTE: These checks are defensive code - the regex pattern above already catches these cases
if cleaned.count('+') > 1: # pragma: no cover
raise ValueError('Phone number can only contain one + symbol at the start')
if cleaned.count("+") > 1: # pragma: no cover
raise ValueError("Phone number can only contain one + symbol at the start")
# Check for any non-digit characters (except the leading +)
if not all(c.isdigit() for c in cleaned[1:]): # pragma: no cover
raise ValueError('Phone number can only contain digits after the prefix')
raise ValueError("Phone number can only contain digits after the prefix")
return cleaned
@@ -169,16 +187,16 @@ def validate_slug(slug: str) -> str:
ValueError: If slug format is invalid
"""
if not slug or len(slug) < 2:
raise ValueError('Slug must be at least 2 characters long')
raise ValueError("Slug must be at least 2 characters long")
if len(slug) > 50:
raise ValueError('Slug must be at most 50 characters long')
raise ValueError("Slug must be at most 50 characters long")
# Check format
if not re.match(r'^[a-z0-9]+(?:-[a-z0-9]+)*$', slug):
if not re.match(r"^[a-z0-9]+(?:-[a-z0-9]+)*$", slug):
raise ValueError(
'Slug can only contain lowercase letters, numbers, and hyphens. '
'It cannot start or end with a hyphen, and cannot contain consecutive hyphens'
"Slug can only contain lowercase letters, numbers, and hyphens. "
"It cannot start or end with a hyphen, and cannot contain consecutive hyphens"
)
return slug

View File

@@ -1,18 +1,17 @@
# app/services/auth_service.py
import logging
from typing import Optional
from uuid import UUID
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.core.auth import (
verify_password_async,
get_password_hash_async,
TokenExpiredError,
TokenInvalidError,
create_access_token,
create_refresh_token,
TokenExpiredError,
TokenInvalidError
get_password_hash_async,
verify_password_async,
)
from app.core.config import settings
from app.core.exceptions import AuthenticationError
@@ -26,7 +25,9 @@ class AuthService:
"""Service for handling authentication operations"""
@staticmethod
async def authenticate_user(db: AsyncSession, email: str, password: str) -> Optional[User]:
async def authenticate_user(
db: AsyncSession, email: str, password: str
) -> User | None:
"""
Authenticate a user with email and password using async password verification.
@@ -87,7 +88,7 @@ class AuthService:
last_name=user_data.last_name,
phone_number=user_data.phone_number,
is_active=True,
is_superuser=False
is_superuser=False,
)
db.add(user)
@@ -103,8 +104,8 @@ class AuthService:
except Exception as e:
# Rollback on any database errors
await db.rollback()
logger.error(f"Error creating user: {str(e)}", exc_info=True)
raise AuthenticationError(f"Failed to create user: {str(e)}")
logger.error(f"Error creating user: {e!s}", exc_info=True)
raise AuthenticationError(f"Failed to create user: {e!s}")
@staticmethod
def create_tokens(user: User) -> Token:
@@ -121,18 +122,13 @@ class AuthService:
claims = {
"is_superuser": user.is_superuser,
"email": user.email,
"first_name": user.first_name
"first_name": user.first_name,
}
# Create tokens
access_token = create_access_token(
subject=str(user.id),
claims=claims
)
access_token = create_access_token(subject=str(user.id), claims=claims)
refresh_token = create_refresh_token(
subject=str(user.id)
)
refresh_token = create_refresh_token(subject=str(user.id))
# Convert User model to UserResponse schema
user_response = UserResponse.model_validate(user)
@@ -141,7 +137,8 @@ class AuthService:
access_token=access_token,
refresh_token=refresh_token,
user=user_response,
expires_in=settings.ACCESS_TOKEN_EXPIRE_MINUTES * 60 # Convert minutes to seconds
expires_in=settings.ACCESS_TOKEN_EXPIRE_MINUTES
* 60, # Convert minutes to seconds
)
@staticmethod
@@ -180,11 +177,13 @@ class AuthService:
return AuthService.create_tokens(user)
except (TokenExpiredError, TokenInvalidError) as e:
logger.warning(f"Token refresh failed: {str(e)}")
logger.warning(f"Token refresh failed: {e!s}")
raise
@staticmethod
async def change_password(db: AsyncSession, user_id: UUID, current_password: str, new_password: str) -> bool:
async def change_password(
db: AsyncSession, user_id: UUID, current_password: str, new_password: str
) -> bool:
"""
Change a user's password.
@@ -223,5 +222,7 @@ class AuthService:
except Exception as e:
# Rollback on any database errors
await db.rollback()
logger.error(f"Error changing password for user {user_id}: {str(e)}", exc_info=True)
raise AuthenticationError(f"Failed to change password: {str(e)}")
logger.error(
f"Error changing password for user {user_id}: {e!s}", exc_info=True
)
raise AuthenticationError(f"Failed to change password: {e!s}")

View File

@@ -5,9 +5,9 @@ Email service with placeholder implementation.
This service provides email sending functionality with a simple console/log-based
placeholder that can be easily replaced with a real email provider (SendGrid, SES, etc.)
"""
import logging
from abc import ABC, abstractmethod
from typing import List, Optional
from app.core.config import settings
@@ -20,13 +20,12 @@ class EmailBackend(ABC):
@abstractmethod
async def send_email(
self,
to: List[str],
to: list[str],
subject: str,
html_content: str,
text_content: Optional[str] = None
text_content: str | None = None,
) -> bool:
"""Send an email."""
pass
class ConsoleEmailBackend(EmailBackend):
@@ -39,10 +38,10 @@ class ConsoleEmailBackend(EmailBackend):
async def send_email(
self,
to: List[str],
to: list[str],
subject: str,
html_content: str,
text_content: Optional[str] = None
text_content: str | None = None,
) -> bool:
"""
Log email content to console/logs.
@@ -88,10 +87,10 @@ class SMTPEmailBackend(EmailBackend):
async def send_email(
self,
to: List[str],
to: list[str],
subject: str,
html_content: str,
text_content: Optional[str] = None
text_content: str | None = None,
) -> bool:
"""Send email via SMTP."""
# TODO: Implement SMTP sending
@@ -108,7 +107,7 @@ class EmailService:
and can be configured to use different backends (console, SMTP, SendGrid, etc.)
"""
def __init__(self, backend: Optional[EmailBackend] = None):
def __init__(self, backend: EmailBackend | None = None):
"""
Initialize email service with a backend.
@@ -118,10 +117,7 @@ class EmailService:
self.backend = backend or ConsoleEmailBackend()
async def send_password_reset_email(
self,
to_email: str,
reset_token: str,
user_name: Optional[str] = None
self, to_email: str, reset_token: str, user_name: str | None = None
) -> bool:
"""
Send password reset email.
@@ -142,7 +138,7 @@ class EmailService:
# Plain text version
text_content = f"""
Hello{' ' + user_name if user_name else ''},
Hello{" " + user_name if user_name else ""},
You requested a password reset for your account. Click the link below to reset your password:
@@ -177,7 +173,7 @@ The {settings.PROJECT_NAME} Team
<h1>Password Reset</h1>
</div>
<div class="content">
<p>Hello{' ' + user_name if user_name else ''},</p>
<p>Hello{" " + user_name if user_name else ""},</p>
<p>You requested a password reset for your account. Click the button below to reset your password:</p>
<p style="text-align: center;">
<a href="{reset_url}" class="button">Reset Password</a>
@@ -200,17 +196,14 @@ The {settings.PROJECT_NAME} Team
to=[to_email],
subject=subject,
html_content=html_content,
text_content=text_content
text_content=text_content,
)
except Exception as e:
logger.error(f"Failed to send password reset email to {to_email}: {str(e)}")
logger.error(f"Failed to send password reset email to {to_email}: {e!s}")
return False
async def send_email_verification(
self,
to_email: str,
verification_token: str,
user_name: Optional[str] = None
self, to_email: str, verification_token: str, user_name: str | None = None
) -> bool:
"""
Send email verification email.
@@ -224,14 +217,16 @@ The {settings.PROJECT_NAME} Team
True if email sent successfully
"""
# Generate verification URL
verification_url = f"{settings.FRONTEND_URL}/verify-email?token={verification_token}"
verification_url = (
f"{settings.FRONTEND_URL}/verify-email?token={verification_token}"
)
# Prepare email content
subject = "Verify Your Email Address"
# Plain text version
text_content = f"""
Hello{' ' + user_name if user_name else ''},
Hello{" " + user_name if user_name else ""},
Thank you for signing up! Please verify your email address by clicking the link below:
@@ -266,7 +261,7 @@ The {settings.PROJECT_NAME} Team
<h1>Verify Your Email</h1>
</div>
<div class="content">
<p>Hello{' ' + user_name if user_name else ''},</p>
<p>Hello{" " + user_name if user_name else ""},</p>
<p>Thank you for signing up! Please verify your email address by clicking the button below:</p>
<p style="text-align: center;">
<a href="{verification_url}" class="button">Verify Email</a>
@@ -289,10 +284,10 @@ The {settings.PROJECT_NAME} Team
to=[to_email],
subject=subject,
html_content=html_content,
text_content=text_content
text_content=text_content,
)
except Exception as e:
logger.error(f"Failed to send verification email to {to_email}: {str(e)}")
logger.error(f"Failed to send verification email to {to_email}: {e!s}")
return False

View File

@@ -3,8 +3,9 @@ Background job for cleaning up expired sessions.
This service runs periodically to remove old session records from the database.
"""
import logging
from datetime import datetime, timezone
from datetime import UTC, datetime
from app.core.database import SessionLocal
from app.crud.session import session as session_crud
@@ -39,7 +40,7 @@ async def cleanup_expired_sessions(keep_days: int = 30) -> int:
return count
except Exception as e:
logger.error(f"Error during session cleanup: {str(e)}", exc_info=True)
logger.error(f"Error during session cleanup: {e!s}", exc_info=True)
return 0
@@ -52,20 +53,21 @@ async def get_session_statistics() -> dict:
"""
async with SessionLocal() as db:
try:
from sqlalchemy import func, select
from app.models.user_session import UserSession
from sqlalchemy import select, func
total_result = await db.execute(select(func.count(UserSession.id)))
total_sessions = total_result.scalar_one()
active_result = await db.execute(
select(func.count(UserSession.id)).where(UserSession.is_active == True)
select(func.count(UserSession.id)).where(UserSession.is_active)
)
active_sessions = active_result.scalar_one()
expired_result = await db.execute(
select(func.count(UserSession.id)).where(
UserSession.expires_at < datetime.now(timezone.utc)
UserSession.expires_at < datetime.now(UTC)
)
)
expired_sessions = expired_result.scalar_one()
@@ -82,5 +84,5 @@ async def get_session_statistics() -> dict:
return stats
except Exception as e:
logger.error(f"Error getting session statistics: {str(e)}", exc_info=True)
logger.error(f"Error getting session statistics: {e!s}", exc_info=True)
return {}

View File

@@ -2,7 +2,8 @@
Authentication utilities for testing.
This module provides tools to bypass FastAPI's authentication in tests.
"""
from typing import Callable, Dict, Optional
from collections.abc import Callable
from fastapi import FastAPI
from fastapi.security import OAuth2PasswordBearer
@@ -13,9 +14,9 @@ from app.models.user import User
def create_test_auth_client(
app: FastAPI,
test_user: User,
extra_overrides: Optional[Dict[Callable, Callable]] = None
app: FastAPI,
test_user: User,
extra_overrides: dict[Callable, Callable] | None = None,
) -> TestClient:
"""
Create a test client with authentication pre-configured.
@@ -47,10 +48,7 @@ def create_test_auth_client(
return TestClient(app)
def create_test_optional_auth_client(
app: FastAPI,
test_user: User
) -> TestClient:
def create_test_optional_auth_client(app: FastAPI, test_user: User) -> TestClient:
"""
Create a test client with optional authentication pre-configured.
@@ -70,10 +68,7 @@ def create_test_optional_auth_client(
return TestClient(app)
def create_test_superuser_client(
app: FastAPI,
test_user: User
) -> TestClient:
def create_test_superuser_client(app: FastAPI, test_user: User) -> TestClient:
"""
Create a test client with superuser authentication pre-configured.
@@ -120,7 +115,7 @@ def cleanup_test_client_auth(app: FastAPI) -> None:
auth_deps = [
get_current_user,
get_optional_current_user,
OAuth2PasswordBearer(tokenUrl="/api/v1/auth/login")
OAuth2PasswordBearer(tokenUrl="/api/v1/auth/login"),
]
# Remove overrides

View File

@@ -1,8 +1,8 @@
"""
Utility functions for extracting and parsing device information from HTTP requests.
"""
import re
from typing import Optional
from fastapi import Request
@@ -19,11 +19,11 @@ def extract_device_info(request: Request) -> DeviceInfo:
Returns:
DeviceInfo object with parsed device information
"""
user_agent = request.headers.get('user-agent', '')
user_agent = request.headers.get("user-agent", "")
device_info = DeviceInfo(
device_name=parse_device_name(user_agent),
device_id=request.headers.get('x-device-id'), # Client must send this header
device_id=request.headers.get("x-device-id"), # Client must send this header
ip_address=get_client_ip(request),
user_agent=user_agent[:500] if user_agent else None, # Truncate to max length
location_city=None, # Can be populated via IP geolocation service
@@ -33,7 +33,7 @@ def extract_device_info(request: Request) -> DeviceInfo:
return device_info
def parse_device_name(user_agent: str) -> Optional[str]:
def parse_device_name(user_agent: str) -> str | None:
"""
Parse user agent string to extract a friendly device name.
@@ -54,48 +54,48 @@ def parse_device_name(user_agent: str) -> Optional[str]:
user_agent_lower = user_agent.lower()
# Mobile devices (check first, as they can contain desktop patterns too)
if 'iphone' in user_agent_lower:
if "iphone" in user_agent_lower:
return "iPhone"
elif 'ipad' in user_agent_lower:
elif "ipad" in user_agent_lower:
return "iPad"
elif 'android' in user_agent_lower:
elif "android" in user_agent_lower:
# Try to extract device model
android_match = re.search(r'android.*;\s*([^)]+)\s*build', user_agent_lower)
android_match = re.search(r"android.*;\s*([^)]+)\s*build", user_agent_lower)
if android_match:
device_model = android_match.group(1).strip()
return f"Android ({device_model.title()})"
return "Android device"
elif 'windows phone' in user_agent_lower:
elif "windows phone" in user_agent_lower:
return "Windows Phone"
# Tablets (check before desktop, as some tablets contain "android")
elif 'tablet' in user_agent_lower:
elif "tablet" in user_agent_lower:
return "Tablet"
# Smart TVs (check before desktop OS patterns)
elif any(tv in user_agent_lower for tv in ['smart-tv', 'smarttv']):
elif any(tv in user_agent_lower for tv in ["smart-tv", "smarttv"]):
return "Smart TV"
# Game consoles (check before desktop OS patterns, as Xbox contains "Windows")
elif 'playstation' in user_agent_lower:
elif "playstation" in user_agent_lower:
return "PlayStation"
elif 'xbox' in user_agent_lower:
elif "xbox" in user_agent_lower:
return "Xbox"
elif 'nintendo' in user_agent_lower:
elif "nintendo" in user_agent_lower:
return "Nintendo"
# Desktop operating systems
elif 'macintosh' in user_agent_lower or 'mac os x' in user_agent_lower:
elif "macintosh" in user_agent_lower or "mac os x" in user_agent_lower:
# Try to extract browser
browser = extract_browser(user_agent)
return f"{browser} on Mac" if browser else "Mac"
elif 'windows' in user_agent_lower:
elif "windows" in user_agent_lower:
browser = extract_browser(user_agent)
return f"{browser} on Windows" if browser else "Windows PC"
elif 'linux' in user_agent_lower and 'android' not in user_agent_lower:
elif "linux" in user_agent_lower and "android" not in user_agent_lower:
browser = extract_browser(user_agent)
return f"{browser} on Linux" if browser else "Linux"
elif 'cros' in user_agent_lower:
elif "cros" in user_agent_lower:
return "Chromebook"
# Fallback: just return browser name if detected
@@ -106,7 +106,7 @@ def parse_device_name(user_agent: str) -> Optional[str]:
return "Unknown device"
def extract_browser(user_agent: str) -> Optional[str]:
def extract_browser(user_agent: str) -> str | None:
"""
Extract browser name from user agent string.
@@ -126,26 +126,26 @@ def extract_browser(user_agent: str) -> Optional[str]:
user_agent_lower = user_agent.lower()
# Check specific browsers (order matters - check Edge before Chrome!)
if 'edg/' in user_agent_lower or 'edge/' in user_agent_lower:
if "edg/" in user_agent_lower or "edge/" in user_agent_lower:
return "Edge"
elif 'opr/' in user_agent_lower or 'opera' in user_agent_lower:
elif "opr/" in user_agent_lower or "opera" in user_agent_lower:
return "Opera"
elif 'chrome/' in user_agent_lower:
elif "chrome/" in user_agent_lower:
return "Chrome"
elif 'safari/' in user_agent_lower:
elif "safari/" in user_agent_lower:
# Make sure it's actually Safari, not Chrome (which also contains "Safari")
if 'chrome' not in user_agent_lower:
if "chrome" not in user_agent_lower:
return "Safari"
return None
elif 'firefox/' in user_agent_lower:
elif "firefox/" in user_agent_lower:
return "Firefox"
elif 'msie' in user_agent_lower or 'trident/' in user_agent_lower:
elif "msie" in user_agent_lower or "trident/" in user_agent_lower:
return "Internet Explorer"
return None
def get_client_ip(request: Request) -> Optional[str]:
def get_client_ip(request: Request) -> str | None:
"""
Extract client IP address from request, considering proxy headers.
@@ -163,14 +163,14 @@ def get_client_ip(request: Request) -> Optional[str]:
- request.client.host is fallback for direct connections
"""
# Check X-Forwarded-For (common in proxied environments)
x_forwarded_for = request.headers.get('x-forwarded-for')
x_forwarded_for = request.headers.get("x-forwarded-for")
if x_forwarded_for:
# Get the first IP (original client)
client_ip = x_forwarded_for.split(',')[0].strip()
client_ip = x_forwarded_for.split(",")[0].strip()
return client_ip
# Check X-Real-IP (used by some proxies like nginx)
x_real_ip = request.headers.get('x-real-ip')
x_real_ip = request.headers.get("x-real-ip")
if x_real_ip:
return x_real_ip.strip()
@@ -195,9 +195,17 @@ def is_mobile_device(user_agent: str) -> bool:
return False
mobile_patterns = [
'mobile', 'android', 'iphone', 'ipad', 'ipod',
'blackberry', 'windows phone', 'webos', 'opera mini',
'iemobile', 'mobile safari'
"mobile",
"android",
"iphone",
"ipad",
"ipod",
"blackberry",
"windows phone",
"webos",
"opera mini",
"iemobile",
"mobile safari",
]
user_agent_lower = user_agent.lower()
@@ -220,7 +228,7 @@ def get_device_type(user_agent: str) -> str:
user_agent_lower = user_agent.lower()
# Check for tablets first (they can contain "mobile" too)
if 'ipad' in user_agent_lower or 'tablet' in user_agent_lower:
if "ipad" in user_agent_lower or "tablet" in user_agent_lower:
return "tablet"
# Check for mobile
@@ -228,7 +236,7 @@ def get_device_type(user_agent: str) -> str:
return "mobile"
# Check for desktop OS patterns
if any(os in user_agent_lower for os in ['windows', 'macintosh', 'linux', 'cros']):
if any(os in user_agent_lower for os in ["windows", "macintosh", "linux", "cros"]):
return "desktop"
return "other"

View File

@@ -5,18 +5,21 @@ This module provides utilities for creating and verifying signed tokens,
useful for operations like file uploads, password resets, or any other
time-limited, single-use operations.
"""
import base64
import hashlib
import hmac
import json
import secrets
import time
from typing import Dict, Any, Optional
from typing import Any
from app.core.config import settings
def create_upload_token(file_path: str, content_type: str, expires_in: int = 300) -> str:
def create_upload_token(
file_path: str, content_type: str, expires_in: int = 300
) -> str:
"""
Create a signed token for secure file uploads.
@@ -40,34 +43,29 @@ def create_upload_token(file_path: str, content_type: str, expires_in: int = 300
"path": file_path,
"content_type": content_type,
"exp": int(time.time()) + expires_in,
"nonce": secrets.token_hex(8) # Add randomness to prevent token reuse
"nonce": secrets.token_hex(8), # Add randomness to prevent token reuse
}
# Convert to JSON and encode
payload_bytes = json.dumps(payload).encode('utf-8')
payload_bytes = json.dumps(payload).encode("utf-8")
# Create a signature using HMAC-SHA256 for security
# This prevents length extension attacks that plain SHA-256 is vulnerable to
signature = hmac.new(
settings.SECRET_KEY.encode('utf-8'),
payload_bytes,
hashlib.sha256
settings.SECRET_KEY.encode("utf-8"), payload_bytes, hashlib.sha256
).hexdigest()
# Combine payload and signature
token_data = {
"payload": payload,
"signature": signature
}
token_data = {"payload": payload, "signature": signature}
# Encode the final token
token_json = json.dumps(token_data)
token = base64.urlsafe_b64encode(token_json.encode('utf-8')).decode('utf-8')
token = base64.urlsafe_b64encode(token_json.encode("utf-8")).decode("utf-8")
return token
def verify_upload_token(token: str) -> Optional[Dict[str, Any]]:
def verify_upload_token(token: str) -> dict[str, Any] | None:
"""
Verify an upload token and return the payload if valid.
@@ -88,7 +86,7 @@ def verify_upload_token(token: str) -> Optional[Dict[str, Any]]:
"""
try:
# Decode the token
token_json = base64.urlsafe_b64decode(token.encode('utf-8')).decode('utf-8')
token_json = base64.urlsafe_b64decode(token.encode("utf-8")).decode("utf-8")
token_data = json.loads(token_json)
# Extract payload and signature
@@ -96,11 +94,9 @@ def verify_upload_token(token: str) -> Optional[Dict[str, Any]]:
signature = token_data["signature"]
# Verify signature using HMAC and constant-time comparison
payload_bytes = json.dumps(payload).encode('utf-8')
payload_bytes = json.dumps(payload).encode("utf-8")
expected_signature = hmac.new(
settings.SECRET_KEY.encode('utf-8'),
payload_bytes,
hashlib.sha256
settings.SECRET_KEY.encode("utf-8"), payload_bytes, hashlib.sha256
).hexdigest()
if not hmac.compare_digest(signature, expected_signature):
@@ -136,34 +132,29 @@ def create_password_reset_token(email: str, expires_in: int = 3600) -> str:
"email": email,
"exp": int(time.time()) + expires_in,
"nonce": secrets.token_hex(16), # Extra randomness
"purpose": "password_reset"
"purpose": "password_reset",
}
# Convert to JSON and encode
payload_bytes = json.dumps(payload).encode('utf-8')
payload_bytes = json.dumps(payload).encode("utf-8")
# Create a signature using HMAC-SHA256 for security
# This prevents length extension attacks that plain SHA-256 is vulnerable to
signature = hmac.new(
settings.SECRET_KEY.encode('utf-8'),
payload_bytes,
hashlib.sha256
settings.SECRET_KEY.encode("utf-8"), payload_bytes, hashlib.sha256
).hexdigest()
# Combine payload and signature
token_data = {
"payload": payload,
"signature": signature
}
token_data = {"payload": payload, "signature": signature}
# Encode the final token
token_json = json.dumps(token_data)
token = base64.urlsafe_b64encode(token_json.encode('utf-8')).decode('utf-8')
token = base64.urlsafe_b64encode(token_json.encode("utf-8")).decode("utf-8")
return token
def verify_password_reset_token(token: str) -> Optional[str]:
def verify_password_reset_token(token: str) -> str | None:
"""
Verify a password reset token and return the email if valid.
@@ -182,7 +173,7 @@ def verify_password_reset_token(token: str) -> Optional[str]:
"""
try:
# Decode the token
token_json = base64.urlsafe_b64decode(token.encode('utf-8')).decode('utf-8')
token_json = base64.urlsafe_b64decode(token.encode("utf-8")).decode("utf-8")
token_data = json.loads(token_json)
# Extract payload and signature
@@ -194,11 +185,9 @@ def verify_password_reset_token(token: str) -> Optional[str]:
return None
# Verify signature using HMAC and constant-time comparison
payload_bytes = json.dumps(payload).encode('utf-8')
payload_bytes = json.dumps(payload).encode("utf-8")
expected_signature = hmac.new(
settings.SECRET_KEY.encode('utf-8'),
payload_bytes,
hashlib.sha256
settings.SECRET_KEY.encode("utf-8"), payload_bytes, hashlib.sha256
).hexdigest()
if not hmac.compare_digest(signature, expected_signature):
@@ -234,34 +223,29 @@ def create_email_verification_token(email: str, expires_in: int = 86400) -> str:
"email": email,
"exp": int(time.time()) + expires_in,
"nonce": secrets.token_hex(16),
"purpose": "email_verification"
"purpose": "email_verification",
}
# Convert to JSON and encode
payload_bytes = json.dumps(payload).encode('utf-8')
payload_bytes = json.dumps(payload).encode("utf-8")
# Create a signature using HMAC-SHA256 for security
# This prevents length extension attacks that plain SHA-256 is vulnerable to
signature = hmac.new(
settings.SECRET_KEY.encode('utf-8'),
payload_bytes,
hashlib.sha256
settings.SECRET_KEY.encode("utf-8"), payload_bytes, hashlib.sha256
).hexdigest()
# Combine payload and signature
token_data = {
"payload": payload,
"signature": signature
}
token_data = {"payload": payload, "signature": signature}
# Encode the final token
token_json = json.dumps(token_data)
token = base64.urlsafe_b64encode(token_json.encode('utf-8')).decode('utf-8')
token = base64.urlsafe_b64encode(token_json.encode("utf-8")).decode("utf-8")
return token
def verify_email_verification_token(token: str) -> Optional[str]:
def verify_email_verification_token(token: str) -> str | None:
"""
Verify an email verification token and return the email if valid.
@@ -280,7 +264,7 @@ def verify_email_verification_token(token: str) -> Optional[str]:
"""
try:
# Decode the token
token_json = base64.urlsafe_b64decode(token.encode('utf-8')).decode('utf-8')
token_json = base64.urlsafe_b64decode(token.encode("utf-8")).decode("utf-8")
token_data = json.loads(token_json)
# Extract payload and signature
@@ -292,11 +276,9 @@ def verify_email_verification_token(token: str) -> Optional[str]:
return None
# Verify signature using HMAC and constant-time comparison
payload_bytes = json.dumps(payload).encode('utf-8')
payload_bytes = json.dumps(payload).encode("utf-8")
expected_signature = hmac.new(
settings.SECRET_KEY.encode('utf-8'),
payload_bytes,
hashlib.sha256
settings.SECRET_KEY.encode("utf-8"), payload_bytes, hashlib.sha256
).hexdigest()
if not hmac.compare_digest(signature, expected_signature):

View File

@@ -9,17 +9,19 @@ from app.core.database import Base
logger = logging.getLogger(__name__)
def get_test_engine():
"""Create an SQLite in-memory engine specifically for testing"""
test_engine = create_engine(
"sqlite:///:memory:",
connect_args={"check_same_thread": False},
poolclass=StaticPool, # Use static pool for in-memory testing
echo=False
echo=False,
)
return test_engine
def setup_test_db():
"""Create a test database and session factory"""
# Create a new engine for this test run
@@ -30,14 +32,12 @@ def setup_test_db():
# Create session factory
TestingSessionLocal = sessionmaker(
autocommit=False,
autoflush=False,
bind=test_engine,
expire_on_commit=False
autocommit=False, autoflush=False, bind=test_engine, expire_on_commit=False
)
return test_engine, TestingSessionLocal
def teardown_test_db(engine):
"""Clean up after tests"""
# Drop all tables
@@ -46,13 +46,14 @@ def teardown_test_db(engine):
# Dispose of engine
engine.dispose()
async def get_async_test_engine():
"""Create an async SQLite in-memory engine specifically for testing"""
test_engine = create_async_engine(
"sqlite+aiosqlite:///:memory:",
connect_args={"check_same_thread": False},
poolclass=StaticPool, # Use static pool for in-memory testing
echo=False
echo=False,
)
return test_engine
@@ -69,7 +70,7 @@ async def setup_async_test_db():
autoflush=False,
bind=test_engine,
expire_on_commit=False,
class_=AsyncSession
class_=AsyncSession,
)
return test_engine, AsyncTestingSessionLocal