Compare commits

...

4 Commits

Author SHA1 Message Date
Felipe Cardoso
80d2dc0cb2 fix(backend): clear VIRTUAL_ENV before invoking pyright
Prevents a spurious warning when the shell's VIRTUAL_ENV points to a
different project's venv. Pyright detects the mismatch and warns; clearing
the variable inline forces pyright to resolve the venv from pyrightconfig.json.
2026-02-28 19:48:33 +01:00
Felipe Cardoso
a8aa416ecb refactor(backend): migrate type checking from mypy to pyright
Replace mypy>=1.8.0 with pyright>=1.1.390. Remove all [tool.mypy] and
[tool.pydantic-mypy] sections from pyproject.toml and add
pyrightconfig.json (standard mode, SQLAlchemy false-positive rules
suppressed globally).

Fixes surfaced by pyright:
- Remove unreachable except AuthError clauses in login/login_oauth (same class as AuthenticationError)
- Fix Pydantic v2 list Field: min_items/max_items → min_length/max_length
- Split OAuthProviderConfig TypedDict into required + optional(email_url) inheritance
- Move JWTError/ExpiredSignatureError from lazy try-block imports to module level
- Add timezone-aware guard to UserSession.is_expired to match sibling models
- Fix is_active: bool → bool | None in three organization repo signatures
- Initialize search_filter = None before conditional block (possibly unbound fix)
- Add bool() casts to model is_expired and repo is_active/is_superuser returns
- Restructure except (JWTError, Exception) into separate except clauses
2026-02-28 19:12:40 +01:00
Felipe Cardoso
4c6bf55bcc Refactor(backend): improve formatting in services, repositories & tests
- Consistently format multi-line function headers, exception handling, and repository method calls for readability.
- Reorganize misplaced imports across modules (e.g., services & tests) into proper sorted order.
- Adjust indentation, line breaks, and spacing inconsistencies in tests and migration files.
- Cleanup unnecessary trailing newlines and reorganize `__all__` declarations for consistency.
2026-02-28 18:37:56 +01:00
Felipe Cardoso
98b455fdc3 refactor(backend): enforce route→service→repo layered architecture
- introduce custom repository exception hierarchy (DuplicateEntryError,
  IntegrityConstraintError, InvalidInputError) replacing raw ValueError
- eliminate all direct repository imports and raw SQL from route layer
- add UserService, SessionService, OrganizationService to service layer
- add get_stats/get_org_distribution service methods replacing admin inline SQL
- fix timing side-channel in authenticate_user via dummy bcrypt check
- replace SHA-256 client secret fallback with explicit InvalidClientError
- replace assert with InvalidGrantError in authorization code exchange
- replace N+1 token revocation loops with bulk UPDATE statements
- rename oauth account token fields (drop misleading 'encrypted' suffix)
- add Alembic migration 0003 for token field column rename
- add 45 new service/repository tests; 975 passing, 94% coverage
2026-02-27 09:32:57 +01:00
75 changed files with 3488 additions and 2166 deletions

View File

@@ -1,5 +1,8 @@
.PHONY: help lint lint-fix format format-check type-check test test-cov validate clean install-dev sync check-docker install-e2e test-e2e test-e2e-schema test-all
# Prevent a stale VIRTUAL_ENV in the caller's shell from confusing uv
unexport VIRTUAL_ENV
# Default target
help:
@echo "🚀 FastAPI Backend - Development Commands"
@@ -14,7 +17,7 @@ help:
@echo " make lint-fix - Run Ruff linter with auto-fix"
@echo " make format - Format code with Ruff"
@echo " make format-check - Check if code is formatted"
@echo " make type-check - Run mypy type checking"
@echo " make type-check - Run pyright type checking"
@echo " make validate - Run all checks (lint + format + types)"
@echo ""
@echo "Testing:"
@@ -63,8 +66,8 @@ format-check:
@uv run ruff format --check app/ tests/
type-check:
@echo "🔎 Running mypy type checking..."
@uv run mypy app/
@echo "🔎 Running pyright type checking..."
@uv run pyright app/
validate: lint format-check type-check
@echo "✅ All quality checks passed!"
@@ -127,7 +130,7 @@ clean:
@echo "🧹 Cleaning up..."
@find . -type d -name "__pycache__" -exec rm -rf {} + 2>/dev/null || true
@find . -type d -name ".pytest_cache" -exec rm -rf {} + 2>/dev/null || true
@find . -type d -name ".mypy_cache" -exec rm -rf {} + 2>/dev/null || true
@find . -type d -name ".pyright" -exec rm -rf {} + 2>/dev/null || true
@find . -type d -name ".ruff_cache" -exec rm -rf {} + 2>/dev/null || true
@find . -type d -name "*.egg-info" -exec rm -rf {} + 2>/dev/null || true
@find . -type d -name "htmlcov" -exec rm -rf {} + 2>/dev/null || true

View File

@@ -40,6 +40,7 @@ def include_object(object, name, type_, reflected, compare_to):
return False
return True
# Interpret the config file for Python logging.
# This line sets up loggers basically.
if config.config_file_name is not None:

View File

@@ -1,262 +1,446 @@
"""initial models
Revision ID: 0001
Revises:
Revises:
Create Date: 2025-11-27 09:08:09.464506
"""
from typing import Sequence, Union
from alembic import op
from collections.abc import Sequence
import sqlalchemy as sa
from alembic import op
from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision: str = '0001'
down_revision: Union[str, None] = None
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
revision: str = "0001"
down_revision: str | None = None
branch_labels: str | Sequence[str] | None = None
depends_on: str | Sequence[str] | None = None
def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.create_table('oauth_states',
sa.Column('state', sa.String(length=255), nullable=False),
sa.Column('code_verifier', sa.String(length=128), nullable=True),
sa.Column('nonce', sa.String(length=255), nullable=True),
sa.Column('provider', sa.String(length=50), nullable=False),
sa.Column('redirect_uri', sa.String(length=500), nullable=True),
sa.Column('user_id', sa.UUID(), nullable=True),
sa.Column('expires_at', sa.DateTime(timezone=True), nullable=False),
sa.Column('id', sa.UUID(), nullable=False),
sa.Column('created_at', sa.DateTime(timezone=True), nullable=False),
sa.Column('updated_at', sa.DateTime(timezone=True), nullable=False),
sa.PrimaryKeyConstraint('id')
op.create_table(
"oauth_states",
sa.Column("state", sa.String(length=255), nullable=False),
sa.Column("code_verifier", sa.String(length=128), nullable=True),
sa.Column("nonce", sa.String(length=255), nullable=True),
sa.Column("provider", sa.String(length=50), nullable=False),
sa.Column("redirect_uri", sa.String(length=500), nullable=True),
sa.Column("user_id", sa.UUID(), nullable=True),
sa.Column("expires_at", sa.DateTime(timezone=True), nullable=False),
sa.Column("id", sa.UUID(), nullable=False),
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False),
sa.Column("updated_at", sa.DateTime(timezone=True), nullable=False),
sa.PrimaryKeyConstraint("id"),
)
op.create_index(op.f('ix_oauth_states_state'), 'oauth_states', ['state'], unique=True)
op.create_table('organizations',
sa.Column('name', sa.String(length=255), nullable=False),
sa.Column('slug', sa.String(length=255), nullable=False),
sa.Column('description', sa.Text(), nullable=True),
sa.Column('is_active', sa.Boolean(), nullable=False),
sa.Column('settings', postgresql.JSONB(astext_type=sa.Text()), nullable=True),
sa.Column('id', sa.UUID(), nullable=False),
sa.Column('created_at', sa.DateTime(timezone=True), nullable=False),
sa.Column('updated_at', sa.DateTime(timezone=True), nullable=False),
sa.PrimaryKeyConstraint('id')
op.create_index(
op.f("ix_oauth_states_state"), "oauth_states", ["state"], unique=True
)
op.create_index(op.f('ix_organizations_is_active'), 'organizations', ['is_active'], unique=False)
op.create_index(op.f('ix_organizations_name'), 'organizations', ['name'], unique=False)
op.create_index('ix_organizations_name_active', 'organizations', ['name', 'is_active'], unique=False)
op.create_index(op.f('ix_organizations_slug'), 'organizations', ['slug'], unique=True)
op.create_index('ix_organizations_slug_active', 'organizations', ['slug', 'is_active'], unique=False)
op.create_table('users',
sa.Column('email', sa.String(length=255), nullable=False),
sa.Column('password_hash', sa.String(length=255), nullable=True),
sa.Column('first_name', sa.String(length=100), nullable=False),
sa.Column('last_name', sa.String(length=100), nullable=True),
sa.Column('phone_number', sa.String(length=20), nullable=True),
sa.Column('is_active', sa.Boolean(), nullable=False),
sa.Column('is_superuser', sa.Boolean(), nullable=False),
sa.Column('preferences', postgresql.JSONB(astext_type=sa.Text()), nullable=True),
sa.Column('locale', sa.String(length=10), nullable=True),
sa.Column('deleted_at', sa.DateTime(timezone=True), nullable=True),
sa.Column('id', sa.UUID(), nullable=False),
sa.Column('created_at', sa.DateTime(timezone=True), nullable=False),
sa.Column('updated_at', sa.DateTime(timezone=True), nullable=False),
sa.PrimaryKeyConstraint('id')
op.create_table(
"organizations",
sa.Column("name", sa.String(length=255), nullable=False),
sa.Column("slug", sa.String(length=255), nullable=False),
sa.Column("description", sa.Text(), nullable=True),
sa.Column("is_active", sa.Boolean(), nullable=False),
sa.Column("settings", postgresql.JSONB(astext_type=sa.Text()), nullable=True),
sa.Column("id", sa.UUID(), nullable=False),
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False),
sa.Column("updated_at", sa.DateTime(timezone=True), nullable=False),
sa.PrimaryKeyConstraint("id"),
)
op.create_index(op.f('ix_users_deleted_at'), 'users', ['deleted_at'], unique=False)
op.create_index(op.f('ix_users_email'), 'users', ['email'], unique=True)
op.create_index(op.f('ix_users_is_active'), 'users', ['is_active'], unique=False)
op.create_index(op.f('ix_users_is_superuser'), 'users', ['is_superuser'], unique=False)
op.create_index(op.f('ix_users_locale'), 'users', ['locale'], unique=False)
op.create_table('oauth_accounts',
sa.Column('user_id', sa.UUID(), nullable=False),
sa.Column('provider', sa.String(length=50), nullable=False),
sa.Column('provider_user_id', sa.String(length=255), nullable=False),
sa.Column('provider_email', sa.String(length=255), nullable=True),
sa.Column('access_token_encrypted', sa.String(length=2048), nullable=True),
sa.Column('refresh_token_encrypted', sa.String(length=2048), nullable=True),
sa.Column('token_expires_at', sa.DateTime(timezone=True), nullable=True),
sa.Column('id', sa.UUID(), nullable=False),
sa.Column('created_at', sa.DateTime(timezone=True), nullable=False),
sa.Column('updated_at', sa.DateTime(timezone=True), nullable=False),
sa.ForeignKeyConstraint(['user_id'], ['users.id'], ondelete='CASCADE'),
sa.PrimaryKeyConstraint('id'),
sa.UniqueConstraint('provider', 'provider_user_id', name='uq_oauth_provider_user')
op.create_index(
op.f("ix_organizations_is_active"), "organizations", ["is_active"], unique=False
)
op.create_index(op.f('ix_oauth_accounts_provider'), 'oauth_accounts', ['provider'], unique=False)
op.create_index(op.f('ix_oauth_accounts_provider_email'), 'oauth_accounts', ['provider_email'], unique=False)
op.create_index(op.f('ix_oauth_accounts_user_id'), 'oauth_accounts', ['user_id'], unique=False)
op.create_index('ix_oauth_accounts_user_provider', 'oauth_accounts', ['user_id', 'provider'], unique=False)
op.create_table('oauth_clients',
sa.Column('client_id', sa.String(length=64), nullable=False),
sa.Column('client_secret_hash', sa.String(length=255), nullable=True),
sa.Column('client_name', sa.String(length=255), nullable=False),
sa.Column('client_description', sa.String(length=1000), nullable=True),
sa.Column('client_type', sa.String(length=20), nullable=False),
sa.Column('redirect_uris', postgresql.JSONB(astext_type=sa.Text()), nullable=False),
sa.Column('allowed_scopes', postgresql.JSONB(astext_type=sa.Text()), nullable=False),
sa.Column('access_token_lifetime', sa.String(length=10), nullable=False),
sa.Column('refresh_token_lifetime', sa.String(length=10), nullable=False),
sa.Column('is_active', sa.Boolean(), nullable=False),
sa.Column('owner_user_id', sa.UUID(), nullable=True),
sa.Column('mcp_server_url', sa.String(length=2048), nullable=True),
sa.Column('id', sa.UUID(), nullable=False),
sa.Column('created_at', sa.DateTime(timezone=True), nullable=False),
sa.Column('updated_at', sa.DateTime(timezone=True), nullable=False),
sa.ForeignKeyConstraint(['owner_user_id'], ['users.id'], ondelete='SET NULL'),
sa.PrimaryKeyConstraint('id')
op.create_index(
op.f("ix_organizations_name"), "organizations", ["name"], unique=False
)
op.create_index(op.f('ix_oauth_clients_client_id'), 'oauth_clients', ['client_id'], unique=True)
op.create_index(op.f('ix_oauth_clients_is_active'), 'oauth_clients', ['is_active'], unique=False)
op.create_table('user_organizations',
sa.Column('user_id', sa.UUID(), nullable=False),
sa.Column('organization_id', sa.UUID(), nullable=False),
sa.Column('role', sa.Enum('OWNER', 'ADMIN', 'MEMBER', 'GUEST', name='organizationrole'), nullable=False),
sa.Column('is_active', sa.Boolean(), nullable=False),
sa.Column('custom_permissions', sa.String(length=500), nullable=True),
sa.Column('created_at', sa.DateTime(timezone=True), nullable=False),
sa.Column('updated_at', sa.DateTime(timezone=True), nullable=False),
sa.ForeignKeyConstraint(['organization_id'], ['organizations.id'], ondelete='CASCADE'),
sa.ForeignKeyConstraint(['user_id'], ['users.id'], ondelete='CASCADE'),
sa.PrimaryKeyConstraint('user_id', 'organization_id')
op.create_index(
"ix_organizations_name_active",
"organizations",
["name", "is_active"],
unique=False,
)
op.create_index('ix_user_org_org_active', 'user_organizations', ['organization_id', 'is_active'], unique=False)
op.create_index('ix_user_org_role', 'user_organizations', ['role'], unique=False)
op.create_index('ix_user_org_user_active', 'user_organizations', ['user_id', 'is_active'], unique=False)
op.create_index(op.f('ix_user_organizations_is_active'), 'user_organizations', ['is_active'], unique=False)
op.create_table('user_sessions',
sa.Column('user_id', sa.UUID(), nullable=False),
sa.Column('refresh_token_jti', sa.String(length=255), nullable=False),
sa.Column('device_name', sa.String(length=255), nullable=True),
sa.Column('device_id', sa.String(length=255), nullable=True),
sa.Column('ip_address', sa.String(length=45), nullable=True),
sa.Column('user_agent', sa.String(length=500), nullable=True),
sa.Column('last_used_at', sa.DateTime(timezone=True), nullable=False),
sa.Column('expires_at', sa.DateTime(timezone=True), nullable=False),
sa.Column('is_active', sa.Boolean(), nullable=False),
sa.Column('location_city', sa.String(length=100), nullable=True),
sa.Column('location_country', sa.String(length=100), nullable=True),
sa.Column('id', sa.UUID(), nullable=False),
sa.Column('created_at', sa.DateTime(timezone=True), nullable=False),
sa.Column('updated_at', sa.DateTime(timezone=True), nullable=False),
sa.ForeignKeyConstraint(['user_id'], ['users.id'], ondelete='CASCADE'),
sa.PrimaryKeyConstraint('id')
op.create_index(
op.f("ix_organizations_slug"), "organizations", ["slug"], unique=True
)
op.create_index(op.f('ix_user_sessions_is_active'), 'user_sessions', ['is_active'], unique=False)
op.create_index('ix_user_sessions_jti_active', 'user_sessions', ['refresh_token_jti', 'is_active'], unique=False)
op.create_index(op.f('ix_user_sessions_refresh_token_jti'), 'user_sessions', ['refresh_token_jti'], unique=True)
op.create_index('ix_user_sessions_user_active', 'user_sessions', ['user_id', 'is_active'], unique=False)
op.create_index(op.f('ix_user_sessions_user_id'), 'user_sessions', ['user_id'], unique=False)
op.create_table('oauth_authorization_codes',
sa.Column('code', sa.String(length=128), nullable=False),
sa.Column('client_id', sa.String(length=64), nullable=False),
sa.Column('user_id', sa.UUID(), nullable=False),
sa.Column('redirect_uri', sa.String(length=2048), nullable=False),
sa.Column('scope', sa.String(length=1000), nullable=False),
sa.Column('code_challenge', sa.String(length=128), nullable=True),
sa.Column('code_challenge_method', sa.String(length=10), nullable=True),
sa.Column('state', sa.String(length=256), nullable=True),
sa.Column('nonce', sa.String(length=256), nullable=True),
sa.Column('expires_at', sa.DateTime(timezone=True), nullable=False),
sa.Column('used', sa.Boolean(), nullable=False),
sa.Column('id', sa.UUID(), nullable=False),
sa.Column('created_at', sa.DateTime(timezone=True), nullable=False),
sa.Column('updated_at', sa.DateTime(timezone=True), nullable=False),
sa.ForeignKeyConstraint(['client_id'], ['oauth_clients.client_id'], ondelete='CASCADE'),
sa.ForeignKeyConstraint(['user_id'], ['users.id'], ondelete='CASCADE'),
sa.PrimaryKeyConstraint('id')
op.create_index(
"ix_organizations_slug_active",
"organizations",
["slug", "is_active"],
unique=False,
)
op.create_index('ix_oauth_authorization_codes_client_user', 'oauth_authorization_codes', ['client_id', 'user_id'], unique=False)
op.create_index(op.f('ix_oauth_authorization_codes_code'), 'oauth_authorization_codes', ['code'], unique=True)
op.create_index('ix_oauth_authorization_codes_expires_at', 'oauth_authorization_codes', ['expires_at'], unique=False)
op.create_table('oauth_consents',
sa.Column('user_id', sa.UUID(), nullable=False),
sa.Column('client_id', sa.String(length=64), nullable=False),
sa.Column('granted_scopes', sa.String(length=1000), nullable=False),
sa.Column('id', sa.UUID(), nullable=False),
sa.Column('created_at', sa.DateTime(timezone=True), nullable=False),
sa.Column('updated_at', sa.DateTime(timezone=True), nullable=False),
sa.ForeignKeyConstraint(['client_id'], ['oauth_clients.client_id'], ondelete='CASCADE'),
sa.ForeignKeyConstraint(['user_id'], ['users.id'], ondelete='CASCADE'),
sa.PrimaryKeyConstraint('id')
op.create_table(
"users",
sa.Column("email", sa.String(length=255), nullable=False),
sa.Column("password_hash", sa.String(length=255), nullable=True),
sa.Column("first_name", sa.String(length=100), nullable=False),
sa.Column("last_name", sa.String(length=100), nullable=True),
sa.Column("phone_number", sa.String(length=20), nullable=True),
sa.Column("is_active", sa.Boolean(), nullable=False),
sa.Column("is_superuser", sa.Boolean(), nullable=False),
sa.Column(
"preferences", postgresql.JSONB(astext_type=sa.Text()), nullable=True
),
sa.Column("locale", sa.String(length=10), nullable=True),
sa.Column("deleted_at", sa.DateTime(timezone=True), nullable=True),
sa.Column("id", sa.UUID(), nullable=False),
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False),
sa.Column("updated_at", sa.DateTime(timezone=True), nullable=False),
sa.PrimaryKeyConstraint("id"),
)
op.create_index('ix_oauth_consents_user_client', 'oauth_consents', ['user_id', 'client_id'], unique=True)
op.create_table('oauth_provider_refresh_tokens',
sa.Column('token_hash', sa.String(length=64), nullable=False),
sa.Column('jti', sa.String(length=64), nullable=False),
sa.Column('client_id', sa.String(length=64), nullable=False),
sa.Column('user_id', sa.UUID(), nullable=False),
sa.Column('scope', sa.String(length=1000), nullable=False),
sa.Column('expires_at', sa.DateTime(timezone=True), nullable=False),
sa.Column('revoked', sa.Boolean(), nullable=False),
sa.Column('last_used_at', sa.DateTime(timezone=True), nullable=True),
sa.Column('device_info', sa.String(length=500), nullable=True),
sa.Column('ip_address', sa.String(length=45), nullable=True),
sa.Column('id', sa.UUID(), nullable=False),
sa.Column('created_at', sa.DateTime(timezone=True), nullable=False),
sa.Column('updated_at', sa.DateTime(timezone=True), nullable=False),
sa.ForeignKeyConstraint(['client_id'], ['oauth_clients.client_id'], ondelete='CASCADE'),
sa.ForeignKeyConstraint(['user_id'], ['users.id'], ondelete='CASCADE'),
sa.PrimaryKeyConstraint('id')
op.create_index(op.f("ix_users_deleted_at"), "users", ["deleted_at"], unique=False)
op.create_index(op.f("ix_users_email"), "users", ["email"], unique=True)
op.create_index(op.f("ix_users_is_active"), "users", ["is_active"], unique=False)
op.create_index(
op.f("ix_users_is_superuser"), "users", ["is_superuser"], unique=False
)
op.create_index(op.f("ix_users_locale"), "users", ["locale"], unique=False)
op.create_table(
"oauth_accounts",
sa.Column("user_id", sa.UUID(), nullable=False),
sa.Column("provider", sa.String(length=50), nullable=False),
sa.Column("provider_user_id", sa.String(length=255), nullable=False),
sa.Column("provider_email", sa.String(length=255), nullable=True),
sa.Column("access_token_encrypted", sa.String(length=2048), nullable=True),
sa.Column("refresh_token_encrypted", sa.String(length=2048), nullable=True),
sa.Column("token_expires_at", sa.DateTime(timezone=True), nullable=True),
sa.Column("id", sa.UUID(), nullable=False),
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False),
sa.Column("updated_at", sa.DateTime(timezone=True), nullable=False),
sa.ForeignKeyConstraint(["user_id"], ["users.id"], ondelete="CASCADE"),
sa.PrimaryKeyConstraint("id"),
sa.UniqueConstraint(
"provider", "provider_user_id", name="uq_oauth_provider_user"
),
)
op.create_index(
op.f("ix_oauth_accounts_provider"), "oauth_accounts", ["provider"], unique=False
)
op.create_index(
op.f("ix_oauth_accounts_provider_email"),
"oauth_accounts",
["provider_email"],
unique=False,
)
op.create_index(
op.f("ix_oauth_accounts_user_id"), "oauth_accounts", ["user_id"], unique=False
)
op.create_index(
"ix_oauth_accounts_user_provider",
"oauth_accounts",
["user_id", "provider"],
unique=False,
)
op.create_table(
"oauth_clients",
sa.Column("client_id", sa.String(length=64), nullable=False),
sa.Column("client_secret_hash", sa.String(length=255), nullable=True),
sa.Column("client_name", sa.String(length=255), nullable=False),
sa.Column("client_description", sa.String(length=1000), nullable=True),
sa.Column("client_type", sa.String(length=20), nullable=False),
sa.Column(
"redirect_uris", postgresql.JSONB(astext_type=sa.Text()), nullable=False
),
sa.Column(
"allowed_scopes", postgresql.JSONB(astext_type=sa.Text()), nullable=False
),
sa.Column("access_token_lifetime", sa.String(length=10), nullable=False),
sa.Column("refresh_token_lifetime", sa.String(length=10), nullable=False),
sa.Column("is_active", sa.Boolean(), nullable=False),
sa.Column("owner_user_id", sa.UUID(), nullable=True),
sa.Column("mcp_server_url", sa.String(length=2048), nullable=True),
sa.Column("id", sa.UUID(), nullable=False),
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False),
sa.Column("updated_at", sa.DateTime(timezone=True), nullable=False),
sa.ForeignKeyConstraint(["owner_user_id"], ["users.id"], ondelete="SET NULL"),
sa.PrimaryKeyConstraint("id"),
)
op.create_index(
op.f("ix_oauth_clients_client_id"), "oauth_clients", ["client_id"], unique=True
)
op.create_index(
op.f("ix_oauth_clients_is_active"), "oauth_clients", ["is_active"], unique=False
)
op.create_table(
"user_organizations",
sa.Column("user_id", sa.UUID(), nullable=False),
sa.Column("organization_id", sa.UUID(), nullable=False),
sa.Column(
"role",
sa.Enum("OWNER", "ADMIN", "MEMBER", "GUEST", name="organizationrole"),
nullable=False,
),
sa.Column("is_active", sa.Boolean(), nullable=False),
sa.Column("custom_permissions", sa.String(length=500), nullable=True),
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False),
sa.Column("updated_at", sa.DateTime(timezone=True), nullable=False),
sa.ForeignKeyConstraint(
["organization_id"], ["organizations.id"], ondelete="CASCADE"
),
sa.ForeignKeyConstraint(["user_id"], ["users.id"], ondelete="CASCADE"),
sa.PrimaryKeyConstraint("user_id", "organization_id"),
)
op.create_index(
"ix_user_org_org_active",
"user_organizations",
["organization_id", "is_active"],
unique=False,
)
op.create_index("ix_user_org_role", "user_organizations", ["role"], unique=False)
op.create_index(
"ix_user_org_user_active",
"user_organizations",
["user_id", "is_active"],
unique=False,
)
op.create_index(
op.f("ix_user_organizations_is_active"),
"user_organizations",
["is_active"],
unique=False,
)
op.create_table(
"user_sessions",
sa.Column("user_id", sa.UUID(), nullable=False),
sa.Column("refresh_token_jti", sa.String(length=255), nullable=False),
sa.Column("device_name", sa.String(length=255), nullable=True),
sa.Column("device_id", sa.String(length=255), nullable=True),
sa.Column("ip_address", sa.String(length=45), nullable=True),
sa.Column("user_agent", sa.String(length=500), nullable=True),
sa.Column("last_used_at", sa.DateTime(timezone=True), nullable=False),
sa.Column("expires_at", sa.DateTime(timezone=True), nullable=False),
sa.Column("is_active", sa.Boolean(), nullable=False),
sa.Column("location_city", sa.String(length=100), nullable=True),
sa.Column("location_country", sa.String(length=100), nullable=True),
sa.Column("id", sa.UUID(), nullable=False),
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False),
sa.Column("updated_at", sa.DateTime(timezone=True), nullable=False),
sa.ForeignKeyConstraint(["user_id"], ["users.id"], ondelete="CASCADE"),
sa.PrimaryKeyConstraint("id"),
)
op.create_index(
op.f("ix_user_sessions_is_active"), "user_sessions", ["is_active"], unique=False
)
op.create_index(
"ix_user_sessions_jti_active",
"user_sessions",
["refresh_token_jti", "is_active"],
unique=False,
)
op.create_index(
op.f("ix_user_sessions_refresh_token_jti"),
"user_sessions",
["refresh_token_jti"],
unique=True,
)
op.create_index(
"ix_user_sessions_user_active",
"user_sessions",
["user_id", "is_active"],
unique=False,
)
op.create_index(
op.f("ix_user_sessions_user_id"), "user_sessions", ["user_id"], unique=False
)
op.create_table(
"oauth_authorization_codes",
sa.Column("code", sa.String(length=128), nullable=False),
sa.Column("client_id", sa.String(length=64), nullable=False),
sa.Column("user_id", sa.UUID(), nullable=False),
sa.Column("redirect_uri", sa.String(length=2048), nullable=False),
sa.Column("scope", sa.String(length=1000), nullable=False),
sa.Column("code_challenge", sa.String(length=128), nullable=True),
sa.Column("code_challenge_method", sa.String(length=10), nullable=True),
sa.Column("state", sa.String(length=256), nullable=True),
sa.Column("nonce", sa.String(length=256), nullable=True),
sa.Column("expires_at", sa.DateTime(timezone=True), nullable=False),
sa.Column("used", sa.Boolean(), nullable=False),
sa.Column("id", sa.UUID(), nullable=False),
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False),
sa.Column("updated_at", sa.DateTime(timezone=True), nullable=False),
sa.ForeignKeyConstraint(
["client_id"], ["oauth_clients.client_id"], ondelete="CASCADE"
),
sa.ForeignKeyConstraint(["user_id"], ["users.id"], ondelete="CASCADE"),
sa.PrimaryKeyConstraint("id"),
)
op.create_index(
"ix_oauth_authorization_codes_client_user",
"oauth_authorization_codes",
["client_id", "user_id"],
unique=False,
)
op.create_index(
op.f("ix_oauth_authorization_codes_code"),
"oauth_authorization_codes",
["code"],
unique=True,
)
op.create_index(
"ix_oauth_authorization_codes_expires_at",
"oauth_authorization_codes",
["expires_at"],
unique=False,
)
op.create_table(
"oauth_consents",
sa.Column("user_id", sa.UUID(), nullable=False),
sa.Column("client_id", sa.String(length=64), nullable=False),
sa.Column("granted_scopes", sa.String(length=1000), nullable=False),
sa.Column("id", sa.UUID(), nullable=False),
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False),
sa.Column("updated_at", sa.DateTime(timezone=True), nullable=False),
sa.ForeignKeyConstraint(
["client_id"], ["oauth_clients.client_id"], ondelete="CASCADE"
),
sa.ForeignKeyConstraint(["user_id"], ["users.id"], ondelete="CASCADE"),
sa.PrimaryKeyConstraint("id"),
)
op.create_index(
"ix_oauth_consents_user_client",
"oauth_consents",
["user_id", "client_id"],
unique=True,
)
op.create_table(
"oauth_provider_refresh_tokens",
sa.Column("token_hash", sa.String(length=64), nullable=False),
sa.Column("jti", sa.String(length=64), nullable=False),
sa.Column("client_id", sa.String(length=64), nullable=False),
sa.Column("user_id", sa.UUID(), nullable=False),
sa.Column("scope", sa.String(length=1000), nullable=False),
sa.Column("expires_at", sa.DateTime(timezone=True), nullable=False),
sa.Column("revoked", sa.Boolean(), nullable=False),
sa.Column("last_used_at", sa.DateTime(timezone=True), nullable=True),
sa.Column("device_info", sa.String(length=500), nullable=True),
sa.Column("ip_address", sa.String(length=45), nullable=True),
sa.Column("id", sa.UUID(), nullable=False),
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False),
sa.Column("updated_at", sa.DateTime(timezone=True), nullable=False),
sa.ForeignKeyConstraint(
["client_id"], ["oauth_clients.client_id"], ondelete="CASCADE"
),
sa.ForeignKeyConstraint(["user_id"], ["users.id"], ondelete="CASCADE"),
sa.PrimaryKeyConstraint("id"),
)
op.create_index(
"ix_oauth_provider_refresh_tokens_client_user",
"oauth_provider_refresh_tokens",
["client_id", "user_id"],
unique=False,
)
op.create_index(
"ix_oauth_provider_refresh_tokens_expires_at",
"oauth_provider_refresh_tokens",
["expires_at"],
unique=False,
)
op.create_index(
op.f("ix_oauth_provider_refresh_tokens_jti"),
"oauth_provider_refresh_tokens",
["jti"],
unique=True,
)
op.create_index(
op.f("ix_oauth_provider_refresh_tokens_revoked"),
"oauth_provider_refresh_tokens",
["revoked"],
unique=False,
)
op.create_index(
op.f("ix_oauth_provider_refresh_tokens_token_hash"),
"oauth_provider_refresh_tokens",
["token_hash"],
unique=True,
)
op.create_index(
"ix_oauth_provider_refresh_tokens_user_revoked",
"oauth_provider_refresh_tokens",
["user_id", "revoked"],
unique=False,
)
op.create_index('ix_oauth_provider_refresh_tokens_client_user', 'oauth_provider_refresh_tokens', ['client_id', 'user_id'], unique=False)
op.create_index('ix_oauth_provider_refresh_tokens_expires_at', 'oauth_provider_refresh_tokens', ['expires_at'], unique=False)
op.create_index(op.f('ix_oauth_provider_refresh_tokens_jti'), 'oauth_provider_refresh_tokens', ['jti'], unique=True)
op.create_index(op.f('ix_oauth_provider_refresh_tokens_revoked'), 'oauth_provider_refresh_tokens', ['revoked'], unique=False)
op.create_index(op.f('ix_oauth_provider_refresh_tokens_token_hash'), 'oauth_provider_refresh_tokens', ['token_hash'], unique=True)
op.create_index('ix_oauth_provider_refresh_tokens_user_revoked', 'oauth_provider_refresh_tokens', ['user_id', 'revoked'], unique=False)
# ### end Alembic commands ###
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_index('ix_oauth_provider_refresh_tokens_user_revoked', table_name='oauth_provider_refresh_tokens')
op.drop_index(op.f('ix_oauth_provider_refresh_tokens_token_hash'), table_name='oauth_provider_refresh_tokens')
op.drop_index(op.f('ix_oauth_provider_refresh_tokens_revoked'), table_name='oauth_provider_refresh_tokens')
op.drop_index(op.f('ix_oauth_provider_refresh_tokens_jti'), table_name='oauth_provider_refresh_tokens')
op.drop_index('ix_oauth_provider_refresh_tokens_expires_at', table_name='oauth_provider_refresh_tokens')
op.drop_index('ix_oauth_provider_refresh_tokens_client_user', table_name='oauth_provider_refresh_tokens')
op.drop_table('oauth_provider_refresh_tokens')
op.drop_index('ix_oauth_consents_user_client', table_name='oauth_consents')
op.drop_table('oauth_consents')
op.drop_index('ix_oauth_authorization_codes_expires_at', table_name='oauth_authorization_codes')
op.drop_index(op.f('ix_oauth_authorization_codes_code'), table_name='oauth_authorization_codes')
op.drop_index('ix_oauth_authorization_codes_client_user', table_name='oauth_authorization_codes')
op.drop_table('oauth_authorization_codes')
op.drop_index(op.f('ix_user_sessions_user_id'), table_name='user_sessions')
op.drop_index('ix_user_sessions_user_active', table_name='user_sessions')
op.drop_index(op.f('ix_user_sessions_refresh_token_jti'), table_name='user_sessions')
op.drop_index('ix_user_sessions_jti_active', table_name='user_sessions')
op.drop_index(op.f('ix_user_sessions_is_active'), table_name='user_sessions')
op.drop_table('user_sessions')
op.drop_index(op.f('ix_user_organizations_is_active'), table_name='user_organizations')
op.drop_index('ix_user_org_user_active', table_name='user_organizations')
op.drop_index('ix_user_org_role', table_name='user_organizations')
op.drop_index('ix_user_org_org_active', table_name='user_organizations')
op.drop_table('user_organizations')
op.drop_index(op.f('ix_oauth_clients_is_active'), table_name='oauth_clients')
op.drop_index(op.f('ix_oauth_clients_client_id'), table_name='oauth_clients')
op.drop_table('oauth_clients')
op.drop_index('ix_oauth_accounts_user_provider', table_name='oauth_accounts')
op.drop_index(op.f('ix_oauth_accounts_user_id'), table_name='oauth_accounts')
op.drop_index(op.f('ix_oauth_accounts_provider_email'), table_name='oauth_accounts')
op.drop_index(op.f('ix_oauth_accounts_provider'), table_name='oauth_accounts')
op.drop_table('oauth_accounts')
op.drop_index(op.f('ix_users_locale'), table_name='users')
op.drop_index(op.f('ix_users_is_superuser'), table_name='users')
op.drop_index(op.f('ix_users_is_active'), table_name='users')
op.drop_index(op.f('ix_users_email'), table_name='users')
op.drop_index(op.f('ix_users_deleted_at'), table_name='users')
op.drop_table('users')
op.drop_index('ix_organizations_slug_active', table_name='organizations')
op.drop_index(op.f('ix_organizations_slug'), table_name='organizations')
op.drop_index('ix_organizations_name_active', table_name='organizations')
op.drop_index(op.f('ix_organizations_name'), table_name='organizations')
op.drop_index(op.f('ix_organizations_is_active'), table_name='organizations')
op.drop_table('organizations')
op.drop_index(op.f('ix_oauth_states_state'), table_name='oauth_states')
op.drop_table('oauth_states')
op.drop_index(
"ix_oauth_provider_refresh_tokens_user_revoked",
table_name="oauth_provider_refresh_tokens",
)
op.drop_index(
op.f("ix_oauth_provider_refresh_tokens_token_hash"),
table_name="oauth_provider_refresh_tokens",
)
op.drop_index(
op.f("ix_oauth_provider_refresh_tokens_revoked"),
table_name="oauth_provider_refresh_tokens",
)
op.drop_index(
op.f("ix_oauth_provider_refresh_tokens_jti"),
table_name="oauth_provider_refresh_tokens",
)
op.drop_index(
"ix_oauth_provider_refresh_tokens_expires_at",
table_name="oauth_provider_refresh_tokens",
)
op.drop_index(
"ix_oauth_provider_refresh_tokens_client_user",
table_name="oauth_provider_refresh_tokens",
)
op.drop_table("oauth_provider_refresh_tokens")
op.drop_index("ix_oauth_consents_user_client", table_name="oauth_consents")
op.drop_table("oauth_consents")
op.drop_index(
"ix_oauth_authorization_codes_expires_at",
table_name="oauth_authorization_codes",
)
op.drop_index(
op.f("ix_oauth_authorization_codes_code"),
table_name="oauth_authorization_codes",
)
op.drop_index(
"ix_oauth_authorization_codes_client_user",
table_name="oauth_authorization_codes",
)
op.drop_table("oauth_authorization_codes")
op.drop_index(op.f("ix_user_sessions_user_id"), table_name="user_sessions")
op.drop_index("ix_user_sessions_user_active", table_name="user_sessions")
op.drop_index(
op.f("ix_user_sessions_refresh_token_jti"), table_name="user_sessions"
)
op.drop_index("ix_user_sessions_jti_active", table_name="user_sessions")
op.drop_index(op.f("ix_user_sessions_is_active"), table_name="user_sessions")
op.drop_table("user_sessions")
op.drop_index(
op.f("ix_user_organizations_is_active"), table_name="user_organizations"
)
op.drop_index("ix_user_org_user_active", table_name="user_organizations")
op.drop_index("ix_user_org_role", table_name="user_organizations")
op.drop_index("ix_user_org_org_active", table_name="user_organizations")
op.drop_table("user_organizations")
op.drop_index(op.f("ix_oauth_clients_is_active"), table_name="oauth_clients")
op.drop_index(op.f("ix_oauth_clients_client_id"), table_name="oauth_clients")
op.drop_table("oauth_clients")
op.drop_index("ix_oauth_accounts_user_provider", table_name="oauth_accounts")
op.drop_index(op.f("ix_oauth_accounts_user_id"), table_name="oauth_accounts")
op.drop_index(op.f("ix_oauth_accounts_provider_email"), table_name="oauth_accounts")
op.drop_index(op.f("ix_oauth_accounts_provider"), table_name="oauth_accounts")
op.drop_table("oauth_accounts")
op.drop_index(op.f("ix_users_locale"), table_name="users")
op.drop_index(op.f("ix_users_is_superuser"), table_name="users")
op.drop_index(op.f("ix_users_is_active"), table_name="users")
op.drop_index(op.f("ix_users_email"), table_name="users")
op.drop_index(op.f("ix_users_deleted_at"), table_name="users")
op.drop_table("users")
op.drop_index("ix_organizations_slug_active", table_name="organizations")
op.drop_index(op.f("ix_organizations_slug"), table_name="organizations")
op.drop_index("ix_organizations_name_active", table_name="organizations")
op.drop_index(op.f("ix_organizations_name"), table_name="organizations")
op.drop_index(op.f("ix_organizations_is_active"), table_name="organizations")
op.drop_table("organizations")
op.drop_index(op.f("ix_oauth_states_state"), table_name="oauth_states")
op.drop_table("oauth_states")
# ### end Alembic commands ###

View File

@@ -114,8 +114,13 @@ def upgrade() -> None:
def downgrade() -> None:
# Drop indexes in reverse order
op.drop_index("ix_perf_oauth_auth_codes_expires", table_name="oauth_authorization_codes")
op.drop_index("ix_perf_oauth_refresh_tokens_expires", table_name="oauth_provider_refresh_tokens")
op.drop_index(
"ix_perf_oauth_auth_codes_expires", table_name="oauth_authorization_codes"
)
op.drop_index(
"ix_perf_oauth_refresh_tokens_expires",
table_name="oauth_provider_refresh_tokens",
)
op.drop_index("ix_perf_user_sessions_expires", table_name="user_sessions")
op.drop_index("ix_perf_organizations_slug_lower", table_name="organizations")
op.drop_index("ix_perf_users_active", table_name="users")

View File

@@ -0,0 +1,35 @@
"""rename oauth account token fields drop encrypted suffix
Revision ID: 0003
Revises: 0002
Create Date: 2026-02-27 01:03:18.869178
"""
from collections.abc import Sequence
from alembic import op
# revision identifiers, used by Alembic.
revision: str = "0003"
down_revision: str | None = "0002"
branch_labels: str | Sequence[str] | None = None
depends_on: str | Sequence[str] | None = None
def upgrade() -> None:
op.alter_column(
"oauth_accounts", "access_token_encrypted", new_column_name="access_token"
)
op.alter_column(
"oauth_accounts", "refresh_token_encrypted", new_column_name="refresh_token"
)
def downgrade() -> None:
op.alter_column(
"oauth_accounts", "access_token", new_column_name="access_token_encrypted"
)
op.alter_column(
"oauth_accounts", "refresh_token", new_column_name="refresh_token_encrypted"
)

View File

@@ -1,12 +1,12 @@
from fastapi import Depends, Header, HTTPException, status
from fastapi.security import OAuth2PasswordBearer
from fastapi.security.utils import get_authorization_scheme_param
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.core.auth import TokenExpiredError, TokenInvalidError, get_token_data
from app.core.database import get_db
from app.models.user import User
from app.repositories.user import user_repo
# OAuth2 configuration
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/api/v1/auth/login")
@@ -32,9 +32,8 @@ async def get_current_user(
# Decode token and get user ID
token_data = get_token_data(token)
# Get user from database
result = await db.execute(select(User).where(User.id == token_data.user_id))
user = result.scalar_one_or_none()
# Get user from database via repository
user = await user_repo.get(db, id=str(token_data.user_id))
if not user:
raise HTTPException(
@@ -144,8 +143,7 @@ async def get_optional_current_user(
try:
token_data = get_token_data(token)
result = await db.execute(select(User).where(User.id == token_data.user_id))
user = result.scalar_one_or_none()
user = await user_repo.get(db, id=str(token_data.user_id))
if not user or not user.is_active:
return None
return user

View File

@@ -15,9 +15,9 @@ from sqlalchemy.ext.asyncio import AsyncSession
from app.api.dependencies.auth import get_current_user
from app.core.database import get_db
from app.crud.organization import organization as organization_crud
from app.models.user import User
from app.models.user_organization import OrganizationRole
from app.services.organization_service import organization_service
def require_superuser(current_user: User = Depends(get_current_user)) -> User:
@@ -81,7 +81,7 @@ class OrganizationPermission:
return current_user
# Get user's role in organization
user_role = await organization_crud.get_user_role_in_org(
user_role = await organization_service.get_user_role_in_org(
db, user_id=current_user.id, organization_id=organization_id
)
@@ -123,7 +123,7 @@ async def require_org_membership(
if current_user.is_superuser:
return current_user
user_role = await organization_crud.get_user_role_in_org(
user_role = await organization_service.get_user_role_in_org(
db, user_id=current_user.id, organization_id=organization_id
)

View File

@@ -0,0 +1,41 @@
# app/api/dependencies/services.py
"""FastAPI dependency functions for service singletons."""
from app.services import oauth_provider_service
from app.services.auth_service import AuthService
from app.services.oauth_service import OAuthService
from app.services.organization_service import OrganizationService, organization_service
from app.services.session_service import SessionService, session_service
from app.services.user_service import UserService, user_service
def get_auth_service() -> AuthService:
"""Return the AuthService singleton for dependency injection."""
from app.services.auth_service import AuthService as _AuthService
return _AuthService()
def get_user_service() -> UserService:
"""Return the UserService singleton for dependency injection."""
return user_service
def get_organization_service() -> OrganizationService:
"""Return the OrganizationService singleton for dependency injection."""
return organization_service
def get_session_service() -> SessionService:
"""Return the SessionService singleton for dependency injection."""
return session_service
def get_oauth_service() -> OAuthService:
"""Return OAuthService for dependency injection."""
return OAuthService()
def get_oauth_provider_service():
"""Return the oauth_provider_service module for dependency injection."""
return oauth_provider_service

View File

@@ -14,7 +14,6 @@ from uuid import UUID
from fastapi import APIRouter, Depends, Query, status
from pydantic import BaseModel, Field
from sqlalchemy import func, select
from sqlalchemy.ext.asyncio import AsyncSession
from app.api.dependencies.permissions import require_superuser
@@ -25,12 +24,9 @@ from app.core.exceptions import (
ErrorCode,
NotFoundError,
)
from app.crud.organization import organization as organization_crud
from app.crud.session import session as session_crud
from app.crud.user import user as user_crud
from app.models.organization import Organization
from app.core.repository_exceptions import DuplicateEntryError
from app.models.user import User
from app.models.user_organization import OrganizationRole, UserOrganization
from app.models.user_organization import OrganizationRole
from app.schemas.common import (
MessageResponse,
PaginatedResponse,
@@ -46,6 +42,9 @@ from app.schemas.organizations import (
)
from app.schemas.sessions import AdminSessionResponse
from app.schemas.users import UserCreate, UserResponse, UserUpdate
from app.services.organization_service import organization_service
from app.services.session_service import session_service
from app.services.user_service import user_service
logger = logging.getLogger(__name__)
@@ -66,7 +65,7 @@ class BulkUserAction(BaseModel):
action: BulkAction = Field(..., description="Action to perform on selected users")
user_ids: list[UUID] = Field(
..., min_items=1, max_items=100, description="List of user IDs (max 100)"
..., min_length=1, max_length=100, description="List of user IDs (max 100)"
)
@@ -178,38 +177,29 @@ async def admin_get_stats(
"""Get admin dashboard statistics with real data from database."""
from app.core.config import settings
# Check if we have any data
total_users_query = select(func.count()).select_from(User)
total_users = (await db.execute(total_users_query)).scalar() or 0
stats = await user_service.get_stats(db)
total_users = stats["total_users"]
active_count = stats["active_count"]
inactive_count = stats["inactive_count"]
all_users = stats["all_users"]
# If database is essentially empty (only admin user), return demo data
if total_users <= 1 and settings.DEMO_MODE: # pragma: no cover
logger.info("Returning demo stats data (empty database in demo mode)")
return _generate_demo_stats()
# 1. User Growth (Last 30 days) - Improved calculation
datetime.now(UTC) - timedelta(days=30)
# Get all users with their creation dates
all_users_query = select(User).order_by(User.created_at)
result = await db.execute(all_users_query)
all_users = result.scalars().all()
# Build cumulative counts per day
# 1. User Growth (Last 30 days)
user_growth = []
for i in range(29, -1, -1):
date = datetime.now(UTC) - timedelta(days=i)
date_start = date.replace(hour=0, minute=0, second=0, microsecond=0, tzinfo=UTC)
date_end = date_start + timedelta(days=1)
# Count all users created before end of this day
# Make comparison timezone-aware
total_users_on_date = sum(
1
for u in all_users
if u.created_at and u.created_at.replace(tzinfo=UTC) < date_end
)
# Count active users created before end of this day
active_users_on_date = sum(
1
for u in all_users
@@ -227,27 +217,16 @@ async def admin_get_stats(
)
# 2. Organization Distribution - Top 6 organizations by member count
org_query = (
select(Organization.name, func.count(UserOrganization.user_id).label("count"))
.join(UserOrganization, Organization.id == UserOrganization.organization_id)
.group_by(Organization.name)
.order_by(func.count(UserOrganization.user_id).desc())
.limit(6)
)
result = await db.execute(org_query)
org_dist = [
OrgDistributionData(name=row.name, value=row.count) for row in result.all()
]
org_rows = await organization_service.get_org_distribution(db, limit=6)
org_dist = [OrgDistributionData(name=r["name"], value=r["value"]) for r in org_rows]
# 3. User Registration Activity (Last 14 days) - NEW
# 3. User Registration Activity (Last 14 days)
registration_activity = []
for i in range(13, -1, -1):
date = datetime.now(UTC) - timedelta(days=i)
date_start = date.replace(hour=0, minute=0, second=0, microsecond=0, tzinfo=UTC)
date_end = date_start + timedelta(days=1)
# Count users created on this specific day
# Make comparison timezone-aware
day_registrations = sum(
1
for u in all_users
@@ -263,14 +242,6 @@ async def admin_get_stats(
)
# 4. User Status - Active vs Inactive
active_query = select(func.count()).select_from(User).where(User.is_active)
inactive_query = (
select(func.count()).select_from(User).where(User.is_active.is_(False))
)
active_count = (await db.execute(active_query)).scalar() or 0
inactive_count = (await db.execute(inactive_query)).scalar() or 0
logger.info(
f"User status counts - Active: {active_count}, Inactive: {inactive_count}"
)
@@ -321,7 +292,7 @@ async def admin_list_users(
filters["is_superuser"] = is_superuser
# Get users with search
users, total = await user_crud.get_multi_with_total(
users, total = await user_service.list_users(
db,
skip=pagination.offset,
limit=pagination.limit,
@@ -364,12 +335,12 @@ async def admin_create_user(
Allows setting is_superuser and other fields.
"""
try:
user = await user_crud.create(db, obj_in=user_in)
user = await user_service.create_user(db, user_in)
logger.info(f"Admin {admin.email} created user {user.email}")
return user
except ValueError as e:
except DuplicateEntryError as e:
logger.warning(f"Failed to create user: {e!s}")
raise NotFoundError(message=str(e), error_code=ErrorCode.USER_ALREADY_EXISTS)
raise DuplicateError(message=str(e), error_code=ErrorCode.USER_ALREADY_EXISTS)
except Exception as e:
logger.error(f"Error creating user (admin): {e!s}", exc_info=True)
raise
@@ -388,11 +359,7 @@ async def admin_get_user(
db: AsyncSession = Depends(get_db),
) -> Any:
"""Get detailed information about a specific user."""
user = await user_crud.get(db, id=user_id)
if not user:
raise NotFoundError(
message=f"User {user_id} not found", error_code=ErrorCode.USER_NOT_FOUND
)
user = await user_service.get_user(db, str(user_id))
return user
@@ -411,18 +378,11 @@ async def admin_update_user(
) -> Any:
"""Update user information with admin privileges."""
try:
user = await user_crud.get(db, id=user_id)
if not user:
raise NotFoundError(
message=f"User {user_id} not found", error_code=ErrorCode.USER_NOT_FOUND
)
updated_user = await user_crud.update(db, db_obj=user, obj_in=user_in)
user = await user_service.get_user(db, str(user_id))
updated_user = await user_service.update_user(db, user=user, obj_in=user_in)
logger.info(f"Admin {admin.email} updated user {updated_user.email}")
return updated_user
except NotFoundError:
raise
except Exception as e:
logger.error(f"Error updating user (admin): {e!s}", exc_info=True)
raise
@@ -442,11 +402,7 @@ async def admin_delete_user(
) -> Any:
"""Soft delete a user (sets deleted_at timestamp)."""
try:
user = await user_crud.get(db, id=user_id)
if not user:
raise NotFoundError(
message=f"User {user_id} not found", error_code=ErrorCode.USER_NOT_FOUND
)
user = await user_service.get_user(db, str(user_id))
# Prevent deleting yourself
if user.id == admin.id:
@@ -456,15 +412,13 @@ async def admin_delete_user(
error_code=ErrorCode.OPERATION_FORBIDDEN,
)
await user_crud.soft_delete(db, id=user_id)
await user_service.soft_delete_user(db, str(user_id))
logger.info(f"Admin {admin.email} deleted user {user.email}")
return MessageResponse(
success=True, message=f"User {user.email} has been deleted"
)
except NotFoundError:
raise
except Exception as e:
logger.error(f"Error deleting user (admin): {e!s}", exc_info=True)
raise
@@ -484,21 +438,14 @@ async def admin_activate_user(
) -> Any:
"""Activate a user account."""
try:
user = await user_crud.get(db, id=user_id)
if not user:
raise NotFoundError(
message=f"User {user_id} not found", error_code=ErrorCode.USER_NOT_FOUND
)
await user_crud.update(db, db_obj=user, obj_in={"is_active": True})
user = await user_service.get_user(db, str(user_id))
await user_service.update_user(db, user=user, obj_in={"is_active": True})
logger.info(f"Admin {admin.email} activated user {user.email}")
return MessageResponse(
success=True, message=f"User {user.email} has been activated"
)
except NotFoundError:
raise
except Exception as e:
logger.error(f"Error activating user (admin): {e!s}", exc_info=True)
raise
@@ -518,11 +465,7 @@ async def admin_deactivate_user(
) -> Any:
"""Deactivate a user account."""
try:
user = await user_crud.get(db, id=user_id)
if not user:
raise NotFoundError(
message=f"User {user_id} not found", error_code=ErrorCode.USER_NOT_FOUND
)
user = await user_service.get_user(db, str(user_id))
# Prevent deactivating yourself
if user.id == admin.id:
@@ -532,15 +475,13 @@ async def admin_deactivate_user(
error_code=ErrorCode.OPERATION_FORBIDDEN,
)
await user_crud.update(db, db_obj=user, obj_in={"is_active": False})
await user_service.update_user(db, user=user, obj_in={"is_active": False})
logger.info(f"Admin {admin.email} deactivated user {user.email}")
return MessageResponse(
success=True, message=f"User {user.email} has been deactivated"
)
except NotFoundError:
raise
except Exception as e:
logger.error(f"Error deactivating user (admin): {e!s}", exc_info=True)
raise
@@ -567,16 +508,16 @@ async def admin_bulk_user_action(
try:
# Use efficient bulk operations instead of loop
if bulk_action.action == BulkAction.ACTIVATE:
affected_count = await user_crud.bulk_update_status(
affected_count = await user_service.bulk_update_status(
db, user_ids=bulk_action.user_ids, is_active=True
)
elif bulk_action.action == BulkAction.DEACTIVATE:
affected_count = await user_crud.bulk_update_status(
affected_count = await user_service.bulk_update_status(
db, user_ids=bulk_action.user_ids, is_active=False
)
elif bulk_action.action == BulkAction.DELETE:
# bulk_soft_delete automatically excludes the admin user
affected_count = await user_crud.bulk_soft_delete(
affected_count = await user_service.bulk_soft_delete(
db, user_ids=bulk_action.user_ids, exclude_user_id=admin.id
)
else: # pragma: no cover
@@ -624,7 +565,7 @@ async def admin_list_organizations(
"""List all organizations with filtering and search."""
try:
# Use optimized method that gets member counts in single query (no N+1)
orgs_with_data, total = await organization_crud.get_multi_with_member_counts(
orgs_with_data, total = await organization_service.get_multi_with_member_counts(
db,
skip=pagination.offset,
limit=pagination.limit,
@@ -680,7 +621,7 @@ async def admin_create_organization(
) -> Any:
"""Create a new organization."""
try:
org = await organization_crud.create(db, obj_in=org_in)
org = await organization_service.create_organization(db, obj_in=org_in)
logger.info(f"Admin {admin.email} created organization {org.name}")
# Add member count
@@ -697,9 +638,9 @@ async def admin_create_organization(
}
return OrganizationResponse(**org_dict)
except ValueError as e:
except DuplicateEntryError as e:
logger.warning(f"Failed to create organization: {e!s}")
raise NotFoundError(message=str(e), error_code=ErrorCode.ALREADY_EXISTS)
raise DuplicateError(message=str(e), error_code=ErrorCode.ALREADY_EXISTS)
except Exception as e:
logger.error(f"Error creating organization (admin): {e!s}", exc_info=True)
raise
@@ -718,12 +659,7 @@ async def admin_get_organization(
db: AsyncSession = Depends(get_db),
) -> Any:
"""Get detailed information about a specific organization."""
org = await organization_crud.get(db, id=org_id)
if not org:
raise NotFoundError(
message=f"Organization {org_id} not found", error_code=ErrorCode.NOT_FOUND
)
org = await organization_service.get_organization(db, str(org_id))
org_dict = {
"id": org.id,
"name": org.name,
@@ -733,7 +669,7 @@ async def admin_get_organization(
"settings": org.settings,
"created_at": org.created_at,
"updated_at": org.updated_at,
"member_count": await organization_crud.get_member_count(
"member_count": await organization_service.get_member_count(
db, organization_id=org.id
),
}
@@ -755,14 +691,10 @@ async def admin_update_organization(
) -> Any:
"""Update organization information."""
try:
org = await organization_crud.get(db, id=org_id)
if not org:
raise NotFoundError(
message=f"Organization {org_id} not found",
error_code=ErrorCode.NOT_FOUND,
)
updated_org = await organization_crud.update(db, db_obj=org, obj_in=org_in)
org = await organization_service.get_organization(db, str(org_id))
updated_org = await organization_service.update_organization(
db, org=org, obj_in=org_in
)
logger.info(f"Admin {admin.email} updated organization {updated_org.name}")
org_dict = {
@@ -774,14 +706,12 @@ async def admin_update_organization(
"settings": updated_org.settings,
"created_at": updated_org.created_at,
"updated_at": updated_org.updated_at,
"member_count": await organization_crud.get_member_count(
"member_count": await organization_service.get_member_count(
db, organization_id=updated_org.id
),
}
return OrganizationResponse(**org_dict)
except NotFoundError:
raise
except Exception as e:
logger.error(f"Error updating organization (admin): {e!s}", exc_info=True)
raise
@@ -801,22 +731,14 @@ async def admin_delete_organization(
) -> Any:
"""Delete an organization and all its relationships."""
try:
org = await organization_crud.get(db, id=org_id)
if not org:
raise NotFoundError(
message=f"Organization {org_id} not found",
error_code=ErrorCode.NOT_FOUND,
)
await organization_crud.remove(db, id=org_id)
org = await organization_service.get_organization(db, str(org_id))
await organization_service.remove_organization(db, str(org_id))
logger.info(f"Admin {admin.email} deleted organization {org.name}")
return MessageResponse(
success=True, message=f"Organization {org.name} has been deleted"
)
except NotFoundError:
raise
except Exception as e:
logger.error(f"Error deleting organization (admin): {e!s}", exc_info=True)
raise
@@ -838,14 +760,8 @@ async def admin_list_organization_members(
) -> Any:
"""List all members of an organization."""
try:
org = await organization_crud.get(db, id=org_id)
if not org:
raise NotFoundError(
message=f"Organization {org_id} not found",
error_code=ErrorCode.NOT_FOUND,
)
members, total = await organization_crud.get_organization_members(
await organization_service.get_organization(db, str(org_id)) # validates exists
members, total = await organization_service.get_organization_members(
db,
organization_id=org_id,
skip=pagination.offset,
@@ -898,21 +814,10 @@ async def admin_add_organization_member(
) -> Any:
"""Add a user to an organization."""
try:
org = await organization_crud.get(db, id=org_id)
if not org:
raise NotFoundError(
message=f"Organization {org_id} not found",
error_code=ErrorCode.NOT_FOUND,
)
org = await organization_service.get_organization(db, str(org_id))
user = await user_service.get_user(db, str(request.user_id))
user = await user_crud.get(db, id=request.user_id)
if not user:
raise NotFoundError(
message=f"User {request.user_id} not found",
error_code=ErrorCode.USER_NOT_FOUND,
)
await organization_crud.add_user(
await organization_service.add_member(
db, organization_id=org_id, user_id=request.user_id, role=request.role
)
@@ -925,14 +830,11 @@ async def admin_add_organization_member(
success=True, message=f"User {user.email} added to organization {org.name}"
)
except ValueError as e:
except DuplicateEntryError as e:
logger.warning(f"Failed to add user to organization: {e!s}")
# Use DuplicateError for "already exists" scenarios
raise DuplicateError(
message=str(e), error_code=ErrorCode.USER_ALREADY_EXISTS, field="user_id"
)
except NotFoundError:
raise
except Exception as e:
logger.error(
f"Error adding member to organization (admin): {e!s}", exc_info=True
@@ -955,20 +857,10 @@ async def admin_remove_organization_member(
) -> Any:
"""Remove a user from an organization."""
try:
org = await organization_crud.get(db, id=org_id)
if not org:
raise NotFoundError(
message=f"Organization {org_id} not found",
error_code=ErrorCode.NOT_FOUND,
)
org = await organization_service.get_organization(db, str(org_id))
user = await user_service.get_user(db, str(user_id))
user = await user_crud.get(db, id=user_id)
if not user:
raise NotFoundError(
message=f"User {user_id} not found", error_code=ErrorCode.USER_NOT_FOUND
)
success = await organization_crud.remove_user(
success = await organization_service.remove_member(
db, organization_id=org_id, user_id=user_id
)
@@ -1022,7 +914,7 @@ async def admin_list_sessions(
"""List all sessions across all users with filtering and pagination."""
try:
# Get sessions with user info (eager loaded to prevent N+1)
sessions, total = await session_crud.get_all_sessions(
sessions, total = await session_service.get_all_sessions(
db,
skip=pagination.offset,
limit=pagination.limit,

View File

@@ -15,16 +15,14 @@ from app.core.auth import (
TokenExpiredError,
TokenInvalidError,
decode_token,
get_password_hash,
)
from app.core.database import get_db
from app.core.exceptions import (
AuthenticationError as AuthError,
DatabaseError,
DuplicateError,
ErrorCode,
)
from app.crud.session import session as session_crud
from app.crud.user import user as user_crud
from app.models.user import User
from app.schemas.common import MessageResponse
from app.schemas.sessions import LogoutRequest, SessionCreate
@@ -39,6 +37,8 @@ from app.schemas.users import (
)
from app.services.auth_service import AuthenticationError, AuthService
from app.services.email_service import email_service
from app.services.session_service import session_service
from app.services.user_service import user_service
from app.utils.device import extract_device_info
from app.utils.security import create_password_reset_token, verify_password_reset_token
@@ -91,7 +91,7 @@ async def _create_login_session(
location_country=device_info.location_country,
)
await session_crud.create_session(db, obj_in=session_data)
await session_service.create_session(db, obj_in=session_data)
logger.info(
f"{login_type.capitalize()} successful: {user.email} from {device_info.device_name} "
@@ -123,8 +123,14 @@ async def register_user(
try:
user = await AuthService.create_user(db, user_data)
return user
except AuthenticationError as e:
except DuplicateError:
# SECURITY: Don't reveal if email exists - generic error message
logger.warning(f"Registration failed: duplicate email {user_data.email}")
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Registration failed. Please check your information and try again.",
)
except AuthError as e:
logger.warning(f"Registration failed: {e!s}")
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
@@ -177,9 +183,6 @@ async def login(
# Handle specific authentication errors like inactive accounts
logger.warning(f"Authentication failed: {e!s}")
raise AuthError(message=str(e), error_code=ErrorCode.INVALID_CREDENTIALS)
except AuthError:
# Re-raise custom auth exceptions without modification
raise
except Exception as e:
# Handle unexpected errors
logger.error(f"Unexpected error during login: {e!s}", exc_info=True)
@@ -226,9 +229,6 @@ async def login_oauth(
except AuthenticationError as e:
logger.warning(f"OAuth authentication failed: {e!s}")
raise AuthError(message=str(e), error_code=ErrorCode.INVALID_CREDENTIALS)
except AuthError:
# Re-raise custom auth exceptions without modification
raise
except Exception as e:
logger.error(f"Unexpected error during OAuth login: {e!s}", exc_info=True)
raise DatabaseError(
@@ -259,7 +259,7 @@ async def refresh_token(
)
# Check if session exists and is active
session = await session_crud.get_active_by_jti(db, jti=refresh_payload.jti)
session = await session_service.get_active_by_jti(db, jti=refresh_payload.jti)
if not session:
logger.warning(
@@ -279,7 +279,7 @@ async def refresh_token(
# Update session with new refresh token JTI and expiration
try:
await session_crud.update_refresh_token(
await session_service.update_refresh_token(
db,
session=session,
new_jti=new_refresh_payload.jti,
@@ -347,7 +347,7 @@ async def request_password_reset(
"""
try:
# Look up user by email
user = await user_crud.get_by_email(db, email=reset_request.email)
user = await user_service.get_by_email(db, email=reset_request.email)
# Only send email if user exists and is active
if user and user.is_active:
@@ -412,31 +412,23 @@ async def confirm_password_reset(
detail="Invalid or expired password reset token",
)
# Look up user
user = await user_crud.get_by_email(db, email=email)
if not user:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail="User not found"
# Reset password via service (validates user exists and is active)
try:
user = await AuthService.reset_password(
db, email=email, new_password=reset_confirm.new_password
)
if not user.is_active:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="User account is inactive",
)
# Update password
user.password_hash = get_password_hash(reset_confirm.new_password)
db.add(user)
await db.commit()
except AuthenticationError as e:
err_msg = str(e)
if "inactive" in err_msg.lower():
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST, detail=err_msg
)
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=err_msg)
# SECURITY: Invalidate all existing sessions after password reset
# This prevents stolen sessions from being used after password change
from app.crud.session import session as session_crud
try:
deactivated_count = await session_crud.deactivate_all_user_sessions(
deactivated_count = await session_service.deactivate_all_user_sessions(
db, user_id=str(user.id)
)
logger.info(
@@ -511,7 +503,7 @@ async def logout(
return MessageResponse(success=True, message="Logged out successfully")
# Find the session by JTI
session = await session_crud.get_by_jti(db, jti=refresh_payload.jti)
session = await session_service.get_by_jti(db, jti=refresh_payload.jti)
if session:
# Verify session belongs to current user (security check)
@@ -526,7 +518,7 @@ async def logout(
)
# Deactivate the session
await session_crud.deactivate(db, session_id=str(session.id))
await session_service.deactivate(db, session_id=str(session.id))
logger.info(
f"User {current_user.id} logged out from {session.device_name} "
@@ -584,7 +576,7 @@ async def logout_all(
"""
try:
# Deactivate all sessions for this user
count = await session_crud.deactivate_all_user_sessions(
count = await session_service.deactivate_all_user_sessions(
db, user_id=str(current_user.id)
)

View File

@@ -25,8 +25,6 @@ from app.core.auth import decode_token
from app.core.config import settings
from app.core.database import get_db
from app.core.exceptions import AuthenticationError as AuthError
from app.crud import oauth_account
from app.crud.session import session as session_crud
from app.models.user import User
from app.schemas.oauth import (
OAuthAccountsListResponse,
@@ -38,6 +36,7 @@ from app.schemas.oauth import (
from app.schemas.sessions import SessionCreate
from app.schemas.users import Token
from app.services.oauth_service import OAuthService
from app.services.session_service import session_service
from app.utils.device import extract_device_info
router = APIRouter()
@@ -82,7 +81,7 @@ async def _create_oauth_login_session(
location_country=device_info.location_country,
)
await session_crud.create_session(db, obj_in=session_data)
await session_service.create_session(db, obj_in=session_data)
logger.info(
f"OAuth login successful: {user.email} via {provider} "
@@ -289,7 +288,7 @@ async def list_accounts(
Returns:
List of linked OAuth accounts
"""
accounts = await oauth_account.get_user_accounts(db, user_id=current_user.id)
accounts = await OAuthService.get_user_accounts(db, user_id=current_user.id)
return OAuthAccountsListResponse(accounts=accounts)
@@ -397,7 +396,7 @@ async def start_link(
)
# Check if user already has this provider linked
existing = await oauth_account.get_user_account_by_provider(
existing = await OAuthService.get_user_account_by_provider(
db, user_id=current_user.id, provider=provider
)
if existing:

View File

@@ -34,7 +34,6 @@ from app.api.dependencies.auth import (
)
from app.core.config import settings
from app.core.database import get_db
from app.crud import oauth_client as oauth_client_crud
from app.models.user import User
from app.schemas.oauth import (
OAuthClientCreate,
@@ -656,7 +655,7 @@ async def introspect(
)
except Exception as e:
logger.warning(f"Token introspection error: {e}")
return OAuthTokenIntrospectionResponse(active=False)
return OAuthTokenIntrospectionResponse(active=False) # pyright: ignore[reportCallIssue]
# ============================================================================
@@ -712,7 +711,7 @@ async def register_client(
client_type=client_type,
)
client, secret = await oauth_client_crud.create_client(db, obj_in=client_data)
client, secret = await provider_service.register_client(db, client_data)
# Update MCP server URL if provided
if mcp_server_url:
@@ -750,7 +749,7 @@ async def list_clients(
current_user: User = Depends(get_current_superuser),
) -> list[OAuthClientResponse]:
"""List all OAuth clients."""
clients = await oauth_client_crud.get_all_clients(db)
clients = await provider_service.list_clients(db)
return [OAuthClientResponse.model_validate(c) for c in clients]
@@ -776,7 +775,7 @@ async def delete_client(
detail="Client not found",
)
await oauth_client_crud.delete_client(db, client_id=client_id)
await provider_service.delete_client_by_id(db, client_id=client_id)
# ============================================================================
@@ -797,30 +796,7 @@ async def list_my_consents(
current_user: User = Depends(get_current_active_user),
) -> list[dict]:
"""List applications the user has authorized."""
from sqlalchemy import select
from app.models.oauth_client import OAuthClient
from app.models.oauth_provider_token import OAuthConsent
result = await db.execute(
select(OAuthConsent, OAuthClient)
.join(OAuthClient, OAuthConsent.client_id == OAuthClient.client_id)
.where(OAuthConsent.user_id == current_user.id)
)
rows = result.all()
return [
{
"client_id": consent.client_id,
"client_name": client.client_name,
"client_description": client.client_description,
"granted_scopes": consent.granted_scopes.split()
if consent.granted_scopes
else [],
"granted_at": consent.created_at.isoformat(),
}
for consent, client in rows
]
return await provider_service.list_user_consents(db, user_id=current_user.id)
@router.delete(

View File

@@ -15,8 +15,6 @@ from sqlalchemy.ext.asyncio import AsyncSession
from app.api.dependencies.auth import get_current_user
from app.api.dependencies.permissions import require_org_admin, require_org_membership
from app.core.database import get_db
from app.core.exceptions import ErrorCode, NotFoundError
from app.crud.organization import organization as organization_crud
from app.models.user import User
from app.schemas.common import (
PaginatedResponse,
@@ -28,6 +26,7 @@ from app.schemas.organizations import (
OrganizationResponse,
OrganizationUpdate,
)
from app.services.organization_service import organization_service
logger = logging.getLogger(__name__)
@@ -54,7 +53,7 @@ async def get_my_organizations(
"""
try:
# Get all org data in single query with JOIN and subquery
orgs_data = await organization_crud.get_user_organizations_with_details(
orgs_data = await organization_service.get_user_organizations_with_details(
db, user_id=current_user.id, is_active=is_active
)
@@ -100,13 +99,7 @@ async def get_organization(
User must be a member of the organization.
"""
try:
org = await organization_crud.get(db, id=organization_id)
if not org: # pragma: no cover - Permission check prevents this (see docs/UNREACHABLE_DEFENSIVE_CODE_ANALYSIS.md)
raise NotFoundError(
detail=f"Organization {organization_id} not found",
error_code=ErrorCode.NOT_FOUND,
)
org = await organization_service.get_organization(db, str(organization_id))
org_dict = {
"id": org.id,
"name": org.name,
@@ -116,14 +109,12 @@ async def get_organization(
"settings": org.settings,
"created_at": org.created_at,
"updated_at": org.updated_at,
"member_count": await organization_crud.get_member_count(
"member_count": await organization_service.get_member_count(
db, organization_id=org.id
),
}
return OrganizationResponse(**org_dict)
except NotFoundError: # pragma: no cover - See above
raise
except Exception as e:
logger.error(f"Error getting organization: {e!s}", exc_info=True)
raise
@@ -149,7 +140,7 @@ async def get_organization_members(
User must be a member of the organization to view members.
"""
try:
members, total = await organization_crud.get_organization_members(
members, total = await organization_service.get_organization_members(
db,
organization_id=organization_id,
skip=pagination.offset,
@@ -192,14 +183,10 @@ async def update_organization(
Requires owner or admin role in the organization.
"""
try:
org = await organization_crud.get(db, id=organization_id)
if not org: # pragma: no cover - Permission check prevents this (see docs/UNREACHABLE_DEFENSIVE_CODE_ANALYSIS.md)
raise NotFoundError(
detail=f"Organization {organization_id} not found",
error_code=ErrorCode.NOT_FOUND,
)
updated_org = await organization_crud.update(db, db_obj=org, obj_in=org_in)
org = await organization_service.get_organization(db, str(organization_id))
updated_org = await organization_service.update_organization(
db, org=org, obj_in=org_in
)
logger.info(
f"User {current_user.email} updated organization {updated_org.name}"
)
@@ -213,14 +200,12 @@ async def update_organization(
"settings": updated_org.settings,
"created_at": updated_org.created_at,
"updated_at": updated_org.updated_at,
"member_count": await organization_crud.get_member_count(
"member_count": await organization_service.get_member_count(
db, organization_id=updated_org.id
),
}
return OrganizationResponse(**org_dict)
except NotFoundError: # pragma: no cover - See above
raise
except Exception as e:
logger.error(f"Error updating organization: {e!s}", exc_info=True)
raise

View File

@@ -17,10 +17,10 @@ from app.api.dependencies.auth import get_current_user
from app.core.auth import decode_token
from app.core.database import get_db
from app.core.exceptions import AuthorizationError, ErrorCode, NotFoundError
from app.crud.session import session as session_crud
from app.models.user import User
from app.schemas.common import MessageResponse
from app.schemas.sessions import SessionListResponse, SessionResponse
from app.services.session_service import session_service
router = APIRouter()
logger = logging.getLogger(__name__)
@@ -60,7 +60,7 @@ async def list_my_sessions(
"""
try:
# Get all active sessions for user
sessions = await session_crud.get_user_sessions(
sessions = await session_service.get_user_sessions(
db, user_id=str(current_user.id), active_only=True
)
@@ -150,7 +150,7 @@ async def revoke_session(
"""
try:
# Get the session
session = await session_crud.get(db, id=str(session_id))
session = await session_service.get_session(db, str(session_id))
if not session:
raise NotFoundError(
@@ -170,7 +170,7 @@ async def revoke_session(
)
# Deactivate the session
await session_crud.deactivate(db, session_id=str(session_id))
await session_service.deactivate(db, session_id=str(session_id))
logger.info(
f"User {current_user.id} revoked session {session_id} "
@@ -224,7 +224,7 @@ async def cleanup_expired_sessions(
"""
try:
# Use optimized bulk DELETE instead of N individual deletes
deleted_count = await session_crud.cleanup_expired_for_user(
deleted_count = await session_service.cleanup_expired_for_user(
db, user_id=str(current_user.id)
)

View File

@@ -13,8 +13,7 @@ from sqlalchemy.ext.asyncio import AsyncSession
from app.api.dependencies.auth import get_current_superuser, get_current_user
from app.core.database import get_db
from app.core.exceptions import AuthorizationError, ErrorCode, NotFoundError
from app.crud.user import user as user_crud
from app.core.exceptions import AuthorizationError, ErrorCode
from app.models.user import User
from app.schemas.common import (
MessageResponse,
@@ -25,6 +24,7 @@ from app.schemas.common import (
)
from app.schemas.users import PasswordChange, UserResponse, UserUpdate
from app.services.auth_service import AuthenticationError, AuthService
from app.services.user_service import user_service
logger = logging.getLogger(__name__)
@@ -71,7 +71,7 @@ async def list_users(
filters["is_superuser"] = is_superuser
# Get paginated users with total count
users, total = await user_crud.get_multi_with_total(
users, total = await user_service.list_users(
db,
skip=pagination.offset,
limit=pagination.limit,
@@ -107,7 +107,9 @@ async def list_users(
""",
operation_id="get_current_user_profile",
)
def get_current_user_profile(current_user: User = Depends(get_current_user)) -> Any:
async def get_current_user_profile(
current_user: User = Depends(get_current_user),
) -> Any:
"""Get current user's profile."""
return current_user
@@ -138,8 +140,8 @@ async def update_current_user(
Users cannot elevate their own permissions (protected by UserUpdate schema validator).
"""
try:
updated_user = await user_crud.update(
db, db_obj=current_user, obj_in=user_update
updated_user = await user_service.update_user(
db, user=current_user, obj_in=user_update
)
logger.info(f"User {current_user.id} updated their profile")
return updated_user
@@ -190,13 +192,7 @@ async def get_user_by_id(
)
# Get user
user = await user_crud.get(db, id=str(user_id))
if not user:
raise NotFoundError(
message=f"User with id {user_id} not found",
error_code=ErrorCode.USER_NOT_FOUND,
)
user = await user_service.get_user(db, str(user_id))
return user
@@ -241,15 +237,10 @@ async def update_user(
)
# Get user
user = await user_crud.get(db, id=str(user_id))
if not user:
raise NotFoundError(
message=f"User with id {user_id} not found",
error_code=ErrorCode.USER_NOT_FOUND,
)
user = await user_service.get_user(db, str(user_id))
try:
updated_user = await user_crud.update(db, db_obj=user, obj_in=user_update)
updated_user = await user_service.update_user(db, user=user, obj_in=user_update)
logger.info(f"User {user_id} updated by {current_user.id}")
return updated_user
except ValueError as e:
@@ -346,17 +337,12 @@ async def delete_user(
error_code=ErrorCode.INSUFFICIENT_PERMISSIONS,
)
# Get user
user = await user_crud.get(db, id=str(user_id))
if not user:
raise NotFoundError(
message=f"User with id {user_id} not found",
error_code=ErrorCode.USER_NOT_FOUND,
)
# Get user (raises NotFoundError if not found)
await user_service.get_user(db, str(user_id))
try:
# Use soft delete instead of hard delete
await user_crud.soft_delete(db, id=str(user_id))
await user_service.soft_delete_user(db, str(user_id))
logger.info(f"User {user_id} soft-deleted by {current_user.id}")
return MessageResponse(
success=True, message=f"User {user_id} deleted successfully"

View File

@@ -222,7 +222,7 @@ async def http_exception_handler(request: Request, exc: HTTPException) -> JSONRe
)
error_response = ErrorResponse(
errors=[ErrorDetail(code=error_code, message=str(exc.detail))]
errors=[ErrorDetail(code=error_code, message=str(exc.detail), field=None)]
)
return JSONResponse(
@@ -254,7 +254,7 @@ async def unhandled_exception_handler(request: Request, exc: Exception) -> JSONR
message = f"{type(exc).__name__}: {exc!s}"
error_response = ErrorResponse(
errors=[ErrorDetail(code=ErrorCode.INTERNAL_ERROR, message=message)]
errors=[ErrorDetail(code=ErrorCode.INTERNAL_ERROR, message=message, field=None)]
)
return JSONResponse(

View File

@@ -0,0 +1,26 @@
"""
Custom exceptions for the repository layer.
These exceptions allow services and routes to handle database-level errors
with proper semantics, without leaking SQLAlchemy internals.
"""
class RepositoryError(Exception):
"""Base for all repository-layer errors."""
class DuplicateEntryError(RepositoryError):
"""Raised on unique constraint violations. Maps to HTTP 409 Conflict."""
class IntegrityConstraintError(RepositoryError):
"""Raised on FK or check constraint violations."""
class RecordNotFoundError(RepositoryError):
"""Raised when an expected record doesn't exist."""
class InvalidInputError(RepositoryError):
"""Raised on bad pagination params, invalid UUIDs, or other invalid inputs."""

View File

@@ -1,14 +0,0 @@
# app/crud/__init__.py
from .oauth import oauth_account, oauth_client, oauth_state
from .organization import organization
from .session import session as session_crud
from .user import user
__all__ = [
"oauth_account",
"oauth_client",
"oauth_state",
"organization",
"session_crud",
"user",
]

View File

@@ -1,718 +0,0 @@
"""
Async CRUD operations for OAuth models using SQLAlchemy 2.0 patterns.
Provides operations for:
- OAuthAccount: Managing linked OAuth provider accounts
- OAuthState: CSRF protection state during OAuth flows
- OAuthClient: Registered OAuth clients (provider mode skeleton)
"""
import logging
import secrets
from datetime import UTC, datetime
from uuid import UUID
from pydantic import BaseModel
from sqlalchemy import and_, delete, select
from sqlalchemy.exc import IntegrityError
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import joinedload
from app.crud.base import CRUDBase
from app.models.oauth_account import OAuthAccount
from app.models.oauth_client import OAuthClient
from app.models.oauth_state import OAuthState
from app.schemas.oauth import OAuthAccountCreate, OAuthClientCreate, OAuthStateCreate
logger = logging.getLogger(__name__)
# ============================================================================
# OAuth Account CRUD
# ============================================================================
class EmptySchema(BaseModel):
"""Placeholder schema for CRUD operations that don't need update schemas."""
class CRUDOAuthAccount(CRUDBase[OAuthAccount, OAuthAccountCreate, EmptySchema]):
"""CRUD operations for OAuth account links."""
async def get_by_provider_id(
self,
db: AsyncSession,
*,
provider: str,
provider_user_id: str,
) -> OAuthAccount | None:
"""
Get OAuth account by provider and provider user ID.
Args:
db: Database session
provider: OAuth provider name (google, github)
provider_user_id: User ID from the OAuth provider
Returns:
OAuthAccount if found, None otherwise
"""
try:
result = await db.execute(
select(OAuthAccount)
.where(
and_(
OAuthAccount.provider == provider,
OAuthAccount.provider_user_id == provider_user_id,
)
)
.options(joinedload(OAuthAccount.user))
)
return result.scalar_one_or_none()
except Exception as e: # pragma: no cover # pragma: no cover
logger.error(
f"Error getting OAuth account for {provider}:{provider_user_id}: {e!s}"
)
raise
async def get_by_provider_email(
self,
db: AsyncSession,
*,
provider: str,
email: str,
) -> OAuthAccount | None:
"""
Get OAuth account by provider and email.
Used for auto-linking existing accounts by email.
Args:
db: Database session
provider: OAuth provider name
email: Email address from the OAuth provider
Returns:
OAuthAccount if found, None otherwise
"""
try:
result = await db.execute(
select(OAuthAccount)
.where(
and_(
OAuthAccount.provider == provider,
OAuthAccount.provider_email == email,
)
)
.options(joinedload(OAuthAccount.user))
)
return result.scalar_one_or_none()
except Exception as e: # pragma: no cover # pragma: no cover
logger.error(
f"Error getting OAuth account for {provider} email {email}: {e!s}"
)
raise
async def get_user_accounts(
self,
db: AsyncSession,
*,
user_id: str | UUID,
) -> list[OAuthAccount]:
"""
Get all OAuth accounts linked to a user.
Args:
db: Database session
user_id: User ID
Returns:
List of OAuthAccount objects
"""
try:
user_uuid = UUID(str(user_id)) if isinstance(user_id, str) else user_id
result = await db.execute(
select(OAuthAccount)
.where(OAuthAccount.user_id == user_uuid)
.order_by(OAuthAccount.created_at.desc())
)
return list(result.scalars().all())
except Exception as e: # pragma: no cover
logger.error(f"Error getting OAuth accounts for user {user_id}: {e!s}")
raise
async def get_user_account_by_provider(
self,
db: AsyncSession,
*,
user_id: str | UUID,
provider: str,
) -> OAuthAccount | None:
"""
Get a specific OAuth account for a user and provider.
Args:
db: Database session
user_id: User ID
provider: OAuth provider name
Returns:
OAuthAccount if found, None otherwise
"""
try:
user_uuid = UUID(str(user_id)) if isinstance(user_id, str) else user_id
result = await db.execute(
select(OAuthAccount).where(
and_(
OAuthAccount.user_id == user_uuid,
OAuthAccount.provider == provider,
)
)
)
return result.scalar_one_or_none()
except Exception as e: # pragma: no cover
logger.error(
f"Error getting OAuth account for user {user_id}, provider {provider}: {e!s}"
)
raise
async def create_account(
self, db: AsyncSession, *, obj_in: OAuthAccountCreate
) -> OAuthAccount:
"""
Create a new OAuth account link.
Args:
db: Database session
obj_in: OAuth account creation data
Returns:
Created OAuthAccount
Raises:
ValueError: If account already exists or creation fails
"""
try:
db_obj = OAuthAccount(
user_id=obj_in.user_id,
provider=obj_in.provider,
provider_user_id=obj_in.provider_user_id,
provider_email=obj_in.provider_email,
access_token_encrypted=obj_in.access_token_encrypted,
refresh_token_encrypted=obj_in.refresh_token_encrypted,
token_expires_at=obj_in.token_expires_at,
)
db.add(db_obj)
await db.commit()
await db.refresh(db_obj)
logger.info(
f"OAuth account created: {obj_in.provider} linked to user {obj_in.user_id}"
)
return db_obj
except IntegrityError as e: # pragma: no cover
await db.rollback()
error_msg = str(e.orig) if hasattr(e, "orig") else str(e)
if "uq_oauth_provider_user" in error_msg.lower():
logger.warning(
f"OAuth account already exists: {obj_in.provider}:{obj_in.provider_user_id}"
)
raise ValueError(
f"This {obj_in.provider} account is already linked to another user"
)
logger.error(f"Integrity error creating OAuth account: {error_msg}")
raise ValueError(f"Failed to create OAuth account: {error_msg}")
except Exception as e: # pragma: no cover
await db.rollback()
logger.error(f"Error creating OAuth account: {e!s}", exc_info=True)
raise
async def delete_account(
self,
db: AsyncSession,
*,
user_id: str | UUID,
provider: str,
) -> bool:
"""
Delete an OAuth account link.
Args:
db: Database session
user_id: User ID
provider: OAuth provider name
Returns:
True if deleted, False if not found
"""
try:
user_uuid = UUID(str(user_id)) if isinstance(user_id, str) else user_id
result = await db.execute(
delete(OAuthAccount).where(
and_(
OAuthAccount.user_id == user_uuid,
OAuthAccount.provider == provider,
)
)
)
await db.commit()
deleted = result.rowcount > 0
if deleted:
logger.info(
f"OAuth account deleted: {provider} unlinked from user {user_id}"
)
else:
logger.warning(
f"OAuth account not found for deletion: {provider} for user {user_id}"
)
return deleted
except Exception as e: # pragma: no cover
await db.rollback()
logger.error(
f"Error deleting OAuth account {provider} for user {user_id}: {e!s}"
)
raise
async def update_tokens(
self,
db: AsyncSession,
*,
account: OAuthAccount,
access_token_encrypted: str | None = None,
refresh_token_encrypted: str | None = None,
token_expires_at: datetime | None = None,
) -> OAuthAccount:
"""
Update OAuth tokens for an account.
Args:
db: Database session
account: OAuthAccount to update
access_token_encrypted: New encrypted access token
refresh_token_encrypted: New encrypted refresh token
token_expires_at: New token expiration time
Returns:
Updated OAuthAccount
"""
try:
if access_token_encrypted is not None:
account.access_token_encrypted = access_token_encrypted
if refresh_token_encrypted is not None:
account.refresh_token_encrypted = refresh_token_encrypted
if token_expires_at is not None:
account.token_expires_at = token_expires_at
db.add(account)
await db.commit()
await db.refresh(account)
return account
except Exception as e: # pragma: no cover
await db.rollback()
logger.error(f"Error updating OAuth tokens: {e!s}")
raise
# ============================================================================
# OAuth State CRUD
# ============================================================================
class CRUDOAuthState(CRUDBase[OAuthState, OAuthStateCreate, EmptySchema]):
"""CRUD operations for OAuth state (CSRF protection)."""
async def create_state(
self, db: AsyncSession, *, obj_in: OAuthStateCreate
) -> OAuthState:
"""
Create a new OAuth state for CSRF protection.
Args:
db: Database session
obj_in: OAuth state creation data
Returns:
Created OAuthState
"""
try:
db_obj = OAuthState(
state=obj_in.state,
code_verifier=obj_in.code_verifier,
nonce=obj_in.nonce,
provider=obj_in.provider,
redirect_uri=obj_in.redirect_uri,
user_id=obj_in.user_id,
expires_at=obj_in.expires_at,
)
db.add(db_obj)
await db.commit()
await db.refresh(db_obj)
logger.debug(f"OAuth state created for {obj_in.provider}")
return db_obj
except IntegrityError as e: # pragma: no cover
await db.rollback()
# State collision (extremely rare with cryptographic random)
error_msg = str(e.orig) if hasattr(e, "orig") else str(e)
logger.error(f"OAuth state collision: {error_msg}")
raise ValueError("Failed to create OAuth state, please retry")
except Exception as e: # pragma: no cover
await db.rollback()
logger.error(f"Error creating OAuth state: {e!s}", exc_info=True)
raise
async def get_and_consume_state(
self, db: AsyncSession, *, state: str
) -> OAuthState | None:
"""
Get and delete OAuth state (consume it).
This ensures each state can only be used once (replay protection).
Args:
db: Database session
state: State string to look up
Returns:
OAuthState if found and valid, None otherwise
"""
try:
# Get the state
result = await db.execute(
select(OAuthState).where(OAuthState.state == state)
)
db_obj = result.scalar_one_or_none()
if db_obj is None:
logger.warning(f"OAuth state not found: {state[:8]}...")
return None
# Check expiration
# Handle both timezone-aware and timezone-naive datetimes
now = datetime.now(UTC)
expires_at = db_obj.expires_at
if expires_at.tzinfo is None:
# SQLite returns naive datetimes, assume UTC
expires_at = expires_at.replace(tzinfo=UTC)
if expires_at < now:
logger.warning(f"OAuth state expired: {state[:8]}...")
await db.delete(db_obj)
await db.commit()
return None
# Delete it (consume)
await db.delete(db_obj)
await db.commit()
logger.debug(f"OAuth state consumed: {state[:8]}...")
return db_obj
except Exception as e: # pragma: no cover
await db.rollback()
logger.error(f"Error consuming OAuth state: {e!s}")
raise
async def cleanup_expired(self, db: AsyncSession) -> int:
"""
Clean up expired OAuth states.
Should be called periodically to remove stale states.
Args:
db: Database session
Returns:
Number of states deleted
"""
try:
now = datetime.now(UTC)
stmt = delete(OAuthState).where(OAuthState.expires_at < now)
result = await db.execute(stmt)
await db.commit()
count = result.rowcount
if count > 0:
logger.info(f"Cleaned up {count} expired OAuth states")
return count
except Exception as e: # pragma: no cover
await db.rollback()
logger.error(f"Error cleaning up expired OAuth states: {e!s}")
raise
# ============================================================================
# OAuth Client CRUD (Provider Mode - Skeleton)
# ============================================================================
class CRUDOAuthClient(CRUDBase[OAuthClient, OAuthClientCreate, EmptySchema]):
"""
CRUD operations for OAuth clients (provider mode).
This is a skeleton implementation for MCP client registration.
Full implementation can be expanded when needed.
"""
async def get_by_client_id(
self, db: AsyncSession, *, client_id: str
) -> OAuthClient | None:
"""
Get OAuth client by client_id.
Args:
db: Database session
client_id: OAuth client ID
Returns:
OAuthClient if found, None otherwise
"""
try:
result = await db.execute(
select(OAuthClient).where(
and_(
OAuthClient.client_id == client_id,
OAuthClient.is_active == True, # noqa: E712
)
)
)
return result.scalar_one_or_none()
except Exception as e: # pragma: no cover
logger.error(f"Error getting OAuth client {client_id}: {e!s}")
raise
async def create_client(
self,
db: AsyncSession,
*,
obj_in: OAuthClientCreate,
owner_user_id: UUID | None = None,
) -> tuple[OAuthClient, str | None]:
"""
Create a new OAuth client.
Args:
db: Database session
obj_in: OAuth client creation data
owner_user_id: Optional owner user ID
Returns:
Tuple of (created OAuthClient, client_secret or None for public clients)
"""
try:
# Generate client_id
client_id = secrets.token_urlsafe(32)
# Generate client_secret for confidential clients
client_secret = None
client_secret_hash = None
if obj_in.client_type == "confidential":
client_secret = secrets.token_urlsafe(48)
# SECURITY: Use bcrypt for secret storage (not SHA-256)
# bcrypt is computationally expensive, making brute-force attacks infeasible
from app.core.auth import get_password_hash
client_secret_hash = get_password_hash(client_secret)
db_obj = OAuthClient(
client_id=client_id,
client_secret_hash=client_secret_hash,
client_name=obj_in.client_name,
client_description=obj_in.client_description,
client_type=obj_in.client_type,
redirect_uris=obj_in.redirect_uris,
allowed_scopes=obj_in.allowed_scopes,
owner_user_id=owner_user_id,
is_active=True,
)
db.add(db_obj)
await db.commit()
await db.refresh(db_obj)
logger.info(
f"OAuth client created: {obj_in.client_name} ({client_id[:8]}...)"
)
return db_obj, client_secret
except IntegrityError as e: # pragma: no cover
await db.rollback()
error_msg = str(e.orig) if hasattr(e, "orig") else str(e)
logger.error(f"Error creating OAuth client: {error_msg}")
raise ValueError(f"Failed to create OAuth client: {error_msg}")
except Exception as e: # pragma: no cover
await db.rollback()
logger.error(f"Error creating OAuth client: {e!s}", exc_info=True)
raise
async def deactivate_client(
self, db: AsyncSession, *, client_id: str
) -> OAuthClient | None:
"""
Deactivate an OAuth client.
Args:
db: Database session
client_id: OAuth client ID
Returns:
Deactivated OAuthClient if found, None otherwise
"""
try:
client = await self.get_by_client_id(db, client_id=client_id)
if client is None:
return None
client.is_active = False
db.add(client)
await db.commit()
await db.refresh(client)
logger.info(f"OAuth client deactivated: {client.client_name}")
return client
except Exception as e: # pragma: no cover
await db.rollback()
logger.error(f"Error deactivating OAuth client {client_id}: {e!s}")
raise
async def validate_redirect_uri(
self, db: AsyncSession, *, client_id: str, redirect_uri: str
) -> bool:
"""
Validate that a redirect URI is allowed for a client.
Args:
db: Database session
client_id: OAuth client ID
redirect_uri: Redirect URI to validate
Returns:
True if valid, False otherwise
"""
try:
client = await self.get_by_client_id(db, client_id=client_id)
if client is None:
return False
return redirect_uri in (client.redirect_uris or [])
except Exception as e: # pragma: no cover
logger.error(f"Error validating redirect URI: {e!s}")
return False
async def verify_client_secret(
self, db: AsyncSession, *, client_id: str, client_secret: str
) -> bool:
"""
Verify client credentials.
Args:
db: Database session
client_id: OAuth client ID
client_secret: Client secret to verify
Returns:
True if valid, False otherwise
"""
try:
result = await db.execute(
select(OAuthClient).where(
and_(
OAuthClient.client_id == client_id,
OAuthClient.is_active == True, # noqa: E712
)
)
)
client = result.scalar_one_or_none()
if client is None or client.client_secret_hash is None:
return False
# SECURITY: Verify secret using bcrypt (not SHA-256)
# This supports both old SHA-256 hashes (for migration) and new bcrypt hashes
from app.core.auth import verify_password
stored_hash: str = str(client.client_secret_hash)
# Check if it's a bcrypt hash (starts with $2b$) or legacy SHA-256
if stored_hash.startswith("$2"):
# New bcrypt format
return verify_password(client_secret, stored_hash)
else:
# Legacy SHA-256 format - still support for migration
import hashlib
secret_hash = hashlib.sha256(client_secret.encode()).hexdigest()
return secrets.compare_digest(stored_hash, secret_hash)
except Exception as e: # pragma: no cover
logger.error(f"Error verifying client secret: {e!s}")
return False
async def get_all_clients(
self, db: AsyncSession, *, include_inactive: bool = False
) -> list[OAuthClient]:
"""
Get all OAuth clients.
Args:
db: Database session
include_inactive: Whether to include inactive clients
Returns:
List of OAuthClient objects
"""
try:
query = select(OAuthClient).order_by(OAuthClient.created_at.desc())
if not include_inactive:
query = query.where(OAuthClient.is_active == True) # noqa: E712
result = await db.execute(query)
return list(result.scalars().all())
except Exception as e: # pragma: no cover
logger.error(f"Error getting all OAuth clients: {e!s}")
raise
async def delete_client(self, db: AsyncSession, *, client_id: str) -> bool:
"""
Delete an OAuth client permanently.
Note: This will cascade delete related records (tokens, consents, etc.)
due to foreign key constraints.
Args:
db: Database session
client_id: OAuth client ID
Returns:
True if deleted, False if not found
"""
try:
result = await db.execute(
delete(OAuthClient).where(OAuthClient.client_id == client_id)
)
await db.commit()
deleted = result.rowcount > 0
if deleted:
logger.info(f"OAuth client deleted: {client_id}")
else:
logger.warning(f"OAuth client not found for deletion: {client_id}")
return deleted
except Exception as e: # pragma: no cover
await db.rollback()
logger.error(f"Error deleting OAuth client {client_id}: {e!s}")
raise
# ============================================================================
# Singleton instances
# ============================================================================
oauth_account = CRUDOAuthAccount(OAuthAccount)
oauth_state = CRUDOAuthState(OAuthState)
oauth_client = CRUDOAuthClient(OAuthClient)

View File

@@ -16,10 +16,10 @@ from sqlalchemy import select, text
from app.core.config import settings
from app.core.database import SessionLocal, engine
from app.crud.user import user as user_crud
from app.models.organization import Organization
from app.models.user import User
from app.models.user_organization import UserOrganization
from app.repositories.user import user_repo as user_crud
from app.schemas.users import UserCreate
logger = logging.getLogger(__name__)

View File

@@ -36,9 +36,9 @@ class OAuthAccount(Base, UUIDMixin, TimestampMixin):
) # Email from provider (for reference)
# Optional: store provider tokens for API access
# These should be encrypted at rest in production
access_token_encrypted = Column(String(2048), nullable=True)
refresh_token_encrypted = Column(String(2048), nullable=True)
# TODO: Encrypt these at rest in production (requires key management infrastructure)
access_token = Column(String(2048), nullable=True)
refresh_token = Column(String(2048), nullable=True)
token_expires_at = Column(DateTime(timezone=True), nullable=True)
# Relationship

View File

@@ -92,7 +92,7 @@ class OAuthAuthorizationCode(Base, UUIDMixin, TimestampMixin):
# Handle both timezone-aware and naive datetimes from DB
if expires_at.tzinfo is None:
expires_at = expires_at.replace(tzinfo=UTC)
return now > expires_at
return bool(now > expires_at)
@property
def is_valid(self) -> bool:

View File

@@ -99,7 +99,7 @@ class OAuthProviderRefreshToken(Base, UUIDMixin, TimestampMixin):
# Handle both timezone-aware and naive datetimes from DB
if expires_at.tzinfo is None:
expires_at = expires_at.replace(tzinfo=UTC)
return now > expires_at
return bool(now > expires_at)
@property
def is_valid(self) -> bool:

View File

@@ -76,7 +76,11 @@ class UserSession(Base, UUIDMixin, TimestampMixin):
"""Check if session has expired."""
from datetime import datetime
return self.expires_at < datetime.now(UTC)
now = datetime.now(UTC)
expires_at = self.expires_at
if expires_at.tzinfo is None:
expires_at = expires_at.replace(tzinfo=UTC)
return bool(expires_at < now)
def to_dict(self):
"""Convert session to dictionary for serialization."""

View File

@@ -0,0 +1,39 @@
# app/repositories/__init__.py
"""Repository layer — all database access goes through these classes."""
from app.repositories.oauth_account import OAuthAccountRepository, oauth_account_repo
from app.repositories.oauth_authorization_code import (
OAuthAuthorizationCodeRepository,
oauth_authorization_code_repo,
)
from app.repositories.oauth_client import OAuthClientRepository, oauth_client_repo
from app.repositories.oauth_consent import OAuthConsentRepository, oauth_consent_repo
from app.repositories.oauth_provider_token import (
OAuthProviderTokenRepository,
oauth_provider_token_repo,
)
from app.repositories.oauth_state import OAuthStateRepository, oauth_state_repo
from app.repositories.organization import OrganizationRepository, organization_repo
from app.repositories.session import SessionRepository, session_repo
from app.repositories.user import UserRepository, user_repo
__all__ = [
"OAuthAccountRepository",
"OAuthAuthorizationCodeRepository",
"OAuthClientRepository",
"OAuthConsentRepository",
"OAuthProviderTokenRepository",
"OAuthStateRepository",
"OrganizationRepository",
"SessionRepository",
"UserRepository",
"oauth_account_repo",
"oauth_authorization_code_repo",
"oauth_client_repo",
"oauth_consent_repo",
"oauth_provider_token_repo",
"oauth_state_repo",
"organization_repo",
"session_repo",
"user_repo",
]

View File

@@ -1,6 +1,6 @@
# app/crud/base_async.py
# app/repositories/base.py
"""
Async CRUD operations base class using SQLAlchemy 2.0 async patterns.
Base repository class for async CRUD operations using SQLAlchemy 2.0 async patterns.
Provides reusable create, read, update, and delete operations for all models.
"""
@@ -18,6 +18,11 @@ from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import Load
from app.core.database import Base
from app.core.repository_exceptions import (
DuplicateEntryError,
IntegrityConstraintError,
InvalidInputError,
)
logger = logging.getLogger(__name__)
@@ -26,16 +31,16 @@ CreateSchemaType = TypeVar("CreateSchemaType", bound=BaseModel)
UpdateSchemaType = TypeVar("UpdateSchemaType", bound=BaseModel)
class CRUDBase[
class BaseRepository[
ModelType: Base,
CreateSchemaType: BaseModel,
UpdateSchemaType: BaseModel,
]:
"""Async CRUD operations for a model."""
"""Async repository operations for a model."""
def __init__(self, model: type[ModelType]):
"""
CRUD object with default async methods to Create, Read, Update, Delete.
Repository object with default async methods to Create, Read, Update, Delete.
Parameters:
model: A SQLAlchemy model class
@@ -56,13 +61,7 @@ class CRUDBase[
Returns:
Model instance or None if not found
Example:
# Eager load user relationship
from sqlalchemy.orm import joinedload
session = await session_crud.get(db, id=session_id, options=[joinedload(UserSession.user)])
"""
# Validate UUID format and convert to UUID object if string
try:
if isinstance(id, uuid.UUID):
uuid_obj = id
@@ -75,7 +74,6 @@ class CRUDBase[
try:
query = select(self.model).where(self.model.id == uuid_obj)
# Apply eager loading options if provided
if options:
for option in options:
query = query.options(option)
@@ -96,28 +94,17 @@ class CRUDBase[
) -> list[ModelType]:
"""
Get multiple records with pagination validation and optional eager loading.
Args:
db: Database session
skip: Number of records to skip
limit: Maximum number of records to return
options: Optional list of SQLAlchemy load options for eager loading
Returns:
List of model instances
"""
# Validate pagination parameters
if skip < 0:
raise ValueError("skip must be non-negative")
raise InvalidInputError("skip must be non-negative")
if limit < 0:
raise ValueError("limit must be non-negative")
raise InvalidInputError("limit must be non-negative")
if limit > 1000:
raise ValueError("Maximum limit is 1000")
raise InvalidInputError("Maximum limit is 1000")
try:
query = select(self.model).offset(skip).limit(limit)
query = select(self.model).order_by(self.model.id).offset(skip).limit(limit)
# Apply eager loading options if provided
if options:
for option in options:
query = query.options(option)
@@ -136,9 +123,8 @@ class CRUDBase[
"""Create a new record with error handling.
NOTE: This method is defensive code that's never called in practice.
All CRUD subclasses (CRUDUser, CRUDOrganization, CRUDSession) override this method
with their own implementations, so the base implementation and its exception handlers
are never executed. Marked as pragma: no cover to avoid false coverage gaps.
All repository subclasses override this method with their own implementations.
Marked as pragma: no cover to avoid false coverage gaps.
"""
try: # pragma: no cover
obj_in_data = jsonable_encoder(obj_in)
@@ -154,15 +140,15 @@ class CRUDBase[
logger.warning(
f"Duplicate entry attempted for {self.model.__name__}: {error_msg}"
)
raise ValueError(
raise DuplicateEntryError(
f"A {self.model.__name__} with this data already exists"
)
logger.error(f"Integrity error creating {self.model.__name__}: {error_msg}")
raise ValueError(f"Database integrity error: {error_msg}")
raise IntegrityConstraintError(f"Database integrity error: {error_msg}")
except (OperationalError, DataError) as e: # pragma: no cover
await db.rollback()
logger.error(f"Database error creating {self.model.__name__}: {e!s}")
raise ValueError(f"Database operation failed: {e!s}")
raise IntegrityConstraintError(f"Database operation failed: {e!s}")
except Exception as e: # pragma: no cover
await db.rollback()
logger.error(
@@ -200,15 +186,15 @@ class CRUDBase[
logger.warning(
f"Duplicate entry attempted for {self.model.__name__}: {error_msg}"
)
raise ValueError(
raise DuplicateEntryError(
f"A {self.model.__name__} with this data already exists"
)
logger.error(f"Integrity error updating {self.model.__name__}: {error_msg}")
raise ValueError(f"Database integrity error: {error_msg}")
raise IntegrityConstraintError(f"Database integrity error: {error_msg}")
except (OperationalError, DataError) as e:
await db.rollback()
logger.error(f"Database error updating {self.model.__name__}: {e!s}")
raise ValueError(f"Database operation failed: {e!s}")
raise IntegrityConstraintError(f"Database operation failed: {e!s}")
except Exception as e:
await db.rollback()
logger.error(
@@ -218,7 +204,6 @@ class CRUDBase[
async def remove(self, db: AsyncSession, *, id: str) -> ModelType | None:
"""Delete a record with error handling and null check."""
# Validate UUID format and convert to UUID object if string
try:
if isinstance(id, uuid.UUID):
uuid_obj = id
@@ -247,7 +232,7 @@ class CRUDBase[
await db.rollback()
error_msg = str(e.orig) if hasattr(e, "orig") else str(e)
logger.error(f"Integrity error deleting {self.model.__name__}: {error_msg}")
raise ValueError(
raise IntegrityConstraintError(
f"Cannot delete {self.model.__name__}: referenced by other records"
)
except Exception as e:
@@ -272,57 +257,40 @@ class CRUDBase[
Get multiple records with total count, filtering, and sorting.
NOTE: This method is defensive code that's never called in practice.
All CRUD subclasses (CRUDUser, CRUDOrganization, CRUDSession) override this method
with their own implementations that include additional parameters like search.
All repository subclasses override this method with their own implementations.
Marked as pragma: no cover to avoid false coverage gaps.
Args:
db: Database session
skip: Number of records to skip
limit: Maximum number of records to return
sort_by: Field name to sort by (must be a valid model attribute)
sort_order: Sort order ("asc" or "desc")
filters: Dictionary of filters (field_name: value)
Returns:
Tuple of (items, total_count)
"""
# Validate pagination parameters
if skip < 0:
raise ValueError("skip must be non-negative")
raise InvalidInputError("skip must be non-negative")
if limit < 0:
raise ValueError("limit must be non-negative")
raise InvalidInputError("limit must be non-negative")
if limit > 1000:
raise ValueError("Maximum limit is 1000")
raise InvalidInputError("Maximum limit is 1000")
try:
# Build base query
query = select(self.model)
# Exclude soft-deleted records by default
if hasattr(self.model, "deleted_at"):
query = query.where(self.model.deleted_at.is_(None))
# Apply filters
if filters:
for field, value in filters.items():
if hasattr(self.model, field) and value is not None:
query = query.where(getattr(self.model, field) == value)
# Get total count (before pagination)
count_query = select(func.count()).select_from(query.alias())
count_result = await db.execute(count_query)
total = count_result.scalar_one()
# Apply sorting
if sort_by and hasattr(self.model, sort_by):
sort_column = getattr(self.model, sort_by)
if sort_order.lower() == "desc":
query = query.order_by(sort_column.desc())
else:
query = query.order_by(sort_column.asc())
else:
query = query.order_by(self.model.id)
# Apply pagination
query = query.offset(skip).limit(limit)
items_result = await db.execute(query)
items = list(items_result.scalars().all())
@@ -356,7 +324,6 @@ class CRUDBase[
"""
from datetime import datetime
# Validate UUID format and convert to UUID object if string
try:
if isinstance(id, uuid.UUID):
uuid_obj = id
@@ -378,14 +345,12 @@ class CRUDBase[
)
return None
# Check if model supports soft deletes
if not hasattr(self.model, "deleted_at"):
logger.error(f"{self.model.__name__} does not support soft deletes")
raise ValueError(
raise InvalidInputError(
f"{self.model.__name__} does not have a deleted_at column"
)
# Set deleted_at timestamp
obj.deleted_at = datetime.now(UTC)
db.add(obj)
await db.commit()
@@ -405,7 +370,6 @@ class CRUDBase[
Only works if the model has a 'deleted_at' column.
"""
# Validate UUID format
try:
if isinstance(id, uuid.UUID):
uuid_obj = id
@@ -416,7 +380,6 @@ class CRUDBase[
return None
try:
# Find the soft-deleted record
if hasattr(self.model, "deleted_at"):
result = await db.execute(
select(self.model).where(
@@ -426,7 +389,7 @@ class CRUDBase[
obj = result.scalar_one_or_none()
else:
logger.error(f"{self.model.__name__} does not support soft deletes")
raise ValueError(
raise InvalidInputError(
f"{self.model.__name__} does not have a deleted_at column"
)
@@ -436,7 +399,6 @@ class CRUDBase[
)
return None
# Clear deleted_at timestamp
obj.deleted_at = None
db.add(obj)
await db.commit()

View File

@@ -0,0 +1,237 @@
# app/repositories/oauth_account.py
"""Repository for OAuthAccount model async CRUD operations."""
import logging
from datetime import datetime
from uuid import UUID
from pydantic import BaseModel
from sqlalchemy import and_, delete, select
from sqlalchemy.exc import IntegrityError
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import joinedload
from app.core.repository_exceptions import DuplicateEntryError
from app.models.oauth_account import OAuthAccount
from app.repositories.base import BaseRepository
from app.schemas.oauth import OAuthAccountCreate
logger = logging.getLogger(__name__)
class EmptySchema(BaseModel):
"""Placeholder schema for repository operations that don't need update schemas."""
class OAuthAccountRepository(
BaseRepository[OAuthAccount, OAuthAccountCreate, EmptySchema]
):
"""Repository for OAuth account links."""
async def get_by_provider_id(
self,
db: AsyncSession,
*,
provider: str,
provider_user_id: str,
) -> OAuthAccount | None:
"""Get OAuth account by provider and provider user ID."""
try:
result = await db.execute(
select(OAuthAccount)
.where(
and_(
OAuthAccount.provider == provider,
OAuthAccount.provider_user_id == provider_user_id,
)
)
.options(joinedload(OAuthAccount.user))
)
return result.scalar_one_or_none()
except Exception as e: # pragma: no cover
logger.error(
f"Error getting OAuth account for {provider}:{provider_user_id}: {e!s}"
)
raise
async def get_by_provider_email(
self,
db: AsyncSession,
*,
provider: str,
email: str,
) -> OAuthAccount | None:
"""Get OAuth account by provider and email."""
try:
result = await db.execute(
select(OAuthAccount)
.where(
and_(
OAuthAccount.provider == provider,
OAuthAccount.provider_email == email,
)
)
.options(joinedload(OAuthAccount.user))
)
return result.scalar_one_or_none()
except Exception as e: # pragma: no cover
logger.error(
f"Error getting OAuth account for {provider} email {email}: {e!s}"
)
raise
async def get_user_accounts(
self,
db: AsyncSession,
*,
user_id: str | UUID,
) -> list[OAuthAccount]:
"""Get all OAuth accounts linked to a user."""
try:
user_uuid = UUID(str(user_id)) if isinstance(user_id, str) else user_id
result = await db.execute(
select(OAuthAccount)
.where(OAuthAccount.user_id == user_uuid)
.order_by(OAuthAccount.created_at.desc())
)
return list(result.scalars().all())
except Exception as e: # pragma: no cover
logger.error(f"Error getting OAuth accounts for user {user_id}: {e!s}")
raise
async def get_user_account_by_provider(
self,
db: AsyncSession,
*,
user_id: str | UUID,
provider: str,
) -> OAuthAccount | None:
"""Get a specific OAuth account for a user and provider."""
try:
user_uuid = UUID(str(user_id)) if isinstance(user_id, str) else user_id
result = await db.execute(
select(OAuthAccount).where(
and_(
OAuthAccount.user_id == user_uuid,
OAuthAccount.provider == provider,
)
)
)
return result.scalar_one_or_none()
except Exception as e: # pragma: no cover
logger.error(
f"Error getting OAuth account for user {user_id}, provider {provider}: {e!s}"
)
raise
async def create_account(
self, db: AsyncSession, *, obj_in: OAuthAccountCreate
) -> OAuthAccount:
"""Create a new OAuth account link."""
try:
db_obj = OAuthAccount(
user_id=obj_in.user_id,
provider=obj_in.provider,
provider_user_id=obj_in.provider_user_id,
provider_email=obj_in.provider_email,
access_token=obj_in.access_token,
refresh_token=obj_in.refresh_token,
token_expires_at=obj_in.token_expires_at,
)
db.add(db_obj)
await db.commit()
await db.refresh(db_obj)
logger.info(
f"OAuth account created: {obj_in.provider} linked to user {obj_in.user_id}"
)
return db_obj
except IntegrityError as e: # pragma: no cover
await db.rollback()
error_msg = str(e.orig) if hasattr(e, "orig") else str(e)
if "uq_oauth_provider_user" in error_msg.lower():
logger.warning(
f"OAuth account already exists: {obj_in.provider}:{obj_in.provider_user_id}"
)
raise DuplicateEntryError(
f"This {obj_in.provider} account is already linked to another user"
)
logger.error(f"Integrity error creating OAuth account: {error_msg}")
raise DuplicateEntryError(f"Failed to create OAuth account: {error_msg}")
except Exception as e: # pragma: no cover
await db.rollback()
logger.error(f"Error creating OAuth account: {e!s}", exc_info=True)
raise
async def delete_account(
self,
db: AsyncSession,
*,
user_id: str | UUID,
provider: str,
) -> bool:
"""Delete an OAuth account link."""
try:
user_uuid = UUID(str(user_id)) if isinstance(user_id, str) else user_id
result = await db.execute(
delete(OAuthAccount).where(
and_(
OAuthAccount.user_id == user_uuid,
OAuthAccount.provider == provider,
)
)
)
await db.commit()
deleted = result.rowcount > 0
if deleted:
logger.info(
f"OAuth account deleted: {provider} unlinked from user {user_id}"
)
else:
logger.warning(
f"OAuth account not found for deletion: {provider} for user {user_id}"
)
return deleted
except Exception as e: # pragma: no cover
await db.rollback()
logger.error(
f"Error deleting OAuth account {provider} for user {user_id}: {e!s}"
)
raise
async def update_tokens(
self,
db: AsyncSession,
*,
account: OAuthAccount,
access_token: str | None = None,
refresh_token: str | None = None,
token_expires_at: datetime | None = None,
) -> OAuthAccount:
"""Update OAuth tokens for an account."""
try:
if access_token is not None:
account.access_token = access_token
if refresh_token is not None:
account.refresh_token = refresh_token
if token_expires_at is not None:
account.token_expires_at = token_expires_at
db.add(account)
await db.commit()
await db.refresh(account)
return account
except Exception as e: # pragma: no cover
await db.rollback()
logger.error(f"Error updating OAuth tokens: {e!s}")
raise
# Singleton instance
oauth_account_repo = OAuthAccountRepository(OAuthAccount)

View File

@@ -0,0 +1,108 @@
# app/repositories/oauth_authorization_code.py
"""Repository for OAuthAuthorizationCode model."""
import logging
from datetime import UTC, datetime
from uuid import UUID
from sqlalchemy import and_, delete, select, update
from sqlalchemy.ext.asyncio import AsyncSession
from app.models.oauth_authorization_code import OAuthAuthorizationCode
logger = logging.getLogger(__name__)
class OAuthAuthorizationCodeRepository:
"""Repository for OAuth 2.0 authorization codes."""
async def create_code(
self,
db: AsyncSession,
*,
code: str,
client_id: str,
user_id: UUID,
redirect_uri: str,
scope: str,
expires_at: datetime,
code_challenge: str | None = None,
code_challenge_method: str | None = None,
state: str | None = None,
nonce: str | None = None,
) -> OAuthAuthorizationCode:
"""Create and persist a new authorization code."""
auth_code = OAuthAuthorizationCode(
code=code,
client_id=client_id,
user_id=user_id,
redirect_uri=redirect_uri,
scope=scope,
code_challenge=code_challenge,
code_challenge_method=code_challenge_method,
state=state,
nonce=nonce,
expires_at=expires_at,
used=False,
)
db.add(auth_code)
await db.commit()
return auth_code
async def consume_code_atomically(
self, db: AsyncSession, *, code: str
) -> UUID | None:
"""
Atomically mark a code as used and return its UUID.
Returns the UUID if the code was found and not yet used, None otherwise.
This prevents race conditions per RFC 6749 Section 4.1.2.
"""
stmt = (
update(OAuthAuthorizationCode)
.where(
and_(
OAuthAuthorizationCode.code == code,
OAuthAuthorizationCode.used == False, # noqa: E712
)
)
.values(used=True)
.returning(OAuthAuthorizationCode.id)
)
result = await db.execute(stmt)
row_id = result.scalar_one_or_none()
if row_id is not None:
await db.commit()
return row_id
async def get_by_id(
self, db: AsyncSession, *, code_id: UUID
) -> OAuthAuthorizationCode | None:
"""Get authorization code by its UUID primary key."""
result = await db.execute(
select(OAuthAuthorizationCode).where(OAuthAuthorizationCode.id == code_id)
)
return result.scalar_one_or_none()
async def get_by_code(
self, db: AsyncSession, *, code: str
) -> OAuthAuthorizationCode | None:
"""Get authorization code by the code string value."""
result = await db.execute(
select(OAuthAuthorizationCode).where(OAuthAuthorizationCode.code == code)
)
return result.scalar_one_or_none()
async def cleanup_expired(self, db: AsyncSession) -> int:
"""Delete all expired authorization codes. Returns count deleted."""
result = await db.execute(
delete(OAuthAuthorizationCode).where(
OAuthAuthorizationCode.expires_at < datetime.now(UTC)
)
)
await db.commit()
return result.rowcount # type: ignore[attr-defined]
# Singleton instance
oauth_authorization_code_repo = OAuthAuthorizationCodeRepository()

View File

@@ -0,0 +1,201 @@
# app/repositories/oauth_client.py
"""Repository for OAuthClient model async CRUD operations."""
import logging
import secrets
from uuid import UUID
from pydantic import BaseModel
from sqlalchemy import and_, delete, select
from sqlalchemy.exc import IntegrityError
from sqlalchemy.ext.asyncio import AsyncSession
from app.core.repository_exceptions import DuplicateEntryError
from app.models.oauth_client import OAuthClient
from app.repositories.base import BaseRepository
from app.schemas.oauth import OAuthClientCreate
logger = logging.getLogger(__name__)
class EmptySchema(BaseModel):
"""Placeholder schema for repository operations that don't need update schemas."""
class OAuthClientRepository(
BaseRepository[OAuthClient, OAuthClientCreate, EmptySchema]
):
"""Repository for OAuth clients (provider mode)."""
async def get_by_client_id(
self, db: AsyncSession, *, client_id: str
) -> OAuthClient | None:
"""Get OAuth client by client_id."""
try:
result = await db.execute(
select(OAuthClient).where(
and_(
OAuthClient.client_id == client_id,
OAuthClient.is_active == True, # noqa: E712
)
)
)
return result.scalar_one_or_none()
except Exception as e: # pragma: no cover
logger.error(f"Error getting OAuth client {client_id}: {e!s}")
raise
async def create_client(
self,
db: AsyncSession,
*,
obj_in: OAuthClientCreate,
owner_user_id: UUID | None = None,
) -> tuple[OAuthClient, str | None]:
"""Create a new OAuth client."""
try:
client_id = secrets.token_urlsafe(32)
client_secret = None
client_secret_hash = None
if obj_in.client_type == "confidential":
client_secret = secrets.token_urlsafe(48)
from app.core.auth import get_password_hash
client_secret_hash = get_password_hash(client_secret)
db_obj = OAuthClient(
client_id=client_id,
client_secret_hash=client_secret_hash,
client_name=obj_in.client_name,
client_description=obj_in.client_description,
client_type=obj_in.client_type,
redirect_uris=obj_in.redirect_uris,
allowed_scopes=obj_in.allowed_scopes,
owner_user_id=owner_user_id,
is_active=True,
)
db.add(db_obj)
await db.commit()
await db.refresh(db_obj)
logger.info(
f"OAuth client created: {obj_in.client_name} ({client_id[:8]}...)"
)
return db_obj, client_secret
except IntegrityError as e: # pragma: no cover
await db.rollback()
error_msg = str(e.orig) if hasattr(e, "orig") else str(e)
logger.error(f"Error creating OAuth client: {error_msg}")
raise DuplicateEntryError(f"Failed to create OAuth client: {error_msg}")
except Exception as e: # pragma: no cover
await db.rollback()
logger.error(f"Error creating OAuth client: {e!s}", exc_info=True)
raise
async def deactivate_client(
self, db: AsyncSession, *, client_id: str
) -> OAuthClient | None:
"""Deactivate an OAuth client."""
try:
client = await self.get_by_client_id(db, client_id=client_id)
if client is None:
return None
client.is_active = False
db.add(client)
await db.commit()
await db.refresh(client)
logger.info(f"OAuth client deactivated: {client.client_name}")
return client
except Exception as e: # pragma: no cover
await db.rollback()
logger.error(f"Error deactivating OAuth client {client_id}: {e!s}")
raise
async def validate_redirect_uri(
self, db: AsyncSession, *, client_id: str, redirect_uri: str
) -> bool:
"""Validate that a redirect URI is allowed for a client."""
try:
client = await self.get_by_client_id(db, client_id=client_id)
if client is None:
return False
return redirect_uri in (client.redirect_uris or [])
except Exception as e: # pragma: no cover
logger.error(f"Error validating redirect URI: {e!s}")
return False
async def verify_client_secret(
self, db: AsyncSession, *, client_id: str, client_secret: str
) -> bool:
"""Verify client credentials."""
try:
result = await db.execute(
select(OAuthClient).where(
and_(
OAuthClient.client_id == client_id,
OAuthClient.is_active == True, # noqa: E712
)
)
)
client = result.scalar_one_or_none()
if client is None or client.client_secret_hash is None:
return False
from app.core.auth import verify_password
stored_hash: str = str(client.client_secret_hash)
if stored_hash.startswith("$2"):
return verify_password(client_secret, stored_hash)
else:
import hashlib
secret_hash = hashlib.sha256(client_secret.encode()).hexdigest()
return secrets.compare_digest(stored_hash, secret_hash)
except Exception as e: # pragma: no cover
logger.error(f"Error verifying client secret: {e!s}")
return False
async def get_all_clients(
self, db: AsyncSession, *, include_inactive: bool = False
) -> list[OAuthClient]:
"""Get all OAuth clients."""
try:
query = select(OAuthClient).order_by(OAuthClient.created_at.desc())
if not include_inactive:
query = query.where(OAuthClient.is_active == True) # noqa: E712
result = await db.execute(query)
return list(result.scalars().all())
except Exception as e: # pragma: no cover
logger.error(f"Error getting all OAuth clients: {e!s}")
raise
async def delete_client(self, db: AsyncSession, *, client_id: str) -> bool:
"""Delete an OAuth client permanently."""
try:
result = await db.execute(
delete(OAuthClient).where(OAuthClient.client_id == client_id)
)
await db.commit()
deleted = result.rowcount > 0
if deleted:
logger.info(f"OAuth client deleted: {client_id}")
else:
logger.warning(f"OAuth client not found for deletion: {client_id}")
return deleted
except Exception as e: # pragma: no cover
await db.rollback()
logger.error(f"Error deleting OAuth client {client_id}: {e!s}")
raise
# Singleton instance
oauth_client_repo = OAuthClientRepository(OAuthClient)

View File

@@ -0,0 +1,113 @@
# app/repositories/oauth_consent.py
"""Repository for OAuthConsent model."""
import logging
from typing import Any
from uuid import UUID
from sqlalchemy import and_, delete, select
from sqlalchemy.ext.asyncio import AsyncSession
from app.models.oauth_client import OAuthClient
from app.models.oauth_provider_token import OAuthConsent
logger = logging.getLogger(__name__)
class OAuthConsentRepository:
"""Repository for OAuth consent records (user grants to clients)."""
async def get_consent(
self, db: AsyncSession, *, user_id: UUID, client_id: str
) -> OAuthConsent | None:
"""Get the consent record for a user-client pair, or None if not found."""
result = await db.execute(
select(OAuthConsent).where(
and_(
OAuthConsent.user_id == user_id,
OAuthConsent.client_id == client_id,
)
)
)
return result.scalar_one_or_none()
async def grant_consent(
self,
db: AsyncSession,
*,
user_id: UUID,
client_id: str,
scopes: list[str],
) -> OAuthConsent:
"""
Create or update consent for a user-client pair.
If consent already exists, the new scopes are merged with existing ones.
Returns the created or updated consent record.
"""
consent = await self.get_consent(db, user_id=user_id, client_id=client_id)
if consent:
existing = (
set(consent.granted_scopes.split()) if consent.granted_scopes else set()
)
merged = existing | set(scopes)
consent.granted_scopes = " ".join(sorted(merged)) # type: ignore[assignment]
else:
consent = OAuthConsent(
user_id=user_id,
client_id=client_id,
granted_scopes=" ".join(sorted(set(scopes))),
)
db.add(consent)
await db.commit()
await db.refresh(consent)
return consent
async def get_user_consents_with_clients(
self, db: AsyncSession, *, user_id: UUID
) -> list[dict[str, Any]]:
"""Get all consent records for a user joined with client details."""
result = await db.execute(
select(OAuthConsent, OAuthClient)
.join(OAuthClient, OAuthConsent.client_id == OAuthClient.client_id)
.where(OAuthConsent.user_id == user_id)
)
rows = result.all()
return [
{
"client_id": consent.client_id,
"client_name": client.client_name,
"client_description": client.client_description,
"granted_scopes": consent.granted_scopes.split()
if consent.granted_scopes
else [],
"granted_at": consent.created_at.isoformat(),
}
for consent, client in rows
]
async def revoke_consent(
self, db: AsyncSession, *, user_id: UUID, client_id: str
) -> bool:
"""
Delete the consent record for a user-client pair.
Returns True if a record was found and deleted.
Note: Callers are responsible for also revoking associated tokens.
"""
result = await db.execute(
delete(OAuthConsent).where(
and_(
OAuthConsent.user_id == user_id,
OAuthConsent.client_id == client_id,
)
)
)
await db.commit()
return result.rowcount > 0 # type: ignore[attr-defined]
# Singleton instance
oauth_consent_repo = OAuthConsentRepository()

View File

@@ -0,0 +1,142 @@
# app/repositories/oauth_provider_token.py
"""Repository for OAuthProviderRefreshToken model."""
import logging
from datetime import UTC, datetime, timedelta
from uuid import UUID
from sqlalchemy import and_, delete, select, update
from sqlalchemy.ext.asyncio import AsyncSession
from app.models.oauth_provider_token import OAuthProviderRefreshToken
logger = logging.getLogger(__name__)
class OAuthProviderTokenRepository:
"""Repository for OAuth provider refresh tokens."""
async def create_token(
self,
db: AsyncSession,
*,
token_hash: str,
jti: str,
client_id: str,
user_id: UUID,
scope: str,
expires_at: datetime,
device_info: str | None = None,
ip_address: str | None = None,
) -> OAuthProviderRefreshToken:
"""Create and persist a new refresh token record."""
token = OAuthProviderRefreshToken(
token_hash=token_hash,
jti=jti,
client_id=client_id,
user_id=user_id,
scope=scope,
expires_at=expires_at,
device_info=device_info,
ip_address=ip_address,
)
db.add(token)
await db.commit()
return token
async def get_by_token_hash(
self, db: AsyncSession, *, token_hash: str
) -> OAuthProviderRefreshToken | None:
"""Get refresh token record by SHA-256 token hash."""
result = await db.execute(
select(OAuthProviderRefreshToken).where(
OAuthProviderRefreshToken.token_hash == token_hash
)
)
return result.scalar_one_or_none()
async def get_by_jti(
self, db: AsyncSession, *, jti: str
) -> OAuthProviderRefreshToken | None:
"""Get refresh token record by JWT ID (JTI)."""
result = await db.execute(
select(OAuthProviderRefreshToken).where(
OAuthProviderRefreshToken.jti == jti
)
)
return result.scalar_one_or_none()
async def revoke(
self, db: AsyncSession, *, token: OAuthProviderRefreshToken
) -> None:
"""Mark a specific token record as revoked."""
token.revoked = True # type: ignore[assignment]
token.last_used_at = datetime.now(UTC) # type: ignore[assignment]
await db.commit()
async def revoke_all_for_user_client(
self, db: AsyncSession, *, user_id: UUID, client_id: str
) -> int:
"""
Revoke all active tokens for a specific user-client pair.
Used when security incidents are detected (e.g., authorization code reuse).
Returns the number of tokens revoked.
"""
result = await db.execute(
update(OAuthProviderRefreshToken)
.where(
and_(
OAuthProviderRefreshToken.user_id == user_id,
OAuthProviderRefreshToken.client_id == client_id,
OAuthProviderRefreshToken.revoked == False, # noqa: E712
)
)
.values(revoked=True)
)
count = result.rowcount # type: ignore[attr-defined]
if count > 0:
await db.commit()
return count
async def revoke_all_for_user(self, db: AsyncSession, *, user_id: UUID) -> int:
"""
Revoke all active tokens for a user across all clients.
Used when user changes password or logs out everywhere.
Returns the number of tokens revoked.
"""
result = await db.execute(
update(OAuthProviderRefreshToken)
.where(
and_(
OAuthProviderRefreshToken.user_id == user_id,
OAuthProviderRefreshToken.revoked == False, # noqa: E712
)
)
.values(revoked=True)
)
count = result.rowcount # type: ignore[attr-defined]
if count > 0:
await db.commit()
return count
async def cleanup_expired(self, db: AsyncSession, *, cutoff_days: int = 7) -> int:
"""
Delete expired refresh tokens older than cutoff_days.
Should be called periodically (e.g., daily).
Returns the number of tokens deleted.
"""
cutoff = datetime.now(UTC) - timedelta(days=cutoff_days)
result = await db.execute(
delete(OAuthProviderRefreshToken).where(
OAuthProviderRefreshToken.expires_at < cutoff
)
)
await db.commit()
return result.rowcount # type: ignore[attr-defined]
# Singleton instance
oauth_provider_token_repo = OAuthProviderTokenRepository()

View File

@@ -0,0 +1,113 @@
# app/repositories/oauth_state.py
"""Repository for OAuthState model async CRUD operations."""
import logging
from datetime import UTC, datetime
from pydantic import BaseModel
from sqlalchemy import delete, select
from sqlalchemy.exc import IntegrityError
from sqlalchemy.ext.asyncio import AsyncSession
from app.core.repository_exceptions import DuplicateEntryError
from app.models.oauth_state import OAuthState
from app.repositories.base import BaseRepository
from app.schemas.oauth import OAuthStateCreate
logger = logging.getLogger(__name__)
class EmptySchema(BaseModel):
"""Placeholder schema for repository operations that don't need update schemas."""
class OAuthStateRepository(BaseRepository[OAuthState, OAuthStateCreate, EmptySchema]):
"""Repository for OAuth state (CSRF protection)."""
async def create_state(
self, db: AsyncSession, *, obj_in: OAuthStateCreate
) -> OAuthState:
"""Create a new OAuth state for CSRF protection."""
try:
db_obj = OAuthState(
state=obj_in.state,
code_verifier=obj_in.code_verifier,
nonce=obj_in.nonce,
provider=obj_in.provider,
redirect_uri=obj_in.redirect_uri,
user_id=obj_in.user_id,
expires_at=obj_in.expires_at,
)
db.add(db_obj)
await db.commit()
await db.refresh(db_obj)
logger.debug(f"OAuth state created for {obj_in.provider}")
return db_obj
except IntegrityError as e: # pragma: no cover
await db.rollback()
error_msg = str(e.orig) if hasattr(e, "orig") else str(e)
logger.error(f"OAuth state collision: {error_msg}")
raise DuplicateEntryError("Failed to create OAuth state, please retry")
except Exception as e: # pragma: no cover
await db.rollback()
logger.error(f"Error creating OAuth state: {e!s}", exc_info=True)
raise
async def get_and_consume_state(
self, db: AsyncSession, *, state: str
) -> OAuthState | None:
"""Get and delete OAuth state (consume it)."""
try:
result = await db.execute(
select(OAuthState).where(OAuthState.state == state)
)
db_obj = result.scalar_one_or_none()
if db_obj is None:
logger.warning(f"OAuth state not found: {state[:8]}...")
return None
now = datetime.now(UTC)
expires_at = db_obj.expires_at
if expires_at.tzinfo is None:
expires_at = expires_at.replace(tzinfo=UTC)
if expires_at < now:
logger.warning(f"OAuth state expired: {state[:8]}...")
await db.delete(db_obj)
await db.commit()
return None
await db.delete(db_obj)
await db.commit()
logger.debug(f"OAuth state consumed: {state[:8]}...")
return db_obj
except Exception as e: # pragma: no cover
await db.rollback()
logger.error(f"Error consuming OAuth state: {e!s}")
raise
async def cleanup_expired(self, db: AsyncSession) -> int:
"""Clean up expired OAuth states."""
try:
now = datetime.now(UTC)
stmt = delete(OAuthState).where(OAuthState.expires_at < now)
result = await db.execute(stmt)
await db.commit()
count = result.rowcount
if count > 0:
logger.info(f"Cleaned up {count} expired OAuth states")
return count
except Exception as e: # pragma: no cover
await db.rollback()
logger.error(f"Error cleaning up expired OAuth states: {e!s}")
raise
# Singleton instance
oauth_state_repo = OAuthStateRepository(OAuthState)

View File

@@ -1,5 +1,5 @@
# app/crud/organization_async.py
"""Async CRUD operations for Organization model using SQLAlchemy 2.0 patterns."""
# app/repositories/organization.py
"""Repository for Organization model async CRUD operations using SQLAlchemy 2.0 patterns."""
import logging
from typing import Any
@@ -9,10 +9,11 @@ from sqlalchemy import and_, case, func, or_, select
from sqlalchemy.exc import IntegrityError
from sqlalchemy.ext.asyncio import AsyncSession
from app.crud.base import CRUDBase
from app.core.repository_exceptions import DuplicateEntryError, IntegrityConstraintError
from app.models.organization import Organization
from app.models.user import User
from app.models.user_organization import OrganizationRole, UserOrganization
from app.repositories.base import BaseRepository
from app.schemas.organizations import (
OrganizationCreate,
OrganizationUpdate,
@@ -21,8 +22,10 @@ from app.schemas.organizations import (
logger = logging.getLogger(__name__)
class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUpdate]):
"""Async CRUD operations for Organization model."""
class OrganizationRepository(
BaseRepository[Organization, OrganizationCreate, OrganizationUpdate]
):
"""Repository for Organization model."""
async def get_by_slug(self, db: AsyncSession, *, slug: str) -> Organization | None:
"""Get organization by slug."""
@@ -54,13 +57,17 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp
except IntegrityError as e:
await db.rollback()
error_msg = str(e.orig) if hasattr(e, "orig") else str(e)
if "slug" in error_msg.lower():
if (
"slug" in error_msg.lower()
or "unique" in error_msg.lower()
or "duplicate" in error_msg.lower()
):
logger.warning(f"Duplicate slug attempted: {obj_in.slug}")
raise ValueError(
raise DuplicateEntryError(
f"Organization with slug '{obj_in.slug}' already exists"
)
logger.error(f"Integrity error creating organization: {error_msg}")
raise ValueError(f"Database integrity error: {error_msg}")
raise IntegrityConstraintError(f"Database integrity error: {error_msg}")
except Exception as e:
await db.rollback()
logger.error(
@@ -79,16 +86,10 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp
sort_by: str = "created_at",
sort_order: str = "desc",
) -> tuple[list[Organization], int]:
"""
Get multiple organizations with filtering, searching, and sorting.
Returns:
Tuple of (organizations list, total count)
"""
"""Get multiple organizations with filtering, searching, and sorting."""
try:
query = select(Organization)
# Apply filters
if is_active is not None:
query = query.where(Organization.is_active == is_active)
@@ -100,19 +101,16 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp
)
query = query.where(search_filter)
# Get total count before pagination
count_query = select(func.count()).select_from(query.alias())
count_result = await db.execute(count_query)
total = count_result.scalar_one()
# Apply sorting
sort_column = getattr(Organization, sort_by, Organization.created_at)
if sort_order == "desc":
query = query.order_by(sort_column.desc())
else:
query = query.order_by(sort_column.asc())
# Apply pagination
query = query.offset(skip).limit(limit)
result = await db.execute(query)
organizations = list(result.scalars().all())
@@ -149,16 +147,8 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp
is_active: bool | None = None,
search: str | None = None,
) -> tuple[list[dict[str, Any]], int]:
"""
Get organizations with member counts in a SINGLE QUERY using JOIN and GROUP BY.
This eliminates the N+1 query problem.
Returns:
Tuple of (list of dicts with org and member_count, total count)
"""
"""Get organizations with member counts in a SINGLE QUERY using JOIN and GROUP BY."""
try:
# Build base query with LEFT JOIN and GROUP BY
# Use CASE statement to count only active members
query = (
select(
Organization,
@@ -181,10 +171,10 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp
.group_by(Organization.id)
)
# Apply filters
if is_active is not None:
query = query.where(Organization.is_active == is_active)
search_filter = None
if search:
search_filter = or_(
Organization.name.ilike(f"%{search}%"),
@@ -193,17 +183,15 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp
)
query = query.where(search_filter)
# Get total count
count_query = select(func.count(Organization.id))
if is_active is not None:
count_query = count_query.where(Organization.is_active == is_active)
if search:
if search_filter is not None:
count_query = count_query.where(search_filter)
count_result = await db.execute(count_query)
total = count_result.scalar_one()
# Apply pagination and ordering
query = (
query.order_by(Organization.created_at.desc()).offset(skip).limit(limit)
)
@@ -211,7 +199,6 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp
result = await db.execute(query)
rows = result.all()
# Convert to list of dicts
orgs_with_counts = [
{"organization": org, "member_count": member_count}
for org, member_count in rows
@@ -236,7 +223,6 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp
) -> UserOrganization:
"""Add a user to an organization with a specific role."""
try:
# Check if relationship already exists
result = await db.execute(
select(UserOrganization).where(
and_(
@@ -248,7 +234,6 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp
existing = result.scalar_one_or_none()
if existing:
# Reactivate if inactive, or raise error if already active
if not existing.is_active:
existing.is_active = True
existing.role = role
@@ -257,9 +242,10 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp
await db.refresh(existing)
return existing
else:
raise ValueError("User is already a member of this organization")
raise DuplicateEntryError(
"User is already a member of this organization"
)
# Create new relationship
user_org = UserOrganization(
user_id=user_id,
organization_id=organization_id,
@@ -274,7 +260,7 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp
except IntegrityError as e:
await db.rollback()
logger.error(f"Integrity error adding user to organization: {e!s}")
raise ValueError("Failed to add user to organization")
raise IntegrityConstraintError("Failed to add user to organization")
except Exception as e:
await db.rollback()
logger.error(f"Error adding user to organization: {e!s}", exc_info=True)
@@ -348,16 +334,10 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp
organization_id: UUID,
skip: int = 0,
limit: int = 100,
is_active: bool = True,
is_active: bool | None = True,
) -> tuple[list[dict[str, Any]], int]:
"""
Get members of an organization with user details.
Returns:
Tuple of (members list with user details, total count)
"""
"""Get members of an organization with user details."""
try:
# Build query with join
query = (
select(UserOrganization, User)
.join(User, UserOrganization.user_id == User.id)
@@ -367,7 +347,6 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp
if is_active is not None:
query = query.where(UserOrganization.is_active == is_active)
# Get total count
count_query = select(func.count()).select_from(
select(UserOrganization)
.where(UserOrganization.organization_id == organization_id)
@@ -381,7 +360,6 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp
count_result = await db.execute(count_query)
total = count_result.scalar_one()
# Apply ordering and pagination
query = (
query.order_by(UserOrganization.created_at.desc())
.offset(skip)
@@ -410,7 +388,7 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp
raise
async def get_user_organizations(
self, db: AsyncSession, *, user_id: UUID, is_active: bool = True
self, db: AsyncSession, *, user_id: UUID, is_active: bool | None = True
) -> list[Organization]:
"""Get all organizations a user belongs to."""
try:
@@ -433,17 +411,10 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp
raise
async def get_user_organizations_with_details(
self, db: AsyncSession, *, user_id: UUID, is_active: bool = True
self, db: AsyncSession, *, user_id: UUID, is_active: bool | None = True
) -> list[dict[str, Any]]:
"""
Get user's organizations with role and member count in SINGLE QUERY.
Eliminates N+1 problem by using subquery for member counts.
Returns:
List of dicts with organization, role, and member_count
"""
"""Get user's organizations with role and member count in SINGLE QUERY."""
try:
# Subquery to get member counts for each organization
member_count_subq = (
select(
UserOrganization.organization_id,
@@ -454,7 +425,6 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp
.subquery()
)
# Main query with JOIN to get org, role, and member count
query = (
select(
Organization,
@@ -507,7 +477,7 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp
)
user_org = result.scalar_one_or_none()
return user_org.role if user_org else None
return user_org.role if user_org else None # pyright: ignore[reportReturnType]
except Exception as e:
logger.error(f"Error getting user role in org: {e!s}")
raise
@@ -531,5 +501,5 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp
return role in [OrganizationRole.OWNER, OrganizationRole.ADMIN]
# Create a singleton instance for use across the application
organization = CRUDOrganization(Organization)
# Singleton instance
organization_repo = OrganizationRepository(Organization)

View File

@@ -1,6 +1,5 @@
"""
Async CRUD operations for user sessions using SQLAlchemy 2.0 patterns.
"""
# app/repositories/session.py
"""Repository for UserSession model async CRUD operations using SQLAlchemy 2.0 patterns."""
import logging
import uuid
@@ -11,27 +10,19 @@ from sqlalchemy import and_, delete, func, select, update
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import joinedload
from app.crud.base import CRUDBase
from app.core.repository_exceptions import IntegrityConstraintError, InvalidInputError
from app.models.user_session import UserSession
from app.repositories.base import BaseRepository
from app.schemas.sessions import SessionCreate, SessionUpdate
logger = logging.getLogger(__name__)
class CRUDSession(CRUDBase[UserSession, SessionCreate, SessionUpdate]):
"""Async CRUD operations for user sessions."""
class SessionRepository(BaseRepository[UserSession, SessionCreate, SessionUpdate]):
"""Repository for UserSession model."""
async def get_by_jti(self, db: AsyncSession, *, jti: str) -> UserSession | None:
"""
Get session by refresh token JTI.
Args:
db: Database session
jti: Refresh token JWT ID
Returns:
UserSession if found, None otherwise
"""
"""Get session by refresh token JTI."""
try:
result = await db.execute(
select(UserSession).where(UserSession.refresh_token_jti == jti)
@@ -44,16 +35,7 @@ class CRUDSession(CRUDBase[UserSession, SessionCreate, SessionUpdate]):
async def get_active_by_jti(
self, db: AsyncSession, *, jti: str
) -> UserSession | None:
"""
Get active session by refresh token JTI.
Args:
db: Database session
jti: Refresh token JWT ID
Returns:
Active UserSession if found, None otherwise
"""
"""Get active session by refresh token JTI."""
try:
result = await db.execute(
select(UserSession).where(
@@ -76,25 +58,12 @@ class CRUDSession(CRUDBase[UserSession, SessionCreate, SessionUpdate]):
active_only: bool = True,
with_user: bool = False,
) -> list[UserSession]:
"""
Get all sessions for a user with optional eager loading.
Args:
db: Database session
user_id: User ID
active_only: If True, return only active sessions
with_user: If True, eager load user relationship to prevent N+1
Returns:
List of UserSession objects
"""
"""Get all sessions for a user with optional eager loading."""
try:
# Convert user_id string to UUID if needed
user_uuid = UUID(user_id) if isinstance(user_id, str) else user_id
query = select(UserSession).where(UserSession.user_id == user_uuid)
# Add eager loading if requested to prevent N+1 queries
if with_user:
query = query.options(joinedload(UserSession.user))
@@ -111,19 +80,7 @@ class CRUDSession(CRUDBase[UserSession, SessionCreate, SessionUpdate]):
async def create_session(
self, db: AsyncSession, *, obj_in: SessionCreate
) -> UserSession:
"""
Create a new user session.
Args:
db: Database session
obj_in: SessionCreate schema with session data
Returns:
Created UserSession
Raises:
ValueError: If session creation fails
"""
"""Create a new user session."""
try:
db_obj = UserSession(
user_id=obj_in.user_id,
@@ -151,21 +108,12 @@ class CRUDSession(CRUDBase[UserSession, SessionCreate, SessionUpdate]):
except Exception as e:
await db.rollback()
logger.error(f"Error creating session: {e!s}", exc_info=True)
raise ValueError(f"Failed to create session: {e!s}")
raise IntegrityConstraintError(f"Failed to create session: {e!s}")
async def deactivate(
self, db: AsyncSession, *, session_id: str
) -> UserSession | None:
"""
Deactivate a session (logout from device).
Args:
db: Database session
session_id: Session UUID
Returns:
Deactivated UserSession if found, None otherwise
"""
"""Deactivate a session (logout from device)."""
try:
session = await self.get(db, id=session_id)
if not session:
@@ -191,18 +139,8 @@ class CRUDSession(CRUDBase[UserSession, SessionCreate, SessionUpdate]):
async def deactivate_all_user_sessions(
self, db: AsyncSession, *, user_id: str
) -> int:
"""
Deactivate all active sessions for a user (logout from all devices).
Args:
db: Database session
user_id: User ID
Returns:
Number of sessions deactivated
"""
"""Deactivate all active sessions for a user (logout from all devices)."""
try:
# Convert user_id string to UUID if needed
user_uuid = UUID(user_id) if isinstance(user_id, str) else user_id
stmt = (
@@ -227,16 +165,7 @@ class CRUDSession(CRUDBase[UserSession, SessionCreate, SessionUpdate]):
async def update_last_used(
self, db: AsyncSession, *, session: UserSession
) -> UserSession:
"""
Update the last_used_at timestamp for a session.
Args:
db: Database session
session: UserSession object
Returns:
Updated UserSession
"""
"""Update the last_used_at timestamp for a session."""
try:
session.last_used_at = datetime.now(UTC)
db.add(session)
@@ -256,20 +185,7 @@ class CRUDSession(CRUDBase[UserSession, SessionCreate, SessionUpdate]):
new_jti: str,
new_expires_at: datetime,
) -> UserSession:
"""
Update session with new refresh token JTI and expiration.
Called during token refresh.
Args:
db: Database session
session: UserSession object
new_jti: New refresh token JTI
new_expires_at: New expiration datetime
Returns:
Updated UserSession
"""
"""Update session with new refresh token JTI and expiration."""
try:
session.refresh_token_jti = new_jti
session.expires_at = new_expires_at
@@ -286,27 +202,11 @@ class CRUDSession(CRUDBase[UserSession, SessionCreate, SessionUpdate]):
raise
async def cleanup_expired(self, db: AsyncSession, *, keep_days: int = 30) -> int:
"""
Clean up expired sessions using optimized bulk DELETE.
Deletes sessions that are:
- Expired AND inactive
- Older than keep_days
Uses single DELETE query instead of N individual deletes for efficiency.
Args:
db: Database session
keep_days: Keep inactive sessions for this many days (for audit)
Returns:
Number of sessions deleted
"""
"""Clean up expired sessions using optimized bulk DELETE."""
try:
cutoff_date = datetime.now(UTC) - timedelta(days=keep_days)
now = datetime.now(UTC)
# Use bulk DELETE with WHERE clause - single query
stmt = delete(UserSession).where(
and_(
UserSession.is_active == False, # noqa: E712
@@ -330,29 +230,16 @@ class CRUDSession(CRUDBase[UserSession, SessionCreate, SessionUpdate]):
raise
async def cleanup_expired_for_user(self, db: AsyncSession, *, user_id: str) -> int:
"""
Clean up expired and inactive sessions for a specific user.
Uses single bulk DELETE query for efficiency instead of N individual deletes.
Args:
db: Database session
user_id: User ID to cleanup sessions for
Returns:
Number of sessions deleted
"""
"""Clean up expired and inactive sessions for a specific user."""
try:
# Validate UUID
try:
uuid_obj = uuid.UUID(user_id)
except (ValueError, AttributeError):
logger.error(f"Invalid UUID format: {user_id}")
raise ValueError(f"Invalid user ID format: {user_id}")
raise InvalidInputError(f"Invalid user ID format: {user_id}")
now = datetime.now(UTC)
# Use bulk DELETE with WHERE clause - single query
stmt = delete(UserSession).where(
and_(
UserSession.user_id == uuid_obj,
@@ -380,18 +267,8 @@ class CRUDSession(CRUDBase[UserSession, SessionCreate, SessionUpdate]):
raise
async def get_user_session_count(self, db: AsyncSession, *, user_id: str) -> int:
"""
Get count of active sessions for a user.
Args:
db: Database session
user_id: User ID
Returns:
Number of active sessions
"""
"""Get count of active sessions for a user."""
try:
# Convert user_id string to UUID if needed
user_uuid = UUID(user_id) if isinstance(user_id, str) else user_id
result = await db.execute(
@@ -413,31 +290,16 @@ class CRUDSession(CRUDBase[UserSession, SessionCreate, SessionUpdate]):
active_only: bool = True,
with_user: bool = True,
) -> tuple[list[UserSession], int]:
"""
Get all sessions across all users with pagination (admin only).
Args:
db: Database session
skip: Number of records to skip
limit: Maximum number of records to return
active_only: If True, return only active sessions
with_user: If True, eager load user relationship to prevent N+1
Returns:
Tuple of (list of UserSession objects, total count)
"""
"""Get all sessions across all users with pagination (admin only)."""
try:
# Build query
query = select(UserSession)
# Add eager loading if requested to prevent N+1 queries
if with_user:
query = query.options(joinedload(UserSession.user))
if active_only:
query = query.where(UserSession.is_active)
# Get total count
count_query = select(func.count(UserSession.id))
if active_only:
count_query = count_query.where(UserSession.is_active)
@@ -445,7 +307,6 @@ class CRUDSession(CRUDBase[UserSession, SessionCreate, SessionUpdate]):
count_result = await db.execute(count_query)
total = count_result.scalar_one()
# Apply pagination and ordering
query = (
query.order_by(UserSession.last_used_at.desc())
.offset(skip)
@@ -462,5 +323,5 @@ class CRUDSession(CRUDBase[UserSession, SessionCreate, SessionUpdate]):
raise
# Create singleton instance
session = CRUDSession(UserSession)
# Singleton instance
session_repo = SessionRepository(UserSession)

View File

@@ -1,5 +1,5 @@
# app/crud/user_async.py
"""Async CRUD operations for User model using SQLAlchemy 2.0 patterns."""
# app/repositories/user.py
"""Repository for User model async CRUD operations using SQLAlchemy 2.0 patterns."""
import logging
from datetime import UTC, datetime
@@ -11,15 +11,16 @@ from sqlalchemy.exc import IntegrityError
from sqlalchemy.ext.asyncio import AsyncSession
from app.core.auth import get_password_hash_async
from app.crud.base import CRUDBase
from app.core.repository_exceptions import DuplicateEntryError, InvalidInputError
from app.models.user import User
from app.repositories.base import BaseRepository
from app.schemas.users import UserCreate, UserUpdate
logger = logging.getLogger(__name__)
class CRUDUser(CRUDBase[User, UserCreate, UserUpdate]):
"""Async CRUD operations for User model."""
class UserRepository(BaseRepository[User, UserCreate, UserUpdate]):
"""Repository for User model."""
async def get_by_email(self, db: AsyncSession, *, email: str) -> User | None:
"""Get user by email address."""
@@ -33,7 +34,6 @@ class CRUDUser(CRUDBase[User, UserCreate, UserUpdate]):
async def create(self, db: AsyncSession, *, obj_in: UserCreate) -> User:
"""Create a new user with async password hashing and error handling."""
try:
# Hash password asynchronously to avoid blocking event loop
password_hash = await get_password_hash_async(obj_in.password)
db_obj = User(
@@ -58,14 +58,50 @@ class CRUDUser(CRUDBase[User, UserCreate, UserUpdate]):
error_msg = str(e.orig) if hasattr(e, "orig") else str(e)
if "email" in error_msg.lower():
logger.warning(f"Duplicate email attempted: {obj_in.email}")
raise ValueError(f"User with email {obj_in.email} already exists")
raise DuplicateEntryError(
f"User with email {obj_in.email} already exists"
)
logger.error(f"Integrity error creating user: {error_msg}")
raise ValueError(f"Database integrity error: {error_msg}")
raise DuplicateEntryError(f"Database integrity error: {error_msg}")
except Exception as e:
await db.rollback()
logger.error(f"Unexpected error creating user: {e!s}", exc_info=True)
raise
async def create_oauth_user(
self,
db: AsyncSession,
*,
email: str,
first_name: str = "User",
last_name: str | None = None,
) -> User:
"""Create a new passwordless user for OAuth sign-in."""
try:
db_obj = User(
email=email,
password_hash=None, # OAuth-only user
first_name=first_name,
last_name=last_name,
is_active=True,
is_superuser=False,
)
db.add(db_obj)
await db.flush() # Get user.id without committing
return db_obj
except IntegrityError as e:
await db.rollback()
error_msg = str(e.orig) if hasattr(e, "orig") else str(e)
if "email" in error_msg.lower():
logger.warning(f"Duplicate email attempted: {email}")
raise DuplicateEntryError(f"User with email {email} already exists")
logger.error(f"Integrity error creating OAuth user: {error_msg}")
raise DuplicateEntryError(f"Database integrity error: {error_msg}")
except Exception as e:
await db.rollback()
logger.error(f"Unexpected error creating OAuth user: {e!s}", exc_info=True)
raise
async def update(
self, db: AsyncSession, *, db_obj: User, obj_in: UserUpdate | dict[str, Any]
) -> User:
@@ -75,8 +111,6 @@ class CRUDUser(CRUDBase[User, UserCreate, UserUpdate]):
else:
update_data = obj_in.model_dump(exclude_unset=True)
# Handle password separately if it exists in update data
# Hash password asynchronously to avoid blocking event loop
if "password" in update_data:
update_data["password_hash"] = await get_password_hash_async(
update_data["password"]
@@ -85,6 +119,15 @@ class CRUDUser(CRUDBase[User, UserCreate, UserUpdate]):
return await super().update(db, db_obj=db_obj, obj_in=update_data)
async def update_password(
self, db: AsyncSession, *, user: User, password_hash: str
) -> User:
"""Set a new password hash on a user and commit."""
user.password_hash = password_hash
await db.commit()
await db.refresh(user)
return user
async def get_multi_with_total(
self,
db: AsyncSession,
@@ -96,43 +139,23 @@ class CRUDUser(CRUDBase[User, UserCreate, UserUpdate]):
filters: dict[str, Any] | None = None,
search: str | None = None,
) -> tuple[list[User], int]:
"""
Get multiple users with total count, filtering, sorting, and search.
Args:
db: Database session
skip: Number of records to skip
limit: Maximum number of records to return
sort_by: Field name to sort by
sort_order: Sort order ("asc" or "desc")
filters: Dictionary of filters (field_name: value)
search: Search term to match against email, first_name, last_name
Returns:
Tuple of (users list, total count)
"""
# Validate pagination
"""Get multiple users with total count, filtering, sorting, and search."""
if skip < 0:
raise ValueError("skip must be non-negative")
raise InvalidInputError("skip must be non-negative")
if limit < 0:
raise ValueError("limit must be non-negative")
raise InvalidInputError("limit must be non-negative")
if limit > 1000:
raise ValueError("Maximum limit is 1000")
raise InvalidInputError("Maximum limit is 1000")
try:
# Build base query
query = select(User)
# Exclude soft-deleted users
query = query.where(User.deleted_at.is_(None))
# Apply filters
if filters:
for field, value in filters.items():
if hasattr(User, field) and value is not None:
query = query.where(getattr(User, field) == value)
# Apply search
if search:
search_filter = or_(
User.email.ilike(f"%{search}%"),
@@ -141,14 +164,12 @@ class CRUDUser(CRUDBase[User, UserCreate, UserUpdate]):
)
query = query.where(search_filter)
# Get total count
from sqlalchemy import func
count_query = select(func.count()).select_from(query.alias())
count_result = await db.execute(count_query)
total = count_result.scalar_one()
# Apply sorting
if sort_by and hasattr(User, sort_by):
sort_column = getattr(User, sort_by)
if sort_order.lower() == "desc":
@@ -156,7 +177,6 @@ class CRUDUser(CRUDBase[User, UserCreate, UserUpdate]):
else:
query = query.order_by(sort_column.asc())
# Apply pagination
query = query.offset(skip).limit(limit)
result = await db.execute(query)
users = list(result.scalars().all())
@@ -170,26 +190,15 @@ class CRUDUser(CRUDBase[User, UserCreate, UserUpdate]):
async def bulk_update_status(
self, db: AsyncSession, *, user_ids: list[UUID], is_active: bool
) -> int:
"""
Bulk update is_active status for multiple users.
Args:
db: Database session
user_ids: List of user IDs to update
is_active: New active status
Returns:
Number of users updated
"""
"""Bulk update is_active status for multiple users."""
try:
if not user_ids:
return 0
# Use UPDATE with WHERE IN for efficiency
stmt = (
update(User)
.where(User.id.in_(user_ids))
.where(User.deleted_at.is_(None)) # Don't update deleted users
.where(User.deleted_at.is_(None))
.values(is_active=is_active, updated_at=datetime.now(UTC))
)
@@ -212,34 +221,20 @@ class CRUDUser(CRUDBase[User, UserCreate, UserUpdate]):
user_ids: list[UUID],
exclude_user_id: UUID | None = None,
) -> int:
"""
Bulk soft delete multiple users.
Args:
db: Database session
user_ids: List of user IDs to delete
exclude_user_id: Optional user ID to exclude (e.g., the admin performing the action)
Returns:
Number of users deleted
"""
"""Bulk soft delete multiple users."""
try:
if not user_ids:
return 0
# Remove excluded user from list
filtered_ids = [uid for uid in user_ids if uid != exclude_user_id]
if not filtered_ids:
return 0
# Use UPDATE with WHERE IN for efficiency
stmt = (
update(User)
.where(User.id.in_(filtered_ids))
.where(
User.deleted_at.is_(None)
) # Don't re-delete already deleted users
.where(User.deleted_at.is_(None))
.values(
deleted_at=datetime.now(UTC),
is_active=False,
@@ -261,12 +256,12 @@ class CRUDUser(CRUDBase[User, UserCreate, UserUpdate]):
def is_active(self, user: User) -> bool:
"""Check if user is active."""
return user.is_active
return bool(user.is_active)
def is_superuser(self, user: User) -> bool:
"""Check if user is a superuser."""
return user.is_superuser
return bool(user.is_superuser)
# Create a singleton instance for use across the application
user = CRUDUser(User)
# Singleton instance
user_repo = UserRepository(User)

View File

@@ -60,8 +60,8 @@ class OAuthAccountCreate(OAuthAccountBase):
user_id: UUID
provider_user_id: str = Field(..., max_length=255)
access_token_encrypted: str | None = None
refresh_token_encrypted: str | None = None
access_token: str | None = None
refresh_token: str | None = None
token_expires_at: datetime | None = None

View File

@@ -48,7 +48,7 @@ class OrganizationCreate(OrganizationBase):
"""Schema for creating a new organization."""
name: str = Field(..., min_length=1, max_length=255)
slug: str = Field(..., min_length=1, max_length=255)
slug: str = Field(..., min_length=1, max_length=255) # pyright: ignore[reportIncompatibleVariableOverride]
class OrganizationUpdate(BaseModel):

View File

@@ -1,5 +1,19 @@
# app/services/__init__.py
from . import oauth_provider_service
from .auth_service import AuthService
from .oauth_service import OAuthService
from .organization_service import OrganizationService, organization_service
from .session_service import SessionService, session_service
from .user_service import UserService, user_service
__all__ = ["AuthService", "OAuthService"]
__all__ = [
"AuthService",
"OAuthService",
"OrganizationService",
"SessionService",
"UserService",
"oauth_provider_service",
"organization_service",
"session_service",
"user_service",
]

View File

@@ -2,7 +2,6 @@
import logging
from uuid import UUID
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.core.auth import (
@@ -14,12 +13,18 @@ from app.core.auth import (
verify_password_async,
)
from app.core.config import settings
from app.core.exceptions import AuthenticationError
from app.core.exceptions import AuthenticationError, DuplicateError
from app.core.repository_exceptions import DuplicateEntryError
from app.models.user import User
from app.repositories.user import user_repo
from app.schemas.users import Token, UserCreate, UserResponse
logger = logging.getLogger(__name__)
# Pre-computed bcrypt hash used for constant-time comparison when user is not found,
# preventing timing attacks that could enumerate valid email addresses.
_DUMMY_HASH = "$2b$12$EixZaYVK1fsbw1ZfbX3OXePaWxn96p36zLFbnJHfxPSEFBzXKiHia"
class AuthService:
"""Service for handling authentication operations"""
@@ -39,10 +44,12 @@ class AuthService:
Returns:
User if authenticated, None otherwise
"""
result = await db.execute(select(User).where(User.email == email))
user = result.scalar_one_or_none()
user = await user_repo.get_by_email(db, email=email)
if not user:
# Perform a dummy verification to match timing of a real bcrypt check,
# preventing email enumeration via response-time differences.
await verify_password_async(password, _DUMMY_HASH)
return None
# Verify password asynchronously to avoid blocking event loop
@@ -71,39 +78,22 @@ class AuthService:
"""
try:
# Check if user already exists
result = await db.execute(select(User).where(User.email == user_data.email))
existing_user = result.scalar_one_or_none()
existing_user = await user_repo.get_by_email(db, email=user_data.email)
if existing_user:
raise AuthenticationError("User with this email already exists")
raise DuplicateError("User with this email already exists")
# Create new user with async password hashing
# Hash password asynchronously to avoid blocking event loop
hashed_password = await get_password_hash_async(user_data.password)
# Create user object from model
user = User(
email=user_data.email,
password_hash=hashed_password,
first_name=user_data.first_name,
last_name=user_data.last_name,
phone_number=user_data.phone_number,
is_active=True,
is_superuser=False,
)
db.add(user)
await db.commit()
await db.refresh(user)
# Delegate creation (hashing + commit) to the repository
user = await user_repo.create(db, obj_in=user_data)
logger.info(f"User created successfully: {user.email}")
return user
except AuthenticationError:
# Re-raise authentication errors without rollback
except (AuthenticationError, DuplicateError):
# Re-raise API exceptions without rollback
raise
except DuplicateEntryError as e:
raise DuplicateError(str(e))
except Exception as e:
# Rollback on any database errors
await db.rollback()
logger.error(f"Error creating user: {e!s}", exc_info=True)
raise AuthenticationError(f"Failed to create user: {e!s}")
@@ -168,8 +158,7 @@ class AuthService:
user_id = token_data.user_id
# Get user from database
result = await db.execute(select(User).where(User.id == user_id))
user = result.scalar_one_or_none()
user = await user_repo.get(db, id=str(user_id))
if not user or not user.is_active:
raise TokenInvalidError("Invalid user or inactive account")
@@ -200,8 +189,7 @@ class AuthService:
AuthenticationError: If current password is incorrect or update fails
"""
try:
result = await db.execute(select(User).where(User.id == user_id))
user = result.scalar_one_or_none()
user = await user_repo.get(db, id=str(user_id))
if not user:
raise AuthenticationError("User not found")
@@ -210,8 +198,8 @@ class AuthService:
raise AuthenticationError("Current password is incorrect")
# Hash new password asynchronously to avoid blocking event loop
user.password_hash = await get_password_hash_async(new_password)
await db.commit()
new_hash = await get_password_hash_async(new_password)
await user_repo.update_password(db, user=user, password_hash=new_hash)
logger.info(f"Password changed successfully for user {user_id}")
return True
@@ -226,3 +214,32 @@ class AuthService:
f"Error changing password for user {user_id}: {e!s}", exc_info=True
)
raise AuthenticationError(f"Failed to change password: {e!s}")
@staticmethod
async def reset_password(
db: AsyncSession, *, email: str, new_password: str
) -> User:
"""
Reset a user's password without requiring the current password.
Args:
db: Database session
email: User email address
new_password: New password to set
Returns:
Updated user
Raises:
AuthenticationError: If user not found or inactive
"""
user = await user_repo.get_by_email(db, email=email)
if not user:
raise AuthenticationError("User not found")
if not user.is_active:
raise AuthenticationError("User account is inactive")
new_hash = await get_password_hash_async(new_password)
user = await user_repo.update_password(db, user=user, password_hash=new_hash)
logger.info(f"Password reset successfully for {email}")
return user

View File

@@ -25,15 +25,19 @@ from datetime import UTC, datetime, timedelta
from typing import Any
from uuid import UUID
from jose import jwt
from sqlalchemy import and_, delete, select
from jose import JWTError, jwt
from jose.exceptions import ExpiredSignatureError
from sqlalchemy.ext.asyncio import AsyncSession
from app.core.config import settings
from app.models.oauth_authorization_code import OAuthAuthorizationCode
from app.models.oauth_client import OAuthClient
from app.models.oauth_provider_token import OAuthConsent, OAuthProviderRefreshToken
from app.models.user import User
from app.repositories.oauth_authorization_code import oauth_authorization_code_repo
from app.repositories.oauth_client import oauth_client_repo
from app.repositories.oauth_consent import oauth_consent_repo
from app.repositories.oauth_provider_token import oauth_provider_token_repo
from app.repositories.user import user_repo
from app.schemas.oauth import OAuthClientCreate
logger = logging.getLogger(__name__)
@@ -161,15 +165,7 @@ def join_scope(scopes: list[str]) -> str:
async def get_client(db: AsyncSession, client_id: str) -> OAuthClient | None:
"""Get OAuth client by client_id."""
result = await db.execute(
select(OAuthClient).where(
and_(
OAuthClient.client_id == client_id,
OAuthClient.is_active == True, # noqa: E712
)
)
)
return result.scalar_one_or_none()
return await oauth_client_repo.get_by_client_id(db, client_id=client_id)
async def validate_client(
@@ -204,21 +200,19 @@ async def validate_client(
if not client.client_secret_hash:
raise InvalidClientError("Client not configured with secret")
# SECURITY: Verify secret using bcrypt (not SHA-256)
# Supports both bcrypt and legacy SHA-256 hashes for migration
# SECURITY: Verify secret using bcrypt
from app.core.auth import verify_password
stored_hash = str(client.client_secret_hash)
if stored_hash.startswith("$2"):
# New bcrypt format
if not verify_password(client_secret, stored_hash):
raise InvalidClientError("Invalid client secret")
else:
# Legacy SHA-256 format
computed_hash = hashlib.sha256(client_secret.encode()).hexdigest()
if not secrets.compare_digest(computed_hash, stored_hash):
raise InvalidClientError("Invalid client secret")
if not stored_hash.startswith("$2"):
raise InvalidClientError(
"Client secret uses deprecated hash format. "
"Please regenerate your client credentials."
)
if not verify_password(client_secret, stored_hash):
raise InvalidClientError("Invalid client secret")
return client
@@ -311,23 +305,20 @@ async def create_authorization_code(
minutes=AUTHORIZATION_CODE_EXPIRY_MINUTES
)
auth_code = OAuthAuthorizationCode(
await oauth_authorization_code_repo.create_code(
db,
code=code,
client_id=client.client_id,
user_id=user.id,
redirect_uri=redirect_uri,
scope=scope,
expires_at=expires_at,
code_challenge=code_challenge,
code_challenge_method=code_challenge_method,
state=state,
nonce=nonce,
expires_at=expires_at,
used=False,
)
db.add(auth_code)
await db.commit()
logger.info(
f"Created authorization code for user {user.id} and client {client.client_id}"
)
@@ -366,30 +357,14 @@ async def exchange_authorization_code(
"""
# Atomically mark code as used and fetch it (prevents race condition)
# RFC 6749 Section 4.1.2: Authorization codes MUST be single-use
from sqlalchemy import update
# First, atomically mark the code as used and get affected count
update_stmt = (
update(OAuthAuthorizationCode)
.where(
and_(
OAuthAuthorizationCode.code == code,
OAuthAuthorizationCode.used == False, # noqa: E712
)
)
.values(used=True)
.returning(OAuthAuthorizationCode.id)
updated_id = await oauth_authorization_code_repo.consume_code_atomically(
db, code=code
)
result = await db.execute(update_stmt)
updated_id = result.scalar_one_or_none()
if not updated_id:
# Either code doesn't exist or was already used
# Check if it exists to provide appropriate error
check_result = await db.execute(
select(OAuthAuthorizationCode).where(OAuthAuthorizationCode.code == code)
)
existing_code = check_result.scalar_one_or_none()
existing_code = await oauth_authorization_code_repo.get_by_code(db, code=code)
if existing_code and existing_code.used:
# Code reuse is a security incident - revoke all tokens for this grant
@@ -404,11 +379,9 @@ async def exchange_authorization_code(
raise InvalidGrantError("Invalid authorization code")
# Now fetch the full auth code record
auth_code_result = await db.execute(
select(OAuthAuthorizationCode).where(OAuthAuthorizationCode.id == updated_id)
)
auth_code = auth_code_result.scalar_one()
await db.commit()
auth_code = await oauth_authorization_code_repo.get_by_id(db, code_id=updated_id)
if auth_code is None:
raise InvalidGrantError("Authorization code not found after consumption")
if auth_code.is_expired:
raise InvalidGrantError("Authorization code has expired")
@@ -452,8 +425,7 @@ async def exchange_authorization_code(
raise InvalidGrantError("PKCE required for public clients")
# Get user
user_result = await db.execute(select(User).where(User.id == auth_code.user_id))
user = user_result.scalar_one_or_none()
user = await user_repo.get(db, id=str(auth_code.user_id))
if not user or not user.is_active:
raise InvalidGrantError("User not found or inactive")
@@ -543,7 +515,8 @@ async def create_tokens(
refresh_token_hash = hash_token(refresh_token)
# Store refresh token in database
refresh_token_record = OAuthProviderRefreshToken(
await oauth_provider_token_repo.create_token(
db,
token_hash=refresh_token_hash,
jti=jti,
client_id=client.client_id,
@@ -553,8 +526,6 @@ async def create_tokens(
device_info=device_info,
ip_address=ip_address,
)
db.add(refresh_token_record)
await db.commit()
logger.info(f"Issued tokens for user {user.id} to client {client.client_id}")
@@ -599,12 +570,9 @@ async def refresh_tokens(
"""
# Find refresh token
token_hash = hash_token(refresh_token)
result = await db.execute(
select(OAuthProviderRefreshToken).where(
OAuthProviderRefreshToken.token_hash == token_hash
)
token_record = await oauth_provider_token_repo.get_by_token_hash(
db, token_hash=token_hash
)
token_record: OAuthProviderRefreshToken | None = result.scalar_one_or_none()
if not token_record:
raise InvalidGrantError("Invalid refresh token")
@@ -631,8 +599,7 @@ async def refresh_tokens(
)
# Get user
user_result = await db.execute(select(User).where(User.id == token_record.user_id))
user = user_result.scalar_one_or_none()
user = await user_repo.get(db, id=str(token_record.user_id))
if not user or not user.is_active:
raise InvalidGrantError("User not found or inactive")
@@ -648,9 +615,7 @@ async def refresh_tokens(
final_scope = token_scope
# Revoke old refresh token (token rotation)
token_record.revoked = True # type: ignore[assignment]
token_record.last_used_at = datetime.now(UTC) # type: ignore[assignment]
await db.commit()
await oauth_provider_token_repo.revoke(db, token=token_record)
# Issue new tokens
device = str(token_record.device_info) if token_record.device_info else None
@@ -697,28 +662,22 @@ async def revoke_token(
# Try as refresh token first (more likely)
if token_type_hint != "access_token":
token_hash = hash_token(token)
result = await db.execute(
select(OAuthProviderRefreshToken).where(
OAuthProviderRefreshToken.token_hash == token_hash
)
refresh_record = await oauth_provider_token_repo.get_by_token_hash(
db, token_hash=token_hash
)
refresh_record = result.scalar_one_or_none()
if refresh_record:
# Validate client if provided
if client_id and refresh_record.client_id != client_id:
raise InvalidClientError("Token was not issued to this client")
refresh_record.revoked = True # type: ignore[assignment]
await db.commit()
await oauth_provider_token_repo.revoke(db, token=refresh_record)
logger.info(f"Revoked refresh token {refresh_record.jti[:8]}...")
return True
# Try as access token (JWT)
if token_type_hint != "refresh_token":
try:
from jose.exceptions import JWTError
payload = jwt.decode(
token,
settings.SECRET_KEY,
@@ -731,22 +690,18 @@ async def revoke_token(
jti = payload.get("jti")
if jti:
# Find and revoke the associated refresh token
result = await db.execute(
select(OAuthProviderRefreshToken).where(
OAuthProviderRefreshToken.jti == jti
)
)
refresh_record = result.scalar_one_or_none()
refresh_record = await oauth_provider_token_repo.get_by_jti(db, jti=jti)
if refresh_record:
if client_id and refresh_record.client_id != client_id:
raise InvalidClientError("Token was not issued to this client")
refresh_record.revoked = True # type: ignore[assignment]
await db.commit()
await oauth_provider_token_repo.revoke(db, token=refresh_record)
logger.info(
f"Revoked refresh token via access token JTI {jti[:8]}..."
)
return True
except (JWTError, Exception): # noqa: S110 - Intentional: invalid JWT not an error
except JWTError:
pass
except Exception: # noqa: S110 - Intentional: invalid JWT not an error
pass
return False
@@ -770,24 +725,11 @@ async def revoke_tokens_for_user_client(
Returns:
Number of tokens revoked
"""
result = await db.execute(
select(OAuthProviderRefreshToken).where(
and_(
OAuthProviderRefreshToken.user_id == user_id,
OAuthProviderRefreshToken.client_id == client_id,
OAuthProviderRefreshToken.revoked == False, # noqa: E712
)
)
count = await oauth_provider_token_repo.revoke_all_for_user_client(
db, user_id=user_id, client_id=client_id
)
tokens = result.scalars().all()
count = 0
for token in tokens:
token.revoked = True # type: ignore[assignment]
count += 1
if count > 0:
await db.commit()
logger.warning(
f"Revoked {count} tokens for user {user_id} and client {client_id}"
)
@@ -808,23 +750,9 @@ async def revoke_all_user_tokens(db: AsyncSession, user_id: UUID) -> int:
Returns:
Number of tokens revoked
"""
result = await db.execute(
select(OAuthProviderRefreshToken).where(
and_(
OAuthProviderRefreshToken.user_id == user_id,
OAuthProviderRefreshToken.revoked == False, # noqa: E712
)
)
)
tokens = result.scalars().all()
count = 0
for token in tokens:
token.revoked = True # type: ignore[assignment]
count += 1
count = await oauth_provider_token_repo.revoke_all_for_user(db, user_id=user_id)
if count > 0:
await db.commit()
logger.info(f"Revoked {count} OAuth provider tokens for user {user_id}")
return count
@@ -864,8 +792,6 @@ async def introspect_token(
# Try as access token (JWT) first
if token_type_hint != "refresh_token":
try:
from jose.exceptions import ExpiredSignatureError, JWTError
payload = jwt.decode(
token,
settings.SECRET_KEY,
@@ -878,12 +804,7 @@ async def introspect_token(
# Check if associated refresh token is revoked
jti = payload.get("jti")
if jti:
result = await db.execute(
select(OAuthProviderRefreshToken).where(
OAuthProviderRefreshToken.jti == jti
)
)
refresh_record = result.scalar_one_or_none()
refresh_record = await oauth_provider_token_repo.get_by_jti(db, jti=jti)
if refresh_record and refresh_record.revoked:
return {"active": False}
@@ -901,18 +822,17 @@ async def introspect_token(
}
except ExpiredSignatureError:
return {"active": False}
except (JWTError, Exception): # noqa: S110 - Intentional: invalid JWT falls through to refresh token check
except JWTError:
pass
except Exception: # noqa: S110 - Intentional: invalid JWT falls through to refresh token check
pass
# Try as refresh token
if token_type_hint != "access_token":
token_hash = hash_token(token)
result = await db.execute(
select(OAuthProviderRefreshToken).where(
OAuthProviderRefreshToken.token_hash == token_hash
)
refresh_record = await oauth_provider_token_repo.get_by_token_hash(
db, token_hash=token_hash
)
refresh_record = result.scalar_one_or_none()
if refresh_record and refresh_record.is_valid:
return {
@@ -937,17 +857,11 @@ async def get_consent(
db: AsyncSession,
user_id: UUID,
client_id: str,
) -> OAuthConsent | None:
):
"""Get existing consent record for user-client pair."""
result = await db.execute(
select(OAuthConsent).where(
and_(
OAuthConsent.user_id == user_id,
OAuthConsent.client_id == client_id,
)
)
return await oauth_consent_repo.get_consent(
db, user_id=user_id, client_id=client_id
)
return result.scalar_one_or_none()
async def check_consent(
@@ -972,31 +886,15 @@ async def grant_consent(
user_id: UUID,
client_id: str,
scopes: list[str],
) -> OAuthConsent:
):
"""
Grant or update consent for a user-client pair.
If consent already exists, updates the granted scopes.
"""
consent = await get_consent(db, user_id, client_id)
if consent:
# Merge scopes
granted = str(consent.granted_scopes) if consent.granted_scopes else ""
existing = set(parse_scope(granted))
new_scopes = existing | set(scopes)
consent.granted_scopes = join_scope(list(new_scopes)) # type: ignore[assignment]
else:
consent = OAuthConsent(
user_id=user_id,
client_id=client_id,
granted_scopes=join_scope(scopes),
)
db.add(consent)
await db.commit()
await db.refresh(consent)
return consent
return await oauth_consent_repo.grant_consent(
db, user_id=user_id, client_id=client_id, scopes=scopes
)
async def revoke_consent(
@@ -1009,21 +907,13 @@ async def revoke_consent(
Returns True if consent was found and revoked.
"""
# Delete consent record
result = await db.execute(
delete(OAuthConsent).where(
and_(
OAuthConsent.user_id == user_id,
OAuthConsent.client_id == client_id,
)
)
)
# Revoke all tokens
# Revoke all tokens first
await revoke_tokens_for_user_client(db, user_id, client_id)
await db.commit()
return result.rowcount > 0 # type: ignore[attr-defined]
# Delete consent record
return await oauth_consent_repo.revoke_consent(
db, user_id=user_id, client_id=client_id
)
# ============================================================================
@@ -1031,6 +921,26 @@ async def revoke_consent(
# ============================================================================
async def register_client(db: AsyncSession, client_data: OAuthClientCreate) -> tuple:
"""Create a new OAuth client. Returns (client, secret)."""
return await oauth_client_repo.create_client(db, obj_in=client_data)
async def list_clients(db: AsyncSession) -> list:
"""List all registered OAuth clients."""
return await oauth_client_repo.get_all_clients(db)
async def delete_client_by_id(db: AsyncSession, client_id: str) -> None:
"""Delete an OAuth client by client_id."""
await oauth_client_repo.delete_client(db, client_id=client_id)
async def list_user_consents(db: AsyncSession, user_id: UUID) -> list[dict]:
"""Get all OAuth consents for a user with client details."""
return await oauth_consent_repo.get_user_consents_with_clients(db, user_id=user_id)
async def cleanup_expired_codes(db: AsyncSession) -> int:
"""
Delete expired authorization codes.
@@ -1040,13 +950,7 @@ async def cleanup_expired_codes(db: AsyncSession) -> int:
Returns:
Number of codes deleted
"""
result = await db.execute(
delete(OAuthAuthorizationCode).where(
OAuthAuthorizationCode.expires_at < datetime.now(UTC)
)
)
await db.commit()
return result.rowcount # type: ignore[attr-defined]
return await oauth_authorization_code_repo.cleanup_expired(db)
async def cleanup_expired_tokens(db: AsyncSession) -> int:
@@ -1058,12 +962,4 @@ async def cleanup_expired_tokens(db: AsyncSession) -> int:
Returns:
Number of tokens deleted
"""
# Delete tokens that are both expired AND revoked (or just very old)
cutoff = datetime.now(UTC) - timedelta(days=7)
result = await db.execute(
delete(OAuthProviderRefreshToken).where(
OAuthProviderRefreshToken.expires_at < cutoff
)
)
await db.commit()
return result.rowcount # type: ignore[attr-defined]
return await oauth_provider_token_repo.cleanup_expired(db, cutoff_days=7)

View File

@@ -19,14 +19,15 @@ from typing import TypedDict, cast
from uuid import UUID
from authlib.integrations.httpx_client import AsyncOAuth2Client
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.core.auth import create_access_token, create_refresh_token
from app.core.config import settings
from app.core.exceptions import AuthenticationError
from app.crud import oauth_account, oauth_state
from app.models.user import User
from app.repositories.oauth_account import oauth_account_repo as oauth_account
from app.repositories.oauth_state import oauth_state_repo as oauth_state
from app.repositories.user import user_repo
from app.schemas.oauth import (
OAuthAccountCreate,
OAuthCallbackResponse,
@@ -38,19 +39,22 @@ from app.schemas.oauth import (
logger = logging.getLogger(__name__)
class OAuthProviderConfig(TypedDict, total=False):
"""Type definition for OAuth provider configuration."""
class _OAuthProviderConfigRequired(TypedDict):
name: str
icon: str
authorize_url: str
token_url: str
userinfo_url: str
email_url: str # Optional, GitHub-only
scopes: list[str]
supports_pkce: bool
class OAuthProviderConfig(_OAuthProviderConfigRequired, total=False):
"""Type definition for OAuth provider configuration."""
email_url: str # Optional, GitHub-only
# Provider configurations
OAUTH_PROVIDERS: dict[str, OAuthProviderConfig] = {
"google": {
@@ -343,7 +347,9 @@ class OAuthService:
await oauth_account.update_tokens(
db,
account=existing_oauth,
access_token_encrypted=token.get("access_token"), refresh_token_encrypted=token.get("refresh_token"), token_expires_at=datetime.now(UTC)
access_token=token.get("access_token"),
refresh_token=token.get("refresh_token"),
token_expires_at=datetime.now(UTC)
+ timedelta(seconds=token.get("expires_in", 3600)),
)
@@ -351,10 +357,7 @@ class OAuthService:
elif state_record.user_id:
# Account linking flow (user is already logged in)
result = await db.execute(
select(User).where(User.id == state_record.user_id)
)
user = result.scalar_one_or_none()
user = await user_repo.get(db, id=str(state_record.user_id))
if not user:
raise AuthenticationError("User not found for account linking")
@@ -375,7 +378,9 @@ class OAuthService:
provider=provider,
provider_user_id=provider_user_id,
provider_email=provider_email,
access_token_encrypted=token.get("access_token"), refresh_token_encrypted=token.get("refresh_token"), token_expires_at=datetime.now(UTC)
access_token=token.get("access_token"),
refresh_token=token.get("refresh_token"),
token_expires_at=datetime.now(UTC)
+ timedelta(seconds=token.get("expires_in", 3600))
if token.get("expires_in")
else None,
@@ -389,10 +394,7 @@ class OAuthService:
user = None
if provider_email and settings.OAUTH_AUTO_LINK_BY_EMAIL:
result = await db.execute(
select(User).where(User.email == provider_email)
)
user = result.scalar_one_or_none()
user = await user_repo.get_by_email(db, email=provider_email)
if user:
# Auto-link to existing user
@@ -416,8 +418,8 @@ class OAuthService:
provider=provider,
provider_user_id=provider_user_id,
provider_email=provider_email,
access_token_encrypted=token.get("access_token"),
refresh_token_encrypted=token.get("refresh_token"),
access_token=token.get("access_token"),
refresh_token=token.get("refresh_token"),
token_expires_at=datetime.now(UTC)
+ timedelta(seconds=token.get("expires_in", 3600))
if token.get("expires_in")
@@ -486,7 +488,7 @@ class OAuthService:
# GitHub requires separate request for email
if provider == "github" and not user_info.get("email"):
email_resp = await client.get(
config["email_url"],
config["email_url"], # pyright: ignore[reportTypedDictNotRequiredAccess]
headers=headers,
)
email_resp.raise_for_status()
@@ -644,14 +646,15 @@ class OAuthService:
provider=provider,
provider_user_id=provider_user_id,
provider_email=email,
access_token_encrypted=token.get("access_token"), refresh_token_encrypted=token.get("refresh_token"), token_expires_at=datetime.now(UTC)
access_token=token.get("access_token"),
refresh_token=token.get("refresh_token"),
token_expires_at=datetime.now(UTC)
+ timedelta(seconds=token.get("expires_in", 3600))
if token.get("expires_in")
else None,
)
await oauth_account.create_account(db, obj_in=oauth_create)
await db.commit()
await db.refresh(user)
return user
@@ -701,6 +704,20 @@ class OAuthService:
logger.info(f"OAuth provider unlinked: {provider} from {user.email}")
return True
@staticmethod
async def get_user_accounts(db: AsyncSession, *, user_id: UUID) -> list:
"""Get all OAuth accounts linked to a user."""
return await oauth_account.get_user_accounts(db, user_id=user_id)
@staticmethod
async def get_user_account_by_provider(
db: AsyncSession, *, user_id: UUID, provider: str
):
"""Get a specific OAuth account for a user and provider."""
return await oauth_account.get_user_account_by_provider(
db, user_id=user_id, provider=provider
)
@staticmethod
async def cleanup_expired_states(db: AsyncSession) -> int:
"""

View File

@@ -0,0 +1,155 @@
# app/services/organization_service.py
"""Service layer for organization operations — delegates to OrganizationRepository."""
import logging
from typing import Any
from uuid import UUID
from sqlalchemy.ext.asyncio import AsyncSession
from app.core.exceptions import NotFoundError
from app.models.organization import Organization
from app.models.user_organization import OrganizationRole, UserOrganization
from app.repositories.organization import OrganizationRepository, organization_repo
from app.schemas.organizations import OrganizationCreate, OrganizationUpdate
logger = logging.getLogger(__name__)
class OrganizationService:
"""Service for organization management operations."""
def __init__(
self, organization_repository: OrganizationRepository | None = None
) -> None:
self._repo = organization_repository or organization_repo
async def get_organization(self, db: AsyncSession, org_id: str) -> Organization:
"""Get organization by ID, raising NotFoundError if not found."""
org = await self._repo.get(db, id=org_id)
if not org:
raise NotFoundError(f"Organization {org_id} not found")
return org
async def create_organization(
self, db: AsyncSession, *, obj_in: OrganizationCreate
) -> Organization:
"""Create a new organization."""
return await self._repo.create(db, obj_in=obj_in)
async def update_organization(
self,
db: AsyncSession,
*,
org: Organization,
obj_in: OrganizationUpdate | dict[str, Any],
) -> Organization:
"""Update an existing organization."""
return await self._repo.update(db, db_obj=org, obj_in=obj_in)
async def remove_organization(self, db: AsyncSession, org_id: str) -> None:
"""Permanently delete an organization by ID."""
await self._repo.remove(db, id=org_id)
async def get_member_count(self, db: AsyncSession, *, organization_id: UUID) -> int:
"""Get number of active members in an organization."""
return await self._repo.get_member_count(db, organization_id=organization_id)
async def get_multi_with_member_counts(
self,
db: AsyncSession,
*,
skip: int = 0,
limit: int = 100,
is_active: bool | None = None,
search: str | None = None,
) -> tuple[list[dict[str, Any]], int]:
"""List organizations with member counts and pagination."""
return await self._repo.get_multi_with_member_counts(
db, skip=skip, limit=limit, is_active=is_active, search=search
)
async def get_user_organizations_with_details(
self,
db: AsyncSession,
*,
user_id: UUID,
is_active: bool | None = None,
) -> list[dict[str, Any]]:
"""Get all organizations a user belongs to, with membership details."""
return await self._repo.get_user_organizations_with_details(
db, user_id=user_id, is_active=is_active
)
async def get_organization_members(
self,
db: AsyncSession,
*,
organization_id: UUID,
skip: int = 0,
limit: int = 100,
is_active: bool | None = True,
) -> tuple[list[dict[str, Any]], int]:
"""Get members of an organization with pagination."""
return await self._repo.get_organization_members(
db,
organization_id=organization_id,
skip=skip,
limit=limit,
is_active=is_active,
)
async def add_member(
self,
db: AsyncSession,
*,
organization_id: UUID,
user_id: UUID,
role: OrganizationRole = OrganizationRole.MEMBER,
) -> UserOrganization:
"""Add a user to an organization."""
return await self._repo.add_user(
db, organization_id=organization_id, user_id=user_id, role=role
)
async def remove_member(
self,
db: AsyncSession,
*,
organization_id: UUID,
user_id: UUID,
) -> bool:
"""Remove a user from an organization. Returns True if found and removed."""
return await self._repo.remove_user(
db, organization_id=organization_id, user_id=user_id
)
async def get_user_role_in_org(
self, db: AsyncSession, *, user_id: UUID, organization_id: UUID
) -> OrganizationRole | None:
"""Get the role of a user in an organization."""
return await self._repo.get_user_role_in_org(
db, user_id=user_id, organization_id=organization_id
)
async def get_org_distribution(
self, db: AsyncSession, *, limit: int = 6
) -> list[dict[str, Any]]:
"""Return top organizations by member count for admin dashboard."""
from sqlalchemy import func, select
result = await db.execute(
select(
Organization.name,
func.count(UserOrganization.user_id).label("count"),
)
.join(UserOrganization, Organization.id == UserOrganization.organization_id)
.group_by(Organization.name)
.order_by(func.count(UserOrganization.user_id).desc())
.limit(limit)
)
return [{"name": row.name, "value": row.count} for row in result.all()]
# Default singleton
organization_service = OrganizationService()

View File

@@ -8,7 +8,7 @@ import logging
from datetime import UTC, datetime
from app.core.database import SessionLocal
from app.crud.session import session as session_crud
from app.repositories.session import session_repo as session_crud
logger = logging.getLogger(__name__)

View File

@@ -0,0 +1,97 @@
# app/services/session_service.py
"""Service layer for session operations — delegates to SessionRepository."""
import logging
from datetime import datetime
from sqlalchemy.ext.asyncio import AsyncSession
from app.models.user_session import UserSession
from app.repositories.session import SessionRepository, session_repo
from app.schemas.sessions import SessionCreate
logger = logging.getLogger(__name__)
class SessionService:
"""Service for user session management operations."""
def __init__(self, session_repository: SessionRepository | None = None) -> None:
self._repo = session_repository or session_repo
async def create_session(
self, db: AsyncSession, *, obj_in: SessionCreate
) -> UserSession:
"""Create a new session record."""
return await self._repo.create_session(db, obj_in=obj_in)
async def get_session(
self, db: AsyncSession, session_id: str
) -> UserSession | None:
"""Get session by ID."""
return await self._repo.get(db, id=session_id)
async def get_user_sessions(
self, db: AsyncSession, *, user_id: str, active_only: bool = True
) -> list[UserSession]:
"""Get all sessions for a user."""
return await self._repo.get_user_sessions(
db, user_id=user_id, active_only=active_only
)
async def get_active_by_jti(
self, db: AsyncSession, *, jti: str
) -> UserSession | None:
"""Get active session by refresh token JTI."""
return await self._repo.get_active_by_jti(db, jti=jti)
async def get_by_jti(self, db: AsyncSession, *, jti: str) -> UserSession | None:
"""Get session by refresh token JTI (active or inactive)."""
return await self._repo.get_by_jti(db, jti=jti)
async def deactivate(
self, db: AsyncSession, *, session_id: str
) -> UserSession | None:
"""Deactivate a session (logout from device)."""
return await self._repo.deactivate(db, session_id=session_id)
async def deactivate_all_user_sessions(
self, db: AsyncSession, *, user_id: str
) -> int:
"""Deactivate all sessions for a user. Returns count deactivated."""
return await self._repo.deactivate_all_user_sessions(db, user_id=user_id)
async def update_refresh_token(
self,
db: AsyncSession,
*,
session: UserSession,
new_jti: str,
new_expires_at: datetime,
) -> UserSession:
"""Update session with a rotated refresh token."""
return await self._repo.update_refresh_token(
db, session=session, new_jti=new_jti, new_expires_at=new_expires_at
)
async def cleanup_expired_for_user(self, db: AsyncSession, *, user_id: str) -> int:
"""Remove expired sessions for a user. Returns count removed."""
return await self._repo.cleanup_expired_for_user(db, user_id=user_id)
async def get_all_sessions(
self,
db: AsyncSession,
*,
skip: int = 0,
limit: int = 100,
active_only: bool = True,
with_user: bool = True,
) -> tuple[list[UserSession], int]:
"""Get all sessions with pagination (admin only)."""
return await self._repo.get_all_sessions(
db, skip=skip, limit=limit, active_only=active_only, with_user=with_user
)
# Default singleton
session_service = SessionService()

View File

@@ -0,0 +1,120 @@
# app/services/user_service.py
"""Service layer for user operations — delegates to UserRepository."""
import logging
from typing import Any
from uuid import UUID
from sqlalchemy.ext.asyncio import AsyncSession
from app.core.exceptions import NotFoundError
from app.models.user import User
from app.repositories.user import UserRepository, user_repo
from app.schemas.users import UserCreate, UserUpdate
logger = logging.getLogger(__name__)
class UserService:
"""Service for user management operations."""
def __init__(self, user_repository: UserRepository | None = None) -> None:
self._repo = user_repository or user_repo
async def get_user(self, db: AsyncSession, user_id: str) -> User:
"""Get user by ID, raising NotFoundError if not found."""
user = await self._repo.get(db, id=user_id)
if not user:
raise NotFoundError(f"User {user_id} not found")
return user
async def get_by_email(self, db: AsyncSession, email: str) -> User | None:
"""Get user by email address."""
return await self._repo.get_by_email(db, email=email)
async def create_user(self, db: AsyncSession, user_data: UserCreate) -> User:
"""Create a new user."""
return await self._repo.create(db, obj_in=user_data)
async def update_user(
self, db: AsyncSession, *, user: User, obj_in: UserUpdate | dict[str, Any]
) -> User:
"""Update an existing user."""
return await self._repo.update(db, db_obj=user, obj_in=obj_in)
async def soft_delete_user(self, db: AsyncSession, user_id: str) -> None:
"""Soft-delete a user by ID."""
await self._repo.soft_delete(db, id=user_id)
async def list_users(
self,
db: AsyncSession,
*,
skip: int = 0,
limit: int = 100,
sort_by: str | None = None,
sort_order: str = "asc",
filters: dict[str, Any] | None = None,
search: str | None = None,
) -> tuple[list[User], int]:
"""List users with pagination, sorting, filtering, and search."""
return await self._repo.get_multi_with_total(
db,
skip=skip,
limit=limit,
sort_by=sort_by,
sort_order=sort_order,
filters=filters,
search=search,
)
async def bulk_update_status(
self, db: AsyncSession, *, user_ids: list[UUID], is_active: bool
) -> int:
"""Bulk update active status for multiple users. Returns count updated."""
return await self._repo.bulk_update_status(
db, user_ids=user_ids, is_active=is_active
)
async def bulk_soft_delete(
self,
db: AsyncSession,
*,
user_ids: list[UUID],
exclude_user_id: UUID | None = None,
) -> int:
"""Bulk soft-delete multiple users. Returns count deleted."""
return await self._repo.bulk_soft_delete(
db, user_ids=user_ids, exclude_user_id=exclude_user_id
)
async def get_stats(self, db: AsyncSession) -> dict[str, Any]:
"""Return user stats needed for the admin dashboard."""
from sqlalchemy import func, select
total_users = (
await db.execute(select(func.count()).select_from(User))
).scalar() or 0
active_count = (
await db.execute(
select(func.count()).select_from(User).where(User.is_active)
)
).scalar() or 0
inactive_count = (
await db.execute(
select(func.count()).select_from(User).where(User.is_active.is_(False))
)
).scalar() or 0
all_users = list(
(await db.execute(select(User).order_by(User.created_at))).scalars().all()
)
return {
"total_users": total_users,
"active_count": active_count,
"inactive_count": inactive_count,
"all_users": all_users,
}
# Default singleton
user_service = UserService()

View File

@@ -65,10 +65,10 @@ async def setup_async_test_db():
async with test_engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
AsyncTestingSessionLocal = sessionmaker(
AsyncTestingSessionLocal = sessionmaker( # pyright: ignore[reportCallIssue]
autocommit=False,
autoflush=False,
bind=test_engine,
bind=test_engine, # pyright: ignore[reportArgumentType]
expire_on_commit=False,
class_=AsyncSession,
)

View File

@@ -72,7 +72,7 @@ dev = [
# Development tools
"ruff>=0.8.0", # All-in-one: linting, formatting, import sorting
"mypy>=1.8.0", # Type checking
"pyright>=1.1.390", # Type checking
]
# E2E testing with real PostgreSQL (requires Docker)
@@ -185,120 +185,6 @@ indent-style = "space"
skip-magic-trailing-comma = false
line-ending = "lf"
# ============================================================================
# mypy Configuration - Type Checking
# ============================================================================
[tool.mypy]
python_version = "3.12"
warn_return_any = false # SQLAlchemy queries return Any - overly strict
warn_unused_configs = true
disallow_untyped_defs = false # Gradual typing - enable later
disallow_incomplete_defs = false
check_untyped_defs = true
no_implicit_optional = true
warn_redundant_casts = true
warn_unused_ignores = true
warn_no_return = true
strict_equality = true
ignore_missing_imports = false
explicit_package_bases = true
namespace_packages = true
# Pydantic plugin for better validation
plugins = ["pydantic.mypy"]
# Per-module options
[[tool.mypy.overrides]]
module = "alembic.*"
ignore_errors = true
[[tool.mypy.overrides]]
module = "app.alembic.*"
ignore_errors = true
[[tool.mypy.overrides]]
module = "sqlalchemy.*"
ignore_missing_imports = true
[[tool.mypy.overrides]]
module = "fastapi_utils.*"
ignore_missing_imports = true
[[tool.mypy.overrides]]
module = "slowapi.*"
ignore_missing_imports = true
[[tool.mypy.overrides]]
module = "jose.*"
ignore_missing_imports = true
[[tool.mypy.overrides]]
module = "passlib.*"
ignore_missing_imports = true
[[tool.mypy.overrides]]
module = "pydantic_settings.*"
ignore_missing_imports = true
[[tool.mypy.overrides]]
module = "fastapi.*"
ignore_missing_imports = true
[[tool.mypy.overrides]]
module = "apscheduler.*"
ignore_missing_imports = true
[[tool.mypy.overrides]]
module = "starlette.*"
ignore_missing_imports = true
[[tool.mypy.overrides]]
module = "authlib.*"
ignore_missing_imports = true
# SQLAlchemy ORM models - Column descriptors cause type confusion
[[tool.mypy.overrides]]
module = "app.models.*"
disable_error_code = ["assignment", "arg-type", "return-value"]
# CRUD operations - Generic ModelType and SQLAlchemy Result issues
[[tool.mypy.overrides]]
module = "app.crud.*"
disable_error_code = ["attr-defined", "assignment", "arg-type", "return-value"]
# API routes - SQLAlchemy Column to Pydantic schema conversions
[[tool.mypy.overrides]]
module = "app.api.routes.*"
disable_error_code = ["arg-type", "call-arg", "call-overload", "assignment"]
# API dependencies - Similar SQLAlchemy Column issues
[[tool.mypy.overrides]]
module = "app.api.dependencies.*"
disable_error_code = ["arg-type"]
# FastAPI exception handlers have correct signatures despite mypy warnings
[[tool.mypy.overrides]]
module = "app.main"
disable_error_code = ["arg-type"]
# Auth service - SQLAlchemy Column issues
[[tool.mypy.overrides]]
module = "app.services.auth_service"
disable_error_code = ["assignment", "arg-type"]
# Test utils - Testing patterns
[[tool.mypy.overrides]]
module = "app.utils.auth_test_utils"
disable_error_code = ["assignment", "arg-type"]
# ============================================================================
# Pydantic mypy plugin configuration
# ============================================================================
[tool.pydantic-mypy]
init_forbid_extra = true
init_typed = true
warn_required_dynamic_aliases = true
# ============================================================================
# Pytest Configuration
# ============================================================================

View File

@@ -0,0 +1,23 @@
{
"include": ["app"],
"exclude": ["app/alembic"],
"pythonVersion": "3.12",
"venvPath": ".",
"venv": ".venv",
"typeCheckingMode": "standard",
"reportMissingImports": true,
"reportMissingTypeStubs": false,
"reportUnknownMemberType": false,
"reportUnknownVariableType": false,
"reportUnknownArgumentType": false,
"reportUnknownParameterType": false,
"reportUnknownLambdaType": false,
"reportReturnType": true,
"reportUnusedImport": false,
"reportGeneralTypeIssues": false,
"reportAttributeAccessIssue": false,
"reportArgumentType": false,
"strictListInference": false,
"strictDictionaryInference": false,
"strictSetInference": false
}

View File

@@ -147,7 +147,7 @@ class TestAdminCreateUser:
headers={"Authorization": f"Bearer {superuser_token}"},
)
assert response.status_code == status.HTTP_404_NOT_FOUND
assert response.status_code == status.HTTP_409_CONFLICT
class TestAdminGetUser:
@@ -565,7 +565,7 @@ class TestAdminCreateOrganization:
headers={"Authorization": f"Bearer {superuser_token}"},
)
assert response.status_code == status.HTTP_404_NOT_FOUND
assert response.status_code == status.HTTP_409_CONFLICT
class TestAdminGetOrganization:

View File

@@ -45,7 +45,7 @@ class TestAdminListUsersFilters:
async def test_list_users_database_error_propagates(self, client, superuser_token):
"""Test that database errors propagate correctly (covers line 118-120)."""
with patch(
"app.api.routes.admin.user_crud.get_multi_with_total",
"app.api.routes.admin.user_service.list_users",
side_effect=Exception("DB error"),
):
with pytest.raises(Exception):
@@ -74,8 +74,8 @@ class TestAdminCreateUserErrors:
},
)
# Should get error for duplicate email
assert response.status_code == status.HTTP_404_NOT_FOUND
# Should get conflict for duplicate email
assert response.status_code == status.HTTP_409_CONFLICT
@pytest.mark.asyncio
async def test_create_user_unexpected_error_propagates(
@@ -83,7 +83,7 @@ class TestAdminCreateUserErrors:
):
"""Test unexpected errors during user creation (covers line 151-153)."""
with patch(
"app.api.routes.admin.user_crud.create",
"app.api.routes.admin.user_service.create_user",
side_effect=RuntimeError("Unexpected error"),
):
with pytest.raises(RuntimeError):
@@ -135,7 +135,7 @@ class TestAdminUpdateUserErrors:
):
"""Test unexpected errors during user update (covers line 206-208)."""
with patch(
"app.api.routes.admin.user_crud.update",
"app.api.routes.admin.user_service.update_user",
side_effect=RuntimeError("Update failed"),
):
with pytest.raises(RuntimeError):
@@ -166,7 +166,7 @@ class TestAdminDeleteUserErrors:
):
"""Test unexpected errors during user deletion (covers line 238-240)."""
with patch(
"app.api.routes.admin.user_crud.soft_delete",
"app.api.routes.admin.user_service.soft_delete_user",
side_effect=Exception("Delete failed"),
):
with pytest.raises(Exception):
@@ -196,7 +196,7 @@ class TestAdminActivateUserErrors:
):
"""Test unexpected errors during user activation (covers line 282-284)."""
with patch(
"app.api.routes.admin.user_crud.update",
"app.api.routes.admin.user_service.update_user",
side_effect=Exception("Activation failed"),
):
with pytest.raises(Exception):
@@ -238,7 +238,7 @@ class TestAdminDeactivateUserErrors:
):
"""Test unexpected errors during user deactivation (covers line 326-328)."""
with patch(
"app.api.routes.admin.user_crud.update",
"app.api.routes.admin.user_service.update_user",
side_effect=Exception("Deactivation failed"),
):
with pytest.raises(Exception):
@@ -258,7 +258,7 @@ class TestAdminListOrganizationsErrors:
async def test_list_organizations_database_error(self, client, superuser_token):
"""Test list organizations with database error (covers line 427-456)."""
with patch(
"app.api.routes.admin.organization_crud.get_multi_with_member_counts",
"app.api.routes.admin.organization_service.get_multi_with_member_counts",
side_effect=Exception("DB error"),
):
with pytest.raises(Exception):
@@ -299,14 +299,14 @@ class TestAdminCreateOrganizationErrors:
},
)
# Should get error for duplicate slug
assert response.status_code == status.HTTP_404_NOT_FOUND
# Should get conflict for duplicate slug
assert response.status_code == status.HTTP_409_CONFLICT
@pytest.mark.asyncio
async def test_create_organization_unexpected_error(self, client, superuser_token):
"""Test unexpected errors during organization creation (covers line 484-485)."""
with patch(
"app.api.routes.admin.organization_crud.create",
"app.api.routes.admin.organization_service.create_organization",
side_effect=RuntimeError("Creation failed"),
):
with pytest.raises(RuntimeError):
@@ -367,7 +367,7 @@ class TestAdminUpdateOrganizationErrors:
org_id = org.id
with patch(
"app.api.routes.admin.organization_crud.update",
"app.api.routes.admin.organization_service.update_organization",
side_effect=Exception("Update failed"),
):
with pytest.raises(Exception):
@@ -412,7 +412,7 @@ class TestAdminDeleteOrganizationErrors:
org_id = org.id
with patch(
"app.api.routes.admin.organization_crud.remove",
"app.api.routes.admin.organization_service.remove_organization",
side_effect=Exception("Delete failed"),
):
with pytest.raises(Exception):
@@ -456,7 +456,7 @@ class TestAdminListOrganizationMembersErrors:
org_id = org.id
with patch(
"app.api.routes.admin.organization_crud.get_organization_members",
"app.api.routes.admin.organization_service.get_organization_members",
side_effect=Exception("DB error"),
):
with pytest.raises(Exception):
@@ -531,7 +531,7 @@ class TestAdminAddOrganizationMemberErrors:
org_id = org.id
with patch(
"app.api.routes.admin.organization_crud.add_user",
"app.api.routes.admin.organization_service.add_member",
side_effect=Exception("Add failed"),
):
with pytest.raises(Exception):
@@ -587,7 +587,7 @@ class TestAdminRemoveOrganizationMemberErrors:
org_id = org.id
with patch(
"app.api.routes.admin.organization_crud.remove_user",
"app.api.routes.admin.organization_service.remove_member",
side_effect=Exception("Remove failed"),
):
with pytest.raises(Exception):

View File

@@ -19,7 +19,7 @@ class TestLoginSessionCreationFailure:
"""Test that login succeeds even if session creation fails."""
# Mock session creation to fail
with patch(
"app.api.routes.auth.session_crud.create_session",
"app.api.routes.auth.session_service.create_session",
side_effect=Exception("Session creation failed"),
):
response = await client.post(
@@ -43,7 +43,7 @@ class TestOAuthLoginSessionCreationFailure:
):
"""Test OAuth login succeeds even if session creation fails."""
with patch(
"app.api.routes.auth.session_crud.create_session",
"app.api.routes.auth.session_service.create_session",
side_effect=Exception("Session failed"),
):
response = await client.post(
@@ -76,7 +76,7 @@ class TestRefreshTokenSessionUpdateFailure:
# Mock session update to fail
with patch(
"app.api.routes.auth.session_crud.update_refresh_token",
"app.api.routes.auth.session_service.update_refresh_token",
side_effect=Exception("Update failed"),
):
response = await client.post(
@@ -130,7 +130,7 @@ class TestLogoutWithNonExistentSession:
tokens = response.json()
# Mock session lookup to return None
with patch("app.api.routes.auth.session_crud.get_by_jti", return_value=None):
with patch("app.api.routes.auth.session_service.get_by_jti", return_value=None):
response = await client.post(
"/api/v1/auth/logout",
headers={"Authorization": f"Bearer {tokens['access_token']}"},
@@ -157,7 +157,7 @@ class TestLogoutUnexpectedError:
# Mock to raise unexpected error
with patch(
"app.api.routes.auth.session_crud.get_by_jti",
"app.api.routes.auth.session_service.get_by_jti",
side_effect=Exception("Unexpected error"),
):
response = await client.post(
@@ -186,7 +186,7 @@ class TestLogoutAllUnexpectedError:
# Mock to raise database error
with patch(
"app.api.routes.auth.session_crud.deactivate_all_user_sessions",
"app.api.routes.auth.session_service.deactivate_all_user_sessions",
side_effect=Exception("DB error"),
):
response = await client.post(
@@ -212,7 +212,7 @@ class TestPasswordResetConfirmSessionInvalidation:
# Mock session invalidation to fail
with patch(
"app.api.routes.auth.session_crud.deactivate_all_user_sessions",
"app.api.routes.auth.session_service.deactivate_all_user_sessions",
side_effect=Exception("Invalidation failed"),
):
response = await client.post(

View File

@@ -334,7 +334,7 @@ class TestPasswordResetConfirm:
token = create_password_reset_token(async_test_user.email)
# Mock the database commit to raise an exception
with patch("app.api.routes.auth.user_crud.get_by_email") as mock_get:
with patch("app.services.auth_service.user_repo.get_by_email") as mock_get:
mock_get.side_effect = Exception("Database error")
response = await client.post(

View File

@@ -12,8 +12,8 @@ These tests prevent real-world attack scenarios.
import pytest
from httpx import AsyncClient
from app.crud.session import session as session_crud
from app.models.user import User
from app.repositories.session import session_repo as session_crud
class TestRevokedSessionSecurity:

View File

@@ -8,7 +8,7 @@ from uuid import uuid4
import pytest
from app.crud.oauth import oauth_account
from app.repositories.oauth_account import oauth_account_repo as oauth_account
from app.schemas.oauth import OAuthAccountCreate
@@ -349,7 +349,7 @@ class TestOAuthProviderEndpoints:
_test_engine, AsyncTestingSessionLocal = async_test_db
# Create a test client
from app.crud.oauth import oauth_client
from app.repositories.oauth_client import oauth_client_repo as oauth_client
from app.schemas.oauth import OAuthClientCreate
async with AsyncTestingSessionLocal() as session:
@@ -386,7 +386,7 @@ class TestOAuthProviderEndpoints:
_test_engine, AsyncTestingSessionLocal = async_test_db
# Create a test client
from app.crud.oauth import oauth_client
from app.repositories.oauth_client import oauth_client_repo as oauth_client
from app.schemas.oauth import OAuthClientCreate
async with AsyncTestingSessionLocal() as session:

View File

@@ -537,7 +537,7 @@ class TestOrganizationExceptionHandlers:
):
"""Test generic exception handler in get_my_organizations (covers lines 81-83)."""
with patch(
"app.crud.organization.organization.get_user_organizations_with_details",
"app.api.routes.organizations.organization_service.get_user_organizations_with_details",
side_effect=Exception("Database connection lost"),
):
# The exception handler logs and re-raises, so we expect the exception
@@ -554,7 +554,7 @@ class TestOrganizationExceptionHandlers:
):
"""Test generic exception handler in get_organization (covers lines 124-128)."""
with patch(
"app.crud.organization.organization.get",
"app.api.routes.organizations.organization_service.get_organization",
side_effect=Exception("Database timeout"),
):
with pytest.raises(Exception, match="Database timeout"):
@@ -569,7 +569,7 @@ class TestOrganizationExceptionHandlers:
):
"""Test generic exception handler in get_organization_members (covers lines 170-172)."""
with patch(
"app.crud.organization.organization.get_organization_members",
"app.api.routes.organizations.organization_service.get_organization_members",
side_effect=Exception("Connection pool exhausted"),
):
with pytest.raises(Exception, match="Connection pool exhausted"):
@@ -591,11 +591,11 @@ class TestOrganizationExceptionHandlers:
admin_token = login_response.json()["access_token"]
with patch(
"app.crud.organization.organization.get",
"app.api.routes.organizations.organization_service.get_organization",
return_value=test_org_with_user_admin,
):
with patch(
"app.crud.organization.organization.update",
"app.api.routes.organizations.organization_service.update_organization",
side_effect=Exception("Write lock timeout"),
):
with pytest.raises(Exception, match="Write lock timeout"):

View File

@@ -11,9 +11,9 @@ These tests prevent unauthorized access and privilege escalation.
import pytest
from httpx import AsyncClient
from app.crud.user import user as user_crud
from app.models.organization import Organization
from app.models.user import User
from app.repositories.user import user_repo as user_crud
class TestInactiveUserBlocking:

View File

@@ -39,7 +39,7 @@ async def async_test_user2(async_test_db):
_test_engine, SessionLocal = async_test_db
async with SessionLocal() as session:
from app.crud.user import user as user_crud
from app.repositories.user import user_repo as user_crud
from app.schemas.users import UserCreate
user_data = UserCreate(
@@ -191,7 +191,7 @@ class TestRevokeSession:
# Verify session is deactivated
async with SessionLocal() as session:
from app.crud.session import session as session_crud
from app.repositories.session import session_repo as session_crud
revoked_session = await session_crud.get(session, id=str(session_id))
assert revoked_session.is_active is False
@@ -268,7 +268,7 @@ class TestCleanupExpiredSessions:
_test_engine, SessionLocal = async_test_db
# Create expired and active sessions using CRUD to avoid greenlet issues
from app.crud.session import session as session_crud
from app.repositories.session import session_repo as session_crud
from app.schemas.sessions import SessionCreate
async with SessionLocal() as db:
@@ -334,7 +334,7 @@ class TestCleanupExpiredSessions:
_test_engine, SessionLocal = async_test_db
# Create only active sessions using CRUD
from app.crud.session import session as session_crud
from app.repositories.session import session_repo as session_crud
from app.schemas.sessions import SessionCreate
async with SessionLocal() as db:
@@ -384,7 +384,7 @@ class TestSessionsAdditionalCases:
# Create multiple sessions
async with SessionLocal() as session:
from app.crud.session import session as session_crud
from app.repositories.session import session_repo as session_crud
from app.schemas.sessions import SessionCreate
for i in range(5):
@@ -431,7 +431,7 @@ class TestSessionsAdditionalCases:
"""Test cleanup with mix of active/inactive and expired/not-expired sessions."""
_test_engine, SessionLocal = async_test_db
from app.crud.session import session as session_crud
from app.repositories.session import session_repo as session_crud
from app.schemas.sessions import SessionCreate
async with SessionLocal() as db:
@@ -502,10 +502,10 @@ class TestSessionExceptionHandlers:
"""Test list_sessions handles database errors (covers lines 104-106)."""
from unittest.mock import patch
from app.crud import session as session_module
from app.repositories import session as session_module
with patch.object(
session_module.session,
session_module.session_repo,
"get_user_sessions",
side_effect=Exception("Database error"),
):
@@ -527,10 +527,10 @@ class TestSessionExceptionHandlers:
from unittest.mock import patch
from uuid import uuid4
from app.crud import session as session_module
from app.repositories import session as session_module
# First create a session to revoke
from app.crud.session import session as session_crud
from app.repositories.session import session_repo as session_crud
from app.schemas.sessions import SessionCreate
_test_engine, AsyncTestingSessionLocal = async_test_db
@@ -550,7 +550,7 @@ class TestSessionExceptionHandlers:
# Mock the deactivate method to raise an exception
with patch.object(
session_module.session,
session_module.session_repo,
"deactivate",
side_effect=Exception("Database connection lost"),
):
@@ -568,10 +568,10 @@ class TestSessionExceptionHandlers:
"""Test cleanup_expired_sessions handles database errors (covers lines 233-236)."""
from unittest.mock import patch
from app.crud import session as session_module
from app.repositories import session as session_module
with patch.object(
session_module.session,
session_module.session_repo,
"cleanup_expired_for_user",
side_effect=Exception("Cleanup failed"),
):

View File

@@ -99,7 +99,8 @@ class TestUpdateCurrentUser:
from unittest.mock import patch
with patch(
"app.api.routes.users.user_crud.update", side_effect=Exception("DB error")
"app.api.routes.users.user_service.update_user",
side_effect=Exception("DB error"),
):
with pytest.raises(Exception):
await client.patch(
@@ -134,7 +135,7 @@ class TestUpdateCurrentUser:
from unittest.mock import patch
with patch(
"app.api.routes.users.user_crud.update",
"app.api.routes.users.user_service.update_user",
side_effect=ValueError("Invalid value"),
):
with pytest.raises(ValueError):
@@ -224,7 +225,8 @@ class TestUpdateUserById:
from unittest.mock import patch
with patch(
"app.api.routes.users.user_crud.update", side_effect=ValueError("Invalid")
"app.api.routes.users.user_service.update_user",
side_effect=ValueError("Invalid"),
):
with pytest.raises(ValueError):
await client.patch(
@@ -241,7 +243,8 @@ class TestUpdateUserById:
from unittest.mock import patch
with patch(
"app.api.routes.users.user_crud.update", side_effect=Exception("Unexpected")
"app.api.routes.users.user_service.update_user",
side_effect=Exception("Unexpected"),
):
with pytest.raises(Exception):
await client.patch(
@@ -354,7 +357,7 @@ class TestDeleteUserById:
from unittest.mock import patch
with patch(
"app.api.routes.users.user_crud.soft_delete",
"app.api.routes.users.user_service.soft_delete_user",
side_effect=ValueError("Cannot delete"),
):
with pytest.raises(ValueError):
@@ -371,7 +374,7 @@ class TestDeleteUserById:
from unittest.mock import patch
with patch(
"app.api.routes.users.user_crud.soft_delete",
"app.api.routes.users.user_service.soft_delete_user",
side_effect=Exception("Unexpected"),
):
with pytest.raises(Exception):

View File

@@ -46,7 +46,7 @@ async def login_user(client, email: str, password: str = "SecurePassword123!"):
async def create_superuser(e2e_db_session, email: str, password: str):
"""Create a superuser directly in the database."""
from app.crud.user import user as user_crud
from app.repositories.user import user_repo as user_crud
from app.schemas.users import UserCreate
user_in = UserCreate(

View File

@@ -46,7 +46,7 @@ async def register_and_login(client, email: str, password: str = "SecurePassword
async def create_superuser_and_login(client, db_session):
"""Helper to create a superuser directly in DB and login."""
from app.crud.user import user as user_crud
from app.repositories.user import user_repo as user_crud
from app.schemas.users import UserCreate
email = f"admin-{uuid4().hex[:8]}@example.com"

View File

@@ -11,7 +11,12 @@ import pytest
from sqlalchemy.exc import DataError, IntegrityError, OperationalError
from sqlalchemy.orm import joinedload
from app.crud.user import user as user_crud
from app.core.repository_exceptions import (
DuplicateEntryError,
IntegrityConstraintError,
InvalidInputError,
)
from app.repositories.user import user_repo as user_crud
from app.schemas.users import UserCreate, UserUpdate
@@ -81,7 +86,7 @@ class TestCRUDBaseGetMulti:
_test_engine, SessionLocal = async_test_db
async with SessionLocal() as session:
with pytest.raises(ValueError, match="skip must be non-negative"):
with pytest.raises(InvalidInputError, match="skip must be non-negative"):
await user_crud.get_multi(session, skip=-1)
@pytest.mark.asyncio
@@ -90,7 +95,7 @@ class TestCRUDBaseGetMulti:
_test_engine, SessionLocal = async_test_db
async with SessionLocal() as session:
with pytest.raises(ValueError, match="limit must be non-negative"):
with pytest.raises(InvalidInputError, match="limit must be non-negative"):
await user_crud.get_multi(session, limit=-1)
@pytest.mark.asyncio
@@ -99,7 +104,7 @@ class TestCRUDBaseGetMulti:
_test_engine, SessionLocal = async_test_db
async with SessionLocal() as session:
with pytest.raises(ValueError, match="Maximum limit is 1000"):
with pytest.raises(InvalidInputError, match="Maximum limit is 1000"):
await user_crud.get_multi(session, limit=1001)
@pytest.mark.asyncio
@@ -140,7 +145,7 @@ class TestCRUDBaseCreate:
last_name="Duplicate",
)
with pytest.raises(ValueError, match="already exists"):
with pytest.raises(DuplicateEntryError, match="already exists"):
await user_crud.create(session, obj_in=user_data)
@pytest.mark.asyncio
@@ -165,7 +170,9 @@ class TestCRUDBaseCreate:
last_name="User",
)
with pytest.raises(ValueError, match="Database integrity error"):
with pytest.raises(
DuplicateEntryError, match="Database integrity error"
):
await user_crud.create(session, obj_in=user_data)
@pytest.mark.asyncio
@@ -244,7 +251,7 @@ class TestCRUDBaseUpdate:
# Create another user
async with SessionLocal() as session:
from app.crud.user import user as user_crud
from app.repositories.user import user_repo as user_crud
user2_data = UserCreate(
email="user2@example.com",
@@ -268,7 +275,7 @@ class TestCRUDBaseUpdate:
):
update_data = UserUpdate(email=async_test_user.email)
with pytest.raises(ValueError, match="already exists"):
with pytest.raises(DuplicateEntryError, match="already exists"):
await user_crud.update(
session, db_obj=user2_obj, obj_in=update_data
)
@@ -302,7 +309,9 @@ class TestCRUDBaseUpdate:
"statement", {}, Exception("constraint failed")
),
):
with pytest.raises(ValueError, match="Database integrity error"):
with pytest.raises(
IntegrityConstraintError, match="Database integrity error"
):
await user_crud.update(
session, db_obj=user, obj_in={"first_name": "Test"}
)
@@ -322,7 +331,9 @@ class TestCRUDBaseUpdate:
"statement", {}, Exception("connection error")
),
):
with pytest.raises(ValueError, match="Database operation failed"):
with pytest.raises(
IntegrityConstraintError, match="Database operation failed"
):
await user_crud.update(
session, db_obj=user, obj_in={"first_name": "Test"}
)
@@ -403,7 +414,8 @@ class TestCRUDBaseRemove:
),
):
with pytest.raises(
ValueError, match="Cannot delete.*referenced by other records"
IntegrityConstraintError,
match="Cannot delete.*referenced by other records",
):
await user_crud.remove(session, id=str(async_test_user.id))
@@ -442,7 +454,7 @@ class TestCRUDBaseGetMultiWithTotal:
_test_engine, SessionLocal = async_test_db
async with SessionLocal() as session:
with pytest.raises(ValueError, match="skip must be non-negative"):
with pytest.raises(InvalidInputError, match="skip must be non-negative"):
await user_crud.get_multi_with_total(session, skip=-1)
@pytest.mark.asyncio
@@ -451,7 +463,7 @@ class TestCRUDBaseGetMultiWithTotal:
_test_engine, SessionLocal = async_test_db
async with SessionLocal() as session:
with pytest.raises(ValueError, match="limit must be non-negative"):
with pytest.raises(InvalidInputError, match="limit must be non-negative"):
await user_crud.get_multi_with_total(session, limit=-1)
@pytest.mark.asyncio
@@ -460,7 +472,7 @@ class TestCRUDBaseGetMultiWithTotal:
_test_engine, SessionLocal = async_test_db
async with SessionLocal() as session:
with pytest.raises(ValueError, match="Maximum limit is 1000"):
with pytest.raises(InvalidInputError, match="Maximum limit is 1000"):
await user_crud.get_multi_with_total(session, limit=1001)
@pytest.mark.asyncio
@@ -827,7 +839,7 @@ class TestCRUDBasePaginationValidation:
_test_engine, SessionLocal = async_test_db
async with SessionLocal() as session:
with pytest.raises(ValueError, match="skip must be non-negative"):
with pytest.raises(InvalidInputError, match="skip must be non-negative"):
await user_crud.get_multi_with_total(session, skip=-1, limit=10)
@pytest.mark.asyncio
@@ -836,7 +848,7 @@ class TestCRUDBasePaginationValidation:
_test_engine, SessionLocal = async_test_db
async with SessionLocal() as session:
with pytest.raises(ValueError, match="limit must be non-negative"):
with pytest.raises(InvalidInputError, match="limit must be non-negative"):
await user_crud.get_multi_with_total(session, skip=0, limit=-1)
@pytest.mark.asyncio
@@ -845,7 +857,7 @@ class TestCRUDBasePaginationValidation:
_test_engine, SessionLocal = async_test_db
async with SessionLocal() as session:
with pytest.raises(ValueError, match="Maximum limit is 1000"):
with pytest.raises(InvalidInputError, match="Maximum limit is 1000"):
await user_crud.get_multi_with_total(session, skip=0, limit=1001)
@pytest.mark.asyncio
@@ -899,8 +911,8 @@ class TestCRUDBaseModelsWithoutSoftDelete:
_test_engine, SessionLocal = async_test_db
# Create an organization (which doesn't have deleted_at)
from app.crud.organization import organization as org_crud
from app.models.organization import Organization
from app.repositories.organization import organization_repo as org_crud
async with SessionLocal() as session:
org = Organization(name="Test Org", slug="test-org")
@@ -910,7 +922,9 @@ class TestCRUDBaseModelsWithoutSoftDelete:
# Try to soft delete organization (should fail)
async with SessionLocal() as session:
with pytest.raises(ValueError, match="does not have a deleted_at column"):
with pytest.raises(
InvalidInputError, match="does not have a deleted_at column"
):
await org_crud.soft_delete(session, id=str(org_id))
@pytest.mark.asyncio
@@ -919,8 +933,8 @@ class TestCRUDBaseModelsWithoutSoftDelete:
_test_engine, SessionLocal = async_test_db
# Create an organization (which doesn't have deleted_at)
from app.crud.organization import organization as org_crud
from app.models.organization import Organization
from app.repositories.organization import organization_repo as org_crud
async with SessionLocal() as session:
org = Organization(name="Restore Test", slug="restore-test")
@@ -930,7 +944,9 @@ class TestCRUDBaseModelsWithoutSoftDelete:
# Try to restore organization (should fail)
async with SessionLocal() as session:
with pytest.raises(ValueError, match="does not have a deleted_at column"):
with pytest.raises(
InvalidInputError, match="does not have a deleted_at column"
):
await org_crud.restore(session, id=str(org_id))
@@ -950,8 +966,8 @@ class TestCRUDBaseEagerLoadingWithRealOptions:
_test_engine, SessionLocal = async_test_db
# Create a session for the user
from app.crud.session import session as session_crud
from app.models.user_session import UserSession
from app.repositories.session import session_repo as session_crud
async with SessionLocal() as session:
user_session = UserSession(
@@ -989,8 +1005,8 @@ class TestCRUDBaseEagerLoadingWithRealOptions:
_test_engine, SessionLocal = async_test_db
# Create multiple sessions for the user
from app.crud.session import session as session_crud
from app.models.user_session import UserSession
from app.repositories.session import session_repo as session_crud
async with SessionLocal() as session:
for i in range(3):

View File

@@ -10,7 +10,8 @@ from uuid import uuid4
import pytest
from sqlalchemy.exc import DataError, OperationalError
from app.crud.user import user as user_crud
from app.core.repository_exceptions import IntegrityConstraintError
from app.repositories.user import user_repo as user_crud
from app.schemas.users import UserCreate
@@ -119,7 +120,9 @@ class TestBaseCRUDUpdateFailures:
with patch.object(
session, "rollback", new_callable=AsyncMock
) as mock_rollback:
with pytest.raises(ValueError, match="Database operation failed"):
with pytest.raises(
IntegrityConstraintError, match="Database operation failed"
):
await user_crud.update(
session, db_obj=user, obj_in={"first_name": "Updated"}
)
@@ -141,7 +144,9 @@ class TestBaseCRUDUpdateFailures:
with patch.object(
session, "rollback", new_callable=AsyncMock
) as mock_rollback:
with pytest.raises(ValueError, match="Database operation failed"):
with pytest.raises(
IntegrityConstraintError, match="Database operation failed"
):
await user_crud.update(
session, db_obj=user, obj_in={"first_name": "Updated"}
)

View File

@@ -7,7 +7,10 @@ from datetime import UTC, datetime, timedelta
import pytest
from app.crud.oauth import oauth_account, oauth_client, oauth_state
from app.core.repository_exceptions import DuplicateEntryError
from app.repositories.oauth_account import oauth_account_repo as oauth_account
from app.repositories.oauth_client import oauth_client_repo as oauth_client
from app.repositories.oauth_state import oauth_state_repo as oauth_state
from app.schemas.oauth import OAuthAccountCreate, OAuthClientCreate, OAuthStateCreate
@@ -60,7 +63,8 @@ class TestOAuthAccountCRUD:
# SQLite returns different error message than PostgreSQL
with pytest.raises(
ValueError, match="(already linked|UNIQUE constraint failed)"
DuplicateEntryError,
match="(already linked|UNIQUE constraint failed|Failed to create)",
):
await oauth_account.create_account(session, obj_in=account_data2)
@@ -256,13 +260,13 @@ class TestOAuthAccountCRUD:
updated = await oauth_account.update_tokens(
session,
account=account,
access_token_encrypted="new_access_token",
refresh_token_encrypted="new_refresh_token",
access_token="new_access_token",
refresh_token="new_refresh_token",
token_expires_at=new_expires,
)
assert updated.access_token_encrypted == "new_access_token"
assert updated.refresh_token_encrypted == "new_refresh_token"
assert updated.access_token == "new_access_token"
assert updated.refresh_token == "new_refresh_token"
class TestOAuthStateCRUD:

View File

@@ -9,9 +9,10 @@ from uuid import uuid4
import pytest
from sqlalchemy import select
from app.crud.organization import organization as organization_crud
from app.core.repository_exceptions import DuplicateEntryError, IntegrityConstraintError
from app.models.organization import Organization
from app.models.user_organization import OrganizationRole, UserOrganization
from app.repositories.organization import organization_repo as organization_crud
from app.schemas.organizations import OrganizationCreate
@@ -87,7 +88,7 @@ class TestCreate:
# Try to create second with same slug
async with AsyncTestingSessionLocal() as session:
org_in = OrganizationCreate(name="Org 2", slug="duplicate-slug")
with pytest.raises(ValueError, match="already exists"):
with pytest.raises(DuplicateEntryError, match="already exists"):
await organization_crud.create(session, obj_in=org_in)
@pytest.mark.asyncio
@@ -295,7 +296,7 @@ class TestAddUser:
org_id = org.id
async with AsyncTestingSessionLocal() as session:
with pytest.raises(ValueError, match="already a member"):
with pytest.raises(DuplicateEntryError, match="already a member"):
await organization_crud.add_user(
session, organization_id=org_id, user_id=async_test_user.id
)
@@ -972,7 +973,9 @@ class TestOrganizationExceptionHandlers:
with patch.object(session, "commit", side_effect=mock_commit):
with patch.object(session, "rollback", new_callable=AsyncMock):
org_in = OrganizationCreate(name="Test", slug="test")
with pytest.raises(ValueError, match="Database integrity error"):
with pytest.raises(
IntegrityConstraintError, match="Database integrity error"
):
await organization_crud.create(session, obj_in=org_in)
@pytest.mark.asyncio
@@ -1058,7 +1061,8 @@ class TestOrganizationExceptionHandlers:
with patch.object(session, "commit", side_effect=mock_commit):
with patch.object(session, "rollback", new_callable=AsyncMock):
with pytest.raises(
ValueError, match="Failed to add user to organization"
IntegrityConstraintError,
match="Failed to add user to organization",
):
await organization_crud.add_user(
session,

View File

@@ -8,8 +8,9 @@ from uuid import uuid4
import pytest
from app.crud.session import session as session_crud
from app.core.repository_exceptions import InvalidInputError
from app.models.user_session import UserSession
from app.repositories.session import session_repo as session_crud
from app.schemas.sessions import SessionCreate
@@ -503,7 +504,7 @@ class TestCleanupExpiredForUser:
_test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session:
with pytest.raises(ValueError, match="Invalid user ID format"):
with pytest.raises(InvalidInputError, match="Invalid user ID format"):
await session_crud.cleanup_expired_for_user(
session, user_id="not-a-valid-uuid"
)

View File

@@ -10,8 +10,9 @@ from uuid import uuid4
import pytest
from sqlalchemy.exc import OperationalError
from app.crud.session import session as session_crud
from app.core.repository_exceptions import IntegrityConstraintError
from app.models.user_session import UserSession
from app.repositories.session import session_repo as session_crud
from app.schemas.sessions import SessionCreate
@@ -102,7 +103,9 @@ class TestSessionCRUDCreateSessionFailures:
last_used_at=datetime.now(UTC),
)
with pytest.raises(ValueError, match="Failed to create session"):
with pytest.raises(
IntegrityConstraintError, match="Failed to create session"
):
await session_crud.create_session(session, obj_in=session_data)
mock_rollback.assert_called_once()
@@ -133,7 +136,9 @@ class TestSessionCRUDCreateSessionFailures:
last_used_at=datetime.now(UTC),
)
with pytest.raises(ValueError, match="Failed to create session"):
with pytest.raises(
IntegrityConstraintError, match="Failed to create session"
):
await session_crud.create_session(session, obj_in=session_data)
mock_rollback.assert_called_once()

View File

@@ -5,7 +5,8 @@ Comprehensive tests for async user CRUD operations.
import pytest
from app.crud.user import user as user_crud
from app.core.repository_exceptions import DuplicateEntryError, InvalidInputError
from app.repositories.user import user_repo as user_crud
from app.schemas.users import UserCreate, UserUpdate
@@ -93,7 +94,7 @@ class TestCreate:
last_name="User",
)
with pytest.raises(ValueError) as exc_info:
with pytest.raises(DuplicateEntryError) as exc_info:
await user_crud.create(session, obj_in=user_data)
assert "already exists" in str(exc_info.value).lower()
@@ -330,7 +331,7 @@ class TestGetMultiWithTotal:
_test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session:
with pytest.raises(ValueError) as exc_info:
with pytest.raises(InvalidInputError) as exc_info:
await user_crud.get_multi_with_total(session, skip=-1, limit=10)
assert "skip must be non-negative" in str(exc_info.value)
@@ -341,7 +342,7 @@ class TestGetMultiWithTotal:
_test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session:
with pytest.raises(ValueError) as exc_info:
with pytest.raises(InvalidInputError) as exc_info:
await user_crud.get_multi_with_total(session, skip=0, limit=-1)
assert "limit must be non-negative" in str(exc_info.value)
@@ -352,7 +353,7 @@ class TestGetMultiWithTotal:
_test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session:
with pytest.raises(ValueError) as exc_info:
with pytest.raises(InvalidInputError) as exc_info:
await user_crud.get_multi_with_total(session, skip=0, limit=1001)
assert "Maximum limit is 1000" in str(exc_info.value)

View File

@@ -10,6 +10,7 @@ from app.core.auth import (
get_password_hash,
verify_password,
)
from app.core.exceptions import DuplicateError
from app.models.user import User
from app.schemas.users import Token, UserCreate
from app.services.auth_service import AuthenticationError, AuthService
@@ -152,9 +153,9 @@ class TestAuthServiceUserCreation:
last_name="User",
)
# Should raise AuthenticationError
# Should raise DuplicateError for duplicate email
async with AsyncTestingSessionLocal() as session:
with pytest.raises(AuthenticationError):
with pytest.raises(DuplicateError):
await AuthService.create_user(db=session, user_data=user_data)

View File

@@ -269,18 +269,18 @@ class TestClientValidation:
async def test_validate_client_legacy_sha256_hash(
self, db, confidential_client_legacy_hash
):
"""Test validating a client with legacy SHA-256 hash (backward compatibility)."""
"""Test that legacy SHA-256 hash is rejected with clear error message."""
client, secret = confidential_client_legacy_hash
validated = await service.validate_client(db, client.client_id, secret)
assert validated.client_id == client.client_id
with pytest.raises(service.InvalidClientError, match="deprecated hash format"):
await service.validate_client(db, client.client_id, secret)
@pytest.mark.asyncio
async def test_validate_client_legacy_sha256_wrong_secret(
self, db, confidential_client_legacy_hash
):
"""Test legacy SHA-256 client rejects wrong secret."""
"""Test that legacy SHA-256 client with wrong secret is rejected."""
client, _ = confidential_client_legacy_hash
with pytest.raises(service.InvalidClientError, match="Invalid client secret"):
with pytest.raises(service.InvalidClientError, match="deprecated hash format"):
await service.validate_client(db, client.client_id, "wrong_secret")
def test_validate_redirect_uri_success(self, public_client):

View File

@@ -11,7 +11,8 @@ from uuid import uuid4
import pytest
from app.core.exceptions import AuthenticationError
from app.crud.oauth import oauth_account, oauth_state
from app.repositories.oauth_account import oauth_account_repo as oauth_account
from app.repositories.oauth_state import oauth_state_repo as oauth_state
from app.schemas.oauth import OAuthAccountCreate, OAuthStateCreate
from app.services.oauth_service import OAUTH_PROVIDERS, OAuthService

View File

@@ -0,0 +1,444 @@
# tests/services/test_organization_service.py
"""Tests for the OrganizationService class."""
import uuid
import pytest
from app.core.exceptions import NotFoundError
from app.models.user_organization import OrganizationRole
from app.schemas.organizations import OrganizationCreate, OrganizationUpdate
from app.services.organization_service import organization_service
def _make_org_create(name=None, slug=None) -> OrganizationCreate:
"""Helper to create an OrganizationCreate schema with unique defaults."""
unique = uuid.uuid4().hex[:8]
return OrganizationCreate(
name=name or f"Test Org {unique}",
slug=slug or f"test-org-{unique}",
description="A test organization",
is_active=True,
settings={},
)
class TestGetOrganization:
"""Tests for OrganizationService.get_organization method."""
@pytest.mark.asyncio
async def test_get_organization_found(self, async_test_db, async_test_user):
"""Test getting an existing organization by ID returns the org."""
_test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session:
created = await organization_service.create_organization(
session, obj_in=_make_org_create()
)
async with AsyncTestingSessionLocal() as session:
result = await organization_service.get_organization(
session, str(created.id)
)
assert result is not None
assert result.id == created.id
assert result.slug == created.slug
@pytest.mark.asyncio
async def test_get_organization_not_found(self, async_test_db):
"""Test getting a non-existent organization raises NotFoundError."""
_test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session:
with pytest.raises(NotFoundError):
await organization_service.get_organization(session, str(uuid.uuid4()))
class TestCreateOrganization:
"""Tests for OrganizationService.create_organization method."""
@pytest.mark.asyncio
async def test_create_organization(self, async_test_db, async_test_user):
"""Test creating a new organization returns the created org with correct fields."""
_test_engine, AsyncTestingSessionLocal = async_test_db
obj_in = _make_org_create()
async with AsyncTestingSessionLocal() as session:
result = await organization_service.create_organization(
session, obj_in=obj_in
)
assert result is not None
assert result.name == obj_in.name
assert result.slug == obj_in.slug
assert result.description == obj_in.description
assert result.is_active is True
class TestUpdateOrganization:
"""Tests for OrganizationService.update_organization method."""
@pytest.mark.asyncio
async def test_update_organization(self, async_test_db, async_test_user):
"""Test updating an organization name."""
_test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session:
created = await organization_service.create_organization(
session, obj_in=_make_org_create()
)
async with AsyncTestingSessionLocal() as session:
org = await organization_service.get_organization(session, str(created.id))
updated = await organization_service.update_organization(
session,
org=org,
obj_in=OrganizationUpdate(name="Updated Org Name"),
)
assert updated.name == "Updated Org Name"
assert updated.id == created.id
@pytest.mark.asyncio
async def test_update_organization_with_dict(self, async_test_db, async_test_user):
"""Test updating an organization using a dict."""
_test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session:
created = await organization_service.create_organization(
session, obj_in=_make_org_create()
)
async with AsyncTestingSessionLocal() as session:
org = await organization_service.get_organization(session, str(created.id))
updated = await organization_service.update_organization(
session,
org=org,
obj_in={"description": "Updated description"},
)
assert updated.description == "Updated description"
class TestRemoveOrganization:
"""Tests for OrganizationService.remove_organization method."""
@pytest.mark.asyncio
async def test_remove_organization(self, async_test_db, async_test_user):
"""Test permanently deleting an organization."""
_test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session:
created = await organization_service.create_organization(
session, obj_in=_make_org_create()
)
org_id = str(created.id)
async with AsyncTestingSessionLocal() as session:
await organization_service.remove_organization(session, org_id)
async with AsyncTestingSessionLocal() as session:
with pytest.raises(NotFoundError):
await organization_service.get_organization(session, org_id)
class TestGetMemberCount:
"""Tests for OrganizationService.get_member_count method."""
@pytest.mark.asyncio
async def test_get_member_count_empty(self, async_test_db, async_test_user):
"""Test member count for org with no members is zero."""
_test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session:
created = await organization_service.create_organization(
session, obj_in=_make_org_create()
)
async with AsyncTestingSessionLocal() as session:
count = await organization_service.get_member_count(
session, organization_id=created.id
)
assert count == 0
@pytest.mark.asyncio
async def test_get_member_count_with_member(self, async_test_db, async_test_user):
"""Test member count increases after adding a member."""
_test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session:
created = await organization_service.create_organization(
session, obj_in=_make_org_create()
)
async with AsyncTestingSessionLocal() as session:
await organization_service.add_member(
session,
organization_id=created.id,
user_id=async_test_user.id,
)
async with AsyncTestingSessionLocal() as session:
count = await organization_service.get_member_count(
session, organization_id=created.id
)
assert count == 1
class TestGetMultiWithMemberCounts:
"""Tests for OrganizationService.get_multi_with_member_counts method."""
@pytest.mark.asyncio
async def test_get_multi_with_member_counts(self, async_test_db, async_test_user):
"""Test listing organizations with member counts returns tuple."""
_test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session:
await organization_service.create_organization(
session, obj_in=_make_org_create()
)
async with AsyncTestingSessionLocal() as session:
orgs, count = await organization_service.get_multi_with_member_counts(
session, skip=0, limit=10
)
assert isinstance(orgs, list)
assert isinstance(count, int)
assert count >= 1
@pytest.mark.asyncio
async def test_get_multi_with_member_counts_search(
self, async_test_db, async_test_user
):
"""Test listing organizations with a search filter."""
_test_engine, AsyncTestingSessionLocal = async_test_db
unique = uuid.uuid4().hex[:8]
org_name = f"Searchable Org {unique}"
async with AsyncTestingSessionLocal() as session:
await organization_service.create_organization(
session,
obj_in=OrganizationCreate(
name=org_name,
slug=f"searchable-org-{unique}",
is_active=True,
settings={},
),
)
async with AsyncTestingSessionLocal() as session:
orgs, count = await organization_service.get_multi_with_member_counts(
session, skip=0, limit=10, search=f"Searchable Org {unique}"
)
assert count >= 1
# Each element is a dict with key "organization" (an Organization obj) and "member_count"
names = [o["organization"].name for o in orgs]
assert org_name in names
class TestGetUserOrganizationsWithDetails:
"""Tests for OrganizationService.get_user_organizations_with_details method."""
@pytest.mark.asyncio
async def test_get_user_organizations_with_details(
self, async_test_db, async_test_user
):
"""Test getting organizations for a user returns list of dicts."""
_test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session:
created = await organization_service.create_organization(
session, obj_in=_make_org_create()
)
await organization_service.add_member(
session,
organization_id=created.id,
user_id=async_test_user.id,
)
async with AsyncTestingSessionLocal() as session:
orgs = await organization_service.get_user_organizations_with_details(
session, user_id=async_test_user.id
)
assert isinstance(orgs, list)
assert len(orgs) >= 1
class TestGetOrganizationMembers:
"""Tests for OrganizationService.get_organization_members method."""
@pytest.mark.asyncio
async def test_get_organization_members(self, async_test_db, async_test_user):
"""Test getting organization members returns paginated results."""
_test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session:
created = await organization_service.create_organization(
session, obj_in=_make_org_create()
)
await organization_service.add_member(
session,
organization_id=created.id,
user_id=async_test_user.id,
)
async with AsyncTestingSessionLocal() as session:
members, count = await organization_service.get_organization_members(
session, organization_id=created.id, skip=0, limit=10
)
assert isinstance(members, list)
assert isinstance(count, int)
assert count >= 1
class TestAddMember:
"""Tests for OrganizationService.add_member method."""
@pytest.mark.asyncio
async def test_add_member_default_role(self, async_test_db, async_test_user):
"""Test adding a user to an org with default MEMBER role."""
_test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session:
created = await organization_service.create_organization(
session, obj_in=_make_org_create()
)
async with AsyncTestingSessionLocal() as session:
membership = await organization_service.add_member(
session,
organization_id=created.id,
user_id=async_test_user.id,
)
assert membership is not None
assert membership.user_id == async_test_user.id
assert membership.organization_id == created.id
assert membership.role == OrganizationRole.MEMBER
@pytest.mark.asyncio
async def test_add_member_admin_role(self, async_test_db, async_test_user):
"""Test adding a user to an org with ADMIN role."""
_test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session:
created = await organization_service.create_organization(
session, obj_in=_make_org_create()
)
async with AsyncTestingSessionLocal() as session:
membership = await organization_service.add_member(
session,
organization_id=created.id,
user_id=async_test_user.id,
role=OrganizationRole.ADMIN,
)
assert membership.role == OrganizationRole.ADMIN
class TestRemoveMember:
"""Tests for OrganizationService.remove_member method."""
@pytest.mark.asyncio
async def test_remove_member(self, async_test_db, async_test_user):
"""Test removing a member from an org returns True."""
_test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session:
created = await organization_service.create_organization(
session, obj_in=_make_org_create()
)
await organization_service.add_member(
session,
organization_id=created.id,
user_id=async_test_user.id,
)
async with AsyncTestingSessionLocal() as session:
removed = await organization_service.remove_member(
session,
organization_id=created.id,
user_id=async_test_user.id,
)
assert removed is True
@pytest.mark.asyncio
async def test_remove_member_not_found(self, async_test_db, async_test_user):
"""Test removing a non-member returns False."""
_test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session:
created = await organization_service.create_organization(
session, obj_in=_make_org_create()
)
async with AsyncTestingSessionLocal() as session:
removed = await organization_service.remove_member(
session,
organization_id=created.id,
user_id=async_test_user.id,
)
assert removed is False
class TestGetUserRoleInOrg:
"""Tests for OrganizationService.get_user_role_in_org method."""
@pytest.mark.asyncio
async def test_get_user_role_in_org(self, async_test_db, async_test_user):
"""Test getting a user's role in an org they belong to."""
_test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session:
created = await organization_service.create_organization(
session, obj_in=_make_org_create()
)
await organization_service.add_member(
session,
organization_id=created.id,
user_id=async_test_user.id,
role=OrganizationRole.MEMBER,
)
async with AsyncTestingSessionLocal() as session:
role = await organization_service.get_user_role_in_org(
session,
user_id=async_test_user.id,
organization_id=created.id,
)
assert role == OrganizationRole.MEMBER
@pytest.mark.asyncio
async def test_get_user_role_in_org_not_member(
self, async_test_db, async_test_user
):
"""Test getting role for a user not in the org returns None."""
_test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session:
created = await organization_service.create_organization(
session, obj_in=_make_org_create()
)
async with AsyncTestingSessionLocal() as session:
role = await organization_service.get_user_role_in_org(
session,
user_id=async_test_user.id,
organization_id=created.id,
)
assert role is None
class TestGetOrgDistribution:
"""Tests for OrganizationService.get_org_distribution method."""
@pytest.mark.asyncio
async def test_get_org_distribution_empty(self, async_test_db):
"""Test org distribution with no memberships returns empty list."""
_test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session:
result = await organization_service.get_org_distribution(session, limit=6)
assert isinstance(result, list)
@pytest.mark.asyncio
async def test_get_org_distribution_with_members(
self, async_test_db, async_test_user
):
"""Test org distribution returns org name and member count."""
_test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session:
created = await organization_service.create_organization(
session, obj_in=_make_org_create()
)
await organization_service.add_member(
session,
organization_id=created.id,
user_id=async_test_user.id,
)
async with AsyncTestingSessionLocal() as session:
result = await organization_service.get_org_distribution(session, limit=6)
assert isinstance(result, list)
assert len(result) >= 1
entry = result[0]
assert "name" in entry
assert "value" in entry
assert entry["value"] >= 1

View File

@@ -0,0 +1,291 @@
# tests/services/test_session_service.py
"""Tests for the SessionService class."""
import uuid
from datetime import UTC, datetime, timedelta
import pytest
from app.schemas.sessions import SessionCreate
from app.services.session_service import session_service
def _make_session_create(user_id, jti=None) -> SessionCreate:
"""Helper to build a SessionCreate with sensible defaults."""
now = datetime.now(UTC)
return SessionCreate(
user_id=user_id,
refresh_token_jti=jti or str(uuid.uuid4()),
ip_address="127.0.0.1",
user_agent="pytest/test",
device_name="Test Device",
device_id="test-device-id",
last_used_at=now,
expires_at=now + timedelta(days=7),
location_city="TestCity",
location_country="TestCountry",
)
class TestCreateSession:
"""Tests for SessionService.create_session method."""
@pytest.mark.asyncio
async def test_create_session(self, async_test_db, async_test_user):
"""Test creating a session returns a UserSession with correct fields."""
_test_engine, AsyncTestingSessionLocal = async_test_db
obj_in = _make_session_create(async_test_user.id)
async with AsyncTestingSessionLocal() as session:
result = await session_service.create_session(session, obj_in=obj_in)
assert result is not None
assert result.user_id == async_test_user.id
assert result.refresh_token_jti == obj_in.refresh_token_jti
assert result.is_active is True
assert result.ip_address == "127.0.0.1"
class TestGetSession:
"""Tests for SessionService.get_session method."""
@pytest.mark.asyncio
async def test_get_session_found(self, async_test_db, async_test_user):
"""Test getting a session by ID returns the session."""
_test_engine, AsyncTestingSessionLocal = async_test_db
obj_in = _make_session_create(async_test_user.id)
async with AsyncTestingSessionLocal() as session:
created = await session_service.create_session(session, obj_in=obj_in)
async with AsyncTestingSessionLocal() as session:
result = await session_service.get_session(session, str(created.id))
assert result is not None
assert result.id == created.id
@pytest.mark.asyncio
async def test_get_session_not_found(self, async_test_db):
"""Test getting a non-existent session returns None."""
_test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session:
result = await session_service.get_session(session, str(uuid.uuid4()))
assert result is None
class TestGetUserSessions:
"""Tests for SessionService.get_user_sessions method."""
@pytest.mark.asyncio
async def test_get_user_sessions_active_only(self, async_test_db, async_test_user):
"""Test getting active sessions for a user returns only active sessions."""
_test_engine, AsyncTestingSessionLocal = async_test_db
obj_in = _make_session_create(async_test_user.id)
async with AsyncTestingSessionLocal() as session:
await session_service.create_session(session, obj_in=obj_in)
async with AsyncTestingSessionLocal() as session:
sessions = await session_service.get_user_sessions(
session, user_id=str(async_test_user.id), active_only=True
)
assert isinstance(sessions, list)
assert len(sessions) >= 1
for s in sessions:
assert s.is_active is True
@pytest.mark.asyncio
async def test_get_user_sessions_all(self, async_test_db, async_test_user):
"""Test getting all sessions (active and inactive) for a user."""
_test_engine, AsyncTestingSessionLocal = async_test_db
obj_in = _make_session_create(async_test_user.id)
async with AsyncTestingSessionLocal() as session:
created = await session_service.create_session(session, obj_in=obj_in)
await session_service.deactivate(session, session_id=str(created.id))
async with AsyncTestingSessionLocal() as session:
sessions = await session_service.get_user_sessions(
session, user_id=str(async_test_user.id), active_only=False
)
assert isinstance(sessions, list)
assert len(sessions) >= 1
class TestGetActiveByJti:
"""Tests for SessionService.get_active_by_jti method."""
@pytest.mark.asyncio
async def test_get_active_by_jti_found(self, async_test_db, async_test_user):
"""Test getting an active session by JTI returns the session."""
_test_engine, AsyncTestingSessionLocal = async_test_db
jti = str(uuid.uuid4())
obj_in = _make_session_create(async_test_user.id, jti=jti)
async with AsyncTestingSessionLocal() as session:
await session_service.create_session(session, obj_in=obj_in)
async with AsyncTestingSessionLocal() as session:
result = await session_service.get_active_by_jti(session, jti=jti)
assert result is not None
assert result.refresh_token_jti == jti
assert result.is_active is True
@pytest.mark.asyncio
async def test_get_active_by_jti_not_found(self, async_test_db):
"""Test getting an active session by non-existent JTI returns None."""
_test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session:
result = await session_service.get_active_by_jti(
session, jti=str(uuid.uuid4())
)
assert result is None
class TestGetByJti:
"""Tests for SessionService.get_by_jti method."""
@pytest.mark.asyncio
async def test_get_by_jti_active(self, async_test_db, async_test_user):
"""Test getting a session (active or inactive) by JTI."""
_test_engine, AsyncTestingSessionLocal = async_test_db
jti = str(uuid.uuid4())
obj_in = _make_session_create(async_test_user.id, jti=jti)
async with AsyncTestingSessionLocal() as session:
await session_service.create_session(session, obj_in=obj_in)
async with AsyncTestingSessionLocal() as session:
result = await session_service.get_by_jti(session, jti=jti)
assert result is not None
assert result.refresh_token_jti == jti
class TestDeactivate:
"""Tests for SessionService.deactivate method."""
@pytest.mark.asyncio
async def test_deactivate_session(self, async_test_db, async_test_user):
"""Test deactivating a session sets is_active to False."""
_test_engine, AsyncTestingSessionLocal = async_test_db
obj_in = _make_session_create(async_test_user.id)
async with AsyncTestingSessionLocal() as session:
created = await session_service.create_session(session, obj_in=obj_in)
session_id = str(created.id)
async with AsyncTestingSessionLocal() as session:
deactivated = await session_service.deactivate(
session, session_id=session_id
)
assert deactivated is not None
assert deactivated.is_active is False
class TestDeactivateAllUserSessions:
"""Tests for SessionService.deactivate_all_user_sessions method."""
@pytest.mark.asyncio
async def test_deactivate_all_user_sessions(self, async_test_db, async_test_user):
"""Test deactivating all sessions for a user returns count deactivated."""
_test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session:
await session_service.create_session(
session, obj_in=_make_session_create(async_test_user.id)
)
await session_service.create_session(
session, obj_in=_make_session_create(async_test_user.id)
)
async with AsyncTestingSessionLocal() as session:
count = await session_service.deactivate_all_user_sessions(
session, user_id=str(async_test_user.id)
)
assert count >= 2
async with AsyncTestingSessionLocal() as session:
active_sessions = await session_service.get_user_sessions(
session, user_id=str(async_test_user.id), active_only=True
)
assert len(active_sessions) == 0
class TestUpdateRefreshToken:
"""Tests for SessionService.update_refresh_token method."""
@pytest.mark.asyncio
async def test_update_refresh_token(self, async_test_db, async_test_user):
"""Test rotating a session's refresh token updates JTI and expiry."""
_test_engine, AsyncTestingSessionLocal = async_test_db
obj_in = _make_session_create(async_test_user.id)
async with AsyncTestingSessionLocal() as session:
created = await session_service.create_session(session, obj_in=obj_in)
session_id = str(created.id)
new_jti = str(uuid.uuid4())
new_expires_at = datetime.now(UTC) + timedelta(days=14)
async with AsyncTestingSessionLocal() as session:
result = await session_service.get_session(session, session_id)
updated = await session_service.update_refresh_token(
session,
session=result,
new_jti=new_jti,
new_expires_at=new_expires_at,
)
assert updated.refresh_token_jti == new_jti
class TestCleanupExpiredForUser:
"""Tests for SessionService.cleanup_expired_for_user method."""
@pytest.mark.asyncio
async def test_cleanup_expired_for_user(self, async_test_db, async_test_user):
"""Test cleaning up expired inactive sessions returns count removed."""
_test_engine, AsyncTestingSessionLocal = async_test_db
now = datetime.now(UTC)
# Create a session that is already expired
obj_in = SessionCreate(
user_id=async_test_user.id,
refresh_token_jti=str(uuid.uuid4()),
ip_address="127.0.0.1",
user_agent="pytest/test",
last_used_at=now - timedelta(days=8),
expires_at=now - timedelta(days=1),
)
async with AsyncTestingSessionLocal() as session:
created = await session_service.create_session(session, obj_in=obj_in)
session_id = str(created.id)
# Deactivate it so it qualifies for cleanup (requires is_active=False AND expired)
async with AsyncTestingSessionLocal() as session:
await session_service.deactivate(session, session_id=session_id)
async with AsyncTestingSessionLocal() as session:
count = await session_service.cleanup_expired_for_user(
session, user_id=str(async_test_user.id)
)
assert isinstance(count, int)
assert count >= 1
class TestGetAllSessions:
"""Tests for SessionService.get_all_sessions method."""
@pytest.mark.asyncio
async def test_get_all_sessions(self, async_test_db, async_test_user):
"""Test getting all sessions with pagination returns tuple of list and count."""
_test_engine, AsyncTestingSessionLocal = async_test_db
obj_in = _make_session_create(async_test_user.id)
async with AsyncTestingSessionLocal() as session:
await session_service.create_session(session, obj_in=obj_in)
async with AsyncTestingSessionLocal() as session:
sessions, count = await session_service.get_all_sessions(
session, skip=0, limit=10, active_only=True, with_user=False
)
assert isinstance(sessions, list)
assert isinstance(count, int)
assert count >= 1
assert len(sessions) >= 1

View File

@@ -0,0 +1,213 @@
# tests/services/test_user_service.py
"""Tests for the UserService class."""
import uuid
import pytest
from sqlalchemy import select
from app.core.exceptions import NotFoundError
from app.models.user import User
from app.schemas.users import UserCreate, UserUpdate
from app.services.user_service import user_service
class TestGetUser:
"""Tests for UserService.get_user method."""
@pytest.mark.asyncio
async def test_get_user_found(self, async_test_db, async_test_user):
"""Test getting an existing user by ID returns the user."""
_test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session:
result = await user_service.get_user(session, str(async_test_user.id))
assert result is not None
assert result.id == async_test_user.id
assert result.email == async_test_user.email
@pytest.mark.asyncio
async def test_get_user_not_found(self, async_test_db):
"""Test getting a non-existent user raises NotFoundError."""
_test_engine, AsyncTestingSessionLocal = async_test_db
non_existent_id = str(uuid.uuid4())
async with AsyncTestingSessionLocal() as session:
with pytest.raises(NotFoundError):
await user_service.get_user(session, non_existent_id)
class TestGetByEmail:
"""Tests for UserService.get_by_email method."""
@pytest.mark.asyncio
async def test_get_by_email_found(self, async_test_db, async_test_user):
"""Test getting an existing user by email returns the user."""
_test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session:
result = await user_service.get_by_email(session, async_test_user.email)
assert result is not None
assert result.id == async_test_user.id
assert result.email == async_test_user.email
@pytest.mark.asyncio
async def test_get_by_email_not_found(self, async_test_db):
"""Test getting a user by non-existent email returns None."""
_test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session:
result = await user_service.get_by_email(session, "nonexistent@example.com")
assert result is None
class TestCreateUser:
"""Tests for UserService.create_user method."""
@pytest.mark.asyncio
async def test_create_user(self, async_test_db):
"""Test creating a new user with valid data."""
_test_engine, AsyncTestingSessionLocal = async_test_db
unique_email = f"test_{uuid.uuid4()}@example.com"
user_data = UserCreate(
email=unique_email,
password="TestPassword123!",
first_name="New",
last_name="User",
)
async with AsyncTestingSessionLocal() as session:
result = await user_service.create_user(session, user_data)
assert result is not None
assert result.email == unique_email
assert result.first_name == "New"
assert result.last_name == "User"
assert result.is_active is True
class TestUpdateUser:
"""Tests for UserService.update_user method."""
@pytest.mark.asyncio
async def test_update_user(self, async_test_db, async_test_user):
"""Test updating a user's first_name."""
_test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session:
user = await user_service.get_user(session, str(async_test_user.id))
updated = await user_service.update_user(
session,
user=user,
obj_in=UserUpdate(first_name="Updated"),
)
assert updated.first_name == "Updated"
assert updated.id == async_test_user.id
class TestSoftDeleteUser:
"""Tests for UserService.soft_delete_user method."""
@pytest.mark.asyncio
async def test_soft_delete_user(self, async_test_db, async_test_user):
"""Test soft-deleting a user sets deleted_at."""
_test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session:
await user_service.soft_delete_user(session, str(async_test_user.id))
async with AsyncTestingSessionLocal() as session:
result = await session.execute(
select(User).where(User.id == async_test_user.id)
)
user = result.scalar_one_or_none()
assert user is not None
assert user.deleted_at is not None
class TestListUsers:
"""Tests for UserService.list_users method."""
@pytest.mark.asyncio
async def test_list_users(self, async_test_db, async_test_user):
"""Test listing users with pagination returns correct results."""
_test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session:
users, count = await user_service.list_users(session, skip=0, limit=10)
assert isinstance(users, list)
assert isinstance(count, int)
assert count >= 1
assert len(users) >= 1
@pytest.mark.asyncio
async def test_list_users_with_search(self, async_test_db, async_test_user):
"""Test listing users with email fragment search returns matching users."""
_test_engine, AsyncTestingSessionLocal = async_test_db
# Search by partial email fragment of the test user
email_fragment = async_test_user.email.split("@")[0]
async with AsyncTestingSessionLocal() as session:
users, count = await user_service.list_users(
session, skip=0, limit=10, search=email_fragment
)
assert isinstance(users, list)
assert count >= 1
emails = [u.email for u in users]
assert async_test_user.email in emails
class TestBulkUpdateStatus:
"""Tests for UserService.bulk_update_status method."""
@pytest.mark.asyncio
async def test_bulk_update_status(self, async_test_db, async_test_user):
"""Test bulk activating users returns correct count."""
_test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session:
count = await user_service.bulk_update_status(
session,
user_ids=[async_test_user.id],
is_active=True,
)
assert count >= 1
async with AsyncTestingSessionLocal() as session:
result = await session.execute(
select(User).where(User.id == async_test_user.id)
)
user = result.scalar_one_or_none()
assert user is not None
assert user.is_active is True
class TestBulkSoftDelete:
"""Tests for UserService.bulk_soft_delete method."""
@pytest.mark.asyncio
async def test_bulk_soft_delete(self, async_test_db, async_test_user):
"""Test bulk soft-deleting users returns correct count."""
_test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session:
count = await user_service.bulk_soft_delete(
session,
user_ids=[async_test_user.id],
)
assert count >= 1
async with AsyncTestingSessionLocal() as session:
result = await session.execute(
select(User).where(User.id == async_test_user.id)
)
user = result.scalar_one_or_none()
assert user is not None
assert user.deleted_at is not None
class TestGetStats:
"""Tests for UserService.get_stats method."""
@pytest.mark.asyncio
async def test_get_stats(self, async_test_db, async_test_user):
"""Test get_stats returns dict with expected keys and correct counts."""
_test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session:
stats = await user_service.get_stats(session)
assert "total_users" in stats
assert "active_count" in stats
assert "inactive_count" in stats
assert "all_users" in stats
assert stats["total_users"] >= 1
assert stats["active_count"] >= 1
assert isinstance(stats["all_users"], list)
assert len(stats["all_users"]) >= 1

66
backend/uv.lock generated
View File

@@ -519,7 +519,7 @@ dependencies = [
[package.optional-dependencies]
dev = [
{ name = "freezegun" },
{ name = "mypy" },
{ name = "pyright" },
{ name = "pytest" },
{ name = "pytest-asyncio" },
{ name = "pytest-cov" },
@@ -546,12 +546,12 @@ requires-dist = [
{ name = "fastapi-utils", specifier = "==0.8.0" },
{ name = "freezegun", marker = "extra == 'dev'", specifier = "~=1.5.1" },
{ name = "httpx", specifier = ">=0.27.0" },
{ name = "mypy", marker = "extra == 'dev'", specifier = ">=1.8.0" },
{ name = "passlib", specifier = "==1.7.4" },
{ name = "pillow", specifier = ">=10.3.0" },
{ name = "psycopg2-binary", specifier = ">=2.9.9" },
{ name = "pydantic", specifier = ">=2.10.6" },
{ name = "pydantic-settings", specifier = ">=2.2.1" },
{ name = "pyright", marker = "extra == 'dev'", specifier = ">=1.1.390" },
{ name = "pytest", marker = "extra == 'dev'", specifier = ">=8.0.0" },
{ name = "pytest-asyncio", marker = "extra == 'dev'", specifier = ">=0.23.5" },
{ name = "pytest-cov", marker = "extra == 'dev'", specifier = ">=4.1.0" },
@@ -966,44 +966,12 @@ wheels = [
]
[[package]]
name = "mypy"
version = "1.18.2"
name = "nodeenv"
version = "1.10.0"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "mypy-extensions" },
{ name = "pathspec" },
{ name = "typing-extensions" },
]
sdist = { url = "https://files.pythonhosted.org/packages/c0/77/8f0d0001ffad290cef2f7f216f96c814866248a0b92a722365ed54648e7e/mypy-1.18.2.tar.gz", hash = "sha256:06a398102a5f203d7477b2923dda3634c36727fa5c237d8f859ef90c42a9924b", size = 3448846, upload-time = "2025-09-19T00:11:10.519Z" }
sdist = { url = "https://files.pythonhosted.org/packages/24/bf/d1bda4f6168e0b2e9e5958945e01910052158313224ada5ce1fb2e1113b8/nodeenv-1.10.0.tar.gz", hash = "sha256:996c191ad80897d076bdfba80a41994c2b47c68e224c542b48feba42ba00f8bb", size = 55611, upload-time = "2025-12-20T14:08:54.006Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/07/06/dfdd2bc60c66611dd8335f463818514733bc763e4760dee289dcc33df709/mypy-1.18.2-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:33eca32dd124b29400c31d7cf784e795b050ace0e1f91b8dc035672725617e34", size = 12908273, upload-time = "2025-09-19T00:10:58.321Z" },
{ url = "https://files.pythonhosted.org/packages/81/14/6a9de6d13a122d5608e1a04130724caf9170333ac5a924e10f670687d3eb/mypy-1.18.2-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:a3c47adf30d65e89b2dcd2fa32f3aeb5e94ca970d2c15fcb25e297871c8e4764", size = 11920910, upload-time = "2025-09-19T00:10:20.043Z" },
{ url = "https://files.pythonhosted.org/packages/5f/a9/b29de53e42f18e8cc547e38daa9dfa132ffdc64f7250e353f5c8cdd44bee/mypy-1.18.2-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:5d6c838e831a062f5f29d11c9057c6009f60cb294fea33a98422688181fe2893", size = 12465585, upload-time = "2025-09-19T00:10:33.005Z" },
{ url = "https://files.pythonhosted.org/packages/77/ae/6c3d2c7c61ff21f2bee938c917616c92ebf852f015fb55917fd6e2811db2/mypy-1.18.2-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:01199871b6110a2ce984bde85acd481232d17413868c9807e95c1b0739a58914", size = 13348562, upload-time = "2025-09-19T00:10:11.51Z" },
{ url = "https://files.pythonhosted.org/packages/4d/31/aec68ab3b4aebdf8f36d191b0685d99faa899ab990753ca0fee60fb99511/mypy-1.18.2-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:a2afc0fa0b0e91b4599ddfe0f91e2c26c2b5a5ab263737e998d6817874c5f7c8", size = 13533296, upload-time = "2025-09-19T00:10:06.568Z" },
{ url = "https://files.pythonhosted.org/packages/9f/83/abcb3ad9478fca3ebeb6a5358bb0b22c95ea42b43b7789c7fb1297ca44f4/mypy-1.18.2-cp312-cp312-win_amd64.whl", hash = "sha256:d8068d0afe682c7c4897c0f7ce84ea77f6de953262b12d07038f4d296d547074", size = 9828828, upload-time = "2025-09-19T00:10:28.203Z" },
{ url = "https://files.pythonhosted.org/packages/5f/04/7f462e6fbba87a72bc8097b93f6842499c428a6ff0c81dd46948d175afe8/mypy-1.18.2-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:07b8b0f580ca6d289e69209ec9d3911b4a26e5abfde32228a288eb79df129fcc", size = 12898728, upload-time = "2025-09-19T00:10:01.33Z" },
{ url = "https://files.pythonhosted.org/packages/99/5b/61ed4efb64f1871b41fd0b82d29a64640f3516078f6c7905b68ab1ad8b13/mypy-1.18.2-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:ed4482847168439651d3feee5833ccedbf6657e964572706a2adb1f7fa4dfe2e", size = 11910758, upload-time = "2025-09-19T00:10:42.607Z" },
{ url = "https://files.pythonhosted.org/packages/3c/46/d297d4b683cc89a6e4108c4250a6a6b717f5fa96e1a30a7944a6da44da35/mypy-1.18.2-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:c3ad2afadd1e9fea5cf99a45a822346971ede8685cc581ed9cd4d42eaf940986", size = 12475342, upload-time = "2025-09-19T00:11:00.371Z" },
{ url = "https://files.pythonhosted.org/packages/83/45/4798f4d00df13eae3bfdf726c9244bcb495ab5bd588c0eed93a2f2dd67f3/mypy-1.18.2-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:a431a6f1ef14cf8c144c6b14793a23ec4eae3db28277c358136e79d7d062f62d", size = 13338709, upload-time = "2025-09-19T00:11:03.358Z" },
{ url = "https://files.pythonhosted.org/packages/d7/09/479f7358d9625172521a87a9271ddd2441e1dab16a09708f056e97007207/mypy-1.18.2-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:7ab28cc197f1dd77a67e1c6f35cd1f8e8b73ed2217e4fc005f9e6a504e46e7ba", size = 13529806, upload-time = "2025-09-19T00:10:26.073Z" },
{ url = "https://files.pythonhosted.org/packages/71/cf/ac0f2c7e9d0ea3c75cd99dff7aec1c9df4a1376537cb90e4c882267ee7e9/mypy-1.18.2-cp313-cp313-win_amd64.whl", hash = "sha256:0e2785a84b34a72ba55fb5daf079a1003a34c05b22238da94fcae2bbe46f3544", size = 9833262, upload-time = "2025-09-19T00:10:40.035Z" },
{ url = "https://files.pythonhosted.org/packages/5a/0c/7d5300883da16f0063ae53996358758b2a2df2a09c72a5061fa79a1f5006/mypy-1.18.2-cp314-cp314-macosx_10_13_x86_64.whl", hash = "sha256:62f0e1e988ad41c2a110edde6c398383a889d95b36b3e60bcf155f5164c4fdce", size = 12893775, upload-time = "2025-09-19T00:10:03.814Z" },
{ url = "https://files.pythonhosted.org/packages/50/df/2cffbf25737bdb236f60c973edf62e3e7b4ee1c25b6878629e88e2cde967/mypy-1.18.2-cp314-cp314-macosx_11_0_arm64.whl", hash = "sha256:8795a039bab805ff0c1dfdb8cd3344642c2b99b8e439d057aba30850b8d3423d", size = 11936852, upload-time = "2025-09-19T00:10:51.631Z" },
{ url = "https://files.pythonhosted.org/packages/be/50/34059de13dd269227fb4a03be1faee6e2a4b04a2051c82ac0a0b5a773c9a/mypy-1.18.2-cp314-cp314-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:6ca1e64b24a700ab5ce10133f7ccd956a04715463d30498e64ea8715236f9c9c", size = 12480242, upload-time = "2025-09-19T00:11:07.955Z" },
{ url = "https://files.pythonhosted.org/packages/5b/11/040983fad5132d85914c874a2836252bbc57832065548885b5bb5b0d4359/mypy-1.18.2-cp314-cp314-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:d924eef3795cc89fecf6bedc6ed32b33ac13e8321344f6ddbf8ee89f706c05cb", size = 13326683, upload-time = "2025-09-19T00:09:55.572Z" },
{ url = "https://files.pythonhosted.org/packages/e9/ba/89b2901dd77414dd7a8c8729985832a5735053be15b744c18e4586e506ef/mypy-1.18.2-cp314-cp314-musllinux_1_2_x86_64.whl", hash = "sha256:20c02215a080e3a2be3aa50506c67242df1c151eaba0dcbc1e4e557922a26075", size = 13514749, upload-time = "2025-09-19T00:10:44.827Z" },
{ url = "https://files.pythonhosted.org/packages/25/bc/cc98767cffd6b2928ba680f3e5bc969c4152bf7c2d83f92f5a504b92b0eb/mypy-1.18.2-cp314-cp314-win_amd64.whl", hash = "sha256:749b5f83198f1ca64345603118a6f01a4e99ad4bf9d103ddc5a3200cc4614adf", size = 9982959, upload-time = "2025-09-19T00:10:37.344Z" },
{ url = "https://files.pythonhosted.org/packages/87/e3/be76d87158ebafa0309946c4a73831974d4d6ab4f4ef40c3b53a385a66fd/mypy-1.18.2-py3-none-any.whl", hash = "sha256:22a1748707dd62b58d2ae53562ffc4d7f8bcc727e8ac7cbc69c053ddc874d47e", size = 2352367, upload-time = "2025-09-19T00:10:15.489Z" },
]
[[package]]
name = "mypy-extensions"
version = "1.1.0"
source = { registry = "https://pypi.org/simple" }
sdist = { url = "https://files.pythonhosted.org/packages/a2/6e/371856a3fb9d31ca8dac321cda606860fa4548858c0cc45d9d1d4ca2628b/mypy_extensions-1.1.0.tar.gz", hash = "sha256:52e68efc3284861e772bbcd66823fde5ae21fd2fdb51c62a211403730b916558", size = 6343, upload-time = "2025-04-22T14:54:24.164Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/79/7b/2c79738432f5c924bef5071f933bcc9efd0473bac3b4aa584a6f7c1c8df8/mypy_extensions-1.1.0-py3-none-any.whl", hash = "sha256:1be4cccdb0f2482337c4743e60421de3a356cd97508abadd57d47403e94f5505", size = 4963, upload-time = "2025-04-22T14:54:22.983Z" },
{ url = "https://files.pythonhosted.org/packages/88/b2/d0896bdcdc8d28a7fc5717c305f1a861c26e18c05047949fb371034d98bd/nodeenv-1.10.0-py2.py3-none-any.whl", hash = "sha256:5bb13e3eed2923615535339b3c620e76779af4cb4c6a90deccc9e36b274d3827", size = 23438, upload-time = "2025-12-20T14:08:52.782Z" },
]
[[package]]
@@ -1024,15 +992,6 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/3b/a4/ab6b7589382ca3df236e03faa71deac88cae040af60c071a78d254a62172/passlib-1.7.4-py2.py3-none-any.whl", hash = "sha256:aa6bca462b8d8bda89c70b382f0c298a20b5560af6cbfa2dce410c0a2fb669f1", size = 525554, upload-time = "2020-10-08T19:00:49.856Z" },
]
[[package]]
name = "pathspec"
version = "0.12.1"
source = { registry = "https://pypi.org/simple" }
sdist = { url = "https://files.pythonhosted.org/packages/ca/bc/f35b8446f4531a7cb215605d100cd88b7ac6f44ab3fc94870c120ab3adbf/pathspec-0.12.1.tar.gz", hash = "sha256:a482d51503a1ab33b1c67a6c3813a26953dbdc71c31dacaef9a838c4e29f5712", size = 51043, upload-time = "2023-12-10T22:30:45Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/cc/20/ff623b09d963f88bfde16306a54e12ee5ea43e9b597108672ff3a408aad6/pathspec-0.12.1-py3-none-any.whl", hash = "sha256:a0d503e138a4c123b27490a4f7beda6a01c6f288df0e4a8b79c7eb0dc7b4cc08", size = 31191, upload-time = "2023-12-10T22:30:43.14Z" },
]
[[package]]
name = "pillow"
version = "12.0.0"
@@ -1302,6 +1261,19 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/04/af/d8bf0959ece9bc4679bd203908c31019556a421d76d8143b0c6871c7f614/pyrate_limiter-3.9.0-py3-none-any.whl", hash = "sha256:77357840c8cf97a36d67005d4e090787043f54000c12c2b414ff65657653e378", size = 33628, upload-time = "2025-07-30T14:36:57.71Z" },
]
[[package]]
name = "pyright"
version = "1.1.408"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "nodeenv" },
{ name = "typing-extensions" },
]
sdist = { url = "https://files.pythonhosted.org/packages/74/b2/5db700e52554b8f025faa9c3c624c59f1f6c8841ba81ab97641b54322f16/pyright-1.1.408.tar.gz", hash = "sha256:f28f2321f96852fa50b5829ea492f6adb0e6954568d1caa3f3af3a5f555eb684", size = 4400578, upload-time = "2026-01-08T08:07:38.795Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/0c/82/a2c93e32800940d9573fb28c346772a14778b84ba7524e691b324620ab89/pyright-1.1.408-py3-none-any.whl", hash = "sha256:090b32865f4fdb1e0e6cd82bf5618480d48eecd2eb2e70f960982a3d9a4c17c1", size = 6399144, upload-time = "2026-01-08T08:07:37.082Z" },
]
[[package]]
name = "pytest"
version = "8.4.2"