Add pyproject.toml for consolidated project configuration and replace Black, isort, and Flake8 with Ruff
- Introduced `pyproject.toml` to centralize backend tool configurations (e.g., Ruff, mypy, coverage, pytest). - Replaced Black, isort, and Flake8 with Ruff for linting, formatting, and import sorting. - Updated `requirements.txt` to include Ruff and remove replaced tools. - Added `Makefile` to streamline development workflows with commands for linting, formatting, type-checking, testing, and cleanup.
This commit is contained in:
@@ -2,12 +2,11 @@ import sys
|
||||
from logging.config import fileConfig
|
||||
from pathlib import Path
|
||||
|
||||
from sqlalchemy import engine_from_config, pool, text, create_engine
|
||||
from alembic import context
|
||||
from sqlalchemy import create_engine, engine_from_config, pool, text
|
||||
from sqlalchemy.engine.url import make_url
|
||||
from sqlalchemy.exc import OperationalError
|
||||
|
||||
from alembic import context
|
||||
|
||||
# Get the path to the app directory (parent of 'alembic')
|
||||
app_dir = Path(__file__).resolve().parent.parent
|
||||
# Add the app directory to Python path
|
||||
@@ -66,7 +65,9 @@ def ensure_database_exists(db_url: str) -> None:
|
||||
admin_url = url.set(database="postgres")
|
||||
|
||||
# CREATE DATABASE cannot run inside a transaction
|
||||
admin_engine = create_engine(str(admin_url), isolation_level="AUTOCOMMIT", poolclass=pool.NullPool)
|
||||
admin_engine = create_engine(
|
||||
str(admin_url), isolation_level="AUTOCOMMIT", poolclass=pool.NullPool
|
||||
)
|
||||
try:
|
||||
with admin_engine.connect() as conn:
|
||||
exists = conn.execute(
|
||||
@@ -122,9 +123,7 @@ def run_migrations_online() -> None:
|
||||
)
|
||||
|
||||
with connectable.connect() as connection:
|
||||
context.configure(
|
||||
connection=connection, target_metadata=target_metadata
|
||||
)
|
||||
context.configure(connection=connection, target_metadata=target_metadata)
|
||||
|
||||
with context.begin_transaction():
|
||||
context.run_migrations()
|
||||
@@ -133,4 +132,4 @@ def run_migrations_online() -> None:
|
||||
if context.is_offline_mode():
|
||||
run_migrations_offline()
|
||||
else:
|
||||
run_migrations_online()
|
||||
run_migrations_online()
|
||||
|
||||
@@ -5,17 +5,17 @@ Revises: fbf6318a8a36
|
||||
Create Date: 2025-11-01 04:15:25.367010
|
||||
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
|
||||
from collections.abc import Sequence
|
||||
|
||||
import sqlalchemy as sa
|
||||
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = '1174fffbe3e4'
|
||||
down_revision: Union[str, None] = 'fbf6318a8a36'
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
revision: str = "1174fffbe3e4"
|
||||
down_revision: str | None = "fbf6318a8a36"
|
||||
branch_labels: str | Sequence[str] | None = None
|
||||
depends_on: str | Sequence[str] | None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
@@ -24,46 +24,46 @@ def upgrade() -> None:
|
||||
# Index for session cleanup queries
|
||||
# Optimizes: DELETE WHERE is_active = FALSE AND expires_at < now AND created_at < cutoff
|
||||
op.create_index(
|
||||
'ix_user_sessions_cleanup',
|
||||
'user_sessions',
|
||||
['is_active', 'expires_at', 'created_at'],
|
||||
"ix_user_sessions_cleanup",
|
||||
"user_sessions",
|
||||
["is_active", "expires_at", "created_at"],
|
||||
unique=False,
|
||||
postgresql_where=sa.text('is_active = false')
|
||||
postgresql_where=sa.text("is_active = false"),
|
||||
)
|
||||
|
||||
# Index for user search queries (basic trigram support without pg_trgm extension)
|
||||
# Optimizes: WHERE email ILIKE '%search%' OR first_name ILIKE '%search%'
|
||||
# Note: For better performance, consider enabling pg_trgm extension
|
||||
op.create_index(
|
||||
'ix_users_email_lower',
|
||||
'users',
|
||||
[sa.text('LOWER(email)')],
|
||||
"ix_users_email_lower",
|
||||
"users",
|
||||
[sa.text("LOWER(email)")],
|
||||
unique=False,
|
||||
postgresql_where=sa.text('deleted_at IS NULL')
|
||||
postgresql_where=sa.text("deleted_at IS NULL"),
|
||||
)
|
||||
|
||||
op.create_index(
|
||||
'ix_users_first_name_lower',
|
||||
'users',
|
||||
[sa.text('LOWER(first_name)')],
|
||||
"ix_users_first_name_lower",
|
||||
"users",
|
||||
[sa.text("LOWER(first_name)")],
|
||||
unique=False,
|
||||
postgresql_where=sa.text('deleted_at IS NULL')
|
||||
postgresql_where=sa.text("deleted_at IS NULL"),
|
||||
)
|
||||
|
||||
op.create_index(
|
||||
'ix_users_last_name_lower',
|
||||
'users',
|
||||
[sa.text('LOWER(last_name)')],
|
||||
"ix_users_last_name_lower",
|
||||
"users",
|
||||
[sa.text("LOWER(last_name)")],
|
||||
unique=False,
|
||||
postgresql_where=sa.text('deleted_at IS NULL')
|
||||
postgresql_where=sa.text("deleted_at IS NULL"),
|
||||
)
|
||||
|
||||
# Index for organization search
|
||||
op.create_index(
|
||||
'ix_organizations_name_lower',
|
||||
'organizations',
|
||||
[sa.text('LOWER(name)')],
|
||||
unique=False
|
||||
"ix_organizations_name_lower",
|
||||
"organizations",
|
||||
[sa.text("LOWER(name)")],
|
||||
unique=False,
|
||||
)
|
||||
|
||||
|
||||
@@ -71,8 +71,8 @@ def downgrade() -> None:
|
||||
"""Remove performance indexes."""
|
||||
|
||||
# Drop indexes in reverse order
|
||||
op.drop_index('ix_organizations_name_lower', table_name='organizations')
|
||||
op.drop_index('ix_users_last_name_lower', table_name='users')
|
||||
op.drop_index('ix_users_first_name_lower', table_name='users')
|
||||
op.drop_index('ix_users_email_lower', table_name='users')
|
||||
op.drop_index('ix_user_sessions_cleanup', table_name='user_sessions')
|
||||
op.drop_index("ix_organizations_name_lower", table_name="organizations")
|
||||
op.drop_index("ix_users_last_name_lower", table_name="users")
|
||||
op.drop_index("ix_users_first_name_lower", table_name="users")
|
||||
op.drop_index("ix_users_email_lower", table_name="users")
|
||||
op.drop_index("ix_user_sessions_cleanup", table_name="user_sessions")
|
||||
|
||||
@@ -5,30 +5,32 @@ Revises: 9e4f2a1b8c7d
|
||||
Create Date: 2025-10-30 16:40:21.000021
|
||||
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
|
||||
from collections.abc import Sequence
|
||||
|
||||
import sqlalchemy as sa
|
||||
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = '2d0fcec3b06d'
|
||||
down_revision: Union[str, None] = '9e4f2a1b8c7d'
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
revision: str = "2d0fcec3b06d"
|
||||
down_revision: str | None = "9e4f2a1b8c7d"
|
||||
branch_labels: str | Sequence[str] | None = None
|
||||
depends_on: str | Sequence[str] | None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Add deleted_at column for soft deletes
|
||||
op.add_column('users', sa.Column('deleted_at', sa.DateTime(timezone=True), nullable=True))
|
||||
op.add_column(
|
||||
"users", sa.Column("deleted_at", sa.DateTime(timezone=True), nullable=True)
|
||||
)
|
||||
|
||||
# Add index on deleted_at for efficient queries
|
||||
op.create_index('ix_users_deleted_at', 'users', ['deleted_at'])
|
||||
op.create_index("ix_users_deleted_at", "users", ["deleted_at"])
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Remove index
|
||||
op.drop_index('ix_users_deleted_at', table_name='users')
|
||||
op.drop_index("ix_users_deleted_at", table_name="users")
|
||||
|
||||
# Remove column
|
||||
op.drop_column('users', 'deleted_at')
|
||||
op.drop_column("users", "deleted_at")
|
||||
|
||||
@@ -5,42 +5,42 @@ Revises: 7396957cbe80
|
||||
Create Date: 2025-02-28 09:19:33.212278
|
||||
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
|
||||
from collections.abc import Sequence
|
||||
|
||||
import sqlalchemy as sa
|
||||
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = '38bf9e7e74b3'
|
||||
down_revision: Union[str, None] = '7396957cbe80'
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
revision: str = "38bf9e7e74b3"
|
||||
down_revision: str | None = "7396957cbe80"
|
||||
branch_labels: str | Sequence[str] | None = None
|
||||
depends_on: str | Sequence[str] | None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
|
||||
op.create_table('users',
|
||||
sa.Column('email', sa.String(), nullable=False),
|
||||
sa.Column('password_hash', sa.String(), nullable=False),
|
||||
sa.Column('first_name', sa.String(), nullable=False),
|
||||
sa.Column('last_name', sa.String(), nullable=True),
|
||||
sa.Column('phone_number', sa.String(), nullable=True),
|
||||
sa.Column('is_active', sa.Boolean(), nullable=False),
|
||||
sa.Column('is_superuser', sa.Boolean(), nullable=False),
|
||||
sa.Column('preferences', sa.JSON(), nullable=True),
|
||||
sa.Column('id', sa.UUID(), nullable=False),
|
||||
sa.Column('created_at', sa.DateTime(timezone=True), nullable=False),
|
||||
sa.Column('updated_at', sa.DateTime(timezone=True), nullable=False),
|
||||
sa.PrimaryKeyConstraint('id')
|
||||
op.create_table(
|
||||
"users",
|
||||
sa.Column("email", sa.String(), nullable=False),
|
||||
sa.Column("password_hash", sa.String(), nullable=False),
|
||||
sa.Column("first_name", sa.String(), nullable=False),
|
||||
sa.Column("last_name", sa.String(), nullable=True),
|
||||
sa.Column("phone_number", sa.String(), nullable=True),
|
||||
sa.Column("is_active", sa.Boolean(), nullable=False),
|
||||
sa.Column("is_superuser", sa.Boolean(), nullable=False),
|
||||
sa.Column("preferences", sa.JSON(), nullable=True),
|
||||
sa.Column("id", sa.UUID(), nullable=False),
|
||||
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False),
|
||||
sa.Column("updated_at", sa.DateTime(timezone=True), nullable=False),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
)
|
||||
op.create_index(op.f('ix_users_email'), 'users', ['email'], unique=True)
|
||||
op.create_index(op.f("ix_users_email"), "users", ["email"], unique=True)
|
||||
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.drop_index(op.f('ix_users_email'), table_name='users')
|
||||
op.drop_table('users')
|
||||
op.drop_index(op.f("ix_users_email"), table_name="users")
|
||||
op.drop_table("users")
|
||||
# ### end Alembic commands ###
|
||||
|
||||
@@ -5,98 +5,85 @@ Revises: b76c725fc3cf
|
||||
Create Date: 2025-10-31 07:41:18.729544
|
||||
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
|
||||
from collections.abc import Sequence
|
||||
|
||||
import sqlalchemy as sa
|
||||
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = '549b50ea888d'
|
||||
down_revision: Union[str, None] = 'b76c725fc3cf'
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
revision: str = "549b50ea888d"
|
||||
down_revision: str | None = "b76c725fc3cf"
|
||||
branch_labels: str | Sequence[str] | None = None
|
||||
depends_on: str | Sequence[str] | None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Create user_sessions table for per-device session management
|
||||
op.create_table(
|
||||
'user_sessions',
|
||||
sa.Column('id', sa.UUID(), nullable=False),
|
||||
sa.Column('user_id', sa.UUID(), nullable=False),
|
||||
sa.Column('refresh_token_jti', sa.String(length=255), nullable=False),
|
||||
sa.Column('device_name', sa.String(length=255), nullable=True),
|
||||
sa.Column('device_id', sa.String(length=255), nullable=True),
|
||||
sa.Column('ip_address', sa.String(length=45), nullable=True),
|
||||
sa.Column('user_agent', sa.String(length=500), nullable=True),
|
||||
sa.Column('last_used_at', sa.DateTime(timezone=True), nullable=False),
|
||||
sa.Column('expires_at', sa.DateTime(timezone=True), nullable=False),
|
||||
sa.Column('is_active', sa.Boolean(), nullable=False, server_default='true'),
|
||||
sa.Column('location_city', sa.String(length=100), nullable=True),
|
||||
sa.Column('location_country', sa.String(length=100), nullable=True),
|
||||
sa.Column('created_at', sa.DateTime(timezone=True), nullable=False),
|
||||
sa.Column('updated_at', sa.DateTime(timezone=True), nullable=False),
|
||||
sa.PrimaryKeyConstraint('id')
|
||||
"user_sessions",
|
||||
sa.Column("id", sa.UUID(), nullable=False),
|
||||
sa.Column("user_id", sa.UUID(), nullable=False),
|
||||
sa.Column("refresh_token_jti", sa.String(length=255), nullable=False),
|
||||
sa.Column("device_name", sa.String(length=255), nullable=True),
|
||||
sa.Column("device_id", sa.String(length=255), nullable=True),
|
||||
sa.Column("ip_address", sa.String(length=45), nullable=True),
|
||||
sa.Column("user_agent", sa.String(length=500), nullable=True),
|
||||
sa.Column("last_used_at", sa.DateTime(timezone=True), nullable=False),
|
||||
sa.Column("expires_at", sa.DateTime(timezone=True), nullable=False),
|
||||
sa.Column("is_active", sa.Boolean(), nullable=False, server_default="true"),
|
||||
sa.Column("location_city", sa.String(length=100), nullable=True),
|
||||
sa.Column("location_country", sa.String(length=100), nullable=True),
|
||||
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False),
|
||||
sa.Column("updated_at", sa.DateTime(timezone=True), nullable=False),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
)
|
||||
|
||||
# Create foreign key to users table
|
||||
op.create_foreign_key(
|
||||
'fk_user_sessions_user_id',
|
||||
'user_sessions',
|
||||
'users',
|
||||
['user_id'],
|
||||
['id'],
|
||||
ondelete='CASCADE'
|
||||
"fk_user_sessions_user_id",
|
||||
"user_sessions",
|
||||
"users",
|
||||
["user_id"],
|
||||
["id"],
|
||||
ondelete="CASCADE",
|
||||
)
|
||||
|
||||
# Create indexes for performance
|
||||
# 1. Lookup session by refresh token JTI (most common query)
|
||||
op.create_index(
|
||||
'ix_user_sessions_jti',
|
||||
'user_sessions',
|
||||
['refresh_token_jti'],
|
||||
unique=True
|
||||
"ix_user_sessions_jti", "user_sessions", ["refresh_token_jti"], unique=True
|
||||
)
|
||||
|
||||
# 2. Lookup sessions by user ID
|
||||
op.create_index(
|
||||
'ix_user_sessions_user_id',
|
||||
'user_sessions',
|
||||
['user_id']
|
||||
)
|
||||
op.create_index("ix_user_sessions_user_id", "user_sessions", ["user_id"])
|
||||
|
||||
# 3. Composite index for active sessions by user
|
||||
op.create_index(
|
||||
'ix_user_sessions_user_active',
|
||||
'user_sessions',
|
||||
['user_id', 'is_active']
|
||||
"ix_user_sessions_user_active", "user_sessions", ["user_id", "is_active"]
|
||||
)
|
||||
|
||||
# 4. Index on expires_at for cleanup job
|
||||
op.create_index(
|
||||
'ix_user_sessions_expires_at',
|
||||
'user_sessions',
|
||||
['expires_at']
|
||||
)
|
||||
op.create_index("ix_user_sessions_expires_at", "user_sessions", ["expires_at"])
|
||||
|
||||
# 5. Composite index for active session lookup by JTI
|
||||
op.create_index(
|
||||
'ix_user_sessions_jti_active',
|
||||
'user_sessions',
|
||||
['refresh_token_jti', 'is_active']
|
||||
"ix_user_sessions_jti_active",
|
||||
"user_sessions",
|
||||
["refresh_token_jti", "is_active"],
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Drop indexes first
|
||||
op.drop_index('ix_user_sessions_jti_active', table_name='user_sessions')
|
||||
op.drop_index('ix_user_sessions_expires_at', table_name='user_sessions')
|
||||
op.drop_index('ix_user_sessions_user_active', table_name='user_sessions')
|
||||
op.drop_index('ix_user_sessions_user_id', table_name='user_sessions')
|
||||
op.drop_index('ix_user_sessions_jti', table_name='user_sessions')
|
||||
op.drop_index("ix_user_sessions_jti_active", table_name="user_sessions")
|
||||
op.drop_index("ix_user_sessions_expires_at", table_name="user_sessions")
|
||||
op.drop_index("ix_user_sessions_user_active", table_name="user_sessions")
|
||||
op.drop_index("ix_user_sessions_user_id", table_name="user_sessions")
|
||||
op.drop_index("ix_user_sessions_jti", table_name="user_sessions")
|
||||
|
||||
# Drop foreign key
|
||||
op.drop_constraint('fk_user_sessions_user_id', 'user_sessions', type_='foreignkey')
|
||||
op.drop_constraint("fk_user_sessions_user_id", "user_sessions", type_="foreignkey")
|
||||
|
||||
# Drop table
|
||||
op.drop_table('user_sessions')
|
||||
op.drop_table("user_sessions")
|
||||
|
||||
@@ -1,19 +1,18 @@
|
||||
"""Initial empty migration
|
||||
|
||||
Revision ID: 7396957cbe80
|
||||
Revises:
|
||||
Revises:
|
||||
Create Date: 2025-02-27 12:47:46.445313
|
||||
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
from collections.abc import Sequence
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = '7396957cbe80'
|
||||
down_revision: Union[str, None] = None
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
revision: str = "7396957cbe80"
|
||||
down_revision: str | None = None
|
||||
branch_labels: str | Sequence[str] | None = None
|
||||
depends_on: str | Sequence[str] | None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
|
||||
@@ -5,80 +5,112 @@ Revises: 38bf9e7e74b3
|
||||
Create Date: 2025-10-30 10:00:00.000000
|
||||
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
|
||||
from collections.abc import Sequence
|
||||
|
||||
import sqlalchemy as sa
|
||||
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = '9e4f2a1b8c7d'
|
||||
down_revision: Union[str, None] = '38bf9e7e74b3'
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
revision: str = "9e4f2a1b8c7d"
|
||||
down_revision: str | None = "38bf9e7e74b3"
|
||||
branch_labels: str | Sequence[str] | None = None
|
||||
depends_on: str | Sequence[str] | None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Add missing indexes for is_active and is_superuser
|
||||
op.create_index(op.f('ix_users_is_active'), 'users', ['is_active'], unique=False)
|
||||
op.create_index(op.f('ix_users_is_superuser'), 'users', ['is_superuser'], unique=False)
|
||||
op.create_index(op.f("ix_users_is_active"), "users", ["is_active"], unique=False)
|
||||
op.create_index(
|
||||
op.f("ix_users_is_superuser"), "users", ["is_superuser"], unique=False
|
||||
)
|
||||
|
||||
# Fix column types to match model definitions with explicit lengths
|
||||
op.alter_column('users', 'email',
|
||||
existing_type=sa.String(),
|
||||
type_=sa.String(length=255),
|
||||
nullable=False)
|
||||
op.alter_column(
|
||||
"users",
|
||||
"email",
|
||||
existing_type=sa.String(),
|
||||
type_=sa.String(length=255),
|
||||
nullable=False,
|
||||
)
|
||||
|
||||
op.alter_column('users', 'password_hash',
|
||||
existing_type=sa.String(),
|
||||
type_=sa.String(length=255),
|
||||
nullable=False)
|
||||
op.alter_column(
|
||||
"users",
|
||||
"password_hash",
|
||||
existing_type=sa.String(),
|
||||
type_=sa.String(length=255),
|
||||
nullable=False,
|
||||
)
|
||||
|
||||
op.alter_column('users', 'first_name',
|
||||
existing_type=sa.String(),
|
||||
type_=sa.String(length=100),
|
||||
nullable=False,
|
||||
server_default='user') # Add server default
|
||||
op.alter_column(
|
||||
"users",
|
||||
"first_name",
|
||||
existing_type=sa.String(),
|
||||
type_=sa.String(length=100),
|
||||
nullable=False,
|
||||
server_default="user",
|
||||
) # Add server default
|
||||
|
||||
op.alter_column('users', 'last_name',
|
||||
existing_type=sa.String(),
|
||||
type_=sa.String(length=100),
|
||||
nullable=True)
|
||||
op.alter_column(
|
||||
"users",
|
||||
"last_name",
|
||||
existing_type=sa.String(),
|
||||
type_=sa.String(length=100),
|
||||
nullable=True,
|
||||
)
|
||||
|
||||
op.alter_column('users', 'phone_number',
|
||||
existing_type=sa.String(),
|
||||
type_=sa.String(length=20),
|
||||
nullable=True)
|
||||
op.alter_column(
|
||||
"users",
|
||||
"phone_number",
|
||||
existing_type=sa.String(),
|
||||
type_=sa.String(length=20),
|
||||
nullable=True,
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Revert column types
|
||||
op.alter_column('users', 'phone_number',
|
||||
existing_type=sa.String(length=20),
|
||||
type_=sa.String(),
|
||||
nullable=True)
|
||||
op.alter_column(
|
||||
"users",
|
||||
"phone_number",
|
||||
existing_type=sa.String(length=20),
|
||||
type_=sa.String(),
|
||||
nullable=True,
|
||||
)
|
||||
|
||||
op.alter_column('users', 'last_name',
|
||||
existing_type=sa.String(length=100),
|
||||
type_=sa.String(),
|
||||
nullable=True)
|
||||
op.alter_column(
|
||||
"users",
|
||||
"last_name",
|
||||
existing_type=sa.String(length=100),
|
||||
type_=sa.String(),
|
||||
nullable=True,
|
||||
)
|
||||
|
||||
op.alter_column('users', 'first_name',
|
||||
existing_type=sa.String(length=100),
|
||||
type_=sa.String(),
|
||||
nullable=False,
|
||||
server_default=None) # Remove server default
|
||||
op.alter_column(
|
||||
"users",
|
||||
"first_name",
|
||||
existing_type=sa.String(length=100),
|
||||
type_=sa.String(),
|
||||
nullable=False,
|
||||
server_default=None,
|
||||
) # Remove server default
|
||||
|
||||
op.alter_column('users', 'password_hash',
|
||||
existing_type=sa.String(length=255),
|
||||
type_=sa.String(),
|
||||
nullable=False)
|
||||
op.alter_column(
|
||||
"users",
|
||||
"password_hash",
|
||||
existing_type=sa.String(length=255),
|
||||
type_=sa.String(),
|
||||
nullable=False,
|
||||
)
|
||||
|
||||
op.alter_column('users', 'email',
|
||||
existing_type=sa.String(length=255),
|
||||
type_=sa.String(),
|
||||
nullable=False)
|
||||
op.alter_column(
|
||||
"users",
|
||||
"email",
|
||||
existing_type=sa.String(length=255),
|
||||
type_=sa.String(),
|
||||
nullable=False,
|
||||
)
|
||||
|
||||
# Drop indexes
|
||||
op.drop_index(op.f('ix_users_is_superuser'), table_name='users')
|
||||
op.drop_index(op.f('ix_users_is_active'), table_name='users')
|
||||
op.drop_index(op.f("ix_users_is_superuser"), table_name="users")
|
||||
op.drop_index(op.f("ix_users_is_active"), table_name="users")
|
||||
|
||||
@@ -5,17 +5,17 @@ Revises: 2d0fcec3b06d
|
||||
Create Date: 2025-10-30 16:41:33.273135
|
||||
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
|
||||
from collections.abc import Sequence
|
||||
|
||||
import sqlalchemy as sa
|
||||
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = 'b76c725fc3cf'
|
||||
down_revision: Union[str, None] = '2d0fcec3b06d'
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
revision: str = "b76c725fc3cf"
|
||||
down_revision: str | None = "2d0fcec3b06d"
|
||||
branch_labels: str | Sequence[str] | None = None
|
||||
depends_on: str | Sequence[str] | None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
@@ -23,30 +23,26 @@ def upgrade() -> None:
|
||||
|
||||
# Composite index for filtering active users by role
|
||||
op.create_index(
|
||||
'ix_users_active_superuser',
|
||||
'users',
|
||||
['is_active', 'is_superuser'],
|
||||
postgresql_where=sa.text('deleted_at IS NULL')
|
||||
"ix_users_active_superuser",
|
||||
"users",
|
||||
["is_active", "is_superuser"],
|
||||
postgresql_where=sa.text("deleted_at IS NULL"),
|
||||
)
|
||||
|
||||
# Composite index for sorting active users by creation date
|
||||
op.create_index(
|
||||
'ix_users_active_created',
|
||||
'users',
|
||||
['is_active', 'created_at'],
|
||||
postgresql_where=sa.text('deleted_at IS NULL')
|
||||
"ix_users_active_created",
|
||||
"users",
|
||||
["is_active", "created_at"],
|
||||
postgresql_where=sa.text("deleted_at IS NULL"),
|
||||
)
|
||||
|
||||
# Composite index for email lookup of non-deleted users
|
||||
op.create_index(
|
||||
'ix_users_email_not_deleted',
|
||||
'users',
|
||||
['email', 'deleted_at']
|
||||
)
|
||||
op.create_index("ix_users_email_not_deleted", "users", ["email", "deleted_at"])
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Remove composite indexes
|
||||
op.drop_index('ix_users_email_not_deleted', table_name='users')
|
||||
op.drop_index('ix_users_active_created', table_name='users')
|
||||
op.drop_index('ix_users_active_superuser', table_name='users')
|
||||
op.drop_index("ix_users_email_not_deleted", table_name="users")
|
||||
op.drop_index("ix_users_active_created", table_name="users")
|
||||
op.drop_index("ix_users_active_superuser", table_name="users")
|
||||
|
||||
@@ -5,102 +5,123 @@ Revises: 549b50ea888d
|
||||
Create Date: 2025-10-31 12:08:05.141353
|
||||
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
|
||||
from collections.abc import Sequence
|
||||
|
||||
import sqlalchemy as sa
|
||||
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = 'fbf6318a8a36'
|
||||
down_revision: Union[str, None] = '549b50ea888d'
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
revision: str = "fbf6318a8a36"
|
||||
down_revision: str | None = "549b50ea888d"
|
||||
branch_labels: str | Sequence[str] | None = None
|
||||
depends_on: str | Sequence[str] | None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Create organizations table
|
||||
op.create_table(
|
||||
'organizations',
|
||||
sa.Column('id', sa.UUID(), nullable=False),
|
||||
sa.Column('name', sa.String(length=255), nullable=False),
|
||||
sa.Column('slug', sa.String(length=255), nullable=False),
|
||||
sa.Column('description', sa.Text(), nullable=True),
|
||||
sa.Column('is_active', sa.Boolean(), nullable=False, server_default='true'),
|
||||
sa.Column('settings', sa.JSON(), nullable=True),
|
||||
sa.Column('created_at', sa.DateTime(timezone=True), nullable=False),
|
||||
sa.Column('updated_at', sa.DateTime(timezone=True), nullable=False),
|
||||
sa.PrimaryKeyConstraint('id')
|
||||
"organizations",
|
||||
sa.Column("id", sa.UUID(), nullable=False),
|
||||
sa.Column("name", sa.String(length=255), nullable=False),
|
||||
sa.Column("slug", sa.String(length=255), nullable=False),
|
||||
sa.Column("description", sa.Text(), nullable=True),
|
||||
sa.Column("is_active", sa.Boolean(), nullable=False, server_default="true"),
|
||||
sa.Column("settings", sa.JSON(), nullable=True),
|
||||
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False),
|
||||
sa.Column("updated_at", sa.DateTime(timezone=True), nullable=False),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
)
|
||||
|
||||
# Create indexes for organizations
|
||||
op.create_index('ix_organizations_name', 'organizations', ['name'])
|
||||
op.create_index('ix_organizations_slug', 'organizations', ['slug'], unique=True)
|
||||
op.create_index('ix_organizations_is_active', 'organizations', ['is_active'])
|
||||
op.create_index('ix_organizations_name_active', 'organizations', ['name', 'is_active'])
|
||||
op.create_index('ix_organizations_slug_active', 'organizations', ['slug', 'is_active'])
|
||||
op.create_index("ix_organizations_name", "organizations", ["name"])
|
||||
op.create_index("ix_organizations_slug", "organizations", ["slug"], unique=True)
|
||||
op.create_index("ix_organizations_is_active", "organizations", ["is_active"])
|
||||
op.create_index(
|
||||
"ix_organizations_name_active", "organizations", ["name", "is_active"]
|
||||
)
|
||||
op.create_index(
|
||||
"ix_organizations_slug_active", "organizations", ["slug", "is_active"]
|
||||
)
|
||||
|
||||
# Create user_organizations junction table
|
||||
op.create_table(
|
||||
'user_organizations',
|
||||
sa.Column('user_id', sa.UUID(), nullable=False),
|
||||
sa.Column('organization_id', sa.UUID(), nullable=False),
|
||||
sa.Column('role', sa.Enum('OWNER', 'ADMIN', 'MEMBER', 'GUEST', name='organizationrole'), nullable=False, server_default='MEMBER'),
|
||||
sa.Column('is_active', sa.Boolean(), nullable=False, server_default='true'),
|
||||
sa.Column('custom_permissions', sa.String(length=500), nullable=True),
|
||||
sa.Column('created_at', sa.DateTime(timezone=True), nullable=False),
|
||||
sa.Column('updated_at', sa.DateTime(timezone=True), nullable=False),
|
||||
sa.PrimaryKeyConstraint('user_id', 'organization_id')
|
||||
"user_organizations",
|
||||
sa.Column("user_id", sa.UUID(), nullable=False),
|
||||
sa.Column("organization_id", sa.UUID(), nullable=False),
|
||||
sa.Column(
|
||||
"role",
|
||||
sa.Enum("OWNER", "ADMIN", "MEMBER", "GUEST", name="organizationrole"),
|
||||
nullable=False,
|
||||
server_default="MEMBER",
|
||||
),
|
||||
sa.Column("is_active", sa.Boolean(), nullable=False, server_default="true"),
|
||||
sa.Column("custom_permissions", sa.String(length=500), nullable=True),
|
||||
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False),
|
||||
sa.Column("updated_at", sa.DateTime(timezone=True), nullable=False),
|
||||
sa.PrimaryKeyConstraint("user_id", "organization_id"),
|
||||
)
|
||||
|
||||
# Create foreign keys
|
||||
op.create_foreign_key(
|
||||
'fk_user_organizations_user_id',
|
||||
'user_organizations',
|
||||
'users',
|
||||
['user_id'],
|
||||
['id'],
|
||||
ondelete='CASCADE'
|
||||
"fk_user_organizations_user_id",
|
||||
"user_organizations",
|
||||
"users",
|
||||
["user_id"],
|
||||
["id"],
|
||||
ondelete="CASCADE",
|
||||
)
|
||||
op.create_foreign_key(
|
||||
'fk_user_organizations_organization_id',
|
||||
'user_organizations',
|
||||
'organizations',
|
||||
['organization_id'],
|
||||
['id'],
|
||||
ondelete='CASCADE'
|
||||
"fk_user_organizations_organization_id",
|
||||
"user_organizations",
|
||||
"organizations",
|
||||
["organization_id"],
|
||||
["id"],
|
||||
ondelete="CASCADE",
|
||||
)
|
||||
|
||||
# Create indexes for user_organizations
|
||||
op.create_index('ix_user_organizations_role', 'user_organizations', ['role'])
|
||||
op.create_index('ix_user_organizations_is_active', 'user_organizations', ['is_active'])
|
||||
op.create_index('ix_user_org_user_active', 'user_organizations', ['user_id', 'is_active'])
|
||||
op.create_index('ix_user_org_org_active', 'user_organizations', ['organization_id', 'is_active'])
|
||||
op.create_index("ix_user_organizations_role", "user_organizations", ["role"])
|
||||
op.create_index(
|
||||
"ix_user_organizations_is_active", "user_organizations", ["is_active"]
|
||||
)
|
||||
op.create_index(
|
||||
"ix_user_org_user_active", "user_organizations", ["user_id", "is_active"]
|
||||
)
|
||||
op.create_index(
|
||||
"ix_user_org_org_active", "user_organizations", ["organization_id", "is_active"]
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Drop indexes for user_organizations
|
||||
op.drop_index('ix_user_org_org_active', table_name='user_organizations')
|
||||
op.drop_index('ix_user_org_user_active', table_name='user_organizations')
|
||||
op.drop_index('ix_user_organizations_is_active', table_name='user_organizations')
|
||||
op.drop_index('ix_user_organizations_role', table_name='user_organizations')
|
||||
op.drop_index("ix_user_org_org_active", table_name="user_organizations")
|
||||
op.drop_index("ix_user_org_user_active", table_name="user_organizations")
|
||||
op.drop_index("ix_user_organizations_is_active", table_name="user_organizations")
|
||||
op.drop_index("ix_user_organizations_role", table_name="user_organizations")
|
||||
|
||||
# Drop foreign keys
|
||||
op.drop_constraint('fk_user_organizations_organization_id', 'user_organizations', type_='foreignkey')
|
||||
op.drop_constraint('fk_user_organizations_user_id', 'user_organizations', type_='foreignkey')
|
||||
op.drop_constraint(
|
||||
"fk_user_organizations_organization_id",
|
||||
"user_organizations",
|
||||
type_="foreignkey",
|
||||
)
|
||||
op.drop_constraint(
|
||||
"fk_user_organizations_user_id", "user_organizations", type_="foreignkey"
|
||||
)
|
||||
|
||||
# Drop user_organizations table
|
||||
op.drop_table('user_organizations')
|
||||
op.drop_table("user_organizations")
|
||||
|
||||
# Drop indexes for organizations
|
||||
op.drop_index('ix_organizations_slug_active', table_name='organizations')
|
||||
op.drop_index('ix_organizations_name_active', table_name='organizations')
|
||||
op.drop_index('ix_organizations_is_active', table_name='organizations')
|
||||
op.drop_index('ix_organizations_slug', table_name='organizations')
|
||||
op.drop_index('ix_organizations_name', table_name='organizations')
|
||||
op.drop_index("ix_organizations_slug_active", table_name="organizations")
|
||||
op.drop_index("ix_organizations_name_active", table_name="organizations")
|
||||
op.drop_index("ix_organizations_is_active", table_name="organizations")
|
||||
op.drop_index("ix_organizations_slug", table_name="organizations")
|
||||
op.drop_index("ix_organizations_name", table_name="organizations")
|
||||
|
||||
# Drop organizations table
|
||||
op.drop_table('organizations')
|
||||
op.drop_table("organizations")
|
||||
|
||||
# Drop enum type
|
||||
op.execute('DROP TYPE IF EXISTS organizationrole')
|
||||
op.execute("DROP TYPE IF EXISTS organizationrole")
|
||||
|
||||
@@ -1,12 +1,10 @@
|
||||
from typing import Optional
|
||||
|
||||
from fastapi import Depends, HTTPException, status, Header
|
||||
from fastapi import Depends, Header, HTTPException, status
|
||||
from fastapi.security import OAuth2PasswordBearer
|
||||
from fastapi.security.utils import get_authorization_scheme_param
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.core.auth import get_token_data, TokenExpiredError, TokenInvalidError
|
||||
from app.core.auth import TokenExpiredError, TokenInvalidError, get_token_data
|
||||
from app.core.database import get_db
|
||||
from app.models.user import User
|
||||
|
||||
@@ -15,8 +13,7 @@ oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/api/v1/auth/login")
|
||||
|
||||
|
||||
async def get_current_user(
|
||||
db: AsyncSession = Depends(get_db),
|
||||
token: str = Depends(oauth2_scheme)
|
||||
db: AsyncSession = Depends(get_db), token: str = Depends(oauth2_scheme)
|
||||
) -> User:
|
||||
"""
|
||||
Get the current authenticated user.
|
||||
@@ -36,21 +33,17 @@ async def get_current_user(
|
||||
token_data = get_token_data(token)
|
||||
|
||||
# Get user from database
|
||||
result = await db.execute(
|
||||
select(User).where(User.id == token_data.user_id)
|
||||
)
|
||||
result = await db.execute(select(User).where(User.id == token_data.user_id))
|
||||
user = result.scalar_one_or_none()
|
||||
|
||||
if not user:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="User not found"
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail="User not found"
|
||||
)
|
||||
|
||||
if not user.is_active:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Inactive user"
|
||||
status_code=status.HTTP_403_FORBIDDEN, detail="Inactive user"
|
||||
)
|
||||
|
||||
return user
|
||||
@@ -59,19 +52,17 @@ async def get_current_user(
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Token expired",
|
||||
headers={"WWW-Authenticate": "Bearer"}
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
except TokenInvalidError:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Could not validate credentials",
|
||||
headers={"WWW-Authenticate": "Bearer"}
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
|
||||
|
||||
def get_current_active_user(
|
||||
current_user: User = Depends(get_current_user)
|
||||
) -> User:
|
||||
def get_current_active_user(current_user: User = Depends(get_current_user)) -> User:
|
||||
"""
|
||||
Check if the current user is active.
|
||||
|
||||
@@ -86,15 +77,12 @@ def get_current_active_user(
|
||||
"""
|
||||
if not current_user.is_active:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Inactive user"
|
||||
status_code=status.HTTP_403_FORBIDDEN, detail="Inactive user"
|
||||
)
|
||||
return current_user
|
||||
|
||||
|
||||
def get_current_superuser(
|
||||
current_user: User = Depends(get_current_user)
|
||||
) -> User:
|
||||
def get_current_superuser(current_user: User = Depends(get_current_user)) -> User:
|
||||
"""
|
||||
Check if the current user is a superuser.
|
||||
|
||||
@@ -109,13 +97,12 @@ def get_current_superuser(
|
||||
"""
|
||||
if not current_user.is_superuser:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Not enough permissions"
|
||||
status_code=status.HTTP_403_FORBIDDEN, detail="Not enough permissions"
|
||||
)
|
||||
return current_user
|
||||
|
||||
|
||||
async def get_optional_token(authorization: str = Header(None)) -> Optional[str]:
|
||||
async def get_optional_token(authorization: str = Header(None)) -> str | None:
|
||||
"""
|
||||
Get the token from the Authorization header without requiring it.
|
||||
|
||||
@@ -139,9 +126,8 @@ async def get_optional_token(authorization: str = Header(None)) -> Optional[str]
|
||||
|
||||
|
||||
async def get_optional_current_user(
|
||||
db: AsyncSession = Depends(get_db),
|
||||
token: Optional[str] = Depends(get_optional_token)
|
||||
) -> Optional[User]:
|
||||
db: AsyncSession = Depends(get_db), token: str | None = Depends(get_optional_token)
|
||||
) -> User | None:
|
||||
"""
|
||||
Get the current user if authenticated, otherwise return None.
|
||||
Useful for endpoints that work with both authenticated and unauthenticated users.
|
||||
@@ -158,12 +144,10 @@ async def get_optional_current_user(
|
||||
|
||||
try:
|
||||
token_data = get_token_data(token)
|
||||
result = await db.execute(
|
||||
select(User).where(User.id == token_data.user_id)
|
||||
)
|
||||
result = await db.execute(select(User).where(User.id == token_data.user_id))
|
||||
user = result.scalar_one_or_none()
|
||||
if not user or not user.is_active:
|
||||
return None
|
||||
return user
|
||||
except (TokenExpiredError, TokenInvalidError):
|
||||
return None
|
||||
return None
|
||||
|
||||
@@ -7,7 +7,7 @@ These dependencies are optional and flexible:
|
||||
- Use require_org_role for organization-specific access control
|
||||
- Projects can choose to use these or implement their own permission system
|
||||
"""
|
||||
from typing import Optional
|
||||
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import Depends, HTTPException, status
|
||||
@@ -20,9 +20,7 @@ from app.models.user import User
|
||||
from app.models.user_organization import OrganizationRole
|
||||
|
||||
|
||||
def require_superuser(
|
||||
current_user: User = Depends(get_current_user)
|
||||
) -> User:
|
||||
def require_superuser(current_user: User = Depends(get_current_user)) -> User:
|
||||
"""
|
||||
Dependency to ensure the current user is a superuser.
|
||||
|
||||
@@ -36,7 +34,7 @@ def require_superuser(
|
||||
if not current_user.is_superuser:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Superuser privileges required"
|
||||
detail="Superuser privileges required",
|
||||
)
|
||||
return current_user
|
||||
|
||||
@@ -62,7 +60,7 @@ class OrganizationPermission:
|
||||
self,
|
||||
organization_id: UUID,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> User:
|
||||
"""
|
||||
Check if user has required role in the organization.
|
||||
@@ -84,21 +82,19 @@ class OrganizationPermission:
|
||||
|
||||
# Get user's role in organization
|
||||
user_role = await organization_crud.get_user_role_in_org(
|
||||
db,
|
||||
user_id=current_user.id,
|
||||
organization_id=organization_id
|
||||
db, user_id=current_user.id, organization_id=organization_id
|
||||
)
|
||||
|
||||
if not user_role:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Not a member of this organization"
|
||||
detail="Not a member of this organization",
|
||||
)
|
||||
|
||||
if user_role not in self.allowed_roles:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail=f"Role {user_role} not authorized. Required: {self.allowed_roles}"
|
||||
detail=f"Role {user_role} not authorized. Required: {self.allowed_roles}",
|
||||
)
|
||||
|
||||
return current_user
|
||||
@@ -106,18 +102,18 @@ class OrganizationPermission:
|
||||
|
||||
# Common permission presets for convenience
|
||||
require_org_owner = OrganizationPermission([OrganizationRole.OWNER])
|
||||
require_org_admin = OrganizationPermission([OrganizationRole.OWNER, OrganizationRole.ADMIN])
|
||||
require_org_member = OrganizationPermission([
|
||||
OrganizationRole.OWNER,
|
||||
OrganizationRole.ADMIN,
|
||||
OrganizationRole.MEMBER
|
||||
])
|
||||
require_org_admin = OrganizationPermission(
|
||||
[OrganizationRole.OWNER, OrganizationRole.ADMIN]
|
||||
)
|
||||
require_org_member = OrganizationPermission(
|
||||
[OrganizationRole.OWNER, OrganizationRole.ADMIN, OrganizationRole.MEMBER]
|
||||
)
|
||||
|
||||
|
||||
async def require_org_membership(
|
||||
organization_id: UUID,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> User:
|
||||
"""
|
||||
Ensure user is a member of the organization (any role).
|
||||
@@ -128,15 +124,13 @@ async def require_org_membership(
|
||||
return current_user
|
||||
|
||||
user_role = await organization_crud.get_user_role_in_org(
|
||||
db,
|
||||
user_id=current_user.id,
|
||||
organization_id=organization_id
|
||||
db, user_id=current_user.id, organization_id=organization_id
|
||||
)
|
||||
|
||||
if not user_role:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Not a member of this organization"
|
||||
detail="Not a member of this organization",
|
||||
)
|
||||
|
||||
return current_user
|
||||
|
||||
@@ -1,10 +1,12 @@
|
||||
from fastapi import APIRouter
|
||||
|
||||
from app.api.routes import auth, users, sessions, admin, organizations
|
||||
from app.api.routes import admin, auth, organizations, sessions, users
|
||||
|
||||
api_router = APIRouter()
|
||||
api_router.include_router(auth.router, prefix="/auth", tags=["Authentication"])
|
||||
api_router.include_router(users.router, prefix="/users", tags=["Users"])
|
||||
api_router.include_router(sessions.router, prefix="/sessions", tags=["Sessions"])
|
||||
api_router.include_router(admin.router, prefix="/admin", tags=["Admin"])
|
||||
api_router.include_router(organizations.router, prefix="/organizations", tags=["Organizations"])
|
||||
api_router.include_router(
|
||||
organizations.router, prefix="/organizations", tags=["Organizations"]
|
||||
)
|
||||
|
||||
@@ -5,9 +5,10 @@ Admin-specific endpoints for managing users and organizations.
|
||||
These endpoints require superuser privileges and provide CMS-like functionality
|
||||
for managing the application.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from enum import Enum
|
||||
from typing import Any, List, Optional
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import APIRouter, Depends, Query, status
|
||||
@@ -16,27 +17,32 @@ from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.api.dependencies.permissions import require_superuser
|
||||
from app.core.database import get_db
|
||||
from app.core.exceptions import NotFoundError, DuplicateError, AuthorizationError, ErrorCode
|
||||
from app.core.exceptions import (
|
||||
AuthorizationError,
|
||||
DuplicateError,
|
||||
ErrorCode,
|
||||
NotFoundError,
|
||||
)
|
||||
from app.crud.organization import organization as organization_crud
|
||||
from app.crud.user import user as user_crud
|
||||
from app.crud.session import session as session_crud
|
||||
from app.crud.user import user as user_crud
|
||||
from app.models.user import User
|
||||
from app.models.user_organization import OrganizationRole
|
||||
from app.schemas.common import (
|
||||
PaginationParams,
|
||||
PaginatedResponse,
|
||||
MessageResponse,
|
||||
PaginatedResponse,
|
||||
PaginationParams,
|
||||
SortParams,
|
||||
create_pagination_meta
|
||||
create_pagination_meta,
|
||||
)
|
||||
from app.schemas.organizations import (
|
||||
OrganizationResponse,
|
||||
OrganizationCreate,
|
||||
OrganizationMemberResponse,
|
||||
OrganizationResponse,
|
||||
OrganizationUpdate,
|
||||
OrganizationMemberResponse
|
||||
)
|
||||
from app.schemas.users import UserResponse, UserCreate, UserUpdate
|
||||
from app.schemas.sessions import AdminSessionResponse
|
||||
from app.schemas.users import UserCreate, UserResponse, UserUpdate
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -46,6 +52,7 @@ router = APIRouter()
|
||||
# Schemas for bulk operations
|
||||
class BulkAction(str, Enum):
|
||||
"""Supported bulk actions."""
|
||||
|
||||
ACTIVATE = "activate"
|
||||
DEACTIVATE = "deactivate"
|
||||
DELETE = "delete"
|
||||
@@ -53,36 +60,41 @@ class BulkAction(str, Enum):
|
||||
|
||||
class BulkUserAction(BaseModel):
|
||||
"""Schema for bulk user actions."""
|
||||
|
||||
action: BulkAction = Field(..., description="Action to perform on selected users")
|
||||
user_ids: List[UUID] = Field(..., min_items=1, max_items=100, description="List of user IDs (max 100)")
|
||||
user_ids: list[UUID] = Field(
|
||||
..., min_items=1, max_items=100, description="List of user IDs (max 100)"
|
||||
)
|
||||
|
||||
|
||||
class BulkActionResult(BaseModel):
|
||||
"""Result of a bulk action."""
|
||||
|
||||
success: bool
|
||||
affected_count: int
|
||||
failed_count: int
|
||||
message: str
|
||||
failed_ids: Optional[List[UUID]] = []
|
||||
failed_ids: list[UUID] | None = []
|
||||
|
||||
|
||||
# ===== User Management Endpoints =====
|
||||
|
||||
|
||||
@router.get(
|
||||
"/users",
|
||||
response_model=PaginatedResponse[UserResponse],
|
||||
summary="Admin: List All Users",
|
||||
description="Get paginated list of all users with filtering and search (admin only)",
|
||||
operation_id="admin_list_users"
|
||||
operation_id="admin_list_users",
|
||||
)
|
||||
async def admin_list_users(
|
||||
pagination: PaginationParams = Depends(),
|
||||
sort: SortParams = Depends(),
|
||||
is_active: Optional[bool] = Query(None, description="Filter by active status"),
|
||||
is_superuser: Optional[bool] = Query(None, description="Filter by superuser status"),
|
||||
search: Optional[str] = Query(None, description="Search by email, name"),
|
||||
is_active: bool | None = Query(None, description="Filter by active status"),
|
||||
is_superuser: bool | None = Query(None, description="Filter by superuser status"),
|
||||
search: str | None = Query(None, description="Search by email, name"),
|
||||
admin: User = Depends(require_superuser),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> Any:
|
||||
"""
|
||||
List all users with comprehensive filtering and search.
|
||||
@@ -105,20 +117,20 @@ async def admin_list_users(
|
||||
sort_by=sort.sort_by or "created_at",
|
||||
sort_order=sort.sort_order.value if sort.sort_order else "desc",
|
||||
filters=filters if filters else None,
|
||||
search=search
|
||||
search=search,
|
||||
)
|
||||
|
||||
pagination_meta = create_pagination_meta(
|
||||
total=total,
|
||||
page=pagination.page,
|
||||
limit=pagination.limit,
|
||||
items_count=len(users)
|
||||
items_count=len(users),
|
||||
)
|
||||
|
||||
return PaginatedResponse(data=users, pagination=pagination_meta)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error listing users (admin): {str(e)}", exc_info=True)
|
||||
logger.error(f"Error listing users (admin): {e!s}", exc_info=True)
|
||||
raise
|
||||
|
||||
|
||||
@@ -128,12 +140,12 @@ async def admin_list_users(
|
||||
status_code=status.HTTP_201_CREATED,
|
||||
summary="Admin: Create User",
|
||||
description="Create a new user (admin only)",
|
||||
operation_id="admin_create_user"
|
||||
operation_id="admin_create_user",
|
||||
)
|
||||
async def admin_create_user(
|
||||
user_in: UserCreate,
|
||||
admin: User = Depends(require_superuser),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> Any:
|
||||
"""
|
||||
Create a new user with admin privileges.
|
||||
@@ -145,13 +157,10 @@ async def admin_create_user(
|
||||
logger.info(f"Admin {admin.email} created user {user.email}")
|
||||
return user
|
||||
except ValueError as e:
|
||||
logger.warning(f"Failed to create user: {str(e)}")
|
||||
raise NotFoundError(
|
||||
message=str(e),
|
||||
error_code=ErrorCode.USER_ALREADY_EXISTS
|
||||
)
|
||||
logger.warning(f"Failed to create user: {e!s}")
|
||||
raise NotFoundError(message=str(e), error_code=ErrorCode.USER_ALREADY_EXISTS)
|
||||
except Exception as e:
|
||||
logger.error(f"Error creating user (admin): {str(e)}", exc_info=True)
|
||||
logger.error(f"Error creating user (admin): {e!s}", exc_info=True)
|
||||
raise
|
||||
|
||||
|
||||
@@ -160,19 +169,18 @@ async def admin_create_user(
|
||||
response_model=UserResponse,
|
||||
summary="Admin: Get User Details",
|
||||
description="Get detailed user information (admin only)",
|
||||
operation_id="admin_get_user"
|
||||
operation_id="admin_get_user",
|
||||
)
|
||||
async def admin_get_user(
|
||||
user_id: UUID,
|
||||
admin: User = Depends(require_superuser),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> Any:
|
||||
"""Get detailed information about a specific user."""
|
||||
user = await user_crud.get(db, id=user_id)
|
||||
if not user:
|
||||
raise NotFoundError(
|
||||
message=f"User {user_id} not found",
|
||||
error_code=ErrorCode.USER_NOT_FOUND
|
||||
message=f"User {user_id} not found", error_code=ErrorCode.USER_NOT_FOUND
|
||||
)
|
||||
return user
|
||||
|
||||
@@ -182,21 +190,20 @@ async def admin_get_user(
|
||||
response_model=UserResponse,
|
||||
summary="Admin: Update User",
|
||||
description="Update user information (admin only)",
|
||||
operation_id="admin_update_user"
|
||||
operation_id="admin_update_user",
|
||||
)
|
||||
async def admin_update_user(
|
||||
user_id: UUID,
|
||||
user_in: UserUpdate,
|
||||
admin: User = Depends(require_superuser),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> Any:
|
||||
"""Update user information with admin privileges."""
|
||||
try:
|
||||
user = await user_crud.get(db, id=user_id)
|
||||
if not user:
|
||||
raise NotFoundError(
|
||||
message=f"User {user_id} not found",
|
||||
error_code=ErrorCode.USER_NOT_FOUND
|
||||
message=f"User {user_id} not found", error_code=ErrorCode.USER_NOT_FOUND
|
||||
)
|
||||
|
||||
updated_user = await user_crud.update(db, db_obj=user, obj_in=user_in)
|
||||
@@ -206,7 +213,7 @@ async def admin_update_user(
|
||||
except NotFoundError:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error updating user (admin): {str(e)}", exc_info=True)
|
||||
logger.error(f"Error updating user (admin): {e!s}", exc_info=True)
|
||||
raise
|
||||
|
||||
|
||||
@@ -215,20 +222,19 @@ async def admin_update_user(
|
||||
response_model=MessageResponse,
|
||||
summary="Admin: Delete User",
|
||||
description="Soft delete a user (admin only)",
|
||||
operation_id="admin_delete_user"
|
||||
operation_id="admin_delete_user",
|
||||
)
|
||||
async def admin_delete_user(
|
||||
user_id: UUID,
|
||||
admin: User = Depends(require_superuser),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> Any:
|
||||
"""Soft delete a user (sets deleted_at timestamp)."""
|
||||
try:
|
||||
user = await user_crud.get(db, id=user_id)
|
||||
if not user:
|
||||
raise NotFoundError(
|
||||
message=f"User {user_id} not found",
|
||||
error_code=ErrorCode.USER_NOT_FOUND
|
||||
message=f"User {user_id} not found", error_code=ErrorCode.USER_NOT_FOUND
|
||||
)
|
||||
|
||||
# Prevent deleting yourself
|
||||
@@ -236,21 +242,20 @@ async def admin_delete_user(
|
||||
# Use AuthorizationError for permission/operation restrictions
|
||||
raise AuthorizationError(
|
||||
message="Cannot delete your own account",
|
||||
error_code=ErrorCode.OPERATION_FORBIDDEN
|
||||
error_code=ErrorCode.OPERATION_FORBIDDEN,
|
||||
)
|
||||
|
||||
await user_crud.soft_delete(db, id=user_id)
|
||||
logger.info(f"Admin {admin.email} deleted user {user.email}")
|
||||
|
||||
return MessageResponse(
|
||||
success=True,
|
||||
message=f"User {user.email} has been deleted"
|
||||
success=True, message=f"User {user.email} has been deleted"
|
||||
)
|
||||
|
||||
except NotFoundError:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error deleting user (admin): {str(e)}", exc_info=True)
|
||||
logger.error(f"Error deleting user (admin): {e!s}", exc_info=True)
|
||||
raise
|
||||
|
||||
|
||||
@@ -259,34 +264,32 @@ async def admin_delete_user(
|
||||
response_model=MessageResponse,
|
||||
summary="Admin: Activate User",
|
||||
description="Activate a user account (admin only)",
|
||||
operation_id="admin_activate_user"
|
||||
operation_id="admin_activate_user",
|
||||
)
|
||||
async def admin_activate_user(
|
||||
user_id: UUID,
|
||||
admin: User = Depends(require_superuser),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> Any:
|
||||
"""Activate a user account."""
|
||||
try:
|
||||
user = await user_crud.get(db, id=user_id)
|
||||
if not user:
|
||||
raise NotFoundError(
|
||||
message=f"User {user_id} not found",
|
||||
error_code=ErrorCode.USER_NOT_FOUND
|
||||
message=f"User {user_id} not found", error_code=ErrorCode.USER_NOT_FOUND
|
||||
)
|
||||
|
||||
await user_crud.update(db, db_obj=user, obj_in={"is_active": True})
|
||||
logger.info(f"Admin {admin.email} activated user {user.email}")
|
||||
|
||||
return MessageResponse(
|
||||
success=True,
|
||||
message=f"User {user.email} has been activated"
|
||||
success=True, message=f"User {user.email} has been activated"
|
||||
)
|
||||
|
||||
except NotFoundError:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error activating user (admin): {str(e)}", exc_info=True)
|
||||
logger.error(f"Error activating user (admin): {e!s}", exc_info=True)
|
||||
raise
|
||||
|
||||
|
||||
@@ -295,20 +298,19 @@ async def admin_activate_user(
|
||||
response_model=MessageResponse,
|
||||
summary="Admin: Deactivate User",
|
||||
description="Deactivate a user account (admin only)",
|
||||
operation_id="admin_deactivate_user"
|
||||
operation_id="admin_deactivate_user",
|
||||
)
|
||||
async def admin_deactivate_user(
|
||||
user_id: UUID,
|
||||
admin: User = Depends(require_superuser),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> Any:
|
||||
"""Deactivate a user account."""
|
||||
try:
|
||||
user = await user_crud.get(db, id=user_id)
|
||||
if not user:
|
||||
raise NotFoundError(
|
||||
message=f"User {user_id} not found",
|
||||
error_code=ErrorCode.USER_NOT_FOUND
|
||||
message=f"User {user_id} not found", error_code=ErrorCode.USER_NOT_FOUND
|
||||
)
|
||||
|
||||
# Prevent deactivating yourself
|
||||
@@ -316,21 +318,20 @@ async def admin_deactivate_user(
|
||||
# Use AuthorizationError for permission/operation restrictions
|
||||
raise AuthorizationError(
|
||||
message="Cannot deactivate your own account",
|
||||
error_code=ErrorCode.OPERATION_FORBIDDEN
|
||||
error_code=ErrorCode.OPERATION_FORBIDDEN,
|
||||
)
|
||||
|
||||
await user_crud.update(db, db_obj=user, obj_in={"is_active": False})
|
||||
logger.info(f"Admin {admin.email} deactivated user {user.email}")
|
||||
|
||||
return MessageResponse(
|
||||
success=True,
|
||||
message=f"User {user.email} has been deactivated"
|
||||
success=True, message=f"User {user.email} has been deactivated"
|
||||
)
|
||||
|
||||
except NotFoundError:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error deactivating user (admin): {str(e)}", exc_info=True)
|
||||
logger.error(f"Error deactivating user (admin): {e!s}", exc_info=True)
|
||||
raise
|
||||
|
||||
|
||||
@@ -339,12 +340,12 @@ async def admin_deactivate_user(
|
||||
response_model=BulkActionResult,
|
||||
summary="Admin: Bulk User Action",
|
||||
description="Perform bulk actions on multiple users (admin only)",
|
||||
operation_id="admin_bulk_user_action"
|
||||
operation_id="admin_bulk_user_action",
|
||||
)
|
||||
async def admin_bulk_user_action(
|
||||
bulk_action: BulkUserAction,
|
||||
admin: User = Depends(require_superuser),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> Any:
|
||||
"""
|
||||
Perform bulk actions on multiple users using optimized bulk operations.
|
||||
@@ -356,22 +357,16 @@ async def admin_bulk_user_action(
|
||||
# Use efficient bulk operations instead of loop
|
||||
if bulk_action.action == BulkAction.ACTIVATE:
|
||||
affected_count = await user_crud.bulk_update_status(
|
||||
db,
|
||||
user_ids=bulk_action.user_ids,
|
||||
is_active=True
|
||||
db, user_ids=bulk_action.user_ids, is_active=True
|
||||
)
|
||||
elif bulk_action.action == BulkAction.DEACTIVATE:
|
||||
affected_count = await user_crud.bulk_update_status(
|
||||
db,
|
||||
user_ids=bulk_action.user_ids,
|
||||
is_active=False
|
||||
db, user_ids=bulk_action.user_ids, is_active=False
|
||||
)
|
||||
elif bulk_action.action == BulkAction.DELETE:
|
||||
# bulk_soft_delete automatically excludes the admin user
|
||||
affected_count = await user_crud.bulk_soft_delete(
|
||||
db,
|
||||
user_ids=bulk_action.user_ids,
|
||||
exclude_user_id=admin.id
|
||||
db, user_ids=bulk_action.user_ids, exclude_user_id=admin.id
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unsupported bulk action: {bulk_action.action}")
|
||||
@@ -390,29 +385,30 @@ async def admin_bulk_user_action(
|
||||
affected_count=affected_count,
|
||||
failed_count=failed_count,
|
||||
message=f"Bulk {bulk_action.action.value}: {affected_count} users affected, {failed_count} skipped",
|
||||
failed_ids=None # Bulk operations don't track individual failures
|
||||
failed_ids=None, # Bulk operations don't track individual failures
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in bulk user action: {str(e)}", exc_info=True)
|
||||
logger.error(f"Error in bulk user action: {e!s}", exc_info=True)
|
||||
raise
|
||||
|
||||
|
||||
# ===== Organization Management Endpoints =====
|
||||
|
||||
|
||||
@router.get(
|
||||
"/organizations",
|
||||
response_model=PaginatedResponse[OrganizationResponse],
|
||||
summary="Admin: List Organizations",
|
||||
description="Get paginated list of all organizations (admin only)",
|
||||
operation_id="admin_list_organizations"
|
||||
operation_id="admin_list_organizations",
|
||||
)
|
||||
async def admin_list_organizations(
|
||||
pagination: PaginationParams = Depends(),
|
||||
is_active: Optional[bool] = Query(None, description="Filter by active status"),
|
||||
search: Optional[str] = Query(None, description="Search by name, slug, description"),
|
||||
is_active: bool | None = Query(None, description="Filter by active status"),
|
||||
search: str | None = Query(None, description="Search by name, slug, description"),
|
||||
admin: User = Depends(require_superuser),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> Any:
|
||||
"""List all organizations with filtering and search."""
|
||||
try:
|
||||
@@ -422,14 +418,14 @@ async def admin_list_organizations(
|
||||
skip=pagination.offset,
|
||||
limit=pagination.limit,
|
||||
is_active=is_active,
|
||||
search=search
|
||||
search=search,
|
||||
)
|
||||
|
||||
# Build response objects from optimized query results
|
||||
orgs_with_count = []
|
||||
for item in orgs_with_data:
|
||||
org = item['organization']
|
||||
member_count = item['member_count']
|
||||
org = item["organization"]
|
||||
member_count = item["member_count"]
|
||||
|
||||
org_dict = {
|
||||
"id": org.id,
|
||||
@@ -440,7 +436,7 @@ async def admin_list_organizations(
|
||||
"settings": org.settings,
|
||||
"created_at": org.created_at,
|
||||
"updated_at": org.updated_at,
|
||||
"member_count": member_count
|
||||
"member_count": member_count,
|
||||
}
|
||||
orgs_with_count.append(OrganizationResponse(**org_dict))
|
||||
|
||||
@@ -448,13 +444,13 @@ async def admin_list_organizations(
|
||||
total=total,
|
||||
page=pagination.page,
|
||||
limit=pagination.limit,
|
||||
items_count=len(orgs_with_count)
|
||||
items_count=len(orgs_with_count),
|
||||
)
|
||||
|
||||
return PaginatedResponse(data=orgs_with_count, pagination=pagination_meta)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error listing organizations (admin): {str(e)}", exc_info=True)
|
||||
logger.error(f"Error listing organizations (admin): {e!s}", exc_info=True)
|
||||
raise
|
||||
|
||||
|
||||
@@ -464,12 +460,12 @@ async def admin_list_organizations(
|
||||
status_code=status.HTTP_201_CREATED,
|
||||
summary="Admin: Create Organization",
|
||||
description="Create a new organization (admin only)",
|
||||
operation_id="admin_create_organization"
|
||||
operation_id="admin_create_organization",
|
||||
)
|
||||
async def admin_create_organization(
|
||||
org_in: OrganizationCreate,
|
||||
admin: User = Depends(require_superuser),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> Any:
|
||||
"""Create a new organization."""
|
||||
try:
|
||||
@@ -486,18 +482,15 @@ async def admin_create_organization(
|
||||
"settings": org.settings,
|
||||
"created_at": org.created_at,
|
||||
"updated_at": org.updated_at,
|
||||
"member_count": 0
|
||||
"member_count": 0,
|
||||
}
|
||||
return OrganizationResponse(**org_dict)
|
||||
|
||||
except ValueError as e:
|
||||
logger.warning(f"Failed to create organization: {str(e)}")
|
||||
raise NotFoundError(
|
||||
message=str(e),
|
||||
error_code=ErrorCode.ALREADY_EXISTS
|
||||
)
|
||||
logger.warning(f"Failed to create organization: {e!s}")
|
||||
raise NotFoundError(message=str(e), error_code=ErrorCode.ALREADY_EXISTS)
|
||||
except Exception as e:
|
||||
logger.error(f"Error creating organization (admin): {str(e)}", exc_info=True)
|
||||
logger.error(f"Error creating organization (admin): {e!s}", exc_info=True)
|
||||
raise
|
||||
|
||||
|
||||
@@ -506,19 +499,18 @@ async def admin_create_organization(
|
||||
response_model=OrganizationResponse,
|
||||
summary="Admin: Get Organization Details",
|
||||
description="Get detailed organization information (admin only)",
|
||||
operation_id="admin_get_organization"
|
||||
operation_id="admin_get_organization",
|
||||
)
|
||||
async def admin_get_organization(
|
||||
org_id: UUID,
|
||||
admin: User = Depends(require_superuser),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> Any:
|
||||
"""Get detailed information about a specific organization."""
|
||||
org = await organization_crud.get(db, id=org_id)
|
||||
if not org:
|
||||
raise NotFoundError(
|
||||
message=f"Organization {org_id} not found",
|
||||
error_code=ErrorCode.NOT_FOUND
|
||||
message=f"Organization {org_id} not found", error_code=ErrorCode.NOT_FOUND
|
||||
)
|
||||
|
||||
org_dict = {
|
||||
@@ -530,7 +522,9 @@ async def admin_get_organization(
|
||||
"settings": org.settings,
|
||||
"created_at": org.created_at,
|
||||
"updated_at": org.updated_at,
|
||||
"member_count": await organization_crud.get_member_count(db, organization_id=org.id)
|
||||
"member_count": await organization_crud.get_member_count(
|
||||
db, organization_id=org.id
|
||||
),
|
||||
}
|
||||
return OrganizationResponse(**org_dict)
|
||||
|
||||
@@ -540,13 +534,13 @@ async def admin_get_organization(
|
||||
response_model=OrganizationResponse,
|
||||
summary="Admin: Update Organization",
|
||||
description="Update organization information (admin only)",
|
||||
operation_id="admin_update_organization"
|
||||
operation_id="admin_update_organization",
|
||||
)
|
||||
async def admin_update_organization(
|
||||
org_id: UUID,
|
||||
org_in: OrganizationUpdate,
|
||||
admin: User = Depends(require_superuser),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> Any:
|
||||
"""Update organization information."""
|
||||
try:
|
||||
@@ -554,7 +548,7 @@ async def admin_update_organization(
|
||||
if not org:
|
||||
raise NotFoundError(
|
||||
message=f"Organization {org_id} not found",
|
||||
error_code=ErrorCode.NOT_FOUND
|
||||
error_code=ErrorCode.NOT_FOUND,
|
||||
)
|
||||
|
||||
updated_org = await organization_crud.update(db, db_obj=org, obj_in=org_in)
|
||||
@@ -569,14 +563,16 @@ async def admin_update_organization(
|
||||
"settings": updated_org.settings,
|
||||
"created_at": updated_org.created_at,
|
||||
"updated_at": updated_org.updated_at,
|
||||
"member_count": await organization_crud.get_member_count(db, organization_id=updated_org.id)
|
||||
"member_count": await organization_crud.get_member_count(
|
||||
db, organization_id=updated_org.id
|
||||
),
|
||||
}
|
||||
return OrganizationResponse(**org_dict)
|
||||
|
||||
except NotFoundError:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error updating organization (admin): {str(e)}", exc_info=True)
|
||||
logger.error(f"Error updating organization (admin): {e!s}", exc_info=True)
|
||||
raise
|
||||
|
||||
|
||||
@@ -585,12 +581,12 @@ async def admin_update_organization(
|
||||
response_model=MessageResponse,
|
||||
summary="Admin: Delete Organization",
|
||||
description="Delete an organization (admin only)",
|
||||
operation_id="admin_delete_organization"
|
||||
operation_id="admin_delete_organization",
|
||||
)
|
||||
async def admin_delete_organization(
|
||||
org_id: UUID,
|
||||
admin: User = Depends(require_superuser),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> Any:
|
||||
"""Delete an organization and all its relationships."""
|
||||
try:
|
||||
@@ -598,21 +594,20 @@ async def admin_delete_organization(
|
||||
if not org:
|
||||
raise NotFoundError(
|
||||
message=f"Organization {org_id} not found",
|
||||
error_code=ErrorCode.NOT_FOUND
|
||||
error_code=ErrorCode.NOT_FOUND,
|
||||
)
|
||||
|
||||
await organization_crud.remove(db, id=org_id)
|
||||
logger.info(f"Admin {admin.email} deleted organization {org.name}")
|
||||
|
||||
return MessageResponse(
|
||||
success=True,
|
||||
message=f"Organization {org.name} has been deleted"
|
||||
success=True, message=f"Organization {org.name} has been deleted"
|
||||
)
|
||||
|
||||
except NotFoundError:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error deleting organization (admin): {str(e)}", exc_info=True)
|
||||
logger.error(f"Error deleting organization (admin): {e!s}", exc_info=True)
|
||||
raise
|
||||
|
||||
|
||||
@@ -621,14 +616,14 @@ async def admin_delete_organization(
|
||||
response_model=PaginatedResponse[OrganizationMemberResponse],
|
||||
summary="Admin: List Organization Members",
|
||||
description="Get all members of an organization (admin only)",
|
||||
operation_id="admin_list_organization_members"
|
||||
operation_id="admin_list_organization_members",
|
||||
)
|
||||
async def admin_list_organization_members(
|
||||
org_id: UUID,
|
||||
pagination: PaginationParams = Depends(),
|
||||
is_active: Optional[bool] = Query(True, description="Filter by active status"),
|
||||
is_active: bool | None = Query(True, description="Filter by active status"),
|
||||
admin: User = Depends(require_superuser),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> Any:
|
||||
"""List all members of an organization."""
|
||||
try:
|
||||
@@ -636,7 +631,7 @@ async def admin_list_organization_members(
|
||||
if not org:
|
||||
raise NotFoundError(
|
||||
message=f"Organization {org_id} not found",
|
||||
error_code=ErrorCode.NOT_FOUND
|
||||
error_code=ErrorCode.NOT_FOUND,
|
||||
)
|
||||
|
||||
members, total = await organization_crud.get_organization_members(
|
||||
@@ -644,7 +639,7 @@ async def admin_list_organization_members(
|
||||
organization_id=org_id,
|
||||
skip=pagination.offset,
|
||||
limit=pagination.limit,
|
||||
is_active=is_active
|
||||
is_active=is_active,
|
||||
)
|
||||
|
||||
# Convert to response models
|
||||
@@ -654,7 +649,7 @@ async def admin_list_organization_members(
|
||||
total=total,
|
||||
page=pagination.page,
|
||||
limit=pagination.limit,
|
||||
items_count=len(member_responses)
|
||||
items_count=len(member_responses),
|
||||
)
|
||||
|
||||
return PaginatedResponse(data=member_responses, pagination=pagination_meta)
|
||||
@@ -662,14 +657,19 @@ async def admin_list_organization_members(
|
||||
except NotFoundError:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error listing organization members (admin): {str(e)}", exc_info=True)
|
||||
logger.error(
|
||||
f"Error listing organization members (admin): {e!s}", exc_info=True
|
||||
)
|
||||
raise
|
||||
|
||||
|
||||
class AddMemberRequest(BaseModel):
|
||||
"""Request to add a member to an organization."""
|
||||
|
||||
user_id: UUID = Field(..., description="User ID to add")
|
||||
role: OrganizationRole = Field(OrganizationRole.MEMBER, description="Role in organization")
|
||||
role: OrganizationRole = Field(
|
||||
OrganizationRole.MEMBER, description="Role in organization"
|
||||
)
|
||||
|
||||
|
||||
@router.post(
|
||||
@@ -677,13 +677,13 @@ class AddMemberRequest(BaseModel):
|
||||
response_model=MessageResponse,
|
||||
summary="Admin: Add Member to Organization",
|
||||
description="Add a user to an organization (admin only)",
|
||||
operation_id="admin_add_organization_member"
|
||||
operation_id="admin_add_organization_member",
|
||||
)
|
||||
async def admin_add_organization_member(
|
||||
org_id: UUID,
|
||||
request: AddMemberRequest,
|
||||
admin: User = Depends(require_superuser),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> Any:
|
||||
"""Add a user to an organization."""
|
||||
try:
|
||||
@@ -691,21 +691,18 @@ async def admin_add_organization_member(
|
||||
if not org:
|
||||
raise NotFoundError(
|
||||
message=f"Organization {org_id} not found",
|
||||
error_code=ErrorCode.NOT_FOUND
|
||||
error_code=ErrorCode.NOT_FOUND,
|
||||
)
|
||||
|
||||
user = await user_crud.get(db, id=request.user_id)
|
||||
if not user:
|
||||
raise NotFoundError(
|
||||
message=f"User {request.user_id} not found",
|
||||
error_code=ErrorCode.USER_NOT_FOUND
|
||||
error_code=ErrorCode.USER_NOT_FOUND,
|
||||
)
|
||||
|
||||
await organization_crud.add_user(
|
||||
db,
|
||||
organization_id=org_id,
|
||||
user_id=request.user_id,
|
||||
role=request.role
|
||||
db, organization_id=org_id, user_id=request.user_id, role=request.role
|
||||
)
|
||||
|
||||
logger.info(
|
||||
@@ -714,22 +711,21 @@ async def admin_add_organization_member(
|
||||
)
|
||||
|
||||
return MessageResponse(
|
||||
success=True,
|
||||
message=f"User {user.email} added to organization {org.name}"
|
||||
success=True, message=f"User {user.email} added to organization {org.name}"
|
||||
)
|
||||
|
||||
except ValueError as e:
|
||||
logger.warning(f"Failed to add user to organization: {str(e)}")
|
||||
logger.warning(f"Failed to add user to organization: {e!s}")
|
||||
# Use DuplicateError for "already exists" scenarios
|
||||
raise DuplicateError(
|
||||
message=str(e),
|
||||
error_code=ErrorCode.USER_ALREADY_EXISTS,
|
||||
field="user_id"
|
||||
message=str(e), error_code=ErrorCode.USER_ALREADY_EXISTS, field="user_id"
|
||||
)
|
||||
except NotFoundError:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error adding member to organization (admin): {str(e)}", exc_info=True)
|
||||
logger.error(
|
||||
f"Error adding member to organization (admin): {e!s}", exc_info=True
|
||||
)
|
||||
raise
|
||||
|
||||
|
||||
@@ -738,13 +734,13 @@ async def admin_add_organization_member(
|
||||
response_model=MessageResponse,
|
||||
summary="Admin: Remove Member from Organization",
|
||||
description="Remove a user from an organization (admin only)",
|
||||
operation_id="admin_remove_organization_member"
|
||||
operation_id="admin_remove_organization_member",
|
||||
)
|
||||
async def admin_remove_organization_member(
|
||||
org_id: UUID,
|
||||
user_id: UUID,
|
||||
admin: User = Depends(require_superuser),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> Any:
|
||||
"""Remove a user from an organization."""
|
||||
try:
|
||||
@@ -752,39 +748,40 @@ async def admin_remove_organization_member(
|
||||
if not org:
|
||||
raise NotFoundError(
|
||||
message=f"Organization {org_id} not found",
|
||||
error_code=ErrorCode.NOT_FOUND
|
||||
error_code=ErrorCode.NOT_FOUND,
|
||||
)
|
||||
|
||||
user = await user_crud.get(db, id=user_id)
|
||||
if not user:
|
||||
raise NotFoundError(
|
||||
message=f"User {user_id} not found",
|
||||
error_code=ErrorCode.USER_NOT_FOUND
|
||||
message=f"User {user_id} not found", error_code=ErrorCode.USER_NOT_FOUND
|
||||
)
|
||||
|
||||
success = await organization_crud.remove_user(
|
||||
db,
|
||||
organization_id=org_id,
|
||||
user_id=user_id
|
||||
db, organization_id=org_id, user_id=user_id
|
||||
)
|
||||
|
||||
if not success:
|
||||
raise NotFoundError(
|
||||
message="User is not a member of this organization",
|
||||
error_code=ErrorCode.NOT_FOUND
|
||||
error_code=ErrorCode.NOT_FOUND,
|
||||
)
|
||||
|
||||
logger.info(f"Admin {admin.email} removed user {user.email} from organization {org.name}")
|
||||
logger.info(
|
||||
f"Admin {admin.email} removed user {user.email} from organization {org.name}"
|
||||
)
|
||||
|
||||
return MessageResponse(
|
||||
success=True,
|
||||
message=f"User {user.email} removed from organization {org.name}"
|
||||
message=f"User {user.email} removed from organization {org.name}",
|
||||
)
|
||||
|
||||
except NotFoundError:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error removing member from organization (admin): {str(e)}", exc_info=True)
|
||||
logger.error(
|
||||
f"Error removing member from organization (admin): {e!s}", exc_info=True
|
||||
)
|
||||
raise
|
||||
|
||||
|
||||
@@ -792,6 +789,7 @@ async def admin_remove_organization_member(
|
||||
# Session Management Endpoints
|
||||
# ============================================================================
|
||||
|
||||
|
||||
@router.get(
|
||||
"/sessions",
|
||||
response_model=PaginatedResponse[AdminSessionResponse],
|
||||
@@ -802,13 +800,13 @@ async def admin_remove_organization_member(
|
||||
Returns paginated list of sessions with user information.
|
||||
Useful for admin dashboard statistics and session monitoring.
|
||||
""",
|
||||
operation_id="admin_list_sessions"
|
||||
operation_id="admin_list_sessions",
|
||||
)
|
||||
async def admin_list_sessions(
|
||||
pagination: PaginationParams = Depends(),
|
||||
is_active: Optional[bool] = Query(None, description="Filter by active status"),
|
||||
is_active: bool | None = Query(None, description="Filter by active status"),
|
||||
admin: User = Depends(require_superuser),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> Any:
|
||||
"""List all sessions across all users with filtering and pagination."""
|
||||
try:
|
||||
@@ -818,7 +816,7 @@ async def admin_list_sessions(
|
||||
skip=pagination.offset,
|
||||
limit=pagination.limit,
|
||||
active_only=is_active if is_active is not None else True,
|
||||
with_user=True
|
||||
with_user=True,
|
||||
)
|
||||
|
||||
# Build response objects with user information
|
||||
@@ -847,21 +845,23 @@ async def admin_list_sessions(
|
||||
last_used_at=session.last_used_at,
|
||||
created_at=session.created_at,
|
||||
expires_at=session.expires_at,
|
||||
is_active=session.is_active
|
||||
is_active=session.is_active,
|
||||
)
|
||||
session_responses.append(session_response)
|
||||
|
||||
logger.info(f"Admin {admin.email} listed {len(session_responses)} sessions (total: {total})")
|
||||
logger.info(
|
||||
f"Admin {admin.email} listed {len(session_responses)} sessions (total: {total})"
|
||||
)
|
||||
|
||||
pagination_meta = create_pagination_meta(
|
||||
total=total,
|
||||
page=pagination.page,
|
||||
limit=pagination.limit,
|
||||
items_count=len(session_responses)
|
||||
items_count=len(session_responses),
|
||||
)
|
||||
|
||||
return PaginatedResponse(data=session_responses, pagination=pagination_meta)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error listing sessions (admin): {str(e)}", exc_info=True)
|
||||
logger.error(f"Error listing sessions (admin): {e!s}", exc_info=True)
|
||||
raise
|
||||
|
||||
@@ -1,39 +1,43 @@
|
||||
# app/api/routes/auth.py
|
||||
import logging
|
||||
import os
|
||||
from datetime import datetime, timezone
|
||||
from datetime import UTC, datetime
|
||||
from typing import Any
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, status, Request
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request, status
|
||||
from fastapi.security import OAuth2PasswordRequestForm
|
||||
from slowapi import Limiter
|
||||
from slowapi.util import get_remote_address
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.api.dependencies.auth import get_current_user
|
||||
from app.core.auth import TokenExpiredError, TokenInvalidError, decode_token
|
||||
from app.core.auth import get_password_hash
|
||||
from app.core.auth import (
|
||||
TokenExpiredError,
|
||||
TokenInvalidError,
|
||||
decode_token,
|
||||
get_password_hash,
|
||||
)
|
||||
from app.core.database import get_db
|
||||
from app.core.exceptions import (
|
||||
AuthenticationError as AuthError,
|
||||
DatabaseError,
|
||||
ErrorCode
|
||||
ErrorCode,
|
||||
)
|
||||
from app.crud.session import session as session_crud
|
||||
from app.crud.user import user as user_crud
|
||||
from app.models.user import User
|
||||
from app.schemas.common import MessageResponse
|
||||
from app.schemas.sessions import SessionCreate, LogoutRequest
|
||||
from app.schemas.sessions import LogoutRequest, SessionCreate
|
||||
from app.schemas.users import (
|
||||
LoginRequest,
|
||||
PasswordResetConfirm,
|
||||
PasswordResetRequest,
|
||||
RefreshTokenRequest,
|
||||
Token,
|
||||
UserCreate,
|
||||
UserResponse,
|
||||
Token,
|
||||
LoginRequest,
|
||||
RefreshTokenRequest,
|
||||
PasswordResetRequest,
|
||||
PasswordResetConfirm
|
||||
)
|
||||
from app.services.auth_service import AuthService, AuthenticationError
|
||||
from app.services.auth_service import AuthenticationError, AuthService
|
||||
from app.services.email_service import email_service
|
||||
from app.utils.device import extract_device_info
|
||||
from app.utils.security import create_password_reset_token, verify_password_reset_token
|
||||
@@ -54,7 +58,7 @@ async def _create_login_session(
|
||||
request: Request,
|
||||
user: User,
|
||||
tokens: Token,
|
||||
login_type: str = "login"
|
||||
login_type: str = "login",
|
||||
) -> None:
|
||||
"""
|
||||
Create a session record for successful login.
|
||||
@@ -81,8 +85,8 @@ async def _create_login_session(
|
||||
device_id=device_info.device_id,
|
||||
ip_address=device_info.ip_address,
|
||||
user_agent=device_info.user_agent,
|
||||
last_used_at=datetime.now(timezone.utc),
|
||||
expires_at=datetime.fromtimestamp(refresh_payload.exp, tz=timezone.utc),
|
||||
last_used_at=datetime.now(UTC),
|
||||
expires_at=datetime.fromtimestamp(refresh_payload.exp, tz=UTC),
|
||||
location_city=device_info.location_city,
|
||||
location_country=device_info.location_country,
|
||||
)
|
||||
@@ -95,15 +99,20 @@ async def _create_login_session(
|
||||
)
|
||||
except Exception as session_err:
|
||||
# Log but don't fail login if session creation fails
|
||||
logger.error(f"Failed to create session for {user.email}: {str(session_err)}", exc_info=True)
|
||||
logger.error(
|
||||
f"Failed to create session for {user.email}: {session_err!s}", exc_info=True
|
||||
)
|
||||
|
||||
|
||||
@router.post("/register", response_model=UserResponse, status_code=status.HTTP_201_CREATED, operation_id="register")
|
||||
@router.post(
|
||||
"/register",
|
||||
response_model=UserResponse,
|
||||
status_code=status.HTTP_201_CREATED,
|
||||
operation_id="register",
|
||||
)
|
||||
@limiter.limit(f"{5 * RATE_MULTIPLIER}/minute")
|
||||
async def register_user(
|
||||
request: Request,
|
||||
user_data: UserCreate,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
request: Request, user_data: UserCreate, db: AsyncSession = Depends(get_db)
|
||||
) -> Any:
|
||||
"""
|
||||
Register a new user.
|
||||
@@ -116,25 +125,23 @@ async def register_user(
|
||||
return user
|
||||
except AuthenticationError as e:
|
||||
# SECURITY: Don't reveal if email exists - generic error message
|
||||
logger.warning(f"Registration failed: {str(e)}")
|
||||
logger.warning(f"Registration failed: {e!s}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Registration failed. Please check your information and try again."
|
||||
detail="Registration failed. Please check your information and try again.",
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error during registration: {str(e)}", exc_info=True)
|
||||
logger.error(f"Unexpected error during registration: {e!s}", exc_info=True)
|
||||
raise DatabaseError(
|
||||
message="An unexpected error occurred. Please try again later.",
|
||||
error_code=ErrorCode.INTERNAL_ERROR
|
||||
error_code=ErrorCode.INTERNAL_ERROR,
|
||||
)
|
||||
|
||||
|
||||
@router.post("/login", response_model=Token, operation_id="login")
|
||||
@limiter.limit(f"{10 * RATE_MULTIPLIER}/minute")
|
||||
async def login(
|
||||
request: Request,
|
||||
login_data: LoginRequest,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
request: Request, login_data: LoginRequest, db: AsyncSession = Depends(get_db)
|
||||
) -> Any:
|
||||
"""
|
||||
Login with username and password.
|
||||
@@ -146,14 +153,16 @@ async def login(
|
||||
"""
|
||||
try:
|
||||
# Attempt to authenticate the user
|
||||
user = await AuthService.authenticate_user(db, login_data.email, login_data.password)
|
||||
user = await AuthService.authenticate_user(
|
||||
db, login_data.email, login_data.password
|
||||
)
|
||||
|
||||
# Explicitly check for None result and raise correct exception
|
||||
if user is None:
|
||||
logger.warning(f"Invalid login attempt for: {login_data.email}")
|
||||
raise AuthError(
|
||||
message="Invalid email or password",
|
||||
error_code=ErrorCode.INVALID_CREDENTIALS
|
||||
error_code=ErrorCode.INVALID_CREDENTIALS,
|
||||
)
|
||||
|
||||
# User is authenticated, generate tokens
|
||||
@@ -166,29 +175,26 @@ async def login(
|
||||
|
||||
except AuthenticationError as e:
|
||||
# Handle specific authentication errors like inactive accounts
|
||||
logger.warning(f"Authentication failed: {str(e)}")
|
||||
raise AuthError(
|
||||
message=str(e),
|
||||
error_code=ErrorCode.INVALID_CREDENTIALS
|
||||
)
|
||||
logger.warning(f"Authentication failed: {e!s}")
|
||||
raise AuthError(message=str(e), error_code=ErrorCode.INVALID_CREDENTIALS)
|
||||
except AuthError:
|
||||
# Re-raise custom auth exceptions without modification
|
||||
raise
|
||||
except Exception as e:
|
||||
# Handle unexpected errors
|
||||
logger.error(f"Unexpected error during login: {str(e)}", exc_info=True)
|
||||
logger.error(f"Unexpected error during login: {e!s}", exc_info=True)
|
||||
raise DatabaseError(
|
||||
message="An unexpected error occurred. Please try again later.",
|
||||
error_code=ErrorCode.INTERNAL_ERROR
|
||||
error_code=ErrorCode.INTERNAL_ERROR,
|
||||
)
|
||||
|
||||
|
||||
@router.post("/login/oauth", response_model=Token, operation_id='login_oauth')
|
||||
@router.post("/login/oauth", response_model=Token, operation_id="login_oauth")
|
||||
@limiter.limit("10/minute")
|
||||
async def login_oauth(
|
||||
request: Request,
|
||||
form_data: OAuth2PasswordRequestForm = Depends(),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
request: Request,
|
||||
form_data: OAuth2PasswordRequestForm = Depends(),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> Any:
|
||||
"""
|
||||
OAuth2-compatible login endpoint, used by the OpenAPI UI.
|
||||
@@ -199,12 +205,14 @@ async def login_oauth(
|
||||
Access and refresh tokens.
|
||||
"""
|
||||
try:
|
||||
user = await AuthService.authenticate_user(db, form_data.username, form_data.password)
|
||||
user = await AuthService.authenticate_user(
|
||||
db, form_data.username, form_data.password
|
||||
)
|
||||
|
||||
if user is None:
|
||||
raise AuthError(
|
||||
message="Invalid email or password",
|
||||
error_code=ErrorCode.INVALID_CREDENTIALS
|
||||
error_code=ErrorCode.INVALID_CREDENTIALS,
|
||||
)
|
||||
|
||||
# Generate tokens
|
||||
@@ -216,28 +224,25 @@ async def login_oauth(
|
||||
# Return full token response with user data
|
||||
return tokens
|
||||
except AuthenticationError as e:
|
||||
logger.warning(f"OAuth authentication failed: {str(e)}")
|
||||
raise AuthError(
|
||||
message=str(e),
|
||||
error_code=ErrorCode.INVALID_CREDENTIALS
|
||||
)
|
||||
logger.warning(f"OAuth authentication failed: {e!s}")
|
||||
raise AuthError(message=str(e), error_code=ErrorCode.INVALID_CREDENTIALS)
|
||||
except AuthError:
|
||||
# Re-raise custom auth exceptions without modification
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error during OAuth login: {str(e)}", exc_info=True)
|
||||
logger.error(f"Unexpected error during OAuth login: {e!s}", exc_info=True)
|
||||
raise DatabaseError(
|
||||
message="An unexpected error occurred. Please try again later.",
|
||||
error_code=ErrorCode.INTERNAL_ERROR
|
||||
error_code=ErrorCode.INTERNAL_ERROR,
|
||||
)
|
||||
|
||||
|
||||
@router.post("/refresh", response_model=Token, operation_id="refresh_token")
|
||||
@limiter.limit("30/minute")
|
||||
async def refresh_token(
|
||||
request: Request,
|
||||
refresh_data: RefreshTokenRequest,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
request: Request,
|
||||
refresh_data: RefreshTokenRequest,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> Any:
|
||||
"""
|
||||
Refresh access token using a refresh token.
|
||||
@@ -249,13 +254,17 @@ async def refresh_token(
|
||||
"""
|
||||
try:
|
||||
# Decode the refresh token to get the JTI
|
||||
refresh_payload = decode_token(refresh_data.refresh_token, verify_type="refresh")
|
||||
refresh_payload = decode_token(
|
||||
refresh_data.refresh_token, verify_type="refresh"
|
||||
)
|
||||
|
||||
# Check if session exists and is active
|
||||
session = await session_crud.get_active_by_jti(db, jti=refresh_payload.jti)
|
||||
|
||||
if not session:
|
||||
logger.warning(f"Refresh token used for inactive or non-existent session: {refresh_payload.jti}")
|
||||
logger.warning(
|
||||
f"Refresh token used for inactive or non-existent session: {refresh_payload.jti}"
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Session has been revoked. Please log in again.",
|
||||
@@ -274,10 +283,12 @@ async def refresh_token(
|
||||
db,
|
||||
session=session,
|
||||
new_jti=new_refresh_payload.jti,
|
||||
new_expires_at=datetime.fromtimestamp(new_refresh_payload.exp, tz=timezone.utc)
|
||||
new_expires_at=datetime.fromtimestamp(new_refresh_payload.exp, tz=UTC),
|
||||
)
|
||||
except Exception as session_err:
|
||||
logger.error(f"Failed to update session {session.id}: {str(session_err)}", exc_info=True)
|
||||
logger.error(
|
||||
f"Failed to update session {session.id}: {session_err!s}", exc_info=True
|
||||
)
|
||||
# Continue anyway - tokens are already issued
|
||||
|
||||
return tokens
|
||||
@@ -300,10 +311,10 @@ async def refresh_token(
|
||||
# Re-raise HTTP exceptions (like session revoked)
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error during token refresh: {str(e)}")
|
||||
logger.error(f"Unexpected error during token refresh: {e!s}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="An unexpected error occurred. Please try again later."
|
||||
detail="An unexpected error occurred. Please try again later.",
|
||||
)
|
||||
|
||||
|
||||
@@ -320,13 +331,13 @@ async def refresh_token(
|
||||
|
||||
**Rate Limit**: 3 requests/minute
|
||||
""",
|
||||
operation_id="request_password_reset"
|
||||
operation_id="request_password_reset",
|
||||
)
|
||||
@limiter.limit("3/minute")
|
||||
async def request_password_reset(
|
||||
request: Request,
|
||||
reset_request: PasswordResetRequest,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> Any:
|
||||
"""
|
||||
Request a password reset.
|
||||
@@ -345,26 +356,26 @@ async def request_password_reset(
|
||||
|
||||
# Send password reset email
|
||||
await email_service.send_password_reset_email(
|
||||
to_email=user.email,
|
||||
reset_token=reset_token,
|
||||
user_name=user.first_name
|
||||
to_email=user.email, reset_token=reset_token, user_name=user.first_name
|
||||
)
|
||||
logger.info(f"Password reset requested for {user.email}")
|
||||
else:
|
||||
# Log attempt but don't reveal if email exists
|
||||
logger.warning(f"Password reset requested for non-existent or inactive email: {reset_request.email}")
|
||||
logger.warning(
|
||||
f"Password reset requested for non-existent or inactive email: {reset_request.email}"
|
||||
)
|
||||
|
||||
# Always return success to prevent email enumeration
|
||||
return MessageResponse(
|
||||
success=True,
|
||||
message="If your email is registered, you will receive a password reset link shortly"
|
||||
message="If your email is registered, you will receive a password reset link shortly",
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing password reset request: {str(e)}", exc_info=True)
|
||||
logger.error(f"Error processing password reset request: {e!s}", exc_info=True)
|
||||
# Still return success to prevent information leakage
|
||||
return MessageResponse(
|
||||
success=True,
|
||||
message="If your email is registered, you will receive a password reset link shortly"
|
||||
message="If your email is registered, you will receive a password reset link shortly",
|
||||
)
|
||||
|
||||
|
||||
@@ -378,13 +389,13 @@ async def request_password_reset(
|
||||
|
||||
**Rate Limit**: 5 requests/minute
|
||||
""",
|
||||
operation_id="confirm_password_reset"
|
||||
operation_id="confirm_password_reset",
|
||||
)
|
||||
@limiter.limit("5/minute")
|
||||
async def confirm_password_reset(
|
||||
request: Request,
|
||||
reset_confirm: PasswordResetConfirm,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> Any:
|
||||
"""
|
||||
Confirm password reset with token.
|
||||
@@ -398,7 +409,7 @@ async def confirm_password_reset(
|
||||
if not email:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Invalid or expired password reset token"
|
||||
detail="Invalid or expired password reset token",
|
||||
)
|
||||
|
||||
# Look up user
|
||||
@@ -406,14 +417,13 @@ async def confirm_password_reset(
|
||||
|
||||
if not user:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="User not found"
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail="User not found"
|
||||
)
|
||||
|
||||
if not user.is_active:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="User account is inactive"
|
||||
detail="User account is inactive",
|
||||
)
|
||||
|
||||
# Update password
|
||||
@@ -424,29 +434,33 @@ async def confirm_password_reset(
|
||||
# SECURITY: Invalidate all existing sessions after password reset
|
||||
# This prevents stolen sessions from being used after password change
|
||||
from app.crud.session import session as session_crud
|
||||
|
||||
try:
|
||||
deactivated_count = await session_crud.deactivate_all_user_sessions(
|
||||
db,
|
||||
user_id=str(user.id)
|
||||
db, user_id=str(user.id)
|
||||
)
|
||||
logger.info(
|
||||
f"Password reset successful for {user.email}, invalidated {deactivated_count} sessions"
|
||||
)
|
||||
logger.info(f"Password reset successful for {user.email}, invalidated {deactivated_count} sessions")
|
||||
except Exception as session_error:
|
||||
# Log but don't fail password reset if session invalidation fails
|
||||
logger.error(f"Failed to invalidate sessions after password reset: {str(session_error)}")
|
||||
logger.error(
|
||||
f"Failed to invalidate sessions after password reset: {session_error!s}"
|
||||
)
|
||||
|
||||
return MessageResponse(
|
||||
success=True,
|
||||
message="Password has been reset successfully. All devices have been logged out for security. You can now log in with your new password."
|
||||
message="Password has been reset successfully. All devices have been logged out for security. You can now log in with your new password.",
|
||||
)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error confirming password reset: {str(e)}", exc_info=True)
|
||||
logger.error(f"Error confirming password reset: {e!s}", exc_info=True)
|
||||
await db.rollback()
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="An error occurred while resetting your password"
|
||||
detail="An error occurred while resetting your password",
|
||||
)
|
||||
|
||||
|
||||
@@ -464,14 +478,14 @@ async def confirm_password_reset(
|
||||
|
||||
**Rate Limit**: 10 requests/minute
|
||||
""",
|
||||
operation_id="logout"
|
||||
operation_id="logout",
|
||||
)
|
||||
@limiter.limit("10/minute")
|
||||
async def logout(
|
||||
request: Request,
|
||||
logout_request: LogoutRequest,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> Any:
|
||||
"""
|
||||
Logout from current device by deactivating the session.
|
||||
@@ -487,15 +501,14 @@ async def logout(
|
||||
try:
|
||||
# Decode refresh token to get JTI
|
||||
try:
|
||||
refresh_payload = decode_token(logout_request.refresh_token, verify_type="refresh")
|
||||
refresh_payload = decode_token(
|
||||
logout_request.refresh_token, verify_type="refresh"
|
||||
)
|
||||
except (TokenExpiredError, TokenInvalidError) as e:
|
||||
# Even if token is expired/invalid, try to deactivate session
|
||||
logger.warning(f"Logout with invalid/expired token: {str(e)}")
|
||||
logger.warning(f"Logout with invalid/expired token: {e!s}")
|
||||
# Don't fail - return success anyway
|
||||
return MessageResponse(
|
||||
success=True,
|
||||
message="Logged out successfully"
|
||||
)
|
||||
return MessageResponse(success=True, message="Logged out successfully")
|
||||
|
||||
# Find the session by JTI
|
||||
session = await session_crud.get_by_jti(db, jti=refresh_payload.jti)
|
||||
@@ -509,7 +522,7 @@ async def logout(
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="You can only logout your own sessions"
|
||||
detail="You can only logout your own sessions",
|
||||
)
|
||||
|
||||
# Deactivate the session
|
||||
@@ -522,22 +535,20 @@ async def logout(
|
||||
else:
|
||||
# Session not found - maybe already deleted or never existed
|
||||
# Return success anyway (idempotent)
|
||||
logger.info(f"Logout requested for non-existent session (JTI: {refresh_payload.jti})")
|
||||
logger.info(
|
||||
f"Logout requested for non-existent session (JTI: {refresh_payload.jti})"
|
||||
)
|
||||
|
||||
return MessageResponse(
|
||||
success=True,
|
||||
message="Logged out successfully"
|
||||
)
|
||||
return MessageResponse(success=True, message="Logged out successfully")
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error during logout for user {current_user.id}: {str(e)}", exc_info=True)
|
||||
# Don't expose error details
|
||||
return MessageResponse(
|
||||
success=True,
|
||||
message="Logged out successfully"
|
||||
logger.error(
|
||||
f"Error during logout for user {current_user.id}: {e!s}", exc_info=True
|
||||
)
|
||||
# Don't expose error details
|
||||
return MessageResponse(success=True, message="Logged out successfully")
|
||||
|
||||
|
||||
@router.post(
|
||||
@@ -553,13 +564,13 @@ async def logout(
|
||||
|
||||
**Rate Limit**: 5 requests/minute
|
||||
""",
|
||||
operation_id="logout_all"
|
||||
operation_id="logout_all",
|
||||
)
|
||||
@limiter.limit("5/minute")
|
||||
async def logout_all(
|
||||
request: Request,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> Any:
|
||||
"""
|
||||
Logout from all devices by deactivating all user sessions.
|
||||
@@ -573,19 +584,25 @@ async def logout_all(
|
||||
"""
|
||||
try:
|
||||
# Deactivate all sessions for this user
|
||||
count = await session_crud.deactivate_all_user_sessions(db, user_id=str(current_user.id))
|
||||
count = await session_crud.deactivate_all_user_sessions(
|
||||
db, user_id=str(current_user.id)
|
||||
)
|
||||
|
||||
logger.info(f"User {current_user.id} logged out from all devices ({count} sessions)")
|
||||
logger.info(
|
||||
f"User {current_user.id} logged out from all devices ({count} sessions)"
|
||||
)
|
||||
|
||||
return MessageResponse(
|
||||
success=True,
|
||||
message=f"Successfully logged out from all devices ({count} sessions terminated)"
|
||||
message=f"Successfully logged out from all devices ({count} sessions terminated)",
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error during logout-all for user {current_user.id}: {str(e)}", exc_info=True)
|
||||
logger.error(
|
||||
f"Error during logout-all for user {current_user.id}: {e!s}", exc_info=True
|
||||
)
|
||||
await db.rollback()
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="An error occurred while logging out"
|
||||
detail="An error occurred while logging out",
|
||||
)
|
||||
|
||||
@@ -4,8 +4,9 @@ Organization endpoints for regular users.
|
||||
|
||||
These endpoints allow users to view and manage organizations they belong to.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Any, List
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import APIRouter, Depends, Query
|
||||
@@ -14,18 +15,18 @@ from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from app.api.dependencies.auth import get_current_user
|
||||
from app.api.dependencies.permissions import require_org_admin, require_org_membership
|
||||
from app.core.database import get_db
|
||||
from app.core.exceptions import NotFoundError, ErrorCode
|
||||
from app.core.exceptions import ErrorCode, NotFoundError
|
||||
from app.crud.organization import organization as organization_crud
|
||||
from app.models.user import User
|
||||
from app.schemas.common import (
|
||||
PaginationParams,
|
||||
PaginatedResponse,
|
||||
create_pagination_meta
|
||||
PaginationParams,
|
||||
create_pagination_meta,
|
||||
)
|
||||
from app.schemas.organizations import (
|
||||
OrganizationResponse,
|
||||
OrganizationMemberResponse,
|
||||
OrganizationUpdate
|
||||
OrganizationResponse,
|
||||
OrganizationUpdate,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -35,15 +36,15 @@ router = APIRouter()
|
||||
|
||||
@router.get(
|
||||
"/me",
|
||||
response_model=List[OrganizationResponse],
|
||||
response_model=list[OrganizationResponse],
|
||||
summary="Get My Organizations",
|
||||
description="Get all organizations the current user belongs to",
|
||||
operation_id="get_my_organizations"
|
||||
operation_id="get_my_organizations",
|
||||
)
|
||||
async def get_my_organizations(
|
||||
is_active: bool = Query(True, description="Filter by active membership"),
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> Any:
|
||||
"""
|
||||
Get all organizations the current user belongs to.
|
||||
@@ -54,15 +55,13 @@ async def get_my_organizations(
|
||||
try:
|
||||
# Get all org data in single query with JOIN and subquery
|
||||
orgs_data = await organization_crud.get_user_organizations_with_details(
|
||||
db,
|
||||
user_id=current_user.id,
|
||||
is_active=is_active
|
||||
db, user_id=current_user.id, is_active=is_active
|
||||
)
|
||||
|
||||
# Transform to response objects
|
||||
orgs_with_data = []
|
||||
for item in orgs_data:
|
||||
org = item['organization']
|
||||
org = item["organization"]
|
||||
org_dict = {
|
||||
"id": org.id,
|
||||
"name": org.name,
|
||||
@@ -72,14 +71,14 @@ async def get_my_organizations(
|
||||
"settings": org.settings,
|
||||
"created_at": org.created_at,
|
||||
"updated_at": org.updated_at,
|
||||
"member_count": item['member_count']
|
||||
"member_count": item["member_count"],
|
||||
}
|
||||
orgs_with_data.append(OrganizationResponse(**org_dict))
|
||||
|
||||
return orgs_with_data
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting user organizations: {str(e)}", exc_info=True)
|
||||
logger.error(f"Error getting user organizations: {e!s}", exc_info=True)
|
||||
raise
|
||||
|
||||
|
||||
@@ -88,12 +87,12 @@ async def get_my_organizations(
|
||||
response_model=OrganizationResponse,
|
||||
summary="Get Organization Details",
|
||||
description="Get details of an organization the user belongs to",
|
||||
operation_id="get_organization"
|
||||
operation_id="get_organization",
|
||||
)
|
||||
async def get_organization(
|
||||
organization_id: UUID,
|
||||
current_user: User = Depends(require_org_membership),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> Any:
|
||||
"""
|
||||
Get details of a specific organization.
|
||||
@@ -105,7 +104,7 @@ async def get_organization(
|
||||
if not org: # pragma: no cover - Permission check prevents this (see docs/UNREACHABLE_DEFENSIVE_CODE_ANALYSIS.md)
|
||||
raise NotFoundError(
|
||||
detail=f"Organization {organization_id} not found",
|
||||
error_code=ErrorCode.NOT_FOUND
|
||||
error_code=ErrorCode.NOT_FOUND,
|
||||
)
|
||||
|
||||
org_dict = {
|
||||
@@ -117,14 +116,16 @@ async def get_organization(
|
||||
"settings": org.settings,
|
||||
"created_at": org.created_at,
|
||||
"updated_at": org.updated_at,
|
||||
"member_count": await organization_crud.get_member_count(db, organization_id=org.id)
|
||||
"member_count": await organization_crud.get_member_count(
|
||||
db, organization_id=org.id
|
||||
),
|
||||
}
|
||||
return OrganizationResponse(**org_dict)
|
||||
|
||||
except NotFoundError: # pragma: no cover - See above
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting organization: {str(e)}", exc_info=True)
|
||||
logger.error(f"Error getting organization: {e!s}", exc_info=True)
|
||||
raise
|
||||
|
||||
|
||||
@@ -133,14 +134,14 @@ async def get_organization(
|
||||
response_model=PaginatedResponse[OrganizationMemberResponse],
|
||||
summary="Get Organization Members",
|
||||
description="Get all members of an organization (members can view)",
|
||||
operation_id="get_organization_members"
|
||||
operation_id="get_organization_members",
|
||||
)
|
||||
async def get_organization_members(
|
||||
organization_id: UUID,
|
||||
pagination: PaginationParams = Depends(),
|
||||
is_active: bool = Query(True, description="Filter by active status"),
|
||||
current_user: User = Depends(require_org_membership),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> Any:
|
||||
"""
|
||||
Get all members of an organization.
|
||||
@@ -153,7 +154,7 @@ async def get_organization_members(
|
||||
organization_id=organization_id,
|
||||
skip=pagination.offset,
|
||||
limit=pagination.limit,
|
||||
is_active=is_active
|
||||
is_active=is_active,
|
||||
)
|
||||
|
||||
member_responses = [OrganizationMemberResponse(**member) for member in members]
|
||||
@@ -162,13 +163,13 @@ async def get_organization_members(
|
||||
total=total,
|
||||
page=pagination.page,
|
||||
limit=pagination.limit,
|
||||
items_count=len(member_responses)
|
||||
items_count=len(member_responses),
|
||||
)
|
||||
|
||||
return PaginatedResponse(data=member_responses, pagination=pagination_meta)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting organization members: {str(e)}", exc_info=True)
|
||||
logger.error(f"Error getting organization members: {e!s}", exc_info=True)
|
||||
raise
|
||||
|
||||
|
||||
@@ -177,13 +178,13 @@ async def get_organization_members(
|
||||
response_model=OrganizationResponse,
|
||||
summary="Update Organization",
|
||||
description="Update organization details (admin/owner only)",
|
||||
operation_id="update_organization"
|
||||
operation_id="update_organization",
|
||||
)
|
||||
async def update_organization(
|
||||
organization_id: UUID,
|
||||
org_in: OrganizationUpdate,
|
||||
current_user: User = Depends(require_org_admin),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> Any:
|
||||
"""
|
||||
Update organization details.
|
||||
@@ -195,11 +196,13 @@ async def update_organization(
|
||||
if not org: # pragma: no cover - Permission check prevents this (see docs/UNREACHABLE_DEFENSIVE_CODE_ANALYSIS.md)
|
||||
raise NotFoundError(
|
||||
detail=f"Organization {organization_id} not found",
|
||||
error_code=ErrorCode.NOT_FOUND
|
||||
error_code=ErrorCode.NOT_FOUND,
|
||||
)
|
||||
|
||||
updated_org = await organization_crud.update(db, db_obj=org, obj_in=org_in)
|
||||
logger.info(f"User {current_user.email} updated organization {updated_org.name}")
|
||||
logger.info(
|
||||
f"User {current_user.email} updated organization {updated_org.name}"
|
||||
)
|
||||
|
||||
org_dict = {
|
||||
"id": updated_org.id,
|
||||
@@ -210,12 +213,14 @@ async def update_organization(
|
||||
"settings": updated_org.settings,
|
||||
"created_at": updated_org.created_at,
|
||||
"updated_at": updated_org.updated_at,
|
||||
"member_count": await organization_crud.get_member_count(db, organization_id=updated_org.id)
|
||||
"member_count": await organization_crud.get_member_count(
|
||||
db, organization_id=updated_org.id
|
||||
),
|
||||
}
|
||||
return OrganizationResponse(**org_dict)
|
||||
|
||||
except NotFoundError: # pragma: no cover - See above
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error updating organization: {str(e)}", exc_info=True)
|
||||
logger.error(f"Error updating organization: {e!s}", exc_info=True)
|
||||
raise
|
||||
|
||||
@@ -3,11 +3,12 @@ Session management endpoints.
|
||||
|
||||
Allows users to view and manage their active sessions across devices.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, status, Request
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request, status
|
||||
from slowapi import Limiter
|
||||
from slowapi.util import get_remote_address
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
@@ -15,11 +16,11 @@ from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from app.api.dependencies.auth import get_current_user
|
||||
from app.core.auth import decode_token
|
||||
from app.core.database import get_db
|
||||
from app.core.exceptions import NotFoundError, AuthorizationError, ErrorCode
|
||||
from app.core.exceptions import AuthorizationError, ErrorCode, NotFoundError
|
||||
from app.crud.session import session as session_crud
|
||||
from app.models.user import User
|
||||
from app.schemas.common import MessageResponse
|
||||
from app.schemas.sessions import SessionResponse, SessionListResponse
|
||||
from app.schemas.sessions import SessionListResponse, SessionResponse
|
||||
|
||||
router = APIRouter()
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -39,13 +40,13 @@ limiter = Limiter(key_func=get_remote_address)
|
||||
|
||||
**Rate Limit**: 30 requests/minute
|
||||
""",
|
||||
operation_id="list_my_sessions"
|
||||
operation_id="list_my_sessions",
|
||||
)
|
||||
@limiter.limit("30/minute")
|
||||
async def list_my_sessions(
|
||||
request: Request,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> Any:
|
||||
"""
|
||||
List all active sessions for the current user.
|
||||
@@ -60,18 +61,15 @@ async def list_my_sessions(
|
||||
try:
|
||||
# Get all active sessions for user
|
||||
sessions = await session_crud.get_user_sessions(
|
||||
db,
|
||||
user_id=str(current_user.id),
|
||||
active_only=True
|
||||
db, user_id=str(current_user.id), active_only=True
|
||||
)
|
||||
|
||||
# Try to identify current session from Authorization header
|
||||
current_session_jti = None
|
||||
auth_header = request.headers.get("authorization")
|
||||
if auth_header and auth_header.startswith("Bearer "):
|
||||
try:
|
||||
access_token = auth_header.split(" ")[1]
|
||||
token_payload = decode_token(access_token)
|
||||
decode_token(access_token)
|
||||
# Note: Access tokens don't have JTI by default, but we can try
|
||||
# For now, we'll mark current based on most recent activity
|
||||
except Exception:
|
||||
@@ -90,22 +88,27 @@ async def list_my_sessions(
|
||||
last_used_at=s.last_used_at,
|
||||
created_at=s.created_at,
|
||||
expires_at=s.expires_at,
|
||||
is_current=(s == sessions[0] if sessions else False) # Most recent = current
|
||||
is_current=(
|
||||
s == sessions[0] if sessions else False
|
||||
), # Most recent = current
|
||||
)
|
||||
session_responses.append(session_response)
|
||||
|
||||
logger.info(f"User {current_user.id} listed {len(session_responses)} active sessions")
|
||||
logger.info(
|
||||
f"User {current_user.id} listed {len(session_responses)} active sessions"
|
||||
)
|
||||
|
||||
return SessionListResponse(
|
||||
sessions=session_responses,
|
||||
total=len(session_responses)
|
||||
sessions=session_responses, total=len(session_responses)
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error listing sessions for user {current_user.id}: {str(e)}", exc_info=True)
|
||||
logger.error(
|
||||
f"Error listing sessions for user {current_user.id}: {e!s}", exc_info=True
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Failed to retrieve sessions"
|
||||
detail="Failed to retrieve sessions",
|
||||
)
|
||||
|
||||
|
||||
@@ -122,14 +125,14 @@ async def list_my_sessions(
|
||||
|
||||
**Rate Limit**: 10 requests/minute
|
||||
""",
|
||||
operation_id="revoke_session"
|
||||
operation_id="revoke_session",
|
||||
)
|
||||
@limiter.limit("10/minute")
|
||||
async def revoke_session(
|
||||
request: Request,
|
||||
session_id: UUID,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> Any:
|
||||
"""
|
||||
Revoke a specific session by ID.
|
||||
@@ -149,7 +152,7 @@ async def revoke_session(
|
||||
if not session:
|
||||
raise NotFoundError(
|
||||
message=f"Session {session_id} not found",
|
||||
error_code=ErrorCode.NOT_FOUND
|
||||
error_code=ErrorCode.NOT_FOUND,
|
||||
)
|
||||
|
||||
# Verify session belongs to current user
|
||||
@@ -160,7 +163,7 @@ async def revoke_session(
|
||||
)
|
||||
raise AuthorizationError(
|
||||
message="You can only revoke your own sessions",
|
||||
error_code=ErrorCode.INSUFFICIENT_PERMISSIONS
|
||||
error_code=ErrorCode.INSUFFICIENT_PERMISSIONS,
|
||||
)
|
||||
|
||||
# Deactivate the session
|
||||
@@ -173,16 +176,16 @@ async def revoke_session(
|
||||
|
||||
return MessageResponse(
|
||||
success=True,
|
||||
message=f"Session revoked: {session.device_name or 'Unknown device'}"
|
||||
message=f"Session revoked: {session.device_name or 'Unknown device'}",
|
||||
)
|
||||
|
||||
except (NotFoundError, AuthorizationError):
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error revoking session {session_id}: {str(e)}", exc_info=True)
|
||||
logger.error(f"Error revoking session {session_id}: {e!s}", exc_info=True)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Failed to revoke session"
|
||||
detail="Failed to revoke session",
|
||||
)
|
||||
|
||||
|
||||
@@ -198,13 +201,13 @@ async def revoke_session(
|
||||
|
||||
**Rate Limit**: 5 requests/minute
|
||||
""",
|
||||
operation_id="cleanup_expired_sessions"
|
||||
operation_id="cleanup_expired_sessions",
|
||||
)
|
||||
@limiter.limit("5/minute")
|
||||
async def cleanup_expired_sessions(
|
||||
request: Request,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> Any:
|
||||
"""
|
||||
Cleanup expired sessions for the current user.
|
||||
@@ -219,21 +222,24 @@ async def cleanup_expired_sessions(
|
||||
try:
|
||||
# Use optimized bulk DELETE instead of N individual deletes
|
||||
deleted_count = await session_crud.cleanup_expired_for_user(
|
||||
db,
|
||||
user_id=str(current_user.id)
|
||||
db, user_id=str(current_user.id)
|
||||
)
|
||||
|
||||
logger.info(f"User {current_user.id} cleaned up {deleted_count} expired sessions")
|
||||
logger.info(
|
||||
f"User {current_user.id} cleaned up {deleted_count} expired sessions"
|
||||
)
|
||||
|
||||
return MessageResponse(
|
||||
success=True,
|
||||
message=f"Cleaned up {deleted_count} expired sessions"
|
||||
success=True, message=f"Cleaned up {deleted_count} expired sessions"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error cleaning up sessions for user {current_user.id}: {str(e)}", exc_info=True)
|
||||
logger.error(
|
||||
f"Error cleaning up sessions for user {current_user.id}: {e!s}",
|
||||
exc_info=True,
|
||||
)
|
||||
await db.rollback()
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Failed to cleanup sessions"
|
||||
detail="Failed to cleanup sessions",
|
||||
)
|
||||
|
||||
@@ -1,33 +1,30 @@
|
||||
"""
|
||||
User management endpoints for CRUD operations.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Any, Optional
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import APIRouter, Depends, Query, status, Request
|
||||
from fastapi import APIRouter, Depends, Query, Request, status
|
||||
from slowapi import Limiter
|
||||
from slowapi.util import get_remote_address
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.api.dependencies.auth import get_current_user, get_current_superuser
|
||||
from app.api.dependencies.auth import get_current_superuser, get_current_user
|
||||
from app.core.database import get_db
|
||||
from app.core.exceptions import (
|
||||
NotFoundError,
|
||||
AuthorizationError,
|
||||
ErrorCode
|
||||
)
|
||||
from app.core.exceptions import AuthorizationError, ErrorCode, NotFoundError
|
||||
from app.crud.user import user as user_crud
|
||||
from app.models.user import User
|
||||
from app.schemas.common import (
|
||||
PaginationParams,
|
||||
PaginatedResponse,
|
||||
MessageResponse,
|
||||
PaginatedResponse,
|
||||
PaginationParams,
|
||||
SortParams,
|
||||
create_pagination_meta
|
||||
create_pagination_meta,
|
||||
)
|
||||
from app.schemas.users import UserResponse, UserUpdate, PasswordChange
|
||||
from app.services.auth_service import AuthService, AuthenticationError
|
||||
from app.schemas.users import PasswordChange, UserResponse, UserUpdate
|
||||
from app.services.auth_service import AuthenticationError, AuthService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -50,15 +47,15 @@ limiter = Limiter(key_func=get_remote_address)
|
||||
|
||||
**Rate Limit**: 60 requests/minute
|
||||
""",
|
||||
operation_id="list_users"
|
||||
operation_id="list_users",
|
||||
)
|
||||
async def list_users(
|
||||
pagination: PaginationParams = Depends(),
|
||||
sort: SortParams = Depends(),
|
||||
is_active: Optional[bool] = Query(None, description="Filter by active status"),
|
||||
is_superuser: Optional[bool] = Query(None, description="Filter by superuser status"),
|
||||
is_active: bool | None = Query(None, description="Filter by active status"),
|
||||
is_superuser: bool | None = Query(None, description="Filter by superuser status"),
|
||||
current_user: User = Depends(get_current_superuser),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> Any:
|
||||
"""
|
||||
List all users with pagination, filtering, and sorting.
|
||||
@@ -80,7 +77,7 @@ async def list_users(
|
||||
limit=pagination.limit,
|
||||
sort_by=sort.sort_by,
|
||||
sort_order=sort.sort_order.value if sort.sort_order else "asc",
|
||||
filters=filters if filters else None
|
||||
filters=filters if filters else None,
|
||||
)
|
||||
|
||||
# Create pagination metadata
|
||||
@@ -88,15 +85,12 @@ async def list_users(
|
||||
total=total,
|
||||
page=pagination.page,
|
||||
limit=pagination.limit,
|
||||
items_count=len(users)
|
||||
items_count=len(users),
|
||||
)
|
||||
|
||||
return PaginatedResponse(
|
||||
data=users,
|
||||
pagination=pagination_meta
|
||||
)
|
||||
return PaginatedResponse(data=users, pagination=pagination_meta)
|
||||
except Exception as e:
|
||||
logger.error(f"Error listing users: {str(e)}", exc_info=True)
|
||||
logger.error(f"Error listing users: {e!s}", exc_info=True)
|
||||
raise
|
||||
|
||||
|
||||
@@ -111,11 +105,9 @@ async def list_users(
|
||||
|
||||
**Rate Limit**: 60 requests/minute
|
||||
""",
|
||||
operation_id="get_current_user_profile"
|
||||
operation_id="get_current_user_profile",
|
||||
)
|
||||
def get_current_user_profile(
|
||||
current_user: User = Depends(get_current_user)
|
||||
) -> Any:
|
||||
def get_current_user_profile(current_user: User = Depends(get_current_user)) -> Any:
|
||||
"""Get current user's profile."""
|
||||
return current_user
|
||||
|
||||
@@ -133,12 +125,12 @@ def get_current_user_profile(
|
||||
|
||||
**Rate Limit**: 30 requests/minute
|
||||
""",
|
||||
operation_id="update_current_user"
|
||||
operation_id="update_current_user",
|
||||
)
|
||||
async def update_current_user(
|
||||
user_update: UserUpdate,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> Any:
|
||||
"""
|
||||
Update current user's profile.
|
||||
@@ -147,17 +139,17 @@ async def update_current_user(
|
||||
"""
|
||||
try:
|
||||
updated_user = await user_crud.update(
|
||||
db,
|
||||
db_obj=current_user,
|
||||
obj_in=user_update
|
||||
db, db_obj=current_user, obj_in=user_update
|
||||
)
|
||||
logger.info(f"User {current_user.id} updated their profile")
|
||||
return updated_user
|
||||
except ValueError as e:
|
||||
logger.error(f"Error updating user {current_user.id}: {str(e)}")
|
||||
logger.error(f"Error updating user {current_user.id}: {e!s}")
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error updating user {current_user.id}: {str(e)}", exc_info=True)
|
||||
logger.error(
|
||||
f"Unexpected error updating user {current_user.id}: {e!s}", exc_info=True
|
||||
)
|
||||
raise
|
||||
|
||||
|
||||
@@ -175,12 +167,12 @@ async def update_current_user(
|
||||
|
||||
**Rate Limit**: 60 requests/minute
|
||||
""",
|
||||
operation_id="get_user_by_id"
|
||||
operation_id="get_user_by_id",
|
||||
)
|
||||
async def get_user_by_id(
|
||||
user_id: UUID,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> Any:
|
||||
"""
|
||||
Get user by ID.
|
||||
@@ -194,7 +186,7 @@ async def get_user_by_id(
|
||||
)
|
||||
raise AuthorizationError(
|
||||
message="Not enough permissions to view this user",
|
||||
error_code=ErrorCode.INSUFFICIENT_PERMISSIONS
|
||||
error_code=ErrorCode.INSUFFICIENT_PERMISSIONS,
|
||||
)
|
||||
|
||||
# Get user
|
||||
@@ -202,7 +194,7 @@ async def get_user_by_id(
|
||||
if not user:
|
||||
raise NotFoundError(
|
||||
message=f"User with id {user_id} not found",
|
||||
error_code=ErrorCode.USER_NOT_FOUND
|
||||
error_code=ErrorCode.USER_NOT_FOUND,
|
||||
)
|
||||
|
||||
return user
|
||||
@@ -222,13 +214,13 @@ async def get_user_by_id(
|
||||
|
||||
**Rate Limit**: 30 requests/minute
|
||||
""",
|
||||
operation_id="update_user"
|
||||
operation_id="update_user",
|
||||
)
|
||||
async def update_user(
|
||||
user_id: UUID,
|
||||
user_update: UserUpdate,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> Any:
|
||||
"""
|
||||
Update user by ID.
|
||||
@@ -245,7 +237,7 @@ async def update_user(
|
||||
)
|
||||
raise AuthorizationError(
|
||||
message="Not enough permissions to update this user",
|
||||
error_code=ErrorCode.INSUFFICIENT_PERMISSIONS
|
||||
error_code=ErrorCode.INSUFFICIENT_PERMISSIONS,
|
||||
)
|
||||
|
||||
# Get user
|
||||
@@ -253,7 +245,7 @@ async def update_user(
|
||||
if not user:
|
||||
raise NotFoundError(
|
||||
message=f"User with id {user_id} not found",
|
||||
error_code=ErrorCode.USER_NOT_FOUND
|
||||
error_code=ErrorCode.USER_NOT_FOUND,
|
||||
)
|
||||
|
||||
try:
|
||||
@@ -261,10 +253,10 @@ async def update_user(
|
||||
logger.info(f"User {user_id} updated by {current_user.id}")
|
||||
return updated_user
|
||||
except ValueError as e:
|
||||
logger.error(f"Error updating user {user_id}: {str(e)}")
|
||||
logger.error(f"Error updating user {user_id}: {e!s}")
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error updating user {user_id}: {str(e)}", exc_info=True)
|
||||
logger.error(f"Unexpected error updating user {user_id}: {e!s}", exc_info=True)
|
||||
raise
|
||||
|
||||
|
||||
@@ -281,14 +273,14 @@ async def update_user(
|
||||
|
||||
**Rate Limit**: 5 requests/minute
|
||||
""",
|
||||
operation_id="change_current_user_password"
|
||||
operation_id="change_current_user_password",
|
||||
)
|
||||
@limiter.limit("5/minute")
|
||||
async def change_current_user_password(
|
||||
request: Request,
|
||||
password_change: PasswordChange,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> Any:
|
||||
"""
|
||||
Change current user's password.
|
||||
@@ -300,23 +292,23 @@ async def change_current_user_password(
|
||||
db=db,
|
||||
user_id=current_user.id,
|
||||
current_password=password_change.current_password,
|
||||
new_password=password_change.new_password
|
||||
new_password=password_change.new_password,
|
||||
)
|
||||
|
||||
if success:
|
||||
logger.info(f"User {current_user.id} changed their password")
|
||||
return MessageResponse(
|
||||
success=True,
|
||||
message="Password changed successfully"
|
||||
success=True, message="Password changed successfully"
|
||||
)
|
||||
except AuthenticationError as e:
|
||||
logger.warning(f"Failed password change attempt for user {current_user.id}: {str(e)}")
|
||||
logger.warning(
|
||||
f"Failed password change attempt for user {current_user.id}: {e!s}"
|
||||
)
|
||||
raise AuthorizationError(
|
||||
message=str(e),
|
||||
error_code=ErrorCode.INVALID_CREDENTIALS
|
||||
message=str(e), error_code=ErrorCode.INVALID_CREDENTIALS
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error changing password for user {current_user.id}: {str(e)}")
|
||||
logger.error(f"Error changing password for user {current_user.id}: {e!s}")
|
||||
raise
|
||||
|
||||
|
||||
@@ -335,12 +327,12 @@ async def change_current_user_password(
|
||||
|
||||
**Note**: This performs a hard delete. Consider implementing soft deletes for production.
|
||||
""",
|
||||
operation_id="delete_user"
|
||||
operation_id="delete_user",
|
||||
)
|
||||
async def delete_user(
|
||||
user_id: UUID,
|
||||
current_user: User = Depends(get_current_superuser),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> Any:
|
||||
"""
|
||||
Delete user by ID (superuser only).
|
||||
@@ -351,7 +343,7 @@ async def delete_user(
|
||||
if str(user_id) == str(current_user.id):
|
||||
raise AuthorizationError(
|
||||
message="Cannot delete your own account",
|
||||
error_code=ErrorCode.INSUFFICIENT_PERMISSIONS
|
||||
error_code=ErrorCode.INSUFFICIENT_PERMISSIONS,
|
||||
)
|
||||
|
||||
# Get user
|
||||
@@ -359,7 +351,7 @@ async def delete_user(
|
||||
if not user:
|
||||
raise NotFoundError(
|
||||
message=f"User with id {user_id} not found",
|
||||
error_code=ErrorCode.USER_NOT_FOUND
|
||||
error_code=ErrorCode.USER_NOT_FOUND,
|
||||
)
|
||||
|
||||
try:
|
||||
@@ -367,12 +359,11 @@ async def delete_user(
|
||||
await user_crud.soft_delete(db, id=str(user_id))
|
||||
logger.info(f"User {user_id} soft-deleted by {current_user.id}")
|
||||
return MessageResponse(
|
||||
success=True,
|
||||
message=f"User {user_id} deleted successfully"
|
||||
success=True, message=f"User {user_id} deleted successfully"
|
||||
)
|
||||
except ValueError as e:
|
||||
logger.error(f"Error deleting user {user_id}: {str(e)}")
|
||||
logger.error(f"Error deleting user {user_id}: {e!s}")
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error deleting user {user_id}: {str(e)}", exc_info=True)
|
||||
logger.error(f"Unexpected error deleting user {user_id}: {e!s}", exc_info=True)
|
||||
raise
|
||||
|
||||
@@ -1,39 +1,39 @@
|
||||
import logging
|
||||
logging.getLogger('passlib').setLevel(logging.ERROR)
|
||||
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import Any, Dict, Optional, Union
|
||||
import uuid
|
||||
logging.getLogger("passlib").setLevel(logging.ERROR)
|
||||
|
||||
import asyncio
|
||||
import uuid
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from functools import partial
|
||||
from typing import Any
|
||||
|
||||
from jose import jwt, JWTError
|
||||
from jose import JWTError, jwt
|
||||
from passlib.context import CryptContext
|
||||
from pydantic import ValidationError
|
||||
|
||||
from app.core.config import settings
|
||||
from app.schemas.users import TokenData, TokenPayload
|
||||
|
||||
|
||||
# Password hashing context
|
||||
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
|
||||
|
||||
|
||||
# Custom exceptions for auth
|
||||
class AuthError(Exception):
|
||||
"""Base authentication error"""
|
||||
pass
|
||||
|
||||
|
||||
class TokenExpiredError(AuthError):
|
||||
"""Token has expired"""
|
||||
pass
|
||||
|
||||
|
||||
class TokenInvalidError(AuthError):
|
||||
"""Token is invalid"""
|
||||
pass
|
||||
|
||||
|
||||
class TokenMissingClaimError(AuthError):
|
||||
"""Token is missing a required claim"""
|
||||
pass
|
||||
|
||||
|
||||
def verify_password(plain_password: str, hashed_password: str) -> bool:
|
||||
@@ -62,8 +62,7 @@ async def verify_password_async(plain_password: str, hashed_password: str) -> bo
|
||||
"""
|
||||
loop = asyncio.get_event_loop()
|
||||
return await loop.run_in_executor(
|
||||
None,
|
||||
partial(pwd_context.verify, plain_password, hashed_password)
|
||||
None, partial(pwd_context.verify, plain_password, hashed_password)
|
||||
)
|
||||
|
||||
|
||||
@@ -82,17 +81,13 @@ async def get_password_hash_async(password: str) -> str:
|
||||
Hashed password string
|
||||
"""
|
||||
loop = asyncio.get_event_loop()
|
||||
return await loop.run_in_executor(
|
||||
None,
|
||||
pwd_context.hash,
|
||||
password
|
||||
)
|
||||
return await loop.run_in_executor(None, pwd_context.hash, password)
|
||||
|
||||
|
||||
def create_access_token(
|
||||
subject: Union[str, Any],
|
||||
expires_delta: Optional[timedelta] = None,
|
||||
claims: Optional[Dict[str, Any]] = None
|
||||
subject: str | Any,
|
||||
expires_delta: timedelta | None = None,
|
||||
claims: dict[str, Any] | None = None,
|
||||
) -> str:
|
||||
"""
|
||||
Create a JWT access token.
|
||||
@@ -106,17 +101,19 @@ def create_access_token(
|
||||
Encoded JWT token
|
||||
"""
|
||||
if expires_delta:
|
||||
expire = datetime.now(timezone.utc) + expires_delta
|
||||
expire = datetime.now(UTC) + expires_delta
|
||||
else:
|
||||
expire = datetime.now(timezone.utc) + timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES)
|
||||
expire = datetime.now(UTC) + timedelta(
|
||||
minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES
|
||||
)
|
||||
|
||||
# Base token data
|
||||
to_encode = {
|
||||
"sub": str(subject),
|
||||
"exp": expire,
|
||||
"iat": datetime.now(tz=timezone.utc),
|
||||
"iat": datetime.now(tz=UTC),
|
||||
"jti": str(uuid.uuid4()),
|
||||
"type": "access"
|
||||
"type": "access",
|
||||
}
|
||||
|
||||
# Add custom claims
|
||||
@@ -125,17 +122,14 @@ def create_access_token(
|
||||
|
||||
# Create the JWT
|
||||
encoded_jwt = jwt.encode(
|
||||
to_encode,
|
||||
settings.SECRET_KEY,
|
||||
algorithm=settings.ALGORITHM
|
||||
to_encode, settings.SECRET_KEY, algorithm=settings.ALGORITHM
|
||||
)
|
||||
|
||||
return encoded_jwt
|
||||
|
||||
|
||||
def create_refresh_token(
|
||||
subject: Union[str, Any],
|
||||
expires_delta: Optional[timedelta] = None
|
||||
subject: str | Any, expires_delta: timedelta | None = None
|
||||
) -> str:
|
||||
"""
|
||||
Create a JWT refresh token.
|
||||
@@ -148,28 +142,26 @@ def create_refresh_token(
|
||||
Encoded JWT refresh token
|
||||
"""
|
||||
if expires_delta:
|
||||
expire = datetime.now(timezone.utc) + expires_delta
|
||||
expire = datetime.now(UTC) + expires_delta
|
||||
else:
|
||||
expire = datetime.now(timezone.utc) + timedelta(days=settings.REFRESH_TOKEN_EXPIRE_DAYS)
|
||||
expire = datetime.now(UTC) + timedelta(days=settings.REFRESH_TOKEN_EXPIRE_DAYS)
|
||||
|
||||
to_encode = {
|
||||
"sub": str(subject),
|
||||
"exp": expire,
|
||||
"iat": datetime.now(timezone.utc),
|
||||
"iat": datetime.now(UTC),
|
||||
"jti": str(uuid.uuid4()),
|
||||
"type": "refresh"
|
||||
"type": "refresh",
|
||||
}
|
||||
|
||||
encoded_jwt = jwt.encode(
|
||||
to_encode,
|
||||
settings.SECRET_KEY,
|
||||
algorithm=settings.ALGORITHM
|
||||
to_encode, settings.SECRET_KEY, algorithm=settings.ALGORITHM
|
||||
)
|
||||
|
||||
return encoded_jwt
|
||||
|
||||
|
||||
def decode_token(token: str, verify_type: Optional[str] = None) -> TokenPayload:
|
||||
def decode_token(token: str, verify_type: str | None = None) -> TokenPayload:
|
||||
"""
|
||||
Decode and verify a JWT token.
|
||||
|
||||
@@ -195,8 +187,8 @@ def decode_token(token: str, verify_type: Optional[str] = None) -> TokenPayload:
|
||||
"verify_signature": True,
|
||||
"verify_exp": True,
|
||||
"verify_iat": True,
|
||||
"require": ["exp", "sub", "iat"]
|
||||
}
|
||||
"require": ["exp", "sub", "iat"],
|
||||
},
|
||||
)
|
||||
|
||||
# SECURITY: Explicitly verify the algorithm to prevent algorithm confusion attacks
|
||||
@@ -250,4 +242,4 @@ def get_token_data(token: str) -> TokenData:
|
||||
user_id = payload.sub
|
||||
is_superuser = payload.is_superuser or False
|
||||
|
||||
return TokenData(user_id=uuid.UUID(user_id), is_superuser=is_superuser)
|
||||
return TokenData(user_id=uuid.UUID(user_id), is_superuser=is_superuser)
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
import logging
|
||||
from typing import Optional, List
|
||||
|
||||
from pydantic import Field, field_validator
|
||||
from pydantic_settings import BaseSettings
|
||||
@@ -13,7 +12,7 @@ class Settings(BaseSettings):
|
||||
# Environment (must be before SECRET_KEY for validation)
|
||||
ENVIRONMENT: str = Field(
|
||||
default="development",
|
||||
description="Environment: development, staging, or production"
|
||||
description="Environment: development, staging, or production",
|
||||
)
|
||||
|
||||
# Security: Content Security Policy
|
||||
@@ -21,8 +20,7 @@ class Settings(BaseSettings):
|
||||
# Set to True for strict CSP (blocks most external resources)
|
||||
# Set to "relaxed" for modern frontend development
|
||||
CSP_MODE: str = Field(
|
||||
default="relaxed",
|
||||
description="CSP mode: 'strict', 'relaxed', or 'disabled'"
|
||||
default="relaxed", description="CSP mode: 'strict', 'relaxed', or 'disabled'"
|
||||
)
|
||||
|
||||
# Database configuration
|
||||
@@ -31,7 +29,7 @@ class Settings(BaseSettings):
|
||||
POSTGRES_HOST: str = "localhost"
|
||||
POSTGRES_PORT: str = "5432"
|
||||
POSTGRES_DB: str = "app"
|
||||
DATABASE_URL: Optional[str] = None
|
||||
DATABASE_URL: str | None = None
|
||||
db_pool_size: int = 20 # Default connection pool size
|
||||
db_max_overflow: int = 50 # Maximum overflow connections
|
||||
db_pool_timeout: int = 30 # Seconds to wait for a connection
|
||||
@@ -59,38 +57,36 @@ class Settings(BaseSettings):
|
||||
SECRET_KEY: str = Field(
|
||||
default="dev_only_insecure_key_change_in_production_32chars_min",
|
||||
min_length=32,
|
||||
description="JWT signing key. MUST be changed in production. Generate with: python -c 'import secrets; print(secrets.token_urlsafe(32))'"
|
||||
description="JWT signing key. MUST be changed in production. Generate with: python -c 'import secrets; print(secrets.token_urlsafe(32))'",
|
||||
)
|
||||
ALGORITHM: str = "HS256"
|
||||
ACCESS_TOKEN_EXPIRE_MINUTES: int = 15 # 15 minutes (production standard)
|
||||
REFRESH_TOKEN_EXPIRE_DAYS: int = 7 # 7 days
|
||||
|
||||
# CORS configuration
|
||||
BACKEND_CORS_ORIGINS: List[str] = ["http://localhost:3000"]
|
||||
BACKEND_CORS_ORIGINS: list[str] = ["http://localhost:3000"]
|
||||
|
||||
# Frontend URL for email links
|
||||
FRONTEND_URL: str = Field(
|
||||
default="http://localhost:3000",
|
||||
description="Frontend application URL for email links"
|
||||
description="Frontend application URL for email links",
|
||||
)
|
||||
|
||||
# Admin user
|
||||
FIRST_SUPERUSER_EMAIL: Optional[str] = Field(
|
||||
default=None,
|
||||
description="Email for first superuser account"
|
||||
FIRST_SUPERUSER_EMAIL: str | None = Field(
|
||||
default=None, description="Email for first superuser account"
|
||||
)
|
||||
FIRST_SUPERUSER_PASSWORD: Optional[str] = Field(
|
||||
default=None,
|
||||
description="Password for first superuser (min 12 characters)"
|
||||
FIRST_SUPERUSER_PASSWORD: str | None = Field(
|
||||
default=None, description="Password for first superuser (min 12 characters)"
|
||||
)
|
||||
|
||||
@field_validator('SECRET_KEY')
|
||||
@field_validator("SECRET_KEY")
|
||||
@classmethod
|
||||
def validate_secret_key(cls, v: str, info) -> str:
|
||||
"""Validate SECRET_KEY is secure, especially in production."""
|
||||
# Get environment from values if available
|
||||
values_data = info.data if info.data else {}
|
||||
env = values_data.get('ENVIRONMENT', 'development')
|
||||
env = values_data.get("ENVIRONMENT", "development")
|
||||
|
||||
if v.startswith("your_secret_key_here"):
|
||||
if env == "production":
|
||||
@@ -106,13 +102,15 @@ class Settings(BaseSettings):
|
||||
)
|
||||
|
||||
if len(v) < 32:
|
||||
raise ValueError("SECRET_KEY must be at least 32 characters long for security")
|
||||
raise ValueError(
|
||||
"SECRET_KEY must be at least 32 characters long for security"
|
||||
)
|
||||
|
||||
return v
|
||||
|
||||
@field_validator('FIRST_SUPERUSER_PASSWORD')
|
||||
@field_validator("FIRST_SUPERUSER_PASSWORD")
|
||||
@classmethod
|
||||
def validate_superuser_password(cls, v: Optional[str]) -> Optional[str]:
|
||||
def validate_superuser_password(cls, v: str | None) -> str | None:
|
||||
"""Validate superuser password strength."""
|
||||
if v is None:
|
||||
return v
|
||||
@@ -121,7 +119,13 @@ class Settings(BaseSettings):
|
||||
raise ValueError("FIRST_SUPERUSER_PASSWORD must be at least 12 characters")
|
||||
|
||||
# Check for common weak passwords
|
||||
weak_passwords = {'admin123', 'Admin123', 'password123', 'Password123', '123456789012'}
|
||||
weak_passwords = {
|
||||
"admin123",
|
||||
"Admin123",
|
||||
"password123",
|
||||
"Password123",
|
||||
"123456789012",
|
||||
}
|
||||
if v in weak_passwords:
|
||||
raise ValueError(
|
||||
"FIRST_SUPERUSER_PASSWORD is too weak. "
|
||||
@@ -144,8 +148,8 @@ class Settings(BaseSettings):
|
||||
"env_file": "../.env",
|
||||
"env_file_encoding": "utf-8",
|
||||
"case_sensitive": True,
|
||||
"extra": "ignore" # Ignore extra fields from .env (e.g., frontend-specific vars)
|
||||
"extra": "ignore", # Ignore extra fields from .env (e.g., frontend-specific vars)
|
||||
}
|
||||
|
||||
|
||||
settings = Settings()
|
||||
settings = Settings()
|
||||
|
||||
@@ -5,17 +5,18 @@ Database configuration using SQLAlchemy 2.0 and asyncpg.
|
||||
This module provides async database connectivity with proper connection pooling
|
||||
and session management for FastAPI endpoints.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from collections.abc import AsyncGenerator
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import AsyncGenerator
|
||||
|
||||
from sqlalchemy import text
|
||||
from sqlalchemy.dialects.postgresql import JSONB, UUID
|
||||
from sqlalchemy.ext.asyncio import (
|
||||
AsyncSession,
|
||||
AsyncEngine,
|
||||
create_async_engine,
|
||||
AsyncSession,
|
||||
async_sessionmaker,
|
||||
create_async_engine,
|
||||
)
|
||||
from sqlalchemy.ext.compiler import compiles
|
||||
from sqlalchemy.orm import DeclarativeBase
|
||||
@@ -27,12 +28,12 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# SQLite compatibility for testing
|
||||
@compiles(JSONB, 'sqlite')
|
||||
@compiles(JSONB, "sqlite")
|
||||
def compile_jsonb_sqlite(type_, compiler, **kw):
|
||||
return "TEXT"
|
||||
|
||||
|
||||
@compiles(UUID, 'sqlite')
|
||||
@compiles(UUID, "sqlite")
|
||||
def compile_uuid_sqlite(type_, compiler, **kw):
|
||||
return "TEXT"
|
||||
|
||||
@@ -40,7 +41,6 @@ def compile_uuid_sqlite(type_, compiler, **kw):
|
||||
# Declarative base for models (SQLAlchemy 2.0 style)
|
||||
class Base(DeclarativeBase):
|
||||
"""Base class for all database models."""
|
||||
pass
|
||||
|
||||
|
||||
def get_async_database_url(url: str) -> str:
|
||||
@@ -139,7 +139,7 @@ async def async_transaction_scope() -> AsyncGenerator[AsyncSession, None]:
|
||||
logger.debug("Async transaction committed successfully")
|
||||
except Exception as e:
|
||||
await session.rollback()
|
||||
logger.error(f"Async transaction failed, rolling back: {str(e)}")
|
||||
logger.error(f"Async transaction failed, rolling back: {e!s}")
|
||||
raise
|
||||
finally:
|
||||
await session.close()
|
||||
@@ -155,7 +155,7 @@ async def check_async_database_health() -> bool:
|
||||
await db.execute(text("SELECT 1"))
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"Async database health check failed: {str(e)}")
|
||||
logger.error(f"Async database health check failed: {e!s}")
|
||||
return False
|
||||
|
||||
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
"""
|
||||
Custom exceptions and global exception handlers for the API.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Optional, Union
|
||||
|
||||
from fastapi import HTTPException, Request, status
|
||||
from fastapi.exceptions import RequestValidationError
|
||||
@@ -27,17 +27,13 @@ class APIException(HTTPException):
|
||||
status_code: int,
|
||||
error_code: ErrorCode,
|
||||
message: str,
|
||||
field: Optional[str] = None,
|
||||
headers: Optional[dict] = None
|
||||
field: str | None = None,
|
||||
headers: dict | None = None,
|
||||
):
|
||||
self.error_code = error_code
|
||||
self.field = field
|
||||
self.message = message
|
||||
super().__init__(
|
||||
status_code=status_code,
|
||||
detail=message,
|
||||
headers=headers
|
||||
)
|
||||
super().__init__(status_code=status_code, detail=message, headers=headers)
|
||||
|
||||
|
||||
class AuthenticationError(APIException):
|
||||
@@ -47,14 +43,14 @@ class AuthenticationError(APIException):
|
||||
self,
|
||||
message: str = "Authentication failed",
|
||||
error_code: ErrorCode = ErrorCode.INVALID_CREDENTIALS,
|
||||
field: Optional[str] = None
|
||||
field: str | None = None,
|
||||
):
|
||||
super().__init__(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
error_code=error_code,
|
||||
message=message,
|
||||
field=field,
|
||||
headers={"WWW-Authenticate": "Bearer"}
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
|
||||
|
||||
@@ -64,12 +60,12 @@ class AuthorizationError(APIException):
|
||||
def __init__(
|
||||
self,
|
||||
message: str = "Insufficient permissions",
|
||||
error_code: ErrorCode = ErrorCode.INSUFFICIENT_PERMISSIONS
|
||||
error_code: ErrorCode = ErrorCode.INSUFFICIENT_PERMISSIONS,
|
||||
):
|
||||
super().__init__(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
error_code=error_code,
|
||||
message=message
|
||||
message=message,
|
||||
)
|
||||
|
||||
|
||||
@@ -79,12 +75,12 @@ class NotFoundError(APIException):
|
||||
def __init__(
|
||||
self,
|
||||
message: str = "Resource not found",
|
||||
error_code: ErrorCode = ErrorCode.NOT_FOUND
|
||||
error_code: ErrorCode = ErrorCode.NOT_FOUND,
|
||||
):
|
||||
super().__init__(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
error_code=error_code,
|
||||
message=message
|
||||
message=message,
|
||||
)
|
||||
|
||||
|
||||
@@ -95,13 +91,13 @@ class DuplicateError(APIException):
|
||||
self,
|
||||
message: str = "Resource already exists",
|
||||
error_code: ErrorCode = ErrorCode.DUPLICATE_ENTRY,
|
||||
field: Optional[str] = None
|
||||
field: str | None = None,
|
||||
):
|
||||
super().__init__(
|
||||
status_code=status.HTTP_409_CONFLICT,
|
||||
error_code=error_code,
|
||||
message=message,
|
||||
field=field
|
||||
field=field,
|
||||
)
|
||||
|
||||
|
||||
@@ -112,13 +108,13 @@ class ValidationException(APIException):
|
||||
self,
|
||||
message: str = "Validation error",
|
||||
error_code: ErrorCode = ErrorCode.VALIDATION_ERROR,
|
||||
field: Optional[str] = None
|
||||
field: str | None = None,
|
||||
):
|
||||
super().__init__(
|
||||
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
|
||||
error_code=error_code,
|
||||
message=message,
|
||||
field=field
|
||||
field=field,
|
||||
)
|
||||
|
||||
|
||||
@@ -128,12 +124,12 @@ class DatabaseError(APIException):
|
||||
def __init__(
|
||||
self,
|
||||
message: str = "Database operation failed",
|
||||
error_code: ErrorCode = ErrorCode.DATABASE_ERROR
|
||||
error_code: ErrorCode = ErrorCode.DATABASE_ERROR,
|
||||
):
|
||||
super().__init__(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
error_code=error_code,
|
||||
message=message
|
||||
message=message,
|
||||
)
|
||||
|
||||
|
||||
@@ -152,23 +148,18 @@ async def api_exception_handler(request: Request, exc: APIException) -> JSONResp
|
||||
)
|
||||
|
||||
error_response = ErrorResponse(
|
||||
errors=[ErrorDetail(
|
||||
code=exc.error_code,
|
||||
message=exc.message,
|
||||
field=exc.field
|
||||
)]
|
||||
errors=[ErrorDetail(code=exc.error_code, message=exc.message, field=exc.field)]
|
||||
)
|
||||
|
||||
return JSONResponse(
|
||||
status_code=exc.status_code,
|
||||
content=error_response.model_dump(),
|
||||
headers=exc.headers
|
||||
headers=exc.headers,
|
||||
)
|
||||
|
||||
|
||||
async def validation_exception_handler(
|
||||
request: Request,
|
||||
exc: Union[RequestValidationError, ValidationError]
|
||||
request: Request, exc: RequestValidationError | ValidationError
|
||||
) -> JSONResponse:
|
||||
"""
|
||||
Handler for Pydantic validation errors.
|
||||
@@ -189,22 +180,19 @@ async def validation_exception_handler(
|
||||
# Skip 'body' or 'query' prefix in location
|
||||
field = ".".join(str(x) for x in error["loc"][1:])
|
||||
|
||||
errors.append(ErrorDetail(
|
||||
code=ErrorCode.VALIDATION_ERROR,
|
||||
message=error["msg"],
|
||||
field=field
|
||||
))
|
||||
errors.append(
|
||||
ErrorDetail(
|
||||
code=ErrorCode.VALIDATION_ERROR, message=error["msg"], field=field
|
||||
)
|
||||
)
|
||||
|
||||
logger.warning(
|
||||
f"Validation error: {len(errors)} errors "
|
||||
f"(path: {request.url.path})"
|
||||
)
|
||||
logger.warning(f"Validation error: {len(errors)} errors (path: {request.url.path})")
|
||||
|
||||
error_response = ErrorResponse(errors=errors)
|
||||
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
|
||||
content=error_response.model_dump()
|
||||
content=error_response.model_dump(),
|
||||
)
|
||||
|
||||
|
||||
@@ -226,26 +214,21 @@ async def http_exception_handler(request: Request, exc: HTTPException) -> JSONRe
|
||||
}
|
||||
|
||||
error_code = status_code_to_error_code.get(
|
||||
exc.status_code,
|
||||
ErrorCode.INTERNAL_ERROR
|
||||
exc.status_code, ErrorCode.INTERNAL_ERROR
|
||||
)
|
||||
|
||||
logger.warning(
|
||||
f"HTTP exception: {exc.status_code} - {exc.detail} "
|
||||
f"(path: {request.url.path})"
|
||||
f"HTTP exception: {exc.status_code} - {exc.detail} (path: {request.url.path})"
|
||||
)
|
||||
|
||||
error_response = ErrorResponse(
|
||||
errors=[ErrorDetail(
|
||||
code=error_code,
|
||||
message=str(exc.detail)
|
||||
)]
|
||||
errors=[ErrorDetail(code=error_code, message=str(exc.detail))]
|
||||
)
|
||||
|
||||
return JSONResponse(
|
||||
status_code=exc.status_code,
|
||||
content=error_response.model_dump(),
|
||||
headers=exc.headers
|
||||
headers=exc.headers,
|
||||
)
|
||||
|
||||
|
||||
@@ -257,26 +240,24 @@ async def unhandled_exception_handler(request: Request, exc: Exception) -> JSONR
|
||||
leaking sensitive information in production.
|
||||
"""
|
||||
logger.error(
|
||||
f"Unhandled exception: {type(exc).__name__} - {str(exc)} "
|
||||
f"Unhandled exception: {type(exc).__name__} - {exc!s} "
|
||||
f"(path: {request.url.path})",
|
||||
exc_info=True
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
# In production, don't expose internal error details
|
||||
from app.core.config import settings
|
||||
|
||||
if settings.ENVIRONMENT == "production":
|
||||
message = "An internal error occurred. Please try again later."
|
||||
else:
|
||||
message = f"{type(exc).__name__}: {str(exc)}"
|
||||
message = f"{type(exc).__name__}: {exc!s}"
|
||||
|
||||
error_response = ErrorResponse(
|
||||
errors=[ErrorDetail(
|
||||
code=ErrorCode.INTERNAL_ERROR,
|
||||
message=message
|
||||
)]
|
||||
errors=[ErrorDetail(code=ErrorCode.INTERNAL_ERROR, message=message)]
|
||||
)
|
||||
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
content=error_response.model_dump()
|
||||
content=error_response.model_dump(),
|
||||
)
|
||||
|
||||
@@ -3,4 +3,4 @@ from .organization import organization
|
||||
from .session import session as session_crud
|
||||
from .user import user
|
||||
|
||||
__all__ = ["user", "session_crud", "organization"]
|
||||
__all__ = ["organization", "session_crud", "user"]
|
||||
|
||||
@@ -4,14 +4,16 @@ Async CRUD operations base class using SQLAlchemy 2.0 async patterns.
|
||||
|
||||
Provides reusable create, read, update, and delete operations for all models.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import uuid
|
||||
from typing import Any, Dict, Generic, List, Optional, Type, TypeVar, Union, Tuple
|
||||
from datetime import UTC
|
||||
from typing import Any, TypeVar
|
||||
|
||||
from fastapi.encoders import jsonable_encoder
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy import func, select
|
||||
from sqlalchemy.exc import IntegrityError, OperationalError, DataError
|
||||
from sqlalchemy.exc import DataError, IntegrityError, OperationalError
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm import Load
|
||||
|
||||
@@ -24,10 +26,14 @@ CreateSchemaType = TypeVar("CreateSchemaType", bound=BaseModel)
|
||||
UpdateSchemaType = TypeVar("UpdateSchemaType", bound=BaseModel)
|
||||
|
||||
|
||||
class CRUDBase(Generic[ModelType, CreateSchemaType, UpdateSchemaType]):
|
||||
class CRUDBase[
|
||||
ModelType: Base,
|
||||
CreateSchemaType: BaseModel,
|
||||
UpdateSchemaType: BaseModel,
|
||||
]:
|
||||
"""Async CRUD operations for a model."""
|
||||
|
||||
def __init__(self, model: Type[ModelType]):
|
||||
def __init__(self, model: type[ModelType]):
|
||||
"""
|
||||
CRUD object with default async methods to Create, Read, Update, Delete.
|
||||
|
||||
@@ -37,11 +43,8 @@ class CRUDBase(Generic[ModelType, CreateSchemaType, UpdateSchemaType]):
|
||||
self.model = model
|
||||
|
||||
async def get(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
id: str,
|
||||
options: Optional[List[Load]] = None
|
||||
) -> Optional[ModelType]:
|
||||
self, db: AsyncSession, id: str, options: list[Load] | None = None
|
||||
) -> ModelType | None:
|
||||
"""
|
||||
Get a single record by ID with UUID validation and optional eager loading.
|
||||
|
||||
@@ -66,7 +69,7 @@ class CRUDBase(Generic[ModelType, CreateSchemaType, UpdateSchemaType]):
|
||||
else:
|
||||
uuid_obj = uuid.UUID(str(id))
|
||||
except (ValueError, AttributeError, TypeError) as e:
|
||||
logger.warning(f"Invalid UUID format: {id} - {str(e)}")
|
||||
logger.warning(f"Invalid UUID format: {id} - {e!s}")
|
||||
return None
|
||||
|
||||
try:
|
||||
@@ -80,7 +83,7 @@ class CRUDBase(Generic[ModelType, CreateSchemaType, UpdateSchemaType]):
|
||||
result = await db.execute(query)
|
||||
return result.scalar_one_or_none()
|
||||
except Exception as e:
|
||||
logger.error(f"Error retrieving {self.model.__name__} with id {id}: {str(e)}")
|
||||
logger.error(f"Error retrieving {self.model.__name__} with id {id}: {e!s}")
|
||||
raise
|
||||
|
||||
async def get_multi(
|
||||
@@ -89,8 +92,8 @@ class CRUDBase(Generic[ModelType, CreateSchemaType, UpdateSchemaType]):
|
||||
*,
|
||||
skip: int = 0,
|
||||
limit: int = 100,
|
||||
options: Optional[List[Load]] = None
|
||||
) -> List[ModelType]:
|
||||
options: list[Load] | None = None,
|
||||
) -> list[ModelType]:
|
||||
"""
|
||||
Get multiple records with pagination validation and optional eager loading.
|
||||
|
||||
@@ -122,10 +125,14 @@ class CRUDBase(Generic[ModelType, CreateSchemaType, UpdateSchemaType]):
|
||||
result = await db.execute(query)
|
||||
return list(result.scalars().all())
|
||||
except Exception as e:
|
||||
logger.error(f"Error retrieving multiple {self.model.__name__} records: {str(e)}")
|
||||
logger.error(
|
||||
f"Error retrieving multiple {self.model.__name__} records: {e!s}"
|
||||
)
|
||||
raise
|
||||
|
||||
async def create(self, db: AsyncSession, *, obj_in: CreateSchemaType) -> ModelType: # pragma: no cover
|
||||
async def create(
|
||||
self, db: AsyncSession, *, obj_in: CreateSchemaType
|
||||
) -> ModelType: # pragma: no cover
|
||||
"""Create a new record with error handling.
|
||||
|
||||
NOTE: This method is defensive code that's never called in practice.
|
||||
@@ -142,19 +149,25 @@ class CRUDBase(Generic[ModelType, CreateSchemaType, UpdateSchemaType]):
|
||||
return db_obj
|
||||
except IntegrityError as e: # pragma: no cover
|
||||
await db.rollback()
|
||||
error_msg = str(e.orig) if hasattr(e, 'orig') else str(e)
|
||||
error_msg = str(e.orig) if hasattr(e, "orig") else str(e)
|
||||
if "unique" in error_msg.lower() or "duplicate" in error_msg.lower():
|
||||
logger.warning(f"Duplicate entry attempted for {self.model.__name__}: {error_msg}")
|
||||
raise ValueError(f"A {self.model.__name__} with this data already exists")
|
||||
logger.warning(
|
||||
f"Duplicate entry attempted for {self.model.__name__}: {error_msg}"
|
||||
)
|
||||
raise ValueError(
|
||||
f"A {self.model.__name__} with this data already exists"
|
||||
)
|
||||
logger.error(f"Integrity error creating {self.model.__name__}: {error_msg}")
|
||||
raise ValueError(f"Database integrity error: {error_msg}")
|
||||
except (OperationalError, DataError) as e: # pragma: no cover
|
||||
await db.rollback()
|
||||
logger.error(f"Database error creating {self.model.__name__}: {str(e)}")
|
||||
raise ValueError(f"Database operation failed: {str(e)}")
|
||||
logger.error(f"Database error creating {self.model.__name__}: {e!s}")
|
||||
raise ValueError(f"Database operation failed: {e!s}")
|
||||
except Exception as e: # pragma: no cover
|
||||
await db.rollback()
|
||||
logger.error(f"Unexpected error creating {self.model.__name__}: {str(e)}", exc_info=True)
|
||||
logger.error(
|
||||
f"Unexpected error creating {self.model.__name__}: {e!s}", exc_info=True
|
||||
)
|
||||
raise
|
||||
|
||||
async def update(
|
||||
@@ -162,7 +175,7 @@ class CRUDBase(Generic[ModelType, CreateSchemaType, UpdateSchemaType]):
|
||||
db: AsyncSession,
|
||||
*,
|
||||
db_obj: ModelType,
|
||||
obj_in: Union[UpdateSchemaType, Dict[str, Any]]
|
||||
obj_in: UpdateSchemaType | dict[str, Any],
|
||||
) -> ModelType:
|
||||
"""Update a record with error handling."""
|
||||
try:
|
||||
@@ -182,22 +195,28 @@ class CRUDBase(Generic[ModelType, CreateSchemaType, UpdateSchemaType]):
|
||||
return db_obj
|
||||
except IntegrityError as e:
|
||||
await db.rollback()
|
||||
error_msg = str(e.orig) if hasattr(e, 'orig') else str(e)
|
||||
error_msg = str(e.orig) if hasattr(e, "orig") else str(e)
|
||||
if "unique" in error_msg.lower() or "duplicate" in error_msg.lower():
|
||||
logger.warning(f"Duplicate entry attempted for {self.model.__name__}: {error_msg}")
|
||||
raise ValueError(f"A {self.model.__name__} with this data already exists")
|
||||
logger.warning(
|
||||
f"Duplicate entry attempted for {self.model.__name__}: {error_msg}"
|
||||
)
|
||||
raise ValueError(
|
||||
f"A {self.model.__name__} with this data already exists"
|
||||
)
|
||||
logger.error(f"Integrity error updating {self.model.__name__}: {error_msg}")
|
||||
raise ValueError(f"Database integrity error: {error_msg}")
|
||||
except (OperationalError, DataError) as e:
|
||||
await db.rollback()
|
||||
logger.error(f"Database error updating {self.model.__name__}: {str(e)}")
|
||||
raise ValueError(f"Database operation failed: {str(e)}")
|
||||
logger.error(f"Database error updating {self.model.__name__}: {e!s}")
|
||||
raise ValueError(f"Database operation failed: {e!s}")
|
||||
except Exception as e:
|
||||
await db.rollback()
|
||||
logger.error(f"Unexpected error updating {self.model.__name__}: {str(e)}", exc_info=True)
|
||||
logger.error(
|
||||
f"Unexpected error updating {self.model.__name__}: {e!s}", exc_info=True
|
||||
)
|
||||
raise
|
||||
|
||||
async def remove(self, db: AsyncSession, *, id: str) -> Optional[ModelType]:
|
||||
async def remove(self, db: AsyncSession, *, id: str) -> ModelType | None:
|
||||
"""Delete a record with error handling and null check."""
|
||||
# Validate UUID format and convert to UUID object if string
|
||||
try:
|
||||
@@ -206,7 +225,7 @@ class CRUDBase(Generic[ModelType, CreateSchemaType, UpdateSchemaType]):
|
||||
else:
|
||||
uuid_obj = uuid.UUID(str(id))
|
||||
except (ValueError, AttributeError, TypeError) as e:
|
||||
logger.warning(f"Invalid UUID format for deletion: {id} - {str(e)}")
|
||||
logger.warning(f"Invalid UUID format for deletion: {id} - {e!s}")
|
||||
return None
|
||||
|
||||
try:
|
||||
@@ -216,7 +235,9 @@ class CRUDBase(Generic[ModelType, CreateSchemaType, UpdateSchemaType]):
|
||||
obj = result.scalar_one_or_none()
|
||||
|
||||
if obj is None:
|
||||
logger.warning(f"{self.model.__name__} with id {id} not found for deletion")
|
||||
logger.warning(
|
||||
f"{self.model.__name__} with id {id} not found for deletion"
|
||||
)
|
||||
return None
|
||||
|
||||
await db.delete(obj)
|
||||
@@ -224,12 +245,17 @@ class CRUDBase(Generic[ModelType, CreateSchemaType, UpdateSchemaType]):
|
||||
return obj
|
||||
except IntegrityError as e:
|
||||
await db.rollback()
|
||||
error_msg = str(e.orig) if hasattr(e, 'orig') else str(e)
|
||||
error_msg = str(e.orig) if hasattr(e, "orig") else str(e)
|
||||
logger.error(f"Integrity error deleting {self.model.__name__}: {error_msg}")
|
||||
raise ValueError(f"Cannot delete {self.model.__name__}: referenced by other records")
|
||||
raise ValueError(
|
||||
f"Cannot delete {self.model.__name__}: referenced by other records"
|
||||
)
|
||||
except Exception as e:
|
||||
await db.rollback()
|
||||
logger.error(f"Error deleting {self.model.__name__} with id {id}: {str(e)}", exc_info=True)
|
||||
logger.error(
|
||||
f"Error deleting {self.model.__name__} with id {id}: {e!s}",
|
||||
exc_info=True,
|
||||
)
|
||||
raise
|
||||
|
||||
async def get_multi_with_total(
|
||||
@@ -238,10 +264,10 @@ class CRUDBase(Generic[ModelType, CreateSchemaType, UpdateSchemaType]):
|
||||
*,
|
||||
skip: int = 0,
|
||||
limit: int = 100,
|
||||
sort_by: Optional[str] = None,
|
||||
sort_by: str | None = None,
|
||||
sort_order: str = "asc",
|
||||
filters: Optional[Dict[str, Any]] = None
|
||||
) -> Tuple[List[ModelType], int]:
|
||||
filters: dict[str, Any] | None = None,
|
||||
) -> tuple[list[ModelType], int]:
|
||||
"""
|
||||
Get multiple records with total count, filtering, and sorting.
|
||||
|
||||
@@ -269,7 +295,7 @@ class CRUDBase(Generic[ModelType, CreateSchemaType, UpdateSchemaType]):
|
||||
query = select(self.model)
|
||||
|
||||
# Exclude soft-deleted records by default
|
||||
if hasattr(self.model, 'deleted_at'):
|
||||
if hasattr(self.model, "deleted_at"):
|
||||
query = query.where(self.model.deleted_at.is_(None))
|
||||
|
||||
# Apply filters
|
||||
@@ -298,7 +324,9 @@ class CRUDBase(Generic[ModelType, CreateSchemaType, UpdateSchemaType]):
|
||||
|
||||
return items, total
|
||||
except Exception as e:
|
||||
logger.error(f"Error retrieving paginated {self.model.__name__} records: {str(e)}")
|
||||
logger.error(
|
||||
f"Error retrieving paginated {self.model.__name__} records: {e!s}"
|
||||
)
|
||||
raise
|
||||
|
||||
async def count(self, db: AsyncSession) -> int:
|
||||
@@ -307,7 +335,7 @@ class CRUDBase(Generic[ModelType, CreateSchemaType, UpdateSchemaType]):
|
||||
result = await db.execute(select(func.count(self.model.id)))
|
||||
return result.scalar_one()
|
||||
except Exception as e:
|
||||
logger.error(f"Error counting {self.model.__name__} records: {str(e)}")
|
||||
logger.error(f"Error counting {self.model.__name__} records: {e!s}")
|
||||
raise
|
||||
|
||||
async def exists(self, db: AsyncSession, id: str) -> bool:
|
||||
@@ -315,13 +343,13 @@ class CRUDBase(Generic[ModelType, CreateSchemaType, UpdateSchemaType]):
|
||||
obj = await self.get(db, id=id)
|
||||
return obj is not None
|
||||
|
||||
async def soft_delete(self, db: AsyncSession, *, id: str) -> Optional[ModelType]:
|
||||
async def soft_delete(self, db: AsyncSession, *, id: str) -> ModelType | None:
|
||||
"""
|
||||
Soft delete a record by setting deleted_at timestamp.
|
||||
|
||||
Only works if the model has a 'deleted_at' column.
|
||||
"""
|
||||
from datetime import datetime, timezone
|
||||
from datetime import datetime
|
||||
|
||||
# Validate UUID format and convert to UUID object if string
|
||||
try:
|
||||
@@ -330,7 +358,7 @@ class CRUDBase(Generic[ModelType, CreateSchemaType, UpdateSchemaType]):
|
||||
else:
|
||||
uuid_obj = uuid.UUID(str(id))
|
||||
except (ValueError, AttributeError, TypeError) as e:
|
||||
logger.warning(f"Invalid UUID format for soft deletion: {id} - {str(e)}")
|
||||
logger.warning(f"Invalid UUID format for soft deletion: {id} - {e!s}")
|
||||
return None
|
||||
|
||||
try:
|
||||
@@ -340,26 +368,33 @@ class CRUDBase(Generic[ModelType, CreateSchemaType, UpdateSchemaType]):
|
||||
obj = result.scalar_one_or_none()
|
||||
|
||||
if obj is None:
|
||||
logger.warning(f"{self.model.__name__} with id {id} not found for soft deletion")
|
||||
logger.warning(
|
||||
f"{self.model.__name__} with id {id} not found for soft deletion"
|
||||
)
|
||||
return None
|
||||
|
||||
# Check if model supports soft deletes
|
||||
if not hasattr(self.model, 'deleted_at'):
|
||||
if not hasattr(self.model, "deleted_at"):
|
||||
logger.error(f"{self.model.__name__} does not support soft deletes")
|
||||
raise ValueError(f"{self.model.__name__} does not have a deleted_at column")
|
||||
raise ValueError(
|
||||
f"{self.model.__name__} does not have a deleted_at column"
|
||||
)
|
||||
|
||||
# Set deleted_at timestamp
|
||||
obj.deleted_at = datetime.now(timezone.utc)
|
||||
obj.deleted_at = datetime.now(UTC)
|
||||
db.add(obj)
|
||||
await db.commit()
|
||||
await db.refresh(obj)
|
||||
return obj
|
||||
except Exception as e:
|
||||
await db.rollback()
|
||||
logger.error(f"Error soft deleting {self.model.__name__} with id {id}: {str(e)}", exc_info=True)
|
||||
logger.error(
|
||||
f"Error soft deleting {self.model.__name__} with id {id}: {e!s}",
|
||||
exc_info=True,
|
||||
)
|
||||
raise
|
||||
|
||||
async def restore(self, db: AsyncSession, *, id: str) -> Optional[ModelType]:
|
||||
async def restore(self, db: AsyncSession, *, id: str) -> ModelType | None:
|
||||
"""
|
||||
Restore a soft-deleted record by clearing the deleted_at timestamp.
|
||||
|
||||
@@ -372,25 +407,28 @@ class CRUDBase(Generic[ModelType, CreateSchemaType, UpdateSchemaType]):
|
||||
else:
|
||||
uuid_obj = uuid.UUID(str(id))
|
||||
except (ValueError, AttributeError, TypeError) as e:
|
||||
logger.warning(f"Invalid UUID format for restoration: {id} - {str(e)}")
|
||||
logger.warning(f"Invalid UUID format for restoration: {id} - {e!s}")
|
||||
return None
|
||||
|
||||
try:
|
||||
# Find the soft-deleted record
|
||||
if hasattr(self.model, 'deleted_at'):
|
||||
if hasattr(self.model, "deleted_at"):
|
||||
result = await db.execute(
|
||||
select(self.model).where(
|
||||
self.model.id == uuid_obj,
|
||||
self.model.deleted_at.isnot(None)
|
||||
self.model.id == uuid_obj, self.model.deleted_at.isnot(None)
|
||||
)
|
||||
)
|
||||
obj = result.scalar_one_or_none()
|
||||
else:
|
||||
logger.error(f"{self.model.__name__} does not support soft deletes")
|
||||
raise ValueError(f"{self.model.__name__} does not have a deleted_at column")
|
||||
raise ValueError(
|
||||
f"{self.model.__name__} does not have a deleted_at column"
|
||||
)
|
||||
|
||||
if obj is None:
|
||||
logger.warning(f"Soft-deleted {self.model.__name__} with id {id} not found for restoration")
|
||||
logger.warning(
|
||||
f"Soft-deleted {self.model.__name__} with id {id} not found for restoration"
|
||||
)
|
||||
return None
|
||||
|
||||
# Clear deleted_at timestamp
|
||||
@@ -401,5 +439,8 @@ class CRUDBase(Generic[ModelType, CreateSchemaType, UpdateSchemaType]):
|
||||
return obj
|
||||
except Exception as e:
|
||||
await db.rollback()
|
||||
logger.error(f"Error restoring {self.model.__name__} with id {id}: {str(e)}", exc_info=True)
|
||||
logger.error(
|
||||
f"Error restoring {self.model.__name__} with id {id}: {e!s}",
|
||||
exc_info=True,
|
||||
)
|
||||
raise
|
||||
|
||||
@@ -1,17 +1,18 @@
|
||||
# app/crud/organization_async.py
|
||||
"""Async CRUD operations for Organization model using SQLAlchemy 2.0 patterns."""
|
||||
|
||||
import logging
|
||||
from typing import Optional, List, Dict, Any
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy import func, or_, and_, select, case
|
||||
from sqlalchemy import and_, case, func, or_, select
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.crud.base import CRUDBase
|
||||
from app.models.organization import Organization
|
||||
from app.models.user import User
|
||||
from app.models.user_organization import UserOrganization, OrganizationRole
|
||||
from app.models.user_organization import OrganizationRole, UserOrganization
|
||||
from app.schemas.organizations import (
|
||||
OrganizationCreate,
|
||||
OrganizationUpdate,
|
||||
@@ -23,7 +24,7 @@ logger = logging.getLogger(__name__)
|
||||
class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUpdate]):
|
||||
"""Async CRUD operations for Organization model."""
|
||||
|
||||
async def get_by_slug(self, db: AsyncSession, *, slug: str) -> Optional[Organization]:
|
||||
async def get_by_slug(self, db: AsyncSession, *, slug: str) -> Organization | None:
|
||||
"""Get organization by slug."""
|
||||
try:
|
||||
result = await db.execute(
|
||||
@@ -31,10 +32,12 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp
|
||||
)
|
||||
return result.scalar_one_or_none()
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting organization by slug {slug}: {str(e)}")
|
||||
logger.error(f"Error getting organization by slug {slug}: {e!s}")
|
||||
raise
|
||||
|
||||
async def create(self, db: AsyncSession, *, obj_in: OrganizationCreate) -> Organization:
|
||||
async def create(
|
||||
self, db: AsyncSession, *, obj_in: OrganizationCreate
|
||||
) -> Organization:
|
||||
"""Create a new organization with error handling."""
|
||||
try:
|
||||
db_obj = Organization(
|
||||
@@ -42,7 +45,7 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp
|
||||
slug=obj_in.slug,
|
||||
description=obj_in.description,
|
||||
is_active=obj_in.is_active,
|
||||
settings=obj_in.settings or {}
|
||||
settings=obj_in.settings or {},
|
||||
)
|
||||
db.add(db_obj)
|
||||
await db.commit()
|
||||
@@ -50,15 +53,19 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp
|
||||
return db_obj
|
||||
except IntegrityError as e:
|
||||
await db.rollback()
|
||||
error_msg = str(e.orig) if hasattr(e, 'orig') else str(e)
|
||||
error_msg = str(e.orig) if hasattr(e, "orig") else str(e)
|
||||
if "slug" in error_msg.lower():
|
||||
logger.warning(f"Duplicate slug attempted: {obj_in.slug}")
|
||||
raise ValueError(f"Organization with slug '{obj_in.slug}' already exists")
|
||||
raise ValueError(
|
||||
f"Organization with slug '{obj_in.slug}' already exists"
|
||||
)
|
||||
logger.error(f"Integrity error creating organization: {error_msg}")
|
||||
raise ValueError(f"Database integrity error: {error_msg}")
|
||||
except Exception as e:
|
||||
await db.rollback()
|
||||
logger.error(f"Unexpected error creating organization: {str(e)}", exc_info=True)
|
||||
logger.error(
|
||||
f"Unexpected error creating organization: {e!s}", exc_info=True
|
||||
)
|
||||
raise
|
||||
|
||||
async def get_multi_with_filters(
|
||||
@@ -67,11 +74,11 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp
|
||||
*,
|
||||
skip: int = 0,
|
||||
limit: int = 100,
|
||||
is_active: Optional[bool] = None,
|
||||
search: Optional[str] = None,
|
||||
is_active: bool | None = None,
|
||||
search: str | None = None,
|
||||
sort_by: str = "created_at",
|
||||
sort_order: str = "desc"
|
||||
) -> tuple[List[Organization], int]:
|
||||
sort_order: str = "desc",
|
||||
) -> tuple[list[Organization], int]:
|
||||
"""
|
||||
Get multiple organizations with filtering, searching, and sorting.
|
||||
|
||||
@@ -89,7 +96,7 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp
|
||||
search_filter = or_(
|
||||
Organization.name.ilike(f"%{search}%"),
|
||||
Organization.slug.ilike(f"%{search}%"),
|
||||
Organization.description.ilike(f"%{search}%")
|
||||
Organization.description.ilike(f"%{search}%"),
|
||||
)
|
||||
query = query.where(search_filter)
|
||||
|
||||
@@ -112,7 +119,7 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp
|
||||
|
||||
return organizations, total
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting organizations with filters: {str(e)}")
|
||||
logger.error(f"Error getting organizations with filters: {e!s}")
|
||||
raise
|
||||
|
||||
async def get_member_count(self, db: AsyncSession, *, organization_id: UUID) -> int:
|
||||
@@ -122,13 +129,15 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp
|
||||
select(func.count(UserOrganization.user_id)).where(
|
||||
and_(
|
||||
UserOrganization.organization_id == organization_id,
|
||||
UserOrganization.is_active == True
|
||||
UserOrganization.is_active,
|
||||
)
|
||||
)
|
||||
)
|
||||
return result.scalar_one() or 0
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting member count for organization {organization_id}: {str(e)}")
|
||||
logger.error(
|
||||
f"Error getting member count for organization {organization_id}: {e!s}"
|
||||
)
|
||||
raise
|
||||
|
||||
async def get_multi_with_member_counts(
|
||||
@@ -137,9 +146,9 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp
|
||||
*,
|
||||
skip: int = 0,
|
||||
limit: int = 100,
|
||||
is_active: Optional[bool] = None,
|
||||
search: Optional[str] = None
|
||||
) -> tuple[List[Dict[str, Any]], int]:
|
||||
is_active: bool | None = None,
|
||||
search: str | None = None,
|
||||
) -> tuple[list[dict[str, Any]], int]:
|
||||
"""
|
||||
Get organizations with member counts in a SINGLE QUERY using JOIN and GROUP BY.
|
||||
This eliminates the N+1 query problem.
|
||||
@@ -156,13 +165,19 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp
|
||||
func.count(
|
||||
func.distinct(
|
||||
case(
|
||||
(UserOrganization.is_active == True, UserOrganization.user_id),
|
||||
else_=None
|
||||
(
|
||||
UserOrganization.is_active,
|
||||
UserOrganization.user_id,
|
||||
),
|
||||
else_=None,
|
||||
)
|
||||
)
|
||||
).label('member_count')
|
||||
).label("member_count"),
|
||||
)
|
||||
.outerjoin(
|
||||
UserOrganization,
|
||||
Organization.id == UserOrganization.organization_id,
|
||||
)
|
||||
.outerjoin(UserOrganization, Organization.id == UserOrganization.organization_id)
|
||||
.group_by(Organization.id)
|
||||
)
|
||||
|
||||
@@ -174,7 +189,7 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp
|
||||
search_filter = or_(
|
||||
Organization.name.ilike(f"%{search}%"),
|
||||
Organization.slug.ilike(f"%{search}%"),
|
||||
Organization.description.ilike(f"%{search}%")
|
||||
Organization.description.ilike(f"%{search}%"),
|
||||
)
|
||||
query = query.where(search_filter)
|
||||
|
||||
@@ -189,24 +204,25 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp
|
||||
total = count_result.scalar_one()
|
||||
|
||||
# Apply pagination and ordering
|
||||
query = query.order_by(Organization.created_at.desc()).offset(skip).limit(limit)
|
||||
query = (
|
||||
query.order_by(Organization.created_at.desc()).offset(skip).limit(limit)
|
||||
)
|
||||
|
||||
result = await db.execute(query)
|
||||
rows = result.all()
|
||||
|
||||
# Convert to list of dicts
|
||||
orgs_with_counts = [
|
||||
{
|
||||
'organization': org,
|
||||
'member_count': member_count
|
||||
}
|
||||
{"organization": org, "member_count": member_count}
|
||||
for org, member_count in rows
|
||||
]
|
||||
|
||||
return orgs_with_counts, total
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting organizations with member counts: {str(e)}", exc_info=True)
|
||||
logger.error(
|
||||
f"Error getting organizations with member counts: {e!s}", exc_info=True
|
||||
)
|
||||
raise
|
||||
|
||||
async def add_user(
|
||||
@@ -216,7 +232,7 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp
|
||||
organization_id: UUID,
|
||||
user_id: UUID,
|
||||
role: OrganizationRole = OrganizationRole.MEMBER,
|
||||
custom_permissions: Optional[str] = None
|
||||
custom_permissions: str | None = None,
|
||||
) -> UserOrganization:
|
||||
"""Add a user to an organization with a specific role."""
|
||||
try:
|
||||
@@ -225,7 +241,7 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp
|
||||
select(UserOrganization).where(
|
||||
and_(
|
||||
UserOrganization.user_id == user_id,
|
||||
UserOrganization.organization_id == organization_id
|
||||
UserOrganization.organization_id == organization_id,
|
||||
)
|
||||
)
|
||||
)
|
||||
@@ -249,7 +265,7 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp
|
||||
organization_id=organization_id,
|
||||
role=role,
|
||||
is_active=True,
|
||||
custom_permissions=custom_permissions
|
||||
custom_permissions=custom_permissions,
|
||||
)
|
||||
db.add(user_org)
|
||||
await db.commit()
|
||||
@@ -257,19 +273,15 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp
|
||||
return user_org
|
||||
except IntegrityError as e:
|
||||
await db.rollback()
|
||||
logger.error(f"Integrity error adding user to organization: {str(e)}")
|
||||
logger.error(f"Integrity error adding user to organization: {e!s}")
|
||||
raise ValueError("Failed to add user to organization")
|
||||
except Exception as e:
|
||||
await db.rollback()
|
||||
logger.error(f"Error adding user to organization: {str(e)}", exc_info=True)
|
||||
logger.error(f"Error adding user to organization: {e!s}", exc_info=True)
|
||||
raise
|
||||
|
||||
async def remove_user(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
*,
|
||||
organization_id: UUID,
|
||||
user_id: UUID
|
||||
self, db: AsyncSession, *, organization_id: UUID, user_id: UUID
|
||||
) -> bool:
|
||||
"""Remove a user from an organization (soft delete)."""
|
||||
try:
|
||||
@@ -277,7 +289,7 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp
|
||||
select(UserOrganization).where(
|
||||
and_(
|
||||
UserOrganization.user_id == user_id,
|
||||
UserOrganization.organization_id == organization_id
|
||||
UserOrganization.organization_id == organization_id,
|
||||
)
|
||||
)
|
||||
)
|
||||
@@ -291,7 +303,7 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp
|
||||
return True
|
||||
except Exception as e:
|
||||
await db.rollback()
|
||||
logger.error(f"Error removing user from organization: {str(e)}", exc_info=True)
|
||||
logger.error(f"Error removing user from organization: {e!s}", exc_info=True)
|
||||
raise
|
||||
|
||||
async def update_user_role(
|
||||
@@ -301,15 +313,15 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp
|
||||
organization_id: UUID,
|
||||
user_id: UUID,
|
||||
role: OrganizationRole,
|
||||
custom_permissions: Optional[str] = None
|
||||
) -> Optional[UserOrganization]:
|
||||
custom_permissions: str | None = None,
|
||||
) -> UserOrganization | None:
|
||||
"""Update a user's role in an organization."""
|
||||
try:
|
||||
result = await db.execute(
|
||||
select(UserOrganization).where(
|
||||
and_(
|
||||
UserOrganization.user_id == user_id,
|
||||
UserOrganization.organization_id == organization_id
|
||||
UserOrganization.organization_id == organization_id,
|
||||
)
|
||||
)
|
||||
)
|
||||
@@ -326,7 +338,7 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp
|
||||
return user_org
|
||||
except Exception as e:
|
||||
await db.rollback()
|
||||
logger.error(f"Error updating user role: {str(e)}", exc_info=True)
|
||||
logger.error(f"Error updating user role: {e!s}", exc_info=True)
|
||||
raise
|
||||
|
||||
async def get_organization_members(
|
||||
@@ -336,8 +348,8 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp
|
||||
organization_id: UUID,
|
||||
skip: int = 0,
|
||||
limit: int = 100,
|
||||
is_active: bool = True
|
||||
) -> tuple[List[Dict[str, Any]], int]:
|
||||
is_active: bool = True,
|
||||
) -> tuple[list[dict[str, Any]], int]:
|
||||
"""
|
||||
Get members of an organization with user details.
|
||||
|
||||
@@ -359,46 +371,55 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp
|
||||
count_query = select(func.count()).select_from(
|
||||
select(UserOrganization)
|
||||
.where(UserOrganization.organization_id == organization_id)
|
||||
.where(UserOrganization.is_active == is_active if is_active is not None else True)
|
||||
.where(
|
||||
UserOrganization.is_active == is_active
|
||||
if is_active is not None
|
||||
else True
|
||||
)
|
||||
.alias()
|
||||
)
|
||||
count_result = await db.execute(count_query)
|
||||
total = count_result.scalar_one()
|
||||
|
||||
# Apply ordering and pagination
|
||||
query = query.order_by(UserOrganization.created_at.desc()).offset(skip).limit(limit)
|
||||
query = (
|
||||
query.order_by(UserOrganization.created_at.desc())
|
||||
.offset(skip)
|
||||
.limit(limit)
|
||||
)
|
||||
result = await db.execute(query)
|
||||
results = result.all()
|
||||
|
||||
members = []
|
||||
for user_org, user in results:
|
||||
members.append({
|
||||
"user_id": user.id,
|
||||
"email": user.email,
|
||||
"first_name": user.first_name,
|
||||
"last_name": user.last_name,
|
||||
"role": user_org.role,
|
||||
"is_active": user_org.is_active,
|
||||
"joined_at": user_org.created_at
|
||||
})
|
||||
members.append(
|
||||
{
|
||||
"user_id": user.id,
|
||||
"email": user.email,
|
||||
"first_name": user.first_name,
|
||||
"last_name": user.last_name,
|
||||
"role": user_org.role,
|
||||
"is_active": user_org.is_active,
|
||||
"joined_at": user_org.created_at,
|
||||
}
|
||||
)
|
||||
|
||||
return members, total
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting organization members: {str(e)}")
|
||||
logger.error(f"Error getting organization members: {e!s}")
|
||||
raise
|
||||
|
||||
async def get_user_organizations(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
*,
|
||||
user_id: UUID,
|
||||
is_active: bool = True
|
||||
) -> List[Organization]:
|
||||
self, db: AsyncSession, *, user_id: UUID, is_active: bool = True
|
||||
) -> list[Organization]:
|
||||
"""Get all organizations a user belongs to."""
|
||||
try:
|
||||
query = (
|
||||
select(Organization)
|
||||
.join(UserOrganization, Organization.id == UserOrganization.organization_id)
|
||||
.join(
|
||||
UserOrganization,
|
||||
Organization.id == UserOrganization.organization_id,
|
||||
)
|
||||
.where(UserOrganization.user_id == user_id)
|
||||
)
|
||||
|
||||
@@ -408,16 +429,12 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp
|
||||
result = await db.execute(query)
|
||||
return list(result.scalars().all())
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting user organizations: {str(e)}")
|
||||
logger.error(f"Error getting user organizations: {e!s}")
|
||||
raise
|
||||
|
||||
async def get_user_organizations_with_details(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
*,
|
||||
user_id: UUID,
|
||||
is_active: bool = True
|
||||
) -> List[Dict[str, Any]]:
|
||||
self, db: AsyncSession, *, user_id: UUID, is_active: bool = True
|
||||
) -> list[dict[str, Any]]:
|
||||
"""
|
||||
Get user's organizations with role and member count in SINGLE QUERY.
|
||||
Eliminates N+1 problem by using subquery for member counts.
|
||||
@@ -430,9 +447,9 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp
|
||||
member_count_subq = (
|
||||
select(
|
||||
UserOrganization.organization_id,
|
||||
func.count(UserOrganization.user_id).label('member_count')
|
||||
func.count(UserOrganization.user_id).label("member_count"),
|
||||
)
|
||||
.where(UserOrganization.is_active == True)
|
||||
.where(UserOrganization.is_active)
|
||||
.group_by(UserOrganization.organization_id)
|
||||
.subquery()
|
||||
)
|
||||
@@ -442,10 +459,18 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp
|
||||
select(
|
||||
Organization,
|
||||
UserOrganization.role,
|
||||
func.coalesce(member_count_subq.c.member_count, 0).label('member_count')
|
||||
func.coalesce(member_count_subq.c.member_count, 0).label(
|
||||
"member_count"
|
||||
),
|
||||
)
|
||||
.join(
|
||||
UserOrganization,
|
||||
Organization.id == UserOrganization.organization_id,
|
||||
)
|
||||
.outerjoin(
|
||||
member_count_subq,
|
||||
Organization.id == member_count_subq.c.organization_id,
|
||||
)
|
||||
.join(UserOrganization, Organization.id == UserOrganization.organization_id)
|
||||
.outerjoin(member_count_subq, Organization.id == member_count_subq.c.organization_id)
|
||||
.where(UserOrganization.user_id == user_id)
|
||||
)
|
||||
|
||||
@@ -456,25 +481,19 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp
|
||||
rows = result.all()
|
||||
|
||||
return [
|
||||
{
|
||||
'organization': org,
|
||||
'role': role,
|
||||
'member_count': member_count
|
||||
}
|
||||
{"organization": org, "role": role, "member_count": member_count}
|
||||
for org, role, member_count in rows
|
||||
]
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting user organizations with details: {str(e)}", exc_info=True)
|
||||
logger.error(
|
||||
f"Error getting user organizations with details: {e!s}", exc_info=True
|
||||
)
|
||||
raise
|
||||
|
||||
async def get_user_role_in_org(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
*,
|
||||
user_id: UUID,
|
||||
organization_id: UUID
|
||||
) -> Optional[OrganizationRole]:
|
||||
self, db: AsyncSession, *, user_id: UUID, organization_id: UUID
|
||||
) -> OrganizationRole | None:
|
||||
"""Get a user's role in a specific organization."""
|
||||
try:
|
||||
result = await db.execute(
|
||||
@@ -482,7 +501,7 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp
|
||||
and_(
|
||||
UserOrganization.user_id == user_id,
|
||||
UserOrganization.organization_id == organization_id,
|
||||
UserOrganization.is_active == True
|
||||
UserOrganization.is_active,
|
||||
)
|
||||
)
|
||||
)
|
||||
@@ -490,29 +509,25 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp
|
||||
|
||||
return user_org.role if user_org else None
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting user role in org: {str(e)}")
|
||||
logger.error(f"Error getting user role in org: {e!s}")
|
||||
raise
|
||||
|
||||
async def is_user_org_owner(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
*,
|
||||
user_id: UUID,
|
||||
organization_id: UUID
|
||||
self, db: AsyncSession, *, user_id: UUID, organization_id: UUID
|
||||
) -> bool:
|
||||
"""Check if a user is an owner of an organization."""
|
||||
role = await self.get_user_role_in_org(db, user_id=user_id, organization_id=organization_id)
|
||||
role = await self.get_user_role_in_org(
|
||||
db, user_id=user_id, organization_id=organization_id
|
||||
)
|
||||
return role == OrganizationRole.OWNER
|
||||
|
||||
async def is_user_org_admin(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
*,
|
||||
user_id: UUID,
|
||||
organization_id: UUID
|
||||
self, db: AsyncSession, *, user_id: UUID, organization_id: UUID
|
||||
) -> bool:
|
||||
"""Check if a user is an owner or admin of an organization."""
|
||||
role = await self.get_user_role_in_org(db, user_id=user_id, organization_id=organization_id)
|
||||
role = await self.get_user_role_in_org(
|
||||
db, user_id=user_id, organization_id=organization_id
|
||||
)
|
||||
return role in [OrganizationRole.OWNER, OrganizationRole.ADMIN]
|
||||
|
||||
|
||||
|
||||
@@ -1,13 +1,13 @@
|
||||
"""
|
||||
Async CRUD operations for user sessions using SQLAlchemy 2.0 patterns.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import uuid
|
||||
from datetime import datetime, timezone, timedelta
|
||||
from typing import List, Optional
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy import and_, select, update, delete, func
|
||||
from sqlalchemy import and_, delete, func, select, update
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm import joinedload
|
||||
|
||||
@@ -21,7 +21,7 @@ logger = logging.getLogger(__name__)
|
||||
class CRUDSession(CRUDBase[UserSession, SessionCreate, SessionUpdate]):
|
||||
"""Async CRUD operations for user sessions."""
|
||||
|
||||
async def get_by_jti(self, db: AsyncSession, *, jti: str) -> Optional[UserSession]:
|
||||
async def get_by_jti(self, db: AsyncSession, *, jti: str) -> UserSession | None:
|
||||
"""
|
||||
Get session by refresh token JTI.
|
||||
|
||||
@@ -38,10 +38,12 @@ class CRUDSession(CRUDBase[UserSession, SessionCreate, SessionUpdate]):
|
||||
)
|
||||
return result.scalar_one_or_none()
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting session by JTI {jti}: {str(e)}")
|
||||
logger.error(f"Error getting session by JTI {jti}: {e!s}")
|
||||
raise
|
||||
|
||||
async def get_active_by_jti(self, db: AsyncSession, *, jti: str) -> Optional[UserSession]:
|
||||
async def get_active_by_jti(
|
||||
self, db: AsyncSession, *, jti: str
|
||||
) -> UserSession | None:
|
||||
"""
|
||||
Get active session by refresh token JTI.
|
||||
|
||||
@@ -57,13 +59,13 @@ class CRUDSession(CRUDBase[UserSession, SessionCreate, SessionUpdate]):
|
||||
select(UserSession).where(
|
||||
and_(
|
||||
UserSession.refresh_token_jti == jti,
|
||||
UserSession.is_active == True
|
||||
UserSession.is_active,
|
||||
)
|
||||
)
|
||||
)
|
||||
return result.scalar_one_or_none()
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting active session by JTI {jti}: {str(e)}")
|
||||
logger.error(f"Error getting active session by JTI {jti}: {e!s}")
|
||||
raise
|
||||
|
||||
async def get_user_sessions(
|
||||
@@ -72,8 +74,8 @@ class CRUDSession(CRUDBase[UserSession, SessionCreate, SessionUpdate]):
|
||||
*,
|
||||
user_id: str,
|
||||
active_only: bool = True,
|
||||
with_user: bool = False
|
||||
) -> List[UserSession]:
|
||||
with_user: bool = False,
|
||||
) -> list[UserSession]:
|
||||
"""
|
||||
Get all sessions for a user with optional eager loading.
|
||||
|
||||
@@ -97,20 +99,17 @@ class CRUDSession(CRUDBase[UserSession, SessionCreate, SessionUpdate]):
|
||||
query = query.options(joinedload(UserSession.user))
|
||||
|
||||
if active_only:
|
||||
query = query.where(UserSession.is_active == True)
|
||||
query = query.where(UserSession.is_active)
|
||||
|
||||
query = query.order_by(UserSession.last_used_at.desc())
|
||||
result = await db.execute(query)
|
||||
return list(result.scalars().all())
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting sessions for user {user_id}: {str(e)}")
|
||||
logger.error(f"Error getting sessions for user {user_id}: {e!s}")
|
||||
raise
|
||||
|
||||
async def create_session(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
*,
|
||||
obj_in: SessionCreate
|
||||
self, db: AsyncSession, *, obj_in: SessionCreate
|
||||
) -> UserSession:
|
||||
"""
|
||||
Create a new user session.
|
||||
@@ -151,10 +150,12 @@ class CRUDSession(CRUDBase[UserSession, SessionCreate, SessionUpdate]):
|
||||
return db_obj
|
||||
except Exception as e:
|
||||
await db.rollback()
|
||||
logger.error(f"Error creating session: {str(e)}", exc_info=True)
|
||||
raise ValueError(f"Failed to create session: {str(e)}")
|
||||
logger.error(f"Error creating session: {e!s}", exc_info=True)
|
||||
raise ValueError(f"Failed to create session: {e!s}")
|
||||
|
||||
async def deactivate(self, db: AsyncSession, *, session_id: str) -> Optional[UserSession]:
|
||||
async def deactivate(
|
||||
self, db: AsyncSession, *, session_id: str
|
||||
) -> UserSession | None:
|
||||
"""
|
||||
Deactivate a session (logout from device).
|
||||
|
||||
@@ -184,14 +185,11 @@ class CRUDSession(CRUDBase[UserSession, SessionCreate, SessionUpdate]):
|
||||
return session
|
||||
except Exception as e:
|
||||
await db.rollback()
|
||||
logger.error(f"Error deactivating session {session_id}: {str(e)}")
|
||||
logger.error(f"Error deactivating session {session_id}: {e!s}")
|
||||
raise
|
||||
|
||||
async def deactivate_all_user_sessions(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
*,
|
||||
user_id: str
|
||||
self, db: AsyncSession, *, user_id: str
|
||||
) -> int:
|
||||
"""
|
||||
Deactivate all active sessions for a user (logout from all devices).
|
||||
@@ -209,12 +207,7 @@ class CRUDSession(CRUDBase[UserSession, SessionCreate, SessionUpdate]):
|
||||
|
||||
stmt = (
|
||||
update(UserSession)
|
||||
.where(
|
||||
and_(
|
||||
UserSession.user_id == user_uuid,
|
||||
UserSession.is_active == True
|
||||
)
|
||||
)
|
||||
.where(and_(UserSession.user_id == user_uuid, UserSession.is_active))
|
||||
.values(is_active=False)
|
||||
)
|
||||
|
||||
@@ -228,14 +221,11 @@ class CRUDSession(CRUDBase[UserSession, SessionCreate, SessionUpdate]):
|
||||
return count
|
||||
except Exception as e:
|
||||
await db.rollback()
|
||||
logger.error(f"Error deactivating all sessions for user {user_id}: {str(e)}")
|
||||
logger.error(f"Error deactivating all sessions for user {user_id}: {e!s}")
|
||||
raise
|
||||
|
||||
async def update_last_used(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
*,
|
||||
session: UserSession
|
||||
self, db: AsyncSession, *, session: UserSession
|
||||
) -> UserSession:
|
||||
"""
|
||||
Update the last_used_at timestamp for a session.
|
||||
@@ -248,14 +238,14 @@ class CRUDSession(CRUDBase[UserSession, SessionCreate, SessionUpdate]):
|
||||
Updated UserSession
|
||||
"""
|
||||
try:
|
||||
session.last_used_at = datetime.now(timezone.utc)
|
||||
session.last_used_at = datetime.now(UTC)
|
||||
db.add(session)
|
||||
await db.commit()
|
||||
await db.refresh(session)
|
||||
return session
|
||||
except Exception as e:
|
||||
await db.rollback()
|
||||
logger.error(f"Error updating last_used for session {session.id}: {str(e)}")
|
||||
logger.error(f"Error updating last_used for session {session.id}: {e!s}")
|
||||
raise
|
||||
|
||||
async def update_refresh_token(
|
||||
@@ -264,7 +254,7 @@ class CRUDSession(CRUDBase[UserSession, SessionCreate, SessionUpdate]):
|
||||
*,
|
||||
session: UserSession,
|
||||
new_jti: str,
|
||||
new_expires_at: datetime
|
||||
new_expires_at: datetime,
|
||||
) -> UserSession:
|
||||
"""
|
||||
Update session with new refresh token JTI and expiration.
|
||||
@@ -283,14 +273,16 @@ class CRUDSession(CRUDBase[UserSession, SessionCreate, SessionUpdate]):
|
||||
try:
|
||||
session.refresh_token_jti = new_jti
|
||||
session.expires_at = new_expires_at
|
||||
session.last_used_at = datetime.now(timezone.utc)
|
||||
session.last_used_at = datetime.now(UTC)
|
||||
db.add(session)
|
||||
await db.commit()
|
||||
await db.refresh(session)
|
||||
return session
|
||||
except Exception as e:
|
||||
await db.rollback()
|
||||
logger.error(f"Error updating refresh token for session {session.id}: {str(e)}")
|
||||
logger.error(
|
||||
f"Error updating refresh token for session {session.id}: {e!s}"
|
||||
)
|
||||
raise
|
||||
|
||||
async def cleanup_expired(self, db: AsyncSession, *, keep_days: int = 30) -> int:
|
||||
@@ -311,15 +303,15 @@ class CRUDSession(CRUDBase[UserSession, SessionCreate, SessionUpdate]):
|
||||
Number of sessions deleted
|
||||
"""
|
||||
try:
|
||||
cutoff_date = datetime.now(timezone.utc) - timedelta(days=keep_days)
|
||||
now = datetime.now(timezone.utc)
|
||||
cutoff_date = datetime.now(UTC) - timedelta(days=keep_days)
|
||||
now = datetime.now(UTC)
|
||||
|
||||
# Use bulk DELETE with WHERE clause - single query
|
||||
stmt = delete(UserSession).where(
|
||||
and_(
|
||||
UserSession.is_active == False,
|
||||
not UserSession.is_active,
|
||||
UserSession.expires_at < now,
|
||||
UserSession.created_at < cutoff_date
|
||||
UserSession.created_at < cutoff_date,
|
||||
)
|
||||
)
|
||||
|
||||
@@ -334,15 +326,10 @@ class CRUDSession(CRUDBase[UserSession, SessionCreate, SessionUpdate]):
|
||||
return count
|
||||
except Exception as e:
|
||||
await db.rollback()
|
||||
logger.error(f"Error cleaning up expired sessions: {str(e)}")
|
||||
logger.error(f"Error cleaning up expired sessions: {e!s}")
|
||||
raise
|
||||
|
||||
async def cleanup_expired_for_user(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
*,
|
||||
user_id: str
|
||||
) -> int:
|
||||
async def cleanup_expired_for_user(self, db: AsyncSession, *, user_id: str) -> int:
|
||||
"""
|
||||
Clean up expired and inactive sessions for a specific user.
|
||||
|
||||
@@ -363,14 +350,14 @@ class CRUDSession(CRUDBase[UserSession, SessionCreate, SessionUpdate]):
|
||||
logger.error(f"Invalid UUID format: {user_id}")
|
||||
raise ValueError(f"Invalid user ID format: {user_id}")
|
||||
|
||||
now = datetime.now(timezone.utc)
|
||||
now = datetime.now(UTC)
|
||||
|
||||
# Use bulk DELETE with WHERE clause - single query
|
||||
stmt = delete(UserSession).where(
|
||||
and_(
|
||||
UserSession.user_id == uuid_obj,
|
||||
UserSession.is_active == False,
|
||||
UserSession.expires_at < now
|
||||
not UserSession.is_active,
|
||||
UserSession.expires_at < now,
|
||||
)
|
||||
)
|
||||
|
||||
@@ -388,7 +375,7 @@ class CRUDSession(CRUDBase[UserSession, SessionCreate, SessionUpdate]):
|
||||
except Exception as e:
|
||||
await db.rollback()
|
||||
logger.error(
|
||||
f"Error cleaning up expired sessions for user {user_id}: {str(e)}"
|
||||
f"Error cleaning up expired sessions for user {user_id}: {e!s}"
|
||||
)
|
||||
raise
|
||||
|
||||
@@ -409,15 +396,12 @@ class CRUDSession(CRUDBase[UserSession, SessionCreate, SessionUpdate]):
|
||||
|
||||
result = await db.execute(
|
||||
select(func.count(UserSession.id)).where(
|
||||
and_(
|
||||
UserSession.user_id == user_uuid,
|
||||
UserSession.is_active == True
|
||||
)
|
||||
and_(UserSession.user_id == user_uuid, UserSession.is_active)
|
||||
)
|
||||
)
|
||||
return result.scalar_one()
|
||||
except Exception as e:
|
||||
logger.error(f"Error counting sessions for user {user_id}: {str(e)}")
|
||||
logger.error(f"Error counting sessions for user {user_id}: {e!s}")
|
||||
raise
|
||||
|
||||
async def get_all_sessions(
|
||||
@@ -427,8 +411,8 @@ class CRUDSession(CRUDBase[UserSession, SessionCreate, SessionUpdate]):
|
||||
skip: int = 0,
|
||||
limit: int = 100,
|
||||
active_only: bool = True,
|
||||
with_user: bool = True
|
||||
) -> tuple[List[UserSession], int]:
|
||||
with_user: bool = True,
|
||||
) -> tuple[list[UserSession], int]:
|
||||
"""
|
||||
Get all sessions across all users with pagination (admin only).
|
||||
|
||||
@@ -451,18 +435,22 @@ class CRUDSession(CRUDBase[UserSession, SessionCreate, SessionUpdate]):
|
||||
query = query.options(joinedload(UserSession.user))
|
||||
|
||||
if active_only:
|
||||
query = query.where(UserSession.is_active == True)
|
||||
query = query.where(UserSession.is_active)
|
||||
|
||||
# Get total count
|
||||
count_query = select(func.count(UserSession.id))
|
||||
if active_only:
|
||||
count_query = count_query.where(UserSession.is_active == True)
|
||||
count_query = count_query.where(UserSession.is_active)
|
||||
|
||||
count_result = await db.execute(count_query)
|
||||
total = count_result.scalar_one()
|
||||
|
||||
# Apply pagination and ordering
|
||||
query = query.order_by(UserSession.last_used_at.desc()).offset(skip).limit(limit)
|
||||
query = (
|
||||
query.order_by(UserSession.last_used_at.desc())
|
||||
.offset(skip)
|
||||
.limit(limit)
|
||||
)
|
||||
|
||||
result = await db.execute(query)
|
||||
sessions = list(result.scalars().all())
|
||||
@@ -470,7 +458,7 @@ class CRUDSession(CRUDBase[UserSession, SessionCreate, SessionUpdate]):
|
||||
return sessions, total
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting all sessions: {str(e)}", exc_info=True)
|
||||
logger.error(f"Error getting all sessions: {e!s}", exc_info=True)
|
||||
raise
|
||||
|
||||
|
||||
|
||||
@@ -1,8 +1,9 @@
|
||||
# app/crud/user_async.py
|
||||
"""Async CRUD operations for User model using SQLAlchemy 2.0 patterns."""
|
||||
|
||||
import logging
|
||||
from datetime import datetime, timezone
|
||||
from typing import Optional, Union, Dict, Any, List, Tuple
|
||||
from datetime import UTC, datetime
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy import or_, select, update
|
||||
@@ -20,15 +21,13 @@ logger = logging.getLogger(__name__)
|
||||
class CRUDUser(CRUDBase[User, UserCreate, UserUpdate]):
|
||||
"""Async CRUD operations for User model."""
|
||||
|
||||
async def get_by_email(self, db: AsyncSession, *, email: str) -> Optional[User]:
|
||||
async def get_by_email(self, db: AsyncSession, *, email: str) -> User | None:
|
||||
"""Get user by email address."""
|
||||
try:
|
||||
result = await db.execute(
|
||||
select(User).where(User.email == email)
|
||||
)
|
||||
result = await db.execute(select(User).where(User.email == email))
|
||||
return result.scalar_one_or_none()
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting user by email {email}: {str(e)}")
|
||||
logger.error(f"Error getting user by email {email}: {e!s}")
|
||||
raise
|
||||
|
||||
async def create(self, db: AsyncSession, *, obj_in: UserCreate) -> User:
|
||||
@@ -42,9 +41,13 @@ class CRUDUser(CRUDBase[User, UserCreate, UserUpdate]):
|
||||
password_hash=password_hash,
|
||||
first_name=obj_in.first_name,
|
||||
last_name=obj_in.last_name,
|
||||
phone_number=obj_in.phone_number if hasattr(obj_in, 'phone_number') else None,
|
||||
is_superuser=obj_in.is_superuser if hasattr(obj_in, 'is_superuser') else False,
|
||||
preferences={}
|
||||
phone_number=obj_in.phone_number
|
||||
if hasattr(obj_in, "phone_number")
|
||||
else None,
|
||||
is_superuser=obj_in.is_superuser
|
||||
if hasattr(obj_in, "is_superuser")
|
||||
else False,
|
||||
preferences={},
|
||||
)
|
||||
db.add(db_obj)
|
||||
await db.commit()
|
||||
@@ -52,7 +55,7 @@ class CRUDUser(CRUDBase[User, UserCreate, UserUpdate]):
|
||||
return db_obj
|
||||
except IntegrityError as e:
|
||||
await db.rollback()
|
||||
error_msg = str(e.orig) if hasattr(e, 'orig') else str(e)
|
||||
error_msg = str(e.orig) if hasattr(e, "orig") else str(e)
|
||||
if "email" in error_msg.lower():
|
||||
logger.warning(f"Duplicate email attempted: {obj_in.email}")
|
||||
raise ValueError(f"User with email {obj_in.email} already exists")
|
||||
@@ -60,15 +63,11 @@ class CRUDUser(CRUDBase[User, UserCreate, UserUpdate]):
|
||||
raise ValueError(f"Database integrity error: {error_msg}")
|
||||
except Exception as e:
|
||||
await db.rollback()
|
||||
logger.error(f"Unexpected error creating user: {str(e)}", exc_info=True)
|
||||
logger.error(f"Unexpected error creating user: {e!s}", exc_info=True)
|
||||
raise
|
||||
|
||||
async def update(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
*,
|
||||
db_obj: User,
|
||||
obj_in: Union[UserUpdate, Dict[str, Any]]
|
||||
self, db: AsyncSession, *, db_obj: User, obj_in: UserUpdate | dict[str, Any]
|
||||
) -> User:
|
||||
"""Update user with async password hashing if password is updated."""
|
||||
if isinstance(obj_in, dict):
|
||||
@@ -79,7 +78,9 @@ class CRUDUser(CRUDBase[User, UserCreate, UserUpdate]):
|
||||
# Handle password separately if it exists in update data
|
||||
# Hash password asynchronously to avoid blocking event loop
|
||||
if "password" in update_data:
|
||||
update_data["password_hash"] = await get_password_hash_async(update_data["password"])
|
||||
update_data["password_hash"] = await get_password_hash_async(
|
||||
update_data["password"]
|
||||
)
|
||||
del update_data["password"]
|
||||
|
||||
return await super().update(db, db_obj=db_obj, obj_in=update_data)
|
||||
@@ -90,11 +91,11 @@ class CRUDUser(CRUDBase[User, UserCreate, UserUpdate]):
|
||||
*,
|
||||
skip: int = 0,
|
||||
limit: int = 100,
|
||||
sort_by: Optional[str] = None,
|
||||
sort_by: str | None = None,
|
||||
sort_order: str = "asc",
|
||||
filters: Optional[Dict[str, Any]] = None,
|
||||
search: Optional[str] = None
|
||||
) -> Tuple[List[User], int]:
|
||||
filters: dict[str, Any] | None = None,
|
||||
search: str | None = None,
|
||||
) -> tuple[list[User], int]:
|
||||
"""
|
||||
Get multiple users with total count, filtering, sorting, and search.
|
||||
|
||||
@@ -136,12 +137,13 @@ class CRUDUser(CRUDBase[User, UserCreate, UserUpdate]):
|
||||
search_filter = or_(
|
||||
User.email.ilike(f"%{search}%"),
|
||||
User.first_name.ilike(f"%{search}%"),
|
||||
User.last_name.ilike(f"%{search}%")
|
||||
User.last_name.ilike(f"%{search}%"),
|
||||
)
|
||||
query = query.where(search_filter)
|
||||
|
||||
# Get total count
|
||||
from sqlalchemy import func
|
||||
|
||||
count_query = select(func.count()).select_from(query.alias())
|
||||
count_result = await db.execute(count_query)
|
||||
total = count_result.scalar_one()
|
||||
@@ -162,15 +164,11 @@ class CRUDUser(CRUDBase[User, UserCreate, UserUpdate]):
|
||||
return users, total
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error retrieving paginated users: {str(e)}")
|
||||
logger.error(f"Error retrieving paginated users: {e!s}")
|
||||
raise
|
||||
|
||||
async def bulk_update_status(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
*,
|
||||
user_ids: List[UUID],
|
||||
is_active: bool
|
||||
self, db: AsyncSession, *, user_ids: list[UUID], is_active: bool
|
||||
) -> int:
|
||||
"""
|
||||
Bulk update is_active status for multiple users.
|
||||
@@ -192,7 +190,7 @@ class CRUDUser(CRUDBase[User, UserCreate, UserUpdate]):
|
||||
update(User)
|
||||
.where(User.id.in_(user_ids))
|
||||
.where(User.deleted_at.is_(None)) # Don't update deleted users
|
||||
.values(is_active=is_active, updated_at=datetime.now(timezone.utc))
|
||||
.values(is_active=is_active, updated_at=datetime.now(UTC))
|
||||
)
|
||||
|
||||
result = await db.execute(stmt)
|
||||
@@ -204,15 +202,15 @@ class CRUDUser(CRUDBase[User, UserCreate, UserUpdate]):
|
||||
|
||||
except Exception as e:
|
||||
await db.rollback()
|
||||
logger.error(f"Error bulk updating user status: {str(e)}", exc_info=True)
|
||||
logger.error(f"Error bulk updating user status: {e!s}", exc_info=True)
|
||||
raise
|
||||
|
||||
async def bulk_soft_delete(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
*,
|
||||
user_ids: List[UUID],
|
||||
exclude_user_id: Optional[UUID] = None
|
||||
user_ids: list[UUID],
|
||||
exclude_user_id: UUID | None = None,
|
||||
) -> int:
|
||||
"""
|
||||
Bulk soft delete multiple users.
|
||||
@@ -239,11 +237,13 @@ class CRUDUser(CRUDBase[User, UserCreate, UserUpdate]):
|
||||
stmt = (
|
||||
update(User)
|
||||
.where(User.id.in_(filtered_ids))
|
||||
.where(User.deleted_at.is_(None)) # Don't re-delete already deleted users
|
||||
.where(
|
||||
User.deleted_at.is_(None)
|
||||
) # Don't re-delete already deleted users
|
||||
.values(
|
||||
deleted_at=datetime.now(timezone.utc),
|
||||
deleted_at=datetime.now(UTC),
|
||||
is_active=False,
|
||||
updated_at=datetime.now(timezone.utc)
|
||||
updated_at=datetime.now(UTC),
|
||||
)
|
||||
)
|
||||
|
||||
@@ -256,7 +256,7 @@ class CRUDUser(CRUDBase[User, UserCreate, UserUpdate]):
|
||||
|
||||
except Exception as e:
|
||||
await db.rollback()
|
||||
logger.error(f"Error bulk deleting users: {str(e)}", exc_info=True)
|
||||
logger.error(f"Error bulk deleting users: {e!s}", exc_info=True)
|
||||
raise
|
||||
|
||||
def is_active(self, user: User) -> bool:
|
||||
|
||||
@@ -4,9 +4,9 @@ Async database initialization script.
|
||||
|
||||
Creates the first superuser if configured and doesn't already exist.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
from app.core.config import settings
|
||||
from app.core.database import SessionLocal, engine
|
||||
@@ -17,7 +17,7 @@ from app.schemas.users import UserCreate
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def init_db() -> Optional[User]:
|
||||
async def init_db() -> User | None:
|
||||
"""
|
||||
Initialize database with first superuser if settings are configured and user doesn't exist.
|
||||
|
||||
@@ -49,7 +49,7 @@ async def init_db() -> Optional[User]:
|
||||
password=superuser_password,
|
||||
first_name="Admin",
|
||||
last_name="User",
|
||||
is_superuser=True
|
||||
is_superuser=True,
|
||||
)
|
||||
|
||||
user = await user_crud.create(session, obj_in=user_in)
|
||||
@@ -70,13 +70,13 @@ async def main():
|
||||
# Configure logging to show info logs
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
|
||||
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
|
||||
)
|
||||
|
||||
try:
|
||||
user = await init_db()
|
||||
if user:
|
||||
print(f"✓ Database initialized successfully")
|
||||
print("✓ Database initialized successfully")
|
||||
print(f"✓ Superuser: {user.email}")
|
||||
else:
|
||||
print("✗ Failed to initialize database")
|
||||
|
||||
@@ -2,10 +2,10 @@ import logging
|
||||
import os
|
||||
from contextlib import asynccontextmanager
|
||||
from datetime import datetime
|
||||
from typing import Dict, Any
|
||||
from typing import Any
|
||||
|
||||
from apscheduler.schedulers.asyncio import AsyncIOScheduler
|
||||
from fastapi import FastAPI, status, Request, HTTPException
|
||||
from fastapi import FastAPI, HTTPException, Request, status
|
||||
from fastapi.exceptions import RequestValidationError
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.responses import HTMLResponse, JSONResponse
|
||||
@@ -19,9 +19,9 @@ from app.core.database import check_database_health
|
||||
from app.core.exceptions import (
|
||||
APIException,
|
||||
api_exception_handler,
|
||||
validation_exception_handler,
|
||||
http_exception_handler,
|
||||
unhandled_exception_handler
|
||||
unhandled_exception_handler,
|
||||
validation_exception_handler,
|
||||
)
|
||||
|
||||
scheduler = AsyncIOScheduler()
|
||||
@@ -52,11 +52,11 @@ async def lifespan(app: FastAPI):
|
||||
# Runs daily at 2:00 AM server time
|
||||
scheduler.add_job(
|
||||
cleanup_expired_sessions,
|
||||
'cron',
|
||||
"cron",
|
||||
hour=2,
|
||||
minute=0,
|
||||
id='cleanup_expired_sessions',
|
||||
replace_existing=True
|
||||
id="cleanup_expired_sessions",
|
||||
replace_existing=True,
|
||||
)
|
||||
|
||||
scheduler.start()
|
||||
@@ -73,12 +73,12 @@ async def lifespan(app: FastAPI):
|
||||
logger.info("Scheduled jobs stopped")
|
||||
|
||||
|
||||
logger.info(f"Starting app!!!")
|
||||
logger.info("Starting app!!!")
|
||||
app = FastAPI(
|
||||
title=settings.PROJECT_NAME,
|
||||
version=settings.VERSION,
|
||||
openapi_url=f"{settings.API_V1_STR}/openapi.json",
|
||||
lifespan=lifespan
|
||||
lifespan=lifespan,
|
||||
)
|
||||
|
||||
# Add rate limiter state to app
|
||||
@@ -96,7 +96,14 @@ app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=settings.BACKEND_CORS_ORIGINS,
|
||||
allow_credentials=True,
|
||||
allow_methods=["GET", "POST", "PUT", "PATCH", "DELETE", "OPTIONS"], # Explicit methods only
|
||||
allow_methods=[
|
||||
"GET",
|
||||
"POST",
|
||||
"PUT",
|
||||
"PATCH",
|
||||
"DELETE",
|
||||
"OPTIONS",
|
||||
], # Explicit methods only
|
||||
allow_headers=[
|
||||
"Content-Type",
|
||||
"Authorization",
|
||||
@@ -129,12 +136,14 @@ async def limit_request_size(request: Request, call_next):
|
||||
status_code=status.HTTP_413_REQUEST_ENTITY_TOO_LARGE,
|
||||
content={
|
||||
"success": False,
|
||||
"errors": [{
|
||||
"code": "REQUEST_TOO_LARGE",
|
||||
"message": f"Request body too large. Maximum size is {MAX_REQUEST_SIZE // (1024 * 1024)}MB",
|
||||
"field": None
|
||||
}]
|
||||
}
|
||||
"errors": [
|
||||
{
|
||||
"code": "REQUEST_TOO_LARGE",
|
||||
"message": f"Request body too large. Maximum size is {MAX_REQUEST_SIZE // (1024 * 1024)}MB",
|
||||
"field": None,
|
||||
}
|
||||
],
|
||||
},
|
||||
)
|
||||
|
||||
response = await call_next(request)
|
||||
@@ -165,15 +174,19 @@ async def add_security_headers(request: Request, call_next):
|
||||
|
||||
# Enforce HTTPS in production
|
||||
if settings.ENVIRONMENT == "production":
|
||||
response.headers["Strict-Transport-Security"] = "max-age=31536000; includeSubDomains"
|
||||
response.headers["Strict-Transport-Security"] = (
|
||||
"max-age=31536000; includeSubDomains"
|
||||
)
|
||||
|
||||
# Content Security Policy
|
||||
csp_mode = settings.CSP_MODE.lower()
|
||||
|
||||
# Special handling for API docs
|
||||
is_docs = request.url.path in ["/docs", "/redoc"] or \
|
||||
request.url.path.startswith("/docs/") or \
|
||||
request.url.path.startswith("/redoc/")
|
||||
is_docs = (
|
||||
request.url.path in ["/docs", "/redoc"]
|
||||
or request.url.path.startswith("/docs/")
|
||||
or request.url.path.startswith("/redoc/")
|
||||
)
|
||||
|
||||
if csp_mode == "disabled":
|
||||
# No CSP (only for local development/debugging)
|
||||
@@ -264,7 +277,7 @@ async def root():
|
||||
description="Check the health status of the API and its dependencies",
|
||||
response_description="Health status information",
|
||||
tags=["Health"],
|
||||
operation_id="health_check"
|
||||
operation_id="health_check",
|
||||
)
|
||||
async def health_check() -> JSONResponse:
|
||||
"""
|
||||
@@ -278,12 +291,12 @@ async def health_check() -> JSONResponse:
|
||||
- environment: Current environment (development, staging, production)
|
||||
- database: Database connectivity status
|
||||
"""
|
||||
health_status: Dict[str, Any] = {
|
||||
health_status: dict[str, Any] = {
|
||||
"status": "healthy",
|
||||
"timestamp": datetime.utcnow().isoformat() + "Z",
|
||||
"version": settings.VERSION,
|
||||
"environment": settings.ENVIRONMENT,
|
||||
"checks": {}
|
||||
"checks": {},
|
||||
}
|
||||
|
||||
response_status = status.HTTP_200_OK
|
||||
@@ -294,7 +307,7 @@ async def health_check() -> JSONResponse:
|
||||
if db_healthy:
|
||||
health_status["checks"]["database"] = {
|
||||
"status": "healthy",
|
||||
"message": "Database connection successful"
|
||||
"message": "Database connection successful",
|
||||
}
|
||||
else:
|
||||
raise Exception("Database health check returned unhealthy status")
|
||||
@@ -302,15 +315,12 @@ async def health_check() -> JSONResponse:
|
||||
health_status["status"] = "unhealthy"
|
||||
health_status["checks"]["database"] = {
|
||||
"status": "unhealthy",
|
||||
"message": f"Database connection failed: {str(e)}"
|
||||
"message": f"Database connection failed: {e!s}",
|
||||
}
|
||||
response_status = status.HTTP_503_SERVICE_UNAVAILABLE
|
||||
logger.error(f"Health check failed - database error: {e}")
|
||||
|
||||
return JSONResponse(
|
||||
status_code=response_status,
|
||||
content=health_status
|
||||
)
|
||||
return JSONResponse(status_code=response_status, content=health_status)
|
||||
|
||||
|
||||
app.include_router(api_router, prefix=settings.API_V1_STR)
|
||||
|
||||
@@ -2,17 +2,25 @@
|
||||
Models package initialization.
|
||||
Imports all models to ensure they're registered with SQLAlchemy.
|
||||
"""
|
||||
|
||||
# First import Base to avoid circular imports
|
||||
from app.core.database import Base
|
||||
|
||||
from .base import TimestampMixin, UUIDMixin
|
||||
from .organization import Organization
|
||||
|
||||
# Import models
|
||||
from .user import User
|
||||
from .user_organization import UserOrganization, OrganizationRole
|
||||
from .user_organization import OrganizationRole, UserOrganization
|
||||
from .user_session import UserSession
|
||||
|
||||
__all__ = [
|
||||
'Base', 'TimestampMixin', 'UUIDMixin',
|
||||
'User', 'UserSession',
|
||||
'Organization', 'UserOrganization', 'OrganizationRole',
|
||||
]
|
||||
"Base",
|
||||
"Organization",
|
||||
"OrganizationRole",
|
||||
"TimestampMixin",
|
||||
"UUIDMixin",
|
||||
"User",
|
||||
"UserOrganization",
|
||||
"UserSession",
|
||||
]
|
||||
|
||||
@@ -1,20 +1,27 @@
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
from datetime import UTC, datetime
|
||||
|
||||
from sqlalchemy import Column, DateTime
|
||||
from sqlalchemy.dialects.postgresql import UUID
|
||||
|
||||
# noinspection PyUnresolvedReferences
|
||||
from app.core.database import Base
|
||||
|
||||
|
||||
class TimestampMixin:
|
||||
"""Mixin to add created_at and updated_at timestamps to models"""
|
||||
created_at = Column(DateTime(timezone=True), default=lambda: datetime.now(timezone.utc), nullable=False)
|
||||
updated_at = Column(DateTime(timezone=True), default=lambda: datetime.now(timezone.utc),
|
||||
onupdate=lambda: datetime.now(timezone.utc), nullable=False)
|
||||
|
||||
created_at = Column(
|
||||
DateTime(timezone=True), default=lambda: datetime.now(UTC), nullable=False
|
||||
)
|
||||
updated_at = Column(
|
||||
DateTime(timezone=True),
|
||||
default=lambda: datetime.now(UTC),
|
||||
onupdate=lambda: datetime.now(UTC),
|
||||
nullable=False,
|
||||
)
|
||||
|
||||
|
||||
class UUIDMixin:
|
||||
"""Mixin to add UUID primary keys to models"""
|
||||
|
||||
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
# app/models/organization.py
|
||||
from sqlalchemy import Column, String, Boolean, Text, Index
|
||||
from sqlalchemy import Boolean, Column, Index, String, Text
|
||||
from sqlalchemy.dialects.postgresql import JSONB
|
||||
from sqlalchemy.orm import relationship
|
||||
|
||||
@@ -11,7 +11,8 @@ class Organization(Base, UUIDMixin, TimestampMixin):
|
||||
Organization model for multi-tenant support.
|
||||
Users can belong to multiple organizations with different roles.
|
||||
"""
|
||||
__tablename__ = 'organizations'
|
||||
|
||||
__tablename__ = "organizations"
|
||||
|
||||
name = Column(String(255), nullable=False, index=True)
|
||||
slug = Column(String(255), unique=True, nullable=False, index=True)
|
||||
@@ -20,11 +21,13 @@ class Organization(Base, UUIDMixin, TimestampMixin):
|
||||
settings = Column(JSONB, default={})
|
||||
|
||||
# Relationships
|
||||
user_organizations = relationship("UserOrganization", back_populates="organization", cascade="all, delete-orphan")
|
||||
user_organizations = relationship(
|
||||
"UserOrganization", back_populates="organization", cascade="all, delete-orphan"
|
||||
)
|
||||
|
||||
__table_args__ = (
|
||||
Index('ix_organizations_name_active', 'name', 'is_active'),
|
||||
Index('ix_organizations_slug_active', 'slug', 'is_active'),
|
||||
Index("ix_organizations_name_active", "name", "is_active"),
|
||||
Index("ix_organizations_slug_active", "slug", "is_active"),
|
||||
)
|
||||
|
||||
def __repr__(self):
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from sqlalchemy import Column, String, Boolean, DateTime
|
||||
from sqlalchemy import Boolean, Column, DateTime, String
|
||||
from sqlalchemy.dialects.postgresql import JSONB
|
||||
from sqlalchemy.orm import relationship
|
||||
|
||||
@@ -6,7 +6,7 @@ from .base import Base, TimestampMixin, UUIDMixin
|
||||
|
||||
|
||||
class User(Base, UUIDMixin, TimestampMixin):
|
||||
__tablename__ = 'users'
|
||||
__tablename__ = "users"
|
||||
|
||||
email = Column(String(255), unique=True, nullable=False, index=True)
|
||||
password_hash = Column(String(255), nullable=False)
|
||||
@@ -19,7 +19,9 @@ class User(Base, UUIDMixin, TimestampMixin):
|
||||
deleted_at = Column(DateTime(timezone=True), nullable=True, index=True)
|
||||
|
||||
# Relationships
|
||||
user_organizations = relationship("UserOrganization", back_populates="user", cascade="all, delete-orphan")
|
||||
user_organizations = relationship(
|
||||
"UserOrganization", back_populates="user", cascade="all, delete-orphan"
|
||||
)
|
||||
|
||||
def __repr__(self):
|
||||
return f"<User {self.email}>"
|
||||
return f"<User {self.email}>"
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
# app/models/user_organization.py
|
||||
from enum import Enum as PyEnum
|
||||
|
||||
from sqlalchemy import Column, ForeignKey, Boolean, String, Index, Enum
|
||||
from sqlalchemy import Boolean, Column, Enum, ForeignKey, Index, String
|
||||
from sqlalchemy.dialects.postgresql import UUID as PGUUID
|
||||
from sqlalchemy.orm import relationship
|
||||
|
||||
@@ -14,6 +14,7 @@ class OrganizationRole(str, PyEnum):
|
||||
These provide a baseline role system that can be optionally used.
|
||||
Projects can extend this or implement their own permission system.
|
||||
"""
|
||||
|
||||
OWNER = "owner" # Full control over organization
|
||||
ADMIN = "admin" # Can manage users and settings
|
||||
MEMBER = "member" # Regular member with standard access
|
||||
@@ -25,25 +26,41 @@ class UserOrganization(Base, TimestampMixin):
|
||||
Junction table for many-to-many relationship between Users and Organizations.
|
||||
Includes role information for flexible RBAC.
|
||||
"""
|
||||
__tablename__ = 'user_organizations'
|
||||
|
||||
user_id = Column(PGUUID(as_uuid=True), ForeignKey('users.id', ondelete='CASCADE'), primary_key=True)
|
||||
organization_id = Column(PGUUID(as_uuid=True), ForeignKey('organizations.id', ondelete='CASCADE'), primary_key=True)
|
||||
__tablename__ = "user_organizations"
|
||||
|
||||
role = Column(Enum(OrganizationRole), default=OrganizationRole.MEMBER, nullable=False, index=True)
|
||||
user_id = Column(
|
||||
PGUUID(as_uuid=True),
|
||||
ForeignKey("users.id", ondelete="CASCADE"),
|
||||
primary_key=True,
|
||||
)
|
||||
organization_id = Column(
|
||||
PGUUID(as_uuid=True),
|
||||
ForeignKey("organizations.id", ondelete="CASCADE"),
|
||||
primary_key=True,
|
||||
)
|
||||
|
||||
role = Column(
|
||||
Enum(OrganizationRole),
|
||||
default=OrganizationRole.MEMBER,
|
||||
nullable=False,
|
||||
index=True,
|
||||
)
|
||||
is_active = Column(Boolean, default=True, nullable=False, index=True)
|
||||
|
||||
# Optional: Custom permissions override for specific users
|
||||
custom_permissions = Column(String(500), nullable=True) # JSON array of permission strings
|
||||
custom_permissions = Column(
|
||||
String(500), nullable=True
|
||||
) # JSON array of permission strings
|
||||
|
||||
# Relationships
|
||||
user = relationship("User", back_populates="user_organizations")
|
||||
organization = relationship("Organization", back_populates="user_organizations")
|
||||
|
||||
__table_args__ = (
|
||||
Index('ix_user_org_user_active', 'user_id', 'is_active'),
|
||||
Index('ix_user_org_org_active', 'organization_id', 'is_active'),
|
||||
Index('ix_user_org_role', 'role'),
|
||||
Index("ix_user_org_user_active", "user_id", "is_active"),
|
||||
Index("ix_user_org_org_active", "organization_id", "is_active"),
|
||||
Index("ix_user_org_role", "role"),
|
||||
)
|
||||
|
||||
def __repr__(self):
|
||||
|
||||
@@ -6,7 +6,10 @@ This allows users to:
|
||||
- Logout from specific devices
|
||||
- Manage their active sessions
|
||||
"""
|
||||
from sqlalchemy import Column, String, Boolean, DateTime, ForeignKey, Index
|
||||
|
||||
from datetime import UTC
|
||||
|
||||
from sqlalchemy import Boolean, Column, DateTime, ForeignKey, Index, String
|
||||
from sqlalchemy.dialects.postgresql import UUID
|
||||
from sqlalchemy.orm import relationship
|
||||
|
||||
@@ -20,19 +23,27 @@ class UserSession(Base, UUIDMixin, TimestampMixin):
|
||||
Each time a user logs in from a device, a new session is created.
|
||||
Sessions are identified by the refresh token JTI (JWT ID).
|
||||
"""
|
||||
__tablename__ = 'user_sessions'
|
||||
|
||||
__tablename__ = "user_sessions"
|
||||
|
||||
# Foreign key to user
|
||||
user_id = Column(UUID(as_uuid=True), ForeignKey('users.id', ondelete='CASCADE'), nullable=False, index=True)
|
||||
user_id = Column(
|
||||
UUID(as_uuid=True),
|
||||
ForeignKey("users.id", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
index=True,
|
||||
)
|
||||
|
||||
# Refresh token identifier (JWT ID from the refresh token)
|
||||
refresh_token_jti = Column(String(255), unique=True, nullable=False, index=True)
|
||||
|
||||
# Device information
|
||||
device_name = Column(String(255), nullable=True) # "iPhone 14", "Chrome on MacBook"
|
||||
device_id = Column(String(255), nullable=True) # Persistent device identifier (from client)
|
||||
ip_address = Column(String(45), nullable=True) # IPv4 (15 chars) or IPv6 (45 chars)
|
||||
user_agent = Column(String(500), nullable=True) # Browser/app user agent
|
||||
device_id = Column(
|
||||
String(255), nullable=True
|
||||
) # Persistent device identifier (from client)
|
||||
ip_address = Column(String(45), nullable=True) # IPv4 (15 chars) or IPv6 (45 chars)
|
||||
user_agent = Column(String(500), nullable=True) # Browser/app user agent
|
||||
|
||||
# Session timing
|
||||
last_used_at = Column(DateTime(timezone=True), nullable=False)
|
||||
@@ -50,8 +61,8 @@ class UserSession(Base, UUIDMixin, TimestampMixin):
|
||||
|
||||
# Composite indexes for performance (defined in migration)
|
||||
__table_args__ = (
|
||||
Index('ix_user_sessions_user_active', 'user_id', 'is_active'),
|
||||
Index('ix_user_sessions_jti_active', 'refresh_token_jti', 'is_active'),
|
||||
Index("ix_user_sessions_user_active", "user_id", "is_active"),
|
||||
Index("ix_user_sessions_jti_active", "refresh_token_jti", "is_active"),
|
||||
)
|
||||
|
||||
def __repr__(self):
|
||||
@@ -60,21 +71,24 @@ class UserSession(Base, UUIDMixin, TimestampMixin):
|
||||
@property
|
||||
def is_expired(self) -> bool:
|
||||
"""Check if session has expired."""
|
||||
from datetime import datetime, timezone
|
||||
return self.expires_at < datetime.now(timezone.utc)
|
||||
from datetime import datetime
|
||||
|
||||
return self.expires_at < datetime.now(UTC)
|
||||
|
||||
def to_dict(self):
|
||||
"""Convert session to dictionary for serialization."""
|
||||
return {
|
||||
'id': str(self.id),
|
||||
'user_id': str(self.user_id),
|
||||
'device_name': self.device_name,
|
||||
'device_id': self.device_id,
|
||||
'ip_address': self.ip_address,
|
||||
'last_used_at': self.last_used_at.isoformat() if self.last_used_at else None,
|
||||
'expires_at': self.expires_at.isoformat() if self.expires_at else None,
|
||||
'is_active': self.is_active,
|
||||
'location_city': self.location_city,
|
||||
'location_country': self.location_country,
|
||||
'created_at': self.created_at.isoformat() if self.created_at else None,
|
||||
"id": str(self.id),
|
||||
"user_id": str(self.user_id),
|
||||
"device_name": self.device_name,
|
||||
"device_id": self.device_id,
|
||||
"ip_address": self.ip_address,
|
||||
"last_used_at": self.last_used_at.isoformat()
|
||||
if self.last_used_at
|
||||
else None,
|
||||
"expires_at": self.expires_at.isoformat() if self.expires_at else None,
|
||||
"is_active": self.is_active,
|
||||
"location_city": self.location_city,
|
||||
"location_country": self.location_country,
|
||||
"created_at": self.created_at.isoformat() if self.created_at else None,
|
||||
}
|
||||
|
||||
@@ -1,18 +1,20 @@
|
||||
"""
|
||||
Common schemas used across the API for pagination, responses, filtering, and sorting.
|
||||
"""
|
||||
|
||||
from enum import Enum
|
||||
from math import ceil
|
||||
from typing import Generic, TypeVar, List, Optional
|
||||
from typing import TypeVar
|
||||
from uuid import UUID
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
T = TypeVar('T')
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
class SortOrder(str, Enum):
|
||||
"""Sort order options."""
|
||||
|
||||
ASC = "asc"
|
||||
DESC = "desc"
|
||||
|
||||
@@ -20,16 +22,9 @@ class SortOrder(str, Enum):
|
||||
class PaginationParams(BaseModel):
|
||||
"""Parameters for pagination."""
|
||||
|
||||
page: int = Field(
|
||||
default=1,
|
||||
ge=1,
|
||||
description="Page number (1-indexed)"
|
||||
)
|
||||
page: int = Field(default=1, ge=1, description="Page number (1-indexed)")
|
||||
limit: int = Field(
|
||||
default=20,
|
||||
ge=1,
|
||||
le=100,
|
||||
description="Number of items per page (max 100)"
|
||||
default=20, ge=1, le=100, description="Number of items per page (max 100)"
|
||||
)
|
||||
|
||||
@property
|
||||
@@ -42,34 +37,20 @@ class PaginationParams(BaseModel):
|
||||
"""Alias for offset (compatibility with existing code)."""
|
||||
return self.offset
|
||||
|
||||
model_config = {
|
||||
"json_schema_extra": {
|
||||
"example": {
|
||||
"page": 1,
|
||||
"limit": 20
|
||||
}
|
||||
}
|
||||
}
|
||||
model_config = {"json_schema_extra": {"example": {"page": 1, "limit": 20}}}
|
||||
|
||||
|
||||
class SortParams(BaseModel):
|
||||
"""Parameters for sorting."""
|
||||
|
||||
sort_by: Optional[str] = Field(
|
||||
default=None,
|
||||
description="Field name to sort by"
|
||||
)
|
||||
sort_by: str | None = Field(default=None, description="Field name to sort by")
|
||||
sort_order: SortOrder = Field(
|
||||
default=SortOrder.ASC,
|
||||
description="Sort order (asc or desc)"
|
||||
default=SortOrder.ASC, description="Sort order (asc or desc)"
|
||||
)
|
||||
|
||||
model_config = {
|
||||
"json_schema_extra": {
|
||||
"example": {
|
||||
"sort_by": "created_at",
|
||||
"sort_order": "desc"
|
||||
}
|
||||
"example": {"sort_by": "created_at", "sort_order": "desc"}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -92,32 +73,30 @@ class PaginationMeta(BaseModel):
|
||||
"page_size": 20,
|
||||
"total_pages": 8,
|
||||
"has_next": True,
|
||||
"has_prev": False
|
||||
"has_prev": False,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
class PaginatedResponse(BaseModel, Generic[T]):
|
||||
class PaginatedResponse[T](BaseModel):
|
||||
"""Generic paginated response wrapper."""
|
||||
|
||||
data: List[T] = Field(..., description="List of items")
|
||||
data: list[T] = Field(..., description="List of items")
|
||||
pagination: PaginationMeta = Field(..., description="Pagination metadata")
|
||||
|
||||
model_config = {
|
||||
"json_schema_extra": {
|
||||
"example": {
|
||||
"data": [
|
||||
{"id": "123", "name": "Example Item"}
|
||||
],
|
||||
"data": [{"id": "123", "name": "Example Item"}],
|
||||
"pagination": {
|
||||
"total": 150,
|
||||
"page": 1,
|
||||
"page_size": 20,
|
||||
"total_pages": 8,
|
||||
"has_next": True,
|
||||
"has_prev": False
|
||||
}
|
||||
"has_prev": False,
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -131,10 +110,7 @@ class MessageResponse(BaseModel):
|
||||
|
||||
model_config = {
|
||||
"json_schema_extra": {
|
||||
"example": {
|
||||
"success": True,
|
||||
"message": "Operation completed successfully"
|
||||
}
|
||||
"example": {"success": True, "message": "Operation completed successfully"}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -142,11 +118,11 @@ class MessageResponse(BaseModel):
|
||||
class BulkActionRequest(BaseModel):
|
||||
"""Request schema for bulk operations on multiple items."""
|
||||
|
||||
ids: List[UUID] = Field(
|
||||
ids: list[UUID] = Field(
|
||||
...,
|
||||
min_length=1,
|
||||
max_length=100,
|
||||
description="List of item IDs to perform action on (max 100)"
|
||||
description="List of item IDs to perform action on (max 100)",
|
||||
)
|
||||
|
||||
model_config = {
|
||||
@@ -154,7 +130,7 @@ class BulkActionRequest(BaseModel):
|
||||
"example": {
|
||||
"ids": [
|
||||
"550e8400-e29b-41d4-a716-446655440000",
|
||||
"6ba7b810-9dad-11d1-80b4-00c04fd430c8"
|
||||
"6ba7b810-9dad-11d1-80b4-00c04fd430c8",
|
||||
]
|
||||
}
|
||||
}
|
||||
@@ -166,24 +142,23 @@ class BulkActionResponse(BaseModel):
|
||||
|
||||
success: bool = Field(default=True, description="Operation success status")
|
||||
message: str = Field(..., description="Human-readable message")
|
||||
affected_count: int = Field(..., description="Number of items affected by the operation")
|
||||
affected_count: int = Field(
|
||||
..., description="Number of items affected by the operation"
|
||||
)
|
||||
|
||||
model_config = {
|
||||
"json_schema_extra": {
|
||||
"example": {
|
||||
"success": True,
|
||||
"message": "Successfully deactivated 5 users",
|
||||
"affected_count": 5
|
||||
"affected_count": 5,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
def create_pagination_meta(
|
||||
total: int,
|
||||
page: int,
|
||||
limit: int,
|
||||
items_count: int
|
||||
total: int, page: int, limit: int, items_count: int
|
||||
) -> PaginationMeta:
|
||||
"""
|
||||
Helper function to create pagination metadata.
|
||||
@@ -205,5 +180,5 @@ def create_pagination_meta(
|
||||
page_size=items_count,
|
||||
total_pages=total_pages,
|
||||
has_next=page < total_pages,
|
||||
has_prev=page > 1
|
||||
has_prev=page > 1,
|
||||
)
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
"""
|
||||
Error schemas for standardized API error responses.
|
||||
"""
|
||||
|
||||
from enum import Enum
|
||||
from typing import List, Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
@@ -53,14 +53,14 @@ class ErrorDetail(BaseModel):
|
||||
|
||||
code: ErrorCode = Field(..., description="Machine-readable error code")
|
||||
message: str = Field(..., description="Human-readable error message")
|
||||
field: Optional[str] = Field(None, description="Field name if error is field-specific")
|
||||
field: str | None = Field(None, description="Field name if error is field-specific")
|
||||
|
||||
model_config = {
|
||||
"json_schema_extra": {
|
||||
"example": {
|
||||
"code": "VAL_002",
|
||||
"message": "Password must be at least 8 characters long",
|
||||
"field": "password"
|
||||
"field": "password",
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -70,7 +70,7 @@ class ErrorResponse(BaseModel):
|
||||
"""Standardized error response format."""
|
||||
|
||||
success: bool = Field(default=False, description="Always false for error responses")
|
||||
errors: List[ErrorDetail] = Field(..., description="List of errors that occurred")
|
||||
errors: list[ErrorDetail] = Field(..., description="List of errors that occurred")
|
||||
|
||||
model_config = {
|
||||
"json_schema_extra": {
|
||||
@@ -80,9 +80,9 @@ class ErrorResponse(BaseModel):
|
||||
{
|
||||
"code": "AUTH_001",
|
||||
"message": "Invalid email or password",
|
||||
"field": None
|
||||
"field": None,
|
||||
}
|
||||
]
|
||||
],
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
# app/schemas/organizations.py
|
||||
import re
|
||||
from datetime import datetime
|
||||
from typing import Optional, Dict, Any, List
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
from pydantic import BaseModel, field_validator, ConfigDict, Field
|
||||
from pydantic import BaseModel, ConfigDict, Field, field_validator
|
||||
|
||||
from app.models.user_organization import OrganizationRole
|
||||
|
||||
@@ -12,85 +12,94 @@ from app.models.user_organization import OrganizationRole
|
||||
# Organization Schemas
|
||||
class OrganizationBase(BaseModel):
|
||||
"""Base organization schema with common fields."""
|
||||
name: str = Field(..., min_length=1, max_length=255)
|
||||
slug: Optional[str] = Field(None, min_length=1, max_length=255)
|
||||
description: Optional[str] = None
|
||||
is_active: bool = True
|
||||
settings: Optional[Dict[str, Any]] = {}
|
||||
|
||||
@field_validator('slug')
|
||||
name: str = Field(..., min_length=1, max_length=255)
|
||||
slug: str | None = Field(None, min_length=1, max_length=255)
|
||||
description: str | None = None
|
||||
is_active: bool = True
|
||||
settings: dict[str, Any] | None = {}
|
||||
|
||||
@field_validator("slug")
|
||||
@classmethod
|
||||
def validate_slug(cls, v: Optional[str]) -> Optional[str]:
|
||||
def validate_slug(cls, v: str | None) -> str | None:
|
||||
"""Validate slug format: lowercase, alphanumeric, hyphens only."""
|
||||
if v is None:
|
||||
return v
|
||||
if not re.match(r'^[a-z0-9-]+$', v):
|
||||
raise ValueError('Slug must contain only lowercase letters, numbers, and hyphens')
|
||||
if v.startswith('-') or v.endswith('-'):
|
||||
raise ValueError('Slug cannot start or end with a hyphen')
|
||||
if '--' in v:
|
||||
raise ValueError('Slug cannot contain consecutive hyphens')
|
||||
if not re.match(r"^[a-z0-9-]+$", v):
|
||||
raise ValueError(
|
||||
"Slug must contain only lowercase letters, numbers, and hyphens"
|
||||
)
|
||||
if v.startswith("-") or v.endswith("-"):
|
||||
raise ValueError("Slug cannot start or end with a hyphen")
|
||||
if "--" in v:
|
||||
raise ValueError("Slug cannot contain consecutive hyphens")
|
||||
return v
|
||||
|
||||
@field_validator('name')
|
||||
@field_validator("name")
|
||||
@classmethod
|
||||
def validate_name(cls, v: str) -> str:
|
||||
"""Validate organization name."""
|
||||
if not v or v.strip() == "":
|
||||
raise ValueError('Organization name cannot be empty')
|
||||
raise ValueError("Organization name cannot be empty")
|
||||
return v.strip()
|
||||
|
||||
|
||||
class OrganizationCreate(OrganizationBase):
|
||||
"""Schema for creating a new organization."""
|
||||
|
||||
name: str = Field(..., min_length=1, max_length=255)
|
||||
slug: str = Field(..., min_length=1, max_length=255)
|
||||
|
||||
|
||||
class OrganizationUpdate(BaseModel):
|
||||
"""Schema for updating an organization."""
|
||||
name: Optional[str] = Field(None, min_length=1, max_length=255)
|
||||
slug: Optional[str] = Field(None, min_length=1, max_length=255)
|
||||
description: Optional[str] = None
|
||||
is_active: Optional[bool] = None
|
||||
settings: Optional[Dict[str, Any]] = None
|
||||
|
||||
@field_validator('slug')
|
||||
name: str | None = Field(None, min_length=1, max_length=255)
|
||||
slug: str | None = Field(None, min_length=1, max_length=255)
|
||||
description: str | None = None
|
||||
is_active: bool | None = None
|
||||
settings: dict[str, Any] | None = None
|
||||
|
||||
@field_validator("slug")
|
||||
@classmethod
|
||||
def validate_slug(cls, v: Optional[str]) -> Optional[str]:
|
||||
def validate_slug(cls, v: str | None) -> str | None:
|
||||
"""Validate slug format."""
|
||||
if v is None:
|
||||
return v
|
||||
if not re.match(r'^[a-z0-9-]+$', v):
|
||||
raise ValueError('Slug must contain only lowercase letters, numbers, and hyphens')
|
||||
if v.startswith('-') or v.endswith('-'):
|
||||
raise ValueError('Slug cannot start or end with a hyphen')
|
||||
if '--' in v:
|
||||
raise ValueError('Slug cannot contain consecutive hyphens')
|
||||
if not re.match(r"^[a-z0-9-]+$", v):
|
||||
raise ValueError(
|
||||
"Slug must contain only lowercase letters, numbers, and hyphens"
|
||||
)
|
||||
if v.startswith("-") or v.endswith("-"):
|
||||
raise ValueError("Slug cannot start or end with a hyphen")
|
||||
if "--" in v:
|
||||
raise ValueError("Slug cannot contain consecutive hyphens")
|
||||
return v
|
||||
|
||||
@field_validator('name')
|
||||
@field_validator("name")
|
||||
@classmethod
|
||||
def validate_name(cls, v: Optional[str]) -> Optional[str]:
|
||||
def validate_name(cls, v: str | None) -> str | None:
|
||||
"""Validate organization name."""
|
||||
if v is not None and (not v or v.strip() == ""):
|
||||
raise ValueError('Organization name cannot be empty')
|
||||
raise ValueError("Organization name cannot be empty")
|
||||
return v.strip() if v else v
|
||||
|
||||
|
||||
class OrganizationResponse(OrganizationBase):
|
||||
"""Schema for organization API responses."""
|
||||
|
||||
id: UUID
|
||||
created_at: datetime
|
||||
updated_at: Optional[datetime] = None
|
||||
member_count: Optional[int] = 0
|
||||
updated_at: datetime | None = None
|
||||
member_count: int | None = 0
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
|
||||
class OrganizationListResponse(BaseModel):
|
||||
"""Schema for paginated organization list responses."""
|
||||
organizations: List[OrganizationResponse]
|
||||
|
||||
organizations: list[OrganizationResponse]
|
||||
total: int
|
||||
page: int
|
||||
page_size: int
|
||||
@@ -100,44 +109,49 @@ class OrganizationListResponse(BaseModel):
|
||||
# User-Organization Relationship Schemas
|
||||
class UserOrganizationBase(BaseModel):
|
||||
"""Base schema for user-organization relationship."""
|
||||
|
||||
role: OrganizationRole = OrganizationRole.MEMBER
|
||||
is_active: bool = True
|
||||
custom_permissions: Optional[str] = None
|
||||
custom_permissions: str | None = None
|
||||
|
||||
|
||||
class UserOrganizationCreate(BaseModel):
|
||||
"""Schema for adding a user to an organization."""
|
||||
|
||||
user_id: UUID
|
||||
role: OrganizationRole = OrganizationRole.MEMBER
|
||||
custom_permissions: Optional[str] = None
|
||||
custom_permissions: str | None = None
|
||||
|
||||
|
||||
class UserOrganizationUpdate(BaseModel):
|
||||
"""Schema for updating user's role in an organization."""
|
||||
role: Optional[OrganizationRole] = None
|
||||
is_active: Optional[bool] = None
|
||||
custom_permissions: Optional[str] = None
|
||||
|
||||
role: OrganizationRole | None = None
|
||||
is_active: bool | None = None
|
||||
custom_permissions: str | None = None
|
||||
|
||||
|
||||
class UserOrganizationResponse(BaseModel):
|
||||
"""Schema for user-organization relationship responses."""
|
||||
|
||||
user_id: UUID
|
||||
organization_id: UUID
|
||||
role: OrganizationRole
|
||||
is_active: bool
|
||||
custom_permissions: Optional[str] = None
|
||||
custom_permissions: str | None = None
|
||||
created_at: datetime
|
||||
updated_at: Optional[datetime] = None
|
||||
updated_at: datetime | None = None
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
|
||||
class OrganizationMemberResponse(BaseModel):
|
||||
"""Schema for organization member information."""
|
||||
|
||||
user_id: UUID
|
||||
email: str
|
||||
first_name: str
|
||||
last_name: Optional[str] = None
|
||||
last_name: str | None = None
|
||||
role: OrganizationRole
|
||||
is_active: bool
|
||||
joined_at: datetime
|
||||
@@ -147,7 +161,8 @@ class OrganizationMemberResponse(BaseModel):
|
||||
|
||||
class OrganizationMemberListResponse(BaseModel):
|
||||
"""Schema for paginated organization member list."""
|
||||
members: List[OrganizationMemberResponse]
|
||||
|
||||
members: list[OrganizationMemberResponse]
|
||||
total: int
|
||||
page: int
|
||||
page_size: int
|
||||
|
||||
@@ -1,37 +1,44 @@
|
||||
"""
|
||||
Pydantic schemas for user session management.
|
||||
"""
|
||||
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
from uuid import UUID
|
||||
|
||||
from pydantic import BaseModel, Field, ConfigDict
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
|
||||
class SessionBase(BaseModel):
|
||||
"""Base schema for user sessions."""
|
||||
device_name: Optional[str] = Field(None, max_length=255, description="Friendly device name")
|
||||
device_id: Optional[str] = Field(None, max_length=255, description="Persistent device identifier")
|
||||
|
||||
device_name: str | None = Field(
|
||||
None, max_length=255, description="Friendly device name"
|
||||
)
|
||||
device_id: str | None = Field(
|
||||
None, max_length=255, description="Persistent device identifier"
|
||||
)
|
||||
|
||||
|
||||
class SessionCreate(SessionBase):
|
||||
"""Schema for creating a new session (internal use)."""
|
||||
|
||||
user_id: UUID
|
||||
refresh_token_jti: str = Field(..., max_length=255)
|
||||
ip_address: Optional[str] = Field(None, max_length=45)
|
||||
user_agent: Optional[str] = Field(None, max_length=500)
|
||||
ip_address: str | None = Field(None, max_length=45)
|
||||
user_agent: str | None = Field(None, max_length=500)
|
||||
last_used_at: datetime
|
||||
expires_at: datetime
|
||||
location_city: Optional[str] = Field(None, max_length=100)
|
||||
location_country: Optional[str] = Field(None, max_length=100)
|
||||
location_city: str | None = Field(None, max_length=100)
|
||||
location_country: str | None = Field(None, max_length=100)
|
||||
|
||||
|
||||
class SessionUpdate(BaseModel):
|
||||
"""Schema for updating a session (internal use)."""
|
||||
last_used_at: Optional[datetime] = None
|
||||
is_active: Optional[bool] = None
|
||||
refresh_token_jti: Optional[str] = None
|
||||
expires_at: Optional[datetime] = None
|
||||
|
||||
last_used_at: datetime | None = None
|
||||
is_active: bool | None = None
|
||||
refresh_token_jti: str | None = None
|
||||
expires_at: datetime | None = None
|
||||
|
||||
|
||||
class SessionResponse(SessionBase):
|
||||
@@ -40,14 +47,17 @@ class SessionResponse(SessionBase):
|
||||
|
||||
This is what users see when they list their active sessions.
|
||||
"""
|
||||
|
||||
id: UUID
|
||||
ip_address: Optional[str] = None
|
||||
location_city: Optional[str] = None
|
||||
location_country: Optional[str] = None
|
||||
ip_address: str | None = None
|
||||
location_city: str | None = None
|
||||
location_country: str | None = None
|
||||
last_used_at: datetime
|
||||
created_at: datetime
|
||||
expires_at: datetime
|
||||
is_current: bool = Field(default=False, description="Whether this is the current session")
|
||||
is_current: bool = Field(
|
||||
default=False, description="Whether this is the current session"
|
||||
)
|
||||
|
||||
model_config = ConfigDict(
|
||||
from_attributes=True,
|
||||
@@ -62,14 +72,15 @@ class SessionResponse(SessionBase):
|
||||
"last_used_at": "2025-10-31T12:00:00Z",
|
||||
"created_at": "2025-10-30T09:00:00Z",
|
||||
"expires_at": "2025-11-06T09:00:00Z",
|
||||
"is_current": True
|
||||
"is_current": True,
|
||||
}
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
class SessionListResponse(BaseModel):
|
||||
"""Response containing list of sessions."""
|
||||
|
||||
sessions: list[SessionResponse]
|
||||
total: int = Field(..., description="Total number of active sessions")
|
||||
|
||||
@@ -84,10 +95,10 @@ class SessionListResponse(BaseModel):
|
||||
"last_used_at": "2025-10-31T12:00:00Z",
|
||||
"created_at": "2025-10-30T09:00:00Z",
|
||||
"expires_at": "2025-11-06T09:00:00Z",
|
||||
"is_current": True
|
||||
"is_current": True,
|
||||
}
|
||||
],
|
||||
"total": 1
|
||||
"total": 1,
|
||||
}
|
||||
}
|
||||
)
|
||||
@@ -95,17 +106,14 @@ class SessionListResponse(BaseModel):
|
||||
|
||||
class LogoutRequest(BaseModel):
|
||||
"""Request schema for logout endpoint."""
|
||||
|
||||
refresh_token: str = Field(
|
||||
...,
|
||||
description="Refresh token for the session to logout from",
|
||||
min_length=10
|
||||
..., description="Refresh token for the session to logout from", min_length=10
|
||||
)
|
||||
|
||||
model_config = ConfigDict(
|
||||
json_schema_extra={
|
||||
"example": {
|
||||
"refresh_token": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9..."
|
||||
}
|
||||
"example": {"refresh_token": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9..."}
|
||||
}
|
||||
)
|
||||
|
||||
@@ -116,13 +124,14 @@ class AdminSessionResponse(SessionBase):
|
||||
|
||||
Includes user information for admin to see who owns each session.
|
||||
"""
|
||||
|
||||
id: UUID
|
||||
user_id: UUID
|
||||
user_email: str = Field(..., description="Email of the user who owns this session")
|
||||
user_full_name: Optional[str] = Field(None, description="Full name of the user")
|
||||
ip_address: Optional[str] = None
|
||||
location_city: Optional[str] = None
|
||||
location_country: Optional[str] = None
|
||||
user_full_name: str | None = Field(None, description="Full name of the user")
|
||||
ip_address: str | None = None
|
||||
location_city: str | None = None
|
||||
location_country: str | None = None
|
||||
last_used_at: datetime
|
||||
created_at: datetime
|
||||
expires_at: datetime
|
||||
@@ -144,20 +153,21 @@ class AdminSessionResponse(SessionBase):
|
||||
"last_used_at": "2025-10-31T12:00:00Z",
|
||||
"created_at": "2025-10-30T09:00:00Z",
|
||||
"expires_at": "2025-11-06T09:00:00Z",
|
||||
"is_active": True
|
||||
"is_active": True,
|
||||
}
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
class DeviceInfo(BaseModel):
|
||||
"""Device information extracted from request."""
|
||||
device_name: Optional[str] = None
|
||||
device_id: Optional[str] = None
|
||||
ip_address: Optional[str] = None
|
||||
user_agent: Optional[str] = None
|
||||
location_city: Optional[str] = None
|
||||
location_country: Optional[str] = None
|
||||
|
||||
device_name: str | None = None
|
||||
device_id: str | None = None
|
||||
ip_address: str | None = None
|
||||
user_agent: str | None = None
|
||||
location_city: str | None = None
|
||||
location_country: str | None = None
|
||||
|
||||
model_config = ConfigDict(
|
||||
json_schema_extra={
|
||||
@@ -167,7 +177,7 @@ class DeviceInfo(BaseModel):
|
||||
"ip_address": "192.168.1.50",
|
||||
"user_agent": "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7)...",
|
||||
"location_city": "San Francisco",
|
||||
"location_country": "United States"
|
||||
"location_country": "United States",
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
# app/schemas/users.py
|
||||
from datetime import datetime
|
||||
from typing import Optional, Dict, Any
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
from pydantic import BaseModel, EmailStr, field_validator, ConfigDict, Field
|
||||
from pydantic import BaseModel, ConfigDict, EmailStr, Field, field_validator
|
||||
|
||||
from app.schemas.validators import validate_password_strength, validate_phone_number
|
||||
|
||||
@@ -11,12 +11,12 @@ from app.schemas.validators import validate_password_strength, validate_phone_nu
|
||||
class UserBase(BaseModel):
|
||||
email: EmailStr
|
||||
first_name: str
|
||||
last_name: Optional[str] = None
|
||||
phone_number: Optional[str] = None
|
||||
last_name: str | None = None
|
||||
phone_number: str | None = None
|
||||
|
||||
@field_validator('phone_number')
|
||||
@field_validator("phone_number")
|
||||
@classmethod
|
||||
def validate_phone(cls, v: Optional[str]) -> Optional[str]:
|
||||
def validate_phone(cls, v: str | None) -> str | None:
|
||||
return validate_phone_number(v)
|
||||
|
||||
|
||||
@@ -24,7 +24,7 @@ class UserCreate(UserBase):
|
||||
password: str
|
||||
is_superuser: bool = False
|
||||
|
||||
@field_validator('password')
|
||||
@field_validator("password")
|
||||
@classmethod
|
||||
def password_strength(cls, v: str) -> str:
|
||||
"""Enterprise-grade password strength validation"""
|
||||
@@ -32,30 +32,32 @@ class UserCreate(UserBase):
|
||||
|
||||
|
||||
class UserUpdate(BaseModel):
|
||||
first_name: Optional[str] = None
|
||||
last_name: Optional[str] = None
|
||||
phone_number: Optional[str] = None
|
||||
password: Optional[str] = None
|
||||
preferences: Optional[Dict[str, Any]] = None
|
||||
is_active: Optional[bool] = None # Changed default from True to None to avoid unintended updates
|
||||
is_superuser: Optional[bool] = None # Explicitly reject privilege escalation attempts
|
||||
first_name: str | None = None
|
||||
last_name: str | None = None
|
||||
phone_number: str | None = None
|
||||
password: str | None = None
|
||||
preferences: dict[str, Any] | None = None
|
||||
is_active: bool | None = (
|
||||
None # Changed default from True to None to avoid unintended updates
|
||||
)
|
||||
is_superuser: bool | None = None # Explicitly reject privilege escalation attempts
|
||||
|
||||
@field_validator('phone_number')
|
||||
@field_validator("phone_number")
|
||||
@classmethod
|
||||
def validate_phone(cls, v: Optional[str]) -> Optional[str]:
|
||||
def validate_phone(cls, v: str | None) -> str | None:
|
||||
return validate_phone_number(v)
|
||||
|
||||
@field_validator('password')
|
||||
@field_validator("password")
|
||||
@classmethod
|
||||
def password_strength(cls, v: Optional[str]) -> Optional[str]:
|
||||
def password_strength(cls, v: str | None) -> str | None:
|
||||
"""Enterprise-grade password strength validation"""
|
||||
if v is None:
|
||||
return v
|
||||
return validate_password_strength(v)
|
||||
|
||||
@field_validator('is_superuser')
|
||||
@field_validator("is_superuser")
|
||||
@classmethod
|
||||
def prevent_superuser_modification(cls, v: Optional[bool]) -> Optional[bool]:
|
||||
def prevent_superuser_modification(cls, v: bool | None) -> bool | None:
|
||||
"""Prevent users from modifying their superuser status via this schema."""
|
||||
if v is not None:
|
||||
raise ValueError("Cannot modify superuser status through user update")
|
||||
@@ -67,7 +69,7 @@ class UserInDB(UserBase):
|
||||
is_active: bool
|
||||
is_superuser: bool
|
||||
created_at: datetime
|
||||
updated_at: Optional[datetime] = None
|
||||
updated_at: datetime | None = None
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
@@ -77,28 +79,28 @@ class UserResponse(UserBase):
|
||||
is_active: bool
|
||||
is_superuser: bool
|
||||
created_at: datetime
|
||||
updated_at: Optional[datetime] = None
|
||||
updated_at: datetime | None = None
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
|
||||
class Token(BaseModel):
|
||||
access_token: str
|
||||
refresh_token: Optional[str] = None
|
||||
refresh_token: str | None = None
|
||||
token_type: str = "bearer"
|
||||
user: "UserResponse" # Forward reference since UserResponse is defined above
|
||||
expires_in: Optional[int] = None # Token expiration in seconds
|
||||
expires_in: int | None = None # Token expiration in seconds
|
||||
|
||||
|
||||
class TokenPayload(BaseModel):
|
||||
sub: str # User ID
|
||||
exp: int # Expiration time
|
||||
iat: Optional[int] = None # Issued at
|
||||
jti: Optional[str] = None # JWT ID
|
||||
is_superuser: Optional[bool] = False
|
||||
first_name: Optional[str] = None
|
||||
email: Optional[str] = None
|
||||
type: Optional[str] = None # Token type (access/refresh)
|
||||
iat: int | None = None # Issued at
|
||||
jti: str | None = None # JWT ID
|
||||
is_superuser: bool | None = False
|
||||
first_name: str | None = None
|
||||
email: str | None = None
|
||||
type: str | None = None # Token type (access/refresh)
|
||||
|
||||
|
||||
class TokenData(BaseModel):
|
||||
@@ -108,10 +110,11 @@ class TokenData(BaseModel):
|
||||
|
||||
class PasswordChange(BaseModel):
|
||||
"""Schema for changing password (requires current password)."""
|
||||
|
||||
current_password: str
|
||||
new_password: str
|
||||
|
||||
@field_validator('new_password')
|
||||
@field_validator("new_password")
|
||||
@classmethod
|
||||
def password_strength(cls, v: str) -> str:
|
||||
"""Enterprise-grade password strength validation"""
|
||||
@@ -120,10 +123,11 @@ class PasswordChange(BaseModel):
|
||||
|
||||
class PasswordReset(BaseModel):
|
||||
"""Schema for resetting password (via email token)."""
|
||||
|
||||
token: str
|
||||
new_password: str
|
||||
|
||||
@field_validator('new_password')
|
||||
@field_validator("new_password")
|
||||
@classmethod
|
||||
def password_strength(cls, v: str) -> str:
|
||||
"""Enterprise-grade password strength validation"""
|
||||
@@ -141,23 +145,19 @@ class RefreshTokenRequest(BaseModel):
|
||||
|
||||
class PasswordResetRequest(BaseModel):
|
||||
"""Schema for requesting a password reset."""
|
||||
|
||||
email: EmailStr = Field(..., description="Email address of the account")
|
||||
|
||||
model_config = {
|
||||
"json_schema_extra": {
|
||||
"example": {
|
||||
"email": "user@example.com"
|
||||
}
|
||||
}
|
||||
}
|
||||
model_config = {"json_schema_extra": {"example": {"email": "user@example.com"}}}
|
||||
|
||||
|
||||
class PasswordResetConfirm(BaseModel):
|
||||
"""Schema for confirming a password reset with token."""
|
||||
|
||||
token: str = Field(..., description="Password reset token from email")
|
||||
new_password: str = Field(..., min_length=8, description="New password")
|
||||
|
||||
@field_validator('new_password')
|
||||
@field_validator("new_password")
|
||||
@classmethod
|
||||
def password_strength(cls, v: str) -> str:
|
||||
"""Enterprise-grade password strength validation"""
|
||||
@@ -167,7 +167,7 @@ class PasswordResetConfirm(BaseModel):
|
||||
"json_schema_extra": {
|
||||
"example": {
|
||||
"token": "eyJwYXlsb2FkIjp7ImVtYWlsIjoidXNlckBleGFtcGxlLmNvbSIsImV4cCI6MTcxMjM0NTY3OH19",
|
||||
"new_password": "NewSecurePassword123"
|
||||
"new_password": "NewSecurePassword123",
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -4,19 +4,34 @@ Shared validators for Pydantic schemas.
|
||||
This module provides reusable validation functions to ensure consistency
|
||||
across all schemas and avoid code duplication.
|
||||
"""
|
||||
|
||||
import re
|
||||
from typing import Set
|
||||
|
||||
# Common weak passwords that should be rejected
|
||||
COMMON_PASSWORDS: Set[str] = {
|
||||
'password', 'password1', 'password123', 'password1234',
|
||||
'admin', 'admin123', 'admin1234',
|
||||
'welcome', 'welcome1', 'welcome123',
|
||||
'qwerty', 'qwerty123',
|
||||
'12345678', '123456789', '1234567890',
|
||||
'letmein', 'letmein1', 'letmein123',
|
||||
'monkey123', 'dragon123',
|
||||
'passw0rd', 'p@ssw0rd', 'p@ssword',
|
||||
COMMON_PASSWORDS: set[str] = {
|
||||
"password",
|
||||
"password1",
|
||||
"password123",
|
||||
"password1234",
|
||||
"admin",
|
||||
"admin123",
|
||||
"admin1234",
|
||||
"welcome",
|
||||
"welcome1",
|
||||
"welcome123",
|
||||
"qwerty",
|
||||
"qwerty123",
|
||||
"12345678",
|
||||
"123456789",
|
||||
"1234567890",
|
||||
"letmein",
|
||||
"letmein1",
|
||||
"letmein123",
|
||||
"monkey123",
|
||||
"dragon123",
|
||||
"passw0rd",
|
||||
"p@ssw0rd",
|
||||
"p@ssword",
|
||||
}
|
||||
|
||||
|
||||
@@ -47,18 +62,21 @@ def validate_password_strength(password: str) -> str:
|
||||
"""
|
||||
# Check minimum length
|
||||
if len(password) < 12:
|
||||
raise ValueError('Password must be at least 12 characters long')
|
||||
raise ValueError("Password must be at least 12 characters long")
|
||||
|
||||
# Check against common passwords (case-insensitive)
|
||||
if password.lower() in COMMON_PASSWORDS:
|
||||
raise ValueError('Password is too common. Please choose a stronger password')
|
||||
raise ValueError("Password is too common. Please choose a stronger password")
|
||||
|
||||
# Check for required character types
|
||||
checks = [
|
||||
(any(c.islower() for c in password), 'at least one lowercase letter'),
|
||||
(any(c.isupper() for c in password), 'at least one uppercase letter'),
|
||||
(any(c.isdigit() for c in password), 'at least one digit'),
|
||||
(any(c in '!@#$%^&*()_+-=[]{}|;:,.<>?~`' for c in password), 'at least one special character (!@#$%^&*()_+-=[]{}|;:,.<>?~`)')
|
||||
(any(c.islower() for c in password), "at least one lowercase letter"),
|
||||
(any(c.isupper() for c in password), "at least one uppercase letter"),
|
||||
(any(c.isdigit() for c in password), "at least one digit"),
|
||||
(
|
||||
any(c in "!@#$%^&*()_+-=[]{}|;:,.<>?~`" for c in password),
|
||||
"at least one special character (!@#$%^&*()_+-=[]{}|;:,.<>?~`)",
|
||||
),
|
||||
]
|
||||
|
||||
failed = [msg for check, msg in checks if not check]
|
||||
@@ -94,10 +112,10 @@ def validate_phone_number(phone: str | None) -> str | None:
|
||||
|
||||
# Check for empty strings
|
||||
if not phone or phone.strip() == "":
|
||||
raise ValueError('Phone number cannot be empty')
|
||||
raise ValueError("Phone number cannot be empty")
|
||||
|
||||
# Remove all spaces and formatting characters
|
||||
cleaned = re.sub(r'[\s\-\(\)]', '', phone)
|
||||
cleaned = re.sub(r"[\s\-\(\)]", "", phone)
|
||||
|
||||
# Basic pattern:
|
||||
# Must start with + or 0
|
||||
@@ -105,19 +123,19 @@ def validate_phone_number(phone: str | None) -> str | None:
|
||||
# After 0 must have at least 8 digits
|
||||
# Maximum total length of 15 digits (international standard)
|
||||
# Only allowed characters are + at start and digits
|
||||
pattern = r'^(?:\+[0-9]{8,14}|0[0-9]{8,14})$'
|
||||
pattern = r"^(?:\+[0-9]{8,14}|0[0-9]{8,14})$"
|
||||
|
||||
if not re.match(pattern, cleaned):
|
||||
raise ValueError('Phone number must start with + or 0 followed by 8-14 digits')
|
||||
raise ValueError("Phone number must start with + or 0 followed by 8-14 digits")
|
||||
|
||||
# Additional validation to catch specific invalid cases
|
||||
# NOTE: These checks are defensive code - the regex pattern above already catches these cases
|
||||
if cleaned.count('+') > 1: # pragma: no cover
|
||||
raise ValueError('Phone number can only contain one + symbol at the start')
|
||||
if cleaned.count("+") > 1: # pragma: no cover
|
||||
raise ValueError("Phone number can only contain one + symbol at the start")
|
||||
|
||||
# Check for any non-digit characters (except the leading +)
|
||||
if not all(c.isdigit() for c in cleaned[1:]): # pragma: no cover
|
||||
raise ValueError('Phone number can only contain digits after the prefix')
|
||||
raise ValueError("Phone number can only contain digits after the prefix")
|
||||
|
||||
return cleaned
|
||||
|
||||
@@ -169,16 +187,16 @@ def validate_slug(slug: str) -> str:
|
||||
ValueError: If slug format is invalid
|
||||
"""
|
||||
if not slug or len(slug) < 2:
|
||||
raise ValueError('Slug must be at least 2 characters long')
|
||||
raise ValueError("Slug must be at least 2 characters long")
|
||||
|
||||
if len(slug) > 50:
|
||||
raise ValueError('Slug must be at most 50 characters long')
|
||||
raise ValueError("Slug must be at most 50 characters long")
|
||||
|
||||
# Check format
|
||||
if not re.match(r'^[a-z0-9]+(?:-[a-z0-9]+)*$', slug):
|
||||
if not re.match(r"^[a-z0-9]+(?:-[a-z0-9]+)*$", slug):
|
||||
raise ValueError(
|
||||
'Slug can only contain lowercase letters, numbers, and hyphens. '
|
||||
'It cannot start or end with a hyphen, and cannot contain consecutive hyphens'
|
||||
"Slug can only contain lowercase letters, numbers, and hyphens. "
|
||||
"It cannot start or end with a hyphen, and cannot contain consecutive hyphens"
|
||||
)
|
||||
|
||||
return slug
|
||||
|
||||
@@ -1,18 +1,17 @@
|
||||
# app/services/auth_service.py
|
||||
import logging
|
||||
from typing import Optional
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.core.auth import (
|
||||
verify_password_async,
|
||||
get_password_hash_async,
|
||||
TokenExpiredError,
|
||||
TokenInvalidError,
|
||||
create_access_token,
|
||||
create_refresh_token,
|
||||
TokenExpiredError,
|
||||
TokenInvalidError
|
||||
get_password_hash_async,
|
||||
verify_password_async,
|
||||
)
|
||||
from app.core.config import settings
|
||||
from app.core.exceptions import AuthenticationError
|
||||
@@ -26,7 +25,9 @@ class AuthService:
|
||||
"""Service for handling authentication operations"""
|
||||
|
||||
@staticmethod
|
||||
async def authenticate_user(db: AsyncSession, email: str, password: str) -> Optional[User]:
|
||||
async def authenticate_user(
|
||||
db: AsyncSession, email: str, password: str
|
||||
) -> User | None:
|
||||
"""
|
||||
Authenticate a user with email and password using async password verification.
|
||||
|
||||
@@ -87,7 +88,7 @@ class AuthService:
|
||||
last_name=user_data.last_name,
|
||||
phone_number=user_data.phone_number,
|
||||
is_active=True,
|
||||
is_superuser=False
|
||||
is_superuser=False,
|
||||
)
|
||||
|
||||
db.add(user)
|
||||
@@ -103,8 +104,8 @@ class AuthService:
|
||||
except Exception as e:
|
||||
# Rollback on any database errors
|
||||
await db.rollback()
|
||||
logger.error(f"Error creating user: {str(e)}", exc_info=True)
|
||||
raise AuthenticationError(f"Failed to create user: {str(e)}")
|
||||
logger.error(f"Error creating user: {e!s}", exc_info=True)
|
||||
raise AuthenticationError(f"Failed to create user: {e!s}")
|
||||
|
||||
@staticmethod
|
||||
def create_tokens(user: User) -> Token:
|
||||
@@ -121,18 +122,13 @@ class AuthService:
|
||||
claims = {
|
||||
"is_superuser": user.is_superuser,
|
||||
"email": user.email,
|
||||
"first_name": user.first_name
|
||||
"first_name": user.first_name,
|
||||
}
|
||||
|
||||
# Create tokens
|
||||
access_token = create_access_token(
|
||||
subject=str(user.id),
|
||||
claims=claims
|
||||
)
|
||||
access_token = create_access_token(subject=str(user.id), claims=claims)
|
||||
|
||||
refresh_token = create_refresh_token(
|
||||
subject=str(user.id)
|
||||
)
|
||||
refresh_token = create_refresh_token(subject=str(user.id))
|
||||
|
||||
# Convert User model to UserResponse schema
|
||||
user_response = UserResponse.model_validate(user)
|
||||
@@ -141,7 +137,8 @@ class AuthService:
|
||||
access_token=access_token,
|
||||
refresh_token=refresh_token,
|
||||
user=user_response,
|
||||
expires_in=settings.ACCESS_TOKEN_EXPIRE_MINUTES * 60 # Convert minutes to seconds
|
||||
expires_in=settings.ACCESS_TOKEN_EXPIRE_MINUTES
|
||||
* 60, # Convert minutes to seconds
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
@@ -180,11 +177,13 @@ class AuthService:
|
||||
return AuthService.create_tokens(user)
|
||||
|
||||
except (TokenExpiredError, TokenInvalidError) as e:
|
||||
logger.warning(f"Token refresh failed: {str(e)}")
|
||||
logger.warning(f"Token refresh failed: {e!s}")
|
||||
raise
|
||||
|
||||
@staticmethod
|
||||
async def change_password(db: AsyncSession, user_id: UUID, current_password: str, new_password: str) -> bool:
|
||||
async def change_password(
|
||||
db: AsyncSession, user_id: UUID, current_password: str, new_password: str
|
||||
) -> bool:
|
||||
"""
|
||||
Change a user's password.
|
||||
|
||||
@@ -223,5 +222,7 @@ class AuthService:
|
||||
except Exception as e:
|
||||
# Rollback on any database errors
|
||||
await db.rollback()
|
||||
logger.error(f"Error changing password for user {user_id}: {str(e)}", exc_info=True)
|
||||
raise AuthenticationError(f"Failed to change password: {str(e)}")
|
||||
logger.error(
|
||||
f"Error changing password for user {user_id}: {e!s}", exc_info=True
|
||||
)
|
||||
raise AuthenticationError(f"Failed to change password: {e!s}")
|
||||
|
||||
@@ -5,9 +5,9 @@ Email service with placeholder implementation.
|
||||
This service provides email sending functionality with a simple console/log-based
|
||||
placeholder that can be easily replaced with a real email provider (SendGrid, SES, etc.)
|
||||
"""
|
||||
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List, Optional
|
||||
|
||||
from app.core.config import settings
|
||||
|
||||
@@ -20,13 +20,12 @@ class EmailBackend(ABC):
|
||||
@abstractmethod
|
||||
async def send_email(
|
||||
self,
|
||||
to: List[str],
|
||||
to: list[str],
|
||||
subject: str,
|
||||
html_content: str,
|
||||
text_content: Optional[str] = None
|
||||
text_content: str | None = None,
|
||||
) -> bool:
|
||||
"""Send an email."""
|
||||
pass
|
||||
|
||||
|
||||
class ConsoleEmailBackend(EmailBackend):
|
||||
@@ -39,10 +38,10 @@ class ConsoleEmailBackend(EmailBackend):
|
||||
|
||||
async def send_email(
|
||||
self,
|
||||
to: List[str],
|
||||
to: list[str],
|
||||
subject: str,
|
||||
html_content: str,
|
||||
text_content: Optional[str] = None
|
||||
text_content: str | None = None,
|
||||
) -> bool:
|
||||
"""
|
||||
Log email content to console/logs.
|
||||
@@ -88,10 +87,10 @@ class SMTPEmailBackend(EmailBackend):
|
||||
|
||||
async def send_email(
|
||||
self,
|
||||
to: List[str],
|
||||
to: list[str],
|
||||
subject: str,
|
||||
html_content: str,
|
||||
text_content: Optional[str] = None
|
||||
text_content: str | None = None,
|
||||
) -> bool:
|
||||
"""Send email via SMTP."""
|
||||
# TODO: Implement SMTP sending
|
||||
@@ -108,7 +107,7 @@ class EmailService:
|
||||
and can be configured to use different backends (console, SMTP, SendGrid, etc.)
|
||||
"""
|
||||
|
||||
def __init__(self, backend: Optional[EmailBackend] = None):
|
||||
def __init__(self, backend: EmailBackend | None = None):
|
||||
"""
|
||||
Initialize email service with a backend.
|
||||
|
||||
@@ -118,10 +117,7 @@ class EmailService:
|
||||
self.backend = backend or ConsoleEmailBackend()
|
||||
|
||||
async def send_password_reset_email(
|
||||
self,
|
||||
to_email: str,
|
||||
reset_token: str,
|
||||
user_name: Optional[str] = None
|
||||
self, to_email: str, reset_token: str, user_name: str | None = None
|
||||
) -> bool:
|
||||
"""
|
||||
Send password reset email.
|
||||
@@ -142,7 +138,7 @@ class EmailService:
|
||||
|
||||
# Plain text version
|
||||
text_content = f"""
|
||||
Hello{' ' + user_name if user_name else ''},
|
||||
Hello{" " + user_name if user_name else ""},
|
||||
|
||||
You requested a password reset for your account. Click the link below to reset your password:
|
||||
|
||||
@@ -177,7 +173,7 @@ The {settings.PROJECT_NAME} Team
|
||||
<h1>Password Reset</h1>
|
||||
</div>
|
||||
<div class="content">
|
||||
<p>Hello{' ' + user_name if user_name else ''},</p>
|
||||
<p>Hello{" " + user_name if user_name else ""},</p>
|
||||
<p>You requested a password reset for your account. Click the button below to reset your password:</p>
|
||||
<p style="text-align: center;">
|
||||
<a href="{reset_url}" class="button">Reset Password</a>
|
||||
@@ -200,17 +196,14 @@ The {settings.PROJECT_NAME} Team
|
||||
to=[to_email],
|
||||
subject=subject,
|
||||
html_content=html_content,
|
||||
text_content=text_content
|
||||
text_content=text_content,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to send password reset email to {to_email}: {str(e)}")
|
||||
logger.error(f"Failed to send password reset email to {to_email}: {e!s}")
|
||||
return False
|
||||
|
||||
async def send_email_verification(
|
||||
self,
|
||||
to_email: str,
|
||||
verification_token: str,
|
||||
user_name: Optional[str] = None
|
||||
self, to_email: str, verification_token: str, user_name: str | None = None
|
||||
) -> bool:
|
||||
"""
|
||||
Send email verification email.
|
||||
@@ -224,14 +217,16 @@ The {settings.PROJECT_NAME} Team
|
||||
True if email sent successfully
|
||||
"""
|
||||
# Generate verification URL
|
||||
verification_url = f"{settings.FRONTEND_URL}/verify-email?token={verification_token}"
|
||||
verification_url = (
|
||||
f"{settings.FRONTEND_URL}/verify-email?token={verification_token}"
|
||||
)
|
||||
|
||||
# Prepare email content
|
||||
subject = "Verify Your Email Address"
|
||||
|
||||
# Plain text version
|
||||
text_content = f"""
|
||||
Hello{' ' + user_name if user_name else ''},
|
||||
Hello{" " + user_name if user_name else ""},
|
||||
|
||||
Thank you for signing up! Please verify your email address by clicking the link below:
|
||||
|
||||
@@ -266,7 +261,7 @@ The {settings.PROJECT_NAME} Team
|
||||
<h1>Verify Your Email</h1>
|
||||
</div>
|
||||
<div class="content">
|
||||
<p>Hello{' ' + user_name if user_name else ''},</p>
|
||||
<p>Hello{" " + user_name if user_name else ""},</p>
|
||||
<p>Thank you for signing up! Please verify your email address by clicking the button below:</p>
|
||||
<p style="text-align: center;">
|
||||
<a href="{verification_url}" class="button">Verify Email</a>
|
||||
@@ -289,10 +284,10 @@ The {settings.PROJECT_NAME} Team
|
||||
to=[to_email],
|
||||
subject=subject,
|
||||
html_content=html_content,
|
||||
text_content=text_content
|
||||
text_content=text_content,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to send verification email to {to_email}: {str(e)}")
|
||||
logger.error(f"Failed to send verification email to {to_email}: {e!s}")
|
||||
return False
|
||||
|
||||
|
||||
|
||||
@@ -3,8 +3,9 @@ Background job for cleaning up expired sessions.
|
||||
|
||||
This service runs periodically to remove old session records from the database.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from datetime import datetime, timezone
|
||||
from datetime import UTC, datetime
|
||||
|
||||
from app.core.database import SessionLocal
|
||||
from app.crud.session import session as session_crud
|
||||
@@ -39,7 +40,7 @@ async def cleanup_expired_sessions(keep_days: int = 30) -> int:
|
||||
return count
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error during session cleanup: {str(e)}", exc_info=True)
|
||||
logger.error(f"Error during session cleanup: {e!s}", exc_info=True)
|
||||
return 0
|
||||
|
||||
|
||||
@@ -52,20 +53,21 @@ async def get_session_statistics() -> dict:
|
||||
"""
|
||||
async with SessionLocal() as db:
|
||||
try:
|
||||
from sqlalchemy import func, select
|
||||
|
||||
from app.models.user_session import UserSession
|
||||
from sqlalchemy import select, func
|
||||
|
||||
total_result = await db.execute(select(func.count(UserSession.id)))
|
||||
total_sessions = total_result.scalar_one()
|
||||
|
||||
active_result = await db.execute(
|
||||
select(func.count(UserSession.id)).where(UserSession.is_active == True)
|
||||
select(func.count(UserSession.id)).where(UserSession.is_active)
|
||||
)
|
||||
active_sessions = active_result.scalar_one()
|
||||
|
||||
expired_result = await db.execute(
|
||||
select(func.count(UserSession.id)).where(
|
||||
UserSession.expires_at < datetime.now(timezone.utc)
|
||||
UserSession.expires_at < datetime.now(UTC)
|
||||
)
|
||||
)
|
||||
expired_sessions = expired_result.scalar_one()
|
||||
@@ -82,5 +84,5 @@ async def get_session_statistics() -> dict:
|
||||
return stats
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting session statistics: {str(e)}", exc_info=True)
|
||||
logger.error(f"Error getting session statistics: {e!s}", exc_info=True)
|
||||
return {}
|
||||
|
||||
@@ -2,7 +2,8 @@
|
||||
Authentication utilities for testing.
|
||||
This module provides tools to bypass FastAPI's authentication in tests.
|
||||
"""
|
||||
from typing import Callable, Dict, Optional
|
||||
|
||||
from collections.abc import Callable
|
||||
|
||||
from fastapi import FastAPI
|
||||
from fastapi.security import OAuth2PasswordBearer
|
||||
@@ -13,9 +14,9 @@ from app.models.user import User
|
||||
|
||||
|
||||
def create_test_auth_client(
|
||||
app: FastAPI,
|
||||
test_user: User,
|
||||
extra_overrides: Optional[Dict[Callable, Callable]] = None
|
||||
app: FastAPI,
|
||||
test_user: User,
|
||||
extra_overrides: dict[Callable, Callable] | None = None,
|
||||
) -> TestClient:
|
||||
"""
|
||||
Create a test client with authentication pre-configured.
|
||||
@@ -47,10 +48,7 @@ def create_test_auth_client(
|
||||
return TestClient(app)
|
||||
|
||||
|
||||
def create_test_optional_auth_client(
|
||||
app: FastAPI,
|
||||
test_user: User
|
||||
) -> TestClient:
|
||||
def create_test_optional_auth_client(app: FastAPI, test_user: User) -> TestClient:
|
||||
"""
|
||||
Create a test client with optional authentication pre-configured.
|
||||
|
||||
@@ -70,10 +68,7 @@ def create_test_optional_auth_client(
|
||||
return TestClient(app)
|
||||
|
||||
|
||||
def create_test_superuser_client(
|
||||
app: FastAPI,
|
||||
test_user: User
|
||||
) -> TestClient:
|
||||
def create_test_superuser_client(app: FastAPI, test_user: User) -> TestClient:
|
||||
"""
|
||||
Create a test client with superuser authentication pre-configured.
|
||||
|
||||
@@ -120,7 +115,7 @@ def cleanup_test_client_auth(app: FastAPI) -> None:
|
||||
auth_deps = [
|
||||
get_current_user,
|
||||
get_optional_current_user,
|
||||
OAuth2PasswordBearer(tokenUrl="/api/v1/auth/login")
|
||||
OAuth2PasswordBearer(tokenUrl="/api/v1/auth/login"),
|
||||
]
|
||||
|
||||
# Remove overrides
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
"""
|
||||
Utility functions for extracting and parsing device information from HTTP requests.
|
||||
"""
|
||||
|
||||
import re
|
||||
from typing import Optional
|
||||
|
||||
from fastapi import Request
|
||||
|
||||
@@ -19,11 +19,11 @@ def extract_device_info(request: Request) -> DeviceInfo:
|
||||
Returns:
|
||||
DeviceInfo object with parsed device information
|
||||
"""
|
||||
user_agent = request.headers.get('user-agent', '')
|
||||
user_agent = request.headers.get("user-agent", "")
|
||||
|
||||
device_info = DeviceInfo(
|
||||
device_name=parse_device_name(user_agent),
|
||||
device_id=request.headers.get('x-device-id'), # Client must send this header
|
||||
device_id=request.headers.get("x-device-id"), # Client must send this header
|
||||
ip_address=get_client_ip(request),
|
||||
user_agent=user_agent[:500] if user_agent else None, # Truncate to max length
|
||||
location_city=None, # Can be populated via IP geolocation service
|
||||
@@ -33,7 +33,7 @@ def extract_device_info(request: Request) -> DeviceInfo:
|
||||
return device_info
|
||||
|
||||
|
||||
def parse_device_name(user_agent: str) -> Optional[str]:
|
||||
def parse_device_name(user_agent: str) -> str | None:
|
||||
"""
|
||||
Parse user agent string to extract a friendly device name.
|
||||
|
||||
@@ -54,48 +54,48 @@ def parse_device_name(user_agent: str) -> Optional[str]:
|
||||
user_agent_lower = user_agent.lower()
|
||||
|
||||
# Mobile devices (check first, as they can contain desktop patterns too)
|
||||
if 'iphone' in user_agent_lower:
|
||||
if "iphone" in user_agent_lower:
|
||||
return "iPhone"
|
||||
elif 'ipad' in user_agent_lower:
|
||||
elif "ipad" in user_agent_lower:
|
||||
return "iPad"
|
||||
elif 'android' in user_agent_lower:
|
||||
elif "android" in user_agent_lower:
|
||||
# Try to extract device model
|
||||
android_match = re.search(r'android.*;\s*([^)]+)\s*build', user_agent_lower)
|
||||
android_match = re.search(r"android.*;\s*([^)]+)\s*build", user_agent_lower)
|
||||
if android_match:
|
||||
device_model = android_match.group(1).strip()
|
||||
return f"Android ({device_model.title()})"
|
||||
return "Android device"
|
||||
elif 'windows phone' in user_agent_lower:
|
||||
elif "windows phone" in user_agent_lower:
|
||||
return "Windows Phone"
|
||||
|
||||
# Tablets (check before desktop, as some tablets contain "android")
|
||||
elif 'tablet' in user_agent_lower:
|
||||
elif "tablet" in user_agent_lower:
|
||||
return "Tablet"
|
||||
|
||||
# Smart TVs (check before desktop OS patterns)
|
||||
elif any(tv in user_agent_lower for tv in ['smart-tv', 'smarttv']):
|
||||
elif any(tv in user_agent_lower for tv in ["smart-tv", "smarttv"]):
|
||||
return "Smart TV"
|
||||
|
||||
# Game consoles (check before desktop OS patterns, as Xbox contains "Windows")
|
||||
elif 'playstation' in user_agent_lower:
|
||||
elif "playstation" in user_agent_lower:
|
||||
return "PlayStation"
|
||||
elif 'xbox' in user_agent_lower:
|
||||
elif "xbox" in user_agent_lower:
|
||||
return "Xbox"
|
||||
elif 'nintendo' in user_agent_lower:
|
||||
elif "nintendo" in user_agent_lower:
|
||||
return "Nintendo"
|
||||
|
||||
# Desktop operating systems
|
||||
elif 'macintosh' in user_agent_lower or 'mac os x' in user_agent_lower:
|
||||
elif "macintosh" in user_agent_lower or "mac os x" in user_agent_lower:
|
||||
# Try to extract browser
|
||||
browser = extract_browser(user_agent)
|
||||
return f"{browser} on Mac" if browser else "Mac"
|
||||
elif 'windows' in user_agent_lower:
|
||||
elif "windows" in user_agent_lower:
|
||||
browser = extract_browser(user_agent)
|
||||
return f"{browser} on Windows" if browser else "Windows PC"
|
||||
elif 'linux' in user_agent_lower and 'android' not in user_agent_lower:
|
||||
elif "linux" in user_agent_lower and "android" not in user_agent_lower:
|
||||
browser = extract_browser(user_agent)
|
||||
return f"{browser} on Linux" if browser else "Linux"
|
||||
elif 'cros' in user_agent_lower:
|
||||
elif "cros" in user_agent_lower:
|
||||
return "Chromebook"
|
||||
|
||||
# Fallback: just return browser name if detected
|
||||
@@ -106,7 +106,7 @@ def parse_device_name(user_agent: str) -> Optional[str]:
|
||||
return "Unknown device"
|
||||
|
||||
|
||||
def extract_browser(user_agent: str) -> Optional[str]:
|
||||
def extract_browser(user_agent: str) -> str | None:
|
||||
"""
|
||||
Extract browser name from user agent string.
|
||||
|
||||
@@ -126,26 +126,26 @@ def extract_browser(user_agent: str) -> Optional[str]:
|
||||
user_agent_lower = user_agent.lower()
|
||||
|
||||
# Check specific browsers (order matters - check Edge before Chrome!)
|
||||
if 'edg/' in user_agent_lower or 'edge/' in user_agent_lower:
|
||||
if "edg/" in user_agent_lower or "edge/" in user_agent_lower:
|
||||
return "Edge"
|
||||
elif 'opr/' in user_agent_lower or 'opera' in user_agent_lower:
|
||||
elif "opr/" in user_agent_lower or "opera" in user_agent_lower:
|
||||
return "Opera"
|
||||
elif 'chrome/' in user_agent_lower:
|
||||
elif "chrome/" in user_agent_lower:
|
||||
return "Chrome"
|
||||
elif 'safari/' in user_agent_lower:
|
||||
elif "safari/" in user_agent_lower:
|
||||
# Make sure it's actually Safari, not Chrome (which also contains "Safari")
|
||||
if 'chrome' not in user_agent_lower:
|
||||
if "chrome" not in user_agent_lower:
|
||||
return "Safari"
|
||||
return None
|
||||
elif 'firefox/' in user_agent_lower:
|
||||
elif "firefox/" in user_agent_lower:
|
||||
return "Firefox"
|
||||
elif 'msie' in user_agent_lower or 'trident/' in user_agent_lower:
|
||||
elif "msie" in user_agent_lower or "trident/" in user_agent_lower:
|
||||
return "Internet Explorer"
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def get_client_ip(request: Request) -> Optional[str]:
|
||||
def get_client_ip(request: Request) -> str | None:
|
||||
"""
|
||||
Extract client IP address from request, considering proxy headers.
|
||||
|
||||
@@ -163,14 +163,14 @@ def get_client_ip(request: Request) -> Optional[str]:
|
||||
- request.client.host is fallback for direct connections
|
||||
"""
|
||||
# Check X-Forwarded-For (common in proxied environments)
|
||||
x_forwarded_for = request.headers.get('x-forwarded-for')
|
||||
x_forwarded_for = request.headers.get("x-forwarded-for")
|
||||
if x_forwarded_for:
|
||||
# Get the first IP (original client)
|
||||
client_ip = x_forwarded_for.split(',')[0].strip()
|
||||
client_ip = x_forwarded_for.split(",")[0].strip()
|
||||
return client_ip
|
||||
|
||||
# Check X-Real-IP (used by some proxies like nginx)
|
||||
x_real_ip = request.headers.get('x-real-ip')
|
||||
x_real_ip = request.headers.get("x-real-ip")
|
||||
if x_real_ip:
|
||||
return x_real_ip.strip()
|
||||
|
||||
@@ -195,9 +195,17 @@ def is_mobile_device(user_agent: str) -> bool:
|
||||
return False
|
||||
|
||||
mobile_patterns = [
|
||||
'mobile', 'android', 'iphone', 'ipad', 'ipod',
|
||||
'blackberry', 'windows phone', 'webos', 'opera mini',
|
||||
'iemobile', 'mobile safari'
|
||||
"mobile",
|
||||
"android",
|
||||
"iphone",
|
||||
"ipad",
|
||||
"ipod",
|
||||
"blackberry",
|
||||
"windows phone",
|
||||
"webos",
|
||||
"opera mini",
|
||||
"iemobile",
|
||||
"mobile safari",
|
||||
]
|
||||
|
||||
user_agent_lower = user_agent.lower()
|
||||
@@ -220,7 +228,7 @@ def get_device_type(user_agent: str) -> str:
|
||||
user_agent_lower = user_agent.lower()
|
||||
|
||||
# Check for tablets first (they can contain "mobile" too)
|
||||
if 'ipad' in user_agent_lower or 'tablet' in user_agent_lower:
|
||||
if "ipad" in user_agent_lower or "tablet" in user_agent_lower:
|
||||
return "tablet"
|
||||
|
||||
# Check for mobile
|
||||
@@ -228,7 +236,7 @@ def get_device_type(user_agent: str) -> str:
|
||||
return "mobile"
|
||||
|
||||
# Check for desktop OS patterns
|
||||
if any(os in user_agent_lower for os in ['windows', 'macintosh', 'linux', 'cros']):
|
||||
if any(os in user_agent_lower for os in ["windows", "macintosh", "linux", "cros"]):
|
||||
return "desktop"
|
||||
|
||||
return "other"
|
||||
|
||||
@@ -5,18 +5,21 @@ This module provides utilities for creating and verifying signed tokens,
|
||||
useful for operations like file uploads, password resets, or any other
|
||||
time-limited, single-use operations.
|
||||
"""
|
||||
|
||||
import base64
|
||||
import hashlib
|
||||
import hmac
|
||||
import json
|
||||
import secrets
|
||||
import time
|
||||
from typing import Dict, Any, Optional
|
||||
from typing import Any
|
||||
|
||||
from app.core.config import settings
|
||||
|
||||
|
||||
def create_upload_token(file_path: str, content_type: str, expires_in: int = 300) -> str:
|
||||
def create_upload_token(
|
||||
file_path: str, content_type: str, expires_in: int = 300
|
||||
) -> str:
|
||||
"""
|
||||
Create a signed token for secure file uploads.
|
||||
|
||||
@@ -40,34 +43,29 @@ def create_upload_token(file_path: str, content_type: str, expires_in: int = 300
|
||||
"path": file_path,
|
||||
"content_type": content_type,
|
||||
"exp": int(time.time()) + expires_in,
|
||||
"nonce": secrets.token_hex(8) # Add randomness to prevent token reuse
|
||||
"nonce": secrets.token_hex(8), # Add randomness to prevent token reuse
|
||||
}
|
||||
|
||||
# Convert to JSON and encode
|
||||
payload_bytes = json.dumps(payload).encode('utf-8')
|
||||
payload_bytes = json.dumps(payload).encode("utf-8")
|
||||
|
||||
# Create a signature using HMAC-SHA256 for security
|
||||
# This prevents length extension attacks that plain SHA-256 is vulnerable to
|
||||
signature = hmac.new(
|
||||
settings.SECRET_KEY.encode('utf-8'),
|
||||
payload_bytes,
|
||||
hashlib.sha256
|
||||
settings.SECRET_KEY.encode("utf-8"), payload_bytes, hashlib.sha256
|
||||
).hexdigest()
|
||||
|
||||
# Combine payload and signature
|
||||
token_data = {
|
||||
"payload": payload,
|
||||
"signature": signature
|
||||
}
|
||||
token_data = {"payload": payload, "signature": signature}
|
||||
|
||||
# Encode the final token
|
||||
token_json = json.dumps(token_data)
|
||||
token = base64.urlsafe_b64encode(token_json.encode('utf-8')).decode('utf-8')
|
||||
token = base64.urlsafe_b64encode(token_json.encode("utf-8")).decode("utf-8")
|
||||
|
||||
return token
|
||||
|
||||
|
||||
def verify_upload_token(token: str) -> Optional[Dict[str, Any]]:
|
||||
def verify_upload_token(token: str) -> dict[str, Any] | None:
|
||||
"""
|
||||
Verify an upload token and return the payload if valid.
|
||||
|
||||
@@ -88,7 +86,7 @@ def verify_upload_token(token: str) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
try:
|
||||
# Decode the token
|
||||
token_json = base64.urlsafe_b64decode(token.encode('utf-8')).decode('utf-8')
|
||||
token_json = base64.urlsafe_b64decode(token.encode("utf-8")).decode("utf-8")
|
||||
token_data = json.loads(token_json)
|
||||
|
||||
# Extract payload and signature
|
||||
@@ -96,11 +94,9 @@ def verify_upload_token(token: str) -> Optional[Dict[str, Any]]:
|
||||
signature = token_data["signature"]
|
||||
|
||||
# Verify signature using HMAC and constant-time comparison
|
||||
payload_bytes = json.dumps(payload).encode('utf-8')
|
||||
payload_bytes = json.dumps(payload).encode("utf-8")
|
||||
expected_signature = hmac.new(
|
||||
settings.SECRET_KEY.encode('utf-8'),
|
||||
payload_bytes,
|
||||
hashlib.sha256
|
||||
settings.SECRET_KEY.encode("utf-8"), payload_bytes, hashlib.sha256
|
||||
).hexdigest()
|
||||
|
||||
if not hmac.compare_digest(signature, expected_signature):
|
||||
@@ -136,34 +132,29 @@ def create_password_reset_token(email: str, expires_in: int = 3600) -> str:
|
||||
"email": email,
|
||||
"exp": int(time.time()) + expires_in,
|
||||
"nonce": secrets.token_hex(16), # Extra randomness
|
||||
"purpose": "password_reset"
|
||||
"purpose": "password_reset",
|
||||
}
|
||||
|
||||
# Convert to JSON and encode
|
||||
payload_bytes = json.dumps(payload).encode('utf-8')
|
||||
payload_bytes = json.dumps(payload).encode("utf-8")
|
||||
|
||||
# Create a signature using HMAC-SHA256 for security
|
||||
# This prevents length extension attacks that plain SHA-256 is vulnerable to
|
||||
signature = hmac.new(
|
||||
settings.SECRET_KEY.encode('utf-8'),
|
||||
payload_bytes,
|
||||
hashlib.sha256
|
||||
settings.SECRET_KEY.encode("utf-8"), payload_bytes, hashlib.sha256
|
||||
).hexdigest()
|
||||
|
||||
# Combine payload and signature
|
||||
token_data = {
|
||||
"payload": payload,
|
||||
"signature": signature
|
||||
}
|
||||
token_data = {"payload": payload, "signature": signature}
|
||||
|
||||
# Encode the final token
|
||||
token_json = json.dumps(token_data)
|
||||
token = base64.urlsafe_b64encode(token_json.encode('utf-8')).decode('utf-8')
|
||||
token = base64.urlsafe_b64encode(token_json.encode("utf-8")).decode("utf-8")
|
||||
|
||||
return token
|
||||
|
||||
|
||||
def verify_password_reset_token(token: str) -> Optional[str]:
|
||||
def verify_password_reset_token(token: str) -> str | None:
|
||||
"""
|
||||
Verify a password reset token and return the email if valid.
|
||||
|
||||
@@ -182,7 +173,7 @@ def verify_password_reset_token(token: str) -> Optional[str]:
|
||||
"""
|
||||
try:
|
||||
# Decode the token
|
||||
token_json = base64.urlsafe_b64decode(token.encode('utf-8')).decode('utf-8')
|
||||
token_json = base64.urlsafe_b64decode(token.encode("utf-8")).decode("utf-8")
|
||||
token_data = json.loads(token_json)
|
||||
|
||||
# Extract payload and signature
|
||||
@@ -194,11 +185,9 @@ def verify_password_reset_token(token: str) -> Optional[str]:
|
||||
return None
|
||||
|
||||
# Verify signature using HMAC and constant-time comparison
|
||||
payload_bytes = json.dumps(payload).encode('utf-8')
|
||||
payload_bytes = json.dumps(payload).encode("utf-8")
|
||||
expected_signature = hmac.new(
|
||||
settings.SECRET_KEY.encode('utf-8'),
|
||||
payload_bytes,
|
||||
hashlib.sha256
|
||||
settings.SECRET_KEY.encode("utf-8"), payload_bytes, hashlib.sha256
|
||||
).hexdigest()
|
||||
|
||||
if not hmac.compare_digest(signature, expected_signature):
|
||||
@@ -234,34 +223,29 @@ def create_email_verification_token(email: str, expires_in: int = 86400) -> str:
|
||||
"email": email,
|
||||
"exp": int(time.time()) + expires_in,
|
||||
"nonce": secrets.token_hex(16),
|
||||
"purpose": "email_verification"
|
||||
"purpose": "email_verification",
|
||||
}
|
||||
|
||||
# Convert to JSON and encode
|
||||
payload_bytes = json.dumps(payload).encode('utf-8')
|
||||
payload_bytes = json.dumps(payload).encode("utf-8")
|
||||
|
||||
# Create a signature using HMAC-SHA256 for security
|
||||
# This prevents length extension attacks that plain SHA-256 is vulnerable to
|
||||
signature = hmac.new(
|
||||
settings.SECRET_KEY.encode('utf-8'),
|
||||
payload_bytes,
|
||||
hashlib.sha256
|
||||
settings.SECRET_KEY.encode("utf-8"), payload_bytes, hashlib.sha256
|
||||
).hexdigest()
|
||||
|
||||
# Combine payload and signature
|
||||
token_data = {
|
||||
"payload": payload,
|
||||
"signature": signature
|
||||
}
|
||||
token_data = {"payload": payload, "signature": signature}
|
||||
|
||||
# Encode the final token
|
||||
token_json = json.dumps(token_data)
|
||||
token = base64.urlsafe_b64encode(token_json.encode('utf-8')).decode('utf-8')
|
||||
token = base64.urlsafe_b64encode(token_json.encode("utf-8")).decode("utf-8")
|
||||
|
||||
return token
|
||||
|
||||
|
||||
def verify_email_verification_token(token: str) -> Optional[str]:
|
||||
def verify_email_verification_token(token: str) -> str | None:
|
||||
"""
|
||||
Verify an email verification token and return the email if valid.
|
||||
|
||||
@@ -280,7 +264,7 @@ def verify_email_verification_token(token: str) -> Optional[str]:
|
||||
"""
|
||||
try:
|
||||
# Decode the token
|
||||
token_json = base64.urlsafe_b64decode(token.encode('utf-8')).decode('utf-8')
|
||||
token_json = base64.urlsafe_b64decode(token.encode("utf-8")).decode("utf-8")
|
||||
token_data = json.loads(token_json)
|
||||
|
||||
# Extract payload and signature
|
||||
@@ -292,11 +276,9 @@ def verify_email_verification_token(token: str) -> Optional[str]:
|
||||
return None
|
||||
|
||||
# Verify signature using HMAC and constant-time comparison
|
||||
payload_bytes = json.dumps(payload).encode('utf-8')
|
||||
payload_bytes = json.dumps(payload).encode("utf-8")
|
||||
expected_signature = hmac.new(
|
||||
settings.SECRET_KEY.encode('utf-8'),
|
||||
payload_bytes,
|
||||
hashlib.sha256
|
||||
settings.SECRET_KEY.encode("utf-8"), payload_bytes, hashlib.sha256
|
||||
).hexdigest()
|
||||
|
||||
if not hmac.compare_digest(signature, expected_signature):
|
||||
|
||||
@@ -9,17 +9,19 @@ from app.core.database import Base
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def get_test_engine():
|
||||
"""Create an SQLite in-memory engine specifically for testing"""
|
||||
test_engine = create_engine(
|
||||
"sqlite:///:memory:",
|
||||
connect_args={"check_same_thread": False},
|
||||
poolclass=StaticPool, # Use static pool for in-memory testing
|
||||
echo=False
|
||||
echo=False,
|
||||
)
|
||||
|
||||
return test_engine
|
||||
|
||||
|
||||
def setup_test_db():
|
||||
"""Create a test database and session factory"""
|
||||
# Create a new engine for this test run
|
||||
@@ -30,14 +32,12 @@ def setup_test_db():
|
||||
|
||||
# Create session factory
|
||||
TestingSessionLocal = sessionmaker(
|
||||
autocommit=False,
|
||||
autoflush=False,
|
||||
bind=test_engine,
|
||||
expire_on_commit=False
|
||||
autocommit=False, autoflush=False, bind=test_engine, expire_on_commit=False
|
||||
)
|
||||
|
||||
return test_engine, TestingSessionLocal
|
||||
|
||||
|
||||
def teardown_test_db(engine):
|
||||
"""Clean up after tests"""
|
||||
# Drop all tables
|
||||
@@ -46,13 +46,14 @@ def teardown_test_db(engine):
|
||||
# Dispose of engine
|
||||
engine.dispose()
|
||||
|
||||
|
||||
async def get_async_test_engine():
|
||||
"""Create an async SQLite in-memory engine specifically for testing"""
|
||||
test_engine = create_async_engine(
|
||||
"sqlite+aiosqlite:///:memory:",
|
||||
connect_args={"check_same_thread": False},
|
||||
poolclass=StaticPool, # Use static pool for in-memory testing
|
||||
echo=False
|
||||
echo=False,
|
||||
)
|
||||
return test_engine
|
||||
|
||||
@@ -69,7 +70,7 @@ async def setup_async_test_db():
|
||||
autoflush=False,
|
||||
bind=test_engine,
|
||||
expire_on_commit=False,
|
||||
class_=AsyncSession
|
||||
class_=AsyncSession,
|
||||
)
|
||||
|
||||
return test_engine, AsyncTestingSessionLocal
|
||||
|
||||
Reference in New Issue
Block a user